~ubuntu-branches/ubuntu/raring/sphinxtrain/raring-proposed

« back to all changes in this revision

Viewing changes to src/programs/bw/main.c

  • Committer: Package Import Robot
  • Author(s): Samuel Thibault
  • Date: 2013-01-02 04:10:21 UTC
  • Revision ID: package-import@ubuntu.com-20130102041021-ynsizmz33fx02hea
Tags: upstream-1.0.8
ImportĀ upstreamĀ versionĀ 1.0.8

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/* ====================================================================
 
2
 * Copyright (c) 1995-2000 Carnegie Mellon University.  All rights 
 
3
 * reserved.
 
4
 *
 
5
 * Redistribution and use in source and binary forms, with or without
 
6
 * modification, are permitted provided that the following conditions
 
7
 * are met:
 
8
 *
 
9
 * 1. Redistributions of source code must retain the above copyright
 
10
 *    notice, this list of conditions and the following disclaimer. 
 
11
 *
 
12
 * 2. Redistributions in binary form must reproduce the above copyright
 
13
 *    notice, this list of conditions and the following disclaimer in
 
14
 *    the documentation and/or other materials provided with the
 
15
 *    distribution.
 
16
 *
 
17
 * This work was supported in part by funding from the Defense Advanced 
 
18
 * Research Projects Agency and the National Science Foundation of the 
 
19
 * United States of America, and the CMU Sphinx Speech Consortium.
 
20
 *
 
21
 * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
 
22
 * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
 
23
 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 
24
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
 
25
 * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 
26
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
 
27
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
 
28
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
 
29
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
 
30
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
 
31
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
32
 *
 
33
 * ====================================================================
 
34
 *
 
35
 */
 
36
/*********************************************************************
 
37
 *
 
38
 * File: main.c
 
39
 * 
 
40
 * Description: 
 
41
 *      This is the top level routine for SPHINX-III Baum-Welch
 
42
 *      reestimation.
 
43
 * 
 
44
 * Author: 
 
45
 *      Eric Thayer (eht+@cmu.edu) 20-Jun-95
 
46
 * 
 
47
 *********************************************************************/
 
48
 
 
49
#include "train_cmd_ln.h"
 
50
#include "forward.h"
 
51
#include "viterbi.h"
 
52
#include "next_utt_states.h"
 
53
#include "baum_welch.h"
 
54
#include "accum.h"
 
55
 
 
56
#include <s3/common.h>
 
57
#include <s3/mk_phone_list.h>
 
58
#include <s3/cvt2triphone.h>
 
59
#include <s3/mk_sseq.h>
 
60
#include <s3/mk_trans_seq.h>
 
61
#include <s3/model_inventory.h>
 
62
#include <s3/model_def_io.h>
 
63
#include <s3/s3ts2cb_io.h>
 
64
#include <s3/mllr.h>
 
65
#include <s3/mllr_io.h>
 
66
#include <s3/ts2cb.h>
 
67
#include <s3/s3cb2mllr_io.h>
 
68
#include <sys_compat/misc.h>
 
69
#include <sys_compat/time.h>
 
70
#include <sys_compat/file.h>
 
71
 
 
72
#include <sphinxbase/ckd_alloc.h>
 
73
#include <sphinxbase/profile.h>
 
74
#include <sphinxbase/feat.h>
 
75
 
 
76
#include <stdio.h>
 
77
#include <stdlib.h>
 
78
#include <string.h>
 
79
#include <math.h>
 
80
#include <assert.h>
 
81
 
 
82
#define DUMP_RETRY_PERIOD       3       /* If a count dump fails, retry every # of sec's */
 
83
 
 
84
/* the following parameters are used for MMIE training */
 
85
#define LOG_ZERO        -1.0E10
 
86
static float32 lm_scale = 11.5;
 
87
 
 
88
/* FIXME: Should go in libutil */
 
89
static char *
 
90
string_join(const char *base, ...)
 
91
{
 
92
    va_list args;
 
93
    size_t len;
 
94
    const char *c;
 
95
    char *out;
 
96
 
 
97
    va_start(args, base);
 
98
    len = strlen(base);
 
99
    while ((c = va_arg(args, const char *)) != NULL) {
 
100
        len += strlen(c);
 
101
    }
 
102
    len++;
 
103
    va_end(args);
 
104
 
 
105
    out = ckd_calloc(len, 1);
 
106
    va_start(args, base);
 
107
    strcpy(out, base);
 
108
    while ((c = va_arg(args, const char *)) != NULL) {
 
109
        strcat(out, c);
 
110
    }
 
111
    va_end(args);
 
112
 
 
113
    return out;
 
114
}
 
115
 
 
116
static void
 
117
print_all_timers(bw_timers_t *timers, int32 n_frame)
 
118
{
 
119
    printf(" utt %4.3fx %4.3fe"
 
120
           " upd %4.3fx %4.3fe"
 
121
           " fwd %4.3fx %4.3fe"
 
122
           " bwd %4.3fx %4.3fe"
 
123
           " gau %4.3fx %4.3fe"
 
124
           " rsts %4.3fx %4.3fe"
 
125
           " rstf %4.3fx %4.3fe"
 
126
           " rstu %4.3fx %4.3fe",
 
127
 
 
128
        timers->utt_timer.t_cpu/(n_frame*0.01),
 
129
        (timers->utt_timer.t_cpu > 0 ? timers->utt_timer.t_elapsed / timers->utt_timer.t_cpu : 0.0),
 
130
 
 
131
        timers->upd_timer.t_cpu/(n_frame*0.01),
 
132
        (timers->upd_timer.t_cpu > 0 ? timers->upd_timer.t_elapsed / timers->upd_timer.t_cpu : 0.0),
 
133
 
 
134
        timers->fwd_timer.t_cpu/(n_frame*0.01),
 
135
        (timers->fwd_timer.t_cpu > 0 ? timers->fwd_timer.t_elapsed / timers->fwd_timer.t_cpu : 0.0),
 
136
 
 
137
        timers->bwd_timer.t_cpu/(n_frame*0.01),
 
138
        (timers->bwd_timer.t_cpu > 0 ? timers->bwd_timer.t_elapsed / timers->bwd_timer.t_cpu : 0.0),
 
139
 
 
140
        timers->gau_timer.t_cpu/(n_frame*0.01),
 
141
        (timers->gau_timer.t_cpu > 0 ? timers->gau_timer.t_elapsed / timers->gau_timer.t_cpu : 0.0),
 
142
 
 
143
        timers->rsts_timer.t_cpu/(n_frame*0.01),
 
144
        (timers->rsts_timer.t_cpu > 0 ? timers->rsts_timer.t_elapsed / timers->rsts_timer.t_cpu : 0.0),
 
145
 
 
146
        timers->rstf_timer.t_cpu/(n_frame*0.01),
 
147
        (timers->rstf_timer.t_cpu > 0 ? timers->rstf_timer.t_elapsed / timers->rstf_timer.t_cpu : 0.0),
 
148
 
 
149
        timers->rstu_timer.t_cpu/(n_frame*0.01),
 
150
        (timers->rstu_timer.t_cpu > 0 ? timers->rstu_timer.t_elapsed / timers->rstu_timer.t_cpu : 0.0));
 
151
    printf("\n");
 
152
}
 
153
 
 
154
 
 
155
/*********************************************************************
 
156
 *
 
157
 * Function: 
 
158
 *      main_initialize
 
159
 * 
 
160
 * Description: 
 
161
 *      Construct data structures and precompute values necessary
 
162
 *      for Baum-Welch reestimation.
 
163
 *
 
164
 * Function Inputs: 
 
165
 *      - argc
 
166
 *              The number of command line arguments
 
167
 *      - argv
 
168
 *              Array of command line argument strings
 
169
 *      - out_inv
 
170
 *              The model inventory data structure created
 
171
 *              by this routine.  (see libmodinv/modinv.c)
 
172
 *      - lex
 
173
 *              A word -> phone dictionary for the training set.
 
174
 *
 
175
 * Global Inputs: 
 
176
 *      None
 
177
 * 
 
178
 * Return Values: 
 
179
 *      - S3_SUCCESS
 
180
 *              This value is returned when no error condition
 
181
 *              has been detected.
 
182
 *      - S3_ERROR
 
183
 *              This value is returned when an error condition
 
184
 *              has been detected.
 
185
 * 
 
186
 * Global Outputs: 
 
187
 *      None
 
188
 *
 
189
 * Errors: 
 
190
 * 
 
191
 * Pre-Conditions: 
 
192
 * 
 
193
 * Post-Conditions: 
 
194
 *
 
195
 *********************************************************************/
 
196
 
 
197
static int
 
198
main_initialize(int argc,
 
199
                char *argv[],
 
200
                model_inventory_t **out_inv,
 
201
                lexicon_t **out_lex,
 
202
                model_def_t **out_mdef,
 
203
                feat_t **out_feat)
 
