~ubuntu-branches/ubuntu/precise/brian/precise

« back to all changes in this revision

Viewing changes to brian/experimental/neuromorphic/AER.py

  • Committer: Package Import Robot
  • Author(s): Yaroslav Halchenko
  • Date: 2012-01-02 12:49:11 UTC
  • mfrom: (1.1.3)
  • Revision ID: package-import@ubuntu.com-20120102124911-6r1rmqgt5vr22ro3
Tags: 1.3.1-1
* Fresh upstream release
* Boosted policy compliance to 3.9.2 (no changes)
* Added up_skip_tests_with_paths patch to avoid test failures on custom 
  test scripts with hardcoded paths

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
"""
2
 
Module to deal with the AER (Address Event Representation) format.
3
 
"""
4
 
#from struct import *
5
 
from numpy import *
6
 
 
7
 
__all__=['load_AER','extract_DVS_event']
8
 
 
9
 
def load_AER(filename):
10
 
    '''
11
 
    Loads an AER .dat file and returns
12
 
    a vector of addresses and a vector of timestamps (ints)
13
 
    
14
 
    timestamp is (probably) in microseconds
15
 
    '''
16
 
    f=open(filename,'rb')
17
 
    version=1 # default
18
 
    
19
 
    # Skip header and look for version number
20
 
    line=f.readline()
21
 
    while line[0]=='#':
22
 
        if line[:9]=="#!AER-DAT":
23
 
            version=int(float(line[9:-1]))
24
 
        line=f.readline()
25
 
    line+=f.read()
26
 
    f.close()
27
 
    
28
 
    if version==1:
29
 
        #nevents=len(line)/6
30
 
        #for n in range(nevents):
31
 
        #    events.append(unpack('>HI',line[n*6:(n+1)*6])) # address,timestamp
32
 
        x=fromstring(line,dtype=int16) # or uint16?
33
 
        x=x.reshape((len(x)/3,3))
34
 
        addr=x[:,0].newbyteorder('>')
35
 
        timestamp=x[:,1:].copy()
36
 
        timestamp.dtype=int32
37
 
        timestamp=timestamp.newbyteorder('>').flatten()
38
 
    else: # version==2
39
 
        #nevents=len(line)/8
40
 
        #for n in range(nevents):
41
 
        #    events.append(unpack('>II',line[n*8:(n+1)*8])) # address,timestamp
42
 
        x=fromstring(line,dtype=int32).newbyteorder('>')
43
 
        addr=x[:,0]
44
 
        timestamp=x[:,1]
45
 
 
46
 
    return addr,timestamp
47
 
 
48
 
def extract_DVS_event(addr):
49
 
    '''
50
 
    Extracts retina event from an address
51
 
    
52
 
    Chip: Digital Vision Sensor (DVS)
53
 
    http://siliconretina.ini.uzh.ch/wiki/index.php
54
 
    
55
 
    Returns: x, y, polarity (ON/OFF: 1/-1)
56
 
    
57
 
    TODO:
58
 
    * maybe this should in a "chip" module?
59
 
    '''
60
 
    retina_size=128
61
 
    xmask = 0xfE # x are 7 bits (64 cols) ranging from bit 1-8
62
 
    ymask = 0x7f00 # y are also 7 bits
63
 
    xshift=1 # bits to shift x to right
64
 
    yshift=8 # bits to shift y to right
65
 
    polmask=1 # polarity bit is LSB
66
 
 
67
 
    x=retina_size-1-((addr & xmask) >> xshift)
68
 
    y=(addr & ymask) >> yshift
69
 
    pol=1-2*(addr & polmask) # 1 for ON, -1 for OFF
70
 
    return x,y,pol
71
 
 
72
 
if __name__=='__main__':
73
 
    path=r'C:\Users\Romain\Desktop\jaerSampleData\DVS128'
74
 
    filename=r'\Tmpdiff128-2006-02-03T14-39-45-0800-0 tobi eye.dat'
75
 
 
76
 
    addr,timestamp=load_AER(path+filename)
 
