~ubuntu-branches/ubuntu/feisty/python-numpy/feisty

« back to all changes in this revision

Viewing changes to numpy/lib/tests/test_shape_base.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2006-07-12 10:00:24 UTC
  • Revision ID: james.westby@ubuntu.com-20060712100024-5lw9q2yczlisqcrt
Tags: upstream-0.9.8
ImportĀ upstreamĀ versionĀ 0.9.8

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
 
 
2
from numpy.testing import *
 
3
set_package_path()
 
4
import numpy.lib;
 
5
from numpy.lib import *
 
6
from numpy.core import *
 
7
restore_path()
 
8
 
 
9
class test_apply_along_axis(ScipyTestCase):
 
10
    def check_simple(self):
 
11
        a = ones((20,10),'d')
 
12
        assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
 
13
    def check_simple101(self,level=11):
 
14
        # This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4)
 
15
        # when enabled and shape(a)[1]>100. See Issue 202.
 
16
        a = ones((10,101),'d')
 
17
        assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
 
18
 
 
19
class test_array_split(ScipyTestCase):
 
20
    def check_integer_0_split(self):
 
21
        a = arange(10)
 
22
        try:
 
23
            res = array_split(a,0)
 
24
            assert(0) # it should have thrown a value error
 
25
        except ValueError:
 
26
            pass
 
27
    def check_integer_split(self):
 
28
        a = arange(10)
 
29
        res = array_split(a,1)
 
30
        desired = [arange(10)]
 
31
        compare_results(res,desired)
 
32
 
 
33
        res = array_split(a,2)
 
34
        desired = [arange(5),arange(5,10)]
 
35
        compare_results(res,desired)
 
36
 
 
37
        res = array_split(a,3)
 
38
        desired = [arange(4),arange(4,7),arange(7,10)]
 
39
        compare_results(res,desired)
 
40
 
 
41
        res = array_split(a,4)
 
42
        desired = [arange(3),arange(3,6),arange(6,8),arange(8,10)]
 
43
        compare_results(res,desired)
 
44
 
 
45
        res = array_split(a,5)
 
46
        desired = [arange(2),arange(2,4),arange(4,6),arange(6,8),arange(8,10)]
 
47
        compare_results(res,desired)
 
48
 
 
49
        res = array_split(a,6)
 
50
        desired = [arange(2),arange(2,4),arange(4,6),arange(6,8),arange(8,9),
 
51
                   arange(9,10)]
 
52
        compare_results(res,desired)
 
53
 
 
54
        res = array_split(a,7)
 
55
        desired = [arange(2),arange(2,4),arange(4,6),arange(6,7),arange(7,8),
 
56
                   arange(8,9), arange(9,10)]
 
57
        compare_results(res,desired)
 
58
 
 
59
        res = array_split(a,8)
 
60
        desired = [arange(2),arange(2,4),arange(4,5),arange(5,6),arange(6,7),
 
61
                   arange(7,8), arange(8,9), arange(9,10)]
 
62
        compare_results(res,desired)
 
63
 
 
64
        res = array_split(a,9)
 
65
        desired = [arange(2),arange(2,3),arange(3,4),arange(4,5),arange(5,6),
 
66
                   arange(6,7), arange(7,8), arange(8,9), arange(9,10)]
 
67
        compare_results(res,desired)
 
68
 
 
69
        res = array_split(a,10)
 
70
        desired = [arange(1),arange(1,2),arange(2,3),arange(3,4),
 
71
                   arange(4,5),arange(5,6), arange(6,7), arange(7,8),
 
72
                   arange(8,9), arange(9,10)]
 
73
        compare_results(res,desired)
 
74
 
 
75
        res = array_split(a,11)
 
76
        desired = [arange(1),arange(1,2),arange(2,3),arange(3,4),
 
77
                   arange(4,5),arange(5,6), arange(6,7), arange(7,8),
 
78
                   arange(8,9), arange(9,10),array([])]
 
79
        compare_results(res,desired)
 
80
    def check_integer_split_2D_rows(self):
 
81
        a = array([arange(10),arange(10)])
 
82
        res = array_split(a,3,axis=0)
 
83
        desired = [array([arange(10)]),array([arange(10)]),array([])]
 
84
        compare_results(res,desired)
 
85
    def check_integer_split_2D_cols(self):
 
86
        a = array([arange(10),arange(10)])
 
87
        res = array_split(a,3,axis=-1)
 
88
        desired = [array([arange(4),arange(4)]),
 
89
                   array([arange(4,7),arange(4,7)]),
 
90
                   array([arange(7,10),arange(7,10)])]
 
