~ubuntu-branches/debian/sid/lammps/sid

« back to all changes in this revision

Viewing changes to lib/kokkos/TPL/cub/warp/specializations/warp_scan_shfl.cuh

  • Committer: Package Import Robot
  • Author(s): Anton Gladky
  • Date: 2015-04-29 23:44:49 UTC
  • mfrom: (5.1.3 experimental)
  • Revision ID: package-import@ubuntu.com-20150429234449-mbhy9utku6hp6oq8
Tags: 0~20150313.gitfa668e1-1
Upload into unstable.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/******************************************************************************
 
2
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 
3
 * Copyright (c) 2011-2013, NVIDIA CORPORATION.  All rights reserved.
 
4
 * 
 
5
 * Redistribution and use in source and binary forms, with or without
 
6
 * modification, are permitted provided that the following conditions are met:
 
7
 *     * Redistributions of source code must retain the above copyright
 
8
 *       notice, this list of conditions and the following disclaimer.
 
9
 *     * Redistributions in binary form must reproduce the above copyright
 
10
 *       notice, this list of conditions and the following disclaimer in the
 
11
 *       documentation and/or other materials provided with the distribution.
 
12
 *     * Neither the name of the NVIDIA CORPORATION nor the
 
13
 *       names of its contributors may be used to endorse or promote products
 
14
 *       derived from this software without specific prior written permission.
 
15
 * 
 
16
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 
17
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 
18
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 
19
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 
20
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 
21
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 
22
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 
23
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 
24
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 
25
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
26
 *
 
27
 ******************************************************************************/
 
28
 
 
29
/**
 
30
 * \file
 
31
 * cub::WarpScanShfl provides SHFL-based variants of parallel prefix scan across CUDA warps.
 
32
 */
 
33
 
 
34
#pragma once
 
35
 
 
36
#include "../../thread/thread_operators.cuh"
 
37
#include "../../util_type.cuh"
 
38
#include "../../util_ptx.cuh"
 
39
#include "../../util_namespace.cuh"
 
40
 
 
41
/// Optional outer namespace(s)
 
42
CUB_NS_PREFIX
 
43
 
 
44
/// CUB namespace
 