204
{
 
205
    model_inventory_t *inv;     /* the model inventory */
 
206
    lexicon_t *lex;             /* the lexicon to be returned to the caller */
 
207
    model_def_t *mdef;
 
208
    feat_t *feat;
 
209
    uint32 n_map;
 
210
    uint32 n_ts;
 
211
    uint32 n_cb;
 
212
    uint32 n_mllr;
 
213
    int mixw_reest;
 
214
    int tmat_reest;
 
215
    int mean_reest;
 
216
    int var_reest;
 
217
    int did_restore = FALSE;
 
218
    const char *fn;
 
219
    int32 *mllr_idx = NULL;
 
220
    const char *hmmdir;
 
221
    const char *mdeffn, *meanfn, *varfn, *mixwfn, *tmatfn, *fdictfn;
 
222
 
 
223
    /* Note these are forward transforms for use
 
224
       in training.  The inverse transform of the accumulators is now
 
225
       done externally by mllr_transform. */
 
226
    float32 ****sxfrm_a = NULL;
 
227
    float32 ***sxfrm_b = NULL;
 
228
 
 
229
    E_INFO("Compiled on %s at %s\n", __DATE__, __TIME__);
 
230
 
 
231
    /* define, parse and (partially) validate the command line */
 
232
    train_cmd_ln_parse(argc, argv);
 
233
 
 
234
    feat =
 
235
        feat_init(cmd_ln_str("-feat"),
 
236
                  cmn_type_from_str(cmd_ln_str("-cmn")),
 
237
                  cmd_ln_boolean("-varnorm"),
 
238
                  agc_type_from_str(cmd_ln_str("-agc")),
 
239
                  1, cmd_ln_int32("-ceplen"));
 
240
    *out_feat = feat;
 
241
 
 
242
 
 
243
    if (cmd_ln_str("-lda")) {
 
244
        E_INFO("Reading linear feature transformation from %s\n",
 
245
               cmd_ln_str("-lda"));
 
246
        if (feat_read_lda(feat,
 
247
                          cmd_ln_str("-lda"),
 
248
                          cmd_ln_int32("-ldadim")) < 0)
 
249
            return -1;
 
250
    }
 
251
 
 
252
    if (cmd_ln_str("-svspec")) {
 
253
        int32 **subvecs;
 
254
        E_INFO("Using subvector specification %s\n", 
 
255
               cmd_ln_str("-svspec"));
 
256
        if ((subvecs = parse_subvecs(cmd_ln_str("-svspec"))) == NULL)
 
257
            return -1;
 
258
        if ((feat_set_subvecs(feat, subvecs)) < 0)
 
259
            return -1;
 
260
    }
 
261
 
 
262
    if (cmd_ln_exists("-agcthresh")
 
263
        && 0 != strcmp(cmd_ln_str("-agc"), "none")) {
 
264
        agc_set_threshold(feat->agc_struct,
 
265
                          cmd_ln_float32("-agcthresh"));
 
266
    }
 
267
 
 
268
    if (feat->cmn_struct
 
269
        && cmd_ln_exists("-cmninit")) {
 
270
        char *c, *cc, *vallist;
 
271
        int32 nvals;
 
272
 
 
273
        vallist = ckd_salloc(cmd_ln_str("-cmninit"));
 
274
        c = vallist;
 
275
        nvals = 0;
 
276
        while (nvals < feat->cmn_struct->veclen
 
277
               && (cc = strchr(c, ',')) != NULL) {
 
278
            *cc = '\0';
 
279
            feat->cmn_struct->cmn_mean[nvals] = FLOAT2MFCC(atof(c));
 
280
            c = cc + 1;
 
281
            ++nvals;
 
282
        }
 
283
        if (nvals < feat->cmn_struct->veclen && *c != '\0') {
 
284
            feat->cmn_struct->cmn_mean[nvals] = FLOAT2MFCC(atof(c));
 
285
        }
 
286
        ckd_free(vallist);
 
287
    }
 
288
 
 
289
 
 
290
    /* create a new model inventory structure */
 
291
    *out_inv = inv = mod_inv_new();
 
292
 
 
293
    mod_inv_set_n_feat(inv, feat_dimension1(feat));
 
294
 
 
295
    mdeffn = cmd_ln_str("-moddeffn");
 
296
    meanfn = cmd_ln_str("-meanfn");
 
297
    varfn = cmd_ln_str("-varfn");
 
298
    mixwfn = cmd_ln_str("-mixwfn");
 
299
    tmatfn = cmd_ln_str("-tmatfn");
 
300
    fdictfn = cmd_ln_str("-fdictfn");
 
301
 
 
302
    /* Note: this will leak a small amount of memory but we really
 
303
     * don't care. */
 
304
    if ((hmmdir = cmd_ln_str("-hmmdir")) != NULL) {
 
305
        if (mdeffn == NULL)
 
306
            mdeffn = string_join(hmmdir, "/mdef", NULL);
 
307
        if (meanfn == NULL)
 
308
            meanfn = string_join(hmmdir, "/means", NULL);
 
309
        if (varfn == NULL)
 
310
            varfn = string_join(hmmdir, "/variances", NULL);
 
311
        if (mixwfn == NULL)
 
312
            mixwfn = string_join(hmmdir, "/mixture_weights", NULL);
 
313
        if (tmatfn == NULL)
 
314
            tmatfn = string_join(hmmdir, "/transition_matrices", NULL);
 
315
        if (fdictfn == NULL)
 
316
            fdictfn = string_join(hmmdir, "/noisedict", NULL);
 
317
    }
 
318
    E_INFO("Reading %s\n", mdeffn);
 
319
    
 
320
    /* Read in the model definitions.  Defines the set of
 
321
       CI phones and context dependent phones.  Defines the
 
322
       transition matrix tying and state level tying. */
 
323
    if (model_def_read(&mdef, mdeffn) != S3_SUCCESS) {
 
324
        return S3_ERROR;
 
325
    }
 
326
 
 
327
    *out_mdef = mdef;
 
328
 
 
329
    fn = cmd_ln_str("-ts2cbfn");
 
330
    if (fn == NULL) {
 
331
        E_FATAL("Specify -ts2cbfn\n");
 
332
    }
 
333
    if (strcmp(fn, SEMI_LABEL) == 0) {
 
334
        mdef->cb = semi_ts2cb(mdef->n_tied_state);
 
335
        n_ts = mdef->n_tied_state;
 
336
        n_cb = 1;
 
337
    }
 
338
    else if (strcmp(fn, CONT_LABEL) == 0) {
 
339
        mdef->cb = cont_ts2cb(mdef->n_tied_state);
 
340
        n_ts = mdef->n_tied_state;
 
341
        n_cb = mdef->n_tied_state;
 
342
    }
 
343
    else if (strcmp(fn, PTM_LABEL) == 0) {
 
344
      mdef->cb = ptm_ts2cb(mdef);
 
345
      n_ts = mdef->n_tied_state;
 
346
      n_cb = mdef->acmod_set->n_ci;
 
347
    }
 
348
    else if (s3ts2cb_read(fn,
 
349
                          &mdef->cb,
 
350
                          &n_ts,
 
351
                          &n_cb) != S3_SUCCESS) {
 
352
        return S3_ERROR;
 
353
    }
 
354
 
 
355
    inv->acmod_set = mdef->acmod_set;
 
356
    inv->mdef = mdef;
 
357
 
 
358
    if (mod_inv_read_mixw(inv, mdef, mixwfn,
 
359
                          cmd_ln_float32("-mwfloor")) != S3_SUCCESS)
 
360
        return S3_ERROR;
 
361
    
 
362
    if (n_ts != inv->n_mixw) {
 
363
        E_WARN("%u mappings from tied-state to cb, but %u tied-state in %s\n",
 
364
               mdef->n_cb, inv->n_mixw, mixwfn);
 
365
    }
 
366
 
 
367
    if (mod_inv_read_tmat(inv, tmatfn,
 
368
                          cmd_ln_float32("-tpfloor")) != S3_SUCCESS)
 
369
        return S3_ERROR;
 
370
 
 
371
    if (mod_inv_read_gauden(inv, meanfn, varfn,
 
372
                            cmd_ln_float32("-varfloor"),
 
373
                            cmd_ln_int32("-topn"),
 
374
                            cmd_ln_int32("-fullvar")) != S3_SUCCESS) {
 
375
            if (!cmd_ln_int32("-fullvar")) {
 
376
                    return S3_ERROR;
 
377
            }
 
378
            else {
 
379
                    /* If reading full variances failed, try reading
 
380
                     * them as diagonal variances (allows us to
 
381
                     * initialize full vars from diagonal ones) */
 
382
                    if (mod_inv_read_gauden(inv, meanfn, varfn,
 
383
                                            cmd_ln_float32("-varfloor"),
 
384
                                            cmd_ln_int32("-topn"),
 
385
                                            FALSE) != S3_SUCCESS) {
 
386
                            return S3_ERROR;
 
387
                    }
 
388
            }
 
389
            
 
390
    }
 
391
 
 
392
    /* If we want to use diagonals only, and we didn't read diagonals
 
393
     * above, then we have to extract them here. */
 
394
    if (cmd_ln_int32("-diagfull") && inv->gauden->var == NULL) {
 
395
            /* Extract diagonals and use them for Gaussian computation. */
 
396
            gauden_t *g;
 
397
            uint32 i, j, k, l;
 
398
 
 
399
            g = inv->gauden;
 
400
            g->var = gauden_alloc_param(g->n_mgau,
 
401
                                        g->n_feat,
 
402
                                        g->n_density,
 
403
                                        g->veclen);
 
404
            for (i = 0; i < g->n_mgau; ++i)
 
405
                    for (j = 0; j < g->n_feat; ++j)
 
406
                            for (k = 0; k < g->n_density; ++k)
 
407
                                    for (l = 0; l < g->veclen[j]; ++l)
 
408
                                            g->var[i][j][k][l] =
 
409
                                                    g->fullvar[i][j][k][l][l];
 
410
            gauden_free_param_full(g->fullvar);
 
411
            g->fullvar = NULL;
 
412
            gauden_floor_variance(g);
 
413
    }
 
414
    
 
415
    if (gauden_eval_precomp(inv->gauden) != S3_SUCCESS) {
 
416
        E_ERROR("Problems precomputing values used during Gaussian density evaluation\n");
 
417
 
 
418
        return S3_ERROR;
 
419
    }
 
420
 
 
421
    if (inv->gauden->n_mgau != n_cb) {
 
422
        printf("# of codebooks in mean/var files, %u, inconsistent with ts2cb mapping %u\n", inv->gauden->n_mgau, n_cb);
 
423
    }
 
424
 
 
425
    mixw_reest = cmd_ln_int32("-mixwreest");
 
426
    mean_reest = cmd_ln_int32("-meanreest");
 
427
    var_reest  = cmd_ln_int32("-varreest");
 
428
    tmat_reest = cmd_ln_int32("-tmatreest");
 
429
 
 
430
    E_INFO("Will %sreestimate mixing weights.\n",
 
431
           (mixw_reest ? "" : "NOT "));
 
432
    E_INFO("Will %sreestimate means.\n",
 
433
           (mean_reest ? "" : "NOT "));
 
434
    E_INFO("Will %sreestimate variances.\n",
 
435
           (var_reest ? "" : "NOT "));
 
436
 
 
437
    if (cmd_ln_int32("-mixwreest")) {
 
438
        if (mod_inv_alloc_mixw_acc(inv) != S3_SUCCESS)
 
439
            return S3_ERROR;
 
440
    }
 
441
 
 
442
    E_INFO("Will %sreestimate transition matrices\n",
 
443
           (cmd_ln_int32("-tmatreest") ? "" : "NOT "));
 
444
    if (cmd_ln_int32("-tmatreest")) {
 
445
        if (mod_inv_alloc_tmat_acc(inv) != S3_SUCCESS)
 
446
            return S3_ERROR;
 
447
    }
 
448
 
 
449
    if (cmd_ln_int32("-meanreest") ||
 
450
        cmd_ln_int32("-varreest")) {
 
451
        if (mod_inv_alloc_gauden_acc(inv) != S3_SUCCESS)
 
452
            return S3_ERROR;
 
453
    }
 
454
 
 
455
    E_INFO("Reading main lexicon: %s\n",
 
456
           cmd_ln_str("-dictfn"));
 
457
 
 
458
    lex = lexicon_read(NULL,
 
459
                       cmd_ln_str("-dictfn"),
 
460
                       mdef->acmod_set);
 
461
    if (lex == NULL)
 
462
        return S3_ERROR;
 
463
    
 
464
    if (fdictfn) {
 
465
        E_INFO("Reading filler lexicon: %s\n",
 
466
               fdictfn);
 
467
        (void)lexicon_read(lex,
 
468
                           fdictfn,
 
469
                           mdef->acmod_set);
 
470
    }
 
471
 
 
472
    *out_lex = lex;
 
473
 
 
474
 
 
475
    /*
 
476
     * Configure corpus module (controls sequencing/access of per utterance data)
 
477
     */
 
478
 
 
479
    /* set the data directory and extension for cepstrum files */
 
480
    corpus_set_mfcc_dir(cmd_ln_str("-cepdir"));
 
481
    corpus_set_mfcc_ext(cmd_ln_str("-cepext"));
 
482
 
 
483
    if (cmd_ln_str("-lsnfn")) {
 
484
        /* use a LSN file which has all the transcripts */
 
485
        corpus_set_lsn_filename(cmd_ln_str("-lsnfn"));
 
486
    }
 
487
    else {
 
488
        /* set the data directory and extension for word transcript
 
489
           files */
 
490
        corpus_set_sent_dir(cmd_ln_str("-sentdir"));
 
491
        corpus_set_sent_ext(cmd_ln_str("-sentext"));
 
492
    }
 
493
 
 
494
    if (cmd_ln_str("-ctlfn")) {
 
495
        corpus_set_ctl_filename(cmd_ln_str("-ctlfn"));
 
496
    }
 
497
 
 
498
    if (cmd_ln_str("-phsegdir")) {
 
499
            corpus_set_phseg_dir(cmd_ln_str("-phsegdir"));
 
500
            corpus_set_phseg_ext(cmd_ln_str("-phsegext"));
 
501
    }
 
502
 
 
503
    if (cmd_ln_str("-accumdir")) {
 
504
        char fn[MAXPATHLEN+1];
 
505
        FILE *fp;
 
506
 
 
507
        sprintf(fn, "%s/ckpt", cmd_ln_str("-accumdir"));
 
508
        
 
509
        fp = fopen(fn, "r");
 
510
        if (fp != NULL) {
 
511
            const uint32* feat_veclen;
 
512
            fclose(fp);
 
513
 
 
514
            E_INFO("RESTORING CHECKPOINTED COUNTS IN %s\n", cmd_ln_str("-accumdir"));
 
515
            
 
516
            feat_veclen = (uint32 *)feat_stream_lengths(feat);
 
517
                    
 
518
            if (mod_inv_restore_acc(inv,
 
519
                                    cmd_ln_str("-accumdir"),
 
520
                                    mixw_reest,
 
521
                                    mean_reest,
 
522
                                    var_reest,
 
523
                                    tmat_reest,
 
524
                                    feat_veclen) != S3_SUCCESS) {
 
525
                E_FATAL("Unable to restore checkpoint information\n");
 
526
            }
 
527
 
 
528
            if (corpus_ckpt_set_interval(fn) != S3_SUCCESS) {
 
529
                E_FATAL("Unable to restore corpus state information\n");
 
530
            }
 
531
            
 
532
            E_INFO("Resuming at utt %u\n", corpus_get_begin());
 
533
            did_restore = TRUE;
 
534
        }
 
535
    }
 
536
 
 
537
    if (!did_restore) {
 
538
        if (cmd_ln_int32("-nskip") && cmd_ln_int32("-runlen")) {
 
539
            corpus_set_interval(cmd_ln_int32("-nskip"),
 
540
                            cmd_ln_int32("-runlen"));
 
541
        } else if (cmd_ln_int32("-part") && cmd_ln_int32("-npart")) {
 
542
            corpus_set_partition(cmd_ln_int32("-part"),
 
543
                             cmd_ln_int32("-npart"));
 
544
        }
 
545
    }
 
546
 
 
547
    /* BEWARE: this function call must be done after all the other corpus
 
548
       configuration */
 
549
    corpus_init();
 
550
 
 
551
    if (cmd_ln_str("-mllrmat")) {
 
552
        uint32 *tmp_veclen, *feat_veclen;
 
553
        uint32 tmp_n_mllrcls;
 
554
        uint32 tmp_n_stream;
 
555
        uint32 j;
 
556
 
 
557
        if (read_reg_mat(cmd_ln_str("-mllrmat"),
 
558
                         &tmp_veclen,
 
559
                         &tmp_n_mllrcls,
 
560
                         &tmp_n_stream,
 
561
                         &sxfrm_a, &sxfrm_b) != S3_SUCCESS) {
 
562
            E_FATAL("Unable to read %s\n", cmd_ln_str("-mllrmat"));
 
563
        }
 
564
 
 
565
        if (feat_dimension1(feat) != tmp_n_stream) {
 
566
            E_FATAL("# feature streams in -mllrmat %s != # feature streams configured on cmd ln\n");
 
567
        }
 
568
        
 
569
        feat_veclen = (uint32 *)feat_stream_lengths(feat);
 
570
 
 
571
        for (j = 0; j < tmp_n_stream; j++) {
 
572
            if (feat_veclen[j] != tmp_veclen[j]) {
 
573
                E_FATAL("# components of stream %u in -mllrmat inconsistent w/ -feat config (%u != %u)\n",
 
574
                        j, tmp_veclen[j], feat_veclen[j]);
 
575
            }
 
576
        }
 
577
        ckd_free((void *)tmp_veclen);
 
578
 
 
579
        fn = cmd_ln_str("-cb2mllrfn");
 
580
        if (fn != NULL) {
 
581
            if (strcmp(fn, ".1cls.") == 0) {
 
582
                mllr_idx = ckd_calloc(inv->gauden->n_mgau, sizeof(int32));
 
583
                n_mllr = 1;
 
584
                n_map = inv->gauden->n_mgau;
 
585
            }
 
586
            else if (s3cb2mllr_read(cmd_ln_str("-cb2mllrfn"),
 
587
                                    &mllr_idx,
 
588
                                    &n_map,
 
589
                                    &n_mllr) != S3_SUCCESS) {
 
590
                return S3_ERROR;
 
591
            }
 
592
            if (n_map != inv->gauden->n_mgau) {
 
593
                E_FATAL("cb2mllr maps %u cb, but read %u cb from files\n",
 
594
                        n_map, inv->gauden->n_mgau);
 
595
            }
 
596
        }
 
597
 
 
598
        /* Transform the means using the speaker transform if available. */
 
599
        mllr_transform_mean(inv->gauden->mean,
 
600
                            inv->gauden->var,
 
601
                            0, inv->gauden->n_mgau,
 
602
                            inv->gauden->n_feat,
 
603
                            inv->gauden->n_density,
 
604
                            inv->gauden->veclen,
 
605
                            sxfrm_a, sxfrm_b,
 
606
                            mllr_idx, n_mllr);
 
607
        ckd_free(mllr_idx);
 
608
        free_mllr_A(sxfrm_a, n_mllr, tmp_n_stream);
 
609
        free_mllr_B(sxfrm_b, n_mllr, tmp_n_stream);
 
610
    }
 
611
 
 
612
    return S3_SUCCESS;
 
613
}
 
