~ubuntu-branches/ubuntu/saucy/sphinxtrain/saucy

« back to all changes in this revision

Viewing changes to src/programs/map_adapt/main.c

  • Committer: Package Import Robot
  • Author(s): Samuel Thibault
  • Date: 2013-01-02 04:10:21 UTC
  • Revision ID: package-import@ubuntu.com-20130102041021-ynsizmz33fx02hea
Tags: upstream-1.0.8
ImportĀ upstreamĀ versionĀ 1.0.8

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/* -*- c-file-style: "bsd"; c-basic-offset: 4 -*- */
 
2
/*********************************************************************
 
3
 *
 
4
 * $Header$
 
5
 *
 
6
 * Carnegie Mellon ARPA Speech Group
 
7
 *
 
8
 * Copyright (c) 1996-2005 Carnegie Mellon University.
 
9
 * All rights reserved.
 
10
 *********************************************************************
 
11
 *
 
12
 * File: src/programs/map_adapt/main.c
 
13
 * 
 
14
 * Description: 
 
15
 *      Do one pass of MAP re-estimation (adaptation).
 
16
 *
 
17
 *      See "Speaker Adaptation Based on MAP Estimation of HMM
 
18
 *      Parameters", Chin-Hui Lee and Jean-Luc Gauvain, Proceedings
 
19
 *      of ICASSP 1993, p. II-558 for the details of prior density
 
20
 *      estimation and forward-backward MAP.
 
21
 * 
 
22
 * Author: 
 
23
 *      David Huggins-Daines <dhuggins@cs.cmu.edu>
 
24
 *
 
25
 *********************************************************************/
 
26
 
 
27
#include <s3/common.h>
 
28
#include <sys_compat/file.h>
 
29
#include <s3/model_inventory.h>
 
30
#include <s3/model_def_io.h>
 
31
#include <s3/s3gau_io.h>
 
32
#include <s3/s3mixw_io.h>
 
33
#include <s3/s3tmat_io.h>
 
34
#include <s3/s3acc_io.h>
 
35
#include <s3/s3.h>
 
36
 
 
37
#include <sphinxbase/matrix.h>
 
38
#include <sphinxbase/err.h>
 
39
 
 
40
#include <stdio.h>
 
41
#include <math.h>
 
42
#include <assert.h>
 
43
#include <string.h>
 
44
 
 
45
#include "parse_cmd_ln.h"
 
46
 
 
47
static void
 
48
check_consistency(const char *filename,
 
49
                  uint32 n_mgau, uint32 n_mgau_rd,
 
50
                  uint32 n_stream, uint32 n_stream_rd,
 
51
                  uint32 n_density, uint32 n_density_rd,
 
52
                  const uint32 *veclen, 
 
53
                  const uint32 *veclen_rd)
 
54
{
 
55
    uint32 s;
 
56
 
 
57
    if (n_mgau != n_mgau_rd)
 
58
        E_FATAL("Number of codebooks is mismatched in %s\n",filename);
 
59
    if (n_stream != n_stream_rd)
 
60
        E_FATAL("Number of streams is mismatched in %s\n",filename);
 
61
    if (n_density != n_density_rd)
 
62
        E_FATAL("Number of gaussians is mismatched in %s\n",filename);
 
63
    for (s = 0; s < n_stream; ++s)
 
64
        if (veclen[s] != veclen_rd[s])
 
65
            E_FATAL("Vector length of stream %u mismatched in %s\n",
 
66
                    s, filename);
 
67
}
 
68
 
 
69
static float32 ***
 
70
estimate_tau(vector_t ***si_mean, vector_t ***si_var, float32 ***si_mixw,
 
71
             uint32 n_cb, uint32 n_stream, uint32 n_density, uint32 n_mixw, const uint32 *veclen,
 
72
             vector_t ***wt_mean, float32 ***wt_mixw, float32 ***wt_dcount)
 
73
{
 
74
    float32 ***map_tau;
 
75
    uint32 i, j, k, m;
 
76
 
 
77
    E_INFO("Estimating tau hyperparameter from variances and observations\n");
 
78
    map_tau = (float32 ***)ckd_calloc_3d(n_mixw, n_stream, n_density, sizeof(float32));
 
79
    for (i = 0; i < n_mixw; ++i) {
 
80
        for (j = 0; j < n_stream; ++j) {
 
81
            for (k = 0; k < n_density; ++k) {
 
82
                float32 tau_nom, tau_dnom;
 
83
 
 
84
                tau_nom = veclen[j] * wt_mixw[i][j][k];
 
85
                tau_dnom = 0.0f;
 
86
                for (m = 0; m < veclen[j]; ++m) {
 
87
                    float32 ydiff, wvar, dnom, ml_mu, si_mu, si_sigma;
 
88
 
 
89
                    if (n_mixw != n_cb && n_cb == 1) {/* Semi-continuous. */
 
90
                        dnom = wt_dcount[0][j][k];
 
91
                        si_mu = si_mean[0][j][k][m];
 
92
                        si_sigma = si_var[0][j][k][m];
 
93
                        ml_mu = dnom ? wt_mean[0][j][k][m] / dnom : si_mu;
 
94
                    }
 
95
                    else { /* Continuous. */
 
96
                        dnom = wt_dcount[i][j][k];
 
97
                        si_mu = si_mean[i][j][k][m];
 
98
                        si_sigma = si_var[i][j][k][m];
 
99
                        ml_mu = dnom ? wt_mean[i][j][k][m] / dnom : si_mu;
 
100
                    }
 
101
 
 
102
                    ydiff = ml_mu - si_mu;
 
103
                    /* Gauvain/Lee's estimation of this makes no
 
104
                     * sense as I read it, it seems to simply
 
105
                     * equal the precision matrix.  We want to use
 
106
                     * the variance anyway, because higher
 
107
                     * variance in the SI models should lead to
 
108
                     * stronger adaptation. */
 
109
                    /* And this is still less than ideal, because
 
110
                     * it way overestimates tau for SCHMM due to
 
111
                     * the large number of mixtures.  So for
 
112
                     * semi-continuous models you probably want to
 
113
                     * use -fixedtau. */
 
114
                    wvar = si_mixw[i][j][k] * si_sigma;
 
115
                    tau_dnom += dnom * ydiff * wvar * ydiff;
 
116
                }
 
117
                if (tau_dnom > 1e-5 && tau_nom > 1e-5)
 
118
                    map_tau[i][j][k] = tau_nom / tau_dnom;
 
119
                else
 
120
                    map_tau[i][j][k] = 1000.0f; /* FIXME: Something big, I guess. */
 
121
#if 0
 
122
                E_INFO("map_tau[%d][%d][%d] = %f / %f = %f\n",
 
123
                       i, j, k, tau_nom, tau_dnom, map_tau[i][j][k]);
 
124
#endif
 
125
            }
 
126
        }
 
127
    }
 
128
 
 
129
    return map_tau;
 
130
}
 
131
 
 
132
static int
 
133
map_mixw_reest(float32 ***map_tau, float32 fixed_tau,
 
134
               float32 ***si_mixw, float32 ***wt_mixw, float32 ***map_mixw,
 
135
               float32 mwfloor, uint32 n_mixw, uint32 n_stream, uint32 n_density)
 
