~ubuntu-branches/ubuntu/trusty/sunpinyin/trusty-proposed

« back to all changes in this revision

Viewing changes to python/trie.py

  • Committer: Bazaar Package Importer
  • Author(s): Zhengpeng Hou
  • Date: 2010-09-06 12:23:46 UTC
  • Revision ID: james.westby@ubuntu.com-20100906122346-yamofztk2j5p85fs
Tags: upstream-2.0.2
ImportĀ upstreamĀ versionĀ 2.0.2

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/python
 
2
 
 
3
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
 
4
 
5
# Copyright (c) 2007 Sun Microsystems, Inc. All Rights Reserved.
 
6
 
7
# The contents of this file are subject to the terms of either the GNU Lesser
 
8
# General Public License Version 2.1 only ("LGPL") or the Common Development and
 
9
# Distribution License ("CDDL")(collectively, the "License"). You may not use this
 
10
# file except in compliance with the License. You can obtain a copy of the CDDL at
 
11
# http://www.opensource.org/licenses/cddl1.php and a copy of the LGPLv2.1 at
 
12
# http://www.opensource.org/licenses/lgpl-license.php. See the License for the 
 
13
# specific language governing permissions and limitations under the License. When
 
14
# distributing the software, include this License Header Notice in each file and
 
15
# include the full text of the License in the License file as well as the
 
16
# following notice:
 
17
 
18
# NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE
 
19
# (CDDL)
 
20
# For Covered Software in this distribution, this License shall be governed by the
 
21
# laws of the State of California (excluding conflict-of-law provisions).
 
22
# Any litigation relating to this License shall be subject to the jurisdiction of
 
23
# the Federal Courts of the Northern District of California and the state courts
 
24
# of the State of California, with venue lying in Santa Clara County, California.
 
25
 
26
# Contributor(s):
 
27
 
28
# If you wish your version of this file to be governed by only the CDDL or only
 
29
# the LGPL Version 2.1, indicate your decision by adding "[Contributor]" elects to
 
30
# include this software in this distribution under the [CDDL or LGPL Version 2.1]
 
31
# license." If you don't indicate a single choice of license, a recipient has the
 
32
# option to distribute your version of this file under either the CDDL or the LGPL
 
33
# Version 2.1, or to extend the choice of license to its licensees as provided
 
34
# above. However, if you add LGPL Version 2.1 code and therefore, elected the LGPL
 
35
# Version 2 license, then the option applies only if the new code is made subject
 
36
# to such option by the copyright holder. 
 
37
 
 
38
__all__ = ['Trie', 'DATrie', 'match_longest', 'get_ambiguious_length']
 
39
 
 
40
from math import log
 
41
import struct
 
42
 
 
43
class Trie (object):
 
44
    class TrieNode:
 
45
        def __init__ (self):
 
46
            self.val = 0
 
47
            self.trans = {}
 
48
 
 
49
    def __init__(self):
 
50
        self.root = Trie.TrieNode()
 
51
 
 
52
    def add(self, word, value=1):
 
53
        curr_node = self.root
 
54
        for ch in word:
 
55
            try: 
 
56
                curr_node = curr_node.trans[ch]
 
57
            except:
 
58
                curr_node.trans[ch] = Trie.TrieNode()
 
59
                curr_node = curr_node.trans[ch]
 
60
 
 
61
        curr_node.val = value
 
62
 
 
63
    def walk (self, trienode, ch):
 
64
        if ch in trienode.trans:
 
65
            trienode = trienode.trans[ch]
 
66
            return trienode, trienode.val
 
67
        else:
 
68
            return None, 0
 
69
 
 
70
class FlexibleList (list):
 
71
    def __check_size (self, index):
 
72
        if index >= len(self):
 
73
            self.extend ([0] * (index-len(self)+1))
 
74
 
 
75
    def __getitem__ (self, index):
 
76
        self.__check_size (index)
 
77
        return list.__getitem__(self, index)
 