614
 
 
615
void
 
616
main_reestimate(model_inventory_t *inv,
 
617
                lexicon_t *lex,
 
618
                model_def_t *mdef,
 
619
                feat_t *feat,
 
620
                int32 viterbi)
 
621
{
 
622
    vector_t *mfcc;     /* utterance cepstra */ 
 
623
    int32 n_frame;      /* # of cepstrum frames  */
 
624
    uint32 svd_n_frame; /* # of cepstrum frames  */
 
625
    vector_t **f;               /* independent feature streams derived
 
626
                                 * from cepstra */
 
627
    state_t *state_seq;         /* sentence HMM state sequence for the
 
628
                                   utterance */
 
629
    uint32 n_state = 0; /* # of sentence HMM states */
 
630
    float64 total_log_lik;      /* total log liklihood over corpus */
 
631
    float64 log_lik;            /* log liklihood for an utterance */
 
632
    uint32 total_frames;        /* # of frames over the corpus */
 
633
    float64 a_beam;             /* alpha pruning beam */
 
634
    float64 b_beam;             /* beta pruning beam */
 
635
    float32 spthresh;           /* state posterior probability threshold */
 
636
    uint32 seq_no;      /* sequence # of utterance in corpus */
 
637
    uint32 mixw_reest;  /* if TRUE, reestimate mixing weights */
 
638
    uint32 tmat_reest;  /* if TRUE, reestimate transition probability matrices */
 
639
    uint32 mean_reest;  /* if TRUE, reestimate means */
 
640
    uint32 var_reest;   /* if TRUE, reestimate variances */
 
641
    char *trans;
 
642
    const char *pdumpdir;
 
643
    FILE *pdumpfh;
 
644
    uint32 in_veclen;
 
645
 
 
646
    bw_timers_t* timers = NULL;
 
647
    int32 profile;
 
648
 
 
649
    int32 pass2var;
 
650
    int32 var_is_full;
 
651
 
 
652
    uint32 n_utt;
 
653
 
 
654
    s3phseg_t *phseg = NULL;
 
655
 
 
656
    uint32 maxuttlen;
 
657
    uint32 n_frame_skipped = 0;
 
658
 
 
659
    uint32 ckpt_intv = 0;
 
660
    uint32 no_retries = 0;
 
661
 
 
662
    uint32 outputfullpath = 0;
 
663
    uint32 fullsuffixmatch = 0;
 
664
 
 
665
    E_INFO("Reestimation: %s\n",
 
666
        (viterbi ? "Viterbi" : "Baum-Welch"));
 
667
 
 
668
    profile = cmd_ln_int32("-timing");
 
669
    if (profile) {
 
670
        E_INFO("Generating profiling information consumes significant CPU resources.\n");
 
671
        E_INFO("If you are not interested in profiling, use -timing no\n");
 
672
    }
 
673
    outputfullpath = cmd_ln_int32("-outputfullpath");
 
674
    fullsuffixmatch = cmd_ln_int32("-fullsuffixmatch");
 
675
 
 
676
    corpus_set_full_suffix_match(fullsuffixmatch);
 
677
 
 
678
    if (profile) {
 
679
        timers = ckd_calloc(1, sizeof(bw_timers_t));
 
680
        ptmr_init(&timers->utt_timer);
 
681
        ptmr_init(&timers->upd_timer);
 
682
        ptmr_init(&timers->fwd_timer);
 
683
        ptmr_init(&timers->bwd_timer);
 
684
        ptmr_init(&timers->gau_timer);
 
685
        ptmr_init(&timers->rsts_timer);
 
686
        ptmr_init(&timers->rstf_timer);
 
687
        ptmr_init(&timers->rstu_timer);
 
688
    }
 
689
 
 
690
    mixw_reest = cmd_ln_int32("-mixwreest");
 
691
    tmat_reest = cmd_ln_int32("-tmatreest");
 
692
    mean_reest = cmd_ln_int32("-meanreest");
 
693
    var_reest = cmd_ln_int32("-varreest");
 
694
    pass2var = cmd_ln_int32("-2passvar");
 
695
    var_is_full = cmd_ln_int32("-fullvar");
 
696
    pdumpdir = cmd_ln_str("-pdumpdir");
 
697
    in_veclen = cmd_ln_int32("-ceplen");
 
698
 
 
699
    if (cmd_ln_str("-ckptintv")) {
 
700
        ckpt_intv = cmd_ln_int32("-ckptintv");
 
701
    }
 
702
 
 
703
    if (cmd_ln_str("-accumdir") == NULL) {
 
704
        E_WARN("NO ACCUMDIR SET.  No counts will be written; assuming debug\n");
 
705
    }
 
706
 
 
707
    if (!mixw_reest && !tmat_reest && !mean_reest && !var_reest) {
 
708
        E_WARN("No reestimation specified!  None done.\n");
 
709
        
 
710
        return;
 
711
    }
 
712
 
 
713
    total_log_lik = 0;
 
714
    total_frames = 0;
 
715
 
 
716
    a_beam = cmd_ln_float64("-abeam");
 
717
    b_beam = cmd_ln_float64("-bbeam");
 
718
    spthresh = cmd_ln_float32("-spthresh");
 
719
    maxuttlen = cmd_ln_int32("-maxuttlen");
 
720
 
 
721
    /* Begin by skipping over some (possibly zero) # of utterances.
 
722
     * Continue to process utterances until there are no more (either EOF
 
723
     * or end of run). */
 
724
 
 
725
    seq_no = corpus_get_begin();
 
726
 
 
727
    printf("column defns\n");
 
728
    printf("\t<seq>\n");
 
729
    printf("\t<id>\n");
 
730
    printf("\t<n_frame_in>\n");
 
731
    printf("\t<n_frame_del>\n");
 
732
    printf("\t<n_state_shmm>\n");
 
733
    printf("\t<avg_states_alpha>\n");
 
734
    if (!cmd_ln_int32("-viterbi")) {
 
735
        printf("\t<avg_states_beta>\n");
 
736
        printf("\t<avg_states_reest>\n");
 
737
        printf("\t<avg_posterior_prune>\n");
 
738
    }
 
739
    printf("\t<frame_log_lik>\n");
 
740
    printf("\t<utt_log_lik>\n");
 
741
    printf("\t... timing info ... \n");
 
742
 
 
743
    n_utt = 0;
 
744
 
 
745
    while (corpus_next_utt()) {
 
746
        /* Zero timers before utt processing begins */
 
747
        if (timers) {
 
748
            ptmr_reset(&timers->utt_timer);
 
749
            ptmr_reset(&timers->upd_timer);
 
750
            ptmr_reset(&timers->fwd_timer);
 
751
            ptmr_reset(&timers->bwd_timer);
 
752
            ptmr_reset(&timers->gau_timer);
 
753
            ptmr_reset(&timers->rsts_timer);
 
754
            ptmr_reset(&timers->rstf_timer);
 
755
            ptmr_reset(&timers->rstu_timer);
 
756
        }
 
757
        
 
758
        if (timers)
 
759
            ptmr_start(&timers->utt_timer);
 
760
 
 
761
        printf("utt> %5u %25s", 
 
762
               seq_no,
 
763
               (outputfullpath ? corpus_utt_full_name() : corpus_utt()));
 
764
 
 
765
        if (corpus_get_generic_featurevec(&mfcc, &n_frame, in_veclen) < 0) {
 
766
                E_FATAL("Can't read input features\n");
 
767
        }
 
768
 
 
769
        printf(" %4u", n_frame);
 
770
 
 
771
        if (n_frame < 9) {
 
772
            E_WARN("utt %s too short\n", corpus_utt());
 
773
            if (mfcc) {
 
774
                ckd_free(mfcc[0]);
 
775
                ckd_free(mfcc);
 
776
            }
 
777
            continue;
 
778
        }
 
779
 
 
780
        if ((maxuttlen > 0) && (n_frame > maxuttlen)) {
 
781
            E_INFO("utt # frames > -maxuttlen; skipping\n");
 
782
            n_frame_skipped += n_frame;
 
783
            if (mfcc) {
 
784
                ckd_free(mfcc[0]);
 
785
                ckd_free(mfcc);
 
786
            }
 
787
 
 
788
            continue;
 
789
        }
 
790
 
 
791
        svd_n_frame = n_frame;
 
792
 
 
793
        /* Hack to not apply the LDA, it will be applied later during accum_dir
 
794
         * Pretty useless thing to be honest, what to do with CMN after that for example?
 
795
         */
 
796
        if (cmd_ln_boolean("-ldaaccum")) {
 
797
            float32 ***lda = feat->lda;
 
798
            feat->lda = NULL;
 
799
            f = feat_array_alloc(feat, n_frame + feat_window_size(feat));
 
800
            feat_s2mfc2feat_live(feat, mfcc, &n_frame, TRUE, TRUE, f);
 
801
            feat->lda = lda;
 
802
        } else {
 
803
            f = feat_array_alloc(feat, n_frame + feat_window_size(feat));
 
804
            feat_s2mfc2feat_live(feat, mfcc, &n_frame, TRUE, TRUE, f);
 
805
        }
 
806
 
 
807
        printf(" %4u", n_frame - svd_n_frame);
 
808
 
 
809
        /* Get the transcript */
 
810
        corpus_get_sent(&trans);
 
811
 
 
812
        /* Get the phone segmentation */
 
813
        corpus_get_phseg(inv->acmod_set, &phseg);
 
814
 
 
815
        /* Open a dump file if required. */
 
816
        if (pdumpdir) {
 
817
                char *pdumpfn, *uttid;
 
818
 
 
819
                uttid = (outputfullpath ? corpus_utt_full_name() : corpus_utt());
 
820
                pdumpfn = ckd_calloc(strlen(pdumpdir) + 1
 
821
                                     + strlen(uttid)
 
822
                                     + strlen(".pdump") + 1, 1);
 
823
                strcpy(pdumpfn, pdumpdir);
 
824
                strcat(pdumpfn, "/");
 
825
                strcat(pdumpfn, uttid);
 
826
                strcat(pdumpfn, ".pdump");
 
827
                if ((pdumpfh = fopen(pdumpfn, "w")) == NULL)
 
828
                        E_FATAL_SYSTEM("Failed to open %s for writing", pdumpfn);
 
829
                ckd_free(pdumpfn);
 
830
        }
 
831
        else
 
832
                pdumpfh = NULL;
 
833
 
 
834
        if (timers)
 
835
            ptmr_start(&timers->upd_timer);
 
836
        /* create a sentence HMM */
 
837
        state_seq = next_utt_states(&n_state, lex, inv, mdef, trans);
 
838
        printf(" %5u", n_state);
 
839
        
 
840
        if (state_seq == NULL) {
 
841
            E_WARN("Skipped utterance '%s'\n", trans);
 
842
        } else if (!viterbi) {
 
843
            /* accumulate reestimation sums for the utterance */
 
844
            if (baum_welch_update(&log_lik,
 
845
                                  f, n_frame,
 
846
                                  state_seq, n_state,
 
847
                                  inv,
 
848
                                  a_beam,
 
849
                                  b_beam,
 
850
                                  spthresh,
 
851
                                  phseg,
 
852
                                  mixw_reest,
 
853
                                  tmat_reest,
 
854
                                  mean_reest,
 
855
                                  var_reest,
 
856
                                  pass2var,
 
857
                                  var_is_full,
 
858
                                  pdumpfh,
 
859
                                  timers,
 
860
                                  feat) == S3_SUCCESS) {
 
861
                total_frames += n_frame;
 
862
                total_log_lik += log_lik;
 
863
                
 
864
                printf(" %e %e",
 
865
                       (n_frame > 0 ? log_lik / n_frame : 0.0),
 
866
                       log_lik);
 
867
            }
 
868
 
 
869
        } else {
 
870
            /* Viterbi search and accumulate in it */
 
871
            if (viterbi_update(&log_lik,
 
872
                               f, n_frame,
 
873
                               state_seq, n_state,
 
874
                               inv,
 
875
                               a_beam,
 
876
                               spthresh,
 
877
                               phseg,
 
878
                               mixw_reest,
 
879
                               tmat_reest,
 
880
                               mean_reest,
 
881
                               var_reest,
 
882
                               pass2var,
 
883
                               var_is_full,
 
884
                               pdumpfh, 
 
885
                               timers,
 
886
                               feat) == S3_SUCCESS) {
 
887
                total_frames += n_frame;
 
888
                total_log_lik += log_lik;
 
889
                printf(" %e %e",
 
890
                       (n_frame > 0 ? log_lik / n_frame : 0.0),
 
891
                       log_lik);
 
892
            }
 
893
        }
 
894
 
 
895
        if (timers)
 
896
            ptmr_stop(&timers->upd_timer);
 
897
 
 
898
        if (pdumpfh)
 
899
                fclose(pdumpfh);
 
900
        free(mfcc[0]);
 
901
        ckd_free(mfcc);
 
902
        feat_array_free(f);
 
903
        free(trans);    /* alloc'ed using strdup() */
 
904
 
 
905
        seq_no++;
 
906
        n_utt++;
 
907
 
 
908
        if (timers)
 
909
            ptmr_stop(&timers->utt_timer);
 
910
    
 
911
        if (profile)
 
912
            print_all_timers(timers, n_frame);
 
913
 
 
914
        printf("\n");
 
915
        fflush(stdout);
 
916
 
 
917
        if ((ckpt_intv > 0) &&
 
918
            ((n_utt % ckpt_intv) == 0) &&
 
919
            (cmd_ln_str("-accumdir") != NULL)) {
 
920
            while (accum_dump(cmd_ln_str("-accumdir"),
 
921
                              inv,
 
922
                              mixw_reest,
 
923
                              tmat_reest,
 
924
                              mean_reest,
 
925
                              var_reest,
 
926
                              pass2var,
 
927
                              var_is_full,
 
928
                              TRUE) != S3_SUCCESS) {
 
929
                static int notified = FALSE;
 
930
                time_t t;
 
931
                char time_str[64];
 
932
                
 
933
                /*
 
934
                 * If we were not able to dump the parameters, write one log entry
 
935
                 * about the failure
 
936
                 */
 
937
                if (notified == FALSE) {
 
938
                    t = time(NULL);
 
939
                    strcpy(time_str, (const char *)ctime((const time_t *)&t));
 
940
                    /* nuke the newline at the end of this. */
 
941
                    time_str[strlen(time_str)-1] = '\0';
 
942
                    E_WARN("Ckpt count dump failed on %s.  Retrying dump every %3.1f hour until success.\n",
 
943
                           time_str, DUMP_RETRY_PERIOD/3600.0);
 
944
                    
 
945
                    notified = TRUE;
 
946
                    no_retries++;
 
947
                    if(no_retries>10){ 
 
948
                      E_FATAL("Failed to get the files after 10 retries(about 5 minutes).\n ");
 
949
                    }
 
950
                }
 
951
                sleep(DUMP_RETRY_PERIOD);
 
952
            }
 
953
        }
 
954
    }
 
955
 
 
956
    printf("overall> stats %u (-%u) %e %e",
 
957
           total_frames,
 
958
           n_frame_skipped,
 
959
           (total_frames > 0 ? total_log_lik / total_frames : 0.0),
 
960
           total_log_lik);
 
961
    if (profile) {
 
962
        printf(" %4.3fx %4.3fe",
 
963
               (total_frames > 0 ? timers->utt_timer.t_tot_cpu/(total_frames*0.01) : 0.0),
 
964
               (timers->utt_timer.t_tot_cpu > 0 ? timers->utt_timer.t_tot_elapsed / timers->utt_timer.t_tot_cpu : 0.0));
 
965
    }    
 
966
    printf("\n");
 
967
    fflush(stdout);
 
968
 
 
969
    no_retries=0;
 
970
    /* dump the accumulators to a file system */
 
971
    while (cmd_ln_str("-accumdir") != NULL &&
 
972
           accum_dump(cmd_ln_str("-accumdir"), inv,
 
973
                      mixw_reest,
 
974
                      tmat_reest,
 
975
                      mean_reest,
 
976
                      var_reest,
 
977
                      pass2var,
 
978
                      var_is_full,
 
979
                      FALSE) != S3_SUCCESS) {
 
980
        static int notified = FALSE;
 
981
        time_t t;
 
982
        char time_str[64];
 
983
 
 
984
        /*
 
985
         * If we were not able to dump the parameters, write one log entry
 
986
         * about the failure
 
987
         */
 
988
        if (notified == FALSE) {
 
989
            t = time(NULL);
 
990
            strcpy(time_str, (const char *)ctime((const time_t *)&t));
 
991
            /* nuke the newline at the end of this. */
 
992
            time_str[strlen(time_str)-1] = '\0';
 
993
            E_WARN("Count dump failed on %s.  Retrying dump every %3.1f hour until success.\n",
 
994
                   time_str, DUMP_RETRY_PERIOD/3600.0);
 
995
 
 
996
            notified = TRUE;
 
997
            no_retries++;
 
998
            if(no_retries>10){ 
 
999
              E_FATAL("Failed to get the files after 10 retries(about 5 minutes).\n ");
 
1000
            }
 
1001
        }
 
1002
        
 
1003
        sleep(DUMP_RETRY_PERIOD);
 
1004
 
 
1005
 
 
1006
    }
 
1007
 
 
1008
    if (profile) {
 
1009
        ckd_free(timers);
 
1010
    }
 
1011
 
 
1012
    /* Write a log entry on success */
 
1013
    if (cmd_ln_str("-accumdir"))
 
1014
        E_INFO("Counts saved to %s\n", cmd_ln_str("-accumdir"));
 
1015
    else
 
1016
        E_INFO("Counts NOT saved.\n");
 
1017
}
 