1
"""
 
2
Module to deal with the AER (Address Event Representation) format.
 
3
 
 
4
Current state:
 
5
* load_AER seems fine
 
6
* extract_DVS_event is probably fine too, but maybe should be in a "chip" module?
 
7
 
 
8
"""
 
9
 
 
10
from numpy import *
 
11
from brian.directcontrol import SpikeGeneratorGroup
 
12
from brian.units import *
 
13
from brian.neurongroup import *
 
14
from brian.directcontrol import SpikeGeneratorThreshold
 
15
from brian.monitor import SpikeMonitor, FileSpikeMonitor
 
16
from brian.clock import guess_clock
 
17
from brian.stateupdater import *
 
18
 
 
19
import os, datetime, struct
 
20
__all__=['load_AER','save_AER',
 
21
         'extract_DVS_event', 'extract_AMS_event',
 
22
         'AERSpikeGeneratorGroup', 'AERSpikeMonitor']
 
23
 
 
24
################# Fast Direct Control ###############################
 
25
# TODO: this could probably be moved to the main directcontrol module
 
26
 
 
27
class AERSpikeGeneratorGroup(NeuronGroup):
 
28
    '''
 
29
    This class loads AER data files and puts them in a SpikeGeneratorGroup for use in Brian.
 
30
    (one can find sample data files in http://sourceforge.net/apps/trac/jaer/wiki/AER%20data)
 
31
    
 
32
    This can load any AER files that is supported by load_AER, apart from index (.aeidx) files that point to multiple data files. Check the documentation for load_AER for that.
 
33
    
 
34
    Sample usage:
 
35
    Gin = AERSpikeGeneratorGroup('/path/to/file/samplefile.dat')
 
36
    or
 
37
    Gin = AERSpikeGeneratorGroup((addr,timestamps))
 
38
    or
 
39
    Gin = AERSpikeGeneratorGroup(pickled_spike_monitor)
 
40
 
 
41
    Attributes:
 
42
    maxtime : is the timing of the last spike of the object
 
43
    '''
 
44
    def __init__(self, data, clock = None, timeunit = 1*usecond, relative_time = True):
 
45
        if isinstance(data, str):
 
46
            l = data.split('.')
 
47
            ext = l[-1].strip('\n')
 
48
            if ext == 'aeidx':
 
49
                raise ValueError('Cannot create a single AERSpikeGeneratorGroup with aeidx files. Consider using load_AER first and manually create multiple AERSpikeGeneratorGroups.')
 
50
            else:
 
51
                data = load_AER(data, relative_time = relative_time, check_sorted = True)
 
52
        if isinstance(data, SpikeMonitor):
 
53
            addr, time = zip(*data.spikes)
 
54
            addr = np.array(list(addr))
 
55
            timestamps = np.array(list(time))
 
56
        elif isinstance(data, tuple):
 
57
            addr, timestamps = data
 
58
            
 
59
        self.tmax = max(timestamps)*timeunit
 
60
        self._nspikes = len(addr)
 
61
        N = max(addr) + 1
 
62
        clock = guess_clock(clock)
 
63
        threshold = FastDCThreshold(addr, timestamps*timeunit, dt = clock.dt)
 
64
        NeuronGroup.__init__(self, N, model = LazyStateUpdater(), threshold = threshold, clock = clock)
 
65
    
 
66
    @property
 
67
    def maxtime(self):
 
68
        # this should be kept for AER generated groups, because then one can use run(group.maxtime)
 
69
        if not isinstance(self.tmax, Quantity):
 
70
            return self.tmax*second
 
71
        return self.tmax
 
72
    
 
73
    @property
 
74
    def nspikes(self):
 
75
        return self._nspikes
 
76
        
 
77
class FastDCThreshold(SpikeGeneratorThreshold):
 
78
    '''
 
79
    Implementing dan's idea for fast Direct Control Threshold, works like a charm.
 
80
    '''
 
81
    def __init__(self, addr, timestamps, dt = None):
 
82
        self.set_offsets(addr, timestamps, dt = dt)
 
83
        
 
84
    def set_offsets(self, I, T, dt = 1000):
 
85
        # Convert times into integers
 
86
        T = array(T/dt, dtype=int)
 
87
        # Put them into order
 
88
        # We use a field array to sort first by time and then by neuron index
 