136
{
 
137
    uint32 i, j, k;
 
138
 
 
139
    E_INFO("Re-estimating mixture weights using MAP\n");
 
140
    for (i = 0; i < n_mixw; ++i) {
 
141
        for (j = 0; j < n_stream; ++j) {
 
142
            float32 sum_tau, sum_nu, sum_wt_mixw;
 
143
 
 
144
            sum_tau = sum_nu = sum_wt_mixw = 0.0f;
 
145
            for (k = 0; k < n_density; ++k)
 
146
                sum_tau += (map_tau != NULL) ? map_tau[i][j][k] : fixed_tau;
 
147
            for (k = 0; k < n_density; ++k) {
 
148
                float32 nu;
 
149
 
 
150
                /* NOTE: We estimate nu such that the SI mixture
 
151
                 * weight is the mode of the posterior distribution,
 
152
                 * hence the + 1.  This allows the MAP estimate to
 
153
                 * converge to the SI one in the case of no adaptation
 
154
                 * data (clearly, this is desirable!) */
 
155
                nu = si_mixw[i][j][k] * sum_tau + 1;
 
156
                sum_nu += nu;
 
157
                sum_wt_mixw += wt_mixw[i][j][k];
 
158
            }
 
159
 
 
160
            for (k = 0; k < n_density; ++k) {
 
161
                float32 tau, nu;
 
162
 
 
163
                tau = (map_tau != NULL) ? map_tau[i][j][k] : fixed_tau;
 
164
                nu = si_mixw[i][j][k] * sum_tau + 1;
 
165
 
 
166
                map_mixw[i][j][k] = (nu - 1 + wt_mixw[i][j][k])
 
167
                    / (sum_nu - n_density + sum_wt_mixw);
 
168
                /* Floor mixture weights - otherwise they will be
 
169
                   negative in cases where si_mixw is very small.
 
170
                   FIXME: This might be an error in my implementation?  */
 
171
                if (map_mixw[i][j][k] < mwfloor)
 
172
                    map_mixw[i][j][k] = mwfloor;
 
173
#if 0
 
174
                printf("%d %d %d tau %f map_mixw %f =\n"
 
175
                       "      nu %f - 1     +     wt_mixw %f\n"
 
176
                       "/ sum_nu %f  - %d   + sum_wt_mixw %f\n",
 
177
                       i, j, k, tau, 
 
178
                       map_mixw[i][j][k], nu, wt_mixw[i][j][k],
 
179
                       sum_nu, n_density, sum_wt_mixw);
 
180
#endif
 
181
            }
 
182
        }
 
183
    }
 
184
    return S3_SUCCESS;
 
185
}
 
186
 
 
187
static int
 
188
map_tmat_reest(float32 ***si_tmat, float32 ***wt_tmat,
 
189
               float32 ***map_tmat, float32 tpfloor,
 
190
               uint32 n_tmat, uint32 n_state)
 
191
{
 
192
    uint32 t, i, j;
 
193
 
 
194
    E_INFO("Re-estimating transition probabilities using MAP\n");
 
195
    for (t = 0; t < n_tmat; ++t) {
 
196
        for (i = 0; i < n_state-1; ++i) {
 
197
            float32 sum_si_tmat = 0.0f, sum_wt_tmat = 0.0f;
 
198
 
 
199
            for (j = 0; j < n_state; ++j) {
 
200
                sum_si_tmat += si_tmat[t][i][j];
 
201
                sum_wt_tmat += si_tmat[t][i][j];
 
202
            }
 
203
            for (j = 0; j < n_state; ++j) {
 
204
                if (si_tmat[t][i][j] + wt_tmat[t][i][j] < 0) continue;
 
205
 
 
206
                map_tmat[t][i][j] =
 
207
                    (si_tmat[t][i][j] + wt_tmat[t][i][j])
 
208
                    / (sum_si_tmat + sum_wt_tmat);
 
209
                if (map_tmat[t][i][j] < 0.0f) {
 
210
                    E_WARN("map_tmat[%d][%d][%d] < 0 (%f)\n",
 
211
                           t, i, j, map_tmat[t][i][j]);
 
212
                    map_tmat[t][i][j] = 0.0f;
 
213
                }
 
214
            }
 
215
        }
 
216
    }
 
217
 
 
218
    return S3_SUCCESS;
 
219
}
 
220
 
 
221
static int32
 
222
bayes_mean_reest(vector_t ***si_mean, vector_t ***si_var,
 
223
                 vector_t ***wt_mean, vector_t ***wt_var,
 
224
                 float32 ***wt_dcount, int32 pass2var,
 
225
                 vector_t ***map_mean, float32 varfloor,
 
226
                 uint32 i, uint32 j, uint32 k, const uint32 *veclen)
 
