~dinko-metalac/calculus-app2/trunk

« back to all changes in this revision

Viewing changes to lib/py/sympy/stats/rv.py

  • Committer: dinko.metalac at gmail
  • Date: 2015-04-14 13:28:14 UTC
  • Revision ID: dinko.metalac@gmail.com-20150414132814-j25k3qd7sq3warup
new sympy

Show diffs side-by-side

added added

removed removed

Lines of Context:
12
12
sympy.stats.rv_interface
13
13
"""
14
14
 
 
15
from __future__ import print_function, division
 
16
 
15
17
from sympy import (Basic, S, Expr, Symbol, Tuple, And, Add, Eq, lambdify,
16
18
        sympify, Equality, solve, Lambda, DiracDelta)
17
 
from sympy.core.sets import FiniteSet, ProductSet
 
19
from sympy.core.compatibility import reduce
 
20
from sympy.sets.sets import FiniteSet, ProductSet
18
21
from sympy.abc import x
19
 
from functools import reduce
20
22
 
21
23
 
22
24
class RandomDomain(Basic):
130
132
 
131
133
    is_Finite = None
132
134
    is_Continuous = None
 
135
    is_real = None
133
136
 
134
137
    @property
135
138
    def domain(self):
219
222
    """
220
223
 
221
224
    def __new__(cls, pspace, symbol):
222
 
        assert isinstance(symbol, Symbol)
223
 
        assert isinstance(pspace, PSpace)
 
225
        if not isinstance(symbol, Symbol):
 
226
            raise TypeError("symbol should be of type Symbol")
 
227
        if not isinstance(pspace, PSpace):
 
228
            raise TypeError("pspace variable should be of type PSpace")
224
229
        return Basic.__new__(cls, pspace, symbol)
225
230
 
226
 
    is_bounded = True
227
231
    is_finite = True
228
232
    is_Symbol = True
229
233
    is_Atom = True
234
238
    symbol = property(lambda self: self.args[1])
235
239
    name   = property(lambda self: self.symbol.name)
236
240
 
 
241
    def _eval_is_positive(self):
 
242
        return self.symbol.is_positive
 
243
 
 
244
    def _eval_is_integer(self):
 
245
        return self.symbol.is_integer
 
246
 
 
247
    def _eval_is_real(self):
 
248
        return self.symbol.is_real or self.pspace.is_real
 
249
 
237
250
    @property
238
251
    def is_commutative(self):
239
252
        return self.symbol.is_commutative
260
273
            for value in space.values:
261
274
                rs_space_dict[value] = space
262
275
 
263
 
        symbols = FiniteSet(val.symbol for val in list(rs_space_dict.keys()))
 
276
        symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()])
264
277
 
265
278
        # Overlapping symbols
266
279
        if len(symbols) < sum(len(space.symbols) for space in spaces):
287
300
 
288
301
    @property
289
302
    def symbols(self):
290
 
        return FiniteSet(val.symbol for val in list(self.rs_space_dict.keys()))
 
303
        return FiniteSet(*[val.symbol for val in self.rs_space_dict.keys()])
291
304
 
292
305
    @property
293
306
    def spaces(self):
314
327
 
315
328
    def sample(self):
316
329
        return dict([(k, v) for space in self.spaces
317
 
            for k, v in list(space.sample().items())])
 
330
            for k, v in space.sample().items()])
318
331
 
319
332
 
320
333
class ProductDomain(RandomDomain):
338
351
                domains2.append(domain)
339
352
            else:
340
353
                domains2.extend(domain.domains)
341
 
        domains2 = FiniteSet(domains2)
 
354
        domains2 = FiniteSet(*domains2)
342
355
 
343
356
        if all(domain.is_Finite for domain in domains2):
344
357
            from sympy.stats.frv import ProductFiniteDomain
356
369
 
357
370
    @property
358
371
    def symbols(self):