89
        spikes = zeros(len(I), dtype=[('t', int), ('i', int)])
 
90
        spikes['t'] = T
 
91
        spikes['i'] = I
 
92
        spikes.sort(order=('t', 'i'))
 
93
        T = spikes['t']
 
94
        self.I = spikes['i']
 
95
        # Now for each timestep, we find the corresponding segment of I with
 
96
        # the spike indices for that timestep.
 
97
        # The idea of offsets is that the segment offsets[t]:offsets[t+1]
 
98
        # should give the spikes with time t, i.e. T[offsets[t]:offsets[t+1]]
 
99
        # should all be equal to t, and so then later we can return
 
100
        # I[offsets[t]:offsets[t+1]] at time t. It might take a bit of thinking
 
101
        # to see why this works. Since T is sorted, and bincount[i] returns the
 
102
        # number of elements of T equal to i, then j=cumsum(bincount(T))[t]
 
103
        # gives the first index in T where T[j]=t.
 
104
        self.offsets = hstack((0, cumsum(bincount(T))))
 
105
    
 
106
    def __call__(self, P):
 
107
        t = P.clock.t
 
108
        dt = P.clock.dt
 
109
        t = int(round(t/dt))
 
110
        if t+1>=len(self.offsets):
 
111
            return array([], dtype=int)
 
112
        return self.I[self.offsets[t]:self.offsets[t+1]]
 
113
 
 
114
########### AER loading stuff ######################
 
115
 
 
116
def load_multiple_AER(filename, check_sorted = False, relative_time = False, directory = '.'):
 
117
    f=open(filename,'rb')
 
118
    line = f.readline()
 
119
    res = []
 
120
    line = line.strip('\n')
 
121
    while not line == '':
 
122
        res.append(load_AER(os.path.join(directory, line), check_sorted = check_sorted, relative_time = relative_time))
 
123
        line = f.readline()
 
124
    f.close()
 
125
    return res
 
126
 
 
127
def load_AER(filename, check_sorted = False, relative_time = True):
 
128
    '''
 
129
    Loads AER data files for use in Brian.
 
130
    Returns a list containing tuples with a vector of addresses and a vector of timestamps (ints, unit is usually usecond).
 
131
 
 
132
    It can load any kind of .dat, or .aedat files.
 
133
    Note: For index files (that point to multiple .(ae)dat files) it will return a list containing tuples as for single files.
 
134
    
 
135
    Keyword Arguments:
 
136
    If check_sorted is True, checks if timestamps are sorted,
 
137
    and sort them if necessary.
 
138
    If relative_time is True, it will set the first spike time to zero and all others relatively to that precise time (avoid negative timestamps, is definitely a good idea).
 
139
    
 
140
    Hence to use those data files in Brian, one should do:
 
141
 
 
142
    addr, timestamp =  load_AER(filename, relative_time = True)
 
143
    G = AERSpikeGeneratorGroup((addr, timestamps))
 
144
    '''
 
145
    l = filename.split('.')
 
146
    ext = l[-1].strip('\n')
 
147
    filename = filename.strip('\n')
 
148
    directory = os.path.dirname(filename)
 
149
    if ext == 'aeidx':
 
150
        #AER data points to different AER files
 
151
        return load_multiple_AER(filename, check_sorted = check_sorted, relative_time = relative_time, directory = directory)
 
152
    elif not (ext == 'dat' or ext == 'aedat'):
 
153
        raise ValueError('Wrong extension for AER data, should be dat, or aedat, it was '+ext)
 
154
    
 
155
    # This is inspired by the following Matlab script:
 
156
    # http://jaer.svn.sourceforge.net/viewvc/jaer/trunk/host/matlab/loadaerdat.m?revision=2001&content-type=text%2Fplain
 
157
    f=open(filename,'rb')
 
158
    version=1 # default (if not found in the file)
 
159
    
 
160
    # Skip header and look for version number
 
161
    line = f.readline()
 
162
    while line[0] == '#':
 
163
        if line[:9] == "#!AER-DAT":
 
164
            version = int(float(line[9:-1]))
 
165
        line = f.readline()
 
166
    line += f.read()
 