1018
 
 
1019
/* x=log(a) y=log(b), log_add(x,y) = log(a+b) */
 
1020
float64
 
1021
log_add(float64 x, float64 y)
 
1022
{
 
1023
  float64 z;
 
1024
  
 
1025
  if (x<y)
 
1026
    return log_add(y, x);
 
1027
  if (y == LOG_ZERO)
 
1028
    return x;
 
1029
  else
 
1030
    {
 
1031
      z = exp(y-x);
 
1032
      return x+log(1.0+z);
 
1033
    }
 
1034
}
 
1035
 
 
1036
/* forward-backward computation on lattice */
 
1037
int
 
1038
lat_fwd_bwd(s3lattice_t *lat)
 
1039
{
 
1040
  int i, j;
 
1041
  uint32 id;
 
1042
  float64 ac_score, lm_score;
 
1043
 
 
1044
  /* step forward */
 
1045
  for (i=0; i<lat->n_arcs; i++) {
 
1046
    /* initialise alpha */
 
1047
    lat->arc[i].alpha = LOG_ZERO;
 
1048
    if (lat->arc[i].good_arc == 1) {
 
1049
      /* get the acoustic and lm socre for a word hypothesis */
 
1050
      ac_score = lat->arc[i].ac_score / lm_scale;
 
1051
      lm_score = lat->arc[i].lm_score;
 
1052
 
 
1053
      /* compute alpha */
 
1054
      for (j=0; j<lat->arc[i].n_prev_arcs; j++) {
 
1055
        id = lat->arc[i].prev_arcs[j];
 
1056
        if (id == 0) {
 
1057
          if (lat->arc[i].sf == 1) {
 
1058
            lat->arc[i].alpha = log_add(lat->arc[i].alpha, 0);
 
1059
          }
 
1060
        }
 
1061
        else {
 
1062
          if (lat->arc[id-1].good_arc == 1)
 
1063
            lat->arc[i].alpha = log_add(lat->arc[i].alpha, lat->arc[id-1].alpha);
 
1064
        }
 
1065
      }
 
1066
      lat->arc[i].alpha += ac_score + lm_score;
 
1067
    }
 
1068
  }
 
1069
 
 
1070
  /* initialise overall log-likelihood */
 
1071
  lat->prob = LOG_ZERO;
 
1072
 
 
1073
  /* step backward */
 
1074
  for (i=lat->n_arcs-1; i>=0 ;i--) {
 
1075
    /* initialise beta */
 
1076
    lat->arc[i].beta = LOG_ZERO;
 
1077
 
 
1078
    if (lat->arc[i].good_arc == 1) {
 
1079
      /* get the acoustic and lm socre for a word hypothesis */
 
1080
      ac_score = lat->arc[i].ac_score / lm_scale;
 
1081
      lm_score = lat->arc[i].lm_score;
 
1082
 
 
1083
      /* compute beta */
 
1084
      for (j=0; j<lat->arc[i].n_next_arcs; j++) {
 
1085
        id = lat->arc[i].next_arcs[j];
 
1086
        if (id == 0) {
 
1087
          lat->arc[i].beta = log_add(lat->arc[i].beta, 0);
 
1088
        }
 
1089
        else {
 
1090
          if (lat->arc[id-1].good_arc == 1);
 
1091
          lat->arc[i].beta = log_add(lat->arc[i].beta, lat->arc[id-1].beta);
 
1092
        }
 
1093
      }
 
1094
      lat->arc[i].beta += ac_score + lm_score;
 
1095
 
 
1096
      /* compute overall log-likelihood loglid=beta(1)=alpha(Q) */
 
1097
      if (lat->arc[i].sf == 1)
 
1098
        lat->prob = log_add(lat->prob, lat->arc[i].beta);
 
1099
    }
 
1100
  }
 
1101
 
 
1102
  /* compute gamma */
 
1103
  for (i=0; i<lat->n_arcs; i++)
 
1104
    {
 
1105
      /* initialise gamma */
 
1106
      lat->arc[i].gamma = LOG_ZERO;
 
1107
      if (lat->arc[i].good_arc == 1)
 
1108
        {
 
1109
          ac_score = lat->arc[i].ac_score / lm_scale;
 
1110
          lm_score = lat->arc[i].lm_score;
 
1111
          lat->arc[i].gamma = lat->arc[i].alpha + lat->arc[i].beta - (ac_score + lm_score + lat->prob);
 
1112
        }
 
1113
    }
 
1114
 
 
1115
  /* compute the posterior probability of the true path */
 
1116
  lat->postprob = 0;
 
1117
  for (i=lat->n_arcs-lat->n_true_arcs; i<lat->n_arcs; i++)
 
1118
    lat->postprob += lat->arc[i].gamma;
 
1119
  
 
1120
  return S3_SUCCESS;
 
1121
}
 
