~ubuntu-branches/ubuntu/feisty/python-numpy/feisty

« back to all changes in this revision

Viewing changes to numpy/lib/tests/test_function_base.py

  • Committer: Bazaar Package Importer
  • Author(s): Matthias Klose
  • Date: 2006-07-12 10:00:24 UTC
  • Revision ID: james.westby@ubuntu.com-20060712100024-5lw9q2yczlisqcrt
Tags: upstream-0.9.8
ImportĀ upstreamĀ versionĀ 0.9.8

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
 
 
2
import sys
 
3
 
 
4
from numpy.testing import *
 
5
set_package_path()
 
6
import numpy.lib;reload(numpy.lib)
 
7
from numpy.lib import *
 
8
from numpy.core import *
 
9
del sys.path[0]
 
10
 
 
11
class test_any(ScipyTestCase):
 
12
    def check_basic(self):
 
13
        y1 = [0,0,1,0]
 
14
        y2 = [0,0,0,0]
 
15
        y3 = [1,0,1,0]
 
16
        assert(any(y1))
 
17
        assert(any(y3))
 
18
        assert(not any(y2))
 
19
 
 
20
    def check_nd(self):
 
21
        y1 = [[0,0,0],[0,1,0],[1,1,0]]
 
22
        assert(any(y1))
 
23
        assert_array_equal(sometrue(y1),[1,1,0])
 
24
        assert_array_equal(sometrue(y1,axis=1),[0,1,1])
 
25
 
 
26
class test_all(ScipyTestCase):
 
27
    def check_basic(self):
 
28
        y1 = [0,1,1,0]
 
29
        y2 = [0,0,0,0]
 
30
        y3 = [1,1,1,1]
 
31
        assert(not all(y1))
 
32
        assert(all(y3))
 
33
        assert(not all(y2))
 
34
        assert(all(~array(y2)))
 
35
 
 
36
    def check_nd(self):
 
37
        y1 = [[0,0,1],[0,1,1],[1,1,1]]
 
38
        assert(not all(y1))
 
39
        assert_array_equal(alltrue(y1),[0,0,1])
 
40
        assert_array_equal(alltrue(y1,axis=1),[0,0,1])
 
41
 
 
42
class test_average(ScipyTestCase):
 
43
    def check_basic(self):
 
44
        y1 = array([1,2,3])
 
45
        assert(average(y1) == 2.)
 
46
        y2 = array([1.,2.,3.])
 
47
        assert(average(y2) == 2.)
 
48
        y3 = [0.,0.,0.]
 
49
        assert(average(y3) == 0.)
 
50
 
 
51
        y4 = ones((4,4))
 
52
        y4[0,1] = 0
 
53
        y4[1,0] = 2
 
54
        assert_array_equal(y4.mean(0), average(y4, 0))
 
55
        assert_array_equal(y4.mean(1), average(y4, 1))
 
56
 
 
57
        y5 = rand(5,5)
 
58
        assert_array_equal(y5.mean(0), average(y5, 0))
 
59
        assert_array_equal(y5.mean(1), average(y5, 1))
 
60
 
 
61
class test_logspace(ScipyTestCase):
 
62
    def check_basic(self):
 
63
        y = logspace(0,6)
 
64
        assert(len(y)==50)
 
65
        y = logspace(0,6,num=100)
 
66
        assert(y[-1] == 10**6)
 
67
        y = logspace(0,6,endpoint=0)
 
68
        assert(y[-1] < 10**6)
 
69
        y = logspace(0,6,num=7)
 
70
        assert_array_equal(y,[1,10,100,1e3,1e4,1e5,1e6])
 
71
 
 
72
class test_linspace(ScipyTestCase):
 
73
    def check_basic(self):
 
74
        y = linspace(0,10)
 
75
        assert(len(y)==50)
 
76
        y = linspace(2,10,num=100)
 
77
        assert(y[-1] == 10)
 
78
        y = linspace(2,10,endpoint=0)
 
79
        assert(y[-1] < 10)
 
80
        y,st = linspace(2,10,retstep=1)
 
81
        assert_almost_equal(st,8/49.0)
 
82
        assert_array_almost_equal(y,mgrid[2:10:50j],13)
 
83
 
 
84
    def check_corner(self):
 
85
        y = list(linspace(0,1,1))
 
86
        assert y == [0.0], y
 