167
    f.close()
 
168
    
 
169
    if version==1:
 
170
        print 'Loading version 1 file '+filename
 
171
        '''
 
172
        Format is: sequence of (addr = 2 bytes,timestamp = 4 bytes)
 
173
        Number format is big endian ('>')
 
174
        '''
 
175
        ## This commented paragraph is the non-vectorized version
 
176
        #nevents=len(line)/6
 
177
        #for n in range(nevents):
 
178
        #    events.append(unpack('>HI',line[n*6:(n+1)*6])) # address,timestamp
 
179
        x=fromstring(line, dtype=int16) # or uint16?
 
180
        x=x.reshape((len(x)/3,3))
 
181
        addr=x[:,0].newbyteorder('>')
 
182
        timestamp=x[:,1:].copy()
 
183
        timestamp.dtype=int32
 
184
        timestamp=timestamp.newbyteorder('>').flatten()
 
185
    else: # version==2
 
186
        print 'Loading version 2 file '+filename
 
187
        '''
 
188
        Format is: sequence of (addr = 4 bytes,timestamp = 4 bytes)
 
189
        Number format is big endian ('>')
 
190
        '''
 
191
        ## This commented paragraph is the non-vectorized version
 
192
        #nevents=len(line)/8
 
193
        #for n in range(nevents):
 
194
        #    events.append(unpack('>II',line[n*8:(n+1)*8])) # address,timestamp
 
195
        x = fromstring(line, dtype=int32).newbyteorder('>')
 
196
        addr = x[::2]
 
197
        if len(addr) == len(x[1::2]):
 
198
            timestamp = x[1::2]
 
199
        else:
 
200
            print """It seems there was a problem with the AER file, timestamps and addr don't have the same length!"""
 
201
            timestamp = x[1::2]
 
202
 
 
203
    if check_sorted: # Sorts the events if necessary
 
204
        if any(diff(timestamp)<0): # not sorted
 
205
            ind = argsort(timestamp)
 
206
            addr,timestamp = addr[ind],timestamp[ind]
 
207
    if (timestamp<0).all():
 
208
        print 'Negative timestamps'
 
209
    
 
210
    if relative_time:
 
211
        t0 = min(timestamp)
 
212
        timestamp -= t0
 
213
    
 
214
    return addr,timestamp
 
215
 
 
216
HEADER = """#!AER-DAT2.0\n# This is a raw AE data file - do not edit\n# Data format is int32 address, int32 timestamp (8 bytes total), repeated for each event\n# Timestamps tick is 1 us\n# created with the Brian simulator on """
 
217
 
 
218
def save_AER(spikemonitor, f):
 
219
    '''
 
220
    Saves the SpikeMonitor's contents to a file in aedat format.
 
221
    File should have 'aedat' extension.
 
222
    One can specify an open file, or, alternatively the filename as a string.
 
223
 
 
224
    Usage:
 
225
    save_AER(spikemonitor, file)
 
226
    '''
 
227
    if isinstance(spikemonitor, SpikeMonitor):
 
228
        spikes = spikemonitor.spikes
 
229
    else:
 
230
        spikes = spikemonitor
 
231
    if isinstance(f, str):
 
232
        strinput = True
 
233
        f = open(f, 'wb')
 
234
    l = f.name.split('.')
 
235
    if not l[-1] == 'aedat':
 
236
        raise ValueError('File should have aedat extension')
 
237
    header = HEADER
 
238
    header += str(datetime.datetime.now()) + '\n'
 
239
    f.write(header)
 
240
    # i,t=zip(*spikes)
 
241
    for (i,t) in spikes:
 
242
        addr = struct.pack('>i', i)
 
243
        f.write(addr)
 
244
        time = struct.pack('>i', int(ceil(float(t/usecond))))
 
245
        f.write(time)
 
246
    if strinput:
 
247
        f.close()
 
248
    
 
249
class AERSpikeMonitor(FileSpikeMonitor):
 