1122
 
 
1123
/* mmie training: take random left and right context for viterbi run */
 
1124
int
 
1125
mmi_rand_train(model_inventory_t *inv,
 
1126
               model_def_t *mdef,
 
1127
               lexicon_t *lex,
 
1128
               vector_t **f,
 
1129
               s3lattice_t *lat,
 
1130
               float64 a_beam,
 
1131
               uint32 mean_reest,
 
1132
               uint32 var_reest,
 
1133
               feat_t *fcb)
 
1134
{
 
1135
  uint32 k, n;
 
1136
  uint32 n_rand;/* random number */
 
1137
  uint32 n_max_run;/* the maximum number of viterbi run */
 
1138
  char pword[128], cword[128], nword[128];      /* previous, current, next word */
 
1139
  vector_t **arc_f = NULL;/* feature vector for a word arc */
 
1140
  uint32 n_word_obs;/* frames of a word arc */
 
1141
  uint32 rand_prev_id, rand_next_id;/* randomly selected previous and next arc id */
 
1142
  uint32 *lphone, *rphone;        /* the last and first phone of previous and next word hypothesis */
 
1143
  state_t *state_seq;/* HMM state sequence for an arc */
 
1144
  uint32 n_state = 0;/* number of HMM states */
 
1145
  float64 log_lik;/* log-likelihood of an arc */
 
1146
  
 
1147
  /* viterbi run on each arc */
 
1148
  printf(" %5u", lat->n_arcs);
 
1149
  
 
1150
  for(n=0; n<lat->n_arcs; n++) {
 
1151
 
 
1152
    /* total observations of this arc */
 
1153
    /* this is not very accurate, as it consumes one more frame for each word at the end */
 
1154
    n_word_obs = lat->arc[n].ef - lat->arc[n].sf + 1;
 
1155
    
 
1156
    /* get the feature for this arc */
 
1157
    arc_f = (vector_t **) ckd_calloc(n_word_obs, sizeof(vector_t *));
 
1158
    for (k=0; k<n_word_obs; k++)
 
1159
      arc_f[k] = f[k+lat->arc[n].sf-1];
 
1160
    
 
1161
    /* in case the viterbi run fails at a certain left and right context,
 
1162
       at most randomly pick context n_prev_arcs * n_next_arcs times */
 
1163
    n_max_run = lat->arc[n].n_prev_arcs * lat->arc[n].n_next_arcs;
 
1164
    
 
1165
    /* seed the random-number generator with current time */
 
1166
    srand( (unsigned)time( NULL ) );
 
1167
    
 
1168
    /* randomly pick the left and right context */
 
1169
    while (n_max_run > 0 && lat->arc[n].good_arc == 0) {
 
1170
      
 
1171
      /* get left arc id */
 
1172
      if (lat->arc[n].n_prev_arcs == 1) {
 
1173
        n_rand = 0;
 
1174
      }
 
1175
      else {
 
1176
        n_rand = (uint32) (((double) rand() / (((double) RAND_MAX) + 1)) * lat->arc[n].n_prev_arcs );
 
1177
      }
 
1178
      rand_prev_id = lat->arc[n].prev_arcs[n_rand];
 
1179
 
 
1180
      /* get right arc id */
 
1181
      if (lat->arc[n].n_next_arcs == 1) {
 
1182
        n_rand = 0;
 
1183
      }
 
1184
      else {
 
1185
        n_rand = (uint32) (((double) rand() / (((double) RAND_MAX) + 1)) * lat->arc[n].n_next_arcs );
 
1186
      }
 
1187
      rand_next_id = lat->arc[n].next_arcs[n_rand];
 
1188
      
 
1189
      /* get the triphone list */
 
1190
      strcpy(cword, lat->arc[n].word);
 
1191
      if (rand_prev_id == 0)
 
1192
        strcpy(pword, "<s>");
 
1193
      else
 
1194
        strcpy(pword, lat->arc[rand_prev_id-1].word);
 
1195
      lphone = mk_boundary_phone(pword, 0, lex);
 
1196
      if (rand_next_id == 0)
 
1197
        strcpy(nword, "</s>");
 
1198
      else
 
1199
        strcpy(nword, lat->arc[rand_next_id-1].word);
 
1200
      rphone = mk_boundary_phone(nword, 1, lex);
 
1201
 
 
1202
      state_seq = next_utt_states_mmie(&n_state, lex, inv, mdef, cword, lphone, rphone);
 
1203
 
 
1204
      /* viterbi compuation to get the acoustic score for a word hypothesis */
 
1205
      if (mmi_viterbi_run(&log_lik,
 
1206
                          arc_f, n_word_obs,
 
1207
                          state_seq, n_state,
 
1208
                          inv,
 
1209
                          a_beam) == S3_SUCCESS) {
 
1210
        lat->arc[n].good_arc = 1;
 
1211
        lat->arc[n].ac_score = log_lik;
 
1212
        lat->arc[n].best_prev_arc = rand_prev_id;
 
1213
        lat->arc[n].best_next_arc = rand_next_id;
 
1214
      }
 
1215
 
 
1216
      n_max_run--;
 
1217
      ckd_free(lphone);
 
1218
      ckd_free(rphone);
 
1219
    }
 
1220
    
 
1221
    ckd_free(arc_f);
 
1222
    
 
1223
    if (lat->arc[n].good_arc == 0) {
 
1224
      E_INFO("arc_%d is ignored (viterbi run failed)\n", n+1);
 
1225
    }
 
1226
  }
 
1227
 
 
1228
  /* lattice-based forward-backward computation */
 
1229
  lat_fwd_bwd(lat);
 
1230
 
 
1231
  /* update Gaussian parameters */
 
1232
  for (n=0; n<lat->n_arcs; n++) {
 
1233
    
 
1234
    /* only if the arc was successful in viterbi run */
 
1235
    if (lat->arc[n].good_arc == 1) {
 
1236
      
 
1237
      /* total observations of this arc */
 
1238
      n_word_obs = lat->arc[n].ef - lat->arc[n].sf + 1;
 
1239
      arc_f = (vector_t **) ckd_calloc(n_word_obs, sizeof(vector_t *));
 
1240
      for (k=0; k<n_word_obs; k++)
 
1241
        arc_f[k] = f[k+lat->arc[n].sf-1];
 
1242
      
 
1243
      /* get the randomly picked left and right context */
 
1244
      rand_prev_id = lat->arc[n].best_prev_arc;
 
1245
      rand_next_id = lat->arc[n].best_next_arc;
 
1246
      
 
1247
      /* get the triphone list */
 
1248
      strcpy(cword, lat->arc[n].word);
 
1249
      if (rand_prev_id == 0)
 
1250
        strcpy(pword, "<s>");
 
1251
      else
 
1252
        strcpy(pword, lat->arc[rand_prev_id-1].word);
 
1253
      lphone = mk_boundary_phone(pword, 0, lex);
 
1254
      if (rand_next_id == 0)
 
1255
        strcpy(nword, "</s>");
 
1256
      else
 
1257
        strcpy(nword, lat->arc[rand_next_id-1].word);
 
1258
      rphone = mk_boundary_phone(nword, 1, lex);
 
1259
      
 
1260
      /* make state list */
 
1261
      state_seq = next_utt_states_mmie(&n_state, lex, inv, mdef, cword, lphone, rphone);
 
1262
      
 
1263
      /* viterbi update model parameters */
 
1264
      if (mmi_viterbi_update(arc_f, n_word_obs,
 
1265
                             state_seq, n_state,
 
1266
                             inv,
 
1267
                             a_beam,
 
1268
                             mean_reest,
 
1269
                             var_reest,
 
1270
                             lat->arc[n].gamma,
 
1271
                             fcb) != S3_SUCCESS) {
 
1272
        E_ERROR("arc_%d is ignored (viterbi update failed)\n", n+1);
 
1273
      }
 
1274
      ckd_free(arc_f);
 
1275
      ckd_free(lphone);
 
1276
      ckd_free(rphone);
 
1277
    }
 
1278
  }
 
1279
  
 
1280
  return S3_SUCCESS;
 
1281
}
 
