~ubuntu-branches/ubuntu/trusty/libthrust/trusty

« back to all changes in this revision

Viewing changes to detail/device/cuda/detail/stable_radix_sort_key.inl

  • Committer: Bazaar Package Importer
  • Author(s): Andreas Beckmann
  • Date: 2011-05-28 09:32:48 UTC
  • Revision ID: james.westby@ubuntu.com-20110528093248-np3euv5sj7fw3nyv
Tags: upstream-1.4.0
ImportĀ upstreamĀ versionĀ 1.4.0

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/*
 
2
 *  Copyright 2008-2011 NVIDIA Corporation
 
3
 *
 
4
 *  Licensed under the Apache License, Version 2.0 (the "License");
 
5
 *  you may not use this file except in compliance with the License.
 
6
 *  You may obtain a copy of the License at
 
7
 *
 
8
 *      http://www.apache.org/licenses/LICENSE-2.0
 
9
 *
 
10
 *  Unless required by applicable law or agreed to in writing, software
 
11
 *  distributed under the License is distributed on an "AS IS" BASIS,
 
12
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 
13
 *  See the License for the specific language governing permissions and
 
14
 *  limitations under the License.
 
15
 */
 
16
 
 
17
#include <thrust/detail/config.h>
 
18
 
 
19
// do not attempt to compile this file with any other compiler
 
20
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
 
21
 
 
22
#include <limits>
 
23
 
 
24
#include <thrust/device_ptr.h>
 
25
#include <thrust/gather.h>
 
26
#include <thrust/reduce.h>
 
27
#include <thrust/sequence.h>
 
28
#include <thrust/transform.h>
 
29
#include <thrust/iterator/iterator_traits.h>
 
30
 
 
31
#include <thrust/detail/raw_buffer.h>
 
32
#include <thrust/detail/type_traits.h>
 
33
 
 
34
#include "stable_radix_sort_bits.h"
 
35
 
 
36
namespace thrust
 
37
{
 
38
namespace detail
 
39
{
 
40
namespace device
 
41
{
 
42
namespace cuda
 
43
{
 
44
namespace detail
 
45
{
 
46
 
 
47
//////////////////
 
48
// 8 BIT TYPES //
 
49
//////////////////
 
50
 
 
51
template <typename KeyType>
 
52
void stable_radix_sort_key_small_dev(KeyType * keys, unsigned int num_elements)
 
53
{
 
54
    // encode the small types in 32-bit unsigned ints
 
55
    thrust::detail::raw_cuda_device_buffer<unsigned int> full_keys(num_elements);
 
56
 
 
57
    thrust::transform(thrust::device_ptr<KeyType>(keys), 
 
58
                      thrust::device_ptr<KeyType>(keys) + num_elements,
 
59
                      full_keys.begin(),
 
60
                      encode_uint<KeyType>());
 
61
 
 
62
    // sort the 32-bit unsigned ints
 
63
    stable_radix_sort(full_keys.begin(), full_keys.end());
 
64
    
 
65
    // decode the 32-bit unsigned ints
 
66
    thrust::transform(full_keys.begin(),
 
67
                      full_keys.end(),
 
68
                      thrust::device_ptr<KeyType>(keys),
 
69
                      decode_uint<KeyType>());
 
70
}
 
71
 
 
72
template <typename KeyType>
 
73
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
74
                               thrust::detail::integral_constant<int, 1>)
 
75
{
 
76
    stable_radix_sort_key_small_dev(keys, num_elements);
 
77
}
 
78
 
 
79
 
 
80
//////////////////
 
81
// 16 BIT TYPES //
 
82
//////////////////
 
83
 
 
84
    
 
85
template <typename KeyType>
 
86
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
87
                               thrust::detail::integral_constant<int, 2>)
 
88
{
 
89
    stable_radix_sort_key_small_dev(keys, num_elements);
 
90
}
 
91
 
 
92
 
 
93
//////////////////
 
94
// 32 BIT TYPES //
 
95
//////////////////
 
96
 
 
97
template <typename KeyType> 
 
98
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
99
                               thrust::detail::integral_constant<int, 4>,
 
100
                               thrust::detail::integral_constant<bool, true>,   
 
101
                               thrust::detail::integral_constant<bool, false>)  // uint32
 
102
{
 
103
    radix_sort((unsigned int *) keys, num_elements, encode_uint<KeyType>(), encode_uint<KeyType>());
 
104
}
 
105
 
 
106
template <typename KeyType> 
 
107
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
108
                               thrust::detail::integral_constant<int, 4>,
 
109
                               thrust::detail::integral_constant<bool, true>,
 
110
                               thrust::detail::integral_constant<bool, true>)   // int32
 
111
{
 
112
    radix_sort((unsigned int*) keys, num_elements, encode_uint<KeyType>(), decode_uint<KeyType>());
 
113
}
 
114
 
 
115
template <typename KeyType> 
 
116
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
117
                               thrust::detail::integral_constant<int, 4>,
 
118
                               thrust::detail::integral_constant<bool, false>,
 
119
                               thrust::detail::integral_constant<bool, true>)  // float32
 
120
{
 
121
    radix_sort((unsigned int*) keys, num_elements, encode_uint<KeyType>(), decode_uint<KeyType>());
 
122
}
 
123
 
 
124
template <typename KeyType>
 
125
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
126
                               thrust::detail::integral_constant<int, 4>)
 
127
{
 
128
    stable_radix_sort_key_dev(keys, num_elements,
 
129
                              thrust::detail::integral_constant<int, 4>(),
 
130
                              thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_exact>(),
 
131
                              thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_signed>());
 
132
}
 
133
 
 
134
//////////////////
 
135
// 64 BIT TYPES //
 
136
//////////////////
 
137
 
 
138
template <typename KeyType,
 
139
          typename LowerBits, typename UpperBits, 
 
140
          typename LowerBitsExtractor, typename UpperBitsExtractor>
 
141
void stable_radix_sort_key_large_dev(KeyType * keys, unsigned int num_elements,
 
142
                                     LowerBitsExtractor extract_lower_bits,
 
143
                                     UpperBitsExtractor extract_upper_bits)
 
144
{
 
145
    // first sort on the lower 32-bits of the keys
 
146
    thrust::detail::raw_cuda_device_buffer<unsigned int> partial_keys(num_elements);
 
147
    thrust::transform(thrust::device_ptr<KeyType>(keys), 
 
148
                      thrust::device_ptr<KeyType>(keys) + num_elements,
 
149
                      partial_keys.begin(),
 
150
                      extract_lower_bits);
 
151
 
 
152
    thrust::detail::raw_cuda_device_buffer<unsigned int> permutation(num_elements);
 
153
    thrust::sequence(permutation.begin(), permutation.end());
 
154
    
 
155
    stable_radix_sort_by_key((LowerBits *) thrust::raw_pointer_cast(&partial_keys[0]),
 
156
                             (LowerBits *) thrust::raw_pointer_cast(&partial_keys[0]) + num_elements,
 
157
                             thrust::raw_pointer_cast(&permutation[0]));
 
158
 
 
159
    // permute full keys so lower bits are sorted
 
160
    thrust::detail::raw_cuda_device_buffer<KeyType> permuted_keys(num_elements);
 
161
    thrust::gather(permutation.begin(), permutation.end(),
 
162
                   thrust::device_ptr<KeyType>(keys),
 
163
                   permuted_keys.begin());
 
164
    
 
165
    // now sort on the upper 32 bits of the keys
 
166
    thrust::transform(permuted_keys.begin(),
 
167
                      permuted_keys.end(),
 
168
                      partial_keys.begin(),
 
169
                      extract_upper_bits);
 
170
    thrust::sequence(permutation.begin(), permutation.end());
 
171
    
 
172
    stable_radix_sort_by_key((UpperBits *) thrust::raw_pointer_cast(&partial_keys[0]),
 
173
                             (UpperBits *) thrust::raw_pointer_cast(&partial_keys[0]) + num_elements,
 
174
                             thrust::raw_pointer_cast(&permutation[0]));
 
175
 
 
176
    // store sorted keys
 
177
    thrust::gather(permutation.begin(), permutation.end(),
 
178
                   permuted_keys.begin(),
 
179
                   thrust::device_ptr<KeyType>(keys));
 
180
}
 
