2
# -*- coding: utf-8 -*-
3
## ##########################################################################################################
5
## This file is released under GNU Public License v3
6
## See LICENSE File at the top the pycvf tree.
8
## Author : Bertrand NOUVEL / CNRS (2009)
10
## Revision FILE: $Id$
12
## ###########################################################################################################
13
## copyright $Copyright$
14
## @version $Revision$
15
## @lastrevision $Date$
16
## @modifiedby $LastChangedBy$
17
## @lastmodified $LastChangedDate$
18
#############################################################################################################
21
# -*- coding: utf-8 -*-
23
from pycvf.core.errors import pycvf_debug, pycvf_warning, pycvf_error
24
from pycvf.core.distribution import *
25
pycvf_dist(PYCVFD_REQUIRE_PACKAGE,'pysash')
26
pycvf_dist(PYCVFD_SPECIFIC_LICENSE,"Sash is patented and is not under LGPL-3. Please refer at the Sash documentation. Do not use in any commercial application.")
30
from pycvfext.indexes.lib.filearray import *
32
import os, marshal,cPickle
37
import itertools, time
39
def geometric(fromv,tov,factor):
46
class CachedSashIndex():
47
def __init__(self,filename,keysz,valuesz,distance=None):
48
self.filename=filename
49
self.distance=distance
52
self._values=FileArray(filename+"/values.tbl",valuesz)
53
self._keys=FileArray(filename+"/keys.tbl",keysz)
54
self.sashindex=Sash(filename+"/index",self._keys)
55
print filename+" sash successfully loaded"
61
self._values=FileArray(filename+"/values.tbl",valuesz)
62
self._keys=FileArray(filename+"/keys.tbl",keysz)
64
self.sashindex=GenericSash(self.distance)
72
self._values=FileArray(filename+"/values.tbl",valuesz)
73
self._keys=FileArray(filename+"/keys.tbl",keysz)
77
def add_many(self,keys,values):
78
#assert(self._values==None) # SASH INDEX IS NOT INCREMENTAL
79
#assert(self._keys==None) # SASH INDEX IS NOT INCREMENTAL
80
self._values.appenditems(values)
81
self._keys.appenditems(keys)
82
self.sashindex.build(keys.astype(numpy.float32))
83
def __getitem__(self,query):
84
return self.sashindex.getitem(query)[0]
85
def getitem(self,query,numelem=1):
86
nquery=numpy.asarray(query).squeeze().astype(numpy.float32)
87
#print nquery, nquery.shape
88
self.sashindex.findNear(nquery,numelem)
89
return zip(map(lambda i :self._values[i], self.sashindex.getResultIndices(numelem).tolist()),self.sashindex.getResultDists(numelem).tolist())
91
self.sashindex.save(filename+"/sash",self._keys)
93
return iter(self._keys)
95
return self.sashindex.getNumItems()
97
return iter(self._values)
100
o.put_bib(os.environ["JFLIPATH"]+"/pycvf.indexes/sashindex.bib")
101
return "\\cite{Houle}"
105
return i.reshape(s[0],reduce(lambda x,y:x*y,s[1:],1))
107
class SashIndex(object):
108
def __init__(self,distance=None):
111
self.distance=distance
113
self.sashindex=GenericSash(distance)
115
self.sashindex=Sash()
116
def add_many(self,keys,values):
117
#print "checking empty sash"
118
assert(self._values==None) # SASH INDEX IS NOT INCREMENTAL
119
assert(self._keys==None) # SASH INDEX IS NOT INCREMENTAL
120
if (not self.distance):
122
self._keys=keys.reshape((keys.shape[0],-1)).astype(numpy.float32).copy('C')
124
pycvf_warning("Errors SASH keys must be dense float vector in this implementation of SashIndex")
125
pycvf_warning(u"here an example of your keys " +unicode(keys[0]))
127
#print type(values), values#,values[0]
128
self._values=values.reshape((values.shape[0],-1))
129
#print "shapes : ", self._keys.shape, self._values.shape
131
self._keys=keys.reshape((keys.shape[0],-1))
132
self._values=values.reshape((values.shape[0],-1))
133
self.sashindex.build(self._keys)
137
pycvf_debug(10,"load "+filename+"value.dat")
139
f=file(filename+"values.dat","rb")
140
r._values=pickle.load(f)
141
pycvf_debug(10,"load "+filename+"keys.npy/mat")
143
r._keys=numpy.load(file(filename+"keys.npy","rb"))
145
r._keys=scipy.io.loadmat(filename+"keys.mat")["keys"].copy('C')
148
r.sashindex.build(r._keys,filename=filename+"index")
149
pycvf_debug(10,"/load "+filename)
151
def save(self,filename):
152
pycvf_debug(10, "saving sash")
153
#pycvf_debug(10, "keys : " + str(self._keys.shape) +":" + str(self._keys))
154
if (self._keys not in [None,[]]):
155
pickle.dump(self._values,file(filename+"values.dat","wb"))
156
#scipy.io.savemat(filename+"keys.mat",{"keys":self._keys})
157
numpy.save(file(filename+"keys.npy","wb"),self._keys)
158
self.sashindex.save(filename+"index")#,self._keys)
159
def __getitem__(self,query):
160
if (not self.distance):
161
if (type(query) in [ int, float, long ] ) :
163
nquery=numpy.asarray(query).astype(numpy.float32).copy('C')
166
self.sashindex.findNearest(nquery,1)
167
rl=map(lambda i :self._values[i], self.sashindex.getResultIndices(1).tolist())
169
def getitem(self,query,numelem=1, exact=False,*args,**kwargs):
170
if (not self.distance):
171
if (type(query) in [ int, float, long ] ) :
173
nquery=to2d(numpy.asarray([query])).astype(numpy.float32).copy('C')
175
nquery=to2d(numpy.asarray([query]))
177
self.sashindex.findNearest(nquery[0],numelem,*args,**kwargs)
179
self.sashindex.findNear(nquery[0],numelem,*args,**kwargs)
180
ri=self.sashindex.getResultIndices(numelem).tolist()
181
rd=self.sashindex.getResultDists(numelem).tolist()
183
return zip(map(lambda i :self._values[i], ri),rd,map(lambda i :self._keys[i], ri))
184
def getitems_dists(self,queries,numelems=[1,3,10], exact=False, scaleFactor=1, *args, **kwargs):
186
A (non-standard) call to retriece mean distances up to a certain neighbor
191
for r in range(queries.shape[0]):
193
cscaleFactor=scaleFactor
194
rr=self.sashindex.findNear(queries[r],numelem,cscaleFactor,*args,**kwargs)
195
a=self.sashindex.getResultDists(numelem)
196
res.append(map(lambda x:a[:x].mean(),numelems))
198
def getaccuracy(self,queries,numelem=40, numruns=1000, scaleFactor=1, *args, **kwargs):
200
A (non-standard) call to retriece mean distances up to a certain neighbor
206
for run in range(numruns):
207
for r in range(queries.shape[0]):
209
cscaleFactor=scaleFactor
210
rr=self.sashindex.findNear(queries[r],numelem,cscaleFactor,*args,**kwargs)
211
resfast.append(self.sashindex.getResultDists(numelem))
212
tfast=time.clock()-t0fast
215
for r in range(queries.shape[0]):
217
cscaleFactor=scaleFactor
218
rr=self.sashindex.findNearest(queries[r],numelem,cscaleFactor,*args,**kwargs)
219
reslong.append(self.sashindex.getResultDists(numelem))
220
tlong=time.clock()-t0long
224
for q in zip(reslong,resfast):
227
while reslong[li-1]==ld:
229
resfast=numpy.array(q[1])
230
nok=(resfast<ld).sum()
234
accv.append(nok/numelem)
236
speedup=tlong/(tfast/numruns)
241
def getaccuracy2(self,queries,numelem=5, numruns=10, scaleFactor=1, progressbar=True, *args, **kwargs):
243
from pycvf.core.utilities import with_progressbar
245
with_progressbar=lambda x:x
251
for r in with_progressbar(range(queries.shape[0])):
253
self.sashindex.findNearest(queries[r],numelem,*args,**kwargs)
254
tlong+=time.clock()-t0long
255
reslong=self.sashindex.getResultDists(numelem)
258
for run in range(numruns):
259
self.sashindex.resetQuery()
260
self.sashindex.findNear(queries[r],numelem,scaleFactor=scaleFactor,*args,**kwargs)
261
tfast+=time.clock()-t0fast
263
resfast=self.sashindex.getResultDists(numelem)
264
accv.append(self.sashindex.getResultAcc(reslong))
265
accv2.append((resfast.mean()-reslong.mean())/reslong.mean())
267
speedup=tlong/(tfast/numruns)
271
pylab.plot(numpy.sort(accv2))
272
pylab.savefig("res3_3-%f.png"%(scaleFactor,))
275
return numpy.mean(accv),numpy.std(accv),speedup,numpy.mean(accv2),numpy.std(accv2)
277
def get_best_scalefactor(self,queries,*args,**kwargs):
278
#return dict([ (x,self.getaccuracy2(queries,scaleFactor=x, *args, **kwargs)) for x in geometric(16**-1,8,2**.5) ])
280
for x in geometric(16**-1,8,2**.5):
282
c=(x,self.getaccuracy2(queries,scaleFactor=x, *args, **kwargs))
286
def getitems(self,queries,numelem=1, exact=False, *args, **kwargs):
287
#pycvf_debug(10, "query getitems:"+str( queries))
289
if (not self.distance):
290
#if (type(queries[0]) in [ int, float, long ] ) :
291
# queries=map(lambda x:[x],queries)
292
nqueries=to2d(numpy.asarray(queries)).astype(numpy.float32).copy('C')
294
nqueries=to2d(numpy.asarray(queries))
296
for r in range(nqueries.shape[0]):
298
self.sashindex.findNearest(nqueries[r],numelem,*args,**kwargs)
300
self.sashindex.findNear(nqueries[r],numelem,*args,**kwargs)
301
#print self._values.shape, self.sashindex.getResultIndices(numelem).tolist(), self.sashindex.getResultDists(numelem).tolist()
303
ri=self.sashindex.getResultIndices(numelem).tolist()
304
rd=self.sashindex.getResultDists(numelem).tolist()
306
#print self._values.shape, self._keys.shape
307
#print self._values, self._keys
308
#res+=[zip(map(lambda i :self._values[i][0], ri),rd,map(lambda i :self._keys[i], ri))]
309
res+=[zip(map(lambda i :self._values[i], ri),rd )]
311
return res#numpy.vstack(res)
313
return iter(self._keys) if self._keys!= None else []
315
return iter(self._values) if self._values!= None else []
317
return self.sashindex.getNumItems()
325
self.sashindex=Sash()
328
o.put_bib(os.environ["JFLIPATH"]+"/pycvf.indexes/sashindex.bib")
329
return "\\cite{Houle}"