250
    """Records spikes to an AER file
 
251
    
 
252
    Initialised as::
 
253
    
 
254
        FileSpikeMonitor(source, filename[, record=False])
 
255
    
 
256
    Does everything that a :class:`SpikeMonitor` does except ONLY records
 
257
    the spikes to the named file in AER format. 
 
258
 
 
259
    
 
260
    Has one additional method:
 
261
    
 
262
    ``close_file()``
 
263
        Closes the file manually (will happen automatically when
 
264
        the program ends).
 
265
    """
 
266
    def __init__(self, source, filename, record=False, delay=0):
 
267
        super(FileSpikeMonitor, self).__init__(source, record, delay)
 
268
        self.filename = filename
 
269
        self.f = open(filename, 'w')
 
270
        header = HEADER
 
271
        header += str(datetime.datetime.now()) + '\n'
 
272
        self.f.write(header)
 
273
 
 
274
    def propagate(self, spikes):
 
275
#        super(AERSpikeMonitor, self).propagate(spikes)
 
276
        # TODO do it better, no struct.pack! check numpy doc for 
 
277
        
 
278
#        addr = array(spikes).newbyteorder('>')
 
279
        for i in spikes:
 
280
            addr = struct.pack('>i', i)
 
281
            self.f.write(addr)
 
282
            time = struct.pack('>i', int(ceil(float(self.source.clock.t/usecond))))
 
283
            self.f.write(time)
 
284
    
 
285
########### AER addressing stuff ######################
 
286
 
 
287
def extract_DVS_event(addr):
 
288
    '''
 
289
    Extracts retina event from an address or a vector of addresses.
 
290
    
 
291
    Chip: Digital Vision Sensor (DVS)
 
292
    http://siliconretina.ini.uzh.ch/wiki/index.php
 
293
    
 
294
    Returns: x, y, polarity (ON/OFF: 1/-1)
 
295
    '''
 
296
    retina_size=128
 
297
 
 
298
    xmask = 0xfE # x are 7 bits (64 cols) ranging from bit 1-8
 
299
    ymask = 0x7f00 # y are also 7 bits
 
300
    xshift=1 # bits to shift x to right
 
301
    yshift=8 # bits to shift y to right
 
302
    polmask=1 # polarity bit is LSB
 
303
 
 
304
    x = retina_size - 1 - ((addr & xmask) >> xshift)
 
305
    y = (addr & ymask) >> yshift
 
306
    pol = 1 - 2*(addr & polmask) # 1 for ON, -1 for OFF
 
307
    return x,y,pol
 
308
 
 
309
def extract_AMS_event(addr):
 
310
    '''
 
311
    Extracts cochlea event from an address or a vector of addresses
 
312
 
 
313
    Chip: Silicon Cochlea (AMS)
 
314
    
 
315
    Returns: side, channel, filternature
 
316
    
 
317
    More precisely:
 
318
    side: 0 is left, 1 is right
 
319
    channel: apex (LF) is 63, base (HF) is 0
 
320
    filternature: 0 is lowpass, 1 is bandpass
 
321
    '''
 
322
    # Reference:
 
323
    # ch.unizh.ini.jaer.chip.cochlea.CochleaAMSNoBiasgen.Extractor in the jAER package (look in the javadoc)
 
324
    # also the cochlea directory in jAER/host/matlab has interesting stuff
 
325
    # the matlab code was used to write this function. I don't understand the javadoc stuff
 
326
    #cochlea_size = 64
 
327
 
 
328
    xmask = 31 # x are 5 bits 32 channels) ranging from bit 1-5 
 
329
    ymask = 32 # y (one bit) determines left or right cochlea
 
330
    xshift=0 # bits to shift x to right
 
331
    yshift=5 # bits to shift y to right
 
332
    
 
333
    channel = 1 + ((addr & xmask) >> xshift)
 
334
    side = (addr & ymask) >> yshift
 
335
    lpfBpf = mod(addr, 2)
 
336
#    leftRight = mod(addr, 4)
 
337
    return (lpfBpf, side, channel)
 
338
 
 
339
if __name__=='__main__':
 
340
    path=r'C:Users\Romain\Desktop\jaerSampleData\DVS128'
 
341
    filename=r'\Tmpdiff128-2006-02-03T14-39-45-0800-0 tobi eye.dat'
 
342
 
 
343
    addr,timestamp=load_AER(path+filename)