12
12
sympy.stats.rv_interface
15
from __future__ import print_function, division
15
17
from sympy import (Basic, S, Expr, Symbol, Tuple, And, Add, Eq, lambdify,
16
18
sympify, Equality, solve, Lambda, DiracDelta)
17
from sympy.core.sets import FiniteSet, ProductSet
19
from sympy.core.compatibility import reduce
20
from sympy.sets.sets import FiniteSet, ProductSet
18
21
from sympy.abc import x
19
from functools import reduce
22
24
class RandomDomain(Basic):
221
224
def __new__(cls, pspace, symbol):
222
assert isinstance(symbol, Symbol)
223
assert isinstance(pspace, PSpace)
225
if not isinstance(symbol, Symbol):
226
raise TypeError("symbol should be of type Symbol")
227
if not isinstance(pspace, PSpace):
228
raise TypeError("pspace variable should be of type PSpace")
224
229
return Basic.__new__(cls, pspace, symbol)
234
238
symbol = property(lambda self: self.args[1])
235
239
name = property(lambda self: self.symbol.name)
241
def _eval_is_positive(self):
242
return self.symbol.is_positive
244
def _eval_is_integer(self):
245
return self.symbol.is_integer
247
def _eval_is_real(self):
248
return self.symbol.is_real or self.pspace.is_real
238
251
def is_commutative(self):
239
252
return self.symbol.is_commutative
260
273
for value in space.values:
261
274
rs_space_dict[value] = space
263
symbols = FiniteSet(val.symbol for val in list(rs_space_dict.keys()))
276
symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()])
265
278
# Overlapping symbols
266
279
if len(symbols) < sum(len(space.symbols) for space in spaces):
315
328
def sample(self):
316
329
return dict([(k, v) for space in self.spaces
317
for k, v in list(space.sample().items())])
330
for k, v in space.sample().items()])
320
333
class ProductDomain(RandomDomain):
338
351
domains2.append(domain)
340
353
domains2.extend(domain.domains)
341
domains2 = FiniteSet(domains2)
354
domains2 = FiniteSet(*domains2)
343
356
if all(domain.is_Finite for domain in domains2):
344
357
from sympy.stats.frv import ProductFiniteDomain
448
461
def given(expr, condition=None, **kwargs):
462
""" Conditional Random Expression
450
463
From a random expression and a condition on that expression creates a new
451
464
probability space from the condition and returns the same expression on that
452
465
conditional probability space.
457
470
>>> from sympy.stats import given, density, Die
458
471
>>> X = Die('X', 6)
459
472
>>> Y = given(X, X>3)
461
474
{4: 1/3, 5: 1/3, 6: 1/3}
476
Following convention, if the condition is a random symbol then that symbol
479
>>> from sympy.stats import Normal
480
>>> from sympy import pprint
481
>>> from sympy.abc import z
483
>>> X = Normal('X', 0, 1)
484
>>> Y = Normal('Y', 0, 1)
485
>>> pprint(density(X + Y, Y)(z), use_unicode=False)
464
496
if not random_symbols(condition) or pspace_independent(expr, condition):
499
if isinstance(condition, RandomSymbol):
500
condition = Eq(condition, condition.symbol)
467
502
condsymbols = random_symbols(condition)
468
503
if (isinstance(condition, Equality) and len(condsymbols) == 1 and
469
504
not isinstance(pspace(expr).domain, ConditionalDomain)):
486
def expectation(expr, condition=None, numsamples=None, **kwargs):
521
def expectation(expr, condition=None, numsamples=None, evaluate=True, **kwargs):
488
523
Returns the expected value of a random expression
492
528
expr : Expr containing RandomSymbols
493
529
The expression of which you want to compute the expectation value
494
530
given : Expr containing RandomSymbols
517
553
if not random_symbols(expr): # expr isn't random?
519
555
if numsamples: # Computing by monte carlo sampling?
520
return sampling_E(expr, condition, numsamples=numsamples, **kwargs)
556
return sampling_E(expr, condition, numsamples=numsamples)
522
558
# Create new expr and recompute E
523
559
if condition is not None: # If there is a condition
524
return expectation(given(expr, condition, **kwargs), **kwargs)
560
return expectation(given(expr, condition), evaluate=evaluate)
526
562
# A few known statements for efficiency
528
564
if expr.is_Add: # We know that E is Linear
529
return Add(*[expectation(arg, **kwargs) for arg in expr.args])
565
return Add(*[expectation(arg, evaluate=evaluate)
566
for arg in expr.args])
531
568
# Otherwise case is simple, pass work off to the ProbabilitySpace
532
return pspace(expr).integrate(expr, **kwargs)
535
def probability(condition, given_condition=None, numsamples=None, **kwargs):
569
result = pspace(expr).integrate(expr)
570
if evaluate and hasattr(result, 'doit'):
571
return result.doit(**kwargs)
576
def probability(condition, given_condition=None, numsamples=None,
577
evaluate=True, **kwargs):
537
579
Probability that a condition is true, optionally given a second condition
541
584
expr : Relational containing RandomSymbols
542
585
The condition of which you want to compute the probability
543
586
given_condition : Relational containing RandomSymbols
571
614
return probability(given(condition, given_condition, **kwargs), **kwargs)
573
616
# Otherwise pass work off to the ProbabilitySpace
574
return pspace(condition).probability(condition, **kwargs)
617
result = pspace(condition).probability(condition, **kwargs)
618
if evaluate and hasattr(result, 'doit'):
577
624
class Density(Basic):
587
def doit(self, **kwargs):
634
def doit(self, evaluate=True, **kwargs):
588
635
expr, condition = self.expr, self.condition
589
636
if condition is not None:
590
637
# Recompute on new conditional expr
591
638
expr = given(expr, condition, **kwargs)
592
639
if not random_symbols(expr):
593
return Lambda(x, DiracDelta(x-expr))
594
return pspace(expr).compute_density(expr, **kwargs)
597
def density(expr, condition=None, **kwargs):
640
return Lambda(x, DiracDelta(x - expr))
641
if (isinstance(expr, RandomSymbol) and
642
hasattr(expr.pspace, 'distribution') and
643
isinstance(pspace(expr), SinglePSpace)):
644
return expr.pspace.distribution
645
result = pspace(expr).compute_density(expr, **kwargs)
647
if evaluate and hasattr(result, 'doit'):
653
def density(expr, condition=None, evaluate=True, numsamples=None, **kwargs):
599
Probability density of a random expression
601
Optionally given a second condition
655
Probability density of a random expression, optionally given a second
603
658
This density will take on different forms for different types of
605
Discrete variables produce Dicts.
606
Continuous variables produce Lambdas.
659
probability spaces. Discrete variables produce Dicts. Continuous
660
variables produce Lambdas.
665
expr : Expr containing RandomSymbols
666
The expression of which you want to compute the density value
667
condition : Relational containing RandomSymbols
668
A conditional expression. density(X>1, X>0) is density of X>1 given X>0
670
Enables sampling and approximates the density with this many samples
611
675
>>> from sympy.stats import density, Die, Normal
612
676
>>> from sympy import Symbol
614
679
>>> D = Die('D', 6)
615
>>> X = Normal('x', 0, 1)
680
>>> X = Normal(x, 0, 1)
618
683
{1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
684
>>> density(2*D).dict
620
685
{2: 1/6, 4: 1/6, 6: 1/6, 8: 1/6, 10: 1/6, 12: 1/6}
622
Lambda(x, sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)))
687
sqrt(2)*exp(-x**2/2)/(2*sqrt(pi))
624
return Density(expr, condition).doit(**kwargs)
627
def cdf(expr, condition=None, **kwargs):
691
return sampling_density(expr, condition, numsamples=numsamples,
694
return Density(expr, condition).doit(evaluate=evaluate, **kwargs)
697
def cdf(expr, condition=None, evaluate=True, **kwargs):
629
699
Cumulative Distribution Function of a random expression.
644
714
>>> D = Die('D', 6)
645
715
>>> X = Normal('X', 0, 1)
648
718
{1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
650
720
{1: 1/6, 2: 1/3, 3: 1/2, 4: 2/3, 5: 5/6, 6: 1}
659
729
return cdf(given(expr, condition, **kwargs), **kwargs)
661
731
# Otherwise pass work off to the ProbabilitySpace
662
return pspace(expr).compute_cdf(expr, **kwargs)
732
result = pspace(expr).compute_cdf(expr, **kwargs)
734
if evaluate and hasattr(result, 'doit'):
665
740
def where(condition, given_condition=None, **kwargs):
703
778
>>> from sympy.stats import Die, sample
704
779
>>> X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6)
706
>>> die_roll = sample(X+Y+Z) # A random realization of three dice
781
>>> die_roll = sample(X + Y + Z) # A random realization of three dice
708
783
return next(sample_iter(expr, condition, numsamples=1))
777
853
if condition: # Check that these values satisfy the condition
778
854
gd = given_fn(*args)
779
if not isinstance(gd, bool):
855
if gd != True and gd != False:
780
856
raise ValueError(
781
857
"Conditions must not contain free symbols")
782
if gd is False: # If the values don't satisfy then try again
858
if not gd: # If the values don't satisfy then try again
799
875
ps = pspace(expr)
803
878
while count < numsamples:
804
879
d = ps.sample() # a dictionary that maps RVs to values
806
881
if condition is not None: # Check that these values satisfy the condition
807
882
gd = condition.xreplace(d)
808
if not isinstance(gd, bool):
883
if gd != True and gd != False:
809
884
raise ValueError("Conditions must not contain free symbols")
810
if gd is False: # If the values don't satisfy then try again
885
if not gd: # If the values don't satisfy then try again
813
888
yield expr.xreplace(d)
862
samples = sample_iter(condition, given_condition,
938
samples = sample_iter(expr, given_condition,
863
939
numsamples=numsamples, **kwargs)
865
941
result = Add(*list(samples)) / numsamples
947
def sampling_density(expr, given_condition=None, numsamples=1, **kwargs):
949
Sampling version of density
959
for result in sample_iter(expr, given_condition,
960
numsamples=numsamples, **kwargs):
961
results[result] = results.get(result, 0) + 1
872
965
def dependent(a, b):
888
981
>>> dependent(2*X + Y, -Y)
890
>>> X, Y = given(Tuple(X, Y), Eq(X+Y,3))
983
>>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
891
984
>>> dependent(X, Y)
924
1017
>>> independent(2*X + Y, -Y)
926
>>> X, Y = given(Tuple(X, Y), Eq(X+Y,3))
1019
>>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
927
1020
>>> independent(X, Y)
939
1032
Tests for independence between a and b by checking if their PSpaces have
940
1033
overlapping symbols. This is a sufficient but not necessary condition for
941
1034
independence and is intended to be used internally.
943
pspace_independent(a,b) implies independent(a,b)
944
independent(a,b) does not imply pspace_independent(a,b)
1039
pspace_independent(a, b) implies independent(a, b)
1040
independent(a, b) does not imply pspace_independent(a, b)
946
1042
a_symbols = pspace(b).symbols
947
1043
b_symbols = pspace(a).symbols
969
1065
def __getattr__(self, attr):
971
return self.args[list(self._argnames).index(attr)]
1067
return self.args[self._argnames.index(attr)]
972
1068
except ValueError:
973
1069
raise AttributeError("'%s' object has not attribute '%s'" % (
974
1070
type(self).__name__, attr))