~barry/ubuntu/raring/python-whoosh/hg1423

« back to all changes in this revision

Viewing changes to src/whoosh/support/dawg.py

  • Committer: Barry Warsaw
  • Date: 2013-01-23 16:36:20 UTC
  • mfrom: (1.2.20)
  • Revision ID: barry@python.org-20130123163620-wmrpb5uhvx68bo4x
* Pull from upstream Mercurial r1423 for Python 3.3 support.
* d/control:
  - Add B-D and B-D-I on python3-* packages.
  - Added X-Python3-Version: >= 3.2
  - Added python3-whoosh binary package.
* d/patches, d/patches/fix-setup.patch: Fix typo in setup.py and remove
  --pep8 flag from [pytest] section of setup.cfg since it doesn't work.
* d/*.install: Added python3-whoosh.install and updated paths.
* d/rules:
  - Add appropriate targets for Python 3 build.
  - Add get-{packaged-}orig-source for grabbing from upstream Mercurial.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright 2009 Matt Chaput. All rights reserved.
2
 
#
3
 
# Redistribution and use in source and binary forms, with or without
4
 
# modification, are permitted provided that the following conditions are met:
5
 
#
6
 
#    1. Redistributions of source code must retain the above copyright notice,
7
 
#       this list of conditions and the following disclaimer.
8
 
#
9
 
#    2. Redistributions in binary form must reproduce the above copyright
10
 
#       notice, this list of conditions and the following disclaimer in the
11
 
#       documentation and/or other materials provided with the distribution.
12
 
#
13
 
# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
14
 
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
15
 
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
16
 
# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
17
 
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18
 
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
19
 
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
20
 
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
21
 
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
22
 
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23
 
#
24
 
# The views and conclusions contained in the software and documentation are
25
 
# those of the authors and should not be interpreted as representing official
26
 
# policies, either expressed or implied, of Matt Chaput.
27
 
 
28
 
"""
29
 
This module implements an FST/FSA writer and reader. An FST (Finite State
30
 
Transducer) stores a directed acyclic graph with values associated with the
31
 
leaves. Common elements of the values are pushed inside the tree. An FST that
32
 
does not store values is a regular FSA.
33
 
 
34
 
The format of the leaf values is pluggable using subclasses of the Values
35
 
class.
36
 
 
37
 
Whoosh uses these structures to store a directed acyclic word graph (DAWG) for
38
 
use in (at least) spell checking.
39
 
"""
40
 
 
41
 
 
42
 
import sys, copy
43
 
from array import array
44
 
from hashlib import sha1  # @UnresolvedImport
45
 
 
46
 
from whoosh.compat import (b, u, BytesIO, xrange, iteritems, iterkeys,
47
 
                           bytes_type, text_type, izip, array_tobytes)
48
 
from whoosh.filedb.structfile import StructFile
49
 
from whoosh.system import (_INT_SIZE, pack_byte, pack_int, pack_uint,
50
 
                           pack_long, emptybytes)
51
 
from whoosh.util import utf8encode, utf8decode, varint
52
 
 
53
 
 
54
 
class FileVersionError(Exception):
55
 
    pass
56
 
 
57
 
 
58
 
class InactiveCursor(Exception):
59
 
    pass
60
 
 
61
 
 
62
 
ARC_LAST = 1
63
 
ARC_ACCEPT = 2
64
 
ARC_STOP = 4
65
 
ARC_HAS_VAL = 8
66
 
ARC_HAS_ACCEPT_VAL = 16
67
 
MULTIBYTE_LABEL = 32
68
 
 
69
 
 
70
 
# FST Value types
71
 
 
72
 
class Values(object):
73
 
    """Base for classes the describe how to encode and decode FST values.
74
 
    """
75
 
 
76
 
    @staticmethod
77
 
    def is_valid(v):
78
 
        """Returns True if v is a valid object that can be stored by this
79
 
        class.
80
 
        """
81
 
 
82
 
        raise NotImplementedError
83
 
 
84
 
    @staticmethod
85
 
    def common(v1, v2):
86
 
        """Returns the "common" part of the two values, for whatever "common"
87
 
        means for this class. For example, a string implementation would return
88
 
        the common shared prefix, for an int implementation it would return
89
 
        the minimum of the two numbers.
90
 
        
91
 
        If there is no common part, this method should return None.
92
 
        """
93
 
 
94
 
        raise NotImplementedError
95
 
 
96
 
    @staticmethod
97
 
    def add(prefix, v):
98
 
        """Adds the given prefix (the result of a call to common()) to the
99
 
        given value.
100
 
        """
101
 
 
102
 
        raise NotImplementedError
103
 
 
104
 
    @staticmethod
105
 
    def subtract(v, prefix):
106
 
        """Subtracts the "common" part (the prefix) from the given value.
107
 
        """
108
 
 
109
 
        raise NotImplementedError
110
 
 
111
 
    @staticmethod
112
 
    def write(dbfile, v):
113
 
        """Writes value v to a file.
114
 
        """
115
 
 
116
 
        raise NotImplementedError
117
 
 
118
 
    @staticmethod
119
 
    def read(dbfile):
120
 
        """Reads a value from the given file.
121
 
        """
122
 
 
123
 
        raise NotImplementedError
124
 
 
125
 
    @classmethod
126
 
    def skip(cls, dbfile):
127
 
        """Skips over a value in the given file.
128
 
        """
129
 
 
130
 
        cls.read(dbfile)
131
 
 
132
 
    @staticmethod
133
 
    def to_bytes(v):
134
 
        """Returns a str (Python 2.x) or bytes (Python 3) representation of
135
 
        the given value. This is used for calculating node digests, so it
136
 
        should be unique but fast to calculate, and does not have to be
137
 
        parseable.
138
 
        """
139
 
 
140
 
        raise NotImplementedError
141
 
 
142
 
    @staticmethod
143
 
    def merge(v1, v2):
144
 
        raise NotImplementedError
145
 
 
146
 
 
147
 
class IntValues(Values):
148
 
    """Stores integer values in an FST.
149
 
    """
150
 
 
151
 
    @staticmethod
152
 
    def is_valid(v):
153
 
        return isinstance(v, int) and v >= 0
154
 
 
155
 
    @staticmethod
156
 
    def common(v1, v2):
157
 
        if v1 is None or v2 is None:
158
 
            return None
159
 
        if v1 == v2:
160
 
            return v1
161
 
        return min(v1, v2)
162
 
 
163
 
    @staticmethod
164
 
    def add(base, v):
165
 
        if base is None:
166
 
            return v
167
 
        if v is None:
168
 
            return base
169
 
        return base + v
170
 
 
171
 
    @staticmethod
172
 
    def subtract(v, base):
173
 
        if v is None:
174
 
            return None
175
 
        if base is None:
176
 
            return v
177
 
        return v - base
178
 
 
179
 
    @staticmethod
180
 
    def write(dbfile, v):
181
 
        dbfile.write_uint(v)
182
 
 
183
 
    @staticmethod
184
 
    def read(dbfile):
185
 
        return dbfile.read_uint()
186
 
 
187
 
    @staticmethod
188
 
    def skip(dbfile):
189
 
        dbfile.seek(_INT_SIZE, 1)
190
 
 
191
 
    @staticmethod
192
 
    def to_bytes(v):
193
 
        return pack_int(v)
194
 
 
195
 
 
196
 
class SequenceValues(Values):
197
 
    """Abstract base class for value types that store sequences.
198
 
    """
199
 
 
200
 
    @staticmethod
201
 
    def is_valid(v):
202
 
        return isinstance(self, (list, tuple))
203
 
 
204
 
    @staticmethod
205
 
    def common(v1, v2):
206
 
        if v1 is None or v2 is None:
207
 
            return None
208
 
 
209
 
        i = 0
210
 
        while i < len(v1) and i < len(v2):
211
 
            if v1[i] != v2[i]:
212
 
                break
213
 
            i += 1
214
 
 
215
 
        if i == 0:
216
 
            return None
217
 
        if i == len(v1):
218
 
            return v1
219
 
        if i == len(v2):
220
 
            return v2
221
 
        return v1[:i]
222
 
 
223
 
    @staticmethod
224
 
    def add(prefix, v):
225
 
        if prefix is None:
226
 
            return v
227
 
        if v is None:
228
 
            return prefix
229
 
        return prefix + v
230
 
 
231
 
    @staticmethod
232
 
    def subtract(v, prefix):
233
 
        if prefix is None:
234
 
            return v
235
 
        if v is None:
236
 
            return None
237
 
        if len(v) == len(prefix):
238
 
            return None
239
 
        if len(v) < len(prefix) or len(prefix) == 0:
240
 
            raise ValueError((v, prefix))
241
 
        return v[len(prefix):]
242
 
 
243
 
    @staticmethod
244
 
    def write(dbfile, v):
245
 
        dbfile.write_pickle(v)
246
 
 
247
 
    @staticmethod
248
 
    def read(dbfile):
249
 
        return dbfile.read_pickle()
250
 
 
251
 
 
252
 
class BytesValues(SequenceValues):
253
 
    """Stores bytes objects (str in Python 2.x) in an FST.
254
 
    """
255
 
 
256
 
    @staticmethod
257
 
    def is_valid(v):
258
 
        return isinstance(v, bytes_type)
259
 
 
260
 
    @staticmethod
261
 
    def write(dbfile, v):
262
 
        dbfile.write_int(len(v))
263
 
        dbfile.write(v)
264
 
 
265
 
    @staticmethod
266
 
    def read(dbfile):
267
 
        length = dbfile.read_int()
268
 
        return dbfile.read(length)
269
 
 
270
 
    @staticmethod
271
 
    def skip(dbfile):
272
 
        length = dbfile.read_int()
273
 
        dbfile.seek(length, 1)
274
 
 
275
 
    @staticmethod
276
 
    def to_bytes(v):
277
 
        return v
278
 
 
279
 
 
280
 
class ArrayValues(SequenceValues):
281
 
    """Stores array.array objects in an FST.
282
 
    """
283
 
 
284
 
    def __init__(self, typecode):
285
 
        self.typecode = typecode
286
 
        self.itemsize = array(self.typecode).itemsize
287
 
 
288
 
    def is_valid(self, v):
289
 
        return isinstance(v, array) and v.typecode == self.typecode
290
 
 
291
 
    @staticmethod
292
 
    def write(dbfile, v):
293
 
        dbfile.write(b(v.typecode))
294
 
        dbfile.write_int(len(v))
295
 
        dbfile.write_array(v)
296
 
 
297
 
    def read(self, dbfile):
298
 
        typecode = u(dbfile.read(1))
299
 
        length = dbfile.read_int()
300
 
        return dbfile.read_array(self.typecode, length)
301
 
 
302
 
    def skip(self, dbfile):
303
 
        length = dbfile.read_int()
304
 
        dbfile.seek(length * self.itemsize, 1)
305
 
 
306
 
    @staticmethod
307
 
    def to_bytes(v):
308
 
        return array_tobytes(v)
309
 
 
310
 
 
311
 
class IntListValues(SequenceValues):
312
 
    """Stores lists of positive, increasing integers (that is, lists of
313
 
    integers where each number is >= 0 and each number is greater than or equal
314
 
    to the number that precedes it) in an FST.
315
 
    """
316
 
 
317
 
    @staticmethod
318
 
    def is_valid(v):
319
 
        if isinstance(v, (list, tuple)):
320
 
            if len(v) < 2:
321
 
                return True
322
 
            for i in xrange(1, len(v)):
323
 
                if not isinstance(v[i], int) or v[i] < v[i - 1]:
324
 
                    return False
325
 
            return True
326
 
        return False
327
 
 
328
 
    @staticmethod
329
 
    def write(dbfile, v):
330
 
        base = 0
331
 
        dbfile.write_varint(len(v))
332
 
        for x in v:
333
 
            delta = x - base
334
 
            assert delta >= 0
335
 
            dbfile.write_varint(delta)
336
 
            base = x
337
 
 
338
 
    @staticmethod
339
 
    def read(dbfile):
340
 
        length = dbfile.read_varint()
341
 
        result = []
342
 
        if length > 0:
343
 
            base = 0
344
 
            for _ in xrange(length):
345
 
                base += dbfile.read_varint()
346
 
                result.append(base)
347
 
        return result
348
 
 
349
 
    @staticmethod
350
 
    def to_bytes(v):
351
 
        return b(repr(v))
352
 
 
353
 
 
354
 
# Node-like interface wrappers
355
 
 
356
 
class Node(object):
357
 
    """A slow but easier-to-use wrapper for FSA/DAWGs. Translates the low-level
358
 
    arc-based interface of GraphReader into Node objects with methods to follow
359
 
    edges.
360
 
    """
361
 
 
362
 
    def __init__(self, owner, address, accept=False):
363
 
        self.owner = owner
364
 
        self.address = address
365
 
        self._edges = None
366
 
        self.accept = accept
367
 
 
368
 
    def __iter__(self):
369
 
        if not self._edges:
370
 
            self._load()
371
 
        return iterkeys(self._edges)
372
 
 
373
 
    def __contains__(self, key):
374
 
        if self._edges is None:
375
 
            self._load()
376
 
        return key in self._edges
377
 
 
378
 
    def _load(self):
379
 
        owner = self.owner
380
 
        if self.address is None:
381
 
            d = {}
382
 
        else:
383
 
            d = dict((arc.label, Node(owner, arc.target, arc.accept))
384
 
                     for arc in self.owner.iter_arcs(self.address))
385
 
        self._edges = d
386
 
 
387
 
    def keys(self):
388
 
        if self._edges is None:
389
 
            self._load()
390
 
        return self._edges.keys()
391
 
 
392
 
    def all_edges(self):
393
 
        if self._edges is None:
394
 
            self._load()
395
 
        return self._edges
396
 
 
397
 
    def edge(self, key):
398
 
        if self._edges is None:
399
 
            self._load()
400
 
        return self._edges[key]
401
 
 
402
 
    def flatten(self, sofar=emptybytes):
403
 
        if self.accept:
404
 
            yield sofar
405
 
        for key in sorted(self):
406
 
            node = self.edge(key)
407
 
            for result in node.flatten(sofar + key):
408
 
                yield result
409
 
 
410
 
    def flatten_strings(self):
411
 
        return (utf8decode(k)[0] for k in self.flatten())
412
 
 
413
 
 
414
 
class ComboNode(Node):
415
 
    """Base class for nodes that blend the nodes of two different graphs.
416
 
    
417
 
    Concrete subclasses need to implement the ``edge()`` method and possibly
418
 
    override the ``accept`` property.
419
 
    """
420
 
 
421
 
    def __init__(self, a, b):
422
 
        self.a = a
423
 
        self.b = b
424
 
 
425
 
    def __repr__(self):
426
 
        return "<%s %r %r>" % (self.__class__.__name__, self.a, self.b)
427
 
 
428
 
    def __contains__(self, key):
429
 
        return key in self.a or key in self.b
430
 
 
431
 
    def __iter__(self):
432
 
        return iter(set(self.a) | set(self.b))
433
 
 
434
 
    @property
435
 
    def accept(self):
436
 
        return self.a.accept or self.b.accept
437
 
 
438
 
 
439
 
class UnionNode(ComboNode):
440
 
    """Makes two graphs appear to be the union of the two graphs.
441
 
    """
442
 
 
443
 
    def edge(self, key):
444
 
        a = self.a
445
 
        b = self.b
446
 
        if key in a and key in b:
447
 
            return UnionNode(a.edge(key), b.edge(key))
448
 
        elif key in a:
449
 
            return a.edge(key)
450
 
        else:
451
 
            return b.edge(key)
452
 
 
453
 
 
454
 
class IntersectionNode(ComboNode):
455
 
    """Makes two graphs appear to be the intersection of the two graphs.
456
 
    """
457
 
 
458
 
    def edge(self, key):
459
 
        a = self.a
460
 
        b = self.b
461
 
        if key in a and key in b:
462
 
            return IntersectionNode(a.edge(key), b.edge(key))
463
 
 
464
 
 
465
 
# Cursor
466
 
 
467
 
class BaseCursor(object):
468
 
    """Base class for a cursor-type object for navigating an FST/word graph,
469
 
    represented by a :class:`GraphReader` object.
470
 
    
471
 
    >>> cur = GraphReader(dawgfile).cursor()
472
 
    >>> for key in cur.follow():
473
 
    ...   print(repr(key))
474
 
    
475
 
    The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
476
 
    """
477
 
 
478
 
    def is_active(self):
479
 
        """Returns True if this cursor is still active, that is it has not
480
 
        read past the last arc in the graph.
481
 
        """
482
 
 
483
 
        raise NotImplementedError
484
 
 
485
 
    def label(self):
486
 
        """Returns the label bytes of the current arc.
487
 
        """
488
 
 
489
 
        raise NotImplementedError
490
 
 
491
 
    def prefix(self):
492
 
        """Returns a sequence of the label bytes for the path from the root
493
 
        to the current arc.
494
 
        """
495
 
 
496
 
        raise NotImplementedError
497
 
 
498
 
    def prefix_bytes(self):
499
 
        """Returns the label bytes for the path from the root to the current
500
 
        arc as a single joined bytes object.
501
 
        """
502
 
 
503
 
        return emptybytes.join(self.prefix())
504
 
 
505
 
    def prefix_string(self):
506
 
        """Returns the labels of the path from the root to the current arc as
507
 
        a decoded unicode string.
508
 
        """
509
 
 
510
 
        return utf8decode(self.prefix_bytes())[0]
511
 
 
512
 
    def peek_key(self):
513
 
        """Returns a sequence of label bytes representing the next closest
514
 
        key in the graph.
515
 
        """
516
 
 
517
 
        for label in self.prefix():
518
 
            yield label
519
 
        c = self.copy()
520
 
        while not c.stopped():
521
 
            c.follow()
522
 
            yield c.label()
523
 
 
524
 
    def peek_key_bytes(self):
525
 
        """Returns the next closest key in the graph as a single bytes object.
526
 
        """
527
 
 
528
 
        return emptybytes.join(self.peek_key())
529
 
 
530
 
    def peek_key_string(self):
531
 
        """Returns the next closest key in the graph as a decoded unicode
532
 
        string.
533
 
        """
534
 
 
535
 
        return utf8decode(self.peek_key_bytes())[0]
536
 
 
537
 
    def stopped(self):
538
 
        """Returns True if the current arc leads to a stop state.
539
 
        """
540
 
 
541
 
        raise NotImplementedError
542
 
 
543
 
    def value(self):
544
 
        """Returns the value at the current arc, if reading an FST.
545
 
        """
546
 
 
547
 
        raise NotImplementedError
548
 
 
549
 
    def accept(self):
550
 
        """Returns True if the current arc leads to an accept state (the end
551
 
        of a valid key).
552
 
        """
553
 
 
554
 
        raise NotImplementedError
555
 
 
556
 
    def at_last_arc(self):
557
 
        """Returns True if the current arc is the last outgoing arc from the
558
 
        previous node.
559
 
        """
560
 
 
561
 
        raise NotImplementedError
562
 
 
563
 
    def next_arc(self):
564
 
        """Moves to the next outgoing arc from the previous node.
565
 
        """
566
 
 
567
 
        raise NotImplementedError
568
 
 
569
 
    def follow(self):
570
 
        """Follows the current arc.
571
 
        """
572
 
 
573
 
        raise NotImplementedError
574
 
 
575
 
    def switch_to(self, label):
576
 
        """Switch to the sibling arc with the given label bytes.
577
 
        """
578
 
 
579
 
        _label = self.label
580
 
        _at_last_arc = self.at_last_arc
581
 
        _next_arc = self.next_arc
582
 
 
583
 
        while True:
584
 
            thislabel = _label()
585
 
            if thislabel == label:
586
 
                return True
587
 
            if thislabel > label or _at_last_arc():
588
 
                return False
589
 
            _next_arc()
590
 
 
591
 
    def skip_to(self, key):
592
 
        """Moves the cursor to the path represented by the given key bytes.
593
 
        """
594
 
 
595
 
        _accept = self.accept
596
 
        _prefix = self.prefix
597
 
        _next_arc = self.next_arc
598
 
 
599
 
        keylist = list(key)
600
 
        while True:
601
 
            if _accept():
602
 
                thiskey = list(_prefix())
603
 
                if keylist == thiskey:
604
 
                    return True
605
 
                elif keylist > thiskey:
606
 
                    return False
607
 
            _next_arc()
608
 
 
609
 
    def flatten(self):
610
 
        """Yields the keys in the graph, starting at the current position.
611
 
        """
612
 
 
613
 
        _is_active = self.is_active
614
 
        _accept = self.accept
615
 
        _stopped = self.stopped
616
 
        _follow = self.follow
617
 
        _next_arc = self.next_arc
618
 
        _prefix_bytes = self.prefix_bytes
619
 
 
620
 
        if not _is_active():
621
 
            raise InactiveCursor
622
 
        while _is_active():
623
 
            if _accept():
624
 
                yield _prefix_bytes()
625
 
            if not _stopped():
626
 
                _follow()
627
 
                continue
628
 
            _next_arc()
629
 
 
630
 
    def flatten_v(self):
631
 
        """Yields (key, value) tuples in an FST, starting at the current
632
 
        position.
633
 
        """
634
 
 
635
 
        for key in self.flatten():
636
 
            yield key, self.value()
637
 
 
638
 
    def flatten_strings(self):
639
 
        return (utf8decode(k)[0] for k in self.flatten())
640
 
 
641
 
    def find_path(self, path):
642
 
        """Follows the labels in the given path, starting at the current
643
 
        position.
644
 
        """
645
 
 
646
 
        path = to_labels(path)
647
 
        _switch_to = self.switch_to
648
 
        _follow = self.follow
649
 
        _stopped = self.stopped
650
 
 
651
 
        first = True
652
 
        for i, label in enumerate(path):
653
 
            if not first:
654
 
                _follow()
655
 
            if not _switch_to(label):
656
 
                return False
657
 
            if _stopped():
658
 
                if i < len(path) - 1:
659
 
                    return False
660
 
            first = False
661
 
        return True
662
 
 
663
 
 
664
 
class Cursor(BaseCursor):
665
 
    def __init__(self, graph, root=None, stack=None):
666
 
        self.graph = graph
667
 
        self.vtype = graph.vtype
668
 
        self.root = root if root is not None else graph.default_root()
669
 
        if stack:
670
 
            self.stack = stack
671
 
        else:
672
 
            self.reset()
673
 
 
674
 
    def _current_attr(self, name):
675
 
        stack = self.stack
676
 
        if not stack:
677
 
            raise InactiveCursor
678
 
        return getattr(stack[-1], name)
679
 
 
680
 
    def is_active(self):
681
 
        return bool(self.stack)
682
 
 
683
 
    def stopped(self):
684
 
        return self._current_attr("target") is None
685
 
 
686
 
    def accept(self):
687
 
        return self._current_attr("accept")
688
 
 
689
 
    def at_last_arc(self):
690
 
        return self._current_attr("lastarc")
691
 
 
692
 
    def label(self):
693
 
        return self._current_attr("label")
694
 
 
695
 
    def reset(self):
696
 
        self.stack = []
697
 
        self.sums = [None]
698
 
        self._push(self.graph.arc_at(self.root))
699
 
 
700
 
    def copy(self):
701
 
        return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
702
 
 
703
 
    def prefix(self):
704
 
        stack = self.stack
705
 
        if not stack:
706
 
            raise InactiveCursor
707
 
        return (arc.label for arc in stack)
708
 
 
709
 
    # Override: more efficient implementation using graph methods directly
710
 
    def peek_key(self):
711
 
        if not self.stack:
712
 
            raise InactiveCursor
713
 
 
714
 
        for label in self.prefix():
715
 
            yield label
716
 
        arc = copy.copy(self.stack[-1])
717
 
        graph = self.graph
718
 
        while not arc.accept and arc.target is not None:
719
 
            graph.arc_at(arc.target, arc)
720
 
            yield arc.label
721
 
 
722
 
    def value(self):
723
 
        stack = self.stack
724
 
        if not stack:
725
 
            raise InactiveCursor
726
 
        vtype = self.vtype
727
 
        if not vtype:
728
 
            raise Exception("No value type")
729
 
 
730
 
        v = self.sums[-1]
731
 
        current = stack[-1]
732
 
        if current.value:
733
 
            v = vtype.add(v, current.value)
734
 
        if current.accept and current.acceptval is not None:
735
 
            v = vtype.add(v, current.acceptval)
736
 
        return v
737
 
 
738
 
    def next_arc(self):
739
 
        stack = self.stack
740
 
        if not stack:
741
 
            raise InactiveCursor
742
 
 
743
 
        while stack and stack[-1].lastarc:
744
 
            self.pop()
745
 
        if stack:
746
 
            current = stack[-1]
747
 
            self.graph.arc_at(current.endpos, current)
748
 
            return current
749
 
 
750
 
    def follow(self):
751
 
        address = self._current_attr("target")
752
 
        if address is None:
753
 
            raise Exception("Can't follow a stop arc")
754
 
        self._push(self.graph.arc_at(address))
755
 
        return self
756
 
 
757
 
    # Override: more efficient implementation manipulating the stack
758
 
    def skip_to(self, key):
759
 
        key = to_labels(key)
760
 
        stack = self.stack
761
 
        if not stack:
762
 
            raise InactiveCursor
763
 
 
764
 
        _follow = self.follow
765
 
        _next_arc = self.next_arc
766
 
 
767
 
        i = self._pop_to_prefix(key)
768
 
        while stack and i < len(key):
769
 
            curlabel = stack[-1].label
770
 
            keylabel = key[i]
771
 
            if curlabel == keylabel:
772
 
                _follow()
773
 
                i += 1
774
 
            elif curlabel > keylabel:
775
 
                return
776
 
            else:
777
 
                _next_arc()
778
 
 
779
 
    # Override: more efficient implementation using find_arc
780
 
    def switch_to(self, label):
781
 
        stack = self.stack
782
 
        if not stack:
783
 
            raise InactiveCursor
784
 
 
785
 
        current = stack[-1]
786
 
        if label == current.label:
787
 
            return True
788
 
        else:
789
 
            arc = self.graph.find_arc(current.endpos, label, current)
790
 
            return arc
791
 
 
792
 
    def _push(self, arc):
793
 
        if self.vtype and self.stack:
794
 
            sums = self.sums
795
 
            sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
796
 
        self.stack.append(arc)
797
 
 
798
 
    def pop(self):
799
 
        self.stack.pop()
800
 
        if self.vtype:
801
 
            self.sums.pop()
802
 
 
803
 
    def _pop_to_prefix(self, key):
804
 
        stack = self.stack
805
 
        if not stack:
806
 
            raise InactiveCursor
807
 
 
808
 
        i = 0
809
 
        maxpre = min(len(stack), len(key))
810
 
        while i < maxpre and key[i] == stack[i].label:
811
 
            i += 1
812
 
        if stack[i].label > key[i]:
813
 
            self.current = None
814
 
            return
815
 
        while len(stack) > i + 1:
816
 
            self.pop()
817
 
        self.next_arc()
818
 
        return i
819
 
 
820
 
 
821
 
class UncompiledNode(object):
822
 
    # Represents an "in-memory" node used by the GraphWriter before it is
823
 
    # written to disk.
824
 
 
825
 
    compiled = False
826
 
 
827
 
    def __init__(self, owner):
828
 
        self.owner = owner
829
 
        self._digest = None
830
 
        self.clear()
831
 
 
832
 
    def clear(self):
833
 
        self.arcs = []
834
 
        self.value = None
835
 
        self.accept = False
836
 
        self.inputcount = 0
837
 
 
838
 
    def __repr__(self):
839
 
        return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
840
 
 
841
 
    def digest(self):
842
 
        if self._digest is None:
843
 
            d = sha1()
844
 
            vtype = self.owner.vtype
845
 
            for arc in self.arcs:
846
 
                d.update(arc.label)
847
 
                if arc.target:
848
 
                    d.update(pack_long(arc.target))
849
 
                else:
850
 
                    d.update(b("z"))
851
 
                if arc.value:
852
 
                    d.update(vtype.to_bytes(arc.value))
853
 
                if arc.accept:
854
 
                    d.update(b("T"))
855
 
            self._digest = d.digest()
856
 
        return self._digest
857
 
 
858
 
    def edges(self):
859
 
        return self.arcs
860
 
 
861
 
    def last_value(self, label):
862
 
        assert self.arcs[-1].label == label
863
 
        return self.arcs[-1].value
864
 
 
865
 
    def add_arc(self, label, target):
866
 
        self.arcs.append(Arc(label, target))
867
 
 
868
 
    def replace_last(self, label, target, accept, acceptval=None):
869
 
        arc = self.arcs[-1]
870
 
        assert arc.label == label, "%r != %r" % (arc.label, label)
871
 
        arc.target = target
872
 
        arc.accept = accept
873
 
        arc.acceptval = acceptval
874
 
 
875
 
    def delete_last(self, label, target):
876
 
        arc = self.arcs.pop()
877
 
        assert arc.label == label
878
 
        assert arc.target == target
879
 
 
880
 
    def set_last_value(self, label, value):
881
 
        arc = self.arcs[-1]
882
 
        assert arc.label == label, "%r->%r" % (arc.label, label)
883
 
        arc.value = value
884
 
 
885
 
    def prepend_value(self, prefix):
886
 
        add = self.owner.vtype.add
887
 
        for arc in self.arcs:
888
 
            arc.value = add(prefix, arc.value)
889
 
        if self.accept:
890
 
            self.value = add(prefix, self.value)
891
 
 
892
 
 
893
 
class Arc(object):
894
 
    """
895
 
    Represents a directed arc between two nodes in an FSA/FST graph.
896
 
    
897
 
    The ``lastarc`` attribute is True if this is the last outgoing arc from the
898
 
    previous node.
899
 
    """
900
 
 
901
 
    __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
902
 
                 "endpos")
903
 
 
904
 
    def __init__(self, label=None, target=None, value=None, accept=False,
905
 
                 acceptval=None):
906
 
        """
907
 
        :param label:The label bytes for this arc. For a word graph, this will
908
 
            be a character.
909
 
        :param target: The address of the node at the endpoint of this arc.
910
 
        :param value: The inner FST value at the endpoint of this arc.
911
 
        :param accept: Whether the endpoint of this arc is an accept state
912
 
            (eg the end of a valid word).
913
 
        :param acceptval: If the endpoint of this arc is an accept state, the
914
 
            final FST value for that accepted state.
915
 
        """
916
 
 
917
 
        self.label = label
918
 
        self.target = target
919
 
        self.value = value
920
 
        self.accept = accept
921
 
        self.lastarc = None
922
 
        self.acceptval = acceptval
923
 
        self.endpos = None
924
 
 
925
 
    def __repr__(self):
926
 
        return "<%r-%s %s%s>" % (self.label, self.target,
927
 
                                 "." if self.accept else "",
928
 
                                 (" %r" % self.value) if self.value else "")
929
 
 
930
 
    def __eq__(self, other):
931
 
        if (isinstance(other, self.__class__) and self.accept == other.accept
932
 
            and self.lastarc == other.lastarc and self.target == other.target
933
 
            and self.value == other.value and self.label == other.label):
934
 
            return True
935
 
        return False
936
 
 
937
 
 
938
 
# Graph writer
939
 
 
940
 
class GraphWriter(object):
941
 
    """Writes an FSA/FST graph to disk.
942
 
    
943
 
    Call ``insert(key)`` to insert keys into the graph. You must
944
 
    insert keys in sorted order. Call ``close()`` to finish the graph and close
945
 
    the file.
946
 
    
947
 
    >>> gw = GraphWriter(my_file)
948
 
    >>> gw.insert("alfa")
949
 
    >>> gw.insert("bravo")
950
 
    >>> gw.insert("charlie")
951
 
    >>> gw.close()
952
 
    
953
 
    The graph writer can write separate graphs for multiple fields. Use
954
 
    ``start_field(name)`` and ``finish_field()`` to separate fields.
955
 
    
956
 
    >>> gw = GraphWriter(my_file)
957
 
    >>> gw.start_field("content")
958
 
    >>> gw.insert("alfalfa")
959
 
    >>> gw.insert("apple")
960
 
    >>> gw.finish_field()
961
 
    >>> gw.start_field("title")
962
 
    >>> gw.insert("artichoke")
963
 
    >>> gw.finish_field()
964
 
    >>> gw.close()
965
 
    """
966
 
 
967
 
    version = 1
968
 
 
969
 
    def __init__(self, dbfile, vtype=None, merge=None):
970
 
        """
971
 
        :param dbfile: the file to write to.
972
 
        :param vtype: a :class:`Values` class to use for storing values. This
973
 
            is only necessary if you will be storing values for the keys.
974
 
        :param merge: a function that takes two values and returns a single
975
 
            value. This is called if you insert two identical keys with values.
976
 
        """
977
 
 
978
 
        self.dbfile = dbfile
979
 
        self.vtype = vtype
980
 
        self.merge = merge
981
 
        self.fieldroots = {}
982
 
        self.arc_count = 0
983
 
        self.node_count = 0
984
 
        self.fixed_count = 0
985
 
 
986
 
        dbfile.write(b("GRPH"))
987
 
        dbfile.write_int(self.version)
988
 
        dbfile.write_uint(0)
989
 
 
990
 
        self._infield = False
991
 
 
992
 
    def start_field(self, fieldname):
993
 
        """Starts a new graph for the given field.
994
 
        """
995
 
 
996
 
        if not fieldname:
997
 
            raise ValueError("Field name cannot be equivalent to False")
998
 
        if self._infield:
999
 
            self.finish_field()
1000
 
        self.fieldname = fieldname
1001
 
        self.seen = {}
1002
 
        self.nodes = [UncompiledNode(self)]
1003
 
        self.lastkey = ''
1004
 
        self._inserted = False
1005
 
        self._infield = True
1006
 
 
1007
 
    def finish_field(self):
1008
 
        """Finishes the graph for the current field.
1009
 
        """
1010
 
 
1011
 
        if not self._infield:
1012
 
            raise Exception("Called finish_field before start_field")
1013
 
        self._infield = False
1014
 
        if self._inserted:
1015
 
            self.fieldroots[self.fieldname] = self._finish()
1016
 
        self.fieldname = None
1017
 
 
1018
 
    def close(self):
1019
 
        """Finishes the current graph and closes the underlying file.
1020
 
        """
1021
 
 
1022
 
        if self.fieldname is not None:
1023
 
            self.finish_field()
1024
 
        dbfile = self.dbfile
1025
 
        here = dbfile.tell()
1026
 
        dbfile.write_pickle(self.fieldroots)
1027
 
        dbfile.flush()
1028
 
        dbfile.seek(4 + _INT_SIZE)  # Seek past magic and version number
1029
 
        dbfile.write_uint(here)
1030
 
        dbfile.close()
1031
 
 
1032
 
    def insert(self, key, value=None):
1033
 
        """Inserts the given key into the graph.
1034
 
        
1035
 
        :param key: a sequence of bytes objects, a bytes object, or a string.
1036
 
        :param value: an optional value to encode in the graph along with the
1037
 
            key. If the writer was not instantiated with a value type, passing
1038
 
            a value here will raise an error.
1039
 
        """
1040
 
 
1041
 
        if not self._infield:
1042
 
            raise Exception("Inserted %r before starting a field" % key)
1043
 
        self._inserted = True
1044
 
        key = to_labels(key)  # Python 3 sucks
1045
 
 
1046
 
        vtype = self.vtype
1047
 
        lastkey = self.lastkey
1048
 
        nodes = self.nodes
1049
 
        if len(key) < 1:
1050
 
            raise KeyError("Can't store a null key %r" % (key,))
1051
 
        if lastkey and lastkey > key:
1052
 
            raise KeyError("Keys out of order %r..%r" % (lastkey, key))
1053
 
 
1054
 
        # Find the common prefix shared by this key and the previous one
1055
 
        prefixlen = 0
1056
 
        for i in xrange(min(len(lastkey), len(key))):
1057
 
            if lastkey[i] != key[i]:
1058
 
                break
1059
 
            prefixlen += 1
1060
 
        # Compile the nodes after the prefix, since they're not shared
1061
 
        self._freeze_tail(prefixlen + 1)
1062
 
 
1063
 
        # Create new nodes for the parts of this key after the shared prefix
1064
 
        for char in key[prefixlen:]:
1065
 
            node = UncompiledNode(self)
1066
 
            # Create an arc to this node on the previous node
1067
 
            nodes[-1].add_arc(char, node)
1068
 
            nodes.append(node)
1069
 
        # Mark the last node as an accept state
1070
 
        lastnode = nodes[-1]
1071
 
        lastnode.accept = True
1072
 
 
1073
 
        if vtype:
1074
 
            if value is not None and not vtype.is_valid(value):
1075
 
                raise ValueError("%r is not valid for %s" % (value, vtype))
1076
 
 
1077
 
            # Push value commonalities through the tree
1078
 
            common = None
1079
 
            for i in xrange(1, prefixlen + 1):
1080
 
                node = nodes[i]
1081
 
                parent = nodes[i - 1]
1082
 
                lastvalue = parent.last_value(key[i - 1])
1083
 
                if lastvalue is not None:
1084
 
                    common = vtype.common(value, lastvalue)
1085
 
                    suffix = vtype.subtract(lastvalue, common)
1086
 
                    parent.set_last_value(key[i - 1], common)
1087
 
                    node.prepend_value(suffix)
1088
 
                else:
1089
 
                    common = suffix = None
1090
 
                value = vtype.subtract(value, common)
1091
 
 
1092
 
            if key == lastkey:
1093
 
                # If this key is a duplicate, merge its value with the value of
1094
 
                # the previous (same) key
1095
 
                lastnode.value = self.merge(lastnode.value, value)
1096
 
            else:
1097
 
                nodes[prefixlen].set_last_value(key[prefixlen], value)
1098
 
        elif value:
1099
 
            raise Exception("Value %r but no value type" % value)
1100
 
 
1101
 
        self.lastkey = key
1102
 
 
1103
 
    def _freeze_tail(self, prefixlen):
1104
 
        nodes = self.nodes
1105
 
        lastkey = self.lastkey
1106
 
        downto = max(1, prefixlen)
1107
 
 
1108
 
        while len(nodes) > downto:
1109
 
            node = nodes.pop()
1110
 
            parent = nodes[-1]
1111
 
            inlabel = lastkey[len(nodes) - 1]
1112
 
 
1113
 
            self._compile_targets(node)
1114
 
            accept = node.accept or len(node.arcs) == 0
1115
 
            address = self._compile_node(node)
1116
 
            parent.replace_last(inlabel, address, accept, node.value)
1117
 
 
1118
 
    def _finish(self):
1119
 
        nodes = self.nodes
1120
 
        root = nodes[0]
1121
 
        # Minimize nodes in the last word's suffix
1122
 
        self._freeze_tail(0)
1123
 
        # Compile remaining targets
1124
 
        self._compile_targets(root)
1125
 
        return self._compile_node(root)
1126
 
 
1127
 
    def _compile_targets(self, node):
1128
 
        for arc in node.arcs:
1129
 
            if isinstance(arc.target, UncompiledNode):
1130
 
                n = arc.target
1131
 
                if len(n.arcs) == 0:
1132
 
                    arc.accept = n.accept = True
1133
 
                arc.target = self._compile_node(n)
1134
 
 
1135
 
    def _compile_node(self, uncnode):
1136
 
        seen = self.seen
1137
 
 
1138
 
        if len(uncnode.arcs) == 0:
1139
 
            # Leaf node
1140
 
            address = self._write_node(uncnode)
1141
 
        else:
1142
 
            d = uncnode.digest()
1143
 
            address = seen.get(d)
1144
 
            if address is None:
1145
 
                address = self._write_node(uncnode)
1146
 
                seen[d] = address
1147
 
        return address
1148
 
 
1149
 
    def _write_node(self, uncnode):
1150
 
        vtype = self.vtype
1151
 
        dbfile = self.dbfile
1152
 
        arcs = uncnode.arcs
1153
 
        numarcs = len(arcs)
1154
 
 
1155
 
        if not numarcs:
1156
 
            if uncnode.accept:
1157
 
                return None
1158
 
            else:
1159
 
                # What does it mean for an arc to stop but not be accepted?
1160
 
                raise Exception
1161
 
        self.node_count += 1
1162
 
 
1163
 
        buf = StructFile(BytesIO())
1164
 
        nodestart = dbfile.tell()
1165
 
        #self.count += 1
1166
 
        #self.arccount += numarcs
1167
 
 
1168
 
        fixedsize = -1
1169
 
        arcstart = buf.tell()
1170
 
        for i, arc in enumerate(arcs):
1171
 
            self.arc_count += 1
1172
 
            target = arc.target
1173
 
            label = arc.label
1174
 
 
1175
 
            flags = 0
1176
 
            if len(label) > 1:
1177
 
                flags += MULTIBYTE_LABEL
1178
 
            if i == numarcs - 1:
1179
 
                flags += ARC_LAST
1180
 
            if arc.accept:
1181
 
                flags += ARC_ACCEPT
1182
 
            if target is None:
1183
 
                flags += ARC_STOP
1184
 
            if arc.value is not None:
1185
 
                flags += ARC_HAS_VAL
1186
 
            if arc.acceptval is not None:
1187
 
                flags += ARC_HAS_ACCEPT_VAL
1188
 
 
1189
 
            buf.write(pack_byte(flags))
1190
 
            if len(label) > 1:
1191
 
                buf.write(varint(len(label)))
1192
 
            buf.write(label)
1193
 
            if target is not None:
1194
 
                buf.write(pack_uint(target))
1195
 
            if arc.value is not None:
1196
 
                vtype.write(buf, arc.value)
1197
 
            if arc.acceptval is not None:
1198
 
                vtype.write(buf, arc.acceptval)
1199
 
 
1200
 
            here = buf.tell()
1201
 
            thissize = here - arcstart
1202
 
            arcstart = here
1203
 
            if fixedsize == -1:
1204
 
                fixedsize = thissize
1205
 
            elif fixedsize > 0 and thissize != fixedsize:
1206
 
                fixedsize = 0
1207
 
 
1208
 
        if fixedsize > 0:
1209
 
            # Write a fake arc containing the fixed size and number of arcs
1210
 
            dbfile.write_byte(255)  # FIXED_SIZE
1211
 
            dbfile.write_int(fixedsize)
1212
 
            dbfile.write_int(numarcs)
1213
 
            self.fixed_count += 1
1214
 
        dbfile.write(buf.file.getvalue())
1215
 
 
1216
 
        return nodestart
1217
 
 
1218
 
 
1219
 
# Graph reader
1220
 
 
1221
 
class BaseGraphReader(object):
1222
 
    def cursor(self, rootname=None):
1223
 
        return Cursor(self, self.root(rootname))
1224
 
 
1225
 
    def has_root(self, rootname):
1226
 
        raise NotImplementedError
1227
 
 
1228
 
    def root(self, rootname=None):
1229
 
        raise NotImplementedError
1230
 
 
1231
 
    # Low level methods
1232
 
 
1233
 
    def arc_at(self, address, arc):
1234
 
        raise NotImplementedError
1235
 
 
1236
 
    def iter_arcs(self, address, arc=None):
1237
 
        raise NotImplementedError
1238
 
 
1239
 
    def find_arc(self, address, label, arc=None):
1240
 
        arc = arc or Arc()
1241
 
        for arc in self.iter_arcs(address, arc):
1242
 
            thislabel = arc.label
1243
 
            if thislabel == label:
1244
 
                return arc
1245
 
            elif thislabel > label:
1246
 
                return None
1247
 
 
1248
 
    # Convenience methods
1249
 
 
1250
 
    def list_arcs(self, address):
1251
 
        return list(copy.copy(arc) for arc in self.iter_arcs(address))
1252
 
 
1253
 
    def arc_dict(self, address):
1254
 
        return dict((arc.label, copy.copy(arc))
1255
 
                    for arc in self.iter_arcs(address))
1256
 
 
1257
 
    def find_path(self, path, arc=None, address=None):
1258
 
        path = to_labels(path)
1259
 
 
1260
 
        if arc:
1261
 
            address = arc.target
1262
 
        else:
1263
 
            arc = Arc()
1264
 
 
1265
 
        if address is None:
1266
 
            address = self._root
1267
 
 
1268
 
        for label in path:
1269
 
            if address is None:
1270
 
                return None
1271
 
            if not self.find_arc(address, label, arc):
1272
 
                return None
1273
 
            address = arc.target
1274
 
        return arc
1275
 
 
1276
 
 
1277
 
class GraphReader(BaseGraphReader):
1278
 
    def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
1279
 
        self.dbfile = dbfile
1280
 
        self.vtype = vtype
1281
 
        self.filebase = filebase
1282
 
 
1283
 
        dbfile.seek(filebase)
1284
 
        magic = dbfile.read(4)
1285
 
        if magic != b("GRPH"):
1286
 
            raise FileVersionError
1287
 
        self.version = dbfile.read_int()
1288
 
        dbfile.seek(dbfile.read_uint())
1289
 
        self.roots = dbfile.read_pickle()
1290
 
 
1291
 
        self._root = None
1292
 
        if rootname is None and len(self.roots) == 1:
1293
 
            # If there's only one root, just use it. Have to wrap a list around
1294
 
            # the keys() method here because of Python 3.
1295
 
            rootname = list(self.roots.keys())[0]
1296
 
        if rootname is not None:
1297
 
            self._root = self.root(rootname)
1298
 
 
1299
 
    def close(self):
1300
 
        self.dbfile.close()
1301
 
 
1302
 
    # Overrides
1303
 
 
1304
 
    def has_root(self, rootname):
1305
 
        return rootname in self.roots
1306
 
 
1307
 
    def root(self, rootname=None):
1308
 
        if rootname is None:
1309
 
            return self._root
1310
 
        else:
1311
 
            return self.roots[rootname]
1312
 
 
1313
 
    def default_root(self):
1314
 
        return self._root
1315
 
 
1316
 
    def arc_at(self, address, arc=None):
1317
 
        arc = arc or Arc()
1318
 
        self.dbfile.seek(address)
1319
 
        return self._read_arc(arc)
1320
 
 
1321
 
    def iter_arcs(self, address, arc=None):
1322
 
        arc = arc or Arc()
1323
 
        _read_arc = self._read_arc
1324
 
 
1325
 
        self.dbfile.seek(address)
1326
 
        while True:
1327
 
            _read_arc(arc)
1328
 
            yield arc
1329
 
            if arc.lastarc:
1330
 
                break
1331
 
 
1332
 
    def find_arc(self, address, label, arc=None):
1333
 
        arc = arc or Arc()
1334
 
        dbfile = self.dbfile
1335
 
        dbfile.seek(address)
1336
 
 
1337
 
        # If records are fixed size, we can do a binary search
1338
 
        finfo = self._read_fixed_info()
1339
 
        if finfo:
1340
 
            size, count = finfo
1341
 
            address = dbfile.tell()
1342
 
            if count > 2:
1343
 
                return self._binary_search(address, size, count, label, arc)
1344
 
 
1345
 
        # If records aren't fixed size, fall back to the parent's linear
1346
 
        # search method
1347
 
        return BaseGraphReader.find_arc(self, address, label, arc)
1348
 
 
1349
 
    # Implementations
1350
 
 
1351
 
    def _read_arc(self, toarc=None):
1352
 
        toarc = toarc or Arc()
1353
 
        dbfile = self.dbfile
1354
 
        flags = dbfile.read_byte()
1355
 
        if flags == 255:
1356
 
            # This is a fake arc containing fixed size information; skip it
1357
 
            # and read the next arc
1358
 
            dbfile.seek(_INT_SIZE * 2, 1)
1359
 
            flags = dbfile.read_byte()
1360
 
        toarc.label = self._read_label(flags)
1361
 
        return self._read_arc_data(flags, toarc)
1362
 
 
1363
 
    def _read_label(self, flags):
1364
 
        dbfile = self.dbfile
1365
 
        if flags & MULTIBYTE_LABEL:
1366
 
            length = dbfile.read_varint()
1367
 
        else:
1368
 
            length = 1
1369
 
        label = dbfile.read(length)
1370
 
        return label
1371
 
 
1372
 
    def _read_fixed_info(self):
1373
 
        dbfile = self.dbfile
1374
 
 
1375
 
        flags = dbfile.read_byte()
1376
 
        if flags == 255:
1377
 
            size = dbfile.read_int()
1378
 
            count = dbfile.read_int()
1379
 
            return (size, count)
1380
 
        else:
1381
 
            return None
1382
 
 
1383
 
    def _read_arc_data(self, flags, arc):
1384
 
        dbfile = self.dbfile
1385
 
        accept = arc.accept = bool(flags & ARC_ACCEPT)
1386
 
        arc.lastarc = flags & ARC_LAST
1387
 
        if flags & ARC_STOP:
1388
 
            arc.target = None
1389
 
        else:
1390
 
            arc.target = dbfile.read_uint()
1391
 
        if flags & ARC_HAS_VAL:
1392
 
            arc.value = self.vtype.read(dbfile)
1393
 
        else:
1394
 
            arc.value = None
1395
 
        if accept and flags & ARC_HAS_ACCEPT_VAL:
1396
 
            arc.acceptval = self.vtype.read(dbfile)
1397
 
        arc.endpos = dbfile.tell()
1398
 
        return arc
1399
 
 
1400
 
    def _binary_search(self, address, size, count, label, arc):
1401
 
        dbfile = self.dbfile
1402
 
        _read_label = self._read_label
1403
 
 
1404
 
        lo = 0
1405
 
        hi = count
1406
 
        while lo < hi:
1407
 
            mid = (lo + hi) // 2
1408
 
            midaddr = address + mid * size
1409
 
            dbfile.seek(midaddr)
1410
 
            flags = dbfile.read_byte()
1411
 
            midlabel = self._read_label(flags)
1412
 
            if midlabel == label:
1413
 
                arc.label = midlabel
1414
 
                return self._read_arc_data(flags, arc)
1415
 
            elif midlabel < label:
1416
 
                lo = mid + 1
1417
 
            else:
1418
 
                hi = mid
1419
 
        if lo == count:
1420
 
            return None
1421
 
 
1422
 
 
1423
 
def to_labels(key):
1424
 
    """Takes a string and returns a list of bytestrings, suitable for use as
1425
 
    a key or path in an FSA/FST graph.
1426
 
    """
1427
 
 
1428
 
    # Convert to tuples of bytestrings (must be tuples so they can be hashed)
1429
 
    keytype = type(key)
1430
 
 
1431
 
    # I hate the Python 3 bytes object so friggin much
1432
 
    if keytype is tuple or keytype is list:
1433
 
        if not all(isinstance(e, bytes_type) for e in key):
1434
 
            raise TypeError("%r contains a non-bytestring")
1435
 
        if keytype is list:
1436
 
            key = tuple(key)
1437
 
    elif isinstance(key, bytes_type):
1438
 
        key = tuple(key[i:i + 1] for i in xrange(len(key)))
1439
 
    elif isinstance(key, text_type):
1440
 
        key = tuple(utf8encode(key[i:i + 1])[0] for i in xrange(len(key)))
1441
 
    else:
1442
 
        raise TypeError("Don't know how to convert %r" % key)
1443
 
    return key
1444
 
 
1445
 
 
1446
 
# Within edit distance function
1447
 
 
1448
 
def within(graph, text, k=1, prefix=0, address=None):
1449
 
    """Yields a series of keys in the given graph within ``k`` edit distance of
1450
 
    ``text``. If ``prefix`` is greater than 0, all keys must match the first
1451
 
    ``prefix`` characters of ``text``.
1452
 
    """
1453
 
 
1454
 
    text = to_labels(text)
1455
 
    if address is None:
1456
 
        address = graph._root
1457
 
 
1458
 
    sofar = emptybytes
1459
 
    accept = False
1460
 
    if prefix:
1461
 
        prefixchars = text[:prefix]
1462
 
        arc = graph.find_path(prefixchars, address=address)
1463
 
        if arc is None:
1464
 
            return
1465
 
        sofar = emptybytes.join(prefixchars)
1466
 
        address = arc.target
1467
 
        accept = arc.accept
1468
 
 
1469
 
    stack = [(address, k, prefix, sofar, accept)]
1470
 
    seen = set()
1471
 
    while stack:
1472
 
        state = stack.pop()
1473
 
        # Have we already tried this state?
1474
 
        if state in seen:
1475
 
            continue
1476
 
        seen.add(state)
1477
 
 
1478
 
        address, k, i, sofar, accept = state
1479
 
        # If we're at the end of the text (or deleting enough chars would get
1480
 
        # us to the end and still within K), and we're in the accept state,
1481
 
        # yield the current result
1482
 
        if (len(text) - i <= k) and accept:
1483
 
            yield utf8decode(sofar)[0]
1484
 
 
1485
 
        # If we're in the stop state, give up
1486
 
        if address is None:
1487
 
            continue
1488
 
 
1489
 
        # Exact match
1490
 
        if i < len(text):
1491
 
            arc = graph.find_arc(address, text[i])
1492
 
            if arc:
1493
 
                stack.append((arc.target, k, i + 1, sofar + text[i],
1494
 
                              arc.accept))
1495
 
        # If K is already 0, can't do any more edits
1496
 
        if k < 1:
1497
 
            continue
1498
 
        k -= 1
1499
 
 
1500
 
        arcs = graph.arc_dict(address)
1501
 
        # Insertions
1502
 
        stack.extend((arc.target, k, i, sofar + char, arc.accept)
1503
 
                     for char, arc in iteritems(arcs))
1504
 
 
1505
 
        # Deletion, replacement, and transpo only work before the end
1506
 
        if i >= len(text):
1507
 
            continue
1508
 
        char = text[i]
1509
 
 
1510
 
        # Deletion
1511
 
        stack.append((address, k, i + 1, sofar, False))
1512
 
        # Replacement
1513
 
        for char2, arc in iteritems(arcs):
1514
 
            if char2 != char:
1515
 
                stack.append((arc.target, k, i + 1, sofar + char2, arc.accept))
1516
 
        # Transposition
1517
 
        if i < len(text) - 1:
1518
 
            char2 = text[i + 1]
1519
 
            if char != char2 and char2 in arcs:
1520
 
                # Find arc from next char to this char
1521
 
                target = arcs[char2].target
1522
 
                if target:
1523
 
                    arc = graph.find_arc(target, char)
1524
 
                    if arc:
1525
 
                        stack.append((arc.target, k, i + 2,
1526
 
                                      sofar + char2 + char, arc.accept))
1527
 
 
1528
 
 
1529
 
# Utility functions
1530
 
 
1531
 
def dump_graph(graph, address=None, tab=0, out=None):
1532
 
    if address is None:
1533
 
        address = graph._root
1534
 
    if out is None:
1535
 
        out = sys.stdout
1536
 
 
1537
 
    here = "%06d" % address
1538
 
    for i, arc in enumerate(graph.list_arcs(address)):
1539
 
        if i == 0:
1540
 
            out.write(here)
1541
 
        else:
1542
 
            out.write(" " * 6)
1543
 
        out.write("  " * tab)
1544
 
        out.write("%r %r %s %r\n"
1545
 
                  % (arc.label, arc.target, arc.accept, arc.value))
1546
 
        if arc.target is not None:
1547
 
            dump_graph(graph, arc.target, tab + 1, out=out)
1548
 
 
1549