~bertrand-nouvel/pycvf-sash/trunk

« back to all changes in this revision

Viewing changes to indexes/sashindex.py

  • Committer: tranx
  • Date: 2010-10-01 17:30:16 UTC
  • Revision ID: tranx@havane-20101001173016-8m0uknapnvhmpgng
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
# -*- coding: utf-8 -*-
 
3
## ##########################################################################################################
 
4
## 
 
5
## This file is released under GNU Public License v3
 
6
## See LICENSE File at the top the pycvf tree.
 
7
##
 
8
## Author : Bertrand NOUVEL / CNRS (2009)
 
9
##
 
10
## Revision FILE: $Id$
 
11
##
 
12
## ###########################################################################################################
 
13
## copyright $Copyright$
 
14
## @version $Revision$
 
15
## @lastrevision $Date$
 
16
## @modifiedby $LastChangedBy$
 
17
## @lastmodified $LastChangedDate$
 
18
#############################################################################################################
 
19
 
 
20
 
 
21
# -*- coding: utf-8 -*-
 
22
 
 
23
from pycvf.core.errors import pycvf_debug, pycvf_warning, pycvf_error
 
24
from pycvf.core.distribution import *
 
25
pycvf_dist(PYCVFD_REQUIRE_PACKAGE,'pysash')
 
26
pycvf_dist(PYCVFD_SPECIFIC_LICENSE,"Sash is patented and is not under LGPL-3. Please refer at the Sash documentation. Do not use in any commercial application.")
 
27
 
 
28
from pysash import *
 
29
 
 
30
from pycvfext.indexes.lib.filearray import *
 
31
 
 
32
import os, marshal,cPickle
 
33
try:
 
34
  import scipy.io 
 
35
except:
 
36
  pass
 
37
import itertools, time
 
38
 
 
39
def geometric(fromv,tov,factor):
 
40
  f=fromv
 
41
  while (f<=tov):
 
42
    yield f
 
43
    f*=factor
 
44
    
 
45
 
 
46
class CachedSashIndex():
 
47
    def __init__(self,filename,keysz,valuesz,distance=None):
 
48
        self.filename=filename
 
49
        self.distance=distance
 
50
        try: 
 
51
            os.stat(filename)
 
52
            self._values=FileArray(filename+"/values.tbl",valuesz)
 
53
            self._keys=FileArray(filename+"/keys.tbl",keysz)
 
54
            self.sashindex=Sash(filename+"/index",self._keys)
 
55
            print filename+" sash successfully loaded"
 
56
        except:
 
57
            try:
 
58
               os.mkdir(filename)
 
59
            except:
 
60
               pass
 
61
            self._values=FileArray(filename+"/values.tbl",valuesz)
 
62
            self._keys=FileArray(filename+"/keys.tbl",keysz)
 
63
            if (self.distance):
 
64
               self.sashindex=GenericSash(self.distance)                
 
65
            else:
 
66
               self.sashindex=Sash()            
 
67
    def reset(self):
 
68
         try:
 
69
             os.mkdir(filename)
 
70
         except:
 
71
             pass
 
72
         self._values=FileArray(filename+"/values.tbl",valuesz)
 
73
         self._keys=FileArray(filename+"/keys.tbl",keysz)
 
74
         self.sashindex=Sash()
 
75
    def __del__(self):
 
76
        self.save()
 
77
    def add_many(self,keys,values):
 
78
        #assert(self._values==None) # SASH INDEX IS NOT INCREMENTAL
 
79
        #assert(self._keys==None) # SASH INDEX IS NOT INCREMENTAL
 
80
        self._values.appenditems(values)
 
81
        self._keys.appenditems(keys)
 
82
        self.sashindex.build(keys.astype(numpy.float32))
 
83
    def __getitem__(self,query):
 
84
        return self.sashindex.getitem(query)[0]
 
85
    def getitem(self,query,numelem=1):
 
86
        nquery=numpy.asarray(query).squeeze().astype(numpy.float32)
 
87
        #print nquery, nquery.shape
 
88
        self.sashindex.findNear(nquery,numelem)
 
89
        return zip(map(lambda i :self._values[i], self.sashindex.getResultIndices(numelem).tolist()),self.sashindex.getResultDists(numelem).tolist())
 
90
    def save(self):
 
91
        self.sashindex.save(filename+"/sash",self._keys)
 
92
    def keys(self):
 
93
        return iter(self._keys)
 
94
    def __len__(self):
 
95
        return self.sashindex.getNumItems()
 
96
    def values(self):
 
97
        return iter(self._values)
 
98
    @staticmethod
 
99
    def __tex__(o):
 
100
        o.put_bib(os.environ["JFLIPATH"]+"/pycvf.indexes/sashindex.bib")
 
101
        return "\\cite{Houle}"
 
102
 
 
103
def to2d(i):
 
104
  s=i.shape
 
105
  return i.reshape(s[0],reduce(lambda x,y:x*y,s[1:],1))
 
106
 
 
107
class SashIndex(object):
 
108
    def __init__(self,distance=None):
 
109
            self._values=None
 
110
            self._keys=None
 
111
            self.distance=distance
 
112
            if (distance):
 
113
               self.sashindex=GenericSash(distance)                
 
114
            else:
 
115
               self.sashindex=Sash()
 
116
    def add_many(self,keys,values):
 
117
        #print "checking empty sash"
 
118
        assert(self._values==None) # SASH INDEX IS NOT INCREMENTAL
 
119
        assert(self._keys==None) # SASH INDEX IS NOT INCREMENTAL
 
120
        if (not self.distance):
 
121
          try:
 
122
            self._keys=keys.reshape((keys.shape[0],-1)).astype(numpy.float32).copy('C')
 
123
          except ValueError:
 
124
            pycvf_warning("Errors SASH keys must be dense float vector in this implementation of SashIndex")
 
125
            pycvf_warning(u"here an example of your keys "  +unicode(keys[0]))
 
126
            raise
 
127
          #print type(values), values#,values[0]
 
128
          self._values=values.reshape((values.shape[0],-1))
 
129
          #print "shapes : ", self._keys.shape, self._values.shape
 
130
        else:
 
131
          self._keys=keys.reshape((keys.shape[0],-1))
 
132
          self._values=values.reshape((values.shape[0],-1))          
 
133
        self.sashindex.build(self._keys)
 
134
        #print "/built"
 
135
    @staticmethod
 
136
    def load(filename):
 
137
        pycvf_debug(10,"load "+filename+"value.dat")
 
138
        r=SashIndex()
 
139
        f=file(filename+"values.dat","rb")
 
140
        r._values=pickle.load(f)
 
141
        pycvf_debug(10,"load "+filename+"keys.npy/mat")
 
142
        try:
 
143
           r._keys=numpy.load(file(filename+"keys.npy","rb"))
 
144
        except:
 
145
           r._keys=scipy.io.loadmat(filename+"keys.mat")["keys"].copy('C')
 
146
        r.sashindex=Sash()
 
147
        #print r._keys
 
148
        r.sashindex.build(r._keys,filename=filename+"index")
 
149
        pycvf_debug(10,"/load "+filename)
 
150
        return r
 
151
    def save(self,filename):
 
152
        pycvf_debug(10, "saving sash")
 
153
        #pycvf_debug(10, "keys : " + str(self._keys.shape) +":"  + str(self._keys))
 
154
        if (self._keys not in [None,[]]): 
 
155
          pickle.dump(self._values,file(filename+"values.dat","wb"))
 
156
          #scipy.io.savemat(filename+"keys.mat",{"keys":self._keys})
 
157
          numpy.save(file(filename+"keys.npy","wb"),self._keys)
 
158
          self.sashindex.save(filename+"index")#,self._keys)
 
159
    def __getitem__(self,query):
 
160
        if (not self.distance):        
 
161
          if (type(query) in [ int, float, long ] ) : 
 
162
             query=[query]
 
163
          nquery=numpy.asarray(query).astype(numpy.float32).copy('C')
 
164
        else:
 
165
          nquery=[query]
 
166
        self.sashindex.findNearest(nquery,1)
 
167
        rl=map(lambda i :self._values[i], self.sashindex.getResultIndices(1).tolist())
 
168
        return rl[0]
 
