~ubuntu-branches/ubuntu/wily/brian/wily

« back to all changes in this revision

Viewing changes to brian/experimental/morphology/spatialneuron_remy.py

  • Committer: Package Import Robot
  • Author(s): Yaroslav Halchenko
  • Date: 2014-07-30 11:29:44 UTC
  • mfrom: (6.1.2 experimental)
  • Revision ID: package-import@ubuntu.com-20140730112944-ln0ogbq0kpyyuz47
Tags: 1.4.1-2
* Forgotten upload to unstable
* debian/control
  - policy boost to 3.9.5
  - updated Vcs- fields given migration to anonscm.d.o and provided -b
    debian

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
'''
 
2
Compartmental neurons
 
3
'''
 
4
 
 
5
from brian.experimental.morphology import Morphology
 
6
from brian.stdunits import *
 
7
from brian.units import *
 
8
from brian.reset import NoReset
 
9
from brian.stateupdater import StateUpdater
 
10
from brian.inspection import *
 
11
from brian.optimiser import *
 
12
from itertools import count
 
13
from brian.neurongroup import NeuronGroup
 
14
from scipy.linalg import solve_banded
 
15
from numpy import zeros, ones, isscalar, diag_indices
 
16
from numpy.linalg import solve
 
17
from brian.clock import guess_clock
 
18
from brian.equations import Equations
 
19
import functools
 
20
import warnings
 
21
from math import ceil, log
 
22
from scipy import weave
 
23
from time import time
 
24
import trace
 
25
import numpy
 
26
 
 
27
try:
 
28
    import sympy
 
29
    use_sympy = True
 
30
except:
 
31
    warnings.warn('sympy not installed: some features in SpatialNeuron will not be available')
 
32
    use_sympy = False
 
33
 
 
34
__all__ = ['SpatialNeuron', 'CompartmentalNeuron']
 
35
 
 
36
 
 
37
class SpatialNeuron(NeuronGroup):
 
38
        """
 
39
        Compartmental model with morphology.
 
40
        """
 
41
        def __init__(self, morphology=None, model=None, threshold=None, reset=NoReset(),
 
42
                                refractory=0 * ms, level=0,
 
43
                                clock=None, unit_checking=True,
 
44
                                compile=False, freeze=False, implicit=True, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
 
45
                                bc_type = 2, diffeq_nonzero=True):
 
46
                clock = guess_clock(clock)
 
47
                N = len(morphology) # number of compartments
 
48
                
 
49
                if isinstance(model, str):
 
50
                        model = Equations(model, level=level + 1)
 
51
                
 
52
                model += Equations('''
 
53
                v:volt # membrane potential
 
54
                ''')
 
55
 
 
56
                # Process model equations (Im) to extract total conductance and the remaining current
 
57
                if use_sympy:
 
58
                        try:
 
59
                                membrane_eq=model._string['Im'] # the membrane equation
 
60
                        except:
 
61
                                raise TypeError,"The transmembrane current Im must be defined"
 
62
                        # Check conditional linearity
 
63
                        ids=get_identifiers(membrane_eq)
 
64
                        _namespace=dict.fromkeys(ids,1.) # there is a possibility of problems here (division by zero)
 
65
                        _namespace['v']=AffineFunction()
 
66
                        eval(membrane_eq,model._namespace['v'],_namespace)
 
67
                        try:
 
68
                                eval(membrane_eq,model._namespace['v'],_namespace)
 
69
                        except: # not linear
 
70
                                raise TypeError,"The membrane current must be linear with respect to v"
 
71
                        # Extracts the total conductance from Im, and the remaining current
 
72
                        z=symbolic_eval(membrane_eq)
 
73
                        symbol_v=sympy.Symbol('v')
 
74
                        b=z.subs(symbol_v,0)
 
75
                        a=-sympy.simplify(z.subs(symbol_v,1)-b)
 
76
                        gtot_str="_gtot="+str(a)+": siemens/cm**2"
 
77
                        I0_str="_I0="+str(b)+": amp/cm**2"
 
78
                        model+=Equations(gtot_str+"\n"+I0_str,level=level+1) # better: explicit insertion with namespace of v
 