1282
 
 
1283
/* mmie training: take the best left and right context for viterbi run */
 
1284
int
 
1285
mmi_best_train(model_inventory_t *inv,
 
1286
               model_def_t *mdef,
 
1287
               lexicon_t *lex,
 
1288
               vector_t **f,
 
1289
               s3lattice_t *lat,
 
1290
               float64 a_beam,
 
1291
               uint32 mean_reest,
 
1292
               uint32 var_reest,
 
1293
               feat_t *fcb)
 
1294
{
 
1295
  uint32 i, j, k, n;
 
1296
  char pword[128], cword[128], nword[128];      /* previous, current and next word hypothesis */
 
1297
  vector_t **arc_f = NULL;/* feature vector for a word arc */
 
1298
  uint32 n_word_obs;/* frames of a word arc */
 
1299
  uint32 prev_id, next_id;/* previous and next arc id */
 
1300
  uint32 *lphone, *rphone;/* the last and first phone of previous and next arc */
 
1301
  uint32 prev_lphone, prev_rphone;/* the lphone and rphone of previous viterbi run on arc */
 
1302
  state_t *state_seq;/* HMM state sequence for an arc */
 
1303
  uint32 n_state = 0;/* number of HMM states */
 
1304
  float64 log_lik;/* log-likelihood of an arc */
 
1305
  
 
1306
  /* viterbi run on each arc */
 
1307
  printf(" %5u", lat->n_arcs);
 
1308
  
 
1309
  for(n=0; n<lat->n_arcs; n++) {
 
1310
    
 
1311
    /* total observations of this arc */
 
1312
    /* this is not very accurate, as it consumes one more frame for each word at the end */
 
1313
    n_word_obs = lat->arc[n].ef - lat->arc[n].sf + 1;
 
1314
    
 
1315
    /* get the feature for this arc */
 
1316
    arc_f = (vector_t **) ckd_calloc(n_word_obs, sizeof(vector_t *));
 
1317
    for (k=0; k<n_word_obs; k++)
 
1318
      arc_f[k] = f[k+lat->arc[n].sf-1];
 
1319
    
 
1320
    /* now try to find the best left and right context for viterbi run */
 
1321
    /* current word hypothesis */
 
1322
    strcpy(cword, lat->arc[n].word);
 
1323
    
 
1324
    /* initialise previous lphone */
 
1325
    prev_lphone = 0;
 
1326
    
 
1327
    /* try all left context */
 
1328
    for (i=0; i<lat->arc[n].n_prev_arcs; i++) {
 
1329
      /* preceding word */
 
1330
      prev_id = lat->arc[n].prev_arcs[i];
 
1331
      if (prev_id == 0) {
 
1332
        strcpy(pword, "<s>");
 
1333
      }
 
1334
      else {
 
1335
        strcpy(pword, lat->arc[prev_id-1].word);
 
1336
      }
 
1337
      
 
1338
      /* get the left boundary triphone */
 
1339
      lphone = mk_boundary_phone(pword, 0, lex);
 
1340
      
 
1341
      /* if the previous preceeding arc has different context as the new one */
 
1342
      if (*lphone != prev_lphone || i == 0) {
 
1343
        
 
1344
        /* initialize rphone */
 
1345
        prev_rphone = 0;
 
1346
        
 
1347
        /* try all right context */
 
1348
        for(j=0; j<lat->arc[n].n_next_arcs; j++) {
 
1349
          /* succeeding word */
 
1350
          next_id = lat->arc[n].next_arcs[j];
 
1351
          if (next_id == 0)
 
1352
            strcpy(nword, "</s>");
 
1353
          else
 
1354
            strcpy(nword, lat->arc[next_id-1].word);
 
1355
            
 
1356
          /* get the right boundary triphone */
 
1357
          rphone = mk_boundary_phone(nword, 1, lex);
 
1358
            
 
1359
          /* if the previous succeeding arc has different context as the new one */
 
1360
          if (*rphone != prev_rphone || j == 0) {
 
1361
                
 
1362
            /* make state list */
 
1363
            state_seq = next_utt_states_mmie(&n_state, lex, inv, mdef, cword, lphone, rphone);
 
1364
                
 
1365
            /* viterbi compuation to get the acoustic score for a word hypothesis */
 
1366
            if (mmi_viterbi_run(&log_lik,
 
1367
                                arc_f, n_word_obs,
 
1368
                                state_seq, n_state,
 
1369
                                inv,
 
1370
                                a_beam) == S3_SUCCESS) {
 
1371
              if (lat->arc[n].good_arc == 0) {
 
1372
                lat->arc[n].good_arc = 1;
 
1373
                lat->arc[n].ac_score = log_lik;
 
1374
                lat->arc[n].best_prev_arc = lat->arc[n].prev_arcs[i];
 
1375
                lat->arc[n].best_next_arc = lat->arc[n].next_arcs[j];
 
1376
              }
 
1377
              else if (log_lik > lat->arc[n].ac_score) {
 
1378
                lat->arc[n].ac_score = log_lik;
 
1379
                lat->arc[n].best_prev_arc = lat->arc[n].prev_arcs[i];
 
1380
                lat->arc[n].best_next_arc = lat->arc[n].next_arcs[j];
 
1381
              }
 
1382
            }
 
1383
            /* save the current right context */
 
1384
            prev_rphone = *rphone;
 
1385
          }
 
1386
          ckd_free(rphone);
 
1387
        }
 
1388
        /* save the current left context */
 
1389
        prev_lphone = *lphone;
 
1390
      }
 
1391
      ckd_free(lphone);
 
1392
    }
 
1393
    
 
1394
    ckd_free(arc_f);
 
1395
    
 
1396
    if (lat->arc[n].good_arc == 0) {
 
1397
      E_INFO("arc_%d is ignored (viterbi run failed)\n", n+1);
 
1398
    }
 
1399
  }
 
1400
  
 
1401
  /* lattice-based forward-backward computation */
 
1402
  lat_fwd_bwd(lat);
 
1403
  
 
1404
  /* update Gaussian parameters */
 
1405
  for (n=0; n<lat->n_arcs; n++) {
 
1406
    
 
1407
    /* only if the arc was successful in viterbi run */
 
1408
    if (lat->arc[n].good_arc == 1) {
 
1409
      
 
1410
      /* total observations of this arc */
 
1411
      n_word_obs = lat->arc[n].ef - lat->arc[n].sf + 1;
 
1412
      arc_f = (vector_t **) ckd_calloc(n_word_obs, sizeof(vector_t *));
 
1413
      for (k=0; k<n_word_obs; k++)
 
1414
        arc_f[k] = f[k+lat->arc[n].sf-1];
 
1415
      
 
1416
      /* get the best left and right context */
 
1417
      prev_id = lat->arc[n].best_prev_arc;
 
1418
      next_id = lat->arc[n].best_next_arc;
 
1419
      
 
1420
      /* get best triphone list */
 
1421
      strcpy(cword, lat->arc[n].word);
 
1422
      if (prev_id == 0)
 
1423
        strcpy(pword, "<s>");
 
1424
      else
 
1425
        strcpy(pword, lat->arc[prev_id-1].word);
 
1426
      lphone = mk_boundary_phone(pword, 0, lex);
 
1427
      if (next_id == 0)
 
1428
        strcpy(nword, "</s>");
 
1429
      else
 
1430
        strcpy(nword, lat->arc[next_id-1].word);
 
1431
      rphone = mk_boundary_phone(nword, 1, lex);
 
1432
      
 
1433
      /* make state list */
 
1434
      state_seq = next_utt_states_mmie(&n_state, lex, inv, mdef, cword, lphone, rphone);
 
1435
      
 
1436
      /* viterbi update model parameters */
 
1437
      if (mmi_viterbi_update(arc_f, n_word_obs,
 
1438
                             state_seq, n_state,
 
1439
                             inv,
 
1440
                             a_beam,
 
1441
                             mean_reest,
 
1442
                             var_reest,
 
1443
                             lat->arc[n].gamma,
 
1444
                             fcb) != S3_SUCCESS) {
 
1445
        E_ERROR("arc_%d is ignored (viterbi update failed)\n", n+1);
 
1446
      }
 
1447
      ckd_free(arc_f);
 
1448
      ckd_free(lphone);
 
1449
      ckd_free(rphone);
 
1450
    }
 
1451
  }
 
1452
  
 
1453
  return S3_SUCCESS;
 
1454
}
 
1455
 
 
1456
/* mmie training: use context-independent hmms for word boundary models */
 
1457
int
 
1458
mmi_ci_train(model_inventory_t *inv,
 
1459
             model_def_t *mdef,
 
1460
             lexicon_t *lex,
 
1461
             vector_t **f,
 
1462
             s3lattice_t *lat,
 
1463
             float64 a_beam,
 
1464
             uint32 mean_reest,
 
1465
             uint32 var_reest,
 
1466
             feat_t *fcb)
 
1467
{
 
1468
  uint32 k, n;
 
1469
  vector_t **arc_f = NULL;/* feature vector for a word arc */
 
1470
  uint32 n_word_obs;/* frames of a word arc */
 
1471
  state_t *state_seq;/* HMM state sequence for an arc */
 
1472
  uint32 n_state = 0;/* number of HMM states */
 
1473
  float64 log_lik;/* log-likelihood of an arc */
 
1474
  
 
1475
  /* viterbi run on each arc */
 
1476
  printf(" %5u", lat->n_arcs);
 
1477
 
 
1478
  for(n=0; n<lat->n_arcs; n++) {
 
1479
    
 
1480
    /* total observations of this arc */
 
1481
    /* this is not very accurate, as it consumes one more frame for each word at the end */
 
1482
    n_word_obs = lat->arc[n].ef - lat->arc[n].sf + 1;
 
1483
    
 
1484
    /* get the feature for this arc */
 
1485
    arc_f = (vector_t **) ckd_calloc(n_word_obs, sizeof(vector_t *));
 
1486
    for (k=0; k<n_word_obs; k++)
 
1487
      arc_f[k] = f[k+lat->arc[n].sf-1];
 
1488
    
 
1489
    /* make state list */
 
1490
    state_seq = next_utt_states(&n_state, lex, inv, mdef, lat->arc[n].word);
 
1491
    
 
1492
    /* viterbi compuation to get the acoustic score for a word hypothesis */
 
1493
    if (mmi_viterbi_run(&log_lik,
 
1494
                        arc_f, n_word_obs,
 
1495
                        state_seq, n_state,
 
1496
                        inv,
 
1497
                        a_beam) == S3_SUCCESS) {
 
1498
      lat->arc[n].good_arc = 1;
 
1499
      lat->arc[n].ac_score = log_lik;
 
1500
    }
 
1501
    
 
1502
    ckd_free(arc_f);
 
1503
    
 
1504
    if (lat->arc[n].good_arc == 0) {
 
1505
      E_INFO("arc_%d is ignored (viterbi run failed)\n", n+1);
 
1506
    }
 
1507
  }
 
1508
  
 
1509
  /* lattice-based forward-backward computation */
 
1510
  lat_fwd_bwd(lat);
 
1511
  
 
1512
  /* update Gaussian parameters */
 
1513
  for (n=0; n<lat->n_arcs; n++) {
 
1514
    
 
1515
    /* only if the arc was successful in viterbi run */
 
1516
    if (lat->arc[n].good_arc == 1) {
 
1517
      
 
1518
      /* total observations of this arc */
 
1519
      n_word_obs = lat->arc[n].ef - lat->arc[n].sf + 1;
 
1520
      arc_f = (vector_t **) ckd_calloc(n_word_obs, sizeof(vector_t *));
 
1521
      for (k=0; k<n_word_obs; k++)
 
1522
        arc_f[k] = f[k+lat->arc[n].sf-1];
 
1523
      
 
1524
      /* make state list */
 
1525
      state_seq = next_utt_states(&n_state, lex, inv, mdef, lat->arc[n].word);
 
1526
      
 
1527
      /* viterbi update model parameters */
 
1528
      if (mmi_viterbi_update(arc_f, n_word_obs,
 
1529
                             state_seq, n_state,
 
1530
                             inv,
 
1531
                             a_beam,
 
1532
                             mean_reest,
 
1533
                             var_reest,
 
1534
                             lat->arc[n].gamma,
 
1535
                             fcb) != S3_SUCCESS) {
 
1536
        E_ERROR("arc_%d is ignored (viterbi update failed)\n", n+1);
 
1537
      }
 
1538
      
 
1539
      ckd_free(arc_f);
 
1540
    }
 
1541
  }
 
1542
  
 
1543
  return S3_SUCCESS;
 
1544
}
 
1545
 
 
1546
/* main mmie training program */
 
1547
void
 
1548
main_mmi_reestimate(model_inventory_t *inv,
 
1549
                    lexicon_t *lex,
 
1550
                    model_def_t *mdef,
 
1551
                    feat_t *feat)
 
