~malept/ubuntu/lucid/python2.6/dev-dependency-fix

« back to all changes in this revision

Viewing changes to Lib/test/test_heapq.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2009-02-13 12:51:00 UTC
  • Revision ID: james.westby@ubuntu.com-20090213125100-uufgcb9yeqzujpqw
Tags: upstream-2.6.1
ImportĀ upstreamĀ versionĀ 2.6.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
"""Unittests for heapq."""
 
2
 
 
3
import random
 
4
import unittest
 
5
from test import test_support
 
6
import sys
 
7
 
 
8
# We do a bit of trickery here to be able to test both the C implementation
 
9
# and the Python implementation of the module.
 
10
 
 
11
# Make it impossible to import the C implementation anymore.
 
12
sys.modules['_heapq'] = 0
 
13
# We must also handle the case that heapq was imported before.
 
14
if 'heapq' in sys.modules:
 
15
    del sys.modules['heapq']
 
16
 
 
17
# Now we can import the module and get the pure Python implementation.
 
18
import heapq as py_heapq
 
19
 
 
20
# Restore everything to normal.
 
21
del sys.modules['_heapq']
 
22
del sys.modules['heapq']
 
23
 
 
24
# This is now the module with the C implementation.
 
25
import heapq as c_heapq
 
26
 
 
27
 
 
28
class TestHeap(unittest.TestCase):
 
29
    module = None
 
30
 
 
31
    def test_push_pop(self):
 
32
        # 1) Push 256 random numbers and pop them off, verifying all's OK.
 
33
        heap = []
 
34
        data = []
 
35
        self.check_invariant(heap)
 
36
        for i in range(256):
 
37
            item = random.random()
 
38
            data.append(item)
 
39
            self.module.heappush(heap, item)
 
40
            self.check_invariant(heap)
 
41
        results = []
 
42
        while heap:
 
43
            item = self.module.heappop(heap)
 
44
            self.check_invariant(heap)
 
45
            results.append(item)
 
46
        data_sorted = data[:]
 
47
        data_sorted.sort()
 
48
        self.assertEqual(data_sorted, results)
 
49
        # 2) Check that the invariant holds for a sorted array
 
50
        self.check_invariant(results)
 
51
 
 
52
        self.assertRaises(TypeError, self.module.heappush, [])
 
53
        try:
 
54
            self.assertRaises(TypeError, self.module.heappush, None, None)
 
55
            self.assertRaises(TypeError, self.module.heappop, None)
 
56
        except AttributeError:
 
57
            pass
 
58
 
 
59
    def check_invariant(self, heap):
 
60
        # Check the heap invariant.
 
61
        for pos, item in enumerate(heap):
 
62
            if pos: # pos 0 has no parent
 
63
                parentpos = (pos-1) >> 1
 
64
                self.assert_(heap[parentpos] <= item)
 
65
 
 
66
    def test_heapify(self):
 
67
        for size in range(30):
 
68
            heap = [random.random() for dummy in range(size)]
 
69
            self.module.heapify(heap)
 
70
            self.check_invariant(heap)
 
71
 
 
72
        self.assertRaises(TypeError, self.module.heapify, None)
 
73
 
 
74
    def test_naive_nbest(self):
 
75
        data = [random.randrange(2000) for i in range(1000)]
 
76
        heap = []
 
77
        for item in data:
 
78
            self.module.heappush(heap, item)
 
79
            if len(heap) > 10:
 
80
                self.module.heappop(heap)
 
81
        heap.sort()
 
82
        self.assertEqual(heap, sorted(data)[-10:])
 
83
 
 
84
    def heapiter(self, heap):
 
85
        # An iterator returning a heap's elements, smallest-first.
 
86
        try:
 
87
            while 1:
 
88
                yield self.module.heappop(heap)
 
89
        except IndexError:
 
90
            pass
 
91
 
 
92
    def test_nbest(self):
 
93
        # Less-naive "N-best" algorithm, much faster (if len(data) is big
 
94
        # enough <wink>) than sorting all of data.  However, if we had a max
 
95
        # heap instead of a min heap, it could go faster still via
 
96
        # heapify'ing all of data (linear time), then doing 10 heappops
 
97
        # (10 log-time steps).
 
98
        data = [random.randrange(2000) for i in range(1000)]
 
99
        heap = data[:10]
 
100
        self.module.heapify(heap)
 
101
        for item in data[10:]:
 
102
            if item > heap[0]:  # this gets rarer the longer we run
 
103
                self.module.heapreplace(heap, item)
 
104
        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
 
105
 
 
106
        self.assertRaises(TypeError, self.module.heapreplace, None)
 
107
        self.assertRaises(TypeError, self.module.heapreplace, None, None)
 
108
        self.assertRaises(IndexError, self.module.heapreplace, [], None)
 
109
 
 
110
    def test_nbest_with_pushpop(self):
 
111
        data = [random.randrange(2000) for i in range(1000)]
 
112
        heap = data[:10]
 
113
        self.module.heapify(heap)
 
114
        for item in data[10:]:
 
115
            self.module.heappushpop(heap, item)
 
116
        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
 
117
        self.assertEqual(self.module.heappushpop([], 'x'), 'x')
 
118
 
 
119
    def test_heappushpop(self):
 
120
        h = []
 
121
        x = self.module.heappushpop(h, 10)
 
122
        self.assertEqual((h, x), ([], 10))
 
123
 
 
124
        h = [10]
 
125
        x = self.module.heappushpop(h, 10.0)
 
126
        self.assertEqual((h, x), ([10], 10.0))
 
127
        self.assertEqual(type(h[0]), int)
 
128
        self.assertEqual(type(x), float)
 
129
 
 
130
        h = [10];
 
131
        x = self.module.heappushpop(h, 9)
 
132
        self.assertEqual((h, x), ([10], 9))
 
133
 
 
134
        h = [10];
 
135
        x = self.module.heappushpop(h, 11)
 
136
        self.assertEqual((h, x), ([11], 10))
 
137
 
 
138
    def test_heapsort(self):
 
139
        # Exercise everything with repeated heapsort checks
 
140
        for trial in xrange(100):
 
141
            size = random.randrange(50)
 
142
            data = [random.randrange(25) for i in range(size)]
 
143
            if trial & 1:     # Half of the time, use heapify
 
144
                heap = data[:]
 
145
                self.module.heapify(heap)
 
146
            else:             # The rest of the time, use heappush
 
147
                heap = []
 
148
                for item in data:
 
149
                    self.module.heappush(heap, item)
 
150
            heap_sorted = [self.module.heappop(heap) for i in range(size)]
 
151
            self.assertEqual(heap_sorted, sorted(data))
 
152
 
 
153
    def test_merge(self):
 
154
        inputs = []
 
155
        for i in xrange(random.randrange(5)):
 
156
            row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
 
157
            inputs.append(row)
 
158
        self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
 
159
        self.assertEqual(list(self.module.merge()), [])
 
160
 
 
161
    def test_merge_stability(self):
 
162
        class Int(int):
 
163
            pass
 
164
        inputs = [[], [], [], []]
 
165
        for i in range(20000):
 
166
            stream = random.randrange(4)
 
167
            x = random.randrange(500)
 
168
            obj = Int(x)
 
169
            obj.pair = (x, stream)
 
170
            inputs[stream].append(obj)
 
171
        for stream in inputs:
 
172
            stream.sort()
 
173
        result = [i.pair for i in self.module.merge(*inputs)]
 
174
        self.assertEqual(result, sorted(result))
 
175
 
 
176
    def test_nsmallest(self):
 
177
        data = [(random.randrange(2000), i) for i in range(1000)]
 
178
        for f in (None, lambda x:  x[0] * 547 % 2000):
 
179
            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
 
180
                self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
 
181
                self.assertEqual(self.module.nsmallest(n, data, key=f),
 
182
                                 sorted(data, key=f)[:n])
 
183
 
 
184
    def test_nlargest(self):
 
185
        data = [(random.randrange(2000), i) for i in range(1000)]
 
186
        for f in (None, lambda x:  x[0] * 547 % 2000):
 
187
            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
 
188
                self.assertEqual(self.module.nlargest(n, data),
 
189
                                 sorted(data, reverse=True)[:n])
 
190
                self.assertEqual(self.module.nlargest(n, data, key=f),
 
191
                                 sorted(data, key=f, reverse=True)[:n])
 
192
 
 
193
class TestHeapPython(TestHeap):
 
194
    module = py_heapq
 
195
 
 
196
class TestHeapC(TestHeap):
 
197
    module = c_heapq
 
198
 
 
199
    def test_comparison_operator(self):
 
200
        # Issue 3501: Make sure heapq works with both __lt__ and __le__
 
201
        def hsort(data, comp):
 
202
            data = map(comp, data)
 
203
            self.module.heapify(data)
 
204
            return [self.module.heappop(data).x for i in range(len(data))]
 
205
        class LT:
 
206
            def __init__(self, x):
 
207
                self.x = x
 
208
            def __lt__(self, other):
 
209
                return self.x > other.x
 
210
        class LE:
 
211
            def __init__(self, x):
 
212
                self.x = x
 
213
            def __le__(self, other):
 
214
                return self.x >= other.x
 
215
        data = [random.random() for i in range(100)]
 
216
        target = sorted(data, reverse=True)
 
217
        self.assertEqual(hsort(data, LT), target)
 
218
        self.assertEqual(hsort(data, LE), target)
 
219
 
 
220
 
 
221
#==============================================================================
 
222
 
 
223
class LenOnly:
 
224
    "Dummy sequence class defining __len__ but not __getitem__."
 
225
    def __len__(self):
 
226
        return 10
 
227
 
 
228
class GetOnly:
 
229
    "Dummy sequence class defining __getitem__ but not __len__."
 
230
    def __getitem__(self, ndx):
 
231
        return 10
 
232
 
 
233
class CmpErr:
 
234
    "Dummy element that always raises an error during comparison"
 
235
    def __cmp__(self, other):
 
236
        raise ZeroDivisionError
 
237
 
 
238
def R(seqn):
 
239
    'Regular generator'
 
240
    for i in seqn:
 
241
        yield i
 
242
 
 
243
class G:
 
244
    'Sequence using __getitem__'
 
245
    def __init__(self, seqn):
 
246
        self.seqn = seqn
 
247
    def __getitem__(self, i):
 
248
        return self.seqn[i]
 
249
 
 
250
class I:
 
251
    'Sequence using iterator protocol'
 
252
    def __init__(self, seqn):
 
253
        self.seqn = seqn
 
254
        self.i = 0
 
255
    def __iter__(self):
 
256
        return self
 
257
    def next(self):
 
258
        if self.i >= len(self.seqn): raise StopIteration
 
259
        v = self.seqn[self.i]
 
260
        self.i += 1
 
261
        return v
 
262
 
 
263
class Ig:
 
264
    'Sequence using iterator protocol defined with a generator'
 
265
    def __init__(self, seqn):
 
266
        self.seqn = seqn
 
267
        self.i = 0
 
268
    def __iter__(self):
 
269
        for val in self.seqn:
 
270
            yield val
 
271
 
 
272
class X:
 
273
    'Missing __getitem__ and __iter__'
 
274
    def __init__(self, seqn):
 
275
        self.seqn = seqn
 
276
        self.i = 0
 
277
    def next(self):
 
278
        if self.i >= len(self.seqn): raise StopIteration
 
279
        v = self.seqn[self.i]
 
280
        self.i += 1
 
281
        return v
 
282
 
 
283
class N:
 
284
    'Iterator missing next()'
 
285
    def __init__(self, seqn):
 
286
        self.seqn = seqn
 
287
        self.i = 0
 
288
    def __iter__(self):
 
289
        return self
 
290
 
 
291
class E:
 
292
    'Test propagation of exceptions'
 
293
    def __init__(self, seqn):
 
294
        self.seqn = seqn
 
295
        self.i = 0
 
296
    def __iter__(self):
 
297
        return self
 
298
    def next(self):
 
299
        3 // 0
 
300
 
 
301
class S:
 
302
    'Test immediate stop'
 
303
    def __init__(self, seqn):
 
304
        pass
 
305
    def __iter__(self):
 
306
        return self
 
307
    def next(self):
 
308
        raise StopIteration
 
309
 
 
310
from itertools import chain, imap
 
311
def L(seqn):
 
312
    'Test multiple tiers of iterators'
 
313
    return chain(imap(lambda x:x, R(Ig(G(seqn)))))
 
314
 
 
315
class TestErrorHandling(unittest.TestCase):
 
316
    # only for C implementation
 
317
    module = c_heapq
 
318
 
 
319
    def test_non_sequence(self):
 
320
        for f in (self.module.heapify, self.module.heappop):
 
321
            self.assertRaises(TypeError, f, 10)
 
322
        for f in (self.module.heappush, self.module.heapreplace,
 
323
                  self.module.nlargest, self.module.nsmallest):
 
324
            self.assertRaises(TypeError, f, 10, 10)
 
325
 
 
326
    def test_len_only(self):
 
327
        for f in (self.module.heapify, self.module.heappop):
 
328
            self.assertRaises(TypeError, f, LenOnly())
 
329
        for f in (self.module.heappush, self.module.heapreplace):
 
330
            self.assertRaises(TypeError, f, LenOnly(), 10)
 
331
        for f in (self.module.nlargest, self.module.nsmallest):
 
332
            self.assertRaises(TypeError, f, 2, LenOnly())
 
333
 
 
334
    def test_get_only(self):
 
335
        for f in (self.module.heapify, self.module.heappop):
 
336
            self.assertRaises(TypeError, f, GetOnly())
 
337
        for f in (self.module.heappush, self.module.heapreplace):
 
338
            self.assertRaises(TypeError, f, GetOnly(), 10)
 
339
        for f in (self.module.nlargest, self.module.nsmallest):
 
340
            self.assertRaises(TypeError, f, 2, GetOnly())
 
341
 
 
342
    def test_get_only(self):
 
343
        seq = [CmpErr(), CmpErr(), CmpErr()]
 
344
        for f in (self.module.heapify, self.module.heappop):
 
345
            self.assertRaises(ZeroDivisionError, f, seq)
 
346
        for f in (self.module.heappush, self.module.heapreplace):
 
347
            self.assertRaises(ZeroDivisionError, f, seq, 10)
 
348
        for f in (self.module.nlargest, self.module.nsmallest):
 
349
            self.assertRaises(ZeroDivisionError, f, 2, seq)
 
350
 
 
351
    def test_arg_parsing(self):
 
352
        for f in (self.module.heapify, self.module.heappop,
 
353
                  self.module.heappush, self.module.heapreplace,
 
354
                  self.module.nlargest, self.module.nsmallest):
 
355
            self.assertRaises(TypeError, f, 10)
 
356
 
 
357
    def test_iterable_args(self):
 
358
        for f in (self.module.nlargest, self.module.nsmallest):
 
359
            for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
 
360
                for g in (G, I, Ig, L, R):
 
361
                    self.assertEqual(f(2, g(s)), f(2,s))
 
362
                self.assertEqual(f(2, S(s)), [])
 
363
                self.assertRaises(TypeError, f, 2, X(s))
 
364
                self.assertRaises(TypeError, f, 2, N(s))
 
365
                self.assertRaises(ZeroDivisionError, f, 2, E(s))
 
366
 
 
367
 
 
368
#==============================================================================
 
369
 
 
370
 
 
371
def test_main(verbose=None):
 
372
    from types import BuiltinFunctionType
 
373
 
 
374
    test_classes = [TestHeapPython, TestHeapC, TestErrorHandling]
 
375
    test_support.run_unittest(*test_classes)
 
376
 
 
377
    # verify reference counting
 
378
    if verbose and hasattr(sys, "gettotalrefcount"):
 
379
        import gc
 
380
        counts = [None] * 5
 
381
        for i in xrange(len(counts)):
 
382
            test_support.run_unittest(*test_classes)
 
383
            gc.collect()
 
384
            counts[i] = sys.gettotalrefcount()
 
385
        print counts
 
386
 
 
387
if __name__ == "__main__":
 
388
    test_main(verbose=True)