1
/******************************************************************************
2
* Copyright (c) 2011, Duane Merrill. All rights reserved.
3
* Copyright (c) 2011-2013, NVIDIA CORPORATION. All rights reserved.
5
* Redistribution and use in source and binary forms, with or without
6
* modification, are permitted provided that the following conditions are met:
7
* * Redistributions of source code must retain the above copyright
8
* notice, this list of conditions and the following disclaimer.
9
* * Redistributions in binary form must reproduce the above copyright
10
* notice, this list of conditions and the following disclaimer in the
11
* documentation and/or other materials provided with the distribution.
12
* * Neither the name of the NVIDIA CORPORATION nor the
13
* names of its contributors may be used to endorse or promote products
14
* derived from this software without specific prior written permission.
16
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
******************************************************************************/
31
* cub::WarpScanShfl provides SHFL-based variants of parallel prefix scan across CUDA warps.
36
#include "../../thread/thread_operators.cuh"
37
#include "../../util_type.cuh"
38
#include "../../util_ptx.cuh"
39
#include "../../util_namespace.cuh"
41
/// Optional outer namespace(s)
48
* \brief WarpScanShfl provides SHFL-based variants of parallel prefix scan across CUDA warps.
51
typename T, ///< Data type being scanned
52
int LOGICAL_WARPS, ///< Number of logical warps entrant
53
int LOGICAL_WARP_THREADS> ///< Number of threads per logical warp
57
/******************************************************************************
58
* Constants and typedefs
59
******************************************************************************/
63
/// The number of warp scan steps
64
STEPS = Log2<LOGICAL_WARP_THREADS>::VALUE,
66
// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
67
SHFL_C = ((-1 << STEPS) & 31) << 8,
70
/// Shared memory storage layout type
71
typedef NullType TempStorage;
74
/******************************************************************************
76
******************************************************************************/
81
/******************************************************************************
83
******************************************************************************/
86
__device__ __forceinline__ WarpScanShfl(
87
TempStorage &temp_storage,
96
/******************************************************************************
98
******************************************************************************/
101
__device__ __forceinline__ T Broadcast(
102
T input, ///< [in] The value to broadcast
103
int src_lane) ///< [in] Which warp lane is to do the broadcasting
105
typedef typename WordAlignment<T>::ShuffleWord ShuffleWord;
107
const int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
109
ShuffleWord *output_alias = reinterpret_cast<ShuffleWord *>(&output);
110
ShuffleWord *input_alias = reinterpret_cast<ShuffleWord *>(&input);
113
for (int WORD = 0; WORD < WORDS; ++WORD)
115
unsigned int shuffle_word = input_alias[WORD];
116
asm("shfl.idx.b32 %0, %1, %2, %3;"
117
: "=r"(shuffle_word) : "r"(shuffle_word), "r"(src_lane), "r"(LOGICAL_WARP_THREADS - 1));
118
output_alias[WORD] = (ShuffleWord) shuffle_word;
125
//---------------------------------------------------------------------
126
// Inclusive operations
127
//---------------------------------------------------------------------
129
/// Inclusive prefix sum with aggregate (single-SHFL)
130
__device__ __forceinline__ void InclusiveSum(
131
T input, ///< [in] Calling thread's input item.
132
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
133
T &warp_aggregate, ///< [out] Warp-wide aggregate reduction of input items.
134
Int2Type<true> single_shfl)
136
unsigned int temp = reinterpret_cast<unsigned int &>(input);
138
// Iterate scan steps
140
for (int STEP = 0; STEP < STEPS; STEP++)
142
// Use predicate set from SHFL to guard against invalid peers
147
" shfl.up.b32 r0|p, %1, %2, %3;"
148
" @p add.u32 r0, r0, %4;"
151
: "=r"(temp) : "r"(temp), "r"(1 << STEP), "r"(SHFL_C), "r"(temp));
156
// Grab aggregate from last warp lane
157
warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
161
/// Inclusive prefix sum with aggregate (multi-SHFL)
162
__device__ __forceinline__ void InclusiveSum(
163
T input, ///< [in] Calling thread's input item.
164
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
165
T &warp_aggregate, ///< [out] Warp-wide aggregate reduction of input items.
166
Int2Type<false> single_shfl) ///< [in] Marker type indicating whether only one SHFL instruction is required
168
// Delegate to generic scan
169
InclusiveScan(input, output, Sum(), warp_aggregate);
173
/// Inclusive prefix sum with aggregate (specialized for float)
174
__device__ __forceinline__ void InclusiveSum(
175
float input, ///< [in] Calling thread's input item.
176
float &output, ///< [out] Calling thread's output item. May be aliased with \p input.
177
float &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
181
// Iterate scan steps
183
for (int STEP = 0; STEP < STEPS; STEP++)
185
// Use predicate set from SHFL to guard against invalid peers
190
" shfl.up.b32 r0|p, %1, %2, %3;"
191
" @p add.f32 r0, r0, %4;"
194
: "=f"(output) : "f"(output), "r"(1 << STEP), "r"(SHFL_C), "f"(output));
197
// Grab aggregate from last warp lane
198
warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
202
/// Inclusive prefix sum with aggregate (specialized for unsigned long long)
203
__device__ __forceinline__ void InclusiveSum(
204
unsigned long long input, ///< [in] Calling thread's input item.
205
unsigned long long &output, ///< [out] Calling thread's output item. May be aliased with \p input.
206
unsigned long long &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
210
// Iterate scan steps
212
for (int STEP = 0; STEP < STEPS; STEP++)
214
// Use predicate set from SHFL to guard against invalid peers
222
" mov.b64 {lo, hi}, %1;"
223
" shfl.up.b32 r0|p, lo, %2, %3;"
224
" shfl.up.b32 r1|p, hi, %2, %3;"
225
" @p add.cc.u32 r0, r0, lo;"
226
" @p addc.u32 r1, r1, hi;"
227
" mov.b64 %0, {r0, r1};"
229
: "=l"(output) : "l"(output), "r"(1 << STEP), "r"(SHFL_C));
232
// Grab aggregate from last warp lane
233
warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
237
/// Inclusive prefix sum with aggregate (generic)
238
template <typename _T>
239
__device__ __forceinline__ void InclusiveSum(
240
_T input, ///< [in] Calling thread's input item.
241
_T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
242
_T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
244
// Whether sharing can be done with a single SHFL instruction (vs multiple SFHL instructions)
245
Int2Type<(Traits<_T>::PRIMITIVE) && (sizeof(_T) <= sizeof(unsigned int))> single_shfl;
247
InclusiveSum(input, output, warp_aggregate, single_shfl);
251
/// Inclusive prefix sum
252
__device__ __forceinline__ void InclusiveSum(
253
T input, ///< [in] Calling thread's input item.
254
T &output) ///< [out] Calling thread's output item. May be aliased with \p input.
257
InclusiveSum(input, output, warp_aggregate);
261
/// Inclusive scan with aggregate
262
template <typename ScanOp>
263
__device__ __forceinline__ void InclusiveScan(
264
T input, ///< [in] Calling thread's input item.
265
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
266
ScanOp scan_op, ///< [in] Binary scan operator
267
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
271
// Iterate scan steps
273
for (int STEP = 0; STEP < STEPS; STEP++)
275
// Grab addend from peer
276
const int OFFSET = 1 << STEP;
277
T temp = ShuffleUp(output, OFFSET);
279
// Perform scan op if from a valid peer
280
if (lane_id >= OFFSET)
281
output = scan_op(temp, output);
284
// Grab aggregate from last warp lane
285
warp_aggregate = Broadcast(output, LOGICAL_WARP_THREADS - 1);
290
template <typename ScanOp>
291
__device__ __forceinline__ void InclusiveScan(
292
T input, ///< [in] Calling thread's input item.
293
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
294
ScanOp scan_op) ///< [in] Binary scan operator
297
InclusiveScan(input, output, scan_op, warp_aggregate);
301
//---------------------------------------------------------------------
302
// Exclusive operations
303
//---------------------------------------------------------------------
305
/// Exclusive scan with aggregate
306
template <typename ScanOp>
307
__device__ __forceinline__ void ExclusiveScan(
308
T input, ///< [in] Calling thread's input item.
309
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
310
T identity, ///< [in] Identity value
311
ScanOp scan_op, ///< [in] Binary scan operator
312
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
314
// Compute inclusive scan
316
InclusiveScan(input, inclusive, scan_op, warp_aggregate);
318
// Grab result from predecessor
319
T exclusive = ShuffleUp(inclusive, 1);
321
output = (lane_id == 0) ?
328
template <typename ScanOp>
329
__device__ __forceinline__ void ExclusiveScan(
330
T input, ///< [in] Calling thread's input item.
331
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
332
T identity, ///< [in] Identity value
333
ScanOp scan_op) ///< [in] Binary scan operator
336
ExclusiveScan(input, output, identity, scan_op, warp_aggregate);
340
/// Exclusive scan with aggregate, without identity
341
template <typename ScanOp>
342
__device__ __forceinline__ void ExclusiveScan(
343
T input, ///< [in] Calling thread's input item.
344
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
345
ScanOp scan_op, ///< [in] Binary scan operator
346
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
348
// Compute inclusive scan
350
InclusiveScan(input, inclusive, scan_op, warp_aggregate);
352
// Grab result from predecessor
353
output = ShuffleUp(inclusive, 1);
357
/// Exclusive scan without identity
358
template <typename ScanOp>
359
__device__ __forceinline__ void ExclusiveScan(
360
T input, ///< [in] Calling thread's input item.
361
T &output, ///< [out] Calling thread's output item. May be aliased with \p input.
362
ScanOp scan_op) ///< [in] Binary scan operator
365
ExclusiveScan(input, output, scan_op, warp_aggregate);
371
CUB_NS_POSTFIX // Optional outer namespace(s)