2
* Copyright 2008-2012 NVIDIA Corporation
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
8
* http://www.apache.org/licenses/LICENSE-2.0
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
18
* Copyright Jens Maurer 2000-2001
19
* Distributed under the Boost Software License, Version 1.0. (See
20
* accompanying file LICENSE_1_0.txt or copy at
21
* http://www.boost.org/LICENSE_1_0.txt)
26
#include <thrust/detail/config.h>
27
#include <thrust/pair.h>
28
#include <thrust/random/uniform_real_distribution.h>
39
// this version samples the normal distribution directly
40
// and uses the non-standard math function erfcinv
41
template<typename RealType>
42
class normal_distribution_nvcc
45
template<typename UniformRandomNumberGenerator>
47
RealType sample(UniformRandomNumberGenerator &urng, const RealType mean, const RealType stddev)
49
typedef typename UniformRandomNumberGenerator::result_type uint_type;
50
const uint_type urng_range = UniformRandomNumberGenerator::max - UniformRandomNumberGenerator::min;
52
// Constants for conversion
53
const RealType S1 = static_cast<RealType>(1) / urng_range;
54
const RealType S2 = S1 / 2;
56
RealType S3 = static_cast<RealType>(-1.4142135623730950488016887242097); // -sqrt(2)
58
// Get the integer value
59
uint_type u = urng() - UniformRandomNumberGenerator::min;
61
// Ensure the conversion to float will give a value in the range [0,0.5)
62
if(u > (urng_range / 2))
68
// Convert to floating point in [0,0.5)
69
RealType p = u*S1 + S2;
71
// Apply inverse error function
72
return mean + stddev * S3 * erfcinv(2 * p);
80
// this version samples the normal distribution using
81
// Marsaglia's "polar method"
82
template<typename RealType>
83
class normal_distribution_portable
86
normal_distribution_portable()
90
normal_distribution_portable(const normal_distribution_portable &other)
91
: m_valid(other.m_valid)
99
// note that we promise to call this member function with the same mean and stddev
100
template<typename UniformRandomNumberGenerator>
102
RealType sample(UniformRandomNumberGenerator &urng, const RealType mean, const RealType stddev)
104
// implementation from Boost
105
// allow for Koenig lookup
106
using std::sqrt; using std::log; using std::sin; using std::cos;
110
uniform_real_distribution<RealType> u01;
113
m_cached_rho = sqrt(-RealType(2) * log(RealType(1)-m_r2));
122
const RealType pi = RealType(3.14159265358979323846);
124
RealType result = m_cached_rho * (m_valid ?
125
cos(RealType(2)*pi*m_r1) :
126
sin(RealType(2)*pi*m_r1));
132
RealType m_r1, m_r2, m_cached_rho;
136
template<typename RealType>
137
struct normal_distribution_base
139
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC
140
typedef normal_distribution_nvcc<RealType> type;
142
typedef normal_distribution_portable<RealType> type;