79
                else:
 
80
                        raise TypeError,"The Sympy package must be installed for SpatialNeuron"
 
81
                
 
82
                # Equations for morphology (isn't it a duplicate??)
 
83
                eqs_morphology = Equations("""
 
84
                diameter : um
 
85
                length : um
 
86
                x : um
 
87
                y : um
 
88
                z : um
 
89
                area : um**2
 
90
                """)
 
91
                
 
92
                full_model = model + eqs_morphology
 
93
                
 
94
                NeuronGroup.__init__(self, N, model=full_model, threshold=threshold, reset=reset, refractory=refractory,
 
95
                                                        level=level + 1, clock=clock, unit_checking=unit_checking, implicit=implicit)
 
96
                self.model_with_diffeq_nonzero = diffeq_nonzero
 
97
                self._state_updater = SpatialStateUpdater(self, clock)
 
98
                self.Cm = ones(len(self))*Cm
 
99
                self.Ri = Ri
 
100
                self.bc_type = bc_type #default boundary condition on leaves
 
101
                self.bc = ones(len(self)) # boundary conditions on branch points
 
102
                self.changed = True
 
103
                
 
104
                # Insert morphology
 
105
                self.morphology = morphology
 
106
                self.morphology.compress(diameter=self.diameter, length=self.length, x=self.x, y=self.y, z=self.z, area=self.area)
 
107
 
 
108
        def subgroup(self, N): # Subgrouping cannot be done in this way
 
109
                raise NotImplementedError
 
110
 
 
111
        def __getitem__(self, x):
 
112
                '''
 
113
                Subgrouping mechanism.
 
114
                self['axon'] returns the subtree named "axon".
 
115
        
 
116
                TODO:
 
117
                self[:] returns the full branch.
 
118
                '''
 
119
                morpho = self.morphology[x]
 
120
                N = self[morpho._origin:morpho._origin + len(morpho)]
 
121
                N.morphology = morpho
 
122
                return N
 
123
 
 
124
        def __getattr__(self, x):
 
125
                if (x != 'morphology') and ((x in self.morphology._namedkid) or all([c in 'LR123456789' for c in x])): # subtree
 
126
                        return self[x]
 
127
                else:
 
128
                        return NeuronGroup.__getattr__(self, x)
 
129
 
 
130
class SpatialStateUpdater(StateUpdater):
 
131
        """
 
132
        State updater for compartmental models.
 
133
 
 
134
        """
 
135
        def __init__(self, neuron, clock=None):
 
136
                self.eqs = neuron._eqs
 
137
                self.neuron = neuron
 
138
                self._isprepared = False
 
139
                self._state_updater=neuron._state_updater # to update the currents
 
140
                self.first_test_gtot=True
 
141
                self.callcount=0
 
142
                
 
143
 
 
144
        def prepare_branch(self, morphology, mid_diameter,ante=0):
 
145
                '''
 
146
                1) fill neuron.branches and neuron.index with information about the morphology of the neuron
 
147
                2) change some wrong values in Aplus and Aminus. Indeed these were correct only if the neuron is a linear cable.
 
148
                        Knowledge of the morphology gives correct values.
 
149
                3) fill neuron.bc (boundary conditions)
 
150
                '''
 
151
                branch = morphology.branch()
 
152
                i=branch._origin
 
153
                j= i + len(branch) - 2
 
154
                endpoint = j + 1
 
155
                self.neuron.index[i:endpoint+1] = self.neuron.BPcount
 
156
                children_number = 0
 
157
                
 
158
                
 
159
                #connections between branches
 
160
                for x in (morphology.children):#parent of segment n isn't always n-1 at branch points. We need to change Aplus and Aminus
 
161
                        gc = 2 * msiemens/cm**2
 
162
                        startpoint = x._origin
 
163
                        mid_diameter[startpoint] = .5*(self.neuron.diameter[endpoint]+self.neuron.diameter[startpoint])
 
164
                        self.Aminus[startpoint]=mid_diameter[startpoint]**2/(4*self.neuron.diameter[startpoint]*self.neuron.length[startpoint]**2*self.neuron.Ri)
 
