2
* Copyright 2008-2012 NVIDIA Corporation
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
8
* http://www.apache.org/licenses/LICENSE-2.0
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
17
#include <thrust/iterator/iterator_traits.h>
18
#include <thrust/detail/raw_reference_cast.h>
19
#include <thrust/system/detail/generic/scalar/binary_search.h>
32
template<typename Context,
33
typename RandomAccessIterator1,
34
typename RandomAccessIterator2,
35
typename RandomAccessIterator3,
36
typename RandomAccessIterator4,
37
typename StrictWeakOrdering>
38
__device__ __thrust_forceinline__
39
RandomAccessIterator4 set_union(Context context,
40
RandomAccessIterator1 first1,
41
RandomAccessIterator1 last1,
42
RandomAccessIterator2 first2,
43
RandomAccessIterator2 last2,
44
RandomAccessIterator3 temporary,
45
RandomAccessIterator4 result,
46
StrictWeakOrdering comp)
48
typedef typename thrust::iterator_difference<RandomAccessIterator1>::type difference1;
49
typedef typename thrust::iterator_difference<RandomAccessIterator2>::type difference2;
51
difference1 n1 = last1 - first1;
52
difference2 n2 = last2 - first2;
54
if(n1 == 0 && n2 == 0) return result;
56
// for each element in the second range
57
// count the number of matches in the first range
58
// initialize rank1 to an impossible result
59
difference1 rank1 = difference1(-1);
61
if(context.thread_index() < n2)
63
RandomAccessIterator2 x = first2;
64
x += context.thread_index();
66
// count the number of previous occurrances of x in the second range
67
difference2 sub_rank2 = x - thrust::system::detail::generic::scalar::lower_bound(first2,x,raw_reference_cast(*x),comp);
69
// count the number of equivalent elements of x in the first range
70
thrust::pair<RandomAccessIterator1,RandomAccessIterator1> matches =
71
thrust::system::detail::generic::scalar::equal_range(first1,last1,raw_reference_cast(*x),comp);
73
difference2 num_matches = matches.second - matches.first;
75
// the element should be output if its rank is gequal to the number of matches
76
if(sub_rank2 >= num_matches)
78
rank1 = (matches.second - first1);
82
// for the second range of elements,
83
// mark in the scratch array if we need
85
RandomAccessIterator3 temp = temporary + context.thread_index();
86
*temp = (rank1 >= difference1(0)) ? 1 : 0;
90
// inclusive scan the scratch array
91
block::inclusive_scan_n(context, temporary, n2, thrust::plus<int>());
93
// find the rank of each element in the first range in the second range
94
// modulo the fact that some elements of the second range will not appear in the output
95
// these irrelevant elements should be skipped when computing ranks
96
// note that every element of the first range gets output
97
difference2 rank2 = 0;
98
if(context.thread_index() < n1)
100
RandomAccessIterator1 x = first1;
101
x += context.thread_index();
103
// lower_bound ensures that x sorts before any equivalent element of input2
104
// this ensures stability
105
rank2 = thrust::system::detail::generic::scalar::lower_bound(first2, last2, raw_reference_cast(*x), comp) - first2;
107
// since the temporary array contains, for each element inclusive,
108
// the number of previous active elements from the second range,
109
// we can compute the final rank2 simply by using the current value
110
// of rank2 as an index into the temporary array
111
if(rank2 > difference2(0))
113
// subtract one during the index because the scan was inclusive
114
rank2 = temporary[rank2-1];
118
// scatter elements from the first range to their place in the output
119
if(context.thread_index() < n1)
121
RandomAccessIterator1 src = first1 + context.thread_index();
122
RandomAccessIterator4 dst = result + context.thread_index() + rank2;
127
// scatter elements from the second range
128
if(context.thread_index() < n2 && (rank1 >= difference1(0)))
130
// find the index to write our element
131
unsigned int num_elements_from_second_range_before_me = 0;
132
if(context.thread_index() > 0)
134
RandomAccessIterator3 src = temporary;
135
src += context.thread_index() - 1;
136
num_elements_from_second_range_before_me = *src;
139
RandomAccessIterator2 src = first2;
140
src += context.thread_index();
142
RandomAccessIterator4 dst = result;
143
dst += num_elements_from_second_range_before_me + rank1;
148
// finding the size of the result:
149
// range 1: all of range 1 gets output, so add n1
150
// range 2: the temporary array contains, for each element inclusive,
151
// the cumulative number of elements from the second range to output
152
// add the cumulative sum at the final element of the second range
153
// but carefully handle the case where the range is empty
154
// XXX we could handle empty input as a special case at the beginning of the function
156
return result + n1 + (n2 ? temporary[n2-1] : 0);
159
} // end namespace block
160
} // end namespace detail
161
} // end namespace cuda
162
} // end namespace system
163
} // end namespace thrust