87
        y = list(linspace(0,1,2.5))
 
88
        assert y == [0.0, 1.0]
 
89
 
 
90
    def check_type(self):
 
91
        t1 = linspace(0,1,0).dtype
 
92
        t2 = linspace(0,1,1).dtype
 
93
        t3 = linspace(0,1,2).dtype
 
94
        assert_equal(t1, t2)
 
95
        assert_equal(t2, t3)
 
96
 
 
97
class test_amax(ScipyTestCase):
 
98
    def check_basic(self):
 
99
        a = [3,4,5,10,-3,-5,6.0]
 
100
        assert_equal(amax(a),10.0)
 
101
        b = [[3,6.0, 9.0],
 
102
             [4,10.0,5.0],
 
103
             [8,3.0,2.0]]
 
104
        assert_equal(amax(b,axis=0),[8.0,10.0,9.0])
 
105
        assert_equal(amax(b,axis=1),[9.0,10.0,8.0])
 
106
 
 
107
class test_amin(ScipyTestCase):
 
108
    def check_basic(self):
 
109
        a = [3,4,5,10,-3,-5,6.0]
 
110
        assert_equal(amin(a),-5.0)
 
111
        b = [[3,6.0, 9.0],
 
112
             [4,10.0,5.0],
 
113
             [8,3.0,2.0]]
 
114
        assert_equal(amin(b,axis=0),[3.0,3.0,2.0])
 
115
        assert_equal(amin(b,axis=1),[3.0,4.0,2.0])
 
116
 
 
117
class test_ptp(ScipyTestCase):
 
118
    def check_basic(self):
 
119
        a = [3,4,5,10,-3,-5,6.0]
 
120
        assert_equal(ptp(a),15.0)
 
121
        b = [[3,6.0, 9.0],
 
122
             [4,10.0,5.0],
 
123
             [8,3.0,2.0]]
 
124
        assert_equal(ptp(b,axis=0),[5.0,7.0,7.0])
 
125
        assert_equal(ptp(b,axis=-1),[6.0,6.0,6.0])
 
126
 
 
127
class test_cumsum(ScipyTestCase):
 
128
    def check_basic(self):
 
129
        ba = [1,2,10,11,6,5,4]
 
130
        ba2 = [[1,2,3,4],[5,6,7,9],[10,3,4,5]]
 
131
        for ctype in [int8,uint8,int16,uint16,int32,uint32,
 
132
                      float32,float64,complex64,complex128]:
 
133
            a = array(ba,ctype)
 
134
            a2 = array(ba2,ctype)
 
135
            assert_array_equal(cumsum(a), array([1,3,13,24,30,35,39],ctype))
 
136
            assert_array_equal(cumsum(a2,axis=0), array([[1,2,3,4],[6,8,10,13],
 
137
                                                         [16,11,14,18]],ctype))
 
138
            assert_array_equal(cumsum(a2,axis=1),
 
139
                               array([[1,3,6,10],
 
140
                                      [5,11,18,27],
 
141
                                      [10,13,17,22]],ctype))
 
142
 
 
143
class test_prod(ScipyTestCase):
 
144
    def check_basic(self):
 
145
        ba = [1,2,10,11,6,5,4]
 
146
        ba2 = [[1,2,3,4],[5,6,7,9],[10,3,4,5]]
 
147
        for ctype in [int16,uint16,int32,uint32,
 
148
                      float32,float64,complex64,complex128]:
 
149
            a = array(ba,ctype)
 
150
            a2 = array(ba2,ctype)
 
151
            if ctype in ['1', 'b']:
 
152
                self.failUnlessRaises(ArithmeticError, prod, a)
 
153
                self.failUnlessRaises(ArithmeticError, prod, a2, 1)
 
154
                self.failUnlessRaises(ArithmeticError, prod, a)
 
155
            else:
 
156
                assert_equal(prod(a),26400)
 
157
                assert_array_equal(prod(a2,axis=0),
 
158
                                   array([50,36,84,180],ctype))
 
159
                assert_array_equal(prod(a2,axis=-1),array([24, 1890, 600],ctype))
 
160
 
 
161
class test_cumprod(ScipyTestCase):
 
162
    def check_basic(self):
 
163
        ba = [1,2,10,11,6,5,4]
 
164
        ba2 = [[1,2,3,4],[5,6,7,9],[10,3,4,5]]
 
