5
from brian.experimental.morphology import Morphology
6
from brian.stdunits import *
7
from brian.units import *
8
from brian.reset import NoReset
9
from brian.stateupdater import StateUpdater
10
from brian.inspection import *
11
from brian.optimiser import *
12
from itertools import count
13
from brian.neurongroup import NeuronGroup
14
from scipy.linalg import solve_banded
15
from numpy import zeros, ones, isscalar, diag_indices
16
from numpy.linalg import solve
17
from brian.clock import guess_clock
18
from brian.equations import Equations
21
from math import ceil, log
22
from scipy import weave
31
warnings.warn('sympy not installed: some features in SpatialNeuron will not be available')
34
__all__ = ['SpatialNeuron', 'CompartmentalNeuron']
37
class SpatialNeuron(NeuronGroup):
39
Compartmental model with morphology.
41
def __init__(self, morphology=None, model=None, threshold=None, reset=NoReset(),
42
refractory=0 * ms, level=0,
43
clock=None, unit_checking=True,
44
compile=False, freeze=False, implicit=True, Cm=0.9 * uF / cm ** 2, Ri=150 * ohm * cm,
45
bc_type = 2, diffeq_nonzero=True):
46
clock = guess_clock(clock)
47
N = len(morphology) # number of compartments
49
if isinstance(model, str):
50
model = Equations(model, level=level + 1)
52
model += Equations('''
53
v:volt # membrane potential
56
# Process model equations (Im) to extract total conductance and the remaining current
59
membrane_eq=model._string['Im'] # the membrane equation
61
raise TypeError,"The transmembrane current Im must be defined"
62
# Check conditional linearity
63
ids=get_identifiers(membrane_eq)
64
_namespace=dict.fromkeys(ids,1.) # there is a possibility of problems here (division by zero)
65
_namespace['v']=AffineFunction()
66
eval(membrane_eq,model._namespace['v'],_namespace)
68
eval(membrane_eq,model._namespace['v'],_namespace)
70
raise TypeError,"The membrane current must be linear with respect to v"
71
# Extracts the total conductance from Im, and the remaining current
72
z=symbolic_eval(membrane_eq)
73
symbol_v=sympy.Symbol('v')
75
a=-sympy.simplify(z.subs(symbol_v,1)-b)
76
gtot_str="_gtot="+str(a)+": siemens/cm**2"
77
I0_str="_I0="+str(b)+": amp/cm**2"
78
model+=Equations(gtot_str+"\n"+I0_str,level=level+1) # better: explicit insertion with namespace of v
80
raise TypeError,"The Sympy package must be installed for SpatialNeuron"
82
# Equations for morphology (isn't it a duplicate??)
83
eqs_morphology = Equations("""
92
full_model = model + eqs_morphology
94
NeuronGroup.__init__(self, N, model=full_model, threshold=threshold, reset=reset, refractory=refractory,
95
level=level + 1, clock=clock, unit_checking=unit_checking, implicit=implicit)
96
self.model_with_diffeq_nonzero = diffeq_nonzero
97
self._state_updater = SpatialStateUpdater(self, clock)
98
self.Cm = ones(len(self))*Cm
100
self.bc_type = bc_type #default boundary condition on leaves
101
self.bc = ones(len(self)) # boundary conditions on branch points
105
self.morphology = morphology
106
self.morphology.compress(diameter=self.diameter, length=self.length, x=self.x, y=self.y, z=self.z, area=self.area)
108
def subgroup(self, N): # Subgrouping cannot be done in this way
109
raise NotImplementedError
111
def __getitem__(self, x):
113
Subgrouping mechanism.
114
self['axon'] returns the subtree named "axon".
117
self[:] returns the full branch.
119
morpho = self.morphology[x]
120
N = self[morpho._origin:morpho._origin + len(morpho)]
121
N.morphology = morpho
124
def __getattr__(self, x):
125
if (x != 'morphology') and ((x in self.morphology._namedkid) or all([c in 'LR123456789' for c in x])): # subtree
128
return NeuronGroup.__getattr__(self, x)
130
class SpatialStateUpdater(StateUpdater):
132
State updater for compartmental models.
135
def __init__(self, neuron, clock=None):
136
self.eqs = neuron._eqs
138
self._isprepared = False
139
self._state_updater=neuron._state_updater # to update the currents
140
self.first_test_gtot=True
144
def prepare_branch(self, morphology, mid_diameter,ante=0):
146
1) fill neuron.branches and neuron.index with information about the morphology of the neuron
147
2) change some wrong values in Aplus and Aminus. Indeed these were correct only if the neuron is a linear cable.
148
Knowledge of the morphology gives correct values.
149
3) fill neuron.bc (boundary conditions)
151
branch = morphology.branch()
153
j= i + len(branch) - 2
155
self.neuron.index[i:endpoint+1] = self.neuron.BPcount
159
#connections between branches
160
for x in (morphology.children):#parent of segment n isn't always n-1 at branch points. We need to change Aplus and Aminus
161
gc = 2 * msiemens/cm**2
162
startpoint = x._origin
163
mid_diameter[startpoint] = .5*(self.neuron.diameter[endpoint]+self.neuron.diameter[startpoint])
164
self.Aminus[startpoint]=mid_diameter[startpoint]**2/(4*self.neuron.diameter[startpoint]*self.neuron.length[startpoint]**2*self.neuron.Ri)
166
self.Aplus[startpoint]=mid_diameter[startpoint]**2/(4*self.neuron.diameter[endpoint]*self.neuron.length[endpoint]**2*self.neuron.Ri)
168
self.Aplus[startpoint]=gc
169
self.Aminus[startpoint]=gc
173
pointType = self.neuron.bc[endpoint]
174
hasChild = (children_number>0)
175
if (not hasChild) and (pointType == 1): #if the branch point is a leaf of the tree : apply default boundary condition
176
self.neuron.bc[endpoint] = self.neuron.bc_type
179
#extract informations about the branches
180
index_ante = self.neuron.index[ante]
182
index = self.neuron.BPcount
183
self.i_list.append(i)
184
self.j_list.append(j)
185
self.bp_list.append(bp)
186
self.pointType_list.append(max(1,pointType))
187
self.pointTypeAnte_list.append(max(1,self.neuron.bc[ante]))
188
self.temp[index] = index_ante
189
self.id.append(index)
190
self.test_list.append((j-i+2)>1)
191
for x in xrange(j-i+2):
192
self.ante_list.append(ante)
193
self.post_list.append(bp)
195
self.ind0.append(index)
197
self.ind_bctype_0.append(bp)
200
#initialize the parts of the linear systems that will not change
201
if (j-i+2)>1: #j-i+2 = len(branch)
203
self.ab[0,i:j]= self.Aplus[i:j]
204
self.ab[2,i:j]= self.Aminus[i+1:j+1]
208
self.bL[i] = (- VL0 * self.Aminus[i])
212
self.bR[j] = (- VR0 * self.Aplus[j+1])
214
self.neuron.BPcount += 1
215
for x in (morphology.children):
216
self.prepare_branch(x,mid_diameter,endpoint)
220
From Hines 1984 paper, discrete formula is:
221
A_plus*V(i+1)-(A_plus+A_minus)*V(i)+A_minus*V(i-1)=Cm/dt*(V(i,t+dt)-V(i,t))+gtot(i)*V(i)-I0(i)
226
This gives the following tridiagonal system:
227
A_plus*V(i+1)-(Cm/dt+gtot(i)+A_plus+A_minus)*V(i)+A_minus*V(i-1)=-Cm/dt*V(i,t)-I0(i)
230
mid_diameter = zeros(len(self.neuron)) # mid(i) : (i-1) <-> i
231
mid_diameter[1:] = .5*(self.neuron.diameter[:-1]+self.neuron.diameter[1:])
233
self.Aplus = zeros(len(self.neuron)) # A+ i -> j = Aplus(j)
234
self.Aminus = zeros(len(self.neuron)) # A- i <- j = Aminus(j)
235
self.Aplus[1]= mid_diameter[1]**2/(4*self.neuron.diameter[1]*self.neuron.length[1]**2*self.neuron.Ri)
236
self.Aplus[2:]=mid_diameter[2:]**2/(4*self.neuron.diameter[1:-1]*self.neuron.length[1:-1]**2*self.neuron.Ri)
237
self.Aminus[1:]=mid_diameter[1:]**2/(4*self.neuron.diameter[1:]*self.neuron.length[1:]**2*self.neuron.Ri)
239
self.neuron.index = zeros(len(self.neuron),int) # gives the index of the branch containing the current compartment
241
self.neuron.BPcount = 0 # number of branch points (or branches). = len(self.neuron.branches)
243
#the three solutions for V on a branch
244
self.vL = zeros((len(self.neuron)),numpy.float64)
245
self.vR = zeros((len(self.neuron)),numpy.float64)
246
self.d = zeros((len(self.neuron)),numpy.float64)
248
#matrix and right hand in the tridiagonal systems that we solve to find vL, vR and d.
249
self.bL = zeros((len(self.neuron)),numpy.float64)
250
self.bR = zeros((len(self.neuron)),numpy.float64)
251
self.bd = zeros((len(self.neuron)),numpy.float64)
252
self.ab = zeros((3,len(self.neuron)))
253
self.ab1_base = zeros(len(self.neuron))
256
self.gtot = zeros(len(self.neuron))
257
self.I0 = zeros(len(self.neuron))
259
self.i_list = [] #the indexes of the first points of the branches in the neuron. len = neuron.BPcount
260
self.j_list = [] #the indexes of the last points of the branches in the neuron. len = neuron.BPcount
261
self.bp_list = [] #the indexes of the branch points in the neuron. len = neuron.BPcount
262
self.pointType_list = [] #boundary condition on bp. len = neuron.BPcount
263
self.pointTypeAnte_list = [] #boundary condition on ante. len = neuron.BPcount
264
self.index_ante_list1 = [] #index of the parent branch of the current branch. index is in [0,neuron.BPcount]
265
self.index_ante_list2 = []
266
self.ante_list = [] #the indexes in the neuron of the branch points connected to i, for every compartment. len = len(self.neuron)
267
self.post_list = [] #for every compartment, contains the index of the branch point. len = len(self.neuron)
268
self.test_list = [] #for each branch : 1 if the branch has more than 3 compartments, else 0
270
self.id = [] #list of every integer in [0,neuron.BPcount]. used in step to change some values in a matrix
272
self.temp = zeros(len(self.neuron)) #used to construct index_ante_list0, 1, 2.
273
self.ind0 = [] #indexes (in [0,neuron.BPcount]) of the branches connected to compartment 0
274
self.ind_bctype_0 = [] #indexes of the branch point with boundary condition 0 (constant V)
276
# prepare_branch : fill the lists, changes Aplus & Aminus
277
self.prepare_branch(self.neuron.morphology, mid_diameter,0)
280
self.index_ante_list1, self.ind1 = numpy.unique(numpy.array(self.temp,int),return_index=True)
281
self.ind1 = numpy.sort(self.ind1)
282
self.index_ante_list1 = self.temp[self.ind1]
283
self.index_ante_list1 = list(self.index_ante_list1)
285
for x in xrange(self.neuron.BPcount):
287
self.ind2 = numpy.delete(self.ind2,self.ind1,None)
288
self.ind2 = numpy.setdiff1d(self.ind2, self.ind0, assume_unique=True)
289
self.index_ante_list2 = self.temp[self.ind2]
290
self.index_ante_list2 = list(self.index_ante_list2)
292
self.index_ante_list = []
293
for idx in xrange(self.neuron.BPcount):
294
self.index_ante_list.append(self.temp[idx])
297
# linear system P V = B used to deal with the voltage at branch points and take boundary conditions into account.
298
self.P = zeros((self.neuron.BPcount,self.neuron.BPcount))
299
self.B = zeros(self.neuron.BPcount)
300
self.solution_bp = zeros(self.neuron.BPcount)
302
#in case of a sealed end, Aminus and Aplus are doubled :
303
self.Aminus_bp = self.Aminus[self.bp_list]
304
self.Aminus_bp [:] *= self.pointType_list[:]
305
self.Aplus_i = self.Aplus[self.i_list]
306
self.Aplus_i[:] *= self.pointTypeAnte_list[:]
309
def step(self, neuron):
311
if self.first_test_gtot and isscalar(neuron._gtot):
312
self.first_test_gtot=False
313
#neuron._gtot = ones(len(neuron)) * neuron._gtot
315
self.gtot[:] = neuron._gtot #this compute the value of neuron._gtot.
316
#if we call neuron._gtot[1] and then neuron._gtot[2] it does 2 computations
317
#here we call it only one time on the whole array. this is much faster
320
#------------------------------------solve tridiagonal systems on the branchs-------------------------
321
#ab is the matrix in the tridiagonal systems describing the branches.
322
#bd is a right hand in one of these tridiagonal systems.
323
if self.neuron.changed : # neuron.changed = True <=> there was a new input somewhere. example : the user does neuron.I[x] = y
324
self.update_ab_base()
325
self.update_ab_gtot()
328
self.calculate_vd_vL_vR()
329
self.neuron.changed = False
331
#-----------fill P and B, matrix and right hand used to find the voltage at the branch points-----------------
336
Cm = neuron.Cm[self.bp_list]
338
gtot = self.gtot[self.bp_list]
339
I0 = self.I0[self.bp_list]
340
v_bp = neuron.v[self.bp_list]
341
vLleft = self.vL[self.i_list]
342
vLright = self.vL[self.j_list]
343
vRleft = self.vR[self.i_list]
344
vRright = self.vR[self.j_list]
345
dleft = self.d[self.i_list]
346
dright = self.d[self.j_list]
348
vLleft[:] *= self.test_list[:] #if a branch has less than 3 compartments, this equals 0.
349
#thus we can do the same work on every branch point.
350
vLright[:] *= self.test_list[:]
351
vRleft[:] *= self.test_list[:]
352
vRright[:] *= self.test_list[:]
353
dleft[:] *= self.test_list[:]
354
dright[:] *= self.test_list[:]
356
self.B[self.index_ante_list1] += - self.Aplus_i[self.ind1[:]] * dleft[self.ind1[:]]
357
self.B[self.index_ante_list2] += - self.Aplus_i[self.ind2[:]] * dleft[self.ind2[:]]
358
self.B[0] += sum(- self.Aplus_i[self.ind0[:]] * dleft[self.ind0[:]])
360
self.P[(self.index_ante_list1,self.index_ante_list1)] += self.Aplus_i[self.ind1[:]] * (vLleft[self.ind1[:]] - 1)
361
self.P[(self.index_ante_list2,self.index_ante_list2)] += self.Aplus_i[self.ind2[:]] * (vLleft[self.ind2[:]] - 1)
362
self.P[0,0] += sum(self.Aplus_i[self.ind0[:]] * (vLleft[self.ind0[:]] - 1))
364
di = diag_indices(neuron.BPcount)
366
self.B[:] += - Cm[:]/dt * second * v_bp[:] - I0[:] - self.Aminus_bp[:] * dright[:]
367
self.P[di] += - Cm[:]/dt * second - gtot[:] + self.Aminus_bp[:] * (vRright[:] - 1)
368
self.P[(self.id,self.index_ante_list)] += self.Aminus_bp[:] *vLright[:]
369
self.P[(self.index_ante_list,self.id)] += self.Aplus_i[:] *vRleft[:]
371
self.P[self.ind_bctype_0,:] = 0
372
self.P[(self.ind_bctype_0,self.ind_bctype_0)] = 1
373
self.B[self.ind_bctype_0] = neuron.v[self.ind_bctype_0]
375
#------------------------------------------------------solve PV=B-----------------------------------
377
self.solution_bp = solve(self.P,self.B)
378
neuron.v[self.bp_list] = self.solution_bp[:]
380
#-------------------------------------------------------update v-------------------------------------
382
self.finalize_v_global()
385
def update_ab_base(self): #part of ab that doesn't change if there is no prompt from the operator.
386
self.ab1_base[:-1] = (- self.neuron.Cm[:-1] / self.neuron.clock.dt * second - self.Aminus[:-1] - self.Aplus[1:])
387
self.ab1_base[-1] = (- self.neuron.Cm[-1] / self.neuron.clock.dt * second - self.Aminus[-1])
389
def update_ab_gtot(self): #this is called every step. changing part of ab.
390
self.ab[1,:] = self.ab1_base[:] - self.neuron._gtot
392
def update_bd(self): #bd is a right hand side in a tridiagonal system
393
self.bd[:] = -self.neuron.Cm[:] / self.neuron.clock.dt * self.neuron.v[:] - self.neuron._I0[:]
395
def calculate_vd_vL_vR(self):
396
for index in xrange(self.neuron.BPcount) :
397
if self.test_list[index] :
398
i = self.i_list[index]
399
j = self.j_list[index]
400
self.vL[i:j+1] = solve_banded((1,1),self.ab[:,i:j+1],self.bL[i:j+1],overwrite_ab=False,overwrite_b=False)
401
self.vR[i:j+1] = solve_banded((1,1),self.ab[:,i:j+1],self.bR[i:j+1],overwrite_ab=False,overwrite_b=False)
402
self.d[i:j+1] = solve_banded((1,1),self.ab[:,i:j+1],self.bd[i:j+1],overwrite_ab=False,overwrite_b=False)
404
def finalize_v_global(self): #V(x) = V(left) * vL(x) + V(right) * vR(x) + d(x)
405
self.neuron.v[:] = self.vL[:] * self.neuron.v[self.ante_list[:]] + self.vR[:] * self.neuron.v[self.post_list[:]] + self.d[:]
406
self.neuron.v[self.bp_list] = self.solution_bp[:]
408
def __call__(self, neuron):
410
Updates the state variables.
412
if not self._isprepared:
414
self._isprepared=True
415
print "state updater prepared"
420
self._state_updater.changed = True
421
self._state_updater(neuron) #update the currents
422
self.step(neuron) #update V
426
Number of state variables
430
CompartmentalNeuron = SpatialNeuron