165
                        if endpoint>0:
 
166
                                self.Aplus[startpoint]=mid_diameter[startpoint]**2/(4*self.neuron.diameter[endpoint]*self.neuron.length[endpoint]**2*self.neuron.Ri)
 
167
                        else :
 
168
                                self.Aplus[startpoint]=gc
 
169
                                self.Aminus[startpoint]=gc
 
170
                        children_number+=1
 
171
                
 
172
                #boundary conditions
 
173
                pointType = self.neuron.bc[endpoint]
 
174
                hasChild = (children_number>0)
 
175
                if (not hasChild) and (pointType == 1): #if the branch point is a leaf of the tree : apply default boundary condition
 
176
                        self.neuron.bc[endpoint] = self.neuron.bc_type  
 
177
                
 
178
                
 
179
                #extract informations about the branches
 
180
                index_ante = self.neuron.index[ante]
 
181
                bp = endpoint
 
182
                index = self.neuron.BPcount
 
183
                self.i_list.append(i)
 
184
                self.j_list.append(j)
 
185
                self.bp_list.append(bp)
 
186
                self.pointType_list.append(max(1,pointType))
 
187
                self.pointTypeAnte_list.append(max(1,self.neuron.bc[ante]))
 
188
                self.temp[index] = index_ante
 
189
                self.id.append(index)
 
190
                self.test_list.append((j-i+2)>1)
 
191
                for x in xrange(j-i+2):
 
192
                        self.ante_list.append(ante)
 
193
                        self.post_list.append(bp)
 
194
                if index_ante == 0:
 
195
                        self.ind0.append(index)
 
196
                if pointType==0 :
 
197
                        self.ind_bctype_0.append(bp)
 
198
                
 
199
                
 
200
                #initialize the parts of the linear systems that will not change
 
201
                if (j-i+2)>1:   #j-i+2 = len(branch)
 
202
                        #initialize ab
 
203
                        self.ab[0,i:j]= self.Aplus[i:j]
 
204
                        self.ab[2,i:j]= self.Aminus[i+1:j+1]
 
205
                        
 
206
                        #initialize bL
 
207
                        VL0 = 1 * volt
 
208
                        self.bL[i] = (- VL0 * self.Aminus[i])
 
209
                        
 
210
                        #initialize bR
 
211
                        VR0 = 1 * volt
 
212
                        self.bR[j] = (- VR0 * self.Aplus[j+1])
 
213
                
 
214
                self.neuron.BPcount += 1
 
215
                for x in (morphology.children):
 
216
                        self.prepare_branch(x,mid_diameter,endpoint)
 
217
                
 
218
        def prepare(self):
 
219
                '''
 
220
                From Hines 1984 paper, discrete formula is:
 
221
                A_plus*V(i+1)-(A_plus+A_minus)*V(i)+A_minus*V(i-1)=Cm/dt*(V(i,t+dt)-V(i,t))+gtot(i)*V(i)-I0(i)
 
222
       
 
223
                A_plus: i->i+1
 
224
                A_minus: i->i-1
 
225
                
 
226
        This gives the following tridiagonal system:
 
227
        A_plus*V(i+1)-(Cm/dt+gtot(i)+A_plus+A_minus)*V(i)+A_minus*V(i-1)=-Cm/dt*V(i,t)-I0(i)
 
228
        
 
229
        '''
 
230
                mid_diameter = zeros(len(self.neuron)) # mid(i) : (i-1) <-> i
 
231
                mid_diameter[1:] = .5*(self.neuron.diameter[:-1]+self.neuron.diameter[1:])
 
232
                
 
233
                self.Aplus = zeros(len(self.neuron)) # A+ i -> j = Aplus(j)
 
234
                self.Aminus = zeros(len(self.neuron)) # A- i <- j = Aminus(j)
 
235
                self.Aplus[1]= mid_diameter[1]**2/(4*self.neuron.diameter[1]*self.neuron.length[1]**2*self.neuron.Ri)
 
236
                self.Aplus[2:]=mid_diameter[2:]**2/(4*self.neuron.diameter[1:-1]*self.neuron.length[1:-1]**2*self.neuron.Ri)
 