227
{
 
228
    uint32 m;
 
229
 
 
230
    /* Textbook MAP estimator for single Gaussian.
 
231
       This works better if tau is unknown. */
 
232
    for (m = 0; m < veclen[j]; ++m) {
 
233
        if (wt_dcount[i][j][k]) {
 
234
            float32 mlmean, mlvar;
 
235
 
 
236
            mlmean = wt_mean[i][j][k][m] / wt_dcount[i][j][k];
 
237
            if (pass2var)
 
238
                mlvar = wt_var[i][j][k][m] / wt_dcount[i][j][k];
 
239
            else
 
240
                mlvar = (wt_var[i][j][k][m] / wt_dcount[i][j][k]
 
241
                         - mlmean * mlmean);
 
242
            /* Perfectly normal if -2passvar isn't specified. */
 
243
            if (mlvar < 0.0f) {
 
244
                if (pass2var)
 
245
                    E_WARN("mlvar[%d][%d][%d][%d] < 0 (%f)\n", i,j,k,m,mlvar);
 
246
                mlvar = varfloor;
 
247
            }
 
248
            map_mean[i][j][k][m] =
 
249
                (wt_dcount[i][j][k] * si_var[i][j][k][m] * mlmean
 
250
                 + mlvar * si_mean[i][j][k][m])
 
251
                / (wt_dcount[i][j][k] * si_var[i][j][k][m] + mlvar);
 
252
        }
 
253
        else
 
254
            map_mean[i][j][k][m] = si_mean[i][j][k][m];
 
255
    }
 
256
    return S3_SUCCESS;
 
257
}
 
258
 
 
259
static int
 
260
map_mean_reest(float32 tau, vector_t ***si_mean, vector_t ***wt_mean,
 
261
               float32 ***wt_dcount, vector_t ***map_mean,
 
262
               uint32 i, uint32 j, uint32 k, const uint32 *veclen)
 
263
{
 
264
    uint32 m;
 
265
 
 
266
    /* CH Lee mean update equation.  Use this if
 
267
       you want to experiment with values of tau. */
 
268
    for (m = 0; m < veclen[j]; ++m) {
 
269
        if (wt_dcount[i][j][k])
 
270
            map_mean[i][j][k][m] =
 
271
                (tau * si_mean[i][j][k][m] + wt_mean[i][j][k][m])
 
272
                / (tau + wt_dcount[i][j][k]);
 
273
        else
 
274
            map_mean[i][j][k][m] = si_mean[i][j][k][m];
 
275
    }
 
276
    return S3_SUCCESS;
 
277
}
 
278
 
 
279
static int
 
280
map_var_reest(float32 tau, vector_t ***si_mean, vector_t ***si_var,
 
281
              vector_t ***wt_mean, vector_t ***wt_var, float32 ***wt_dcount,
 
282
              vector_t ***map_mean, vector_t ***map_var, float32 varfloor,
 
283
              uint32 i, uint32 j, uint32 k, const uint32 *veclen)
 
284
{
 
285
    uint32 m;
 
286
 
 
287
    for (m = 0; m < veclen[j]; ++m) {
 
288
        float32 alpha, beta, mdiff;
 
289
 
 
290
        /* Somewhat different estimates of alpha and beta from the
 
291
         * ones given in Gauvain & Lee.  These actually converge to
 
292
         * the SI variance with no observations, and also seem to
 
293
         * perform better in at least one case.  */
 
294
        alpha = tau + 1;
 
295
        beta = tau * si_var[i][j][k][m];
 
296
 
 
297
        mdiff = si_mean[i][j][k][m] - map_mean[i][j][k][m];
 
298
        /* This should be the correct update equation for diagonal
 
299
         * covariance matrices. */
 
300
        map_var[i][j][k][m] = (beta
 
301
                               + wt_var[i][j][k][m]
 
302
                               + tau * mdiff * mdiff)
 
303
            / (alpha - 1 + wt_dcount[i][j][k]);
 
304
        if (map_var[i][j][k][m] < 0.0f) {
 
305
            /* This is bad and shouldn't happen! */
 
306
            E_WARN("mapvar[%d][%d][%d][%d] < 0 (%f)\n", i,j,k,m, map_var[i][j][k][m]);
 
307
            map_var[i][j][k][m] = varfloor;
 
308
        }
 
309
        if (map_var[i][j][k][m] < varfloor)
 
310
            map_var[i][j][k][m] = varfloor;
 
311
    }
 
312
    return S3_SUCCESS;
 
313
}
 