359
 
        return FiniteSet(sym for domain in self.domains
360
 
                             for sym    in domain.symbols)
 
372
        return FiniteSet(*[sym for domain in self.domains
 
373
                               for sym    in domain.symbols])
361
374
 
362
375
    @property
363
376
    def domains(self):
446
459
 
447
460
 
448
461
def given(expr, condition=None, **kwargs):
449
 
    """
 
462
    """ Conditional Random Expression
450
463
    From a random expression and a condition on that expression creates a new
451
464
    probability space from the condition and returns the same expression on that
452
465
    conditional probability space.
457
470
    >>> from sympy.stats import given, density, Die
458
471
    >>> X = Die('X', 6)
459
472
    >>> Y = given(X, X>3)
460
 
    >>> density(Y)
 
473
    >>> density(Y).dict
461
474
    {4: 1/3, 5: 1/3, 6: 1/3}
 
475
 
 
476
    Following convention, if the condition is a random symbol then that symbol
 
477
    is considered fixed.
 
478
 
 
479
    >>> from sympy.stats import Normal
 
480
    >>> from sympy import pprint
 
481
    >>> from sympy.abc import z
 
482
 
 
483
    >>> X = Normal('X', 0, 1)
 
484
    >>> Y = Normal('Y', 0, 1)
 
485
    >>> pprint(density(X + Y, Y)(z), use_unicode=False)
 
486
                    2
 
487
           -(-Y + z)
 
488
           -----------
 
489
      ___       2
 
490
    \/ 2 *e
 
491
    ------------------
 
492
             ____
 
493
         2*\/ pi
462
494
    """
463
495
 
464
496
    if not random_symbols(condition) or pspace_independent(expr, condition):
465
497
        return expr
466
498
 
 
499
    if isinstance(condition, RandomSymbol):
 
500
        condition = Eq(condition, condition.symbol)
 
501
 
467
502
    condsymbols = random_symbols(condition)
468
503
    if (isinstance(condition, Equality) and len(condsymbols) == 1 and
469
504
        not isinstance(pspace(expr).domain, ConditionalDomain)):
483
518
    return expr
484
519
 
485
520
 
486
 
def expectation(expr, condition=None, numsamples=None, **kwargs):
 
521
def expectation(expr, condition=None, numsamples=None, evaluate=True, **kwargs):
487
522
    """
488
523
    Returns the expected value of a random expression
489
524
 
490
525
    Parameters
491
 
    ----------
 
526
    ==========
 
527
 
492
528
    expr : Expr containing RandomSymbols
493
529
        The expression of which you want to compute the expectation value
494
530
    given : Expr containing RandomSymbols
517
553
    if not random_symbols(expr):  # expr isn't random?
518
554
        return expr
519
555
    if numsamples:  # Computing by monte carlo sampling?
520
 
        return sampling_E(expr, condition, numsamples=numsamples, **kwargs)
 
556
        return sampling_E(expr, condition, numsamples=numsamples)
521
557
 
522
558
    # Create new expr and recompute E
523
559
    if condition is not None:  # If there is a condition
524
 
        return expectation(given(expr, condition, **kwargs), **kwargs)
 
560
        return expectation(given(expr, condition), evaluate=evaluate)
525
561
 
526
562
    # A few known statements for efficiency
527
563
 
528
564
    if expr.is_Add:  # We know that E is Linear
529
 
        return Add(*[expectation(arg, **kwargs) for arg in expr.args])
 
565
        return Add(*[expectation(arg, evaluate=evaluate)
 
566
                     for arg in expr.args])
530
567
 
531
568
    # Otherwise case is simple, pass work off to the ProbabilitySpace
532
 
    return pspace(expr).integrate(expr, **kwargs)
533
 
 
534
 
 
535
 
def probability(condition, given_condition=None, numsamples=None, **kwargs):
 
569
    result = pspace(expr).integrate(expr)
 
570
    if evaluate and hasattr(result, 'doit'):
 
