~bertrand-nouvel/pycvf-keypoints/trunk

« back to all changes in this revision

Viewing changes to lib/emd_sift.py

  • Committer: tranx
  • Date: 2010-10-01 16:56:14 UTC
  • Revision ID: tranx@havane-20101001165614-u938mdd1y1fgd0o5
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- coding: utf-8 -*-
 
2
import numpy
 
3
import scipy.spatial
 
4
import pyfastemd
 
5
 
 
6
class SiftEMDDistanceComputer:
 
7
  def __init__(self,XNBP= 4 ,YNBP= 4, NBO= 8,thresh= 5,extra_mass_penalty= -1,flowType=1,dt=None):
 
8
      self.XNBP= XNBP # SIFT's X-dimension
 
9
      self.YNBP= YNBP # SIFT's Y-dimension
 
10
      self.NBO=  NBO # SIFT's Orientation-dimension
 
11
      self.N= XNBP*YNBP*NBO 
 
12
      self.thresh= thresh
 
13
      self.extra_mass_penalty= extra_mass_penalty # Default of maximum distance
 
14
      self.flowType= flowType # Regular flows
 
15
      #self.D= numpy.zeros((self.N,self.N)); 
 
16
      print "computing siftdistance mat"     
 
17
      siftenum=numpy.vstack(map(numpy.array,numpy.ndindex((YNBP,XNBP,NBO))))
 
18
      self.D=scipy.spatial.distance.cdist(siftenum,siftenum,lambda x,y: (x[0]-y[0])**2 + (x[1]-y[1])**2 + min(abs(x[2]-y[2]),NBO-abs(x[2]-y[2])))
 
19
      assert((self.D>=0).all())
 
20
      assert((self.D==self.D.T).all())
 
21
      print "/computing siftdistance mat"
 
22
      #maxDist= D.max();
 
23
      self.D[self.D>self.thresh]=self.thresh
 
24
      if (dt!=None):
 
25
        self.D=self.D.astype(dt)
 
26
      print self.D.shape
 
27
      self.dt=dt
 
28
  def distance(self,P,Q):
 
29
      if (self.dt==None):
 
30
          dt=P.dtype
 
31
      else:
 
32
          dt=self.dt
 
33
      if (self.flowType==1):
 
34
         return pyfastemd.emd_hat(P.astype(dt),Q.astype(dt),self.D.astype(dt),self.extra_mass_penalty,self.flowType)
 
35
      else:
 
36
         flow=numpy.zeros((self.N,self.N),dtype=dt)
 
37
         r=pyfastemd.emd_hat(P.astype(dt),Q.astype(dt),self.D.astype(dt),self.extra_mass_penalty,self.flowType,flow)
 
38
         return r,flow
 
39
 
 
40
 
 
41
shared_sift_emd_dist_computer=None     
 
42
     
 
43
def siftemd_distance(P,Q):
 
44
   global shared_sift_emd_dist_computer
 
45
   if shared_sift_emd_dist_computer==None:
 
46
       shared_sift_emd_dist_computer=SiftEMDDistanceComputer(dt=numpy.float64)
 
47
   return shared_sift_emd_dist_computer.distance(P,Q)
 
 
b'\\ No newline at end of file'