~cosmos-door/+junk/libkkc-data

« back to all changes in this revision

Viewing changes to tools/sortlm.py

  • Committer: Mitsuya Shibata
  • Date: 2013-07-06 16:06:31 UTC
  • Revision ID: mty.shibata@gmail.com-20130706160631-rpwsfk1k5fvznehm
Initial commit of Debian packaging.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/python
 
2
 
 
3
# Copyright (C) 2011-2013 Daiki Ueno <ueno@gnu.org>
 
4
# Copyright (C) 2011-2013 Red Hat, Inc.
 
5
 
 
6
# This program is free software: you can redistribute it and/or modify
 
7
# it under the terms of the GNU General Public License as published by
 
8
# the Free Software Foundation, either version 3 of the License, or
 
9
# (at your option) any later version.
 
10
 
 
11
# This program is distributed in the hope that it will be useful,
 
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
14
# GNU General Public License for more details.
 
15
 
 
16
# You should have received a copy of the GNU General Public License
 
17
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
18
 
 
19
import struct
 
20
import marisa
 
21
import re
 
22
 
 
23
NGRAM = 3
 
24
NGRAM_LINE_REGEX = '^([-0-9.]+)[ \t]+([^\t]+?)(?:[ \t]+([-0-9.]+))?$'
 
25
 
 
26
class SortedGenerator(object):
 
27
    def __init__(self, infile, output_prefix):
 
28
        self.__infile = infile
 
29
        self.__output_prefix = output_prefix
 
30
        self.__ngram_line_regex = re.compile(NGRAM_LINE_REGEX)
 
31
 
 
32
        self.__ngram_entries = [{} for x in range(0, NGRAM)]
 
33
 
 
34
        self.__vocab_keyset = marisa.Keyset()
 
35
        self.__input_keyset = marisa.Keyset()
 
36
 
 
37
        self.__vocab_trie = marisa.Trie()
 
38
        self.__input_trie = marisa.Trie()
 
39
 
 
40
        self.__min_cost = 0.0
 
41
 
 
42
    def read(self):
 
43
        print "reading N-grams"
 
44
        self.__read_tries()
 
45
        self.__read_ngrams()
 
46
        print "min cost = %lf" % self.__min_cost
 
47
 
 
48
    def __read_tries(self):
 
49
        while True:
 
50
            line = self.__infile.readline()
 
51
            if line == "":
 
52
                break
 
53
            if line.startswith("\\1-grams"):
 
54
                break
 
55
 
 
56
        unigram_count = 0
 
57
        while True:
 
58
            line = self.__infile.readline()
 
59
            if line == "":
 
60
                break
 
61
            line = line.strip()
 
62
            if line == "":
 
63
                break
 
64
            match = self.__ngram_line_regex.match(line)
 
65
            if not match:
 
66
                continue
 
67
            strv = match.groups()
 
68
            self.__vocab_keyset.push_back(strv[1])
 
69
            if not strv[1] in ("<s>", "</s>", "<UNK>"):
 
70
                if "/" not in strv[1]:
 
71
                    continue
 
72
                (input, output) = strv[1].split("/")
 
73
                self.__input_keyset.push_back(input)
 
74
 
 
75
        self.__vocab_trie.build(self.__vocab_keyset)
 
76
        self.__input_trie.build(self.__input_keyset)
 
77
 
 
78
    def __read_ngrams(self):
 
79
        self.__infile.seek(0)
 
80
        for n in range(1, NGRAM + 1):
 
81
            while True:
 
82
                line = self.__infile.readline()
 
83
                if line == "":
 
84
                    break
 
85
                if line.startswith("\\%s-grams:" % n):
 
86
                    break
 
87
 
 
88
            while True:
 
89
                line = self.__infile.readline()
 
90
                if line == "":
 
91
                    break
 
92
                line = line.strip()
 
93
                if line == "":
 
94
                    break
 
95
                match = self.__ngram_line_regex.match(line)
 
96
                if not match:
 
97
                    continue
 
98
                strv = match.groups()
 
99
                ngram = strv[1].split(" ")
 
100
                ids = []
 
101
                for word in ngram:
 