45
namespace cub {
 
46
 
 
47
/**
 
48
 * \brief WarpScanShfl provides SHFL-based variants of parallel prefix scan across CUDA warps.
 
49
 */
 
50
template <
 
51
    typename    T,                      ///< Data type being scanned
 
52
    int         LOGICAL_WARPS,          ///< Number of logical warps entrant
 
53
    int         LOGICAL_WARP_THREADS>   ///< Number of threads per logical warp
 
54
struct WarpScanShfl
 
55
{
 
56
 
 
57
    /******************************************************************************
 
58
     * Constants and typedefs
 
59
     ******************************************************************************/
 
60
 
 
61
    enum
 
62
    {
 
63
        /// The number of warp scan steps
 
64
        STEPS = Log2<LOGICAL_WARP_THREADS>::VALUE,
 
65
 
 
66
        // The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
 
67
        SHFL_C = ((-1 << STEPS) & 31) << 8,
 
68
    };
 
69
 
 
70
    /// Shared memory storage layout type
 
71
    typedef NullType TempStorage;
 
72
 
 
73
 
 
74
    /******************************************************************************
 
75
     * Thread fields
 
76
     ******************************************************************************/
 
77
 
 
78
    int             warp_id;
 
79
    int             lane_id;
 
80
 
 
81
    /******************************************************************************
 
82
     * Construction
 
83
     ******************************************************************************/
 
84
 
 
85
    /// Constructor
 
86
    __device__ __forceinline__ WarpScanShfl(
 
87
        TempStorage &temp_storage,
 
88
        int warp_id,
 
89
        int lane_id)
 
90
    :
 
91
        warp_id(warp_id),
 
92
        lane_id(lane_id)
 
93
    {}
 
94
 
 
95
 
 
96
    /******************************************************************************
 
97
     * Operation
 
98
     ******************************************************************************/
 
99
 
 
100
    /// Broadcast
 
101
    __device__ __forceinline__ T Broadcast(
 
102
        T               input,              ///< [in] The value to broadcast
 
103
        int             src_lane)           ///< [in] Which warp lane is to do the broadcasting
 
104
    {
 
105
        typedef typename WordAlignment<T>::ShuffleWord ShuffleWord;
 
106
 
 
107
        const int       WORDS           = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
 
108
        T               output;
 
109
        ShuffleWord     *output_alias   = reinterpret_cast<ShuffleWord *>(&output);
 
110
        ShuffleWord     *input_alias    = reinterpret_cast<ShuffleWord *>(&input);
 
111
 
 
112
        #pragma unroll
 
113
        for (int WORD = 0; WORD < WORDS; ++WORD)
 
114
        {
 
115
            unsigned int shuffle_word = input_alias[WORD];
 
116
            asm("shfl.idx.b32 %0, %1, %2, %3;"
 
117
                : "=r"(shuffle_word) : "r"(shuffle_word), "r"(src_lane), "r"(LOGICAL_WARP_THREADS - 1));
 
118
            output_alias[WORD] = (ShuffleWord) shuffle_word;
 
119
        }
 
120
 
 
121
        return output;
 
122
    }
 
123
 
 
124
 
 
125
    //---------------------------------------------------------------------
 
126
    // Inclusive operations
 
127
    //---------------------------------------------------------------------
 
128
 
 
129
    /// Inclusive prefix sum with aggregate (single-SHFL)
 
130
    __device__ __forceinline__ void InclusiveSum(
 
131
        T               input,              ///< [in] Calling thread's input item.
 
132
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
133
        T               &warp_aggregate,    ///< [out] Warp-wide aggregate reduction of input items.
 
134
        Int2Type<true>  single_shfl)
 
135
    {
 
136
        unsigned int temp = reinterpret_cast<unsigned int &>(input);
 
137
 
 
138
        // Iterate scan steps
 
139
        #pragma unroll
 
140
        for (int STEP = 0; STEP < STEPS; STEP++)
 
141
        {
 
142
            // Use predicate set from SHFL to guard against invalid peers
 
143
            asm(
 
144
                "{"
 
145
                "  .reg .u32 r0;"
 
146
                "  .reg .pred p;"
 
147
                "  shfl.up.b32 r0|p, %1, %2, %3;"
 
148
                "  @p add.u32 r0, r0, %4;"
 
149
                "  mov.u32 %0, r0;"
 
150
                "}"
 
151
                : "=r"(temp) : "r"(temp), "r"(1 << STEP), "r"(SHFL_C), "r"(temp));
 
152
        }
 
153
 
 
154
        output = temp;
 
155
 
 
156
        // Grab aggregate from last warp lane
 
157
        warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
 
158
    }
 
159
 
 
160
 
 
161
    /// Inclusive prefix sum with aggregate (multi-SHFL)
 
162
    __device__ __forceinline__ void InclusiveSum(
 
163
        T               input,              ///< [in] Calling thread's input item.
 
164
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
165
        T               &warp_aggregate,    ///< [out] Warp-wide aggregate reduction of input items.
 
166
        Int2Type<false> single_shfl)        ///< [in] Marker type indicating whether only one SHFL instruction is required
 
167
    {
 
168
        // Delegate to generic scan
 
169
        InclusiveScan(input, output, Sum(), warp_aggregate);
 
170
    }
 
171
 
 
172
 
 
173
    /// Inclusive prefix sum with aggregate (specialized for float)
 
174
    __device__ __forceinline__ void InclusiveSum(
 
175
        float           input,              ///< [in] Calling thread's input item.
 
176
        float           &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
177
        float           &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
 
178
    {
 
179
        output = input;
 
180
 
 
181
        // Iterate scan steps
 
182
        #pragma unroll
 
183
        for (int STEP = 0; STEP < STEPS; STEP++)
 
184
        {
 
185
            // Use predicate set from SHFL to guard against invalid peers
 
186
            asm(
 
187
                "{"
 
188
                "  .reg .f32 r0;"
 
189
                "  .reg .pred p;"
 
190
                "  shfl.up.b32 r0|p, %1, %2, %3;"
 
191
                "  @p add.f32 r0, r0, %4;"
 
192
                "  mov.f32 %0, r0;"
 
193
                "}"
 
194
                : "=f"(output) : "f"(output), "r"(1 << STEP), "r"(SHFL_C), "f"(output));
 
195
        }
 
196
 
 
197
        // Grab aggregate from last warp lane
 
198
        warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
 
199
    }
 
200
 
 
201
 
 
202
    /// Inclusive prefix sum with aggregate (specialized for unsigned long long)
 
203
    __device__ __forceinline__ void InclusiveSum(
 
204
        unsigned long long  input,              ///< [in] Calling thread's input item.
 
205
        unsigned long long  &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
206
        unsigned long long  &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
 
207
    {
 
208
        output = input;
 
209
 
 
210
        // Iterate scan steps
 
211
        #pragma unroll
 
212
        for (int STEP = 0; STEP < STEPS; STEP++)
 
213
        {
 
214
            // Use predicate set from SHFL to guard against invalid peers
 
215
            asm(
 
216
                "{"
 
217
                "  .reg .u32 r0;"
 
218
                "  .reg .u32 r1;"
 
219
                "  .reg .u32 lo;"
 
220
                "  .reg .u32 hi;"
 
221
                "  .reg .pred p;"
 
222
                "  mov.b64 {lo, hi}, %1;"
 
223
                "  shfl.up.b32 r0|p, lo, %2, %3;"
 
224
                "  shfl.up.b32 r1|p, hi, %2, %3;"
 
225
                "  @p add.cc.u32 r0, r0, lo;"
 
226
                "  @p addc.u32 r1, r1, hi;"
 
227
                "  mov.b64 %0, {r0, r1};"
 
228
                "}"
 
229
                : "=l"(output) : "l"(output), "r"(1 << STEP), "r"(SHFL_C));
 
230
        }
 
231
 
 
232
        // Grab aggregate from last warp lane
 
233
        warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
 
234
    }
 
235
 
 
236
 
 
237
    /// Inclusive prefix sum with aggregate (generic)
 
238
    template <typename _T>
 
239
    __device__ __forceinline__ void InclusiveSum(
 
240
        _T               input,             ///< [in] Calling thread's input item.
 
241
        _T               &output,           ///< [out] Calling thread's output item.  May be aliased with \p input.
 
242
        _T               &warp_aggregate)   ///< [out] Warp-wide aggregate reduction of input items.
 
243
    {
 
244
        // Whether sharing can be done with a single SHFL instruction (vs multiple SFHL instructions)
 
245
        Int2Type<(Traits<_T>::PRIMITIVE) && (sizeof(_T) <= sizeof(unsigned int))> single_shfl;
 
246
 
 
247
        InclusiveSum(input, output, warp_aggregate, single_shfl);
 
248
    }
 
249
 
 
250
 
 
251
    /// Inclusive prefix sum
 
252
    __device__ __forceinline__ void InclusiveSum(
 
253
        T               input,              ///< [in] Calling thread's input item.
 
254
        T               &output)            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
255
    {
 
256
        T warp_aggregate;
 
257
        InclusiveSum(input, output, warp_aggregate);
 
258
    }
 
259
 
 
260
 
 
261
    /// Inclusive scan with aggregate
 
262
    template <typename ScanOp>
 
263
    __device__ __forceinline__ void InclusiveScan(
 
264
        T               input,              ///< [in] Calling thread's input item.
 
265
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
266
        ScanOp          scan_op,            ///< [in] Binary scan operator
 
267
        T               &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
 
268
    {
 
269
        output = input;
 
270
 
 
271
        // Iterate scan steps
 
272
        #pragma unroll
 
273
        for (int STEP = 0; STEP < STEPS; STEP++)
 
274
        {
 
275
            // Grab addend from peer
 
276
            const int OFFSET = 1 << STEP;
 
277
            T temp = ShuffleUp(output, OFFSET);
 
278
 
 
279
            // Perform scan op if from a valid peer
 
280
            if (lane_id >= OFFSET)
 
281
                output = scan_op(temp, output);
 
282
        }
 
283
 
 
284
        // Grab aggregate from last warp lane
 
285
        warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
 
286
    }
 
287
 
 
288
 
 
289
    /// Inclusive scan
 
290
    template <typename ScanOp>
 
291
    __device__ __forceinline__ void InclusiveScan(
 
292
        T               input,              ///< [in] Calling thread's input item.
 
293
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
294
        ScanOp          scan_op)            ///< [in] Binary scan operator
 
295
    {
 
296
        T warp_aggregate;
 
297
        InclusiveScan(input, output, scan_op, warp_aggregate);
 
298
    }
 
299
 
 
300
 
 
301
    //---------------------------------------------------------------------
 
302
    // Exclusive operations
 
303
    //---------------------------------------------------------------------
 
304
 
 
305
    /// Exclusive scan with aggregate
 
306
    template <typename ScanOp>
 
307
    __device__ __forceinline__ void ExclusiveScan(
 
308
        T               input,              ///< [in] Calling thread's input item.
 
309
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
310
        T               identity,           ///< [in] Identity value
 
311
        ScanOp          scan_op,            ///< [in] Binary scan operator
 
312
        T               &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
 
313
    {
 
314
        // Compute inclusive scan
 
315
        T inclusive;
 
316
        InclusiveScan(input, inclusive, scan_op, warp_aggregate);
 
317
 
 
318
        // Grab result from predecessor
 
319
        T exclusive = ShuffleUp(inclusive, 1);
 
320
 
 
321
        output = (lane_id == 0) ?
 
322
            identity :
 
323
            exclusive;
 
324
    }
 
325
 
 
326
 
 
327
    /// Exclusive scan
 
328
    template <typename ScanOp>
 
329
    __device__ __forceinline__ void ExclusiveScan(
 
330
        T               input,              ///< [in] Calling thread's input item.
 
331
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
332
        T               identity,           ///< [in] Identity value
 
333
        ScanOp          scan_op)            ///< [in] Binary scan operator
 
334
    {
 
335
        T warp_aggregate;
 
336
        ExclusiveScan(input, output, identity, scan_op, warp_aggregate);
 
337
    }
 
338
 
 
339
 
 
340
    /// Exclusive scan with aggregate, without identity
 
341
    template <typename ScanOp>
 
342
    __device__ __forceinline__ void ExclusiveScan(
 
343
        T               input,              ///< [in] Calling thread's input item.
 
344
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
345
        ScanOp          scan_op,            ///< [in] Binary scan operator
 
346
        T               &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
 
347
    {
 
348
        // Compute inclusive scan
 
349
        T inclusive;
 
350
        InclusiveScan(input, inclusive, scan_op, warp_aggregate);
 
351
 
 
352
        // Grab result from predecessor
 
353
        output = ShuffleUp(inclusive, 1);
 
354
    }
 
355
 
 
356
 
 
357
    /// Exclusive scan without identity
 
358
    template <typename ScanOp>
 
359
    __device__ __forceinline__ void ExclusiveScan(
 
360
        T               input,              ///< [in] Calling thread's input item.
 
361
        T               &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
 
362
        ScanOp          scan_op)            ///< [in] Binary scan operator
 
363
    {
 
364
        T warp_aggregate;
 
365
        ExclusiveScan(input, output, scan_op, warp_aggregate);
 
366
    }
 
367
};
 
368
 
 
369
 
 
370
}               // CUB namespace
 
371
CUB_NS_POSTFIX  // Optional outer namespace(s)