91
        compare_results(res,desired)
 
92
    def check_integer_split_2D_default(self):
 
93
        """ This will fail if we change default axis
 
94
        """
 
95
        a = array([arange(10),arange(10)])
 
96
        res = array_split(a,3)
 
97
        desired = [array([arange(10)]),array([arange(10)]),array([])]
 
98
        compare_results(res,desired)
 
99
    #perhaps should check higher dimensions
 
100
 
 
101
    def check_index_split_simple(self):
 
102
        a = arange(10)
 
103
        indices = [1,5,7]
 
104
        res = array_split(a,indices,axis=-1)
 
105
        desired = [arange(0,1),arange(1,5),arange(5,7),arange(7,10)]
 
106
        compare_results(res,desired)
 
107
 
 
108
    def check_index_split_low_bound(self):
 
109
        a = arange(10)
 
110
        indices = [0,5,7]
 
111
        res = array_split(a,indices,axis=-1)
 
112
        desired = [array([]),arange(0,5),arange(5,7),arange(7,10)]
 
113
        compare_results(res,desired)
 
114
    def check_index_split_high_bound(self):
 
115
        a = arange(10)
 
116
        indices = [0,5,7,10,12]
 
117
        res = array_split(a,indices,axis=-1)
 
118
        desired = [array([]),arange(0,5),arange(5,7),arange(7,10),
 
119
                   array([]),array([])]
 
120
        compare_results(res,desired)
 
121
 
 
122
class test_split(ScipyTestCase):
 
123
    """* This function is essentially the same as array_split,
 
124
         except that it test if splitting will result in an
 
125
         equal split.  Only test for this case.
 
126
    *"""
 
127
    def check_equal_split(self):
 
128
        a = arange(10)
 
129
        res = split(a,2)
 
130
        desired = [arange(5),arange(5,10)]
 
131
        compare_results(res,desired)
 
132
 
 
133
    def check_unequal_split(self):
 
134
        a = arange(10)
 
135
        try:
 
136
            res = split(a,3)
 
137
            assert(0) # should raise an error
 
138
        except ValueError:
 
139
            pass
 
140
 
 
141
class test_atleast_1d(ScipyTestCase):
 
142
    def check_0D_array(self):
 
143
        a = array(1); b = array(2);
 
144
        res=map(atleast_1d,[a,b])
 
145
        desired = [array([1]),array([2])]
 
146
        assert_array_equal(res,desired)
 
147
    def check_1D_array(self):
 
148
        a = array([1,2]); b = array([2,3]);
 
149
        res=map(atleast_1d,[a,b])
 
150
        desired = [array([1,2]),array([2,3])]
 
151
        assert_array_equal(res,desired)
 
152
    def check_2D_array(self):
 
153
        a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
 
154
        res=map(atleast_1d,[a,b])
 
155
        desired = [a,b]
 
156
        assert_array_equal(res,desired)
 
157
    def check_3D_array(self):
 
158
        a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
 
159
        a = array([a,a]);b = array([b,b]);
 
160
        res=map(atleast_1d,[a,b])
 
161
        desired = [a,b]
 
162
        assert_array_equal(res,desired)
 
163
    def check_r1array(self):
 
164
        """ Test to make sure equivalent Travis O's r1array function
 
165
        """
 
166
        assert(atleast_1d(3).shape == (1,))
 
167
        assert(atleast_1d(3j).shape == (1,))
 
168
        assert(atleast_1d(3L).shape == (1,))
 
169
        assert(atleast_1d(3.0).shape == (1,))
 
170
        assert(atleast_1d([[2,3],[4,5]]).shape == (2,2))
 
171
 
 
172
class test_atleast_2d(ScipyTestCase):
 
173
    def check_0D_array(self):
 
174
        a = array(1); b = array(2);
 
175
        res=map(atleast_2d,[a,b])
 
176
        desired = [array([[1]]),array([[2]])]
 
177
        assert_array_equal(res,desired)
 
178
    def check_1D_array(self):
 
179
        a = array([1,2]); b = array([2,3]);
 
180
        res=map(atleast_2d,[a,b])
 
181
        desired = [array([[1,2]]),array([[2,3]])]
 
182
        assert_array_equal(res,desired)
 
183
    def check_2D_array(self):
 
184
        a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
 
185
        res=map(atleast_2d,[a,b])
 
186
        desired = [a,b]
 
187
        assert_array_equal(res,desired)
 
188
    def check_3D_array(self):
 
189
        a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
 
190
        a = array([a,a]);b = array([b,b]);
 
191
        res=map(atleast_2d,[a,b])
 
192
        desired = [a,b]
 
193
        assert_array_equal(res,desired)
 
194
    def check_r2array(self):
 
195
        """ Test to make sure equivalent Travis O's r2array function
 
196
        """
 
197
        assert(atleast_2d(3).shape == (1,1))
 
198
        assert(atleast_2d([3j,1]).shape == (1,2))
 
199
        assert(atleast_2d([[[3,1],[4,5]],[[3,5],[1,2]]]).shape == (2,2,2))
 
200
 
 
201
class test_atleast_3d(ScipyTestCase):
 
202
    def check_0D_array(self):
 
203
        a = array(1); b = array(2);
 
204
        res=map(atleast_3d,[a,b])
 
205
        desired = [array([[[1]]]),array([[[2]]])]
 
206
        assert_array_equal(res,desired)
 
207
    def check_1D_array(self):
 
208
        a = array([1,2]); b = array([2,3]);
 
209
        res=map(atleast_3d,[a,b])
 
210
        desired = [array([[[1],[2]]]),array([[[2],[3]]])]
 
211
        assert_array_equal(res,desired)
 
212
    def check_2D_array(self):
 
213
        a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
 
214
        res=map(atleast_3d,[a,b])
 
215
        desired = [a[:,:,newaxis],b[:,:,newaxis]]
 
216
        assert_array_equal(res,desired)
 
217
    def check_3D_array(self):
 
218
        a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
 
219
        a = array([a,a]);b = array([b,b]);
 
220
        res=map(atleast_3d,[a,b])
 
221
        desired = [a,b]
 
222
        assert_array_equal(res,desired)
 
223
 
 
224
class test_hstack(ScipyTestCase):
 
225
    def check_0D_array(self):
 
226
        a = array(1); b = array(2);
 
227
        res=hstack([a,b])
 
228
        desired = array([1,2])
 
229
        assert_array_equal(res,desired)
 
230
    def check_1D_array(self):
 
231
        a = array([1]); b = array([2]);
 
232
        res=hstack([a,b])
 
233
        desired = array([1,2])
 
234
        assert_array_equal(res,desired)
 
235
    def check_2D_array(self):
 
236
        a = array([[1],[2]]); b = array([[1],[2]]);
 
237
        res=hstack([a,b])
 
238
        desired = array([[1,1],[2,2]])
 
239
        assert_array_equal(res,desired)
 
240
 
 
241
class test_vstack(ScipyTestCase):
 
242
    def check_0D_array(self):
 
243
        a = array(1); b = array(2);
 
244
        res=vstack([a,b])
 
245
        desired = array([[1],[2]])
 
246
        assert_array_equal(res,desired)
 
247
    def check_1D_array(self):
 
248
        a = array([1]); b = array([2]);
 
249
        res=vstack([a,b])
 
250
        desired = array([[1],[2]])
 
251
        assert_array_equal(res,desired)
 
252
    def check_2D_array(self):
 
253
        a = array([[1],[2]]); b = array([[1],[2]]);
 
254
        res=vstack([a,b])
 
255
        desired = array([[1],[2],[1],[2]])
 
256
        assert_array_equal(res,desired)
 
257
    def check_2D_array2(self):
 
258
        a = array([1,2]); b = array([1,2]);
 
259
        res=vstack([a,b])
 
260
        desired = array([[1,2],[1,2]])
 
261
        assert_array_equal(res,desired)
 
262
 
 
263
class test_dstack(ScipyTestCase):
 
264
    def check_0D_array(self):
 
265
        a = array(1); b = array(2);
 
266
        res=dstack([a,b])
 
267
        desired = array([[[1,2]]])
 
268
        assert_array_equal(res,desired)
 
269
    def check_1D_array(self):
 
270
        a = array([1]); b = array([2]);
 
271
        res=dstack([a,b])
 
272
        desired = array([[[1,2]]])
 
273
        assert_array_equal(res,desired)
 
274
    def check_2D_array(self):
 
275
        a = array([[1],[2]]); b = array([[1],[2]]);
 
276
        res=dstack([a,b])
 
277
        desired = array([[[1,1]],[[2,2,]]])
 
278
        assert_array_equal(res,desired)
 
279
    def check_2D_array2(self):
 
280
        a = array([1,2]); b = array([1,2]);
 
281
        res=dstack([a,b])
 
282
        desired = array([[[1,1],[2,2]]])
 
283
        assert_array_equal(res,desired)
 
284
 
 
285
""" array_split has more comprehensive test of splitting.
 
286
    only do simple test on hsplit, vsplit, and dsplit
 
287
"""
 