571
        return result.doit(**kwargs)
 
572
    else:
 
573
        return result
 
574
 
 
575
 
 
576
def probability(condition, given_condition=None, numsamples=None,
 
577
                evaluate=True, **kwargs):
536
578
    """
537
579
    Probability that a condition is true, optionally given a second condition
538
580
 
539
581
    Parameters
540
 
    ----------
 
582
    ==========
 
583
 
541
584
    expr : Relational containing RandomSymbols
542
585
        The condition of which you want to compute the probability
543
586
    given_condition : Relational containing RandomSymbols
571
614
        return probability(given(condition, given_condition, **kwargs), **kwargs)
572
615
 
573
616
    # Otherwise pass work off to the ProbabilitySpace
574
 
    return pspace(condition).probability(condition, **kwargs)
 
617
    result = pspace(condition).probability(condition, **kwargs)
 
618
    if evaluate and hasattr(result, 'doit'):
 
619
        return result.doit()
 
620
    else:
 
621
        return result
575
622
 
576
623
 
577
624
class Density(Basic):
584
631
        else:
585
632
            return None
586
633
 
587
 
    def doit(self, **kwargs):
 
634
    def doit(self, evaluate=True, **kwargs):
588
635
        expr, condition = self.expr, self.condition
589
636
        if condition is not None:
590
637
            # Recompute on new conditional expr
591
638
            expr = given(expr, condition, **kwargs)
592
639
        if not random_symbols(expr):
593
 
            return Lambda(x, DiracDelta(x-expr))
594
 
        return pspace(expr).compute_density(expr, **kwargs)
595
 
 
596
 
 
597
 
def density(expr, condition=None, **kwargs):
 
640
            return Lambda(x, DiracDelta(x - expr))
 
641
        if (isinstance(expr, RandomSymbol) and
 
642
            hasattr(expr.pspace, 'distribution') and
 
643
            isinstance(pspace(expr), SinglePSpace)):
 
644
            return expr.pspace.distribution
 
645
        result = pspace(expr).compute_density(expr, **kwargs)
 
646
 
 
647
        if evaluate and hasattr(result, 'doit'):
 
648
            return result.doit()
 
649
        else:
 
650
            return result
 
651
 
 
652
 
 
653
def density(expr, condition=None, evaluate=True, numsamples=None, **kwargs):
598
654
    """
599
 
    Probability density of a random expression
600
 
 
601
 
    Optionally given a second condition
 
655
    Probability density of a random expression, optionally given a second
 
656
    condition.
602
657
 
603
658
    This density will take on different forms for different types of
604
 
    probability spaces.
605
 
    Discrete variables produce Dicts.
606
 
    Continuous variables produce Lambdas.
 
659
    probability spaces. Discrete variables produce Dicts. Continuous
 
660
    variables produce Lambdas.
 
661
 
 
662
    Parameters
 
663
    ==========
 
664
 
 
665
    expr : Expr containing RandomSymbols
 
666
        The expression of which you want to compute the density value
 
667
    condition : Relational containing RandomSymbols
 
668
        A conditional expression. density(X>1, X>0) is density of X>1 given X>0
 
669
    numsamples : int
 
670
        Enables sampling and approximates the density with this many samples
607
671
 
608
672
    Examples
609
673
    ========
611
675
    >>> from sympy.stats import density, Die, Normal
612
676
    >>> from sympy import Symbol
613
677
 
 
678
    >>> x = Symbol('x')
614
679
    >>> D = Die('D', 6)
615
 
    >>> X = Normal('x', 0, 1)
 
680
    >>> X = Normal(x, 0, 1)
616
681
 
617
 
    >>> density(D)
 
682
    >>> density(D).dict
618
683
    {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
619
 
    >>> density(2*D)
 
684
    >>> density(2*D).dict
620
685
    {2: 1/6, 4: 1/6, 6: 1/6, 8: 1/6, 10: 1/6, 12: 1/6}
621
 
    >>> density(X)
622
 
    Lambda(x, sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)))
 