1552
{
 
1553
  vector_t *mfcc;/* utterance cepstra */
 
1554
  int32 n_frame;/* # of cepstrum frames  */
 
1555
  uint32 svd_n_frame;        /* # of cepstrum frames  */
 
1556
  vector_t **f;/* independent feature streams derived from cepstra */
 
1557
  float32 ***lda = NULL;
 
1558
  uint32 total_frames;        /* # of frames over the corpus */
 
1559
  float64 a_beam;/* alpha pruning beam */
 
1560
  float64 b_beam;/* beta pruning beam */
 
1561
  float32 spthresh;        /* state posterior probability threshold */
 
1562
  uint32 seq_no;/* sequence # of utterance in corpus */
 
1563
  uint32 mean_reest;        /* if TRUE, reestimate means */
 
1564
  uint32 var_reest;        /* if TRUE, reestimate variances */
 
1565
 
 
1566
  const char *lat_dir;        /* lattice directory */
 
1567
  const char *lat_ext;/* denominator or numerator lattice */
 
1568
  const char *mmi_type;/* different methods to get left and right context for Viterbi run on lattice */
 
1569
  uint32 n_mmi_type = 0;/* convert the mmi_type string to a int */
 
1570
  s3lattice_t *lat = NULL;/* input lattice */
 
1571
  float64 total_log_postprob = 0;/* total posterior probability of the correct hypotheses */
 
1572
  uint32 n_utt_fail = 0;        /* number of sentences failed */
 
1573
  uint32 i;
 
1574
 
 
1575
  char *trans;
 
1576
  uint32 in_veclen;
 
1577
  uint32 n_utt;
 
1578
 
 
1579
  uint32 no_retries=0;
 
1580
 
 
1581
  uint32 maxuttlen;
 
1582
  uint32 n_frame_skipped = 0;
 
1583
 
 
1584
  /* get rid of unnecessary arguments */
 
1585
  if (cmd_ln_int32("-2passvar")) {
 
1586
    E_FATAL("for MMIE training, set -2passvar to no\n");
 
1587
  }
 
1588
  if (cmd_ln_int32("-fullvar")) {
 
1589
    E_FATAL("current MMIE training don't support full variance matrix, set -fullvar to no\n");
 
1590
  }
 
1591
  if (cmd_ln_int32("-timing")) {
 
1592
    E_FATAL("current MMIE training don't support timing, set -timing to no\n");
 
1593
  }
 
1594
  if (cmd_ln_int32("-mixwreest")) {
 
1595
    E_FATAL("current MMIE training don't support mixture weight reestimation, set -mixwreest to no\n");
 
1596
  }
 
1597
  if (cmd_ln_int32("-tmatreest")) {
 
1598
    E_FATAL("current MMIE training don't support transition matrix reestimation, set -tmatreest to no\n");
 
1599
  }
 
1600
  if (cmd_ln_int32("-outputfullpath")) {
 
1601
    E_FATAL("current MMIE training don't support outputfullpath, set -outputfullpath to no\n");
 
1602
  }
 
1603
  if (cmd_ln_int32("-fullsuffixmatch")) {
 
1604
    E_FATAL("current MMIE training don't support fullsuffixmatch, set -fullsuffixmatch to no\n");
 
1605
  }
 
1606
  if (cmd_ln_str("-ckptintv")) {
 
1607
    E_FATAL("current MMIE training don't support ckptintv, remove -ckptintv\n");
 
1608
  }
 
1609
  if (cmd_ln_str("-pdumpdir")) {
 
1610
    E_FATAL("current MMIE training don't support pdumpdir, set -pdumpdir to no\n");
 
1611
  }
 
1612
 
 
1613
  /* get lattice related parameters */
 
1614
  lat_dir = cmd_ln_str("-latdir");
 
1615
  lat_ext = cmd_ln_str("-latext");
 
1616
  if (strcmp(lat_ext, "denlat") != 0 && strcmp(lat_ext, "numlat") != 0) {
 
1617
    E_FATAL("-latext should be either denlat or numlat\n");
 
1618
  }
 
1619
  else {
 
1620
    printf("MMIE training for %s \n", lat_ext);
 
1621
  }
 
1622
  mmi_type = cmd_ln_str("-mmie_type");
 
1623
  if (strcmp(mmi_type, "rand") == 0) {
 
1624
    n_mmi_type = 1;
 
1625
    printf("MMIE training: take random left and right context for Viterbi run \n");
 
1626
  }
 
1627
  else if (strcmp(mmi_type, "best") == 0) {
 
1628
    n_mmi_type = 2;
 
1629
    printf("MMIE training: take the best left and right context for Viterbi run \n");
 
1630
  }
 
1631
  else if (strcmp(mmi_type, "ci") == 0) {
 
1632
    printf("MMIE training: use context-independent hmms for boundary word models \n");
 
1633
    n_mmi_type = 3;
 
1634
  }
 
1635
  else {
 
1636
    E_FATAL("-mmie_type should be rand, best or ci\n");
 
1637
  }
 
1638
  lm_scale = cmd_ln_float32("-lw");
 
1639
 
 
1640
  mean_reest = cmd_ln_int32("-meanreest");
 
1641
  var_reest = cmd_ln_int32("-varreest");
 
1642
  in_veclen = cmd_ln_int32("-ceplen");
 
1643
  
 
1644
  /* Read in an LDA matrix for accumulation. */
 
1645
  if (cmd_ln_str("-lda")) {
 
1646
        feat_read_lda(feat, cmd_ln_str("-lda"), 
 
1647
                            cmd_ln_int32("-ldadim"));
 
1648
        lda = feat->lda;
 
1649
  }
 
1650
 
 
1651
  if (cmd_ln_str("-accumdir") == NULL) {
 
1652
    E_WARN("NO ACCUMDIR SET.  No counts will be written; assuming debug\n");
 
1653
    return;
 
1654
  }
 
1655
 
 
1656
  if (!mean_reest && !var_reest) {
 
1657
    E_FATAL("No reestimation specified! Nothing done. Set -meanreest or -varreest \n");
 
1658
    return;
 
1659
  }
 
1660
 
 
1661
  total_frames = 0;
 
1662
 
 
1663
  a_beam = cmd_ln_float64("-abeam");
 
1664
  b_beam = cmd_ln_float64("-bbeam");
 
1665
  spthresh = cmd_ln_float32("-spthresh");
 
1666
  maxuttlen = cmd_ln_int32("-maxuttlen");
 
1667
 
 
1668
  /* Begin by skipping over some (possibly zero) # of utterances.
 
1669
   * Continue to process utterances until there are no more (either EOF
 
1670
   * or end of run). */
 
1671
  seq_no = corpus_get_begin();
 
1672
 
 
1673
  printf("column defns\n");
 
1674
  printf("\t<seq>\n");
 
1675
  printf("\t<id>\n");
 
1676
  printf("\t<n_frame_in>\n");
 
1677
  printf("\t<n_frame_del>\n");
 
1678
  printf("\t<lattice_cat>\n");
 
1679
  printf("\t<n_word>\n");
 
1680
  printf("\t<lattice_log_postprob>\n");
 
1681
 
 
1682
  /* accumulate density for each training sentence */
 
1683
  n_utt = 0;
 
1684
  while (corpus_next_utt()) {
 
1685
    printf("utt> %5u %25s",  seq_no, corpus_utt());
 
1686
    
 
1687
    if (corpus_get_generic_featurevec(&mfcc, &n_frame, in_veclen) < 0) {
 
1688
        E_FATAL("Can't read input features\n");
 
1689
    }
 
1690
    
 
1691
    printf(" %4u", n_frame);
 
1692
      
 
1693
    if (n_frame < 9) {
 
1694
      E_WARN("utt %s too short\n", corpus_utt());
 
1695
      if (mfcc) {
 
1696
        ckd_free(mfcc[0]);
 
1697
        ckd_free(mfcc);
 
1698
      }
 
1699
      continue;
 
1700
    }
 
1701
      
 
1702
    if ((maxuttlen > 0) && (n_frame > maxuttlen)) {
 
1703
      E_INFO("utt # frames > -maxuttlen; skipping\n");
 
1704
      n_frame_skipped += n_frame;
 
1705
      if (mfcc) {
 
1706
        ckd_free(mfcc[0]);
 
1707
        ckd_free(mfcc);
 
1708
      }
 
1709
      continue;
 
1710
    }
 
1711
      
 
1712
  
 
1713
    svd_n_frame = n_frame;
 
1714
      
 
1715
    f = feat_array_alloc(feat, n_frame + feat_window_size(feat));
 
1716
    feat_s2mfc2feat_live(feat, mfcc, &n_frame, TRUE, TRUE, f);
 
1717
      
 
1718
    printf(" %4u", n_frame - svd_n_frame);
 
1719
      
 
1720
    /* Get the transcript */
 
1721
    corpus_get_sent(&trans);
 
1722
 
 
1723
    /* accumulate density counts on lattice */
 
1724
    if (corpus_load_lattice(&lat, lat_dir, lat_ext) == S3_SUCCESS) {
 
1725
      
 
1726
      /* different type of mmie training */
 
1727
      switch (n_mmi_type) {
 
1728
        /* take random left and right context for viterbi run */
 
1729
      case 1:
 
1730
        {
 
1731
          if (mmi_rand_train(inv, mdef, lex, f, lat,
 
1732
                             a_beam, mean_reest,
 
1733
                             var_reest, feat) == S3_SUCCESS) {
 
1734
            total_log_postprob += lat->postprob;
 
1735
            printf("   %e", lat->postprob);
 
1736
          }
 
1737
          else {
 
1738
            n_utt_fail++;
 
1739
          }
 
1740
          break;
 
1741
        }
 
1742
        /* take the best left and right context for viterbi run */
 
1743
      case 2:
 
1744
        {
 
1745
          if (mmi_best_train(inv, mdef, lex, f, lat,
 
1746
                              a_beam, mean_reest,
 
1747
                             var_reest, feat) == S3_SUCCESS) {
 
1748
            total_log_postprob += lat->postprob;
 
1749
            printf("   %e", lat->postprob);
 
1750
          }
 
1751
          else {
 
1752
            n_utt_fail++;
 
1753
          }
 
1754
          break;
 
1755
        }
 
1756
        /* use context-independent hmms for word boundary models */
 
1757
      case 3:
 
1758
        {
 
1759
          if (mmi_ci_train(inv, mdef, lex, f, lat,
 
1760
                           a_beam, mean_reest,
 
1761
                           var_reest, feat) == S3_SUCCESS) {
 
1762
            total_log_postprob += lat->postprob;
 
1763
            printf("   %e", lat->postprob);
 
1764
          }
 
1765
          else {
 
1766
            n_utt_fail++;
 
1767
          }
 
1768
          break;
 
1769
        }
 
1770
        /* mmi_type error */
 
1771
      default:
 
1772
        {
 
1773
          E_FATAL("Invalid -mmie_type, try rand, best or ci \n");
 
1774
          break;
 
1775
        }
 
1776
      }
 
1777
      
 
1778
      /* free memory for lattice */
 
1779
      for(i=0; i<lat->n_arcs; i++) {
 
1780
        ckd_free(lat->arc[i].prev_arcs);
 
1781
        ckd_free(lat->arc[i].next_arcs);
 
1782
      }
 
1783
      ckd_free(lat->arc);
 
1784
      ckd_free(lat);
 
1785
    }
 
1786
    else {
 
1787
      E_WARN("Can't read input lattice");
 
1788
    }
 
1789
    
 
1790
    free(mfcc[0]);
 
1791
    ckd_free(mfcc);
 
1792
    feat_array_free(f);
 
1793
    free(trans);
 
1794
      
 
1795
    seq_no++;
 
1796
    n_utt++;
 
1797
 
 
1798
    printf("\n");
 
1799
  }
 
1800
    
 
1801
  printf ("overall> stats %u (-%u) %e %e",
 
1802
          n_utt-n_utt_fail,
 
1803
          n_utt_fail,
 
1804
          (n_utt-n_utt_fail>0 ? total_log_postprob/(n_utt-n_utt_fail) : 0.0),
 
1805
          total_log_postprob);  
 
1806
  printf("\n");
 
1807
    
 
1808
  no_retries=0;
 
1809
  /* dump the accumulators to a file system */
 
1810
  while (cmd_ln_str("-accumdir") != NULL &&
 
1811
         accum_mmie_dump(cmd_ln_str("-accumdir"),
 
1812
                         lat_ext,
 
1813
                         inv,
 
1814
                         mean_reest,
 
1815
                         var_reest) != S3_SUCCESS) {
 
1816
    static int notified = FALSE;
 
1817
    time_t t;
 
1818
    char time_str[64];
 
1819
      
 
1820
    /*
 
1821
     * If we were not able to dump the parameters, write one log entry
 
1822
     * about the failure
 
1823
     */
 
1824
    if (notified == FALSE) {
 
1825
      t = time(NULL);
 
1826
      strcpy(time_str, (const char *)ctime((const time_t *)&t));
 
1827
      /* nuke the newline at the end of this. */
 
1828
      time_str[strlen(time_str)-1] = '\0';
 
1829
      E_WARN("Count dump failed on %s.  Retrying dump every %3.1f hour until success.\n",
 
1830
             time_str, DUMP_RETRY_PERIOD/3600.0);
 
1831
      
 
1832
      notified = TRUE;
 
1833
      no_retries++;
 
1834
      if(no_retries>10){ 
 
1835
        E_FATAL("Failed to get the files after 10 retries(about 5 minutes).\n ");
 
1836
      }
 
1837
    }
 
1838
    sleep(DUMP_RETRY_PERIOD);
 
1839
  }
 
1840
    
 
1841
  /* Write a log entry on success */
 
1842
  if (cmd_ln_str("-accumdir"))
 
1843
    E_INFO("Counts saved to %s\n", cmd_ln_str("-accumdir"));
 
1844
  else
 
1845
    E_INFO("Counts NOT saved.\n");
 
1846
}
 
1847
 
 
1848
int main(int argc, char *argv[])
 
1849
{
 
1850
    model_inventory_t *inv;
 
1851
    lexicon_t *lex = NULL;
 
1852
    model_def_t *mdef = NULL;
 
1853
    feat_t *feat = NULL;
 
1854
    
 
1855
    if (main_initialize(argc, argv,
 
1856
                        &inv, &lex, &mdef, &feat) != S3_SUCCESS) {
 
1857
        E_FATAL("initialization failed\n");
 
1858
    }
 
1859
 
 
1860
    if (cmd_ln_int32("-mmie")) {
 
1861
      main_mmi_reestimate(inv, lex, mdef, feat);
 
1862
    }
 
1863
    else {
 
1864
      main_reestimate(inv, lex, mdef, feat, cmd_ln_int32("-viterbi"));
 
1865
    }
 
1866
    
 
1867
    if (feat)
 
1868
        feat_free(feat);
 
1869
    if (mdef)
 
1870
        model_def_free(mdef);    
 
1871
    if (inv)
 
1872
        mod_inv_free(inv);
 
1873
    if (lex)
 
1874
        lexicon_free(lex);
 
1875
 
 
1876
    return 0;
 
1877
}