3
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
5
# Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
7
# The contents of this file are subject to the terms of either the GNU Lesser
8
# General Public License Version 2.1 only ("LGPL") or the Common Development and
9
# Distribution License ("CDDL")(collectively, the "License"). You may not use this
10
# file except in compliance with the License. You can obtain a copy of the CDDL at
11
# http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
12
# http://www.opensource.org/licenses/lgpl-license.php. See the License for the
13
# specific language governing permissions and limitations under the License. When
14
# distributing the software, include this License Header Notice in each file and
15
# include the full text of the License in the License file as well as the
18
# NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
20
# For Covered Software in this distribution, this License shall be governed by the
21
# laws of the State of California (excluding conflict-of-law provisions).
22
# Any litigation relating to this License shall be subject to the jurisdiction of
23
# the Federal Courts of the Northern District of California and the state courts
24
# of the State of California, with venue lying in Santa Clara County, California.
28
# If you wish your version of this file to be governed by only the CDDL or only
29
# the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
30
# include this software in this distribution under the [CDDL or LGPL Version 2.1]
31
# license." If you don't indicate a single choice of license, a recipient has the
32
# option to distribute your version of this file under either the CDDL or the LGPL
33
# Version 2.1, or to extend the choice of license to its licensees as provided
34
# above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
35
# Version 2 license, then the option applies only if the new code is made subject
36
# to such option by the copyright holder.
38
__all__ = ['Trie', 'DATrie', 'match_longest', 'get_ambiguious_length']
50
self.root = Trie.TrieNode()
52
def add(self, word, value=1):
56
curr_node = curr_node.trans[ch]
58
curr_node.trans[ch] = Trie.TrieNode()
59
curr_node = curr_node.trans[ch]
63
def walk (self, trienode, ch):
64
if ch in trienode.trans:
65
trienode = trienode.trans[ch]
66
return trienode, trienode.val
70
class FlexibleList (list):
71
def __check_size (self, index):
72
if index >= len(self):
73
self.extend ([0] * (index-len(self)+1))
75
def __getitem__ (self, index):
76
self.__check_size (index)
77
return list.__getitem__(self, index)
79
def __setitem__ (self, index, value):
80
self.__check_size (index)
81
return list.__setitem__(self, index, value)
83
def character_based_encoder (ch, range=('a', 'z')):
84
ret = ord(ch) - ord(range[0]) + 1
85
if ret <= 0: ret = ord(range[1]) + 1
88
class DATrie (object):
89
def __init__(self, chr_encoder=character_based_encoder):
91
self.chr_encoder = chr_encoder
95
self.base = FlexibleList ()
96
self.check = FlexibleList ()
97
self.value = FlexibleList ()
99
def walk (self, s, ch):
100
c = self.chr_encoder (ch)
101
t = abs(self.base[s]) + c
103
if t<len(self.check) and self.check[t] == s and self.base[t]:
107
v = -1 if self.base[t] < 0 else 0
112
def find_base (self, s, children, i=1):
113
if s == 0 or not children:
120
k = i + self.chr_encoder (ch)
121
if self.base[k] or self.check[k] or k == s:
123
i += int (log (loop_times, 2)) + 1
130
def build (self, words, values=None):
131
assert (not values or (len(words) == len(values)))
132
itval = iter(values) if values else None
136
trie.add (w, itval.next() if itval else -1)
138
self.construct_from_trie (trie, values!=None)
140
def construct_from_trie (self, trie, with_value=True, progress_cb=None, progress_cb_thr=100):
141
nodes = [(trie.root, 0)]
146
trienode, s = nodes.pop(0)
147
find_from = b = self.find_base (s, trienode.trans, find_from)
148
self.base[s] = -b if trienode.val else b
149
if with_value: self.value[s] = trienode.val
151
for ch in trienode.trans:
152
c = self.chr_encoder (ch)
153
t = abs(self.base[s]) + c
154
self.check[t] = s if s else -1
156
nodes.append ((trienode.trans[ch], t))
159
if loop_times == progress_cb_thr:
164
for i in xrange (self.chr_encoder (max(trie.root.trans))+1):
165
if self.check[i] == -1:
168
def save (self, fname):
169
f = open (fname, 'w+')
172
using_32bits = l > 2**15
173
elm_size = 4 if using_32bits else 2
174
fmt_str = '%di'%l if using_32bits else '%dh'%l
176
# the data types here should be aligned with those in datrie.h
177
f.write (struct.pack ('I', l))
178
f.write (struct.pack ('H', elm_size))
179
f.write (struct.pack ('H', 1 if self.value else 0))
181
f.write (struct.pack (fmt_str, *self.base))
182
f.write (struct.pack (fmt_str, *self.check))
185
if len(self.value) < l: self.value[l-1] = 0
186
f.write (struct.pack ('%di'%l, *self.value))
190
def output_static_c_arrays (self, fname):
191
f = open(fname, 'w+')
194
type = "int" if l > 2**15 else "short"
196
f.write (self.__to_c_array (self.base, type, "base"))
197
f.write (self.__to_c_array (self.check, type, "check"))
198
f.write (self.__to_c_array (self.value, "int", "value"))
202
def __to_c_array (self, array, type, name):
203
return "static %s %s[] = {%s};\n\n" % (type, name, ', '.join (str(i) for i in array))
205
def load (self, fname):
206
f = open (fname, 'r')
208
l = struct.unpack ('I', f.read(4))[0]
209
elm_size = struct.unpack ('H', f.read(2))[0]
210
has_value = struct.unpack ('H', f.read(2))[0]
212
fmt_str = '%di'%l if elm_size == 4 else '%dh'%l
213
self.base = struct.unpack (fmt_str, f.read(l*elm_size))
214
self.check = struct.unpack (fmt_str, f.read(l*elm_size))
215
self.value = struct.unpack ('%di'%l, f.read(l*4)) if has_value else []
219
def search (trie, word):
220
curr_node = trie.root
223
curr_node, val = trie.walk (curr_node, ch)
231
def match_longest (trie, word):
232
l = ret_l = ret_v = 0
233
curr_node = trie.root
236
curr_node, val = trie.walk (curr_node, ch)
242
ret_l, ret_v = l, val
246
def get_ambiguious_length (trie, str, word_len):
248
while i < word_len and i < len(str):
249
wid, l = match_longest(trie, str[i:])
256
from pinyin_data import valid_syllables
259
for s in valid_syllables:
260
trie.add (s, valid_syllables[s])
262
for s in valid_syllables:
263
v, l = match_longest (trie, s+'b')
264
assert (len(s) == l and valid_syllables[s] == v)
267
datrie.construct_from_trie (trie)
269
datrie.save ('/tmp/trie_test')
270
datrie.load ('/tmp/trie_test')
272
for s in valid_syllables:
273
v, l = match_longest (datrie, s+'b')
274
assert (len(s) == l and valid_syllables[s] == v)
276
print 'test executed successfully'
278
if __name__ == "__main__":