2
* Copyright 2008-2011 NVIDIA Corporation
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
8
* http://www.apache.org/licenses/LICENSE-2.0
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
18
/*! \file internal_functional.inl
19
* \brief Non-public functionals used to implement algorithm internals.
24
#include <thrust/tuple.h>
25
#include <thrust/iterator/iterator_traits.h>
26
#include <memory> // for ::new
33
// unary_negate does not need to know argument_type
34
template <typename Predicate>
37
typedef bool result_type;
42
explicit unary_negate(const Predicate& pred) : pred(pred) {}
46
bool operator()(const T& x)
48
return !bool(pred(x));
52
// binary_negate does not need to know first_argument_type or second_argument_type
53
template <typename Predicate>
56
typedef bool result_type;
61
explicit binary_negate(const Predicate& pred) : pred(pred) {}
63
template <typename T1, typename T2>
65
bool operator()(const T1& x, const T2& y)
67
return !bool(pred(x,y));
71
template<typename Predicate>
73
thrust::detail::unary_negate<Predicate> not1(const Predicate &pred)
75
return thrust::detail::unary_negate<Predicate>(pred);
78
template<typename Predicate>
80
thrust::detail::binary_negate<Predicate> not2(const Predicate &pred)
82
return thrust::detail::binary_negate<Predicate>(pred);
86
// convert a predicate to a 0 or 1 integral value
87
template <typename Predicate, typename IntegralType>
88
struct predicate_to_integral
93
explicit predicate_to_integral(const Predicate& pred) : pred(pred) {}
97
bool operator()(const T& x)
99
return pred(x) ? IntegralType(1) : IntegralType(0);
104
// note that detail::equal_to does not force conversion from T2 -> T1 as equal_to does
105
template <typename T1>
108
typedef bool result_type;
110
template <typename T2>
112
bool operator()(const T1& lhs, const T2& rhs) const
118
// note that equal_to_value does not force conversion from T2 -> T1 as equal_to does
119
template <typename T2>
120
struct equal_to_value
124
equal_to_value(const T2& rhs) : rhs(rhs) {}
126
template <typename T1>
128
bool operator()(const T1& lhs) const
134
template <typename Predicate>
135
struct tuple_binary_predicate
137
typedef bool result_type;
140
tuple_binary_predicate(const Predicate& p) : pred(p) {}
142
template<typename Tuple>
144
bool operator()(const Tuple& t) const
146
return pred(thrust::get<0>(t), thrust::get<1>(t));
152
template <typename Predicate>
153
struct tuple_not_binary_predicate
155
typedef bool result_type;
158
tuple_not_binary_predicate(const Predicate& p) : pred(p) {}
160
template<typename Tuple>
162
bool operator()(const Tuple& t) const
164
return !pred(thrust::get<0>(t), thrust::get<1>(t));
170
template<typename Generator>
171
struct generate_functor
173
typedef void result_type;
176
generate_functor(Generator g)
179
// operator() does not take an lvalue reference because some iterators
180
// produce temporary proxy references when dereferenced. for example,
181
// consider the temporary tuple of references produced by zip_iterator.
182
// such temporaries cannot bind to an lvalue reference.
184
// to WAR this, accept a const reference (which is bindable to a temporary),
185
// and const_cast in the implementation.
187
// XXX change to an rvalue reference upon c++0x (which either a named variable
188
// or temporary can bind to)
191
void operator()(const T &x)
193
// we have to be naughty and const_cast this to get it to work
194
T &lvalue = const_cast<T&>(x);
196
// this assigns correctly whether x is a true reference or proxy
204
template<typename ResultType, typename BinaryFunction>
205
struct zipped_binary_op
207
typedef ResultType result_type;
210
zipped_binary_op(BinaryFunction binary_op)
211
: m_binary_op(binary_op) {}
213
template<typename Tuple>
215
inline result_type operator()(Tuple t)
217
return m_binary_op(thrust::get<0>(t), thrust::get<1>(t));
220
BinaryFunction m_binary_op;
223
template<typename UnaryFunction>
224
struct host_unary_transform_functor
226
typedef void result_type;
230
host_unary_transform_functor(UnaryFunction f_)
233
template<typename Tuple>
235
inline result_type operator()(Tuple t)
237
thrust::get<1>(t) = f(thrust::get<0>(t));
241
template<typename UnaryFunction>
242
struct device_unary_transform_functor
244
typedef void result_type;
248
device_unary_transform_functor(UnaryFunction f_)
251
// add __host__ to allow the omp backend compile with nvcc
252
template<typename Tuple>
254
inline result_type operator()(Tuple t)
256
thrust::get<1>(t) = f(thrust::get<0>(t));
261
template<typename Space, typename UnaryFunction>
262
struct unary_transform_functor
263
: thrust::detail::eval_if<
264
thrust::detail::is_convertible<Space, thrust::host_space_tag>::value,
265
thrust::detail::identity_<host_unary_transform_functor<UnaryFunction> >,
266
thrust::detail::identity_<device_unary_transform_functor<UnaryFunction> >
271
template <typename BinaryFunction>
272
struct host_binary_transform_functor
276
host_binary_transform_functor(BinaryFunction f_)
280
template <typename Tuple>
282
void operator()(Tuple t)
284
thrust::get<2>(t) = f(thrust::get<0>(t), thrust::get<1>(t));
286
}; // end binary_transform_functor
289
template <typename BinaryFunction>
290
struct device_binary_transform_functor
294
device_binary_transform_functor(BinaryFunction f_)
298
// add __host__ to allow the omp backend compile with nvcc
299
template <typename Tuple>
301
void operator()(Tuple t)
303
thrust::get<2>(t) = f(thrust::get<0>(t), thrust::get<1>(t));
305
}; // end binary_transform_functor
308
template<typename Space, typename BinaryFunction>
309
struct binary_transform_functor
310
: thrust::detail::eval_if<
311
thrust::detail::is_convertible<Space, thrust::host_space_tag>::value,
312
thrust::detail::identity_<host_binary_transform_functor<BinaryFunction> >,
313
thrust::detail::identity_<device_binary_transform_functor<BinaryFunction> >
318
template <typename UnaryFunction, typename Predicate>
319
struct host_unary_transform_if_functor
321
UnaryFunction unary_op;
324
host_unary_transform_if_functor(UnaryFunction _unary_op, Predicate _pred)
325
: unary_op(_unary_op), pred(_pred) {}
327
template <typename Tuple>
329
void operator()(Tuple t)
331
if(pred(thrust::get<1>(t)))
332
thrust::get<2>(t) = unary_op(thrust::get<0>(t));
334
}; // end host_unary_transform_if_functor
337
template <typename UnaryFunction, typename Predicate>
338
struct device_unary_transform_if_functor
340
UnaryFunction unary_op;
343
device_unary_transform_if_functor(UnaryFunction _unary_op, Predicate _pred)
344
: unary_op(_unary_op), pred(_pred) {}
346
// add __host__ to allow the omp backend compile with nvcc
347
template <typename Tuple>
349
void operator()(Tuple t)
351
if(pred(thrust::get<1>(t)))
352
thrust::get<2>(t) = unary_op(thrust::get<0>(t));
354
}; // end device_unary_transform_if_functor
357
template<typename Space, typename UnaryFunction, typename Predicate>
358
struct unary_transform_if_functor
359
: thrust::detail::eval_if<
360
thrust::detail::is_convertible<Space, thrust::host_space_tag>::value,
361
thrust::detail::identity_<host_unary_transform_if_functor<UnaryFunction,Predicate> >,
362
thrust::detail::identity_<device_unary_transform_if_functor<UnaryFunction,Predicate> >
367
template <typename BinaryFunction, typename Predicate>
368
struct host_binary_transform_if_functor
370
BinaryFunction binary_op;
373
host_binary_transform_if_functor(BinaryFunction _binary_op, Predicate _pred)
374
: binary_op(_binary_op), pred(_pred) {}
376
template <typename Tuple>
378
void operator()(Tuple t)
380
if(pred(thrust::get<2>(t)))
381
thrust::get<3>(t) = binary_op(thrust::get<0>(t), thrust::get<1>(t));
383
}; // end host_binary_transform_if_functor
386
template <typename BinaryFunction, typename Predicate>
387
struct device_binary_transform_if_functor
389
BinaryFunction binary_op;
392
device_binary_transform_if_functor(BinaryFunction _binary_op, Predicate _pred)
393
: binary_op(_binary_op), pred(_pred) {}
395
// add __host__ to allow the omp backend compile with nvcc
396
template <typename Tuple>
398
void operator()(Tuple t)
400
if(pred(thrust::get<2>(t)))
401
thrust::get<3>(t) = binary_op(thrust::get<0>(t), thrust::get<1>(t));
403
}; // end device_binary_transform_if_functor
406
template<typename Space, typename BinaryFunction, typename Predicate>
407
struct binary_transform_if_functor
408
: thrust::detail::eval_if<
409
thrust::detail::is_convertible<Space, thrust::host_space_tag>::value,
410
thrust::detail::identity_<host_binary_transform_if_functor<BinaryFunction,Predicate> >,
411
thrust::detail::identity_<device_binary_transform_if_functor<BinaryFunction,Predicate> >
417
struct host_destroy_functor
420
void operator()(T &x) const
423
} // end operator()()
424
}; // end host_destroy_functor
428
struct device_destroy_functor
430
// add __host__ to allow the omp backend to compile with nvcc
432
void operator()(T &x) const
435
} // end operator()()
436
}; // end device_destroy_functor
439
template<typename Space, typename T>
440
struct destroy_functor
441
: thrust::detail::eval_if<
442
thrust::detail::is_convertible<Space, thrust::host_space_tag>::value,
443
thrust::detail::identity_<host_destroy_functor<T> >,
444
thrust::detail::identity_<device_destroy_functor<T> >
449
template <typename T>
454
fill_functor(const T& _exemplar)
455
: exemplar(_exemplar) {}
458
T operator()(void) const
466
struct uninitialized_fill_functor
470
uninitialized_fill_functor(T x):exemplar(x){}
473
void operator()(T &x)
475
::new(static_cast<void*>(&x)) T(exemplar);
476
} // end operator()()
477
}; // end uninitialized_fill_functor
480
// this predicate tests two two-element tuples
481
// we first use a Compare for the first element
482
// if the first elements are equivalent, we use
483
// < for the second elements
484
template<typename Compare>
485
struct compare_first_less_second
487
compare_first_less_second(Compare c)
490
template<typename T1, typename T2>
492
bool operator()(T1 lhs, T2 rhs)
494
return comp(thrust::get<0>(lhs), thrust::get<0>(rhs)) || (!comp(thrust::get<0>(rhs), thrust::get<0>(lhs)) && thrust::get<1>(lhs) < thrust::get<1>(rhs));
498
}; // end compare_first_less_second
501
} // end namespace detail
502
} // end namespace thrust