102
                    agent = marisa.Agent()
 
103
                    agent.set_query(word)
 
104
                    if not self.__vocab_trie.lookup(agent):
 
105
                        continue
 
106
                    ids.append(agent.key_id())
 
107
                cost = float(strv[0])
 
108
                if cost != -99 and cost < self.__min_cost:
 
109
                    self.__min_cost = cost
 
110
                backoff = 0.0
 
111
                if strv[2]:
 
112
                    backoff = float(strv[2])
 
113
                self.__ngram_entries[n - 1][tuple(ids)] = (cost, backoff)
 
114
 
 
115
    def write(self):
 
116
        self.__min_cost = -8.0
 
117
        self.__write_tries()
 
118
        self.__write_ngrams()
 
119
 
 
120
    def __write_tries(self):
 
121
        self.__vocab_trie.save(self.__output_prefix + ".1gram.index")
 
122
        self.__input_trie.save(self.__output_prefix + ".input")
 
123
 
 
124
    def __write_ngrams(self):
 
125
        def quantize(cost, min_cost):
 
126
            return max(0, min(65535, int(cost * 65535 / min_cost)))
 
127
 
 
128
        def cmp_header(a, b):
 
129
            return cmp(a[0], b[0])
 
130
 
 
131
        print "writing 1-gram file"
 
132
        unigram_offsets = {}
 
133
        unigram_file = open("%s.1gram" % self.__output_prefix, "wb")
 
134
        offset = 0
 
135
        for ids, value in sorted(self.__ngram_entries[0].iteritems()):
 
136
            unigram_offsets[ids[0]] = offset
 
137
            s = struct.pack("=HHH",
 
138
                            quantize(value[0], self.__min_cost),
 
139
                            quantize(value[1], self.__min_cost),
 
140
                            0   # reserved
 
141
                            )
 
142
            unigram_file.write(s)
 
143
            offset += 1
 
144
        unigram_file.close()
 
145
 
 
146
        print "writing 2-gram file"
 
147
        bigram_offsets = {}
 
148
        bigram_file = open("%s.2gram" % self.__output_prefix, "wb")
 
149
        keys = self.__ngram_entries[1].keys()
 
150
        items = [(struct.pack("=LL", ids[1], unigram_offsets[ids[0]]), ids) for ids in keys]
 
151
        offset = 0
 
152
        for header, ids in sorted(items, cmp=cmp_header):
 
153
            value = self.__ngram_entries[1][ids]
 
154
            bigram_offsets[ids] = offset
 
155
            s = struct.pack("=HH",
 
156
                            quantize(value[0], self.__min_cost),
 
157
                            quantize(value[1], self.__min_cost))
 
158
            bigram_file.write(header + s)
 
159
            offset += 1
 
160
        bigram_file.close()
 
161
 
 
162
        if len(self.__ngram_entries[2]) > 0:
 
163
            print "writing 3-gram file"
 
164
            trigram_file = open("%s.3gram" % self.__output_prefix, "wb")
 
165
            keys = self.__ngram_entries[2].keys()
 
166
            items = [(struct.pack("=LL", ids[2], bigram_offsets[(ids[0], ids[1])]), ids) for ids in keys]
 
167
            for header, ids in sorted(items, cmp=cmp_header):
 
168
                value = self.__ngram_entries[2][ids]
 
169
                s = struct.pack("=H",
 
170
                                quantize(value[0], self.__min_cost))
 
171
                trigram_file.write(header + s)
 
172
            trigram_file.close()
 
173
 
 
174
if __name__ == '__main__':
 
175
    import sys
 
176
    import argparse
 
177
 
 
178
    parser = argparse.ArgumentParser(description='sortlm')
 
179
    parser.add_argument('infile', nargs='?', type=argparse.FileType('r'),
 
180
                        default=sys.stdin,
 
181
                        help='language model file')
 
182
    parser.add_argument('output_prefix', metavar='OUTPUT_PREFIX', type=str,
 
183
                        help='output file prefix')
 
184
    args = parser.parse_args()
 
185
 
 
186
    generator = SortedGenerator(args.infile, args.output_prefix)
 
187
    generator.read();
 
188
    generator.write();