2
* Copyright 2008-2011 NVIDIA Corporation
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
8
* http://www.apache.org/licenses/LICENSE-2.0
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
17
#include <thrust/detail/config.h>
19
// do not attempt to compile this file with any other compiler
20
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
24
#include <thrust/device_ptr.h>
25
#include <thrust/gather.h>
26
#include <thrust/reduce.h>
27
#include <thrust/sequence.h>
28
#include <thrust/transform.h>
29
#include <thrust/iterator/iterator_traits.h>
31
#include <thrust/detail/raw_buffer.h>
32
#include <thrust/detail/type_traits.h>
34
#include "stable_radix_sort_bits.h"
51
template <typename KeyType>
52
void stable_radix_sort_key_small_dev(KeyType * keys, unsigned int num_elements)
54
// encode the small types in 32-bit unsigned ints
55
thrust::detail::raw_cuda_device_buffer<unsigned int> full_keys(num_elements);
57
thrust::transform(thrust::device_ptr<KeyType>(keys),
58
thrust::device_ptr<KeyType>(keys) + num_elements,
60
encode_uint<KeyType>());
62
// sort the 32-bit unsigned ints
63
stable_radix_sort(full_keys.begin(), full_keys.end());
65
// decode the 32-bit unsigned ints
66
thrust::transform(full_keys.begin(),
68
thrust::device_ptr<KeyType>(keys),
69
decode_uint<KeyType>());
72
template <typename KeyType>
73
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
74
thrust::detail::integral_constant<int, 1>)
76
stable_radix_sort_key_small_dev(keys, num_elements);
85
template <typename KeyType>
86
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
87
thrust::detail::integral_constant<int, 2>)
89
stable_radix_sort_key_small_dev(keys, num_elements);
97
template <typename KeyType>
98
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
99
thrust::detail::integral_constant<int, 4>,
100
thrust::detail::integral_constant<bool, true>,
101
thrust::detail::integral_constant<bool, false>) // uint32
103
radix_sort((unsigned int *) keys, num_elements, encode_uint<KeyType>(), encode_uint<KeyType>());
106
template <typename KeyType>
107
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
108
thrust::detail::integral_constant<int, 4>,
109
thrust::detail::integral_constant<bool, true>,
110
thrust::detail::integral_constant<bool, true>) // int32
112
radix_sort((unsigned int*) keys, num_elements, encode_uint<KeyType>(), decode_uint<KeyType>());
115
template <typename KeyType>
116
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
117
thrust::detail::integral_constant<int, 4>,
118
thrust::detail::integral_constant<bool, false>,
119
thrust::detail::integral_constant<bool, true>) // float32
121
radix_sort((unsigned int*) keys, num_elements, encode_uint<KeyType>(), decode_uint<KeyType>());
124
template <typename KeyType>
125
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
126
thrust::detail::integral_constant<int, 4>)
128
stable_radix_sort_key_dev(keys, num_elements,
129
thrust::detail::integral_constant<int, 4>(),
130
thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_exact>(),
131
thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_signed>());
138
template <typename KeyType,
139
typename LowerBits, typename UpperBits,
140
typename LowerBitsExtractor, typename UpperBitsExtractor>
141
void stable_radix_sort_key_large_dev(KeyType * keys, unsigned int num_elements,
142
LowerBitsExtractor extract_lower_bits,
143
UpperBitsExtractor extract_upper_bits)
145
// first sort on the lower 32-bits of the keys
146
thrust::detail::raw_cuda_device_buffer<unsigned int> partial_keys(num_elements);
147
thrust::transform(thrust::device_ptr<KeyType>(keys),
148
thrust::device_ptr<KeyType>(keys) + num_elements,
149
partial_keys.begin(),
152
thrust::detail::raw_cuda_device_buffer<unsigned int> permutation(num_elements);
153
thrust::sequence(permutation.begin(), permutation.end());
155
stable_radix_sort_by_key((LowerBits *) thrust::raw_pointer_cast(&partial_keys[0]),
156
(LowerBits *) thrust::raw_pointer_cast(&partial_keys[0]) + num_elements,
157
thrust::raw_pointer_cast(&permutation[0]));
159
// permute full keys so lower bits are sorted
160
thrust::detail::raw_cuda_device_buffer<KeyType> permuted_keys(num_elements);
161
thrust::gather(permutation.begin(), permutation.end(),
162
thrust::device_ptr<KeyType>(keys),
163
permuted_keys.begin());
165
// now sort on the upper 32 bits of the keys
166
thrust::transform(permuted_keys.begin(),
168
partial_keys.begin(),
170
thrust::sequence(permutation.begin(), permutation.end());
172
stable_radix_sort_by_key((UpperBits *) thrust::raw_pointer_cast(&partial_keys[0]),
173
(UpperBits *) thrust::raw_pointer_cast(&partial_keys[0]) + num_elements,
174
thrust::raw_pointer_cast(&permutation[0]));
177
thrust::gather(permutation.begin(), permutation.end(),
178
permuted_keys.begin(),
179
thrust::device_ptr<KeyType>(keys));
183
template <typename KeyType>
184
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
185
thrust::detail::integral_constant<int, 8>,
186
thrust::detail::integral_constant<bool, true>,
187
thrust::detail::integral_constant<bool, false>) // uint64
189
stable_radix_sort_key_large_dev<KeyType, unsigned int, unsigned int, lower_32_bits<KeyType>, upper_32_bits<KeyType> >
190
(keys, num_elements, lower_32_bits<KeyType>(), upper_32_bits<KeyType>());
193
template <typename KeyType>
194
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
195
thrust::detail::integral_constant<int, 8>,
196
thrust::detail::integral_constant<bool, true>,
197
thrust::detail::integral_constant<bool, true>) // int64
199
stable_radix_sort_key_large_dev<KeyType, unsigned int, int, lower_32_bits<KeyType>, upper_32_bits<KeyType> >
200
(keys, num_elements, lower_32_bits<KeyType>(), upper_32_bits<KeyType>());
203
template <typename KeyType>
204
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
205
thrust::detail::integral_constant<int, 8>,
206
thrust::detail::integral_constant<bool, false>,
207
thrust::detail::integral_constant<bool, true>) // float64
209
typedef unsigned long long uint64;
210
stable_radix_sort_key_large_dev<uint64, unsigned int, unsigned int, lower_32_bits<KeyType>, upper_32_bits<KeyType> >
211
(reinterpret_cast<uint64 *>(keys), num_elements, lower_32_bits<KeyType>(), upper_32_bits<KeyType>());
214
template <typename KeyType>
215
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
216
thrust::detail::integral_constant<int, 8>)
218
stable_radix_sort_key_dev(keys, num_elements,
219
thrust::detail::integral_constant<int, 8>(),
220
thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_exact>(),
221
thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_signed>());
228
template<typename RandomAccessIterator>
229
void stable_radix_sort(RandomAccessIterator first,
230
RandomAccessIterator last)
232
typedef typename thrust::iterator_traits<RandomAccessIterator>::value_type KeyType;
234
// TODO static_assert< is_arithmetic<KeyType> >
236
// RandomAccessIterator should be a trivial iterator
237
KeyType * keys = thrust::raw_pointer_cast(&*first);
239
// we only handle < 2^32 elements right now
240
__THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING( \
241
unsigned int num_elements = last - first);
243
// dispatch on sizeof(KeyType)
244
stable_radix_sort_key_dev(keys, num_elements, thrust::detail::integral_constant<int, sizeof(KeyType)>());
248
} // end namespace detail
249
} // end namespace cuda
250
} // end namespace device
251
} // end namespace detail
252
} // end namespace thrust
254
#endif // THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC