~ubuntu-branches/ubuntu/saucy/python-scipy/saucy

« back to all changes in this revision

Viewing changes to Lib/sandbox/pyem/tests/test_online_em.py

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#! /usr/bin/env python
2
 
# Last Change: Wed Dec 06 09:00 PM 2006 J
3
 
 
4
 
import copy
5
 
 
6
 
import sys
7
 
from numpy.testing import *
8
 
 
9
 
import numpy as N
10
 
from numpy.random import seed
11
 
 
12
 
set_package_path()
13
 
from pyem import GM, GMM
14
 
from pyem.online_em import OnGMM, OnGMM1d
15
 
restore_path()
16
 
 
17
 
# #Optional:
18
 
# set_local_path()
19
 
# # import modules that are located in the same directory as this file.
20
 
# restore_path()
21
 
 
22
 
# Error precision allowed (nb of decimals)
23
 
AR_AS_PREC  = 12
24
 
KM_ITER     = 5
25
 
 
26
 
class OnlineEmTest(NumpyTestCase):
27
 
    def _create_model(self, d, k, mode, nframes, emiter):
28
 
        #+++++++++++++++++++++++++++++++++++++++++++++++++
29
 
        # Generate a model with k components, d dimensions
30
 
        #+++++++++++++++++++++++++++++++++++++++++++++++++
31
 
        w, mu, va   = GM.gen_param(d, k, mode, spread = 1.5)
32
 
        gm          = GM.fromvalues(w, mu, va)
33
 
        # Sample nframes frames  from the model
34
 
        data        = gm.sample(nframes)
35
 
 
36
 
        #++++++++++++++++++++++++++++++++++++++++++
37
 
        # Approximate the models with classical EM
38
 
        #++++++++++++++++++++++++++++++++++++++++++
39
 
        # Init the model
40
 
        lgm = GM(d, k, mode)
41
 
        gmm = GMM(lgm, 'kmean')
42
 
        gmm.init(data, niter = KM_ITER)
43
 
 
44
 
        self.gm0    = copy.copy(gmm.gm)
45
 
        # The actual EM, with likelihood computation
46
 
        for i in range(emiter):
47
 
            g, tgd  = gmm.sufficient_statistics(data)
48
 
            gmm.update_em(data, g)
49
 
 
50
 
        self.data   = data
51
 
        self.gm     = lgm
52
 
    
53
 
class test_on_off_eq(OnlineEmTest):
54
 
    def check_1d(self, level = 1):
55
 
        d       = 1
56
 
        k       = 2
57
 
        mode    = 'diag'
58
 
        nframes = int(1e2)
59
 
        emiter  = 3
60
 
 
61
 
        seed(1)
62
 
        self._create_model(d, k, mode, nframes, emiter)
63
 
        self._check(d, k, mode, nframes, emiter)
64
 
 
65
 
    def check_2d(self, level = 1):
66
 
        d       = 2
67
 
        k       = 2
68
 
        mode    = 'diag'
69
 
        nframes = int(1e2)
70
 
        emiter  = 3
71
 
 
72
 
        seed(1)
73
 
        self._create_model(d, k, mode, nframes, emiter)
74
 
        self._check(d, k, mode, nframes, emiter)
75
 
 
76
 
    def check_5d(self, level = 5):
77
 
        d       = 5
78
 
        k       = 2
79
 
        mode    = 'diag'
80
 
        nframes = int(1e2)
81
 
        emiter  = 3
82
 
 
83
 
        seed(1)
84
 
        self._create_model(d, k, mode, nframes, emiter)
85
 
        self._check(d, k, mode, nframes, emiter)
86
 
 
87
 
    def _check(self, d, k, mode, nframes, emiter):
88
 
        #++++++++++++++++++++++++++++++++++++++++
89
 
        # Approximate the models with online EM
90
 
        #++++++++++++++++++++++++++++++++++++++++
91
 
        # Learn the model with Online EM
92
 
        ogm         = GM(d, k, mode)
93
 
        ogmm        = OnGMM(ogm, 'kmean')
94
 
        init_data   = self.data
95
 
        ogmm.init(init_data, niter = KM_ITER)
96
 
 
97
 
        # Check that online kmean init is the same than kmean offline init
98
 
        ogm0    = copy.copy(ogm)
99
 
        assert_array_equal(ogm0.w, self.gm0.w)
100
 
        assert_array_equal(ogm0.mu, self.gm0.mu)
101
 
        assert_array_equal(ogm0.va, self.gm0.va)
102
 
 
103
 
        # Forgetting param
104
 
        lamb    = N.ones((nframes, 1))
105
 
        lamb[0] = 0
106
 
        nu0             = 1.0
107
 
        nu              = N.zeros((len(lamb), 1))
108
 
        nu[0]   = nu0
109
 
        for i in range(1, len(lamb)):
110
 
            nu[i]       = 1./(1 + lamb[i] / nu[i-1])
111
 
 
112
 
        # object version of online EM: the p* arguments are updated only at each 
113
 
        # epoch, which is equivalent to on full EM iteration on the 
114
 
        # classic EM algorithm
115
 
        ogmm.pw    = ogmm.cw.copy()
116
 
        ogmm.pmu   = ogmm.cmu.copy()
117
 
        ogmm.pva   = ogmm.cva.copy()
118
 
        for e in range(emiter):
119
 
            for t in range(nframes):
120
 
                ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
121
 
                ogmm.update_em_frame()
122
 
 
123
 
            # Change pw args only a each epoch 
124
 
            ogmm.pw  = ogmm.cw.copy()
125
 
            ogmm.pmu = ogmm.cmu.copy()
126
 
            ogmm.pva = ogmm.cva.copy()
127
 
 
128
 
        # For equivalence between off and on, we allow a margin of error,
129
 
        # because of round-off errors.
130
 
        print " Checking precision of equivalence with offline EM trainer "
131
 
        maxtestprec = 18
132
 
        try :
133
 
            for i in range(maxtestprec):
134
 
                    assert_array_almost_equal(self.gm.w, ogmm.pw, decimal = i)
135
 
                    assert_array_almost_equal(self.gm.mu, ogmm.pmu, decimal = i)
136
 
                    assert_array_almost_equal(self.gm.va, ogmm.pva, decimal = i)
137
 
            print "\t !! Precision up to %d decimals !! " % i
138
 
        except AssertionError:
139
 
            if i < AR_AS_PREC:
140
 
                print """\t !!NOT OK: Precision up to %d decimals only, 
141
 
                    outside the allowed range (%d) !! """ % (i, AR_AS_PREC)
142
 
                raise AssertionError
143
 
            else:
144
 
                print "\t !!OK: Precision up to %d decimals !! " % i
145
 
 
146
 
class test_on(OnlineEmTest):
147
 
    def check_consistency(self):
148
 
        d       = 1
149
 
        k       = 2
150
 
        mode    = 'diag'
151
 
        nframes = int(5e2)
152
 
        emiter  = 4
153
 
 
154
 
        self._create_model(d, k, mode, nframes, emiter)
155
 
        self._run_pure_online(d, k, mode, nframes)
156
 
    
157
 
    def check_1d_imp(self):
158
 
        d       = 1
159
 
        k       = 2
160
 
        mode    = 'diag'
161
 
        nframes = int(5e2)
162
 
        emiter  = 4
163
 
 
164
 
        self._create_model(d, k, mode, nframes, emiter)
165
 
        gmref   = self._run_pure_online(d, k, mode, nframes)
166
 
        gmtest  = self._run_pure_online_1d(d, k, mode, nframes)
167
 
    
168
 
        assert_array_almost_equal(gmref.w, gmtest.w, AR_AS_PREC)
169
 
        assert_array_almost_equal(gmref.mu, gmtest.mu, AR_AS_PREC)
170
 
        assert_array_almost_equal(gmref.va, gmtest.va, AR_AS_PREC)
171
 
 
172
 
    def _run_pure_online_1d(self, d, k, mode, nframes):
173
 
        #++++++++++++++++++++++++++++++++++++++++
174
 
        # Approximate the models with online EM
175
 
        #++++++++++++++++++++++++++++++++++++++++
176
 
        ogm     = GM(d, k, mode)
177
 
        ogmm    = OnGMM1d(ogm, 'kmean')
178
 
        init_data   = self.data[0:nframes / 20, :]
179
 
        ogmm.init(init_data[:, 0])
180
 
 
181
 
        # Forgetting param
182
 
        ku              = 0.005
183
 
        t0              = 200
184
 
        lamb    = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
185
 
        nu0             = 0.2
186
 
        nu              = N.zeros((len(lamb), 1))
187
 
        nu[0]   = nu0
188
 
        for i in range(1, len(lamb)):
189
 
            nu[i]       = 1./(1 + lamb[i] / nu[i-1])
190
 
 
191
 
        # object version of online EM
192
 
        for t in range(nframes):
193
 
            # the assert are here to check we do not create copies
194
 
            # unvoluntary for parameters
195
 
            a, b, c = ogmm.compute_sufficient_statistics_frame(self.data[t, 0], nu[t])
196
 
            ogmm.update_em_frame(a, b, c)
197
 
 
198
 
        ogmm.gm.set_param(ogmm.cw, ogmm.cmu[:, N.newaxis], ogmm.cva[:, N.newaxis])
199
 
 
200
 
        return ogmm.gm
201
 
    def _run_pure_online(self, d, k, mode, nframes):
202
 
        #++++++++++++++++++++++++++++++++++++++++
203
 
        # Approximate the models with online EM
204
 
        #++++++++++++++++++++++++++++++++++++++++
205
 
        ogm     = GM(d, k, mode)
206
 
        ogmm    = OnGMM(ogm, 'kmean')
207
 
        init_data   = self.data[0:nframes / 20, :]
208
 
        ogmm.init(init_data)
209
 
 
210
 
        # Forgetting param
211
 
        ku              = 0.005
212
 
        t0              = 200
213
 
        lamb    = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
214
 
        nu0             = 0.2
215
 
        nu              = N.zeros((len(lamb), 1))
216
 
        nu[0]   = nu0
217
 
        for i in range(1, len(lamb)):
218
 
            nu[i]       = 1./(1 + lamb[i] / nu[i-1])
219
 
 
220
 
        # object version of online EM
221
 
        for t in range(nframes):
222
 
            # the assert are here to check we do not create copies
223
 
            # unvoluntary for parameters
224
 
            assert ogmm.pw is ogmm.cw
225
 
            assert ogmm.pmu is ogmm.cmu
226
 
            assert ogmm.pva is ogmm.cva
227
 
            ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
228
 
            ogmm.update_em_frame()
229
 
 
230
 
        ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
231
 
 
232
 
        return ogmm.gm
233
 
if __name__ == "__main__":
234
 
    NumpyTest().run()