2
# Last Change: Wed Dec 06 09:00 PM 2006 J
7
from numpy.testing import *
10
from numpy.random import seed
13
from pyem import GM, GMM
14
from pyem.online_em import OnGMM, OnGMM1d
19
# # import modules that are located in the same directory as this file.
22
# Error precision allowed (nb of decimals)
26
class OnlineEmTest(NumpyTestCase):
27
def _create_model(self, d, k, mode, nframes, emiter):
28
#+++++++++++++++++++++++++++++++++++++++++++++++++
29
# Generate a model with k components, d dimensions
30
#+++++++++++++++++++++++++++++++++++++++++++++++++
31
w, mu, va = GM.gen_param(d, k, mode, spread = 1.5)
32
gm = GM.fromvalues(w, mu, va)
33
# Sample nframes frames from the model
34
data = gm.sample(nframes)
36
#++++++++++++++++++++++++++++++++++++++++++
37
# Approximate the models with classical EM
38
#++++++++++++++++++++++++++++++++++++++++++
41
gmm = GMM(lgm, 'kmean')
42
gmm.init(data, niter = KM_ITER)
44
self.gm0 = copy.copy(gmm.gm)
45
# The actual EM, with likelihood computation
46
for i in range(emiter):
47
g, tgd = gmm.sufficient_statistics(data)
48
gmm.update_em(data, g)
53
class test_on_off_eq(OnlineEmTest):
54
def check_1d(self, level = 1):
62
self._create_model(d, k, mode, nframes, emiter)
63
self._check(d, k, mode, nframes, emiter)
65
def check_2d(self, level = 1):
73
self._create_model(d, k, mode, nframes, emiter)
74
self._check(d, k, mode, nframes, emiter)
76
def check_5d(self, level = 5):
84
self._create_model(d, k, mode, nframes, emiter)
85
self._check(d, k, mode, nframes, emiter)
87
def _check(self, d, k, mode, nframes, emiter):
88
#++++++++++++++++++++++++++++++++++++++++
89
# Approximate the models with online EM
90
#++++++++++++++++++++++++++++++++++++++++
91
# Learn the model with Online EM
93
ogmm = OnGMM(ogm, 'kmean')
95
ogmm.init(init_data, niter = KM_ITER)
97
# Check that online kmean init is the same than kmean offline init
99
assert_array_equal(ogm0.w, self.gm0.w)
100
assert_array_equal(ogm0.mu, self.gm0.mu)
101
assert_array_equal(ogm0.va, self.gm0.va)
104
lamb = N.ones((nframes, 1))
107
nu = N.zeros((len(lamb), 1))
109
for i in range(1, len(lamb)):
110
nu[i] = 1./(1 + lamb[i] / nu[i-1])
112
# object version of online EM: the p* arguments are updated only at each
113
# epoch, which is equivalent to on full EM iteration on the
114
# classic EM algorithm
115
ogmm.pw = ogmm.cw.copy()
116
ogmm.pmu = ogmm.cmu.copy()
117
ogmm.pva = ogmm.cva.copy()
118
for e in range(emiter):
119
for t in range(nframes):
120
ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
121
ogmm.update_em_frame()
123
# Change pw args only a each epoch
124
ogmm.pw = ogmm.cw.copy()
125
ogmm.pmu = ogmm.cmu.copy()
126
ogmm.pva = ogmm.cva.copy()
128
# For equivalence between off and on, we allow a margin of error,
129
# because of round-off errors.
130
print " Checking precision of equivalence with offline EM trainer "
133
for i in range(maxtestprec):
134
assert_array_almost_equal(self.gm.w, ogmm.pw, decimal = i)
135
assert_array_almost_equal(self.gm.mu, ogmm.pmu, decimal = i)
136
assert_array_almost_equal(self.gm.va, ogmm.pva, decimal = i)
137
print "\t !! Precision up to %d decimals !! " % i
138
except AssertionError:
140
print """\t !!NOT OK: Precision up to %d decimals only,
141
outside the allowed range (%d) !! """ % (i, AR_AS_PREC)
144
print "\t !!OK: Precision up to %d decimals !! " % i
146
class test_on(OnlineEmTest):
147
def check_consistency(self):
154
self._create_model(d, k, mode, nframes, emiter)
155
self._run_pure_online(d, k, mode, nframes)
157
def check_1d_imp(self):
164
self._create_model(d, k, mode, nframes, emiter)
165
gmref = self._run_pure_online(d, k, mode, nframes)
166
gmtest = self._run_pure_online_1d(d, k, mode, nframes)
168
assert_array_almost_equal(gmref.w, gmtest.w, AR_AS_PREC)
169
assert_array_almost_equal(gmref.mu, gmtest.mu, AR_AS_PREC)
170
assert_array_almost_equal(gmref.va, gmtest.va, AR_AS_PREC)
172
def _run_pure_online_1d(self, d, k, mode, nframes):
173
#++++++++++++++++++++++++++++++++++++++++
174
# Approximate the models with online EM
175
#++++++++++++++++++++++++++++++++++++++++
177
ogmm = OnGMM1d(ogm, 'kmean')
178
init_data = self.data[0:nframes / 20, :]
179
ogmm.init(init_data[:, 0])
184
lamb = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
186
nu = N.zeros((len(lamb), 1))
188
for i in range(1, len(lamb)):
189
nu[i] = 1./(1 + lamb[i] / nu[i-1])
191
# object version of online EM
192
for t in range(nframes):
193
# the assert are here to check we do not create copies
194
# unvoluntary for parameters
195
a, b, c = ogmm.compute_sufficient_statistics_frame(self.data[t, 0], nu[t])
196
ogmm.update_em_frame(a, b, c)
198
ogmm.gm.set_param(ogmm.cw, ogmm.cmu[:, N.newaxis], ogmm.cva[:, N.newaxis])
201
def _run_pure_online(self, d, k, mode, nframes):
202
#++++++++++++++++++++++++++++++++++++++++
203
# Approximate the models with online EM
204
#++++++++++++++++++++++++++++++++++++++++
206
ogmm = OnGMM(ogm, 'kmean')
207
init_data = self.data[0:nframes / 20, :]
213
lamb = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
215
nu = N.zeros((len(lamb), 1))
217
for i in range(1, len(lamb)):
218
nu[i] = 1./(1 + lamb[i] / nu[i-1])
220
# object version of online EM
221
for t in range(nframes):
222
# the assert are here to check we do not create copies
223
# unvoluntary for parameters
224
assert ogmm.pw is ogmm.cw
225
assert ogmm.pmu is ogmm.cmu
226
assert ogmm.pva is ogmm.cva
227
ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
228
ogmm.update_em_frame()
230
ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
233
if __name__ == "__main__":