3
# Copyright (C) 2011-2013 Daiki Ueno <ueno@gnu.org>
4
# Copyright (C) 2011-2013 Red Hat, Inc.
6
# This program is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
11
# This program is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
# GNU General Public License for more details.
16
# You should have received a copy of the GNU General Public License
17
# along with this program. If not, see <http://www.gnu.org/licenses/>.
24
NGRAM_LINE_REGEX = '^([-0-9.]+)[ \t]+([^\t]+?)(?:[ \t]+([-0-9.]+))?$'
26
class SortedGenerator(object):
27
def __init__(self, infile, output_prefix):
28
self.__infile = infile
29
self.__output_prefix = output_prefix
30
self.__ngram_line_regex = re.compile(NGRAM_LINE_REGEX)
32
self.__ngram_entries = [{} for x in range(0, NGRAM)]
34
self.__vocab_keyset = marisa.Keyset()
35
self.__input_keyset = marisa.Keyset()
37
self.__vocab_trie = marisa.Trie()
38
self.__input_trie = marisa.Trie()
43
print "reading N-grams"
46
print "min cost = %lf" % self.__min_cost
48
def __read_tries(self):
50
line = self.__infile.readline()
53
if line.startswith("\\1-grams"):
58
line = self.__infile.readline()
64
match = self.__ngram_line_regex.match(line)
68
self.__vocab_keyset.push_back(strv[1])
69
if not strv[1] in ("<s>", "</s>", "<UNK>"):
70
if "/" not in strv[1]:
72
(input, output) = strv[1].split("/")
73
self.__input_keyset.push_back(input)
75
self.__vocab_trie.build(self.__vocab_keyset)
76
self.__input_trie.build(self.__input_keyset)
78
def __read_ngrams(self):
80
for n in range(1, NGRAM + 1):
82
line = self.__infile.readline()
85
if line.startswith("\\%s-grams:" % n):
89
line = self.__infile.readline()
95
match = self.__ngram_line_regex.match(line)
99
ngram = strv[1].split(" ")
102
agent = marisa.Agent()
103
agent.set_query(word)
104
if not self.__vocab_trie.lookup(agent):
106
ids.append(agent.key_id())
107
cost = float(strv[0])
108
if cost != -99 and cost < self.__min_cost:
109
self.__min_cost = cost
112
backoff = float(strv[2])
113
self.__ngram_entries[n - 1][tuple(ids)] = (cost, backoff)
116
self.__min_cost = -8.0
118
self.__write_ngrams()
120
def __write_tries(self):
121
self.__vocab_trie.save(self.__output_prefix + ".1gram.index")
122
self.__input_trie.save(self.__output_prefix + ".input")
124
def __write_ngrams(self):
125
def quantize(cost, min_cost):
126
return max(0, min(65535, int(cost * 65535 / min_cost)))
128
def cmp_header(a, b):
129
return cmp(a[0], b[0])
131
print "writing 1-gram file"
133
unigram_file = open("%s.1gram" % self.__output_prefix, "wb")
135
for ids, value in sorted(self.__ngram_entries[0].iteritems()):
136
unigram_offsets[ids[0]] = offset
137
s = struct.pack("=HHH",
138
quantize(value[0], self.__min_cost),
139
quantize(value[1], self.__min_cost),
142
unigram_file.write(s)
146
print "writing 2-gram file"
148
bigram_file = open("%s.2gram" % self.__output_prefix, "wb")
149
keys = self.__ngram_entries[1].keys()
150
items = [(struct.pack("=LL", ids[1], unigram_offsets[ids[0]]), ids) for ids in keys]
152
for header, ids in sorted(items, cmp=cmp_header):
153
value = self.__ngram_entries[1][ids]
154
bigram_offsets[ids] = offset
155
s = struct.pack("=HH",
156
quantize(value[0], self.__min_cost),
157
quantize(value[1], self.__min_cost))
158
bigram_file.write(header + s)
162
if len(self.__ngram_entries[2]) > 0:
163
print "writing 3-gram file"
164
trigram_file = open("%s.3gram" % self.__output_prefix, "wb")
165
keys = self.__ngram_entries[2].keys()
166
items = [(struct.pack("=LL", ids[2], bigram_offsets[(ids[0], ids[1])]), ids) for ids in keys]
167
for header, ids in sorted(items, cmp=cmp_header):
168
value = self.__ngram_entries[2][ids]
169
s = struct.pack("=H",
170
quantize(value[0], self.__min_cost))
171
trigram_file.write(header + s)
174
if __name__ == '__main__':
178
parser = argparse.ArgumentParser(description='sortlm')
179
parser.add_argument('infile', nargs='?', type=argparse.FileType('r'),
181
help='language model file')
182
parser.add_argument('output_prefix', metavar='OUTPUT_PREFIX', type=str,
183
help='output file prefix')
184
args = parser.parse_args()
186
generator = SortedGenerator(args.infile, args.output_prefix)