78
 
 
79
    def __setitem__ (self, index, value):
 
80
        self.__check_size (index)
 
81
        return list.__setitem__(self, index, value)
 
82
 
 
83
def character_based_encoder (ch, range=('a', 'z')):
 
84
    ret = ord(ch) - ord(range[0]) + 1
 
85
    if ret <= 0: ret = ord(range[1]) + 1
 
86
    return ret
 
87
 
 
88
class DATrie (object):
 
89
    def __init__(self, chr_encoder=character_based_encoder):
 
90
        self.root = 0
 
91
        self.chr_encoder = chr_encoder
 
92
        self.clear()
 
93
 
 
94
    def clear (self):
 
95
        self.base  = FlexibleList ()
 
96
        self.check = FlexibleList ()
 
97
        self.value = FlexibleList ()
 
98
 
 
99
    def walk (self, s, ch):
 
100
        c = self.chr_encoder (ch)
 
101
        t = abs(self.base[s]) + c
 
102
 
 
103
        if t<len(self.check) and self.check[t] == s and self.base[t]:
 
104
            if self.value: 
 
105
                v = self.value[t]
 
106
            else: 
 
107
                v = -1 if self.base[t] < 0 else 0
 
108
            return t, v
 
109
        else:
 
110
            return 0, 0
 
111
 
 
112
    def find_base (self, s, children, i=1):
 
113
        if s == 0 or not children:
 
114
            return s
 
115
 
 
116
        i = max (i, 1)
 
117
        loop_times = 0
 
118
        while True:
 
119
            for ch in children:
 
120
                k = i + self.chr_encoder (ch)
 
121
                if self.base[k] or self.check[k] or k == s:
 
122
                    loop_times += 1
 
123
                    i += int (log (loop_times, 2)) + 1
 
124
                    break
 
125
            else:
 
126
                break
 
127
 
 
128
        return i
 
129
 
 
130
    def build (self, words, values=None):
 
131
        assert (not values or (len(words) == len(values)))
 
132
        itval = iter(values) if values else None
 
133
 
 
134
        trie = Trie()
 
135
        for w in words:
 
136
            trie.add (w, itval.next() if itval else -1)
 
137
 
 
138
        self.construct_from_trie (trie, values!=None)
 
139
 
 
140
    def construct_from_trie (self, trie, with_value=True, progress_cb=None, progress_cb_thr=100):
 
141
        nodes = [(trie.root, 0)]
 
142
        find_from = 1
 
143
        loop_times = 0
 
144
 
 
145
        while nodes:
 
146
            trienode, s = nodes.pop(0)
 
147
            find_from = b = self.find_base (s, trienode.trans, find_from)
 
148
            self.base[s] = -b if trienode.val else b
 
149
            if with_value: self.value[s] = trienode.val
 
150
 
 
151
            for ch in trienode.trans:
 
152
                c = self.chr_encoder (ch)
 
153
                t = abs(self.base[s]) + c
 
154
                self.check[t] = s if s else -1
 
155
 
 
156
                nodes.append ((trienode.trans[ch], t))
 
157
 
 
158
            loop_times += 1
 
159
            if loop_times == progress_cb_thr:
 
160
                loop_times = 0
 
161
                if progress_cb:
 
162
                    progress_cb ()
 
163
 
 
164
        for i in xrange (self.chr_encoder (max(trie.root.trans))+1):
 
165
            if self.check[i] == -1:
 
166
                self.check[i] = 0
 
167
 
 
168
    def save (self, fname):
 
169
        f = open (fname, 'w+')
 
170
        l = len (self.base)
 
171
 
 
172
        using_32bits = l > 2**15
 
173
        elm_size = 4 if using_32bits else 2
 
174
        fmt_str = '%di'%l if using_32bits else '%dh'%l
 
175
 
 
176
        # the data types here should be aligned with those in datrie.h
 
177
        f.write (struct.pack ('I', l))
 
178
        f.write (struct.pack ('H', elm_size))
 