314
 
 
315
static int
 
316
map_update(void)
 
317
{
 
318
    float32 ***si_mixw = NULL;
 
319
    float32 ***si_tmat = NULL;
 
320
    vector_t ***si_mean = NULL;
 
321
    vector_t ***si_var = NULL;
 
322
 
 
323
    vector_t ***wt_mean = NULL;
 
324
    vector_t ***wt_var = NULL;
 
325
    float32 ***wt_mixw = NULL;
 
326
    float32 ***wt_tmat = NULL;
 
327
    float32 ***wt_dcount = NULL;
 
328
    int32 pass2var;
 
329
 
 
330
    float32 ***map_mixw = NULL;
 
331
    float32 ***map_tmat = NULL;
 
332
    vector_t ***map_mean = NULL;
 
333
    vector_t ***map_var = NULL;
 
334
    float32 ***map_tau = NULL;
 
335
    float32 fixed_tau = 10.0f;
 
336
    float32 mwfloor = 1e-5f;
 
337
    float32 varfloor = 1e-5f;
 
338
    float32 tpfloor = 1e-4f;
 
339
 
 
340
    uint32 n_mixw, n_mixw_rd;
 
341
    uint32 n_tmat, n_tmat_rd, n_state, n_state_rd;
 
342
    uint32 n_cb, n_cb_rd;
 
343
    uint32 n_stream, n_stream_rd;
 
344
    uint32 n_density, n_density_rd;
 
345
    uint32 *veclen = NULL;
 
346
    uint32 *veclen_rd = NULL;
 
347
 
 
348
    const char **accum_dir;
 
349
    const char *si_mixw_fn;
 
350
    const char *map_mixw_fn;
 
351
    const char *si_tmat_fn;
 
352
    const char *map_tmat_fn;
 
353
    const char *si_mean_fn;
 
354
    const char *map_mean_fn;
 
355
    const char *si_var_fn;
 
356
    const char *map_var_fn;
 
357
 
 
358
    uint32 i, j, k;
 
359
 
 
360
    accum_dir = cmd_ln_str_list("-accumdir");
 
361
    si_mean_fn = cmd_ln_str("-meanfn");
 
362
    si_var_fn = cmd_ln_str("-varfn");
 
363
    si_tmat_fn = cmd_ln_str("-tmatfn");
 
364
    si_mixw_fn = cmd_ln_str("-mixwfn");
 
365
    map_mean_fn = cmd_ln_str("-mapmeanfn");
 
366
    map_var_fn = cmd_ln_str("-mapvarfn");
 
367
    map_tmat_fn = cmd_ln_str("-maptmatfn");
 
368
    map_mixw_fn = cmd_ln_str("-mapmixwfn");
 
369
 
 
370
    /* Must be at least one accum dir. */
 
371
    if (accum_dir == NULL)
 
372
        E_FATAL("Must specify at least one -accumdir\n");
 
373
 
 
374
    /* Must have means and variances. */
 
375
    if (si_mean_fn == NULL || si_var_fn == NULL || si_mixw_fn == NULL)
 
376
        E_FATAL("Must specify baseline means, variances, and mixture weights\n");
 
377
 
 
378
    /* Must specify output means. */
 
379
    if (map_mean_fn == NULL)
 
380
        E_FATAL("Must at least specify output MAP means\n");
 
381
 
 
382
    /* Read SI model parameters. */
 
383
    if (s3gau_read(si_mean_fn, &si_mean,
 
384
                   &n_cb, &n_stream, &n_density, &veclen) != S3_SUCCESS)
 
385
        E_FATAL("Couldn't read %s\n", si_mean_fn);
 
386
    if (s3gau_read(si_var_fn, &si_var,
 
387
                   &n_cb_rd, &n_stream_rd, &n_density_rd, &veclen_rd) != S3_SUCCESS)
 
388
        E_FATAL("Couldn't read %s\n", si_var_fn);
 
389
    check_consistency(si_var_fn, n_cb, n_cb_rd, n_stream, n_stream_rd,
 
390
                      n_density, n_density_rd, veclen, veclen_rd);
 
391
    /* Don't free veclen_rd, as rdacc_den needs it. */
 
392
 
 
393
    /* Read and normalize SI mixture weights. */
 
394
    if (si_mixw_fn) {
 
395
        mwfloor = cmd_ln_float32("-mwfloor");
 
396
        if (s3mixw_read(si_mixw_fn, &si_mixw, &n_mixw, &n_stream_rd, &n_density_rd)
 
397
            != S3_SUCCESS)
 
398
            E_FATAL("Couldn't read %s\n", si_mixw_fn);
 
399
        for (i = 0; i < n_mixw; ++i) {
 
400
            for (j = 0; j < n_stream; ++j) {
 
401
                float32 sum_si_mixw = 0.0f;
 
402
                for (k = 0; k < n_density; ++k) {
 
403
                    if (si_mixw[i][j][k] < mwfloor)
 
404
                        si_mixw[i][j][k] = mwfloor;
 
405
                    sum_si_mixw += si_mixw[i][j][k];
 
406
                }
 
407
                for (k = 0; k < n_density; ++k)
 
408
                    si_mixw[i][j][k] /= sum_si_mixw;
 
409
            }
 
410
        }
 
411
    }
 
412
 
 
413
    /* Read SI transition matrices. */
 
414
    /* FIXME: We may want to normalize these if we do more interesting
 
415
     * estimation of the eta hyperparameters (i.e. using tau) */
 
416
    if (si_tmat_fn) {
 
417
        tpfloor = cmd_ln_float32("-tpfloor");
 
418
        if (s3tmat_read(si_tmat_fn, &si_tmat, &n_tmat, &n_state)
 
419
            != S3_SUCCESS)
 
420
            E_FATAL("Couldn't read %s\n", si_tmat_fn);
 
421
    }
 
422
 
 
423
    /* Read observation counts. */
 
424
    for (i = 0; accum_dir[i]; ++i) {
 
425
        E_INFO("Reading and accumulating observation counts from %s\n",
 
426
               accum_dir[i]);
 
427
        if (rdacc_den(accum_dir[i],
 
428
                      &wt_mean,
 
429
                      &wt_var,          
 
430
                      &pass2var,        
 
431
                      &wt_dcount,
 
432
                      &n_cb_rd,
 
433
                      &n_stream_rd,
 
434
                      &n_density_rd,
 
435
                      &veclen_rd) != S3_SUCCESS)
 
436
            E_FATAL("Error in reading densities from %s\n", accum_dir[i]);
 
437
        check_consistency(accum_dir[i],
 
438
                          n_cb, n_cb_rd, n_stream, n_stream_rd,
 
439
                          n_density, n_density_rd, veclen, veclen_rd);
 
440
        if (pass2var && map_var_fn)
 
441
            E_FATAL("Variance re-estimation requested, but -2passvar was specified in bw.");
 
442
        if (map_mixw_fn || !cmd_ln_int32("-fixedtau")) {
 
443
            if (rdacc_mixw(accum_dir[i],
 
444
                           &wt_mixw,
 
445
                           &n_mixw_rd, &n_stream_rd, &n_density_rd) != S3_SUCCESS)
 
446
                E_FATAL("Error in reading mixture weights from %s\n", accum_dir[i]);
 
447
            check_consistency(accum_dir[i],
 
448
                              n_mixw, n_mixw_rd, n_stream, n_stream_rd,
 
449
                              n_density, n_density_rd, veclen, veclen_rd);
 
450
        }
 
451
        if (map_tmat_fn) {
 
452
            if (rdacc_tmat(accum_dir[i],
 
453
                           &wt_tmat,
 
454
                           &n_tmat_rd, &n_state_rd) != S3_SUCCESS)
 
455
                E_FATAL("Error in reading transition matrices from %s\n", accum_dir[i]);
 
456
            if (n_tmat_rd != n_tmat || n_state_rd != n_state)
 
457
                E_FATAL("Mimsatch in tranition matrices from %s\n", accum_dir[i]);
 
458
        }
 
459
    }
 
460
    ckd_free(veclen_rd);
 
461
 
 
462
    /* Allocate MAP parameters */
 
463
    map_mean  = gauden_alloc_param(n_cb, n_stream, n_density, veclen);
 
464
    if (map_var_fn)
 
465
        map_var = gauden_alloc_param(n_cb, n_stream, n_density, veclen);
 
466
    if (map_mixw_fn)
 
467
        map_mixw = (float32 ***)ckd_calloc_3d(n_mixw, n_stream, n_density, sizeof(float32));
 
468
    if (map_tmat_fn)
 
469
        map_tmat = (float32 ***)ckd_calloc_3d(n_tmat, n_state-1, n_state, sizeof(float32));
 
470
 
 
471
    /* Optionally estimate prior tau hyperparameter for each HMM
 
472
     * (all other prior parameters can be derived from it). */
 
473
    if (cmd_ln_int32("-fixedtau")) {
 
474
        fixed_tau = cmd_ln_float32("-tau");
 
475
        E_INFO("tau hyperparameter fixed at %f\n", fixed_tau);
 
476
    }
 
477
    else
 
478
        map_tau = estimate_tau(si_mean, si_var, si_mixw,
 
479
                               n_cb, n_stream, n_density, n_mixw, veclen,
 
480
                               wt_mean, wt_mixw, wt_dcount);
 
481
 
 
482
    /* Re-estimate mixture weights. */
 
483
    if (map_mixw) {
 
484
        map_mixw_reest(map_tau, fixed_tau,
 
485
                       si_mixw, wt_mixw, map_mixw, mwfloor,
 
486
                       n_mixw, n_stream, n_density);
 
487
    }
 
488
 
 
489
    /* Re-estimate transition matrices. */
 
490
    if (map_tmat)
 
491
        map_tmat_reest(si_tmat, wt_tmat, map_tmat, tpfloor,
 
492
                       n_tmat, n_state);
 
493
 
 
494
    /* Re-estimate means and variances */
 
495
    if (cmd_ln_int32("-bayesmean"))
 
496
        E_INFO("Re-estimating means using Bayesian interpolation\n");
 
497
    else
 
498
        E_INFO("Re-estimating means using MAP\n");
 
499
    if (n_mixw != n_cb && n_cb == 1)
 
500
        E_INFO("Interpolating tau hyperparameter for semi-continuous models\n");
 
501
    if (map_var)
 
502
        E_INFO("Re-estimating variances using MAP\n");
 
503
 
 
504
    for (i = 0; i < n_cb; ++i) {
 
505
        for (j = 0; j < n_stream; ++j) {
 
506
            for (k = 0; k < n_density; ++k) {
 
507
                float32 tau;
 
508
 
 
509
                if (map_tau == NULL)
 
510
                    tau = fixed_tau;
 
511
                else {
 
512
                    /* Interpolate tau for semi-continuous models. */
 
513
                    if (n_mixw != n_cb && n_cb == 1) {
 
514
                        int m;
 
515
 
 
516
                        tau = 0.0f;
 
517
                        for (m = 0; m < n_mixw; ++m)
 
518
                            tau += map_tau[m][j][k];
 
519
                        tau /= n_mixw;
 
520
#if 0
 
521
                        printf("SC tau[%d][%d] = %f\n", j, k, tau);
 
522
#endif
 
523
                    }
 
524
                    else /* Continuous. */
 
525
                        tau = map_tau[i][j][k];
 
526
                }
 
527
 
 
528
                /* Means re-estimation. */
 
529
                if (cmd_ln_int32("-bayesmean"))
 
530
                    bayes_mean_reest(si_mean, si_var,
 
531
                                     wt_mean, wt_var,
 
532
                                     wt_dcount, pass2var,
 
533
                                     map_mean, varfloor,
 
534
                                     i, j, k, veclen);
 
535
                else
 
536
                    map_mean_reest(tau, si_mean, wt_mean, wt_dcount,
 
537
                                   map_mean, i, j, k, veclen);
 
538
 
 
539
 
 
540
                /* Variance re-estimation.  Doesn't work with
 
541
                 * -2passvar, and in many cases this can actually
 
542
                 * degrade accuracy, so use it with caution. */
 
543
                if (map_var)
 
544
                    map_var_reest(tau, si_mean, si_var, wt_mean, wt_var,
 
545
                                  wt_dcount, map_mean, map_var, varfloor,
 
546
                                  i, j, k, veclen);
 
547
            }
 
548
        }
 
549
    }
 
550
 
 
551
    if (map_mean_fn)
 
552
        if (s3gau_write(map_mean_fn,
 
553
                        (const vector_t ***)map_mean,
 
554
                        n_cb,
 
555
                        n_stream,
 
556
                        n_density,
 
557
                        veclen) != S3_SUCCESS)
 
558
            E_FATAL("Unable to write MAP mean to %s\n",map_mean_fn);
 
559
 
 
560
    if (map_var && map_var_fn)
 
561
        if (s3gau_write(map_var_fn,
 
562
                        (const vector_t ***)map_var,
 
563
                        n_cb,
 
564
                        n_stream,
 
565
                        n_density,
 
566
                        veclen) != S3_SUCCESS)
 
567
            E_FATAL("Unable to write MAP variance to %s\n",map_var_fn);
 
568
 
 
569
    if (map_mixw && map_mixw_fn)
 
570
        if (s3mixw_write(map_mixw_fn,
 
571
                         map_mixw,
 
572
                         n_mixw,
 
573
                         n_stream,
 
574
                         n_density)!= S3_SUCCESS)
 
575
            E_FATAL("Unable to write MAP mixture weights to %s\n",map_mixw_fn);
 
576
 
 
577
    if (map_tmat && map_tmat_fn)
 
578
        if (s3tmat_write(map_tmat_fn,
 
579
                         map_tmat,
 
580
                         n_tmat,
 
581
                         n_state)!= S3_SUCCESS)
 
582
            E_FATAL("Unable to write MAP transition matrices to %s\n",map_tmat_fn);
 
583
 
 
584
    ckd_free(veclen);
 
585
    gauden_free_param(si_mean);
 
586
    gauden_free_param(si_var);
 
587
    if (si_mixw)
 
588
        ckd_free_3d(si_mixw);
 
589
    if (si_tmat)
 
590
        ckd_free_3d(si_tmat);
 
591
    gauden_free_param(wt_mean);
 
592
    gauden_free_param(wt_var);
 
593
    ckd_free_3d(wt_dcount);
 
594
    if (map_mean)
 
595
        gauden_free_param(map_mean);
 
596
    if (map_var)
 
597
        gauden_free_param(map_var);
 
598
    if (map_tau)
 
599
        ckd_free_3d(map_tau);
 
600
    if (map_mixw)
 
601
        ckd_free_3d(map_mixw);
 
602
    if (map_tmat)
 
603
        ckd_free_3d(map_tmat);
 
604
    
 
605
    return S3_SUCCESS;
 
606
}
 
607
 
 
608
int
 
609
main(int argc, char *argv[])
 
610
{
 
611
    /* define, parse and (partially) validate the command line */
 
612
    parse_cmd_ln(argc, argv);
 
613
 
 
614
    if (map_update() != S3_SUCCESS) {
 
615
        exit(1);
 
616
    }
 
617
 
 
618
    exit(0);
 
619
}