237
                self.Aminus[1:]=mid_diameter[1:]**2/(4*self.neuron.diameter[1:]*self.neuron.length[1:]**2*self.neuron.Ri) 
 
238
                
 
239
                self.neuron.index = zeros(len(self.neuron),int) # gives the index of the branch containing the current compartment
 
240
                
 
241
                self.neuron.BPcount = 0 # number of branch points (or branches). = len(self.neuron.branches)
 
242
                
 
243
                #the three solutions for V on a branch
 
244
                self.vL = zeros((len(self.neuron)),numpy.float64)
 
245
                self.vR = zeros((len(self.neuron)),numpy.float64)
 
246
                self.d = zeros((len(self.neuron)),numpy.float64)
 
247
                
 
248
                #matrix and right hand in the tridiagonal systems that we solve to find vL, vR and d.
 
249
                self.bL = zeros((len(self.neuron)),numpy.float64)
 
250
                self.bR = zeros((len(self.neuron)),numpy.float64)
 
251
                self.bd = zeros((len(self.neuron)),numpy.float64)
 
252
                self.ab = zeros((3,len(self.neuron)))
 
253
                self.ab1_base = zeros(len(self.neuron))
 
254
                
 
255
                
 
256
                self.gtot = zeros(len(self.neuron))
 
257
                self.I0 = zeros(len(self.neuron))
 
258
                
 
259
                self.i_list = [] #the indexes of the first points of the branches in the neuron. len = neuron.BPcount
 
260
                self.j_list = [] #the indexes of the last points of the branches in the neuron. len = neuron.BPcount
 
261
                self.bp_list = [] #the indexes of the branch points in the neuron. len = neuron.BPcount
 
262
                self.pointType_list = [] #boundary condition on bp. len = neuron.BPcount
 
263
                self.pointTypeAnte_list = [] #boundary condition on ante. len = neuron.BPcount
 
264
                self.index_ante_list1 = [] #index of the parent branch of the current branch. index is in [0,neuron.BPcount]
 
265
                self.index_ante_list2 = []
 
266
                self.ante_list = [] #the indexes in the neuron of the branch points connected to i, for every compartment. len = len(self.neuron)
 
267
                self.post_list = [] #for every compartment, contains the index of the branch point. len = len(self.neuron)
 
268
                self.test_list = [] #for each branch : 1 if the branch has more than 3 compartments, else 0
 
269
                
 
270
                self.id = [] #list of every integer in [0,neuron.BPcount]. used in step to change some values in a matrix
 
271
                
 
272
                self.temp = zeros(len(self.neuron)) #used to construct index_ante_list0, 1, 2.
 
273
                self.ind0 = [] #indexes (in [0,neuron.BPcount]) of the branches connected to compartment 0
 
274
                self.ind_bctype_0 = [] #indexes of the branch point with boundary condition 0 (constant V)
 
275
                
 
276
                # prepare_branch : fill the lists, changes Aplus & Aminus
 
277
                self.prepare_branch(self.neuron.morphology, mid_diameter,0)
 
278
                
 
279
                
 
280
                self.index_ante_list1, self.ind1 = numpy.unique(numpy.array(self.temp,int),return_index=True)
 
281
                self.ind1 = numpy.sort(self.ind1)
 
282
                self.index_ante_list1 = self.temp[self.ind1]
 
283
                self.index_ante_list1 = list(self.index_ante_list1)
 
284
                self.ind2 = []
 
285
                for x in xrange(self.neuron.BPcount):
 
286
                        self.ind2.append(x)
 
287
                self.ind2 = numpy.delete(self.ind2,self.ind1,None) 
 
288
                self.ind2 = numpy.setdiff1d(self.ind2, self.ind0, assume_unique=True)
 
289
                self.index_ante_list2 = self.temp[self.ind2]
 
290
                self.index_ante_list2 = list(self.index_ante_list2)
 
291
                
 
292
                self.index_ante_list = []
 
293
                for idx in xrange(self.neuron.BPcount):
 
294
                        self.index_ante_list.append(self.temp[idx])
 