179
        f.write (struct.pack ('H', 1 if self.value else 0))
 
180
 
 
181
        f.write (struct.pack (fmt_str, *self.base))
 
182
        f.write (struct.pack (fmt_str, *self.check))
 
183
 
 
184
        if self.value:
 
185
            if len(self.value) < l: self.value[l-1] = 0
 
186
            f.write (struct.pack ('%di'%l, *self.value))
 
187
 
 
188
        f.close()
 
189
 
 
190
    def output_static_c_arrays (self, fname):
 
191
        f = open(fname, 'w+')
 
192
        l = len (self.base)
 
193
 
 
194
        type = "int" if l > 2**15 else "short"
 
195
 
 
196
        f.write (self.__to_c_array (self.base,  type,  "base"))
 
197
        f.write (self.__to_c_array (self.check, type,  "check"))
 
198
        f.write (self.__to_c_array (self.value, "int", "value"))
 
199
 
 
200
        f.close()
 
201
 
 
202
    def __to_c_array (self, array, type, name):
 
203
        return "static %s %s[] = {%s};\n\n" % (type, name, ', '.join (str(i) for i in array))
 
204
 
 
205
    def load (self, fname):
 
206
        f = open (fname, 'r')
 
207
 
 
208
        l = struct.unpack ('I', f.read(4))[0]
 
209
        elm_size = struct.unpack ('H', f.read(2))[0]
 
210
        has_value = struct.unpack ('H', f.read(2))[0]
 
211
 
 
212
        fmt_str = '%di'%l if elm_size == 4 else '%dh'%l
 
213
        self.base  = struct.unpack (fmt_str, f.read(l*elm_size))
 
214
        self.check = struct.unpack (fmt_str, f.read(l*elm_size))
 
215
        self.value = struct.unpack ('%di'%l, f.read(l*4)) if has_value else []
 
216
 
 
217
        f.close()
 
218
 
 
219
def search (trie, word):
 
220
    curr_node = trie.root
 
221
 
 
222
    for ch in word:
 
223
        curr_node, val = trie.walk (curr_node, ch)
 
224
        if not curr_node: 
 
225
            break
 
226
    else:
 
227
        return val
 
228
 
 
229
    return 0
 
230
 
 
231
def match_longest (trie, word):
 
232
    l = ret_l = ret_v = 0
 
233
    curr_node = trie.root
 
234
 
 
235
    for ch in word:
 
236
        curr_node, val = trie.walk (curr_node, ch)
 
237
        if not curr_node: 
 
238
            break
 
239
 
 
240
        l += 1
 
241
        if val: 
 
242
            ret_l, ret_v = l, val
 
243
 
 
244
    return ret_v, ret_l
 
245
 
 
246
def get_ambiguious_length (trie, str, word_len):
 
247
    i = 1
 
248
    while i < word_len and i < len(str):
 
249
        wid, l = match_longest(trie, str[i:])
 
250
        if word_len < i + l:
 
251
            word_len = i + l
 
252
        i += 1
 
253
    return i
 
254
 
 
255
def test ():
 
256
    from pinyin_data import valid_syllables
 
257
 
 
258
    trie = Trie()
 
259
    for s in valid_syllables:
 
260
        trie.add (s, valid_syllables[s])
 
261
 
 
262
    for s in valid_syllables:
 
263
        v, l = match_longest (trie, s+'b')
 
264
        assert (len(s) == l and valid_syllables[s] == v)
 
265
 
 
266
    datrie = DATrie()
 
267
    datrie.construct_from_trie (trie)
 
268
 
 
269
    datrie.save ('/tmp/trie_test')
 
270
    datrie.load ('/tmp/trie_test')
 
271
 
 
272
    for s in valid_syllables:
 
273
        v, l = match_longest (datrie, s+'b')
 
274
        assert (len(s) == l and valid_syllables[s] == v)
 
275
 
 
276
    print 'test executed successfully'
 
277
 
 
278
if __name__ == "__main__":
 
279
    test ()