686
    >>> density(X)(x)
 
687
    sqrt(2)*exp(-x**2/2)/(2*sqrt(pi))
623
688
    """
624
 
    return Density(expr, condition).doit(**kwargs)
625
 
 
626
 
 
627
 
def cdf(expr, condition=None, **kwargs):
 
689
 
 
690
    if numsamples:
 
691
        return sampling_density(expr, condition, numsamples=numsamples,
 
692
                **kwargs)
 
693
 
 
694
    return Density(expr, condition).doit(evaluate=evaluate, **kwargs)
 
695
 
 
696
 
 
697
def cdf(expr, condition=None, evaluate=True, **kwargs):
628
698
    """
629
699
    Cumulative Distribution Function of a random expression.
630
700
 
644
714
    >>> D = Die('D', 6)
645
715
    >>> X = Normal('X', 0, 1)
646
716
 
647
 
    >>> density(D)
 
717
    >>> density(D).dict
648
718
    {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
649
719
    >>> cdf(D)
650
720
    {1: 1/6, 2: 1/3, 3: 1/2, 4: 2/3, 5: 5/6, 6: 1}
659
729
        return cdf(given(expr, condition, **kwargs), **kwargs)
660
730
 
661
731
    # Otherwise pass work off to the ProbabilitySpace
662
 
    return pspace(expr).compute_cdf(expr, **kwargs)
 
732
    result = pspace(expr).compute_cdf(expr, **kwargs)
 
733
 
 
734
    if evaluate and hasattr(result, 'doit'):
 
735
        return result.doit()
 
736
    else:
 
737
        return result
663
738
 
664
739
 
665
740
def where(condition, given_condition=None, **kwargs):
703
778
    >>> from sympy.stats import Die, sample
704
779
    >>> X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6)
705
780
 
706
 
    >>> die_roll = sample(X+Y+Z) # A random realization of three dice
 
781
    >>> die_roll = sample(X + Y + Z) # A random realization of three dice
707
782
    """
708
783
    return next(sample_iter(expr, condition, numsamples=1))
709
784
 
717
792
    numsamples: Length of the iterator (defaults to infinity)
718
793
 
719
794
    Examples
720
 
    --------
 
795
    ========
 
796
 
721
797
    >>> from sympy.stats import Normal, sample_iter
722
798
    >>> X = Normal('X', 0, 1)
723
799
    >>> expr = X*X + 3
765
841
        fn(*args)
766
842
        if condition:
767
843
            given_fn(*args)
768
 
    except:
 
844
    except Exception:
769
845
        raise TypeError("Expr/condition too complex for lambdify")
770
846
 
771
847
    def return_generator():
776
852
 
777
853
            if condition:  # Check that these values satisfy the condition
778
854
                gd = given_fn(*args)
779
 
                if not isinstance(gd, bool):
 
855
                if gd != True and gd != False:
780
856
                    raise ValueError(
781
857
                        "Conditions must not contain free symbols")
782
 
                if gd is False:  # If the values don't satisfy then try again
 
858
                if not gd:  # If the values don't satisfy then try again
783
859
                    continue
784
860
 
785
861
            yield fn(*args)
799
875
        ps = pspace(expr)
800
876
 
801
877
    count = 0
802
 
 
803
878
    while count < numsamples:
804
879
        d = ps.sample()  # a dictionary that maps RVs to values
805
880
 
806
881
        if condition is not None:  # Check that these values satisfy the condition
807
882
            gd = condition.xreplace(d)
808
 
            if not isinstance(gd, bool):
 
883
            if gd != True and gd != False:
809
884
                raise ValueError("Conditions must not contain free symbols")
810
 
            if gd is False:  # If the values don't satisfy then try again
 
885
            if not gd:  # If the values don't satisfy then try again
811
886
                continue
812
887
 
813
888
        yield expr.xreplace(d)
814
 
 
815
889
        count += 1
816
890
 
817
891
 
824
898
    ========
825
899
    P
826
900
    sampling_E
 
901
    sampling_density
827
902
    """
