~ubuntu-branches/ubuntu/trusty/brian/trusty

« back to all changes in this revision

Viewing changes to brian/hears/hrtf/hrtf.py

  • Committer: Bazaar Package Importer
  • Author(s): Yaroslav Halchenko
  • Date: 2011-02-08 19:28:34 UTC
  • mfrom: (1.1.1 upstream)
  • Revision ID: james.westby@ubuntu.com-20110208192834-uexkhylennhz4qzp
Tags: 1.2.2~svn2469-1
* Upstream pre-release snapshot
* debian/copyright - fixed Maintainer and Source entries
* debian/rules:
  - set both HOME and MPLCONFIGDIR to point to build/.
    Should resolve issues of building when $HOME is read-only
    (Closes: #612548)
  - assure Agg matplotlib backend to avoid possible complications
    during off-screen operations
  - un-disabled hears unittests (model-fitting remains disabled due
    to absent dependency -- playdoh)
* debian/control:
  - graphviz (for graphs rendering) and texlive-latex-base,
    texlive-latex-extra (for equations rendering) into build-depends

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
from brian import *
 
2
from ..sounds import *
 
3
from ..filtering import FIRFilterbank
 
4
from copy import copy
 
5
 
 
6
__all__ = ['HRTF', 'HRTFSet', 'HRTFDatabase',
 
7
           'make_coordinates']
 
8
 
 
9
class HRTF(object):
 
10
    '''
 
11
    Head related transfer function.
 
12
    
 
13
    **Attributes**
 
14
 
 
15
    ``impulse_response``
 
16
        The pair of impulse responses (as stereo :class:`Sound` objects)
 
17
    ``fir``
 
18
        The impulse responses in a format suitable for using with
 
19
        :class:`FIRFilterbank` (the transpose of ``impulse_response``).
 
20
    ``left``, ``right``
 
21
        The two HRTFs (mono :class:`Sound` objects)
 
22
    ``samplerate``
 
23
        The sample rate of the HRTFs.
 
24
        
 
25
    **Methods**
 
26
    
 
27
    .. automethod:: apply
 
28
    .. automethod:: filterbank
 
29
    
 
30
    You can get the number of samples in the impulse response with ``len(hrtf)``.        
 
31
    '''
 
32
    def __init__(self, hrir_l, hrir_r=None):
 
33
        if hrir_r is None:
 
34
            hrir = hrir_l
 
35
        else:
 
36
            hrir = Sound((hrir_l, hrir_r), samplerate=hrir_l.samplerate)
 
37
        self.samplerate = hrir.samplerate
 
38
        self.impulse_response = hrir
 
39
        self.left = hrir.left
 
40
        self.right = hrir.right
 
41
 
 
42
    def apply(self, sound):
 
43
        '''
 
44
        Returns a stereo :class:`Sound` object formed by applying the pair of
 
45
        HRTFs to the mono ``sound`` input. Equivalently, you can write
 
46
        ``hrtf(sound)`` for ``hrtf`` an :class:`HRTF` object.
 
47
        '''
 
48
        # Note we use an FFT based method for applying HRTFs that is
 
49
        # mathematically equivalent to using convolution (accurate to 1e-15
 
50
        # in practice) and around 100x faster.
 
51
        if not sound.nchannels==1:
 
52
            raise ValueError('HRTF can only be applied to mono sounds')
 
53
        if len(unique(array([self.samplerate, sound.samplerate], dtype=int)))>1:
 
54
            raise ValueError('HRTF and sound samplerates do not match.')
 
55
        sound = asarray(sound).flatten()
 
56
        # Pad left/right/sound with zeros of length max(impulse response length)
 
57
        # at the beginning, and at the end so that they are all the same length
 
58
        # which should be a power of 2 for efficiency. The reason to pad at
 
59
        # the beginning is that the first output samples are not guaranteed to
 
60
        # be equal because of the delays in the impulse response, but they
 
61
        # exactly equalise after the length of the impulse response, so we just
 
62
        # zero pad. The reason for padding at the end is so that for the FFT we
 
63
        # can just multiply the arrays, which should have the same shape.
 
64
        left = asarray(self.left).flatten()
 
65
        right = asarray(self.right).flatten()
 
66
        ir_nmax = max(len(left), len(right))
 
67
        nmax = max(ir_nmax, len(sound))+ir_nmax
 
68
        nmax = 2**int(ceil(log2(nmax)))
 
69
        leftpad = hstack((left, zeros(nmax-len(left))))
 
70
        rightpad = hstack((right, zeros(nmax-len(right))))
 
71
        soundpad = hstack((zeros(ir_nmax), sound, zeros(nmax-ir_nmax-len(sound))))
 
72
        # Compute FFTs, multiply and compute IFFT
 
73
        left_fft = fft(leftpad, n=nmax)
 
74
        right_fft = fft(rightpad, n=nmax)
 
75
        sound_fft = fft(soundpad, n=nmax)
 
76
        left_sound_fft = left_fft*sound_fft
 
77
        right_sound_fft = right_fft*sound_fft
 
78
        left_sound = ifft(left_sound_fft).real
 
79
        right_sound = ifft(right_sound_fft).real
 
80
        # finally, we take only the unpadded parts of these
 
81
        left_sound = left_sound[ir_nmax:ir_nmax+len(sound)]
 
82
        right_sound = right_sound[ir_nmax:ir_nmax+len(sound)]
 
83
        return Sound((left_sound, right_sound), samplerate=self.samplerate)        
 
84
    __call__ = apply
 
85
 
 
86
    def get_fir(self):
 
87
        return array(self.impulse_response.T, copy=True)
 
88
    fir = property(fget=get_fir)
 
89
 
 
90
    def filterbank(self, source, **kwds):
 
91
        '''
 
92
        Returns an :class:`FIRFilterbank` object that can be used to apply
 
93
        the HRTF as part of a chain of filterbanks.
 
94
        '''
 
95
        return FIRFilterbank(source, self.fir, **kwds)
 
96
    
 
97
    def __len__(self):
 
98
        return self.impulse_response.shape[0]
 
99
 
 
100
def make_coordinates(**kwds):
 
101
    '''
 
102
    Creates a numpy record array from the keywords passed to the function.
 
103
    Each keyword/value pair should be the name of the coordinate the array of
 
104
    values of that coordinate for each location.
 
105
    Returns a numpy record array. For example::
 
106
    
 
107
        coords = make_coordinates(azimuth=[0, 30, 60, 0, 30, 60],
 
108
                                  elevation=[0, 0, 0, 30, 30, 30])
 
109
        print coords['azimuth']
 
110
    '''
 
111
    dtype = [(name, float) for name in kwds.keys()]
 
112
    n = len(kwds.values()[0])
 
113
    x = zeros(n, dtype=dtype)
 
114
    for name, values in kwds.items():
 
115
        x[name] = values
 
116
    return x
 
117
 
 
118
class HRTFSet(object):
 
119
    '''
 
120
    A collection of HRTFs, typically for a single individual.
 
121
    
 
122
    Normally this object is created automatically by an :class:`HRTFDatabase`.
 
123
        
 
124
    **Attributes**
 
125
    
 
126
    ``hrtf``
 
127
        A list of ``HRTF`` objects for each index.
 
128
    ``num_indices``
 
129
        The number of HRTF locations. You can also use ``len(hrtfset)``.
 
130
    ``num_samples``
 
131
        The sample length of each HRTF.
 
132
    ``fir_serial``, ``fir_interleaved``
 
133
        The impulse responses in a format suitable for using with
 
134
        :class:`FIRFilterbank`, in serial (LLLLL...RRRRR....) or interleaved
 
135
        (LRLRLR...).
 
136
    
 
137
    **Methods**
 
138
    
 
139
    .. automethod:: subset
 
140
    .. automethod:: filterbank
 
141
    
 
142
    You can access an HRTF by index via ``hrtfset[index]``, or
 
143
    by its coordinates via ``hrtfset(coord1=val1, coord2=val2)``.
 
144
    
 
145
    **Initialisation**
 
146
    
 
147
    ``data``
 
148
        An array of shape (2, num_indices, num_samples) where data[0,:,:] is
 
149
        the left ear and data[1,:,:] is the right ear, num_indices is the number
 
150
        of HRTFs for each ear, and num_samples is the length of the HRTF.
 
151
    ``samplerate``
 
152
        The sample rate for the HRTFs (should have units of Hz).
 
153
    ``coordinates``
 
154
        A record array of length ``num_indices`` giving the coordinates of each
 
155
        HRTF. You can use :func:`make_coordinates` to help with this.
 
156
    '''
 
157
    def __init__(self, data, samplerate, coordinates):
 
158
        self.data = data
 
159
        self.samplerate = samplerate
 
160
        self.coordinates = coordinates
 
161
        self.hrtf = []
 
162
        for i in xrange(self.num_indices):
 
163
            l = Sound(self.data[0, i, :], samplerate=self.samplerate)
 
164
            r = Sound(self.data[1, i, :], samplerate=self.samplerate)
 
165
            self.hrtf.append(HRTF(l, r))
 
166
            
 
167
    def __getitem__(self, key):
 
168
        return self.hrtf[key]
 
169
    
 
170
    def __call__(self, **kwds):
 
171
        I = ones(self.num_indices, dtype=bool)
 
172
        for key, value in kwds.items():
 
173
            I = logical_and(I, abs(self.coordinates[key]-value)<1e-10)
 
174
        indices = I.nonzero()[0]
 
175
        if len(indices)==0:
 
176
            raise IndexError('No HRTF exists with those coordinates')
 
177
        if len(indices)>1:
 
178
            raise IndexError('More than one HRTF exists with those coordinates')
 
179
        return self.hrtf[indices[0]]
 
180
 
 
181
    def subset(self, condition):
 
182
        '''
 
183
        Generates the subset of the set of HRTFs whose coordinates satisfy
 
184
        the ``condition``. This should be one of: a boolean array of
 
185
        length the number of HRTFs in the set, with values
 
186
        of True/False to indicate if the corresponding HRTF should be included
 
187
        or not; an integer array with the indices of the HRTFs to keep; or a
 
188
        function whose argument names are
 
189
        names of the parameters of the coordinate system, e.g.
 
190
        ``condition=lambda azim:azim<pi/2``.
 
191
        '''
 
192
        if callable(condition):
 
193
            ns = dict((name, self.coordinates[name]) for name in condition.func_code.co_varnames)
 
194
            try:
 
195
                I = condition(**ns)
 
196
                I = I.nonzero()[0]
 
197
            except:
 
198
                I = False
 
199
            if isinstance(I, bool): # vector-based calculation doesn't work
 
200
                n = len(ns[condition.func_code.co_varnames[0]])
 
201
                I = array([condition(**dict((name, ns[name][j]) for name in condition.func_code.co_varnames)) for j in range(n)])
 
202
                I = I.nonzero()[0]
 
203
        else:
 
204
            if condition.dtype==bool:
 
205
                I = condition.nonzero()[0]
 
206
            else:
 
207
                I = condition
 
208
        hrtf = [self.hrtf[i] for i in I]
 
209
        coords = self.coordinates[I]
 
210
        data = self.data[:, I, :]
 
211
        obj = copy(self)
 
212
        obj.hrtf = hrtf
 
213
        obj.coordinates = coords
 
214
        obj.data = data
 
215
        return obj
 
216
    
 
217
    def __len__(self):
 
218
        return self.num_indices
 
219
    
 
220
    @property
 
221
    def num_indices(self):
 
222
        return self.data.shape[1]
 
223
    
 
224
    @property
 
225
    def num_samples(self):
 
226
        return self.data.shape[2]
 
227
    
 
228
    @property
 
229
    def fir_serial(self):
 
230
        return reshape(self.data, (self.num_indices*2, self.num_samples))
 
231
    
 
232
    @property
 
233
    def fir_interleaved(self):
 
234
        fir = empty((self.num_indices*2, self.num_samples))
 
235
        fir[::2, :] = self.data[0, :, :]
 
236
        fir[1::2, :] = self.data[1, :, :]
 
237
        return fir
 
238
    
 
239
    def filterbank(self, source, interleaved=False, **kwds):
 
240
        '''
 
241
        Returns an :class:`FIRFilterbank` object which applies all of the HRTFs
 
242
        in the set. If ``interleaved=False`` then
 
243
        the channels are arranged in the order LLLL...RRRR..., otherwise they
 
244
        are arranged in the order LRLRLR....
 
245
        '''
 
246
        if interleaved:
 
247
            fir = self.fir_interleaved
 
248
        else:
 
249
            fir = self.fir_serial
 
250
        return FIRFilterbank(source, fir, **kwds)
 
251
 
 
252
 
 
253
class HRTFDatabase(object):
 
254
    '''
 
255
    Base class for databases of HRTFs
 
256
    
 
257
    Should have an attribute 'subjects' giving a list of available subjects,
 
258
    and a method ``load_subject(subject)`` which returns an ``HRTFSet`` for that
 
259
    subject.
 
260
    
 
261
    The initialiser should take (optional) keywords:
 
262
    
 
263
    ``samplerate``
 
264
        The intended samplerate (resampling will be used if it is wrong). If
 
265
        left unset, the natural samplerate of the data set will be used.    
 
266
    '''
 
267
    def __init__(self, samplerate=None):
 
268
        raise NotImplementedError
 
269
 
 
270
    def load_subject(self, subject):
 
271
        raise NotImplementedError