169
    def getitem(self,query,numelem=1, exact=False,*args,**kwargs):
 
170
        if (not self.distance):
 
171
          if (type(query) in [ int, float, long ] ) : 
 
172
             query=[query]
 
173
          nquery=to2d(numpy.asarray([query])).astype(numpy.float32).copy('C')
 
174
        else:
 
175
          nquery=to2d(numpy.asarray([query]))
 
176
        if (exact): 
 
177
           self.sashindex.findNearest(nquery[0],numelem,*args,**kwargs)
 
178
        else:
 
179
           self.sashindex.findNear(nquery[0],numelem,*args,**kwargs)
 
180
        ri=self.sashindex.getResultIndices(numelem).tolist()
 
181
        rd=self.sashindex.getResultDists(numelem).tolist()
 
182
        #print ri,rd
 
183
        return zip(map(lambda i :self._values[i], ri),rd,map(lambda i :self._keys[i], ri))
 
184
    def getitems_dists(self,queries,numelems=[1,3,10], exact=False, scaleFactor=1, *args, **kwargs):
 
185
        """
 
186
          A (non-standard) call to retriece mean distances up to a certain neighbor 
 
187
          (vectors only)
 
188
        """
 
189
        res=[]
 
190
        numelem=numelems[-1]
 
191
        for r in range(queries.shape[0]):
 
192
           rr=0
 
193
           cscaleFactor=scaleFactor
 
194
           rr=self.sashindex.findNear(queries[r],numelem,cscaleFactor,*args,**kwargs)
 
195
           a=self.sashindex.getResultDists(numelem)
 
196
           res.append(map(lambda x:a[:x].mean(),numelems))
 
197
        return res
 
198
    def getaccuracy(self,queries,numelem=40, numruns=1000, scaleFactor=1, *args, **kwargs):
 
199
        """
 
200
          A (non-standard) call to retriece mean distances up to a certain neighbor 
 
201
          (vectors only)
 
202
        """
 
203
        resfast=[]
 
204
        reslong=[]
 
205
        t0fast=time.clock()
 
206
        for run in range(numruns): 
 
207
          for r in range(queries.shape[0]):
 
208
            rr=0
 
209
            cscaleFactor=scaleFactor
 
210
            rr=self.sashindex.findNear(queries[r],numelem,cscaleFactor,*args,**kwargs)
 
211
            resfast.append(self.sashindex.getResultDists(numelem))
 
212
        tfast=time.clock()-t0fast
 
213
 
 
214
        t0long=time.clock()
 
215
        for r in range(queries.shape[0]):
 
216
           rr=0
 
217
           cscaleFactor=scaleFactor
 
218
           rr=self.sashindex.findNearest(queries[r],numelem,cscaleFactor,*args,**kwargs)
 
219
           reslong.append(self.sashindex.getResultDists(numelem))
 
220
        tlong=time.clock()-t0long           
 
221
        
 
222
        
 
223
        accv=[]
 
224
        for q in zip(reslong,resfast):
 
225
           li=-1
 
226
           ld=q[0][li]
 
227
           while reslong[li-1]==ld:
 
228
             li=-1
 
229
           resfast=numpy.array(q[1])
 
230
           nok=(resfast<ld).sum()
 
231
           for x in range(-li):
 
232
             if resfast[nok]==ld:
 
233
               nok+=1
 
234
           accv.append(nok/numelem)
 
235
           
 
236
        speedup=tlong/(tfast/numruns)
 
237
        acc=numpy.mean(accv)
 
238
        
 
239
        return acc,speedup
 
240
        
 
241
    def getaccuracy2(self,queries,numelem=5, numruns=10, scaleFactor=1, progressbar=True, *args, **kwargs):
 
242
        if (progressbar):
 
243
          from pycvf.core.utilities import with_progressbar
 
244
        else:
 
245
          with_progressbar=lambda x:x
 
246
        tfast=0        
 
247
        tlong=0        
 
248
        accv=[]
 
249
        accv2=[]        
 
250
        
 
251
        for r in with_progressbar(range(queries.shape[0])):
 
252
            t0long=time.clock()    
 
253
            self.sashindex.findNearest(queries[r],numelem,*args,**kwargs)
 