165
        for ctype in [int16,uint16,int32,uint32,
 
166
                      float32,float64,complex64,complex128]:
 
167
            a = array(ba,ctype)
 
168
            a2 = array(ba2,ctype)
 
169
            if ctype in ['1', 'b']:
 
170
                self.failUnlessRaises(ArithmeticError, cumprod, a)
 
171
                self.failUnlessRaises(ArithmeticError, cumprod, a2, 1)
 
172
                self.failUnlessRaises(ArithmeticError, cumprod, a)
 
173
            else:
 
174
                assert_array_equal(cumprod(a,axis=-1),
 
175
                                   array([1, 2, 20, 220,
 
176
                                          1320, 6600, 26400],ctype))
 
177
                assert_array_equal(cumprod(a2,axis=0),
 
178
                                   array([[ 1,  2,  3,   4],
 
179
                                          [ 5, 12, 21,  36],
 
180
                                          [50, 36, 84, 180]],ctype))
 
181
                assert_array_equal(cumprod(a2,axis=-1),
 
182
                                   array([[ 1,  2,   6,   24],
 
183
                                          [ 5, 30, 210, 1890],
 
184
                                          [10, 30, 120,  600]],ctype))
 
185
 
 
186
class test_diff(ScipyTestCase):
 
187
    def check_basic(self):
 
188
        x = [1,4,6,7,12]
 
189
        out = array([3,2,1,5])
 
190
        out2 = array([-1,-1,4])
 
191
        out3 = array([0,5])
 
192
        assert_array_equal(diff(x),out)
 
193
        assert_array_equal(diff(x,n=2),out2)
 
194
        assert_array_equal(diff(x,n=3),out3)
 
195
 
 
196
    def check_nd(self):
 
197
        x = 20*rand(10,20,30)
 
198
        out1 = x[:,:,1:] - x[:,:,:-1]
 
199
        out2 = out1[:,:,1:] - out1[:,:,:-1]
 
200
        out3 = x[1:,:,:] - x[:-1,:,:]
 
201
        out4 = out3[1:,:,:] - out3[:-1,:,:]
 
202
        assert_array_equal(diff(x),out1)
 
203
        assert_array_equal(diff(x,n=2),out2)
 
204
        assert_array_equal(diff(x,axis=0),out3)
 
205
        assert_array_equal(diff(x,n=2,axis=0),out4)
 
206
 
 
207
class test_angle(ScipyTestCase):
 
208
    def check_basic(self):
 
209
        x = [1+3j,sqrt(2)/2.0+1j*sqrt(2)/2,1,1j,-1,-1j,1-3j,-1+3j]
 
210
        y = angle(x)
 
211
        yo = [arctan(3.0/1.0),arctan(1.0),0,pi/2,pi,-pi/2.0,
 
212
              -arctan(3.0/1.0),pi-arctan(3.0/1.0)]
 
213
        z = angle(x,deg=1)
 
214
        zo = array(yo)*180/pi
 
215
        assert_array_almost_equal(y,yo,11)
 
216
        assert_array_almost_equal(z,zo,11)
 
217
 
 
218
class test_trim_zeros(ScipyTestCase):
 
219
    """ only testing for integer splits.
 
220
    """
 
221
    def check_basic(self):
 
222
        a= array([0,0,1,2,3,4,0])
 
223
        res = trim_zeros(a)
 
224
        assert_array_equal(res,array([1,2,3,4]))
 
225
    def check_leading_skip(self):
 
226
        a= array([0,0,1,0,2,3,4,0])
 
227
        res = trim_zeros(a)
 
228
        assert_array_equal(res,array([1,0,2,3,4]))
 
229
    def check_trailing_skip(self):
 
230
        a= array([0,0,1,0,2,3,0,4,0])
 
231
        res = trim_zeros(a)
 
232
        assert_array_equal(res,array([1,0,2,3,0,4]))
 
233
 
 
234
 
 
235
class test_extins(ScipyTestCase):
 
236
    def check_basic(self):
 
237
        a = array([1,3,2,1,2,3,3])
 
238
        b = extract(a>1,a)
 
239
        assert_array_equal(b,[3,2,2,3,3])
 
240
    def check_insert(self):
 
241
        a = array([1,4,3,2,5,8,7])
 
242
        insert(a,[0,1,0,1,0,1,0],[2,4,6])
 
243
        assert_array_equal(a,[1,2,3,4,5,6,7])
 
