2
# Copyright (C) 2005-2011 the SQLAlchemy authors and contributors <see AUTHORS file>
4
# This module is part of SQLAlchemy and is released under
5
# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
"""Collection classes and helpers."""
13
from langhelpers import symbol
14
from compat import time_func, threading
16
EMPTY_SET = frozenset()
19
class NamedTuple(tuple):
20
"""tuple() subclass that adds labeled names.
26
def __new__(cls, vals, labels=None):
27
t = tuple.__new__(cls, vals)
29
t.__dict__.update(zip(labels, vals))
34
return [l for l in self._labels if l is not None]
36
class ImmutableContainer(object):
37
def _immutable(self, *arg, **kw):
38
raise TypeError("%s object is immutable" % self.__class__.__name__)
40
__delitem__ = __setitem__ = __setattr__ = _immutable
42
class immutabledict(ImmutableContainer, dict):
44
clear = pop = popitem = setdefault = \
45
update = ImmutableContainer._immutable
47
def __new__(cls, *args):
48
new = dict.__new__(cls)
49
dict.__init__(new, *args)
52
def __init__(self, *args):
56
return immutabledict, (dict(self), )
60
return immutabledict(d)
62
d2 = immutabledict(self)
67
return "immutabledict(%s)" % dict.__repr__(self)
69
class Properties(object):
70
"""Provide a __getattr__/__setattr__ interface over a dict."""
72
def __init__(self, data):
73
self.__dict__['_data'] = data
76
return len(self._data)
79
return self._data.itervalues()
81
def __add__(self, other):
82
return list(self) + list(other)
84
def __setitem__(self, key, object):
85
self._data[key] = object
87
def __getitem__(self, key):
88
return self._data[key]
90
def __delitem__(self, key):
93
def __setattr__(self, key, object):
94
self._data[key] = object
96
def __getstate__(self):
97
return {'_data': self.__dict__['_data']}
99
def __setstate__(self, state):
100
self.__dict__['_data'] = state['_data']
102
def __getattr__(self, key):
104
return self._data[key]
106
raise AttributeError(key)
108
def __contains__(self, key):
109
return key in self._data
111
def as_immutable(self):
112
"""Return an immutable proxy for this :class:`.Properties`."""
114
return ImmutableProperties(self._data)
116
def update(self, value):
117
self._data.update(value)
119
def get(self, key, default=None):
126
return self._data.keys()
128
def has_key(self, key):
129
return key in self._data
134
class OrderedProperties(Properties):
135
"""Provide a __getattr__/__setattr__ interface with an OrderedDict
138
Properties.__init__(self, OrderedDict())
141
class ImmutableProperties(ImmutableContainer, Properties):
142
"""Provide immutable dict/object attribute to an underlying dictionary."""
145
class OrderedDict(dict):
146
"""A dict that returns keys/values/items in the order they were added."""
148
def __init__(self, ____sequence=None, **kwargs):
150
if ____sequence is None:
152
self.update(**kwargs)
154
self.update(____sequence, **kwargs)
161
return self.__copy__()
164
return OrderedDict(self)
166
def sort(self, *arg, **kw):
167
self._list.sort(*arg, **kw)
169
def update(self, ____sequence=None, **kwargs):
170
if ____sequence is not None:
171
if hasattr(____sequence, 'keys'):
172
for key in ____sequence.keys():
173
self.__setitem__(key, ____sequence[key])
175
for key, value in ____sequence:
180
def setdefault(self, key, value):
182
self.__setitem__(key, value)
185
return self.__getitem__(key)
188
return iter(self._list)
191
return [self[key] for key in self._list]
193
def itervalues(self):
194
return iter([self[key] for key in self._list])
197
return list(self._list)
200
return iter(self.keys())
203
return [(key, self[key]) for key in self.keys()]
206
return iter(self.items())
208
def __setitem__(self, key, object):
211
self._list.append(key)
212
except AttributeError:
213
# work around Python pickle loads() with
214
# dict subclass (seems to ignore __setstate__?)
216
dict.__setitem__(self, key, object)
218
def __delitem__(self, key):
219
dict.__delitem__(self, key)
220
self._list.remove(key)
222
def pop(self, key, *default):
223
present = key in self
224
value = dict.pop(self, key, *default)
226
self._list.remove(key)
230
item = dict.popitem(self)
231
self._list.remove(item[0])
234
class OrderedSet(set):
235
def __init__(self, d=None):
241
def add(self, element):
242
if element not in self:
243
self._list.append(element)
244
set.add(self, element)
246
def remove(self, element):
247
set.remove(self, element)
248
self._list.remove(element)
250
def insert(self, pos, element):
251
if element not in self:
252
self._list.insert(pos, element)
253
set.add(self, element)
255
def discard(self, element):
257
self._list.remove(element)
258
set.remove(self, element)
264
def __getitem__(self, key):
265
return self._list[key]
268
return iter(self._list)
270
def __add__(self, other):
271
return self.union(other)
274
return '%s(%r)' % (self.__class__.__name__, self._list)
278
def update(self, iterable):
287
def union(self, other):
288
result = self.__class__(self)
294
def intersection(self, other):
296
return self.__class__(a for a in self if a in other)
298
__and__ = intersection
300
def symmetric_difference(self, other):
302
result = self.__class__(a for a in self if a not in other)
303
result.update(a for a in other if a not in self)
306
__xor__ = symmetric_difference
308
def difference(self, other):
310
return self.__class__(a for a in self if a not in other)
314
def intersection_update(self, other):
316
set.intersection_update(self, other)
317
self._list = [ a for a in self._list if a in other]
320
__iand__ = intersection_update
322
def symmetric_difference_update(self, other):
323
set.symmetric_difference_update(self, other)
324
self._list = [ a for a in self._list if a in self]
325
self._list += [ a for a in other._list if a in self]
328
__ixor__ = symmetric_difference_update
330
def difference_update(self, other):
331
set.difference_update(self, other)
332
self._list = [ a for a in self._list if a in self]
335
__isub__ = difference_update
338
class IdentitySet(object):
339
"""A set that considers only object id() for uniqueness.
341
This strategy has edge cases for builtin types- it's possible to have
342
two 'foo' strings in one of these sets, for example. Use sparingly.
348
def __init__(self, iterable=None):
349
self._members = dict()
354
def add(self, value):
355
self._members[id(value)] = value
357
def __contains__(self, value):
358
return id(value) in self._members
360
def remove(self, value):
361
del self._members[id(value)]
363
def discard(self, value):
371
pair = self._members.popitem()
374
raise KeyError('pop from an empty set')
377
self._members.clear()
379
def __cmp__(self, other):
380
raise TypeError('cannot compare sets using cmp()')
382
def __eq__(self, other):
383
if isinstance(other, IdentitySet):
384
return self._members == other._members
388
def __ne__(self, other):
389
if isinstance(other, IdentitySet):
390
return self._members != other._members
394
def issubset(self, iterable):
395
other = type(self)(iterable)
397
if len(self) > len(other):
399
for m in itertools.ifilterfalse(other._members.__contains__,
400
self._members.iterkeys()):
404
def __le__(self, other):
405
if not isinstance(other, IdentitySet):
406
return NotImplemented
407
return self.issubset(other)
409
def __lt__(self, other):
410
if not isinstance(other, IdentitySet):
411
return NotImplemented
412
return len(self) < len(other) and self.issubset(other)
414
def issuperset(self, iterable):
415
other = type(self)(iterable)
417
if len(self) < len(other):
420
for m in itertools.ifilterfalse(self._members.__contains__,
421
other._members.iterkeys()):
425
def __ge__(self, other):
426
if not isinstance(other, IdentitySet):
427
return NotImplemented
428
return self.issuperset(other)
430
def __gt__(self, other):
431
if not isinstance(other, IdentitySet):
432
return NotImplemented
433
return len(self) > len(other) and self.issuperset(other)
435
def union(self, iterable):
436
result = type(self)()
437
# testlib.pragma exempt:__hash__
438
result._members.update(
439
self._working_set(self._member_id_tuples()).union(_iter_id(iterable)))
442
def __or__(self, other):
443
if not isinstance(other, IdentitySet):
444
return NotImplemented
445
return self.union(other)
447
def update(self, iterable):
448
self._members = self.union(iterable)._members
450
def __ior__(self, other):
451
if not isinstance(other, IdentitySet):
452
return NotImplemented
456
def difference(self, iterable):
457
result = type(self)()
458
# testlib.pragma exempt:__hash__
459
result._members.update(
460
self._working_set(self._member_id_tuples()).difference(_iter_id(iterable)))
463
def __sub__(self, other):
464
if not isinstance(other, IdentitySet):
465
return NotImplemented
466
return self.difference(other)
468
def difference_update(self, iterable):
469
self._members = self.difference(iterable)._members
471
def __isub__(self, other):
472
if not isinstance(other, IdentitySet):
473
return NotImplemented
474
self.difference_update(other)
477
def intersection(self, iterable):
478
result = type(self)()
479
# testlib.pragma exempt:__hash__
480
result._members.update(
481
self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable)))
484
def __and__(self, other):
485
if not isinstance(other, IdentitySet):
486
return NotImplemented
487
return self.intersection(other)
489
def intersection_update(self, iterable):
490
self._members = self.intersection(iterable)._members
492
def __iand__(self, other):
493
if not isinstance(other, IdentitySet):
494
return NotImplemented
495
self.intersection_update(other)
498
def symmetric_difference(self, iterable):
499
result = type(self)()
500
# testlib.pragma exempt:__hash__
501
result._members.update(
502
self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable)))
505
def _member_id_tuples(self):
506
return ((id(v), v) for v in self._members.itervalues())
508
def __xor__(self, other):
509
if not isinstance(other, IdentitySet):
510
return NotImplemented
511
return self.symmetric_difference(other)
513
def symmetric_difference_update(self, iterable):
514
self._members = self.symmetric_difference(iterable)._members
516
def __ixor__(self, other):
517
if not isinstance(other, IdentitySet):
518
return NotImplemented
519
self.symmetric_difference(other)
523
return type(self)(self._members.itervalues())
528
return len(self._members)
531
return self._members.itervalues()
534
raise TypeError('set objects are unhashable')
537
return '%s(%r)' % (type(self).__name__, self._members.values())
540
class OrderedIdentitySet(IdentitySet):
541
class _working_set(OrderedSet):
542
# a testing pragma: exempt the OIDS working set from the test suite's
543
# "never call the user's __hash__" assertions. this is a big hammer,
544
# but it's safe here: IDS operates on (id, instance) tuples in the
546
__sa_hash_exempt__ = True
548
def __init__(self, iterable=None):
549
IdentitySet.__init__(self)
550
self._members = OrderedDict()
556
if sys.version_info >= (2, 5):
557
class PopulateDict(dict):
558
"""A dict which populates missing values via a creation function.
560
Note the creation function takes a key, unlike
561
collections.defaultdict.
565
def __init__(self, creator):
566
self.creator = creator
568
def __missing__(self, key):
569
self[key] = val = self.creator(key)
572
class PopulateDict(dict):
573
"""A dict which populates missing values via a creation function."""
575
def __init__(self, creator):
576
self.creator = creator
578
def __getitem__(self, key):
580
return dict.__getitem__(self, key)
582
self[key] = value = self.creator(key)
585
# define collections that are capable of storing
586
# ColumnElement objects as hashable keys/elements.
589
ordered_column_set = OrderedSet
590
populate_column_dict = PopulateDict
592
def unique_list(seq, hashfunc=None):
595
return [x for x in seq
597
and not seen.__setitem__(x, True)]
599
return [x for x in seq
600
if hashfunc(x) not in seen
601
and not seen.__setitem__(hashfunc(x), True)]
603
class UniqueAppender(object):
604
"""Appends items to a collection ensuring uniqueness.
606
Additional appends() of the same object are ignored. Membership is
607
determined by identity (``is a``) not equality (``==``).
610
def __init__(self, data, via=None):
614
self._data_appender = getattr(data, via)
615
elif hasattr(data, 'append'):
616
self._data_appender = data.append
617
elif hasattr(data, 'add'):
618
self._data_appender = data.add
620
def append(self, item):
622
if id_ not in self._unique:
623
self._data_appender(item)
624
self._unique[id_] = True
627
return iter(self.data)
629
def to_list(x, default=None):
632
if not isinstance(x, (list, tuple)):
640
if not isinstance(x, set):
641
return set(to_list(x))
645
def to_column_set(x):
648
if not isinstance(x, column_set):
649
return column_set(to_list(x))
653
def update_copy(d, _new=None, **kw):
654
"""Copy the given dict and update with the given values."""
662
def flatten_iterator(x):
663
"""Given an iterator of which further sub-elements may also be
664
iterators, flatten the sub-elements into a single iterator.
668
if not isinstance(elem, basestring) and hasattr(elem, '__iter__'):
669
for y in flatten_iterator(elem):
674
class WeakIdentityMapping(weakref.WeakKeyDictionary):
675
"""A WeakKeyDictionary with an object identity index.
677
Adds a .by_id dictionary to a regular WeakKeyDictionary. Trades
678
performance during mutation operations for accelerated lookups by id().
680
The usual cautions about weak dictionaries and iteration also apply to
684
_none = symbol('none')
687
weakref.WeakKeyDictionary.__init__(self)
691
def __setitem__(self, object, value):
693
self.by_id[oid] = value
694
if oid not in self._weakrefs:
695
self._weakrefs[oid] = self._ref(object)
696
weakref.WeakKeyDictionary.__setitem__(self, object, value)
698
def __delitem__(self, object):
699
del self._weakrefs[id(object)]
700
del self.by_id[id(object)]
701
weakref.WeakKeyDictionary.__delitem__(self, object)
703
def setdefault(self, object, default=None):
704
value = weakref.WeakKeyDictionary.setdefault(self, object, default)
707
self.by_id[oid] = default
708
if oid not in self._weakrefs:
709
self._weakrefs[oid] = self._ref(object)
712
def pop(self, object, default=_none):
713
if default is self._none:
714
value = weakref.WeakKeyDictionary.pop(self, object)
716
value = weakref.WeakKeyDictionary.pop(self, object, default)
717
if id(object) in self.by_id:
718
del self._weakrefs[id(object)]
719
del self.by_id[id(object)]
723
item = weakref.WeakKeyDictionary.popitem(self)
725
del self._weakrefs[oid]
731
# in 3k, MutableMapping calls popitem()
732
self._weakrefs.clear()
735
weakref.WeakKeyDictionary.clear(self)
737
def update(self, *a, **kw):
738
raise NotImplementedError
740
def _cleanup(self, wr, key=None):
744
del self._weakrefs[key]
745
except (KeyError, AttributeError): # pragma: no cover
746
pass # pragma: no cover
749
except (KeyError, AttributeError): # pragma: no cover
750
pass # pragma: no cover
752
class _keyed_weakref(weakref.ref):
753
def __init__(self, object, callback):
754
weakref.ref.__init__(self, object, callback)
755
self.key = id(object)
757
def _ref(self, object):
758
return self._keyed_weakref(object, self._cleanup)
761
class LRUCache(dict):
762
"""Dictionary with 'squishy' removal of least
766
def __init__(self, capacity=100, threshold=.5):
767
self.capacity = capacity
768
self.threshold = threshold
770
def __getitem__(self, key):
771
item = dict.__getitem__(self, key)
772
item[2] = time_func()
776
return [i[1] for i in dict.values(self)]
778
def setdefault(self, key, value):
785
def __setitem__(self, key, value):
786
item = dict.get(self, key)
788
item = [key, value, time_func()]
789
dict.__setitem__(self, key, item)
794
def _manage_size(self):
795
while len(self) > self.capacity + self.capacity * self.threshold:
796
bytime = sorted(dict.values(self),
797
key=operator.itemgetter(2),
799
for item in bytime[self.capacity:]:
803
# if we couldnt find a key, most
804
# likely some other thread broke in
805
# on us. loop around and try again
809
class ScopedRegistry(object):
810
"""A Registry that can store one or multiple instances of a single
811
class on the basis of a "scope" function.
813
The object implements ``__call__`` as the "getter", so by
814
calling ``myregistry()`` the contained object is returned
815
for the current scope.
818
a callable that returns a new object to be placed in the registry
821
a callable that will return a key to store/retrieve an object.
824
def __init__(self, createfunc, scopefunc):
825
"""Construct a new :class:`.ScopedRegistry`.
827
:param createfunc: A creation function that will generate
828
a new value for the current scope, if none is present.
830
:param scopefunc: A function that returns a hashable
831
token representing the current scope (such as, current
835
self.createfunc = createfunc
836
self.scopefunc = scopefunc
840
key = self.scopefunc()
842
return self.registry[key]
844
return self.registry.setdefault(key, self.createfunc())
847
"""Return True if an object is present in the current scope."""
849
return self.scopefunc() in self.registry
852
"""Set the value forthe current scope."""
854
self.registry[self.scopefunc()] = obj
857
"""Clear the current scope, if any."""
860
del self.registry[self.scopefunc()]
864
class ThreadLocalRegistry(ScopedRegistry):
865
"""A :class:`.ScopedRegistry` that uses a ``threading.local()``
866
variable for storage.
869
def __init__(self, createfunc):
870
self.createfunc = createfunc
871
self.registry = threading.local()
875
return self.registry.value
876
except AttributeError:
877
val = self.registry.value = self.createfunc()
881
return hasattr(self.registry, "value")
884
self.registry.value = obj
888
del self.registry.value
889
except AttributeError:
892
def _iter_id(iterable):
893
"""Generator: ((id(o), o) for o in iterable)."""
895
for item in iterable: