1
1
# Copyright Anne M. Archibald 2008
2
2
# Released under the scipy license
3
from numpy.testing import *
4
from numpy.testing import assert_equal, assert_array_equal, assert_almost_equal, \
5
assert_, run_module_suite
6
8
from scipy.spatial import KDTree, Rectangle, distance_matrix, cKDTree
12
14
d, i = self.kdtree.query(x, 1)
13
15
assert_almost_equal(d**2,np.sum((x-self.data[i])**2))
15
assert np.all(np.sum((self.data-x[np.newaxis,:])**2,axis=1)>d**2-eps)
17
assert_(np.all(np.sum((self.data-x[np.newaxis,:])**2,axis=1)>d**2-eps))
17
19
def test_m_nearest(self):
37
39
assert_almost_equal(near_d**2,np.sum((x-self.data[near_i])**2))
38
assert near_d<d+eps, "near_d=%g should be less than %g" % (near_d,d)
40
assert_(near_d<d+eps, "near_d=%g should be less than %g" % (near_d,d))
39
41
assert_equal(np.sum(np.sum((self.data-x[np.newaxis,:])**2,axis=1)<d**2+eps),hits)
41
43
def test_points_near_l1(self):
51
53
assert_almost_equal(near_d,distance(x,self.data[near_i],1))
52
assert near_d<d+eps, "near_d=%g should be less than %g" % (near_d,d)
54
assert_(near_d<d+eps, "near_d=%g should be less than %g" % (near_d,d))
53
55
assert_equal(np.sum(distance(self.data,x,1)<d+eps),hits)
54
56
def test_points_near_linf(self):
64
66
assert_almost_equal(near_d,distance(x,self.data[near_i],np.inf))
65
assert near_d<d+eps, "near_d=%g should be less than %g" % (near_d,d)
67
assert_(near_d<d+eps, "near_d=%g should be less than %g" % (near_d,d))
66
68
assert_equal(np.sum(distance(self.data,x,np.inf)<d+eps),hits)
68
70
def test_approx(self):
72
74
d_real, i_real = self.kdtree.query(x, k)
73
75
d, i = self.kdtree.query(x, k, eps=eps)
74
assert np.all(d<=d_real*(1+eps))
76
assert_(np.all(d<=d_real*(1+eps)))
77
79
class test_random(ConsistencyTests):
151
153
def test_single_query(self):
152
154
d, i = self.kdtree.query(np.array([0,0,0]))
153
assert isinstance(d,float)
154
assert np.issubdtype(i, int)
155
assert_(isinstance(d,float))
156
assert_(np.issubdtype(i, int))
156
158
def test_vectorized_query(self):
157
159
d, i = self.kdtree.query(np.zeros((2,4,3)))
164
166
d, i = self.kdtree.query(np.array([0,0,0]),k=kk)
165
167
assert_equal(np.shape(d),(kk,))
166
168
assert_equal(np.shape(i),(kk,))
167
assert np.all(~np.isfinite(d[-s:]))
168
assert np.all(i[-s:]==self.kdtree.n)
169
assert_(np.all(~np.isfinite(d[-s:])))
170
assert_(np.all(i[-s:]==self.kdtree.n))
169
172
def test_vectorized_query_multiple_neighbors(self):
171
174
kk = self.kdtree.n+s
172
175
d, i = self.kdtree.query(np.zeros((2,4,3)),k=kk)
173
176
assert_equal(np.shape(d),(2,4,kk))
174
177
assert_equal(np.shape(i),(2,4,kk))
175
assert np.all(~np.isfinite(d[:,:,-s:]))
176
assert np.all(i[:,:,-s:]==self.kdtree.n)
178
assert_(np.all(~np.isfinite(d[:,:,-s:])))
179
assert_(np.all(i[:,:,-s:]==self.kdtree.n))
177
181
def test_single_query_all_neighbors(self):
178
182
d, i = self.kdtree.query([0,0,0],k=None,distance_upper_bound=1.1)
179
assert isinstance(d,list)
180
assert isinstance(i,list)
183
assert_(isinstance(d,list))
184
assert_(isinstance(i,list))
181
186
def test_vectorized_query_all_neighbors(self):
182
187
d, i = self.kdtree.query(np.zeros((2,4,3)),k=None,distance_upper_bound=1.1)
183
188
assert_equal(np.shape(d),(2,4))
184
189
assert_equal(np.shape(i),(2,4))
186
assert isinstance(d[0,0],list)
187
assert isinstance(i[0,0],list)
191
assert_(isinstance(d[0,0],list))
192
assert_(isinstance(i[0,0],list))
189
194
class test_vectorization_compiled:
201
206
def test_single_query(self):
202
207
d, i = self.kdtree.query([0,0,0])
203
assert isinstance(d,float)
204
assert isinstance(i,int)
208
assert_(isinstance(d,float))
209
assert_(isinstance(i,int))
206
211
def test_vectorized_query(self):
207
212
d, i = self.kdtree.query(np.zeros((2,4,3)))
221
226
d, i = self.kdtree.query([0,0,0],k=kk)
222
227
assert_equal(np.shape(d),(kk,))
223
228
assert_equal(np.shape(i),(kk,))
224
assert np.all(~np.isfinite(d[-s:]))
225
assert np.all(i[-s:]==self.kdtree.n)
229
assert_(np.all(~np.isfinite(d[-s:])))
230
assert_(np.all(i[-s:]==self.kdtree.n))
226
232
def test_vectorized_query_multiple_neighbors(self):
228
234
kk = self.kdtree.n+s
229
235
d, i = self.kdtree.query(np.zeros((2,4,3)),k=kk)
230
236
assert_equal(np.shape(d),(2,4,kk))
231
237
assert_equal(np.shape(i),(2,4,kk))
232
assert np.all(~np.isfinite(d[:,:,-s:]))
233
assert np.all(i[:,:,-s:]==self.kdtree.n)
238
assert_(np.all(~np.isfinite(d[:,:,-s:])))
239
assert_(np.all(i[:,:,-s:]==self.kdtree.n))
235
241
class ball_consistency:
237
243
def test_in_ball(self):
238
244
l = self.T.query_ball_point(self.x, self.d, p=self.p, eps=self.eps)
240
assert distance(self.data[i],self.x,self.p)<=self.d*(1.+self.eps)
246
assert_(distance(self.data[i],self.x,self.p)<=self.d*(1.+self.eps))
242
248
def test_found_all(self):
243
249
c = np.ones(self.T.n,dtype=np.bool)
244
250
l = self.T.query_ball_point(self.x, self.d, p=self.p, eps=self.eps)
246
assert np.all(distance(self.data[c],self.x,self.p)>=self.d/(1.+self.eps))
252
assert_(np.all(distance(self.data[c],self.x,self.p)>=self.d/(1.+self.eps)))
248
254
class test_random_ball(ball_consistency):
290
296
r = T.query_ball_point(np.random.randn(2,3,m),1)
291
297
assert_equal(r.shape,(2,3))
292
assert isinstance(r[0,0],list)
298
assert_(isinstance(r[0,0],list))
294
300
class two_trees_consistency:
297
303
r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
298
304
for i, l in enumerate(r):
300
assert distance(self.data1[i],self.data2[j],self.p)<=self.d*(1.+self.eps)
306
assert_(distance(self.data1[i],self.data2[j],self.p)<=self.d*(1.+self.eps))
301
307
def test_found_all(self):
302
308
r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
303
309
for i, l in enumerate(r):
304
310
c = np.ones(self.T2.n,dtype=np.bool)
306
assert np.all(distance(self.data2[c],self.data1[i],self.p)>=self.d/(1.+self.eps))
312
assert_(np.all(distance(self.data2[c],self.data1[i],self.p)>=self.d/(1.+self.eps)))
308
314
class test_two_random_trees(two_trees_consistency):
389
395
def test_multiple_radius(self):
390
396
rs = np.exp(np.linspace(np.log(0.01),np.log(10),3))
391
397
results = self.T1.count_neighbors(self.T2, rs)
392
assert np.all(np.diff(results)>=0)
398
assert_(np.all(np.diff(results)>=0))
393
399
for r,result in zip(rs, results):
394
400
assert_equal(self.T1.count_neighbors(self.T2, r), result)