2
from numpy.testing import *
5
from numpy.lib import *
6
from numpy.core import *
9
class test_apply_along_axis(ScipyTestCase):
10
def check_simple(self):
12
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
13
def check_simple101(self,level=11):
14
# This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4)
15
# when enabled and shape(a)[1]>100. See Issue 202.
16
a = ones((10,101),'d')
17
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
19
class test_array_split(ScipyTestCase):
20
def check_integer_0_split(self):
23
res = array_split(a,0)
24
assert(0) # it should have thrown a value error
27
def check_integer_split(self):
29
res = array_split(a,1)
30
desired = [arange(10)]
31
compare_results(res,desired)
33
res = array_split(a,2)
34
desired = [arange(5),arange(5,10)]
35
compare_results(res,desired)
37
res = array_split(a,3)
38
desired = [arange(4),arange(4,7),arange(7,10)]
39
compare_results(res,desired)
41
res = array_split(a,4)
42
desired = [arange(3),arange(3,6),arange(6,8),arange(8,10)]
43
compare_results(res,desired)
45
res = array_split(a,5)
46
desired = [arange(2),arange(2,4),arange(4,6),arange(6,8),arange(8,10)]
47
compare_results(res,desired)
49
res = array_split(a,6)
50
desired = [arange(2),arange(2,4),arange(4,6),arange(6,8),arange(8,9),
52
compare_results(res,desired)
54
res = array_split(a,7)
55
desired = [arange(2),arange(2,4),arange(4,6),arange(6,7),arange(7,8),
56
arange(8,9), arange(9,10)]
57
compare_results(res,desired)
59
res = array_split(a,8)
60
desired = [arange(2),arange(2,4),arange(4,5),arange(5,6),arange(6,7),
61
arange(7,8), arange(8,9), arange(9,10)]
62
compare_results(res,desired)
64
res = array_split(a,9)
65
desired = [arange(2),arange(2,3),arange(3,4),arange(4,5),arange(5,6),
66
arange(6,7), arange(7,8), arange(8,9), arange(9,10)]
67
compare_results(res,desired)
69
res = array_split(a,10)
70
desired = [arange(1),arange(1,2),arange(2,3),arange(3,4),
71
arange(4,5),arange(5,6), arange(6,7), arange(7,8),
72
arange(8,9), arange(9,10)]
73
compare_results(res,desired)
75
res = array_split(a,11)
76
desired = [arange(1),arange(1,2),arange(2,3),arange(3,4),
77
arange(4,5),arange(5,6), arange(6,7), arange(7,8),
78
arange(8,9), arange(9,10),array([])]
79
compare_results(res,desired)
80
def check_integer_split_2D_rows(self):
81
a = array([arange(10),arange(10)])
82
res = array_split(a,3,axis=0)
83
desired = [array([arange(10)]),array([arange(10)]),array([])]
84
compare_results(res,desired)
85
def check_integer_split_2D_cols(self):
86
a = array([arange(10),arange(10)])
87
res = array_split(a,3,axis=-1)
88
desired = [array([arange(4),arange(4)]),
89
array([arange(4,7),arange(4,7)]),
90
array([arange(7,10),arange(7,10)])]
91
compare_results(res,desired)
92
def check_integer_split_2D_default(self):
93
""" This will fail if we change default axis
95
a = array([arange(10),arange(10)])
96
res = array_split(a,3)
97
desired = [array([arange(10)]),array([arange(10)]),array([])]
98
compare_results(res,desired)
99
#perhaps should check higher dimensions
101
def check_index_split_simple(self):
104
res = array_split(a,indices,axis=-1)
105
desired = [arange(0,1),arange(1,5),arange(5,7),arange(7,10)]
106
compare_results(res,desired)
108
def check_index_split_low_bound(self):
111
res = array_split(a,indices,axis=-1)
112
desired = [array([]),arange(0,5),arange(5,7),arange(7,10)]
113
compare_results(res,desired)
114
def check_index_split_high_bound(self):
116
indices = [0,5,7,10,12]
117
res = array_split(a,indices,axis=-1)
118
desired = [array([]),arange(0,5),arange(5,7),arange(7,10),
120
compare_results(res,desired)
122
class test_split(ScipyTestCase):
123
"""* This function is essentially the same as array_split,
124
except that it test if splitting will result in an
125
equal split. Only test for this case.
127
def check_equal_split(self):
130
desired = [arange(5),arange(5,10)]
131
compare_results(res,desired)
133
def check_unequal_split(self):
137
assert(0) # should raise an error
141
class test_atleast_1d(ScipyTestCase):
142
def check_0D_array(self):
143
a = array(1); b = array(2);
144
res=map(atleast_1d,[a,b])
145
desired = [array([1]),array([2])]
146
assert_array_equal(res,desired)
147
def check_1D_array(self):
148
a = array([1,2]); b = array([2,3]);
149
res=map(atleast_1d,[a,b])
150
desired = [array([1,2]),array([2,3])]
151
assert_array_equal(res,desired)
152
def check_2D_array(self):
153
a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
154
res=map(atleast_1d,[a,b])
156
assert_array_equal(res,desired)
157
def check_3D_array(self):
158
a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
159
a = array([a,a]);b = array([b,b]);
160
res=map(atleast_1d,[a,b])
162
assert_array_equal(res,desired)
163
def check_r1array(self):
164
""" Test to make sure equivalent Travis O's r1array function
166
assert(atleast_1d(3).shape == (1,))
167
assert(atleast_1d(3j).shape == (1,))
168
assert(atleast_1d(3L).shape == (1,))
169
assert(atleast_1d(3.0).shape == (1,))
170
assert(atleast_1d([[2,3],[4,5]]).shape == (2,2))
172
class test_atleast_2d(ScipyTestCase):
173
def check_0D_array(self):
174
a = array(1); b = array(2);
175
res=map(atleast_2d,[a,b])
176
desired = [array([[1]]),array([[2]])]
177
assert_array_equal(res,desired)
178
def check_1D_array(self):
179
a = array([1,2]); b = array([2,3]);
180
res=map(atleast_2d,[a,b])
181
desired = [array([[1,2]]),array([[2,3]])]
182
assert_array_equal(res,desired)
183
def check_2D_array(self):
184
a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
185
res=map(atleast_2d,[a,b])
187
assert_array_equal(res,desired)
188
def check_3D_array(self):
189
a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
190
a = array([a,a]);b = array([b,b]);
191
res=map(atleast_2d,[a,b])
193
assert_array_equal(res,desired)
194
def check_r2array(self):
195
""" Test to make sure equivalent Travis O's r2array function
197
assert(atleast_2d(3).shape == (1,1))
198
assert(atleast_2d([3j,1]).shape == (1,2))
199
assert(atleast_2d([[[3,1],[4,5]],[[3,5],[1,2]]]).shape == (2,2,2))
201
class test_atleast_3d(ScipyTestCase):
202
def check_0D_array(self):
203
a = array(1); b = array(2);
204
res=map(atleast_3d,[a,b])
205
desired = [array([[[1]]]),array([[[2]]])]
206
assert_array_equal(res,desired)
207
def check_1D_array(self):
208
a = array([1,2]); b = array([2,3]);
209
res=map(atleast_3d,[a,b])
210
desired = [array([[[1],[2]]]),array([[[2],[3]]])]
211
assert_array_equal(res,desired)
212
def check_2D_array(self):
213
a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
214
res=map(atleast_3d,[a,b])
215
desired = [a[:,:,newaxis],b[:,:,newaxis]]
216
assert_array_equal(res,desired)
217
def check_3D_array(self):
218
a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
219
a = array([a,a]);b = array([b,b]);
220
res=map(atleast_3d,[a,b])
222
assert_array_equal(res,desired)
224
class test_hstack(ScipyTestCase):
225
def check_0D_array(self):
226
a = array(1); b = array(2);
228
desired = array([1,2])
229
assert_array_equal(res,desired)
230
def check_1D_array(self):
231
a = array([1]); b = array([2]);
233
desired = array([1,2])
234
assert_array_equal(res,desired)
235
def check_2D_array(self):
236
a = array([[1],[2]]); b = array([[1],[2]]);
238
desired = array([[1,1],[2,2]])
239
assert_array_equal(res,desired)
241
class test_vstack(ScipyTestCase):
242
def check_0D_array(self):
243
a = array(1); b = array(2);
245
desired = array([[1],[2]])
246
assert_array_equal(res,desired)
247
def check_1D_array(self):
248
a = array([1]); b = array([2]);
250
desired = array([[1],[2]])
251
assert_array_equal(res,desired)
252
def check_2D_array(self):
253
a = array([[1],[2]]); b = array([[1],[2]]);
255
desired = array([[1],[2],[1],[2]])
256
assert_array_equal(res,desired)
257
def check_2D_array2(self):
258
a = array([1,2]); b = array([1,2]);
260
desired = array([[1,2],[1,2]])
261
assert_array_equal(res,desired)
263
class test_dstack(ScipyTestCase):
264
def check_0D_array(self):
265
a = array(1); b = array(2);
267
desired = array([[[1,2]]])
268
assert_array_equal(res,desired)
269
def check_1D_array(self):
270
a = array([1]); b = array([2]);
272
desired = array([[[1,2]]])
273
assert_array_equal(res,desired)
274
def check_2D_array(self):
275
a = array([[1],[2]]); b = array([[1],[2]]);
277
desired = array([[[1,1]],[[2,2,]]])
278
assert_array_equal(res,desired)
279
def check_2D_array2(self):
280
a = array([1,2]); b = array([1,2]);
282
desired = array([[[1,1],[2,2]]])
283
assert_array_equal(res,desired)
285
""" array_split has more comprehensive test of splitting.
286
only do simple test on hsplit, vsplit, and dsplit
288
class test_hsplit(ScipyTestCase):
289
""" only testing for integer splits.
291
def check_0D_array(self):
298
def check_1D_array(self):
301
desired = [array([1,2]),array([3,4])]
302
compare_results(res,desired)
303
def check_2D_array(self):
307
desired = [array([[1,2],[1,2]]),array([[3,4],[3,4]])]
308
compare_results(res,desired)
310
class test_vsplit(ScipyTestCase):
311
""" only testing for integer splits.
313
def check_1D_array(self):
320
def check_2D_array(self):
324
desired = [array([[1,2,3,4]]),array([[1,2,3,4]])]
325
compare_results(res,desired)
327
class test_dsplit(ScipyTestCase):
328
""" only testing for integer splits.
330
def check_2D_array(self):
338
def check_3D_array(self):
339
a= array([[[1,2,3,4],
344
desired = [array([[[1,2],[1,2]],[[1,2],[1,2]]]),
345
array([[[3,4],[3,4]],[[3,4],[3,4]]])]
346
compare_results(res,desired)
348
class test_squeeze(ScipyTestCase):
349
def check_basic(self):
350
a = rand(20,10,10,1,1)
351
b = rand(20,1,10,1,20)
353
assert_array_equal(squeeze(a),reshape(a,(20,10,10)))
354
assert_array_equal(squeeze(b),reshape(b,(20,10,20)))
355
assert_array_equal(squeeze(c),reshape(c,(20,10)))
357
class test_kron(ScipyTestCase):
358
def check_return_type(self):
361
assert_equal(type(kron(a,a)), ndarray)
362
assert_equal(type(kron(m,m)), matrix)
363
assert_equal(type(kron(a,m)), matrix)
364
assert_equal(type(kron(m,a)), matrix)
365
class myarray(ndarray):
366
__array_priority__ = 0.0
367
ma = myarray(a.shape, a.dtype, a.data)
368
assert_equal(type(kron(a,a)), ndarray)
369
assert_equal(type(kron(ma,ma)), myarray)
370
assert_equal(type(kron(a,ma)), ndarray)
371
assert_equal(type(kron(ma,a)), myarray)
372
def check_rank_checking(self):
375
three = ones([2,2,2])
376
for a in [one, two, three]:
377
for b in [one, two, three]:
386
assert False, "ValueError expected"
391
def compare_results(res,desired):
392
for i in range(len(desired)):
393
assert_array_equal(res[i],desired[i])
396
if __name__ == "__main__":