288
class test_hsplit(ScipyTestCase):
 
289
    """ only testing for integer splits.
 
290
    """
 
291
    def check_0D_array(self):
 
292
        a= array(1)
 
293
        try:
 
294
            hsplit(a,2)
 
295
            assert(0)
 
296
        except ValueError:
 
297
            pass
 
298
    def check_1D_array(self):
 
299
        a= array([1,2,3,4])
 
300
        res = hsplit(a,2)
 
301
        desired = [array([1,2]),array([3,4])]
 
302
        compare_results(res,desired)
 
303
    def check_2D_array(self):
 
304
        a= array([[1,2,3,4],
 
305
                  [1,2,3,4]])
 
306
        res = hsplit(a,2)
 
307
        desired = [array([[1,2],[1,2]]),array([[3,4],[3,4]])]
 
308
        compare_results(res,desired)
 
309
 
 
310
class test_vsplit(ScipyTestCase):
 
311
    """ only testing for integer splits.
 
312
    """
 
313
    def check_1D_array(self):
 
314
        a= array([1,2,3,4])
 
315
        try:
 
316
            vsplit(a,2)
 
317
            assert(0)
 
318
        except ValueError:
 
319
            pass
 
320
    def check_2D_array(self):
 
321
        a= array([[1,2,3,4],
 
322
                  [1,2,3,4]])
 
323
        res = vsplit(a,2)
 
324
        desired = [array([[1,2,3,4]]),array([[1,2,3,4]])]
 
325
        compare_results(res,desired)
 
326
 
 
327
class test_dsplit(ScipyTestCase):
 
328
    """ only testing for integer splits.
 
329
    """
 
330
    def check_2D_array(self):
 
331
        a= array([[1,2,3,4],
 
332
                  [1,2,3,4]])
 
333
        try:
 
334
            dsplit(a,2)
 
335
            assert(0)
 
336
        except ValueError:
 
337
            pass
 
338
    def check_3D_array(self):
 
339
        a= array([[[1,2,3,4],
 
340
                   [1,2,3,4]],
 
341
                  [[1,2,3,4],
 
342
                   [1,2,3,4]]])
 
343
        res = dsplit(a,2)
 
344
        desired = [array([[[1,2],[1,2]],[[1,2],[1,2]]]),
 
345
                   array([[[3,4],[3,4]],[[3,4],[3,4]]])]
 
346
        compare_results(res,desired)
 
347
 
 
348
class test_squeeze(ScipyTestCase):
 
349
    def check_basic(self):
 
350
        a = rand(20,10,10,1,1)
 
351
        b = rand(20,1,10,1,20)
 
352
        c = rand(1,1,20,10)
 
353
        assert_array_equal(squeeze(a),reshape(a,(20,10,10)))
 
354
        assert_array_equal(squeeze(b),reshape(b,(20,10,20)))
 
355
        assert_array_equal(squeeze(c),reshape(c,(20,10)))
 
356
 
 
357
class test_kron(ScipyTestCase):
 
358
    def check_return_type(self):
 
359
        a = ones([2,2])
 
360
        m = asmatrix(a)
 
361
        assert_equal(type(kron(a,a)), ndarray) 
 
362
        assert_equal(type(kron(m,m)), matrix) 
 
363
        assert_equal(type(kron(a,m)), matrix) 
 
364
        assert_equal(type(kron(m,a)), matrix) 
 
365
        class myarray(ndarray): 
 
366
            __array_priority__ = 0.0
 
367
        ma = myarray(a.shape, a.dtype, a.data)
 
368
        assert_equal(type(kron(a,a)), ndarray) 
 
369
        assert_equal(type(kron(ma,ma)), myarray) 
 
370
        assert_equal(type(kron(a,ma)), ndarray) 
 
371
        assert_equal(type(kron(ma,a)), myarray) 
 
372
    def check_rank_checking(self):
 
373
        one = ones([2])
 
374
        two = ones([2,2])
 
375
        three = ones([2,2,2])
 
376
        for a in [one, two, three]:
 
377
            for b in [one, two, three]:
 
378
                if a is b is two:
 
379
                    continue
 
380
                try:
 
381
                    kron(a, b)
 
382
                except ValueError:
 
383
                    continue
 
384
                except:
 
385
                    pass
 
386
                assert False, "ValueError expected"
 
387
 
 
388
 
 
389
# Utility
 
390
 
 
391
def compare_results(res,desired):
 
392
    for i in range(len(desired)):
 
393
        assert_array_equal(res[i],desired[i])
 
394
 
 
395
 
 
396
if __name__ == "__main__":
 
397
    ScipyTest().run()