~ubuntu-branches/ubuntu/utopic/sphinxtrain/utopic

« back to all changes in this revision

Viewing changes to python/cmusphinx/lattice_prune.py

  • Committer: Package Import Robot
  • Author(s): Samuel Thibault
  • Date: 2013-01-02 04:10:21 UTC
  • Revision ID: package-import@ubuntu.com-20130102041021-ynsizmz33fx02hea
Tags: upstream-1.0.8
ImportĀ upstreamĀ versionĀ 1.0.8

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
 
 
3
import os
 
4
import sys
 
5
import lattice
 
6
import sphinxbase
 
7
 
 
8
 
 
9
 
 
10
if __name__ == '__main__':
 
11
    if len(sys.argv) != 11:
 
12
        sys.stderr.write("Usage: %s ABEAM NBEAM LMWEIGHT LMFILE DENLATDIR PRUNED_DENLATDIR FILELIST TRANSFILE FILECOUNT FILEOFFSET\n" % (sys.argv[0]))
 
13
        sys.exit(1)
 
14
 
 
15
    # print command line
 
16
    command = ''
 
17
    for argv in sys.argv:
 
18
        command += argv + ' '
 
19
    print "%s\n" % command
 
20
 
 
21
    abeam, nbeam, lw, lmfile, denlatdir, pruned_denlatdir, ctlfile, transfile, filecount, fileoffset = sys.argv[1:]
 
22
 
 
23
    abeam = float(abeam)
 
24
    nbeam = float(nbeam)
 
25
    lw = float(lw)
 
26
    start = int(fileoffset)
 
27
    end = int(fileoffset) + int(filecount)
 
28
 
 
29
    # load language model
 
30
    lm = sphinxbase.NGramModel(lmfile)
 
31
 
 
32
    # read control file
 
33
    f = open(ctlfile, 'r')
 
34
    ctl = f.readlines()
 
35
    f.close()
 
36
 
 
37
    # read transcription file
 
38
    f = open(transfile, 'r')
 
39
    ref = f.readlines()
 
40
    f.close()
 
41
 
 
42
    sentcount = 0
 
43
    wer = 0
 
44
    nodecount = 0
 
45
    edgecount = 0
 
46
    density = 0
 
47
    # prune lattices one by one
 
48
    for i in range(start, end):
 
49
        c = ctl[i].strip()
 
50
        r = ref[i].split()
 
51
        del r[-1]
 
52
        if r[0] != '<s>': r.insert(0, '<s>')
 
53
        if r[-1] != '</s>': r.append('</s>')
 
54
        r = filter(lambda x: not lattice.is_filler(x), r)
 
55
 
 
56
        print "process sent: %s" % c
 
57
        
 
58
        # load lattice
 
59
        print "\t load lattice ..."
 
60
        dag = lattice.Dag(os.path.join(denlatdir, c + ".lat.gz"))
 
61
        dag.bypass_fillers()
 
62
        dag.remove_unreachable()
 
63
 
 
64
        # prune lattice
 
65
        dag.edges_unigram_score(lm,lw)
 
66
        dag.dt_posterior()
 
67
 
 
68
        # edge pruning
 
69
        print "\t edge pruning ..."
 
70
        dag.forward_edge_prune(abeam)
 
71
        dag.backward_edge_prune(abeam)
 
72
        dag.remove_unreachable()
 
73
 
 
74
        # node pruning
 
75
        print "\t node pruning ..."
 
76
        dag.post_node_prune(nbeam)
 
77
        dag.remove_unreachable()
 
78
 
 
79
        # calculate error
 
80
        err, bt = dag.minimum_error(r)
 
81
 
 
82
        # save pruned lattice
 
83
        print "\t saving pruned lattice ...\n"
 
84
        dag.dag2sphinx(os.path.join(pruned_denlatdir, c + ".lat.gz"))
 
85
 
 
86
        sentcount += 1
 
87
        nodecount += dag.n_nodes()
 
88
        edgecount += dag.n_edges()
 
89
        wer += float(err) / len(r)
 
90
        density += float(dag.n_edges())/len(r)
 
91
 
 
92
    print "Average Lattice Word Error Rate: %.2f%%" % (wer / sentcount * 100)
 
93
    print "Average Lattice Density: %.2f" % (float(density) / sentcount)
 
94
    print "Average Number of Node: %.2f" % (float(nodecount) / sentcount)
 
95
    print "Average Number of Arc: %.2f" % (float(edgecount) / sentcount)
 
96
    print "ALL DONE"