1
# Copyright 2009 Matt Chaput. All rights reserved.
3
# Redistribution and use in source and binary forms, with or without
4
# modification, are permitted provided that the following conditions are met:
6
# 1. Redistributions of source code must retain the above copyright notice,
7
# this list of conditions and the following disclaimer.
9
# 2. Redistributions in binary form must reproduce the above copyright
10
# notice, this list of conditions and the following disclaimer in the
11
# documentation and/or other materials provided with the distribution.
13
# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
14
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
15
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
16
# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
17
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
19
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
20
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
21
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
22
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
# The views and conclusions contained in the software and documentation are
25
# those of the authors and should not be interpreted as representing official
26
# policies, either expressed or implied, of Matt Chaput.
29
This module implements an FST/FSA writer and reader. An FST (Finite State
30
Transducer) stores a directed acyclic graph with values associated with the
31
leaves. Common elements of the values are pushed inside the tree. An FST that
32
does not store values is a regular FSA.
34
The format of the leaf values is pluggable using subclasses of the Values
37
Whoosh uses these structures to store a directed acyclic word graph (DAWG) for
38
use in (at least) spell checking.
43
from array import array
44
from hashlib import sha1 # @UnresolvedImport
46
from whoosh.compat import (b, u, BytesIO, xrange, iteritems, iterkeys,
47
bytes_type, text_type, izip, array_tobytes)
48
from whoosh.filedb.structfile import StructFile
49
from whoosh.system import (_INT_SIZE, pack_byte, pack_int, pack_uint,
50
pack_long, emptybytes)
51
from whoosh.util import utf8encode, utf8decode, varint
54
class FileVersionError(Exception):
58
class InactiveCursor(Exception):
66
ARC_HAS_ACCEPT_VAL = 16
73
"""Base for classes the describe how to encode and decode FST values.
78
"""Returns True if v is a valid object that can be stored by this
82
raise NotImplementedError
86
"""Returns the "common" part of the two values, for whatever "common"
87
means for this class. For example, a string implementation would return
88
the common shared prefix, for an int implementation it would return
89
the minimum of the two numbers.
91
If there is no common part, this method should return None.
94
raise NotImplementedError
98
"""Adds the given prefix (the result of a call to common()) to the
102
raise NotImplementedError
105
def subtract(v, prefix):
106
"""Subtracts the "common" part (the prefix) from the given value.
109
raise NotImplementedError
112
def write(dbfile, v):
113
"""Writes value v to a file.
116
raise NotImplementedError
120
"""Reads a value from the given file.
123
raise NotImplementedError
126
def skip(cls, dbfile):
127
"""Skips over a value in the given file.
134
"""Returns a str (Python 2.x) or bytes (Python 3) representation of
135
the given value. This is used for calculating node digests, so it
136
should be unique but fast to calculate, and does not have to be
140
raise NotImplementedError
144
raise NotImplementedError
147
class IntValues(Values):
148
"""Stores integer values in an FST.
153
return isinstance(v, int) and v >= 0
157
if v1 is None or v2 is None:
172
def subtract(v, base):
180
def write(dbfile, v):
185
return dbfile.read_uint()
189
dbfile.seek(_INT_SIZE, 1)
196
class SequenceValues(Values):
197
"""Abstract base class for value types that store sequences.
202
return isinstance(self, (list, tuple))
206
if v1 is None or v2 is None:
210
while i < len(v1) and i < len(v2):
232
def subtract(v, prefix):
237
if len(v) == len(prefix):
239
if len(v) < len(prefix) or len(prefix) == 0:
240
raise ValueError((v, prefix))
241
return v[len(prefix):]
244
def write(dbfile, v):
245
dbfile.write_pickle(v)
249
return dbfile.read_pickle()
252
class BytesValues(SequenceValues):
253
"""Stores bytes objects (str in Python 2.x) in an FST.
258
return isinstance(v, bytes_type)
261
def write(dbfile, v):
262
dbfile.write_int(len(v))
267
length = dbfile.read_int()
268
return dbfile.read(length)
272
length = dbfile.read_int()
273
dbfile.seek(length, 1)
280
class ArrayValues(SequenceValues):
281
"""Stores array.array objects in an FST.
284
def __init__(self, typecode):
285
self.typecode = typecode
286
self.itemsize = array(self.typecode).itemsize
288
def is_valid(self, v):
289
return isinstance(v, array) and v.typecode == self.typecode
292
def write(dbfile, v):
293
dbfile.write(b(v.typecode))
294
dbfile.write_int(len(v))
295
dbfile.write_array(v)
297
def read(self, dbfile):
298
typecode = u(dbfile.read(1))
299
length = dbfile.read_int()
300
return dbfile.read_array(self.typecode, length)
302
def skip(self, dbfile):
303
length = dbfile.read_int()
304
dbfile.seek(length * self.itemsize, 1)
308
return array_tobytes(v)
311
class IntListValues(SequenceValues):
312
"""Stores lists of positive, increasing integers (that is, lists of
313
integers where each number is >= 0 and each number is greater than or equal
314
to the number that precedes it) in an FST.
319
if isinstance(v, (list, tuple)):
322
for i in xrange(1, len(v)):
323
if not isinstance(v[i], int) or v[i] < v[i - 1]:
329
def write(dbfile, v):
331
dbfile.write_varint(len(v))
335
dbfile.write_varint(delta)
340
length = dbfile.read_varint()
344
for _ in xrange(length):
345
base += dbfile.read_varint()
354
# Node-like interface wrappers
357
"""A slow but easier-to-use wrapper for FSA/DAWGs. Translates the low-level
358
arc-based interface of GraphReader into Node objects with methods to follow
362
def __init__(self, owner, address, accept=False):
364
self.address = address
371
return iterkeys(self._edges)
373
def __contains__(self, key):
374
if self._edges is None:
376
return key in self._edges
380
if self.address is None:
383
d = dict((arc.label, Node(owner, arc.target, arc.accept))
384
for arc in self.owner.iter_arcs(self.address))
388
if self._edges is None:
390
return self._edges.keys()
393
if self._edges is None:
398
if self._edges is None:
400
return self._edges[key]
402
def flatten(self, sofar=emptybytes):
405
for key in sorted(self):
406
node = self.edge(key)
407
for result in node.flatten(sofar + key):
410
def flatten_strings(self):
411
return (utf8decode(k)[0] for k in self.flatten())
414
class ComboNode(Node):
415
"""Base class for nodes that blend the nodes of two different graphs.
417
Concrete subclasses need to implement the ``edge()`` method and possibly
418
override the ``accept`` property.
421
def __init__(self, a, b):
426
return "<%s %r %r>" % (self.__class__.__name__, self.a, self.b)
428
def __contains__(self, key):
429
return key in self.a or key in self.b
432
return iter(set(self.a) | set(self.b))
436
return self.a.accept or self.b.accept
439
class UnionNode(ComboNode):
440
"""Makes two graphs appear to be the union of the two graphs.
446
if key in a and key in b:
447
return UnionNode(a.edge(key), b.edge(key))
454
class IntersectionNode(ComboNode):
455
"""Makes two graphs appear to be the intersection of the two graphs.
461
if key in a and key in b:
462
return IntersectionNode(a.edge(key), b.edge(key))
467
class BaseCursor(object):
468
"""Base class for a cursor-type object for navigating an FST/word graph,
469
represented by a :class:`GraphReader` object.
471
>>> cur = GraphReader(dawgfile).cursor()
472
>>> for key in cur.follow():
475
The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
479
"""Returns True if this cursor is still active, that is it has not
480
read past the last arc in the graph.
483
raise NotImplementedError
486
"""Returns the label bytes of the current arc.
489
raise NotImplementedError
492
"""Returns a sequence of the label bytes for the path from the root
496
raise NotImplementedError
498
def prefix_bytes(self):
499
"""Returns the label bytes for the path from the root to the current
500
arc as a single joined bytes object.
503
return emptybytes.join(self.prefix())
505
def prefix_string(self):
506
"""Returns the labels of the path from the root to the current arc as
507
a decoded unicode string.
510
return utf8decode(self.prefix_bytes())[0]
513
"""Returns a sequence of label bytes representing the next closest
517
for label in self.prefix():
520
while not c.stopped():
524
def peek_key_bytes(self):
525
"""Returns the next closest key in the graph as a single bytes object.
528
return emptybytes.join(self.peek_key())
530
def peek_key_string(self):
531
"""Returns the next closest key in the graph as a decoded unicode
535
return utf8decode(self.peek_key_bytes())[0]
538
"""Returns True if the current arc leads to a stop state.
541
raise NotImplementedError
544
"""Returns the value at the current arc, if reading an FST.
547
raise NotImplementedError
550
"""Returns True if the current arc leads to an accept state (the end
554
raise NotImplementedError
556
def at_last_arc(self):
557
"""Returns True if the current arc is the last outgoing arc from the
561
raise NotImplementedError
564
"""Moves to the next outgoing arc from the previous node.
567
raise NotImplementedError
570
"""Follows the current arc.
573
raise NotImplementedError
575
def switch_to(self, label):
576
"""Switch to the sibling arc with the given label bytes.
580
_at_last_arc = self.at_last_arc
581
_next_arc = self.next_arc
585
if thislabel == label:
587
if thislabel > label or _at_last_arc():
591
def skip_to(self, key):
592
"""Moves the cursor to the path represented by the given key bytes.
595
_accept = self.accept
596
_prefix = self.prefix
597
_next_arc = self.next_arc
602
thiskey = list(_prefix())
603
if keylist == thiskey:
605
elif keylist > thiskey:
610
"""Yields the keys in the graph, starting at the current position.
613
_is_active = self.is_active
614
_accept = self.accept
615
_stopped = self.stopped
616
_follow = self.follow
617
_next_arc = self.next_arc
618
_prefix_bytes = self.prefix_bytes
624
yield _prefix_bytes()
631
"""Yields (key, value) tuples in an FST, starting at the current
635
for key in self.flatten():
636
yield key, self.value()
638
def flatten_strings(self):
639
return (utf8decode(k)[0] for k in self.flatten())
641
def find_path(self, path):
642
"""Follows the labels in the given path, starting at the current
646
path = to_labels(path)
647
_switch_to = self.switch_to
648
_follow = self.follow
649
_stopped = self.stopped
652
for i, label in enumerate(path):
655
if not _switch_to(label):
658
if i < len(path) - 1:
664
class Cursor(BaseCursor):
665
def __init__(self, graph, root=None, stack=None):
667
self.vtype = graph.vtype
668
self.root = root if root is not None else graph.default_root()
674
def _current_attr(self, name):
678
return getattr(stack[-1], name)
681
return bool(self.stack)
684
return self._current_attr("target") is None
687
return self._current_attr("accept")
689
def at_last_arc(self):
690
return self._current_attr("lastarc")
693
return self._current_attr("label")
698
self._push(self.graph.arc_at(self.root))
701
return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
707
return (arc.label for arc in stack)
709
# Override: more efficient implementation using graph methods directly
714
for label in self.prefix():
716
arc = copy.copy(self.stack[-1])
718
while not arc.accept and arc.target is not None:
719
graph.arc_at(arc.target, arc)
728
raise Exception("No value type")
733
v = vtype.add(v, current.value)
734
if current.accept and current.acceptval is not None:
735
v = vtype.add(v, current.acceptval)
743
while stack and stack[-1].lastarc:
747
self.graph.arc_at(current.endpos, current)
751
address = self._current_attr("target")
753
raise Exception("Can't follow a stop arc")
754
self._push(self.graph.arc_at(address))
757
# Override: more efficient implementation manipulating the stack
758
def skip_to(self, key):
764
_follow = self.follow
765
_next_arc = self.next_arc
767
i = self._pop_to_prefix(key)
768
while stack and i < len(key):
769
curlabel = stack[-1].label
771
if curlabel == keylabel:
774
elif curlabel > keylabel:
779
# Override: more efficient implementation using find_arc
780
def switch_to(self, label):
786
if label == current.label:
789
arc = self.graph.find_arc(current.endpos, label, current)
792
def _push(self, arc):
793
if self.vtype and self.stack:
795
sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
796
self.stack.append(arc)
803
def _pop_to_prefix(self, key):
809
maxpre = min(len(stack), len(key))
810
while i < maxpre and key[i] == stack[i].label:
812
if stack[i].label > key[i]:
815
while len(stack) > i + 1:
821
class UncompiledNode(object):
822
# Represents an "in-memory" node used by the GraphWriter before it is
827
def __init__(self, owner):
839
return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
842
if self._digest is None:
844
vtype = self.owner.vtype
845
for arc in self.arcs:
848
d.update(pack_long(arc.target))
852
d.update(vtype.to_bytes(arc.value))
855
self._digest = d.digest()
861
def last_value(self, label):
862
assert self.arcs[-1].label == label
863
return self.arcs[-1].value
865
def add_arc(self, label, target):
866
self.arcs.append(Arc(label, target))
868
def replace_last(self, label, target, accept, acceptval=None):
870
assert arc.label == label, "%r != %r" % (arc.label, label)
873
arc.acceptval = acceptval
875
def delete_last(self, label, target):
876
arc = self.arcs.pop()
877
assert arc.label == label
878
assert arc.target == target
880
def set_last_value(self, label, value):
882
assert arc.label == label, "%r->%r" % (arc.label, label)
885
def prepend_value(self, prefix):
886
add = self.owner.vtype.add
887
for arc in self.arcs:
888
arc.value = add(prefix, arc.value)
890
self.value = add(prefix, self.value)
895
Represents a directed arc between two nodes in an FSA/FST graph.
897
The ``lastarc`` attribute is True if this is the last outgoing arc from the
901
__slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
904
def __init__(self, label=None, target=None, value=None, accept=False,
907
:param label:The label bytes for this arc. For a word graph, this will
909
:param target: The address of the node at the endpoint of this arc.
910
:param value: The inner FST value at the endpoint of this arc.
911
:param accept: Whether the endpoint of this arc is an accept state
912
(eg the end of a valid word).
913
:param acceptval: If the endpoint of this arc is an accept state, the
914
final FST value for that accepted state.
922
self.acceptval = acceptval
926
return "<%r-%s %s%s>" % (self.label, self.target,
927
"." if self.accept else "",
928
(" %r" % self.value) if self.value else "")
930
def __eq__(self, other):
931
if (isinstance(other, self.__class__) and self.accept == other.accept
932
and self.lastarc == other.lastarc and self.target == other.target
933
and self.value == other.value and self.label == other.label):
940
class GraphWriter(object):
941
"""Writes an FSA/FST graph to disk.
943
Call ``insert(key)`` to insert keys into the graph. You must
944
insert keys in sorted order. Call ``close()`` to finish the graph and close
947
>>> gw = GraphWriter(my_file)
948
>>> gw.insert("alfa")
949
>>> gw.insert("bravo")
950
>>> gw.insert("charlie")
953
The graph writer can write separate graphs for multiple fields. Use
954
``start_field(name)`` and ``finish_field()`` to separate fields.
956
>>> gw = GraphWriter(my_file)
957
>>> gw.start_field("content")
958
>>> gw.insert("alfalfa")
959
>>> gw.insert("apple")
960
>>> gw.finish_field()
961
>>> gw.start_field("title")
962
>>> gw.insert("artichoke")
963
>>> gw.finish_field()
969
def __init__(self, dbfile, vtype=None, merge=None):
971
:param dbfile: the file to write to.
972
:param vtype: a :class:`Values` class to use for storing values. This
973
is only necessary if you will be storing values for the keys.
974
:param merge: a function that takes two values and returns a single
975
value. This is called if you insert two identical keys with values.
986
dbfile.write(b("GRPH"))
987
dbfile.write_int(self.version)
990
self._infield = False
992
def start_field(self, fieldname):
993
"""Starts a new graph for the given field.
997
raise ValueError("Field name cannot be equivalent to False")
1000
self.fieldname = fieldname
1002
self.nodes = [UncompiledNode(self)]
1004
self._inserted = False
1005
self._infield = True
1007
def finish_field(self):
1008
"""Finishes the graph for the current field.
1011
if not self._infield:
1012
raise Exception("Called finish_field before start_field")
1013
self._infield = False
1015
self.fieldroots[self.fieldname] = self._finish()
1016
self.fieldname = None
1019
"""Finishes the current graph and closes the underlying file.
1022
if self.fieldname is not None:
1024
dbfile = self.dbfile
1025
here = dbfile.tell()
1026
dbfile.write_pickle(self.fieldroots)
1028
dbfile.seek(4 + _INT_SIZE) # Seek past magic and version number
1029
dbfile.write_uint(here)
1032
def insert(self, key, value=None):
1033
"""Inserts the given key into the graph.
1035
:param key: a sequence of bytes objects, a bytes object, or a string.
1036
:param value: an optional value to encode in the graph along with the
1037
key. If the writer was not instantiated with a value type, passing
1038
a value here will raise an error.
1041
if not self._infield:
1042
raise Exception("Inserted %r before starting a field" % key)
1043
self._inserted = True
1044
key = to_labels(key) # Python 3 sucks
1047
lastkey = self.lastkey
1050
raise KeyError("Can't store a null key %r" % (key,))
1051
if lastkey and lastkey > key:
1052
raise KeyError("Keys out of order %r..%r" % (lastkey, key))
1054
# Find the common prefix shared by this key and the previous one
1056
for i in xrange(min(len(lastkey), len(key))):
1057
if lastkey[i] != key[i]:
1060
# Compile the nodes after the prefix, since they're not shared
1061
self._freeze_tail(prefixlen + 1)
1063
# Create new nodes for the parts of this key after the shared prefix
1064
for char in key[prefixlen:]:
1065
node = UncompiledNode(self)
1066
# Create an arc to this node on the previous node
1067
nodes[-1].add_arc(char, node)
1069
# Mark the last node as an accept state
1070
lastnode = nodes[-1]
1071
lastnode.accept = True
1074
if value is not None and not vtype.is_valid(value):
1075
raise ValueError("%r is not valid for %s" % (value, vtype))
1077
# Push value commonalities through the tree
1079
for i in xrange(1, prefixlen + 1):
1081
parent = nodes[i - 1]
1082
lastvalue = parent.last_value(key[i - 1])
1083
if lastvalue is not None:
1084
common = vtype.common(value, lastvalue)
1085
suffix = vtype.subtract(lastvalue, common)
1086
parent.set_last_value(key[i - 1], common)
1087
node.prepend_value(suffix)
1089
common = suffix = None
1090
value = vtype.subtract(value, common)
1093
# If this key is a duplicate, merge its value with the value of
1094
# the previous (same) key
1095
lastnode.value = self.merge(lastnode.value, value)
1097
nodes[prefixlen].set_last_value(key[prefixlen], value)
1099
raise Exception("Value %r but no value type" % value)
1103
def _freeze_tail(self, prefixlen):
1105
lastkey = self.lastkey
1106
downto = max(1, prefixlen)
1108
while len(nodes) > downto:
1111
inlabel = lastkey[len(nodes) - 1]
1113
self._compile_targets(node)
1114
accept = node.accept or len(node.arcs) == 0
1115
address = self._compile_node(node)
1116
parent.replace_last(inlabel, address, accept, node.value)
1121
# Minimize nodes in the last word's suffix
1122
self._freeze_tail(0)
1123
# Compile remaining targets
1124
self._compile_targets(root)
1125
return self._compile_node(root)
1127
def _compile_targets(self, node):
1128
for arc in node.arcs:
1129
if isinstance(arc.target, UncompiledNode):
1131
if len(n.arcs) == 0:
1132
arc.accept = n.accept = True
1133
arc.target = self._compile_node(n)
1135
def _compile_node(self, uncnode):
1138
if len(uncnode.arcs) == 0:
1140
address = self._write_node(uncnode)
1142
d = uncnode.digest()
1143
address = seen.get(d)
1145
address = self._write_node(uncnode)
1149
def _write_node(self, uncnode):
1151
dbfile = self.dbfile
1159
# What does it mean for an arc to stop but not be accepted?
1161
self.node_count += 1
1163
buf = StructFile(BytesIO())
1164
nodestart = dbfile.tell()
1166
#self.arccount += numarcs
1169
arcstart = buf.tell()
1170
for i, arc in enumerate(arcs):
1177
flags += MULTIBYTE_LABEL
1178
if i == numarcs - 1:
1184
if arc.value is not None:
1185
flags += ARC_HAS_VAL
1186
if arc.acceptval is not None:
1187
flags += ARC_HAS_ACCEPT_VAL
1189
buf.write(pack_byte(flags))
1191
buf.write(varint(len(label)))
1193
if target is not None:
1194
buf.write(pack_uint(target))
1195
if arc.value is not None:
1196
vtype.write(buf, arc.value)
1197
if arc.acceptval is not None:
1198
vtype.write(buf, arc.acceptval)
1201
thissize = here - arcstart
1204
fixedsize = thissize
1205
elif fixedsize > 0 and thissize != fixedsize:
1209
# Write a fake arc containing the fixed size and number of arcs
1210
dbfile.write_byte(255) # FIXED_SIZE
1211
dbfile.write_int(fixedsize)
1212
dbfile.write_int(numarcs)
1213
self.fixed_count += 1
1214
dbfile.write(buf.file.getvalue())
1221
class BaseGraphReader(object):
1222
def cursor(self, rootname=None):
1223
return Cursor(self, self.root(rootname))
1225
def has_root(self, rootname):
1226
raise NotImplementedError
1228
def root(self, rootname=None):
1229
raise NotImplementedError
1233
def arc_at(self, address, arc):
1234
raise NotImplementedError
1236
def iter_arcs(self, address, arc=None):
1237
raise NotImplementedError
1239
def find_arc(self, address, label, arc=None):
1241
for arc in self.iter_arcs(address, arc):
1242
thislabel = arc.label
1243
if thislabel == label:
1245
elif thislabel > label:
1248
# Convenience methods
1250
def list_arcs(self, address):
1251
return list(copy.copy(arc) for arc in self.iter_arcs(address))
1253
def arc_dict(self, address):
1254
return dict((arc.label, copy.copy(arc))
1255
for arc in self.iter_arcs(address))
1257
def find_path(self, path, arc=None, address=None):
1258
path = to_labels(path)
1261
address = arc.target
1266
address = self._root
1271
if not self.find_arc(address, label, arc):
1273
address = arc.target
1277
class GraphReader(BaseGraphReader):
1278
def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
1279
self.dbfile = dbfile
1281
self.filebase = filebase
1283
dbfile.seek(filebase)
1284
magic = dbfile.read(4)
1285
if magic != b("GRPH"):
1286
raise FileVersionError
1287
self.version = dbfile.read_int()
1288
dbfile.seek(dbfile.read_uint())
1289
self.roots = dbfile.read_pickle()
1292
if rootname is None and len(self.roots) == 1:
1293
# If there's only one root, just use it. Have to wrap a list around
1294
# the keys() method here because of Python 3.
1295
rootname = list(self.roots.keys())[0]
1296
if rootname is not None:
1297
self._root = self.root(rootname)
1304
def has_root(self, rootname):
1305
return rootname in self.roots
1307
def root(self, rootname=None):
1308
if rootname is None:
1311
return self.roots[rootname]
1313
def default_root(self):
1316
def arc_at(self, address, arc=None):
1318
self.dbfile.seek(address)
1319
return self._read_arc(arc)
1321
def iter_arcs(self, address, arc=None):
1323
_read_arc = self._read_arc
1325
self.dbfile.seek(address)
1332
def find_arc(self, address, label, arc=None):
1334
dbfile = self.dbfile
1335
dbfile.seek(address)
1337
# If records are fixed size, we can do a binary search
1338
finfo = self._read_fixed_info()
1341
address = dbfile.tell()
1343
return self._binary_search(address, size, count, label, arc)
1345
# If records aren't fixed size, fall back to the parent's linear
1347
return BaseGraphReader.find_arc(self, address, label, arc)
1351
def _read_arc(self, toarc=None):
1352
toarc = toarc or Arc()
1353
dbfile = self.dbfile
1354
flags = dbfile.read_byte()
1356
# This is a fake arc containing fixed size information; skip it
1357
# and read the next arc
1358
dbfile.seek(_INT_SIZE * 2, 1)
1359
flags = dbfile.read_byte()
1360
toarc.label = self._read_label(flags)
1361
return self._read_arc_data(flags, toarc)
1363
def _read_label(self, flags):
1364
dbfile = self.dbfile
1365
if flags & MULTIBYTE_LABEL:
1366
length = dbfile.read_varint()
1369
label = dbfile.read(length)
1372
def _read_fixed_info(self):
1373
dbfile = self.dbfile
1375
flags = dbfile.read_byte()
1377
size = dbfile.read_int()
1378
count = dbfile.read_int()
1379
return (size, count)
1383
def _read_arc_data(self, flags, arc):
1384
dbfile = self.dbfile
1385
accept = arc.accept = bool(flags & ARC_ACCEPT)
1386
arc.lastarc = flags & ARC_LAST
1387
if flags & ARC_STOP:
1390
arc.target = dbfile.read_uint()
1391
if flags & ARC_HAS_VAL:
1392
arc.value = self.vtype.read(dbfile)
1395
if accept and flags & ARC_HAS_ACCEPT_VAL:
1396
arc.acceptval = self.vtype.read(dbfile)
1397
arc.endpos = dbfile.tell()
1400
def _binary_search(self, address, size, count, label, arc):
1401
dbfile = self.dbfile
1402
_read_label = self._read_label
1407
mid = (lo + hi) // 2
1408
midaddr = address + mid * size
1409
dbfile.seek(midaddr)
1410
flags = dbfile.read_byte()
1411
midlabel = self._read_label(flags)
1412
if midlabel == label:
1413
arc.label = midlabel
1414
return self._read_arc_data(flags, arc)
1415
elif midlabel < label:
1424
"""Takes a string and returns a list of bytestrings, suitable for use as
1425
a key or path in an FSA/FST graph.
1428
# Convert to tuples of bytestrings (must be tuples so they can be hashed)
1431
# I hate the Python 3 bytes object so friggin much
1432
if keytype is tuple or keytype is list:
1433
if not all(isinstance(e, bytes_type) for e in key):
1434
raise TypeError("%r contains a non-bytestring")
1437
elif isinstance(key, bytes_type):
1438
key = tuple(key[i:i + 1] for i in xrange(len(key)))
1439
elif isinstance(key, text_type):
1440
key = tuple(utf8encode(key[i:i + 1])[0] for i in xrange(len(key)))
1442
raise TypeError("Don't know how to convert %r" % key)
1446
# Within edit distance function
1448
def within(graph, text, k=1, prefix=0, address=None):
1449
"""Yields a series of keys in the given graph within ``k`` edit distance of
1450
``text``. If ``prefix`` is greater than 0, all keys must match the first
1451
``prefix`` characters of ``text``.
1454
text = to_labels(text)
1456
address = graph._root
1461
prefixchars = text[:prefix]
1462
arc = graph.find_path(prefixchars, address=address)
1465
sofar = emptybytes.join(prefixchars)
1466
address = arc.target
1469
stack = [(address, k, prefix, sofar, accept)]
1473
# Have we already tried this state?
1478
address, k, i, sofar, accept = state
1479
# If we're at the end of the text (or deleting enough chars would get
1480
# us to the end and still within K), and we're in the accept state,
1481
# yield the current result
1482
if (len(text) - i <= k) and accept:
1483
yield utf8decode(sofar)[0]
1485
# If we're in the stop state, give up
1491
arc = graph.find_arc(address, text[i])
1493
stack.append((arc.target, k, i + 1, sofar + text[i],
1495
# If K is already 0, can't do any more edits
1500
arcs = graph.arc_dict(address)
1502
stack.extend((arc.target, k, i, sofar + char, arc.accept)
1503
for char, arc in iteritems(arcs))
1505
# Deletion, replacement, and transpo only work before the end
1511
stack.append((address, k, i + 1, sofar, False))
1513
for char2, arc in iteritems(arcs):
1515
stack.append((arc.target, k, i + 1, sofar + char2, arc.accept))
1517
if i < len(text) - 1:
1519
if char != char2 and char2 in arcs:
1520
# Find arc from next char to this char
1521
target = arcs[char2].target
1523
arc = graph.find_arc(target, char)
1525
stack.append((arc.target, k, i + 2,
1526
sofar + char2 + char, arc.accept))
1531
def dump_graph(graph, address=None, tab=0, out=None):
1533
address = graph._root
1537
here = "%06d" % address
1538
for i, arc in enumerate(graph.list_arcs(address)):
1543
out.write(" " * tab)
1544
out.write("%r %r %s %r\n"
1545
% (arc.label, arc.target, arc.accept, arc.value))
1546
if arc.target is not None:
1547
dump_graph(graph, arc.target, tab + 1, out=out)