254
            tlong+=time.clock()-t0long 
 
255
            reslong=self.sashindex.getResultDists(numelem) 
 
256
            
 
257
            t0fast=time.clock()    
 
258
            for run in range(numruns): 
 
259
              self.sashindex.resetQuery()
 
260
              self.sashindex.findNear(queries[r],numelem,scaleFactor=scaleFactor,*args,**kwargs)
 
261
            tfast+=time.clock()-t0fast
 
262
            
 
263
            resfast=self.sashindex.getResultDists(numelem)          
 
264
            accv.append(self.sashindex.getResultAcc(reslong))
 
265
            accv2.append((resfast.mean()-reslong.mean())/reslong.mean())            
 
266
 
 
267
        speedup=tlong/(tfast/numruns)
 
268
        
 
269
        import pylab
 
270
        pylab.ion()
 
271
        pylab.plot(numpy.sort(accv2))
 
272
        pylab.savefig("res3_3-%f.png"%(scaleFactor,))
 
273
        
 
274
        
 
275
        return numpy.mean(accv),numpy.std(accv),speedup,numpy.mean(accv2),numpy.std(accv2)
 
276
        
 
277
    def get_best_scalefactor(self,queries,*args,**kwargs):
 
278
       #return dict([ (x,self.getaccuracy2(queries,scaleFactor=x, *args, **kwargs)) for x in geometric(16**-1,8,2**.5) ])
 
279
       l=[]
 
280
       for x in    geometric(16**-1,8,2**.5):
 
281
         print x
 
282
         c=(x,self.getaccuracy2(queries,scaleFactor=x, *args, **kwargs))
 
283
         print c
 
284
         l.append( c  )
 
285
       return dict(l)
 
286
    def getitems(self,queries,numelem=1, exact=False, *args, **kwargs):
 
287
        #pycvf_debug(10, "query getitems:"+str( queries))
 
288
        #time.sleep(10)
 
289
        if (not self.distance):        
 
290
          #if (type(queries[0]) in [ int, float, long ] ) : 
 
291
          #  queries=map(lambda x:[x],queries)
 
292
          nqueries=to2d(numpy.asarray(queries)).astype(numpy.float32).copy('C')
 
293
        else:          
 
294
          nqueries=to2d(numpy.asarray(queries))
 
295
        res=[]
 
296
        for r in range(nqueries.shape[0]):
 
297
          if (exact): 
 
298
            self.sashindex.findNearest(nqueries[r],numelem,*args,**kwargs)
 
299
          else:
 
300
            self.sashindex.findNear(nqueries[r],numelem,*args,**kwargs)
 
301
          #print self._values.shape, self.sashindex.getResultIndices(numelem).tolist(), self.sashindex.getResultDists(numelem).tolist()
 
302
          #print "-q"
 
303
          ri=self.sashindex.getResultIndices(numelem).tolist()
 
304
          rd=self.sashindex.getResultDists(numelem).tolist()
 
305
          #print ri,rd ,
 
306
          #print self._values.shape, self._keys.shape
 
307
          #print self._values, self._keys
 
308
          #res+=[zip(map(lambda i :self._values[i][0], ri),rd,map(lambda i :self._keys[i], ri))]
 
309
          res+=[zip(map(lambda i :self._values[i], ri),rd )]          
 
310
          #print res
 
311
        return res#numpy.vstack(res)
 
312
    def keys(self):
 
313
        return iter(self._keys) if self._keys!= None else []
 
314
    def values(self):
 
315
        return iter(self._values) if self._values!= None else []
 
316
    def __len__(self):
 
317
        return self.sashindex.getNumItems()
 
318
    def reset(self):
 
319
         try:
 
320
             os.mkdir(filename)
 
321
         except:
 
322
             pass
 
323
         self._values=None
 
324
         self._keys=None
 
325
         self.sashindex=Sash()
 
326
    @staticmethod
 
327
    def __tex__(o):
 
328
        o.put_bib(os.environ["JFLIPATH"]+"/pycvf.indexes/sashindex.bib")
 
329
        return "\\cite{Houle}"    
 
330
 
 
331
__call__=SashIndex
 
332
load=SashIndex.load