295
                
 
296
                
 
297
                # linear system P V = B used to deal with the voltage at branch points and take boundary conditions into account.
 
298
                self.P = zeros((self.neuron.BPcount,self.neuron.BPcount))
 
299
                self.B = zeros(self.neuron.BPcount)
 
300
                self.solution_bp = zeros(self.neuron.BPcount)
 
301
                
 
302
                #in case of a sealed end, Aminus and Aplus are doubled :
 
303
                self.Aminus_bp = self.Aminus[self.bp_list]
 
304
                self.Aminus_bp [:] *= self.pointType_list[:]
 
305
                self.Aplus_i = self.Aplus[self.i_list]
 
306
                self.Aplus_i[:] *= self.pointTypeAnte_list[:]
 
307
                
 
308
                
 
309
        def step(self, neuron):
 
310
                
 
311
                if self.first_test_gtot and isscalar(neuron._gtot):
 
312
                        self.first_test_gtot=False
 
313
                        #neuron._gtot = ones(len(neuron)) * neuron._gtot
 
314
                        
 
315
                self.gtot[:] = neuron._gtot #this compute the value of neuron._gtot.
 
316
                                                        #if we call neuron._gtot[1] and then neuron._gtot[2] it does 2 computations
 
317
                                                        #here we call it only one time on the whole array. this is much faster
 
318
                self.I0 = neuron._I0
 
319
                
 
320
                #------------------------------------solve tridiagonal systems on the branchs-------------------------
 
321
                #ab is the matrix in the tridiagonal systems describing the branches.
 
322
                #bd is a right hand in one of these tridiagonal systems.
 
323
                if self.neuron.changed : # neuron.changed = True <=> there was a new input somewhere. example : the user does  neuron.I[x] = y
 
324
                        self.update_ab_base() 
 
325
                self.update_ab_gtot()
 
326
                self.update_bd()
 
327
 
 
328
                self.calculate_vd_vL_vR()
 
329
                self.neuron.changed = False
 
330
                
 
331
                #-----------fill P and B, matrix and right hand used to find the voltage at the branch points-----------------
 
332
                
 
333
                self.P[:,:] = 0
 
334
                self.B[:] = 0
 
335
                
 
336
                Cm = neuron.Cm[self.bp_list]
 
337
                dt = neuron.clock.dt
 
338
                gtot = self.gtot[self.bp_list]
 
339
                I0 = self.I0[self.bp_list]
 
340
                v_bp = neuron.v[self.bp_list]
 
341
                vLleft = self.vL[self.i_list]
 
342
                vLright = self.vL[self.j_list]
 
343
                vRleft = self.vR[self.i_list]
 
344
                vRright = self.vR[self.j_list]
 
345
                dleft = self.d[self.i_list]
 
346
                dright = self.d[self.j_list]
 
347
                
 
348
                vLleft[:] *= self.test_list[:] #if a branch has less than 3 compartments, this equals 0.
 
349
                                                                                #thus we can do the same work on every branch point.
 
350
                vLright[:] *= self.test_list[:]
 
351
                vRleft[:] *= self.test_list[:]
 
352
                vRright[:] *= self.test_list[:]
 
353
                dleft[:] *= self.test_list[:]
 
354
                dright[:] *= self.test_list[:]
 
355
                
 
356
                self.B[self.index_ante_list1] += - self.Aplus_i[self.ind1[:]] * dleft[self.ind1[:]]
 
357
                self.B[self.index_ante_list2] += - self.Aplus_i[self.ind2[:]] * dleft[self.ind2[:]]
 
358
                self.B[0] += sum(- self.Aplus_i[self.ind0[:]] * dleft[self.ind0[:]])
 
359
                
 
360
                self.P[(self.index_ante_list1,self.index_ante_list1)] += self.Aplus_i[self.ind1[:]] * (vLleft[self.ind1[:]] - 1)
 
361
                self.P[(self.index_ante_list2,self.index_ante_list2)] += self.Aplus_i[self.ind2[:]] * (vLleft[self.ind2[:]] - 1)
 
362
                self.P[0,0] += sum(self.Aplus_i[self.ind0[:]] * (vLleft[self.ind0[:]] - 1))
 