244
    def check_both(self):
 
245
        a = rand(10)
 
246
        mask = a > 0.5
 
247
        ac = a.copy()
 
248
        c = extract(mask, a)
 
249
        insert(a,mask,0)
 
250
        insert(a,mask,c)
 
251
        assert_array_equal(a,ac)
 
252
 
 
253
class test_vectorize(ScipyTestCase):
 
254
    def check_simple(self):
 
255
        def addsubtract(a,b):
 
256
            if a > b:
 
257
                return a - b
 
258
            else:
 
259
                return a + b
 
260
        f = vectorize(addsubtract)
 
261
        r = f([0,3,6,9],[1,3,5,7])
 
262
        assert_array_equal(r,[1,6,1,2])
 
263
    def check_scalar(self):
 
264
        def addsubtract(a,b):
 
265
            if a > b:
 
266
                return a - b
 
267
            else:
 
268
                return a + b
 
269
        f = vectorize(addsubtract)
 
270
        r = f([0,3,6,9],5)
 
271
        assert_array_equal(r,[5,8,1,4])
 
272
    def check_large(self): 
 
273
        x = linspace(-3,2,10000) 
 
274
        f = vectorize(lambda x: x) 
 
275
        y = f(x) 
 
276
        assert_array_equal(y, x) 
 
277
    
 
278
 
 
279
class test_unwrap(ScipyTestCase):
 
280
    def check_simple(self):
 
281
                #check that unwrap removes jumps greather that 2*pi
 
282
        assert_array_equal(unwrap([1,1+2*pi]),[1,1])
 
283
        #check that unwrap maintans continuity
 
284
        assert(all(diff(unwrap(rand(10)*100))<pi))
 
285
 
 
286
 
 
287
class test_filterwindows(ScipyTestCase):
 
288
    def check_hanning(self):
 
289
        #check symmetry
 
290
        w=hanning(10)
 
291
        assert_array_almost_equal(w,flipud(w),7)
 
292
        #check known value
 
293
        assert_almost_equal(sum(w),4.500,4)
 
294
 
 
295
    def check_hamming(self):
 
296
        #check symmetry
 
297
        w=hamming(10)
 
298
        assert_array_almost_equal(w,flipud(w),7)
 
299
        #check known value
 
300
        assert_almost_equal(sum(w),4.9400,4)
 
301
 
 
302
    def check_bartlett(self):
 
303
        #check symmetry
 
304
        w=bartlett(10)
 
305
        assert_array_almost_equal(w,flipud(w),7)
 
306
        #check known value
 
307
        assert_almost_equal(sum(w),4.4444,4)
 
308
 
 
309
    def check_blackman(self):
 
310
        #check symmetry
 
311
        w=blackman(10)
 
312
        assert_array_almost_equal(w,flipud(w),7)
 
313
        #check known value
 
314
        assert_almost_equal(sum(w),3.7800,4)
 
315
 
 
316
 
 
317
class test_trapz(ScipyTestCase):
 
318
    def check_simple(self):
 
319
        r=trapz(exp(-1.0/2*(arange(-10,10,.1))**2)/sqrt(2*pi),dx=0.1)
 
320
        #check integral of normal equals 1
 
321
        assert_almost_equal(sum(r),1,7)
 
322
 
 
323
class test_sinc(ScipyTestCase):
 
324
    def check_simple(self):
 
325
        assert(sinc(0)==1)
 
326
        w=sinc(linspace(-1,1,100))
 
327
        #check symmetry
 
328
        assert_array_almost_equal(w,flipud(w),7)
 
329
 
 
330
class test_histogram(ScipyTestCase):
 
331
    def check_simple(self):
 
332
        n=100
 
333
        v=rand(n)
 
334
        (a,b)=histogram(v)
 
335
        #check if the sum of the bins equals the number of samples
 
336
        assert(sum(a)==n)
 
337
        #check that the bin counts are evenly spaced when the data is from a linear function
 
338
        (a,b)=histogram(linspace(0,10,100))
 
339
        assert(all(a==10))
 
340
 
 
341
 
 
342
 
 
343
def compare_results(res,desired):
 
344
    for i in range(len(desired)):
 
345
        assert_array_equal(res[i],desired[i])
 
346
 
 
347
if __name__ == "__main__":
 
348
    ScipyTest('numpy.lib.function_base').run()