~ubuntu-branches/ubuntu/saucy/python-scipy/saucy

« back to all changes in this revision

Viewing changes to Lib/sandbox/pyem/online_em.py

  • Committer: Bazaar Package Importer
  • Author(s): Ondrej Certik
  • Date: 2008-06-16 22:58:01 UTC
  • mfrom: (2.1.24 intrepid)
  • Revision ID: james.westby@ubuntu.com-20080616225801-irdhrpcwiocfbcmt
Tags: 0.6.0-12
* The description updated to match the current SciPy (Closes: #489149).
* Standards-Version bumped to 3.8.0 (no action needed)
* Build-Depends: netcdf-dev changed to libnetcdf-dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# /usr/bin/python
2
 
# Last Change: Wed Dec 06 09:00 PM 2006 J
3
 
 
4
 
#---------------------------------------------
5
 
# This is not meant to be used yet !!!! I am 
6
 
# not sure how to integrate this stuff inside
7
 
# the package yet. The cases are:
8
 
#   - we have a set of data, and we want to test online EM 
9
 
#   compared to normal EM 
10
 
#   - we do not have all the data before putting them in online EM:
11
 
#   eg current frame depends on previous frame in some way.
12
 
 
13
 
# TODO:
14
 
#   - Add biblio
15
 
#   - Look back at articles for discussion for init, regularization and 
16
 
#   convergence rates
17
 
#   - the function sufficient_statistics does not really return SS. This is not a
18
 
#   big problem, but it would be better to really return them as the name implied.
19
 
 
20
 
import numpy as N
21
 
from numpy import mean
22
 
from numpy.testing import assert_array_almost_equal, assert_array_equal
23
 
 
24
 
from gmm_em import ExpMixtureModel, GMM, EM
25
 
from gauss_mix import GM
26
 
from kmean import kmean
27
 
import densities2 as D
28
 
 
29
 
import copy
30
 
from numpy.random import seed
31
 
 
32
 
# Clamp
33
 
clamp   = 1e-8
34
 
 
35
 
# Error classes
36
 
class OnGmmError(Exception):
37
 
    """Base class for exceptions in this module."""
38
 
    pass
39
 
 
40
 
class OnGmmParamError:
41
 
    """Exception raised for errors in gmm params
42
 
 
43
 
    Attributes:
44
 
        expression -- input expression in which the error occurred
45
 
        message -- explanation of the error
46
 
    """
47
 
    def __init__(self, message):
48
 
        self.message    = message
49
 
    
50
 
    def __str__(self):
51
 
        return self.message
52
 
 
53
 
class OnGMM(ExpMixtureModel):
54
 
    """A Class for 'online' (ie recursive) EM. Instead
55
 
    of running the E step on the whole data, the sufficient statistics
56
 
    are updated for each new frame of data, and used in the (unchanged)
57
 
    M step"""
58
 
    def init_random(self, init_data):
59
 
        """ Init the model at random."""
60
 
        k   = self.gm.k
61
 
        d   = self.gm.d
62
 
        if self.gm.mode == 'diag':
63
 
            w           = N.ones(k) / k
64
 
 
65
 
            # Init the internal state of EM
66
 
            self.cx     = N.outer(w, mean(init_data, 0))
67
 
            self.cxx    = N.outer(w, mean(init_data ** 2, 0))
68
 
 
69
 
            # w, mu and va init is the same that in the standard case
70
 
            (code, label)   = kmean(init_data, init_data[0:k, :], niter)
71
 
            mu          = code.copy()
72
 
            va          = N.zeros((k, d))
73
 
            for i in range(k):
74
 
                for j in range(d):
75
 
                    va [i,j] = N.cov(init_data[N.where(label==i), j], rowvar = 0)
76
 
        else:
77
 
            raise OnGmmParamError("""init_online not implemented for
78
 
                    mode %s yet""", mode)
79
 
 
80
 
        self.gm.set_param(w, mu, va)
81
 
        # c* are the parameters which are computed at every step (ie
82
 
        # when a new frame is taken into account
83
 
        self.cw     = self.gm.w
84
 
        self.cmu    = self.gm.mu
85
 
        self.cva    = self.gm.va
86
 
 
87
 
        # p* are the parameters used when computing gaussian densities
88
 
        # they are always the same than c* in the online case
89
 
        self.pw     = self.cw
90
 
        self.pmu    = self.cmu
91
 
        self.pva    = self.cva
92
 
 
93
 
    def init_kmean(self, init_data, niter = 5):
94
 
        """ Init the model using kmean."""
95
 
        k   = self.gm.k
96
 
        d   = self.gm.d
97
 
        if self.gm.mode == 'diag':
98
 
            w           = N.ones(k) / k
99
 
 
100
 
            # Init the internal state of EM
101
 
            self.cx     = N.outer(w, mean(init_data, 0))
102
 
            self.cxx    = N.outer(w, mean(init_data ** 2, 0))
103
 
 
104
 
            # w, mu and va init is the same that in the standard case
105
 
            (code, label)   = kmean(init_data, init_data[0:k, :], niter)
106
 
            mu          = code.copy()
107
 
            va          = N.zeros((k, d))
108
 
            for i in range(k):
109
 
                for j in range(d):
110
 
                    va [i,j] = N.cov(init_data[N.where(label==i), j], rowvar = 0)
111
 
        else:
112
 
            raise OnGmmParamError("""init_online not implemented for
113
 
                    mode %s yet""", mode)
114
 
 
115
 
        self.gm.set_param(w, mu, va)
116
 
        # c* are the parameters which are computed at every step (ie
117
 
        # when a new frame is taken into account
118
 
        self.cw     = self.gm.w
119
 
        self.cmu    = self.gm.mu
120
 
        self.cva    = self.gm.va
121
 
 
122
 
        # p* are the parameters used when computing gaussian densities
123
 
        # they are the same than c* in the online case
124
 
        # self.pw     = self.cw.copy()
125
 
        # self.pmu    = self.cmu.copy()
126
 
        # self.pva    = self.cva.copy()
127
 
        self.pw     = self.cw
128
 
        self.pmu    = self.cmu
129
 
        self.pva    = self.cva
130
 
 
131
 
    def __init__(self, gm, init_data, init = 'kmean'):
132
 
        self.gm = gm
133
 
        
134
 
        # Possible init methods
135
 
        init_methods = {'kmean' : self.init_kmean}
136
 
 
137
 
        self.init   = init_methods[init]
138
 
 
139
 
    def compute_sufficient_statistics_frame(self, frame, nu):
140
 
        """ sufficient_statistics(frame, nu) for one frame of data
141
 
        
142
 
        frame has to be rank 1 !"""
143
 
        gamma   = multiple_gauss_den_frame(frame, self.pmu, self.pva)
144
 
        gamma   *= self.pw
145
 
        gamma   /= N.sum(gamma)
146
 
        # <1>(t) = cw(t), self.cw = cw(t), each element is one component running weight
147
 
        #self.cw        = (1 - nu) * self.cw + nu * gamma
148
 
        self.cw *= (1 - nu)
149
 
        self.cw += nu * gamma
150
 
 
151
 
        for k in range(self.gm.k):
152
 
            self.cx[k]   = (1 - nu) * self.cx[k] + nu * frame * gamma[k]
153
 
            self.cxx[k]  = (1 - nu) * self.cxx[k] + nu * frame ** 2 * gamma[k]
154
 
 
155
 
    def update_em_frame(self):
156
 
        for k in range(self.gm.k):
157
 
            self.cmu[k]  = self.cx[k] / self.cw[k]
158
 
            self.cva[k]  = self.cxx[k] / self.cw[k] - self.cmu[k] ** 2
159
 
    
160
 
import _rawden
161
 
 
162
 
class OnGMM1d(ExpMixtureModel):
163
 
    """Special purpose case optimized for 1d dimensional cases.
164
 
    
165
 
    Require each frame to be a float, which means the API is a bit
166
 
    different than OnGMM. You are trading elegance for speed here !"""
167
 
    def init_kmean(self, init_data, niter = 5):
168
 
        """ Init the model using kmean."""
169
 
        assert init_data.ndim == 1
170
 
        k   = self.gm.k
171
 
        w   = N.ones(k) / k
172
 
 
173
 
        # Init the internal state of EM
174
 
        self.cx     = w * mean(init_data)
175
 
        self.cxx    = w * mean(init_data ** 2)
176
 
 
177
 
        # w, mu and va init is the same that in the standard case
178
 
        (code, label)   = kmean(init_data[:, N.newaxis], \
179
 
                init_data[0:k, N.newaxis], niter)
180
 
        mu          = code.copy()
181
 
        va          = N.zeros((k, 1))
182
 
        for i in range(k):
183
 
            va[i] = N.cov(init_data[N.where(label==i)], rowvar = 0)
184
 
 
185
 
        self.gm.set_param(w, mu, va)
186
 
        # c* are the parameters which are computed at every step (ie
187
 
        # when a new frame is taken into account
188
 
        self.cw     = self.gm.w
189
 
        self.cmu    = self.gm.mu[:, 0]
190
 
        self.cva    = self.gm.va[:, 0]
191
 
 
192
 
        # p* are the parameters used when computing gaussian densities
193
 
        # they are the same than c* in the online case
194
 
        # self.pw     = self.cw.copy()
195
 
        # self.pmu    = self.cmu.copy()
196
 
        # self.pva    = self.cva.copy()
197
 
        self.pw     = self.cw
198
 
        self.pmu    = self.cmu
199
 
        self.pva    = self.cva
200
 
 
201
 
    def __init__(self, gm, init_data, init = 'kmean'):
202
 
        self.gm = gm
203
 
        if self.gm.d is not 1:
204
 
            raise RuntimeError("expects 1d gm only !")
205
 
 
206
 
        # Possible init methods
207
 
        init_methods    = {'kmean' : self.init_kmean}
208
 
        self.init       = init_methods[init]
209
 
 
210
 
    def compute_sufficient_statistics_frame(self, frame, nu):
211
 
        """expects frame and nu to be float. Returns
212
 
        cw, cxx and cxx, eg the sufficient statistics."""
213
 
        _rawden.compute_ss_frame_1d(frame, self.cw, self.cmu, self.cva, 
214
 
                self.cx, self.cxx, nu)
215
 
        return self.cw, self.cx, self.cxx
216
 
 
217
 
    def update_em_frame(self, cw, cx, cxx):
218
 
        """Update EM state using SS as returned by 
219
 
        compute_sufficient_statistics_frame. """
220
 
        self.cmu    = cx / cw
221
 
        self.cva    = cxx / cw - self.cmu ** 2
222
 
 
223
 
    def compute_em_frame(self, frame, nu):
224
 
        """Run a whole em step for one frame. frame and nu should be float;
225
 
        if you don't need to split E and M steps, this is faster than calling 
226
 
        compute_sufficient_statistics_frame and update_em_frame"""
227
 
        _rawden.compute_em_frame_1d(frame, self.cw, self.cmu, self.cva, \
228
 
                self.cx, self.cxx, nu)
229
 
#class OnlineEM:
230
 
#    def __init__(self, ogm):
231
 
#        """Init Online Em algorithm with ogm, an instance of OnGMM."""
232
 
#        if not isinstance(ogm, OnGMM):
233
 
#            raise TypeError("expect a OnGMM instance for the model")
234
 
#
235
 
#    def init_em(self):
236
 
#        pass
237
 
#
238
 
#    def train(self, data, nu):
239
 
#        pass
240
 
#
241
 
#    def train_frame(self, frame, nu):
242
 
#        pass
243
 
 
244
 
def multiple_gauss_den_frame(data, mu, va):
245
 
    """Helper function to generate several Gaussian
246
 
    pdf (different parameters) from one frame of data.
247
 
    
248
 
    Semantics depending on data's rank
249
 
        - rank 0: mu and va expected to have rank 0 or 1
250
 
        - rank 1: mu and va expected to have rank 2."""
251
 
    if N.ndim(data) == 0:
252
 
        # scalar case
253
 
        k   = mu.size
254
 
        inva    = 1/va
255
 
        fac     = (2*N.pi) ** (-1/2.0) * N.sqrt(inva)
256
 
        y       = ((data-mu) ** 2) * -0.5 * inva
257
 
        return   fac * N.exp(y.ravel())
258
 
    elif N.ndim(data) == 1:
259
 
        # multi variate case (general case)
260
 
        k   = mu.shape[0]
261
 
        y   = N.zeros(k, data.dtype)
262
 
        if mu.size == va.size:
263
 
            # diag case
264
 
            for i in range(k):
265
 
                #y[i] = D.gauss_den(data, mu[i], va[i])
266
 
                # This is a bit hackish: _diag_gauss_den implementation's
267
 
                # changes can break this, but I don't see how to easily fix this
268
 
                y[i] = D._diag_gauss_den(data, mu[i], va[i], False, -1)
269
 
            return y
270
 
        else:
271
 
            raise RuntimeError("full not implemented yet")
272
 
            #for i in range(K):
273
 
            #    y[i] = D.gauss_den(data, mu[i, :], 
274
 
            #                va[d*i:d*i+d, :])
275
 
            #return y.T
276
 
    else:
277
 
        raise RuntimeError("frame should be rank 0 or 1 only")
278
 
        
279
 
 
280
 
if __name__ == '__main__':
281
 
    d       = 1
282
 
    k       = 2
283
 
    mode    = 'diag'
284
 
    nframes = int(5e3)
285
 
    emiter  = 4
286
 
    seed(5)
287
 
 
288
 
    #+++++++++++++++++++++++++++++++++++++++++++++++++
289
 
    # Generate a model with k components, d dimensions
290
 
    #+++++++++++++++++++++++++++++++++++++++++++++++++
291
 
    w, mu, va   = GM.gen_param(d, k, mode, spread = 1.5)
292
 
    gm          = GM.fromvalues(w, mu, va)
293
 
    # Sample nframes frames  from the model
294
 
    data        = gm.sample(nframes)
295
 
 
296
 
    #++++++++++++++++++++++++++++++++++++++++++
297
 
    # Approximate the models with classical EM
298
 
    #++++++++++++++++++++++++++++++++++++++++++
299
 
    # Init the model
300
 
    lgm = GM(d, k, mode)
301
 
    gmm = GMM(lgm, 'kmean')
302
 
    gmm.init(data)
303
 
 
304
 
    gm0    = copy.copy(gmm.gm)
305
 
    # The actual EM, with likelihood computation
306
 
    like    = N.zeros(emiter)
307
 
    for i in range(emiter):
308
 
        g, tgd  = gmm.sufficient_statistics(data)
309
 
        like[i] = N.sum(N.log(N.sum(tgd, 1)), axis = 0)
310
 
        gmm.update_em(data, g)
311
 
 
312
 
    #++++++++++++++++++++++++++++++++++++++++
313
 
    # Approximate the models with online EM
314
 
    #++++++++++++++++++++++++++++++++++++++++
315
 
    ogm     = GM(d, k, mode)
316
 
    ogmm    = OnGMM(ogm, 'kmean')
317
 
    init_data   = data[0:nframes / 20, :]
318
 
    ogmm.init(init_data)
319
 
 
320
 
    # Forgetting param
321
 
    ku          = 0.005
322
 
    t0          = 200
323
 
    lamb        = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
324
 
    nu0         = 0.2
325
 
    nu          = N.zeros((len(lamb), 1))
326
 
    nu[0]       = nu0
327
 
    for i in range(1, len(lamb)):
328
 
        nu[i]   = 1./(1 + lamb[i] / nu[i-1])
329
 
 
330
 
    print "meth1"
331
 
    # object version of online EM
332
 
    for t in range(nframes):
333
 
        ogmm.compute_sufficient_statistics_frame(data[t], nu[t])
334
 
        ogmm.update_em_frame()
335
 
 
336
 
    ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
337
 
 
338
 
    # 1d optimized version
339
 
    ogm2        = GM(d, k, mode)
340
 
    ogmm2       = OnGMM1d(ogm2, 'kmean')
341
 
    ogmm2.init(init_data[:, 0])
342
 
 
343
 
    print "meth2"
344
 
    # object version of online EM
345
 
    for t in range(nframes):
346
 
        ogmm2.compute_sufficient_statistics_frame(data[t, 0], nu[t])
347
 
        ogmm2.update_em_frame()
348
 
 
349
 
    #ogmm2.gm.set_param(ogmm2.cw, ogmm2.cmu, ogmm2.cva)
350
 
 
351
 
    print ogmm.cw
352
 
    print ogmm2.cw
353
 
    #+++++++++++++++
354
 
    # Draw the model
355
 
    #+++++++++++++++
356
 
    print "drawing..."
357
 
    import pylab as P
358
 
    P.subplot(2, 1, 1)
359
 
 
360
 
    if not d == 1:
361
 
        # Draw what is happening
362
 
        P.plot(data[:, 0], data[:, 1], '.', label = '_nolegend_')
363
 
 
364
 
        h   = gm.plot()    
365
 
        [i.set_color('g') for i in h]
366
 
        h[0].set_label('true confidence ellipsoides')
367
 
 
368
 
        h   = gm0.plot()    
369
 
        [i.set_color('k') for i in h]
370
 
        h[0].set_label('initial confidence ellipsoides')
371
 
 
372
 
        h   = lgm.plot()    
373
 
        [i.set_color('r') for i in h]
374
 
        h[0].set_label('confidence ellipsoides found by EM')
375
 
 
376
 
        h   = ogmm.gm.plot()    
377
 
        [i.set_color('m') for i in h]
378
 
        h[0].set_label('confidence ellipsoides found by Online EM')
379
 
 
380
 
        # P.legend(loc = 0)
381
 
    else:
382
 
        # Real confidence ellipses
383
 
        h   = gm.plot1d()
384
 
        [i.set_color('g') for i in h['pdf']]
385
 
        h['pdf'][0].set_label('true pdf')
386
 
 
387
 
        # Initial confidence ellipses as found by kmean
388
 
        h0  = gm0.plot1d()
389
 
        [i.set_color('k') for i in h0['pdf']]
390
 
        h0['pdf'][0].set_label('initial pdf')
391
 
 
392
 
        # Values found by EM
393
 
        hl  = lgm.plot1d(fill = 1, level = 0.66)
394
 
        [i.set_color('r') for i in hl['pdf']]
395
 
        hl['pdf'][0].set_label('pdf found by EM')
396
 
 
397
 
        P.legend(loc = 0)
398
 
 
399
 
        # Values found by Online EM
400
 
        hl  = ogmm.gm.plot1d(fill = 1, level = 0.66)
401
 
        [i.set_color('m') for i in hl['pdf']]
402
 
        hl['pdf'][0].set_label('pdf found by Online EM')
403
 
 
404
 
        P.legend(loc = 0)
405
 
 
406
 
    P.subplot(2, 1, 2)
407
 
    P.plot(nu)
408
 
    P.title('Learning rate')
409
 
    P.show()