828
903
 
829
904
    count_true = 0
833
908
                          numsamples=numsamples, **kwargs)
834
909
 
835
910
    for x in samples:
836
 
        if not isinstance(x, bool):
 
911
        if x != True and x != False:
837
912
            raise ValueError("Conditions must not contain free symbols")
838
913
 
839
 
        if x is True:
 
914
        if x:
840
915
            count_true += 1
841
916
        else:
842
917
            count_false += 1
848
923
        return result
849
924
 
850
925
 
851
 
def sampling_E(condition, given_condition=None, numsamples=1,
 
926
def sampling_E(expr, given_condition=None, numsamples=1,
852
927
               evalf=True, **kwargs):
853
928
    """
854
929
    Sampling version of E
857
932
    ========
858
933
    P
859
934
    sampling_P
 
935
    sampling_density
860
936
    """
861
937
 
862
 
    samples = sample_iter(condition, given_condition,
 
938
    samples = sample_iter(expr, given_condition,
863
939
                          numsamples=numsamples, **kwargs)
864
940
 
865
941
    result = Add(*list(samples)) / numsamples
868
944
    else:
869
945
        return result
870
946
 
 
947
def sampling_density(expr, given_condition=None, numsamples=1, **kwargs):
 
948
    """
 
949
    Sampling version of density
 
950
 
 
951
    See Also
 
952
    ========
 
953
    density
 
954
    sampling_P
 
955
    sampling_E
 
956
    """
 
957
 
 
958
    results = {}
 
959
    for result in sample_iter(expr, given_condition,
 
960
                              numsamples=numsamples, **kwargs):
 
961
        results[result] = results.get(result, 0) + 1
 
962
    return results
 
963
 
871
964
 
872
965
def dependent(a, b):
873
966
    """
887
980
    False
888
981
    >>> dependent(2*X + Y, -Y)
889
982
    True
890
 
    >>> X, Y = given(Tuple(X, Y), Eq(X+Y,3))
 
983
    >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
891
984
    >>> dependent(X, Y)
892
985
    True
893
986
 
923
1016
    True
924
1017
    >>> independent(2*X + Y, -Y)
925
1018
    False
926
 
    >>> X, Y = given(Tuple(X, Y), Eq(X+Y,3))
 
1019
    >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
927
1020
    >>> independent(X, Y)
928
1021
    False
929
1022
 
939
1032
    Tests for independence between a and b by checking if their PSpaces have
940
1033
    overlapping symbols. This is a sufficient but not necessary condition for
941
1034
    independence and is intended to be used internally.
942
 
    Note:
943
 
    pspace_independent(a,b) implies independent(a,b)
944
 
    independent(a,b) does not imply pspace_independent(a,b)
 
1035
 
 
1036
    Notes
 
1037
    =====
 
1038
 
 
1039
    pspace_independent(a, b) implies independent(a, b)
 
1040
    independent(a, b) does not imply pspace_independent(a, b)
945
1041
    """
946
1042
    a_symbols = pspace(b).symbols
947
1043
    b_symbols = pspace(a).symbols
968
1064
 
969
1065
    def __getattr__(self, attr):
970
1066
        try:
971
 
            return self.args[list(self._argnames).index(attr)]
 
1067
            return self.args[self._argnames.index(attr)]
972
1068
        except ValueError:
973
1069
            raise AttributeError("'%s' object has not attribute '%s'" % (
974
1070
                type(self).__name__, attr))
979
1075
 
980
1076
    Raises ValueError with message if condition is not True
981
1077
    """
982
 
    if condition is not True:
 
1078
    if condition != True:
983
1079
        raise ValueError(message)