363
                
 
364
                di = diag_indices(neuron.BPcount)
 
365
                
 
366
                self.B[:] += - Cm[:]/dt * second * v_bp[:] - I0[:] - self.Aminus_bp[:] * dright[:]
 
367
                self.P[di] += - Cm[:]/dt * second - gtot[:] + self.Aminus_bp[:] * (vRright[:] - 1)
 
368
                self.P[(self.id,self.index_ante_list)] += self.Aminus_bp[:] *vLright[:]
 
369
                self.P[(self.index_ante_list,self.id)] += self.Aplus_i[:] *vRleft[:]
 
370
                
 
371
                self.P[self.ind_bctype_0,:] = 0
 
372
                self.P[(self.ind_bctype_0,self.ind_bctype_0)] = 1
 
373
                self.B[self.ind_bctype_0] = neuron.v[self.ind_bctype_0]
 
374
                
 
375
                #------------------------------------------------------solve PV=B-----------------------------------
 
376
                
 
377
                self.solution_bp = solve(self.P,self.B)
 
378
                neuron.v[self.bp_list] = self.solution_bp[:]
 
379
                
 
380
                #-------------------------------------------------------update v-------------------------------------
 
381
                
 
382
                self.finalize_v_global()
 
383
                
 
384
 
 
385
        def update_ab_base(self): #part of ab that doesn't change if there is no prompt from the operator.
 
386
                self.ab1_base[:-1] = (- self.neuron.Cm[:-1] / self.neuron.clock.dt * second - self.Aminus[:-1] - self.Aplus[1:])
 
387
                self.ab1_base[-1] = (- self.neuron.Cm[-1] / self.neuron.clock.dt * second - self.Aminus[-1])
 
388
                
 
389
        def update_ab_gtot(self): #this is called every step. changing part of ab.
 
390
                self.ab[1,:] = self.ab1_base[:] - self.neuron._gtot
 
391
                
 
392
        def update_bd(self): #bd is a right hand side in a tridiagonal system
 
393
                self.bd[:] = -self.neuron.Cm[:] / self.neuron.clock.dt * self.neuron.v[:] - self.neuron._I0[:]
 
394
        
 
395
        def calculate_vd_vL_vR(self):
 
396
                for index in xrange(self.neuron.BPcount) :
 
397
                        if self.test_list[index] :
 
398
                                i = self.i_list[index]
 
399
                                j = self.j_list[index]
 
400
                                self.vL[i:j+1] = solve_banded((1,1),self.ab[:,i:j+1],self.bL[i:j+1],overwrite_ab=False,overwrite_b=False)
 
401
                                self.vR[i:j+1] = solve_banded((1,1),self.ab[:,i:j+1],self.bR[i:j+1],overwrite_ab=False,overwrite_b=False)
 
402
                                self.d[i:j+1] = solve_banded((1,1),self.ab[:,i:j+1],self.bd[i:j+1],overwrite_ab=False,overwrite_b=False)
 
403
        
 
404
        def finalize_v_global(self): #V(x) = V(left) * vL(x) + V(right) * vR(x) + d(x)
 
405
                self.neuron.v[:] = self.vL[:] * self.neuron.v[self.ante_list[:]] + self.vR[:] * self.neuron.v[self.post_list[:]] + self.d[:]
 
406
                self.neuron.v[self.bp_list] = self.solution_bp[:]
 
407
        
 
408
        def __call__(self, neuron):
 
409
                '''
 
410
                Updates the state variables.
 
411
                '''
 
412
                if not self._isprepared:
 
413
                        self.prepare()
 
414
                        self._isprepared=True
 
415
                        print "state updater prepared"
 
416
                self.callcount+=1
 
417
                print self.callcount
 
418
                #Update I,V
 
419
                if neuron.changed :
 
420
                        self._state_updater.changed = True
 
421
                self._state_updater(neuron) #update the currents
 
422
                self.step(neuron) #update V
 
423
                
 
424
        def __len__(self):
 
425
                '''
 
426
                Number of state variables
 
427
                '''
 
428
                return len(self.eqs)
 
429
 
 
430
CompartmentalNeuron = SpatialNeuron