181
 
 
182
    
 
183
template <typename KeyType>
 
184
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
185
                               thrust::detail::integral_constant<int, 8>,
 
186
                               thrust::detail::integral_constant<bool, true>,
 
187
                               thrust::detail::integral_constant<bool, false>)  // uint64
 
188
{
 
189
    stable_radix_sort_key_large_dev<KeyType, unsigned int, unsigned int, lower_32_bits<KeyType>, upper_32_bits<KeyType> >
 
190
        (keys, num_elements, lower_32_bits<KeyType>(), upper_32_bits<KeyType>());
 
191
}
 
192
 
 
193
template <typename KeyType>
 
194
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
195
                               thrust::detail::integral_constant<int, 8>,
 
196
                               thrust::detail::integral_constant<bool, true>,
 
197
                               thrust::detail::integral_constant<bool, true>)   // int64
 
198
{
 
199
    stable_radix_sort_key_large_dev<KeyType, unsigned int, int, lower_32_bits<KeyType>, upper_32_bits<KeyType> >
 
200
        (keys, num_elements, lower_32_bits<KeyType>(), upper_32_bits<KeyType>());
 
201
}
 
202
 
 
203
template <typename KeyType>
 
204
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
205
                               thrust::detail::integral_constant<int, 8>,
 
206
                               thrust::detail::integral_constant<bool, false>,
 
207
                               thrust::detail::integral_constant<bool, true>)  // float64
 
208
{
 
209
    typedef unsigned long long uint64;
 
210
    stable_radix_sort_key_large_dev<uint64, unsigned int, unsigned int, lower_32_bits<KeyType>, upper_32_bits<KeyType> >
 
211
        (reinterpret_cast<uint64 *>(keys), num_elements, lower_32_bits<KeyType>(), upper_32_bits<KeyType>());
 
212
}
 
213
 
 
214
template <typename KeyType>
 
215
void stable_radix_sort_key_dev(KeyType * keys, unsigned int num_elements,
 
216
                               thrust::detail::integral_constant<int, 8>)
 
217
{
 
218
    stable_radix_sort_key_dev(keys, num_elements,
 
219
                              thrust::detail::integral_constant<int, 8>(),
 
220
                              thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_exact>(),
 
221
                              thrust::detail::integral_constant<bool, std::numeric_limits<KeyType>::is_signed>());
 
222
}
 
223
 
 
224
/////////////////
 
225
// Entry Point //
 
226
/////////////////
 
227
 
 
228
template<typename RandomAccessIterator>
 
229
void stable_radix_sort(RandomAccessIterator first,
 
230
                       RandomAccessIterator last)
 
231
{
 
232
    typedef typename thrust::iterator_traits<RandomAccessIterator>::value_type KeyType;
 
233
 
 
234
    // TODO static_assert< is_arithmetic<KeyType> >
 
235
 
 
236
    // RandomAccessIterator should be a trivial iterator
 
237
    KeyType * keys = thrust::raw_pointer_cast(&*first);
 
238
 
 
239
    // we only handle < 2^32 elements right now
 
240
    __THRUST_DISABLE_MSVC_POSSIBLE_LOSS_OF_DATA_WARNING( \
 
241
        unsigned int num_elements = last - first);
 
242
 
 
243
    // dispatch on sizeof(KeyType)
 
244
    stable_radix_sort_key_dev(keys, num_elements, thrust::detail::integral_constant<int, sizeof(KeyType)>());
 
245
}
 
246
 
 
247
 
 
248
} // end namespace detail
 
249
} // end namespace cuda
 
250
} // end namespace device
 
251
} // end namespace detail
 
252
} // end namespace thrust
 
253
 
 
254
#endif // THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
 
255