dendrotweaks 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dendrotweaks/__init__.py +10 -0
- dendrotweaks/analysis/__init__.py +11 -0
- dendrotweaks/analysis/ephys_analysis.py +482 -0
- dendrotweaks/analysis/morphometric_analysis.py +106 -0
- dendrotweaks/membrane/__init__.py +6 -0
- dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
- dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
- dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
- dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
- dendrotweaks/membrane/default_mod/Leak.mod +27 -0
- dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
- dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
- dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
- dendrotweaks/membrane/default_templates/default.py +73 -0
- dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
- dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
- dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
- dendrotweaks/membrane/distributions.py +324 -0
- dendrotweaks/membrane/groups.py +103 -0
- dendrotweaks/membrane/io/__init__.py +11 -0
- dendrotweaks/membrane/io/ast.py +201 -0
- dendrotweaks/membrane/io/code_generators.py +312 -0
- dendrotweaks/membrane/io/converter.py +108 -0
- dendrotweaks/membrane/io/factories.py +144 -0
- dendrotweaks/membrane/io/grammar.py +417 -0
- dendrotweaks/membrane/io/loader.py +90 -0
- dendrotweaks/membrane/io/parser.py +499 -0
- dendrotweaks/membrane/io/reader.py +212 -0
- dendrotweaks/membrane/mechanisms.py +574 -0
- dendrotweaks/model.py +1916 -0
- dendrotweaks/model_io.py +75 -0
- dendrotweaks/morphology/__init__.py +5 -0
- dendrotweaks/morphology/domains.py +100 -0
- dendrotweaks/morphology/io/__init__.py +5 -0
- dendrotweaks/morphology/io/factories.py +212 -0
- dendrotweaks/morphology/io/reader.py +66 -0
- dendrotweaks/morphology/io/validation.py +212 -0
- dendrotweaks/morphology/point_trees.py +681 -0
- dendrotweaks/morphology/reduce/__init__.py +16 -0
- dendrotweaks/morphology/reduce/reduce.py +155 -0
- dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
- dendrotweaks/morphology/sec_trees.py +1112 -0
- dendrotweaks/morphology/seg_trees.py +157 -0
- dendrotweaks/morphology/trees.py +567 -0
- dendrotweaks/path_manager.py +261 -0
- dendrotweaks/simulators.py +235 -0
- dendrotweaks/stimuli/__init__.py +3 -0
- dendrotweaks/stimuli/iclamps.py +73 -0
- dendrotweaks/stimuli/populations.py +265 -0
- dendrotweaks/stimuli/synapses.py +203 -0
- dendrotweaks/utils.py +239 -0
- dendrotweaks-0.3.1.dist-info/METADATA +70 -0
- dendrotweaks-0.3.1.dist-info/RECORD +56 -0
- dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
- dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
- dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
dendrotweaks/model.py
ADDED
@@ -0,0 +1,1916 @@
|
|
1
|
+
from typing import List, Union, Callable
|
2
|
+
import os
|
3
|
+
import json
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
import numpy as np
|
6
|
+
import quantities as pq
|
7
|
+
|
8
|
+
from dendrotweaks.morphology.point_trees import PointTree
|
9
|
+
from dendrotweaks.morphology.sec_trees import Section, SectionTree, Domain
|
10
|
+
from dendrotweaks.morphology.seg_trees import Segment, SegmentTree
|
11
|
+
from dendrotweaks.simulators import NEURONSimulator
|
12
|
+
from dendrotweaks.membrane.groups import SegmentGroup
|
13
|
+
from dendrotweaks.membrane.mechanisms import Mechanism, LeakChannel, CaDynamics
|
14
|
+
from dendrotweaks.membrane.io import create_channel, standardize_channel, create_standard_channel
|
15
|
+
from dendrotweaks.membrane.io import MODFileLoader
|
16
|
+
from dendrotweaks.morphology.io import create_point_tree, create_section_tree, create_segment_tree
|
17
|
+
from dendrotweaks.stimuli.iclamps import IClamp
|
18
|
+
from dendrotweaks.membrane.distributions import Distribution
|
19
|
+
from dendrotweaks.stimuli.populations import Population
|
20
|
+
from dendrotweaks.utils import calculate_lambda_f, dynamic_import
|
21
|
+
from dendrotweaks.utils import get_domain_color, timeit
|
22
|
+
|
23
|
+
from collections import OrderedDict, defaultdict
|
24
|
+
from numpy import nan
|
25
|
+
# from .logger import logger
|
26
|
+
|
27
|
+
from dendrotweaks.path_manager import PathManager
|
28
|
+
import dendrotweaks.morphology.reduce as rdc
|
29
|
+
|
30
|
+
import pandas as pd
|
31
|
+
|
32
|
+
import warnings
|
33
|
+
|
34
|
+
POPULATIONS = {'AMPA': {}, 'NMDA': {}, 'AMPA_NMDA': {}, 'GABAa': {}}
|
35
|
+
|
36
|
+
def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
|
37
|
+
return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
|
38
|
+
|
39
|
+
warnings.formatwarning = custom_warning_formatter
|
40
|
+
|
41
|
+
INDEPENDENT_PARAMS = {
|
42
|
+
'cm': 1, # uF/cm2
|
43
|
+
'Ra': 100, # Ohm cm
|
44
|
+
'ena': 50, # mV
|
45
|
+
'ek': -77, # mV
|
46
|
+
'eca': 140 # mV
|
47
|
+
}
|
48
|
+
|
49
|
+
DOMAIN_TO_GROUP = {
|
50
|
+
'soma': 'somatic',
|
51
|
+
'axon': 'axonal',
|
52
|
+
'dend': 'dendritic',
|
53
|
+
'apic': 'apical',
|
54
|
+
}
|
55
|
+
|
56
|
+
|
57
|
+
class Model():
|
58
|
+
"""
|
59
|
+
A model object that represents a neuron model.
|
60
|
+
|
61
|
+
Parameters
|
62
|
+
----------
|
63
|
+
name : str
|
64
|
+
The name of the model.
|
65
|
+
simulator_name : str
|
66
|
+
The name of the simulator to use (either 'NEURON' or 'Jaxley').
|
67
|
+
path_to_data : str
|
68
|
+
The path to the data files where swc and mod files are stored.
|
69
|
+
|
70
|
+
Attributes
|
71
|
+
----------
|
72
|
+
path_to_model : str
|
73
|
+
The path to the model directory.
|
74
|
+
path_manager : PathManager
|
75
|
+
The path manager for the model.
|
76
|
+
mod_loader : MODFileLoader
|
77
|
+
The MOD file loader.
|
78
|
+
simulator_name : str
|
79
|
+
The name of the simulator to use. Default is 'NEURON'.
|
80
|
+
point_tree : PointTree
|
81
|
+
The point tree representing the morphological reconstruction.
|
82
|
+
sec_tree : SectionTree
|
83
|
+
The section tree representing the morphology on the section level.
|
84
|
+
mechanisms : dict
|
85
|
+
A dictionary of mechanisms available for the model.
|
86
|
+
domains_to_mechs : dict
|
87
|
+
A dictionary mapping domains to mechanisms inserted in them.
|
88
|
+
params : dict
|
89
|
+
A dictionary mapping parameters to their distributions.
|
90
|
+
d_lambda : float
|
91
|
+
The spatial discretization parameter.
|
92
|
+
seg_tree : SegmentTree
|
93
|
+
The segment tree representing the morphology on the segment level.
|
94
|
+
iclamps : dict
|
95
|
+
A dictionary of current clamps in the model.
|
96
|
+
populations : dict
|
97
|
+
A dictionary of "virtual" populations forming synapses on the model.
|
98
|
+
simulator : Simulator
|
99
|
+
The simulator object to use.
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(self, path_to_model,
|
103
|
+
simulator_name='NEURON',) -> None:
|
104
|
+
|
105
|
+
# Metadata
|
106
|
+
self.path_to_model = path_to_model
|
107
|
+
self._name = os.path.basename(os.path.normpath(path_to_model))
|
108
|
+
self.morphology_name = ''
|
109
|
+
self.version = ''
|
110
|
+
self.path_manager = PathManager(path_to_model)
|
111
|
+
self.simulator_name = simulator_name
|
112
|
+
self._verbose = False
|
113
|
+
|
114
|
+
# File managers
|
115
|
+
self.mod_loader = MODFileLoader()
|
116
|
+
|
117
|
+
# Morphology
|
118
|
+
self.point_tree = None
|
119
|
+
self.sec_tree = None
|
120
|
+
|
121
|
+
# Mechanisms
|
122
|
+
self.mechanisms = {}
|
123
|
+
self.domains_to_mechs = {}
|
124
|
+
|
125
|
+
# Parameters
|
126
|
+
self.params = {
|
127
|
+
'cm': {'all': Distribution('constant', value=1)}, # uF/cm2
|
128
|
+
'Ra': {'all': Distribution('constant', value=35.4)}, # Ohm cm
|
129
|
+
}
|
130
|
+
|
131
|
+
self.params_to_units = {
|
132
|
+
'cm': pq.uF/pq.cm**2,
|
133
|
+
'Ra': pq.ohm*pq.cm,
|
134
|
+
}
|
135
|
+
|
136
|
+
# Groups
|
137
|
+
self._groups = []
|
138
|
+
|
139
|
+
# Distributions
|
140
|
+
# self.distributed_params = {}
|
141
|
+
|
142
|
+
# Segmentation
|
143
|
+
self.d_lambda = 0.1
|
144
|
+
self.seg_tree = None
|
145
|
+
|
146
|
+
# Stimuli
|
147
|
+
self.iclamps = {}
|
148
|
+
self.populations = POPULATIONS
|
149
|
+
|
150
|
+
# Simulator
|
151
|
+
if simulator_name == 'NEURON':
|
152
|
+
self.simulator = NEURONSimulator()
|
153
|
+
elif simulator_name == 'Jaxley':
|
154
|
+
self.simulator = JaxleySimulator()
|
155
|
+
else:
|
156
|
+
raise ValueError(
|
157
|
+
'Simulator name not recognized. Use NEURON or Jaxley.')
|
158
|
+
|
159
|
+
|
160
|
+
# -----------------------------------------------------------------------
|
161
|
+
# PROPERTIES
|
162
|
+
# -----------------------------------------------------------------------
|
163
|
+
|
164
|
+
@property
|
165
|
+
def name(self):
|
166
|
+
"""
|
167
|
+
The name of the directory containing the model.
|
168
|
+
"""
|
169
|
+
return self._name
|
170
|
+
|
171
|
+
@property
|
172
|
+
def verbose(self):
|
173
|
+
"""
|
174
|
+
Whether to print verbose output.
|
175
|
+
"""
|
176
|
+
return self._verbose
|
177
|
+
|
178
|
+
@verbose.setter
|
179
|
+
def verbose(self, value):
|
180
|
+
self._verbose = value
|
181
|
+
self.mod_loader.verbose = value
|
182
|
+
|
183
|
+
|
184
|
+
@property
|
185
|
+
def domains(self):
|
186
|
+
"""
|
187
|
+
The morphological or functional domains of the model.
|
188
|
+
Reference to the domains in the section tree.
|
189
|
+
"""
|
190
|
+
return self.sec_tree.domains
|
191
|
+
|
192
|
+
|
193
|
+
@property
|
194
|
+
def recordings(self):
|
195
|
+
"""
|
196
|
+
The recordings of the model. Reference to the recordings in the simulator.
|
197
|
+
"""
|
198
|
+
return self.simulator.recordings
|
199
|
+
|
200
|
+
|
201
|
+
@recordings.setter
|
202
|
+
def recordings(self, recordings):
|
203
|
+
self.simulator.recordings = recordings
|
204
|
+
|
205
|
+
|
206
|
+
@property
|
207
|
+
def groups(self):
|
208
|
+
"""
|
209
|
+
The dictionary of segment groups in the model.
|
210
|
+
"""
|
211
|
+
return {group.name: group for group in self._groups}
|
212
|
+
|
213
|
+
|
214
|
+
@property
|
215
|
+
def groups_to_parameters(self):
|
216
|
+
"""
|
217
|
+
The dictionary mapping segment groups to parameters.
|
218
|
+
"""
|
219
|
+
groups_to_parameters = {}
|
220
|
+
for group in self._groups:
|
221
|
+
groups_to_parameters[group.name] = {}
|
222
|
+
for mech_name, params in self.mechs_to_params.items():
|
223
|
+
if mech_name not in group.mechanisms:
|
224
|
+
continue
|
225
|
+
groups_to_parameters[group.name] = params
|
226
|
+
return groups_to_parameters
|
227
|
+
|
228
|
+
@property
|
229
|
+
def mechs_to_domains(self):
|
230
|
+
"""
|
231
|
+
The dictionary mapping mechanisms to domains where they are inserted.
|
232
|
+
"""
|
233
|
+
mechs_to_domains = defaultdict(set)
|
234
|
+
for domain, mechs in self.domains_to_mechs.items():
|
235
|
+
for mech in mechs:
|
236
|
+
mechs_to_domains[mech].add(domain)
|
237
|
+
return dict(mechs_to_domains)
|
238
|
+
|
239
|
+
|
240
|
+
@property
|
241
|
+
def parameters_to_groups(self):
|
242
|
+
"""
|
243
|
+
The dictionary mapping parameters to groups where they are distributed.
|
244
|
+
"""
|
245
|
+
parameters_to_groups = defaultdict(list)
|
246
|
+
for group in self._groups:
|
247
|
+
for mech_name, params in self.mechs_to_params.items():
|
248
|
+
if mech_name not in group.mechanisms:
|
249
|
+
continue
|
250
|
+
for param in params:
|
251
|
+
parameters_to_groups[param].append(group.name)
|
252
|
+
return dict(parameters_to_groups)
|
253
|
+
|
254
|
+
|
255
|
+
@property
|
256
|
+
def params_to_mechs(self):
|
257
|
+
"""
|
258
|
+
The dictionary mapping parameters to mechanisms to which they belong.
|
259
|
+
"""
|
260
|
+
params_to_mechs = {}
|
261
|
+
# Sort mechanisms by length (longer first) to ensure specific matches
|
262
|
+
sorted_mechs = sorted(self.mechanisms, key=len, reverse=True)
|
263
|
+
for param in self.params:
|
264
|
+
matched = False
|
265
|
+
for mech in sorted_mechs:
|
266
|
+
suffix = f"_{mech}" # Define exact suffix
|
267
|
+
if param.endswith(suffix):
|
268
|
+
params_to_mechs[param] = mech
|
269
|
+
matched = True
|
270
|
+
break
|
271
|
+
if not matched:
|
272
|
+
params_to_mechs[param] = "Independent" # No match found
|
273
|
+
return params_to_mechs
|
274
|
+
|
275
|
+
|
276
|
+
@property
|
277
|
+
def mechs_to_params(self):
|
278
|
+
"""
|
279
|
+
The dictionary mapping mechanisms to parameters they contain.
|
280
|
+
"""
|
281
|
+
mechs_to_params = defaultdict(list)
|
282
|
+
for param, mech_name in self.params_to_mechs.items():
|
283
|
+
mechs_to_params[mech_name].append(param)
|
284
|
+
return dict(mechs_to_params)
|
285
|
+
|
286
|
+
|
287
|
+
@property
|
288
|
+
def conductances(self):
|
289
|
+
"""
|
290
|
+
A filtered dictionary of parameters that represent conductances.
|
291
|
+
"""
|
292
|
+
return {param: value for param, value in self.params.items()
|
293
|
+
if param.startswith('gbar')}
|
294
|
+
# -----------------------------------------------------------------------
|
295
|
+
# METADATA
|
296
|
+
# -----------------------------------------------------------------------
|
297
|
+
|
298
|
+
def info(self):
|
299
|
+
"""
|
300
|
+
Print information about the model.
|
301
|
+
"""
|
302
|
+
info_str = (
|
303
|
+
f"Model: {self.name}\n"
|
304
|
+
f"Path to data: {self.path_manager.path_to_data}\n"
|
305
|
+
f"Simulator: {self.simulator_name}\n"
|
306
|
+
f"Groups: {len(self.groups)}\n"
|
307
|
+
f"Avaliable mechanisms: {len(self.mechanisms)}\n"
|
308
|
+
f"Inserted mechanisms: {len(self.mechs_to_params) - 1}\n"
|
309
|
+
# f"Parameters: {len(self.parameters)}\n"
|
310
|
+
f"IClamps: {len(self.iclamps)}\n"
|
311
|
+
)
|
312
|
+
print(info_str)
|
313
|
+
|
314
|
+
|
315
|
+
@property
|
316
|
+
def df_params(self):
|
317
|
+
"""
|
318
|
+
A DataFrame of parameters and their distributions.
|
319
|
+
"""
|
320
|
+
data = []
|
321
|
+
for mech_name, params in self.mechs_to_params.items():
|
322
|
+
for param in params:
|
323
|
+
for group_name, distribution in self.params[param].items():
|
324
|
+
data.append({
|
325
|
+
'Mechanism': mech_name,
|
326
|
+
'Parameter': param,
|
327
|
+
'Group': group_name,
|
328
|
+
'Distribution': distribution if isinstance(distribution, str) else distribution.function_name,
|
329
|
+
'Distribution params': {} if isinstance(distribution, str) else distribution.parameters,
|
330
|
+
})
|
331
|
+
df = pd.DataFrame(data)
|
332
|
+
return df
|
333
|
+
|
334
|
+
def print_directory_tree(self, *args, **kwargs):
|
335
|
+
"""
|
336
|
+
Print the directory tree.
|
337
|
+
"""
|
338
|
+
return self.path_manager.print_directory_tree(*args, **kwargs)
|
339
|
+
|
340
|
+
def list_morphologies(self, extension='swc'):
|
341
|
+
"""
|
342
|
+
List the morphologies available for the model.
|
343
|
+
"""
|
344
|
+
return self.path_manager.list_files('morphology', extension=extension)
|
345
|
+
|
346
|
+
def list_membrane_configs(self, extension='json'):
|
347
|
+
"""
|
348
|
+
List the membrane configurations available for the model.
|
349
|
+
"""
|
350
|
+
return self.path_manager.list_files('membrane', extension=extension)
|
351
|
+
|
352
|
+
def list_mechanisms(self, extension='mod'):
|
353
|
+
"""
|
354
|
+
List the mechanisms available for the model.
|
355
|
+
"""
|
356
|
+
return self.path_manager.list_files('mod', extension=extension)
|
357
|
+
|
358
|
+
def list_stimuli_configs(self, extension='json'):
|
359
|
+
"""
|
360
|
+
List the stimuli configurations available for the model.
|
361
|
+
"""
|
362
|
+
return self.path_manager.list_files('stimuli', extension=extension)
|
363
|
+
|
364
|
+
# ========================================================================
|
365
|
+
# MORPHOLOGY
|
366
|
+
# ========================================================================
|
367
|
+
|
368
|
+
def load_morphology(self, file_name, soma_notation='3PS',
|
369
|
+
align=True, sort_children=True, force=False) -> None:
|
370
|
+
"""
|
371
|
+
Read an SWC file and build the SWC and section trees.
|
372
|
+
|
373
|
+
Parameters
|
374
|
+
----------
|
375
|
+
file_name : str
|
376
|
+
The name of the SWC file to read.
|
377
|
+
soma_notation : str, optional
|
378
|
+
The notation of the soma in the SWC file. Can be '3PS' (three-point soma) or '1PS'. Default is '3PS'.
|
379
|
+
align : bool, optional
|
380
|
+
Whether to align the morphology to the soma center and align the apical dendrite (if present).
|
381
|
+
sort_children : bool, optional
|
382
|
+
Whether to sort the children of each node by increasing subtree size
|
383
|
+
in the tree sorting algorithms. If True, the traversal visits
|
384
|
+
children with shorter subtrees first and assigns them lower indices. If False, children
|
385
|
+
are visited in their original SWC file order (matching NEURON's behavior).
|
386
|
+
"""
|
387
|
+
# self.name = file_name.split('.')[0]
|
388
|
+
self.morphology_name = file_name.replace('.swc', '')
|
389
|
+
path_to_swc_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
|
390
|
+
point_tree = create_point_tree(path_to_swc_file)
|
391
|
+
# point_tree.remove_overlaps()
|
392
|
+
point_tree.change_soma_notation(soma_notation)
|
393
|
+
point_tree.sort(sort_children=sort_children, force=force)
|
394
|
+
if align:
|
395
|
+
point_tree.shift_coordinates_to_soma_center()
|
396
|
+
point_tree.align_apical_dendrite()
|
397
|
+
point_tree.round_coordinates(8)
|
398
|
+
self.point_tree = point_tree
|
399
|
+
|
400
|
+
sec_tree = create_section_tree(point_tree)
|
401
|
+
sec_tree.sort(sort_children=sort_children, force=force)
|
402
|
+
self.sec_tree = sec_tree
|
403
|
+
|
404
|
+
self.create_and_reference_sections_in_simulator()
|
405
|
+
seg_tree = create_segment_tree(sec_tree)
|
406
|
+
self.seg_tree = seg_tree
|
407
|
+
|
408
|
+
self._add_default_segment_groups()
|
409
|
+
self._initialize_domains_to_mechs()
|
410
|
+
|
411
|
+
d_lambda = self.d_lambda
|
412
|
+
self.set_segmentation(d_lambda=d_lambda)
|
413
|
+
|
414
|
+
|
415
|
+
def create_and_reference_sections_in_simulator(self):
|
416
|
+
"""
|
417
|
+
Create and reference sections in the simulator.
|
418
|
+
"""
|
419
|
+
if self.verbose: print(f'Building sections in {self.simulator_name}...')
|
420
|
+
for sec in self.sec_tree.sections:
|
421
|
+
sec.create_and_reference(self.simulator_name)
|
422
|
+
n_sec = len([sec._ref for sec in self.sec_tree.sections
|
423
|
+
if sec._ref is not None])
|
424
|
+
if self.verbose: print(f'{n_sec} sections created.')
|
425
|
+
|
426
|
+
|
427
|
+
|
428
|
+
|
429
|
+
def _add_default_segment_groups(self):
|
430
|
+
self.add_group('all', list(self.domains.keys()))
|
431
|
+
for domain_name in self.domains:
|
432
|
+
group_name = DOMAIN_TO_GROUP.get(domain_name, domain_name)
|
433
|
+
self.add_group(group_name, [domain_name])
|
434
|
+
|
435
|
+
|
436
|
+
def _initialize_domains_to_mechs(self):
|
437
|
+
for domain_name in self.domains:
|
438
|
+
# Only if haven't been defined for the previous morphology
|
439
|
+
# TODO: Check that domains match
|
440
|
+
if not domain_name in self.domains_to_mechs:
|
441
|
+
self.domains_to_mechs[domain_name] = set()
|
442
|
+
for domain_name, mechs in self.domains_to_mechs.items():
|
443
|
+
for mech_name in mechs:
|
444
|
+
self.insert_mechanism(mech_name, domain_name)
|
445
|
+
|
446
|
+
|
447
|
+
def get_sections(self, filter_function):
|
448
|
+
"""Filter sections using a lambda function.
|
449
|
+
|
450
|
+
Parameters
|
451
|
+
----------
|
452
|
+
filter_function : Callable
|
453
|
+
The lambda function to filter sections.
|
454
|
+
"""
|
455
|
+
return [sec for sec in self.sec_tree.sections if filter_function(sec)]
|
456
|
+
|
457
|
+
|
458
|
+
def get_segments(self, group_names=None):
|
459
|
+
"""
|
460
|
+
Get the segments in specified groups.
|
461
|
+
|
462
|
+
Parameters
|
463
|
+
----------
|
464
|
+
group_names : List[str]
|
465
|
+
The names of the groups to get segments from.
|
466
|
+
"""
|
467
|
+
if not isinstance(group_names, list):
|
468
|
+
raise ValueError('Group names must be a list.')
|
469
|
+
return [seg for group_name in group_names for seg in self.seg_tree.segments if seg in self.groups[group_name]]
|
470
|
+
|
471
|
+
# ========================================================================
|
472
|
+
# SEGMENTATION
|
473
|
+
# ========================================================================
|
474
|
+
|
475
|
+
def set_segmentation(self, d_lambda=0.1, f=100):
|
476
|
+
"""
|
477
|
+
Set the number of segments in each section based on the geometry.
|
478
|
+
|
479
|
+
Parameters
|
480
|
+
----------
|
481
|
+
d_lambda : float
|
482
|
+
The lambda value to use.
|
483
|
+
f : float
|
484
|
+
The frequency value to use.
|
485
|
+
"""
|
486
|
+
self.d_lambda = d_lambda
|
487
|
+
|
488
|
+
# Pre-distribute parameters needed for lambda_f calculation
|
489
|
+
for param_name in ['cm', 'Ra']:
|
490
|
+
self.distribute(param_name)
|
491
|
+
|
492
|
+
# Calculate lambda_f for each section and set nseg
|
493
|
+
for sec in self.sec_tree.sections:
|
494
|
+
lambda_f = calculate_lambda_f(sec.distances, sec.diameters, sec.Ra, sec.cm, f)
|
495
|
+
nseg = int((sec.L / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1
|
496
|
+
# TODO: Set sec._nseg instead
|
497
|
+
sec._ref.nseg = nseg
|
498
|
+
# Rebuild the segment tree
|
499
|
+
self.seg_tree = create_segment_tree(self.sec_tree)
|
500
|
+
|
501
|
+
# Redistribute parameters
|
502
|
+
self.distribute_all()
|
503
|
+
|
504
|
+
|
505
|
+
# ========================================================================
|
506
|
+
# MECHANISMS
|
507
|
+
# ========================================================================
|
508
|
+
|
509
|
+
def add_default_mechanisms(self, recompile=False):
|
510
|
+
"""
|
511
|
+
Add default mechanisms to the model.
|
512
|
+
|
513
|
+
Parameters
|
514
|
+
----------
|
515
|
+
recompile : bool, optional
|
516
|
+
Whether to recompile the mechanisms.
|
517
|
+
"""
|
518
|
+
leak = LeakChannel()
|
519
|
+
self.mechanisms[leak.name] = leak
|
520
|
+
|
521
|
+
cadyn = CaDynamics()
|
522
|
+
self.mechanisms[cadyn.name] = cadyn
|
523
|
+
|
524
|
+
self.load_mechanisms('default_mod', recompile=recompile)
|
525
|
+
|
526
|
+
|
527
|
+
def add_mechanisms(self, dir_name:str = 'mod', recompile=True) -> None:
|
528
|
+
"""
|
529
|
+
Add a set of mechanisms from an archive to the model.
|
530
|
+
|
531
|
+
Parameters
|
532
|
+
----------
|
533
|
+
dir_name : str, optional
|
534
|
+
The name of the archive to load mechanisms from. Default is 'mod'.
|
535
|
+
recompile : bool, optional
|
536
|
+
Whether to recompile the mechanisms.
|
537
|
+
"""
|
538
|
+
# Create Mechanism objects and add them to the model
|
539
|
+
for mechanism_name in self.path_manager.list_files(dir_name, extension='mod'):
|
540
|
+
self.add_mechanism(mechanism_name,
|
541
|
+
load=True,
|
542
|
+
dir_name=dir_name,
|
543
|
+
recompile=recompile)
|
544
|
+
|
545
|
+
|
546
|
+
|
547
|
+
def add_mechanism(self, mechanism_name: str,
|
548
|
+
python_template_name: str = 'default',
|
549
|
+
load=True, dir_name: str = 'mod', recompile=True
|
550
|
+
) -> None:
|
551
|
+
"""
|
552
|
+
Create a Mechanism object from the MOD file (or LeakChannel).
|
553
|
+
|
554
|
+
Parameters
|
555
|
+
----------
|
556
|
+
mechanism_name : str
|
557
|
+
The name of the mechanism to add.
|
558
|
+
python_template_name : str, optional
|
559
|
+
The name of the Python template to use. Default is 'default'.
|
560
|
+
load : bool, optional
|
561
|
+
Whether to load the mechanism using neuron.load_mechanisms.
|
562
|
+
"""
|
563
|
+
paths = self.path_manager.get_channel_paths(
|
564
|
+
mechanism_name,
|
565
|
+
python_template_name=python_template_name
|
566
|
+
)
|
567
|
+
mech = create_channel(**paths)
|
568
|
+
# Add the mechanism to the model
|
569
|
+
self.mechanisms[mech.name] = mech
|
570
|
+
# Update the global parameters
|
571
|
+
|
572
|
+
if load:
|
573
|
+
self.load_mechanism(mechanism_name, dir_name, recompile)
|
574
|
+
|
575
|
+
|
576
|
+
|
577
|
+
def load_mechanisms(self, dir_name: str = 'mod', recompile=True) -> None:
|
578
|
+
"""
|
579
|
+
Load mechanisms from an archive.
|
580
|
+
|
581
|
+
Parameters
|
582
|
+
----------
|
583
|
+
dir_name : str, optional
|
584
|
+
The name of the archive to load mechanisms from.
|
585
|
+
recompile : bool, optional
|
586
|
+
Whether to recompile the mechanisms.
|
587
|
+
"""
|
588
|
+
mod_files = self.path_manager.list_files(dir_name, extension='mod')
|
589
|
+
for mechanism_name in mod_files:
|
590
|
+
self.load_mechanism(mechanism_name, dir_name, recompile)
|
591
|
+
|
592
|
+
|
593
|
+
def load_mechanism(self, mechanism_name, dir_name='mod', recompile=False) -> None:
|
594
|
+
"""
|
595
|
+
Load a mechanism from the specified archive.
|
596
|
+
|
597
|
+
Parameters
|
598
|
+
----------
|
599
|
+
mechanism_name : str
|
600
|
+
The name of the mechanism to load.
|
601
|
+
dir_name : str, optional
|
602
|
+
The name of the directory to load the mechanism from. Default is 'mod'.
|
603
|
+
recompile : bool, optional
|
604
|
+
Whether to recompile the mechanism.
|
605
|
+
"""
|
606
|
+
path_to_mod_file = self.path_manager.get_file_path(
|
607
|
+
dir_name, mechanism_name, extension='mod'
|
608
|
+
)
|
609
|
+
self.mod_loader.load_mechanism(
|
610
|
+
path_to_mod_file=path_to_mod_file, recompile=recompile
|
611
|
+
)
|
612
|
+
|
613
|
+
|
614
|
+
def standardize_channel(self, channel_name,
|
615
|
+
python_template_name=None, mod_template_name=None, remove_old=True):
|
616
|
+
"""
|
617
|
+
Standardize a channel by creating a new channel with the same kinetic
|
618
|
+
properties using the standard equations.
|
619
|
+
|
620
|
+
Parameters
|
621
|
+
----------
|
622
|
+
channel_name : str
|
623
|
+
The name of the channel to standardize.
|
624
|
+
python_template_name : str, optional
|
625
|
+
The name of the Python template to use.
|
626
|
+
mod_template_name : str, optional
|
627
|
+
The name of the MOD template to use.
|
628
|
+
remove_old : bool, optional
|
629
|
+
Whether to remove the old channel from the model. Default is True.
|
630
|
+
"""
|
631
|
+
|
632
|
+
# Get data to transfer
|
633
|
+
channel = self.mechanisms[channel_name]
|
634
|
+
channel_domain_names = [domain_name for domain_name, mechs
|
635
|
+
in self.domains_to_mechs.items() if channel_name in mechs]
|
636
|
+
gbar_name = f'gbar_{channel_name}'
|
637
|
+
gbar_distributions = self.params[gbar_name]
|
638
|
+
# Kinetic variables cannot be transferred
|
639
|
+
|
640
|
+
# Uninsert the old channel
|
641
|
+
for domain_name in self.domains:
|
642
|
+
if channel_name in self.domains_to_mechs[domain_name]:
|
643
|
+
self.uninsert_mechanism(channel_name, domain_name)
|
644
|
+
|
645
|
+
# Remove the old channel
|
646
|
+
if remove_old:
|
647
|
+
self.mechanisms.pop(channel_name)
|
648
|
+
|
649
|
+
# Create, add and load a new channel
|
650
|
+
paths = self.path_manager.get_standard_channel_paths(
|
651
|
+
channel_name,
|
652
|
+
mod_template_name=mod_template_name
|
653
|
+
)
|
654
|
+
standard_channel = standardize_channel(channel, **paths)
|
655
|
+
|
656
|
+
self.mechanisms[standard_channel.name] = standard_channel
|
657
|
+
self.load_mechanism(standard_channel.name, recompile=True)
|
658
|
+
|
659
|
+
# Insert the new channel
|
660
|
+
for domain_name in channel_domain_names:
|
661
|
+
self.insert_mechanism(standard_channel.name, domain_name)
|
662
|
+
|
663
|
+
# Transfer data
|
664
|
+
gbar_name = f'gbar_{standard_channel.name}'
|
665
|
+
for group_name, distribution in gbar_distributions.items():
|
666
|
+
self.set_param(gbar_name, group_name,
|
667
|
+
distribution.function_name, **distribution.parameters)
|
668
|
+
|
669
|
+
|
670
|
+
# ========================================================================
|
671
|
+
# DOMAINS
|
672
|
+
# ========================================================================
|
673
|
+
|
674
|
+
def define_domain(self, domain_name: str, sections, distribute=True):
|
675
|
+
"""
|
676
|
+
Adds a new domain to the tree and ensures correct partitioning of
|
677
|
+
the section tree graph.
|
678
|
+
|
679
|
+
Parameters
|
680
|
+
----------
|
681
|
+
domain_name : str
|
682
|
+
The name of the domain.
|
683
|
+
sections : list[Section] or Callable
|
684
|
+
The sections to include in the domain. If a callable is provided,
|
685
|
+
it should be a filter function applied to the list of all sections
|
686
|
+
of the cell.
|
687
|
+
"""
|
688
|
+
if isinstance(sections, Callable):
|
689
|
+
sections = self.get_sections(sections)
|
690
|
+
|
691
|
+
if domain_name not in self.domains:
|
692
|
+
domain = Domain(domain_name)
|
693
|
+
self._add_domain_groups(domain.name)
|
694
|
+
self.domains[domain_name] = domain
|
695
|
+
self.domains_to_mechs[domain_name] = set()
|
696
|
+
else:
|
697
|
+
domain = self.domains[domain_name]
|
698
|
+
|
699
|
+
sections_to_move = [sec for sec in sections
|
700
|
+
if sec.domain != domain_name]
|
701
|
+
|
702
|
+
if not sections_to_move:
|
703
|
+
warnings.warn(f'Sections already in domain {domain_name}.')
|
704
|
+
return
|
705
|
+
|
706
|
+
for sec in sections_to_move:
|
707
|
+
old_domain = self.domains[sec.domain]
|
708
|
+
old_domain.remove_section(sec)
|
709
|
+
for mech_name in self.domains_to_mechs[old_domain.name]:
|
710
|
+
# TODO: What if section is already in domain? Can't be as
|
711
|
+
# we use a filtered list of sections.
|
712
|
+
sec.uninsert_mechanism(mech_name)
|
713
|
+
|
714
|
+
|
715
|
+
for sec in sections_to_move:
|
716
|
+
domain.add_section(sec)
|
717
|
+
for mech_name in self.domains_to_mechs.get(domain.name, set()):
|
718
|
+
sec.insert_mechanism(mech_name, distribute=distribute)
|
719
|
+
|
720
|
+
self._remove_empty()
|
721
|
+
|
722
|
+
|
723
|
+
def _add_domain_groups(self, domain_name):
|
724
|
+
"""
|
725
|
+
Manage groups when a domain is added.
|
726
|
+
"""
|
727
|
+
# Add new domain to `all` group
|
728
|
+
if self.groups.get('all'):
|
729
|
+
self.groups['all'].domains.append(domain_name)
|
730
|
+
# Create a new group for the domain
|
731
|
+
group_name = DOMAIN_TO_GROUP.get(domain_name, domain_name)
|
732
|
+
self.add_group(group_name, domains=[domain_name])
|
733
|
+
|
734
|
+
|
735
|
+
def _remove_empty(self):
|
736
|
+
self._remove_empty_domains()
|
737
|
+
self._remove_uninserted_mechanisms()
|
738
|
+
self._remove_empty_groups()
|
739
|
+
|
740
|
+
|
741
|
+
def _remove_empty_domains(self):
|
742
|
+
"""
|
743
|
+
"""
|
744
|
+
empty_domains = [domain for domain in self.domains.values()
|
745
|
+
if domain.is_empty()]
|
746
|
+
for domain in empty_domains:
|
747
|
+
warnings.warn(f'Domain {domain.name} is empty and will be removed.')
|
748
|
+
self.domains.pop(domain.name)
|
749
|
+
self.domains_to_mechs.pop(domain.name)
|
750
|
+
for group in self._groups:
|
751
|
+
if domain.name in group.domains:
|
752
|
+
group.domains.remove(domain.name)
|
753
|
+
# self.groups['all'].domains.remove(domain.name)
|
754
|
+
|
755
|
+
|
756
|
+
def _remove_uninserted_mechanisms(self):
|
757
|
+
mech_names = list(self.mechs_to_params.keys())
|
758
|
+
mechs = [self.mechanisms[mech_name] for mech_name in mech_names
|
759
|
+
if mech_name != 'Independent']
|
760
|
+
uninserted_mechs = [mech for mech in mechs
|
761
|
+
if mech.name not in self.mechs_to_domains]
|
762
|
+
for mech in uninserted_mechs:
|
763
|
+
warnings.warn(f'Mechanism {mech.name} is not inserted in any domain and will be removed.')
|
764
|
+
self._remove_mechanism_params(mech)
|
765
|
+
|
766
|
+
|
767
|
+
def _remove_empty_groups(self):
|
768
|
+
empty_groups = [group for group in self._groups
|
769
|
+
if not any(seg in group
|
770
|
+
for seg in self.seg_tree)]
|
771
|
+
for group in empty_groups:
|
772
|
+
warnings.warn(f'Group {group.name} is empty and will be removed.')
|
773
|
+
self.remove_group(group.name)
|
774
|
+
|
775
|
+
|
776
|
+
# -----------------------------------------------------------------------
|
777
|
+
# INSERT / UNINSERT MECHANISMS
|
778
|
+
# -----------------------------------------------------------------------
|
779
|
+
|
780
|
+
def insert_mechanism(self, mechanism_name: str,
|
781
|
+
domain_name: str, distribute=True):
|
782
|
+
"""
|
783
|
+
Insert a mechanism into all sections in a domain.
|
784
|
+
|
785
|
+
Parameters
|
786
|
+
----------
|
787
|
+
mechanism_name : str
|
788
|
+
The name of the mechanism to insert.
|
789
|
+
domain_name : str
|
790
|
+
The name of the domain to insert the mechanism into.
|
791
|
+
distribute : bool, optional
|
792
|
+
Whether to distribute the parameters after inserting the mechanism.
|
793
|
+
"""
|
794
|
+
mech = self.mechanisms[mechanism_name]
|
795
|
+
domain = self.domains[domain_name]
|
796
|
+
|
797
|
+
# domain.insert_mechanism(mech)
|
798
|
+
self.domains_to_mechs[domain_name].add(mech.name)
|
799
|
+
for sec in domain.sections:
|
800
|
+
sec.insert_mechanism(mech.name)
|
801
|
+
self._add_mechanism_params(mech)
|
802
|
+
|
803
|
+
# TODO: Redistribute parameters if any group contains this domain
|
804
|
+
if distribute:
|
805
|
+
for param_name in self.params:
|
806
|
+
self.distribute(param_name)
|
807
|
+
|
808
|
+
|
809
|
+
def _add_mechanism_params(self, mech):
|
810
|
+
"""
|
811
|
+
Update the parameters when a mechanism is inserted.
|
812
|
+
By default each parameter is set to a constant value
|
813
|
+
through the entire cell.
|
814
|
+
"""
|
815
|
+
for param_name, value in mech.range_params_with_suffix.items():
|
816
|
+
self.params[param_name] = {'all': Distribution('constant', value=value)}
|
817
|
+
|
818
|
+
if hasattr(mech, 'ion') and mech.ion in ['na', 'k', 'ca']:
|
819
|
+
self._add_equilibrium_potentials_on_mech_insert(mech.ion)
|
820
|
+
|
821
|
+
|
822
|
+
def _add_equilibrium_potentials_on_mech_insert(self, ion: str) -> None:
|
823
|
+
"""
|
824
|
+
"""
|
825
|
+
if ion == 'na' and not self.params.get('ena'):
|
826
|
+
self.params['ena'] = {'all': Distribution('constant', value=50)}
|
827
|
+
elif ion == 'k' and not self.params.get('ek'):
|
828
|
+
self.params['ek'] = {'all': Distribution('constant', value=-77)}
|
829
|
+
elif ion == 'ca' and not self.params.get('eca'):
|
830
|
+
self.params['eca'] = {'all': Distribution('constant', value=140)}
|
831
|
+
|
832
|
+
|
833
|
+
def uninsert_mechanism(self, mechanism_name: str,
|
834
|
+
domain_name: str):
|
835
|
+
"""
|
836
|
+
Uninsert a mechanism from all sections in a domain
|
837
|
+
|
838
|
+
Parameters
|
839
|
+
----------
|
840
|
+
mechanism_name : str
|
841
|
+
The name of the mechanism to uninsert.
|
842
|
+
domain_name : str
|
843
|
+
The name of the domain to uninsert the mechanism from.
|
844
|
+
"""
|
845
|
+
mech = self.mechanisms[mechanism_name]
|
846
|
+
domain = self.domains[domain_name]
|
847
|
+
|
848
|
+
# domain.uninsert_mechanism(mech)
|
849
|
+
for sec in domain.sections:
|
850
|
+
sec.uninsert_mechanism(mech.name)
|
851
|
+
self.domains_to_mechs[domain_name].remove(mech.name)
|
852
|
+
|
853
|
+
if not self.mechs_to_domains.get(mech.name):
|
854
|
+
warnings.warn(f'Mechanism {mech.name} is not inserted in any domain and will be removed.')
|
855
|
+
self._remove_mechanism_params(mech)
|
856
|
+
|
857
|
+
|
858
|
+
def _remove_mechanism_params(self, mech):
|
859
|
+
for param_name in self.mechs_to_params.get(mech.name, []):
|
860
|
+
self.params.pop(param_name)
|
861
|
+
|
862
|
+
if hasattr(mech, 'ion') and mech.ion in ['na', 'k', 'ca']:
|
863
|
+
self._remove_equilibrium_potentials_on_mech_uninsert(mech.ion)
|
864
|
+
|
865
|
+
|
866
|
+
def _remove_equilibrium_potentials_on_mech_uninsert(self, ion: str) -> None:
|
867
|
+
"""
|
868
|
+
"""
|
869
|
+
for mech_name, mech in self.mechanisms.items():
|
870
|
+
if hasattr(mech, 'ion'):
|
871
|
+
if mech.ion == mech.ion: return
|
872
|
+
|
873
|
+
if ion == 'na':
|
874
|
+
self.params.pop('ena', None)
|
875
|
+
elif ion == 'k':
|
876
|
+
self.params.pop('ek', None)
|
877
|
+
elif ion == 'ca':
|
878
|
+
self.params.pop('eca', None)
|
879
|
+
|
880
|
+
|
881
|
+
# ========================================================================
|
882
|
+
# SET PARAMETERS
|
883
|
+
# ========================================================================
|
884
|
+
|
885
|
+
# -----------------------------------------------------------------------
|
886
|
+
# GROUPS
|
887
|
+
# -----------------------------------------------------------------------
|
888
|
+
|
889
|
+
def add_group(self, name, domains, select_by=None, min_value=None, max_value=None):
|
890
|
+
"""
|
891
|
+
Add a group of sections to the model.
|
892
|
+
|
893
|
+
Parameters
|
894
|
+
----------
|
895
|
+
name : str
|
896
|
+
The name of the group.
|
897
|
+
domains : list[str]
|
898
|
+
The domains to include in the group.
|
899
|
+
select_by : str, optional
|
900
|
+
The parameter to select the sections by. Can be 'diam', 'distance', 'domain_distance'.
|
901
|
+
min_value : float, optional
|
902
|
+
The minimum value of the parameter.
|
903
|
+
max_value : float, optional
|
904
|
+
The maximum value of the
|
905
|
+
"""
|
906
|
+
if self.verbose: print(f'Adding group {name}...')
|
907
|
+
group = SegmentGroup(name, domains, select_by, min_value, max_value)
|
908
|
+
self._groups.append(group)
|
909
|
+
|
910
|
+
|
911
|
+
def remove_group(self, group_name):
|
912
|
+
"""
|
913
|
+
Remove a group from the model.
|
914
|
+
|
915
|
+
Parameters
|
916
|
+
----------
|
917
|
+
group_name : str
|
918
|
+
The name of the group to remove.
|
919
|
+
"""
|
920
|
+
# Remove group from the list of groups
|
921
|
+
self._groups = [group for group in self._groups
|
922
|
+
if group.name != group_name]
|
923
|
+
# Remove distributions that refer to this group
|
924
|
+
for param_name, groups_to_distrs in self.params.items():
|
925
|
+
groups_to_distrs.pop(group_name, None)
|
926
|
+
|
927
|
+
|
928
|
+
def move_group_down(self, name):
|
929
|
+
"""
|
930
|
+
Move a group down in the list of groups.
|
931
|
+
|
932
|
+
Parameters
|
933
|
+
----------
|
934
|
+
name : str
|
935
|
+
The name of the group to move down.
|
936
|
+
"""
|
937
|
+
idx = next(i for i, group in enumerate(self._groups) if group.name == name)
|
938
|
+
if idx > 0:
|
939
|
+
self._groups[idx-1], self._groups[idx] = self._groups[idx], self._groups[idx-1]
|
940
|
+
for param_name in self.distributed_params:
|
941
|
+
self.distribute(param_name)
|
942
|
+
|
943
|
+
|
944
|
+
def move_group_up(self, name):
|
945
|
+
"""
|
946
|
+
Move a group up in the list of groups.
|
947
|
+
|
948
|
+
Parameters
|
949
|
+
----------
|
950
|
+
name : str
|
951
|
+
The name of the group to move up.
|
952
|
+
"""
|
953
|
+
idx = next(i for i, group in enumerate(self._groups) if group.name == name)
|
954
|
+
if idx < len(self._groups) - 1:
|
955
|
+
self._groups[idx+1], self._groups[idx] = self._groups[idx], self._groups[idx+1]
|
956
|
+
for param_name in self.distributed_params:
|
957
|
+
self.distribute(param_name)
|
958
|
+
|
959
|
+
|
960
|
+
# -----------------------------------------------------------------------
|
961
|
+
# DISTRIBUTIONS
|
962
|
+
# -----------------------------------------------------------------------
|
963
|
+
|
964
|
+
def set_param(self, param_name: str,
|
965
|
+
group_name: str = 'all',
|
966
|
+
distr_type: str = 'constant',
|
967
|
+
**distr_params):
|
968
|
+
"""
|
969
|
+
Set a parameter for a group of segments.
|
970
|
+
|
971
|
+
Parameters
|
972
|
+
----------
|
973
|
+
param_name : str
|
974
|
+
The name of the parameter to set.
|
975
|
+
group_name : str, optional
|
976
|
+
The name of the group to set the parameter for. Default is 'all'.
|
977
|
+
distr_type : str, optional
|
978
|
+
The type of the distribution to use. Default is 'constant'.
|
979
|
+
distr_params : dict
|
980
|
+
The parameters of the distribution.
|
981
|
+
"""
|
982
|
+
|
983
|
+
if 'group' in distr_params:
|
984
|
+
raise ValueError("Did you mean 'group_name' instead of 'group'?")
|
985
|
+
|
986
|
+
if param_name in ['temperature', 'v_init']:
|
987
|
+
setattr(self.simulator, param_name, distr_params['value'])
|
988
|
+
return
|
989
|
+
|
990
|
+
for key, value in distr_params.items():
|
991
|
+
if not isinstance(value, (int, float)) or value is nan:
|
992
|
+
raise ValueError(f"Parameter '{key}' must be a numeric value and not NaN, got {type(value).__name__} instead.")
|
993
|
+
|
994
|
+
self.set_distribution(param_name, group_name, distr_type, **distr_params)
|
995
|
+
self.distribute(param_name)
|
996
|
+
|
997
|
+
|
998
|
+
def set_distribution(self, param_name: str,
|
999
|
+
group_name: None,
|
1000
|
+
distr_type: str = 'constant',
|
1001
|
+
**distr_params):
|
1002
|
+
"""
|
1003
|
+
Set a distribution for a parameter.
|
1004
|
+
|
1005
|
+
Parameters
|
1006
|
+
----------
|
1007
|
+
param_name : str
|
1008
|
+
The name of the parameter to set.
|
1009
|
+
group_name : str, optional
|
1010
|
+
The name of the group to set the parameter for. Default is 'all'.
|
1011
|
+
distr_type : str, optional
|
1012
|
+
The type of the distribution to use. Default is 'constant'.
|
1013
|
+
distr_params : dict
|
1014
|
+
The parameters of the distribution.
|
1015
|
+
"""
|
1016
|
+
|
1017
|
+
if distr_type == 'inherit':
|
1018
|
+
distribution = 'inherit'
|
1019
|
+
else:
|
1020
|
+
distribution = Distribution(distr_type, **distr_params)
|
1021
|
+
self.params[param_name][group_name] = distribution
|
1022
|
+
|
1023
|
+
|
1024
|
+
def distribute_all(self):
|
1025
|
+
"""
|
1026
|
+
Distribute all parameters to the segments.
|
1027
|
+
"""
|
1028
|
+
groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group]
|
1029
|
+
for group in self._groups}
|
1030
|
+
for param_name in self.params:
|
1031
|
+
self.distribute(param_name, groups_to_segments)
|
1032
|
+
|
1033
|
+
|
1034
|
+
def distribute(self, param_name: str, precomputed_groups=None):
|
1035
|
+
"""
|
1036
|
+
Distribute a parameter to the segments.
|
1037
|
+
|
1038
|
+
Parameters
|
1039
|
+
----------
|
1040
|
+
param_name : str
|
1041
|
+
The name of the parameter to distribute.
|
1042
|
+
precomputed_groups : dict, optional
|
1043
|
+
A dictionary mapping group names to segments. Default is None.
|
1044
|
+
"""
|
1045
|
+
if param_name == 'Ra':
|
1046
|
+
self._distribute_Ra(precomputed_groups)
|
1047
|
+
return
|
1048
|
+
|
1049
|
+
groups_to_segments = precomputed_groups
|
1050
|
+
if groups_to_segments is None:
|
1051
|
+
groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group]
|
1052
|
+
for group in self._groups}
|
1053
|
+
|
1054
|
+
param_distributions = self.params[param_name]
|
1055
|
+
|
1056
|
+
for group_name, distribution in param_distributions.items():
|
1057
|
+
|
1058
|
+
filtered_segments = groups_to_segments[group_name]
|
1059
|
+
|
1060
|
+
if distribution == 'inherit':
|
1061
|
+
for seg in filtered_segments:
|
1062
|
+
value = seg.parent.get_param_value(param_name)
|
1063
|
+
seg.set_param_value(param_name, value)
|
1064
|
+
else:
|
1065
|
+
for seg in filtered_segments:
|
1066
|
+
value = distribution(seg.path_distance())
|
1067
|
+
seg.set_param_value(param_name, value)
|
1068
|
+
|
1069
|
+
|
1070
|
+
def _distribute_Ra(self, precomputed_groups=None):
|
1071
|
+
"""
|
1072
|
+
Distribute the axial resistance to the segments.
|
1073
|
+
"""
|
1074
|
+
|
1075
|
+
groups_to_segments = precomputed_groups
|
1076
|
+
if groups_to_segments is None:
|
1077
|
+
groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group]
|
1078
|
+
for group in self._groups}
|
1079
|
+
|
1080
|
+
param_distributions = self.params['Ra']
|
1081
|
+
|
1082
|
+
for group_name, distribution in param_distributions.items():
|
1083
|
+
|
1084
|
+
filtered_segments = groups_to_segments[group_name]
|
1085
|
+
if distribution == 'inherit':
|
1086
|
+
raise NotImplementedError("Inheritance of Ra is not implemented.")
|
1087
|
+
else:
|
1088
|
+
for seg in filtered_segments:
|
1089
|
+
value = distribution(seg._section.path_distance(0.5))
|
1090
|
+
seg._section._ref.Ra = value
|
1091
|
+
|
1092
|
+
|
1093
|
+
def remove_distribution(self, param_name, group_name):
|
1094
|
+
"""
|
1095
|
+
Remove a distribution for a parameter.
|
1096
|
+
|
1097
|
+
Parameters
|
1098
|
+
----------
|
1099
|
+
param_name : str
|
1100
|
+
The name of the parameter to remove the distribution for.
|
1101
|
+
group_name : str
|
1102
|
+
The name of the group to remove the distribution for.
|
1103
|
+
"""
|
1104
|
+
self.params[param_name].pop(group_name, None)
|
1105
|
+
self.distribute(param_name)
|
1106
|
+
|
1107
|
+
# def set_section_param(self, param_name, value, domains=None):
|
1108
|
+
|
1109
|
+
# domains = domains or self.domains
|
1110
|
+
# for sec in self.sec_tree.sections:
|
1111
|
+
# if sec.domain in domains:
|
1112
|
+
# setattr(sec._ref, param_name, value)
|
1113
|
+
|
1114
|
+
# ========================================================================
|
1115
|
+
# STIMULI
|
1116
|
+
# ========================================================================
|
1117
|
+
|
1118
|
+
# -----------------------------------------------------------------------
|
1119
|
+
# ICLAMPS
|
1120
|
+
# -----------------------------------------------------------------------
|
1121
|
+
|
1122
|
+
def add_iclamp(self, sec, loc, amp=0, delay=100, dur=100):
|
1123
|
+
"""
|
1124
|
+
Add an IClamp to a section.
|
1125
|
+
|
1126
|
+
Parameters
|
1127
|
+
----------
|
1128
|
+
sec : Section
|
1129
|
+
The section to add the IClamp to.
|
1130
|
+
loc : float
|
1131
|
+
The location of the IClamp in the section.
|
1132
|
+
amp : float, optional
|
1133
|
+
The amplitude of the IClamp. Default is 0.
|
1134
|
+
delay : float, optional
|
1135
|
+
The delay of the IClamp. Default is 100.
|
1136
|
+
dur : float, optional
|
1137
|
+
The duration of the IClamp. Default is 100.
|
1138
|
+
"""
|
1139
|
+
seg = sec(loc)
|
1140
|
+
if self.iclamps.get(seg):
|
1141
|
+
self.remove_iclamp(sec, loc)
|
1142
|
+
iclamp = IClamp(sec, loc, amp, delay, dur)
|
1143
|
+
print(f'IClamp added to sec {sec} at loc {loc}.')
|
1144
|
+
self.iclamps[seg] = iclamp
|
1145
|
+
|
1146
|
+
|
1147
|
+
def remove_iclamp(self, sec, loc):
|
1148
|
+
"""
|
1149
|
+
Remove an IClamp from a section.
|
1150
|
+
|
1151
|
+
Parameters
|
1152
|
+
----------
|
1153
|
+
sec : Section
|
1154
|
+
The section to remove the IClamp from.
|
1155
|
+
loc : float
|
1156
|
+
The location of the IClamp in the section.
|
1157
|
+
"""
|
1158
|
+
seg = sec(loc)
|
1159
|
+
if self.iclamps.get(seg):
|
1160
|
+
self.iclamps.pop(seg)
|
1161
|
+
|
1162
|
+
|
1163
|
+
def remove_all_iclamps(self):
|
1164
|
+
"""
|
1165
|
+
Remove all IClamps from the model.
|
1166
|
+
"""
|
1167
|
+
|
1168
|
+
for seg in list(self.iclamps.keys()):
|
1169
|
+
sec, loc = seg._section, seg.x
|
1170
|
+
self.remove_iclamp(sec, loc)
|
1171
|
+
if self.iclamps:
|
1172
|
+
warnings.warn(f'Not all iclamps were removed: {self.iclamps}')
|
1173
|
+
self.iclamps = {}
|
1174
|
+
|
1175
|
+
|
1176
|
+
# -----------------------------------------------------------------------
|
1177
|
+
# SYNAPSES
|
1178
|
+
# -----------------------------------------------------------------------
|
1179
|
+
|
1180
|
+
def _add_population(self, population):
|
1181
|
+
self.populations[population.syn_type][population.name] = population
|
1182
|
+
|
1183
|
+
|
1184
|
+
def add_population(self, segments, N, syn_type):
|
1185
|
+
"""
|
1186
|
+
Add a population of synapses to the model.
|
1187
|
+
|
1188
|
+
Parameters
|
1189
|
+
----------
|
1190
|
+
segments : list[Segment]
|
1191
|
+
The segments to add the synapses to.
|
1192
|
+
N : int
|
1193
|
+
The number of synapses to add.
|
1194
|
+
syn_type : str
|
1195
|
+
The type of synapse to add.
|
1196
|
+
"""
|
1197
|
+
idx = len(self.populations[syn_type])
|
1198
|
+
population = Population(idx, segments, N, syn_type)
|
1199
|
+
population.allocate_synapses()
|
1200
|
+
population.create_inputs()
|
1201
|
+
self._add_population(population)
|
1202
|
+
|
1203
|
+
|
1204
|
+
def update_population_kinetic_params(self, pop_name, **params):
|
1205
|
+
"""
|
1206
|
+
Update the kinetic parameters of a population of synapses.
|
1207
|
+
|
1208
|
+
Parameters
|
1209
|
+
----------
|
1210
|
+
pop_name : str
|
1211
|
+
The name of the population.
|
1212
|
+
params : dict
|
1213
|
+
The parameters to update.
|
1214
|
+
"""
|
1215
|
+
syn_type, idx = pop_name.rsplit('_', 1)
|
1216
|
+
population = self.populations[syn_type][pop_name]
|
1217
|
+
population.update_kinetic_params(**params)
|
1218
|
+
print(population.kinetic_params)
|
1219
|
+
|
1220
|
+
|
1221
|
+
def update_population_input_params(self, pop_name, **params):
|
1222
|
+
"""
|
1223
|
+
Update the input parameters of a population of synapses.
|
1224
|
+
|
1225
|
+
Parameters
|
1226
|
+
----------
|
1227
|
+
pop_name : str
|
1228
|
+
The name of the population.
|
1229
|
+
params : dict
|
1230
|
+
The parameters to update.
|
1231
|
+
"""
|
1232
|
+
syn_type, idx = pop_name.rsplit('_', 1)
|
1233
|
+
population = self.populations[syn_type][pop_name]
|
1234
|
+
population.update_input_params(**params)
|
1235
|
+
print(population.input_params)
|
1236
|
+
|
1237
|
+
|
1238
|
+
def remove_population(self, name):
|
1239
|
+
"""
|
1240
|
+
Remove a population of synapses from the model.
|
1241
|
+
|
1242
|
+
Parameters
|
1243
|
+
----------
|
1244
|
+
name : str
|
1245
|
+
The name of the population
|
1246
|
+
"""
|
1247
|
+
syn_type, idx = name.rsplit('_', 1)
|
1248
|
+
population = self.populations[syn_type].pop(name)
|
1249
|
+
population.clean()
|
1250
|
+
|
1251
|
+
def remove_all_populations(self):
|
1252
|
+
"""
|
1253
|
+
Remove all populations of synapses from the model.
|
1254
|
+
"""
|
1255
|
+
for syn_type in self.populations:
|
1256
|
+
for name in list(self.populations[syn_type].keys()):
|
1257
|
+
self.remove_population(name)
|
1258
|
+
if any(self.populations.values()):
|
1259
|
+
warnings.warn(f'Not all populations were removed: {self.populations}')
|
1260
|
+
self.populations = POPULATIONS
|
1261
|
+
|
1262
|
+
def remove_all_stimuli(self):
|
1263
|
+
"""
|
1264
|
+
Remove all stimuli from the model.
|
1265
|
+
"""
|
1266
|
+
self.remove_all_iclamps()
|
1267
|
+
self.remove_all_populations()
|
1268
|
+
|
1269
|
+
# ========================================================================
|
1270
|
+
# SIMULATION
|
1271
|
+
# ========================================================================
|
1272
|
+
|
1273
|
+
def add_recording(self, sec, loc, var='v'):
|
1274
|
+
"""
|
1275
|
+
Add a recording to the model.
|
1276
|
+
|
1277
|
+
Parameters
|
1278
|
+
----------
|
1279
|
+
sec : Section
|
1280
|
+
The section to record from.
|
1281
|
+
loc : float
|
1282
|
+
The location along the normalized section length to record from.
|
1283
|
+
var : str, optional
|
1284
|
+
The variable to record. Default is 'v'.
|
1285
|
+
"""
|
1286
|
+
self.simulator.add_recording(sec, loc, var)
|
1287
|
+
print(f'Recording added to sec {sec} at loc {loc}.')
|
1288
|
+
|
1289
|
+
|
1290
|
+
def remove_recording(self, sec, loc):
|
1291
|
+
"""
|
1292
|
+
Remove a recording from the model.
|
1293
|
+
|
1294
|
+
Parameters
|
1295
|
+
----------
|
1296
|
+
sec : Section
|
1297
|
+
The section to remove the recording from.
|
1298
|
+
loc : float
|
1299
|
+
The location along the normalized section length to remove the recording from.
|
1300
|
+
"""
|
1301
|
+
self.simulator.remove_recording(sec, loc)
|
1302
|
+
|
1303
|
+
|
1304
|
+
def remove_all_recordings(self):
|
1305
|
+
"""
|
1306
|
+
Remove all recordings from the model.
|
1307
|
+
"""
|
1308
|
+
self.simulator.remove_all_recordings()
|
1309
|
+
|
1310
|
+
|
1311
|
+
def run(self, duration=300):
|
1312
|
+
"""
|
1313
|
+
Run the simulation for a specified duration.
|
1314
|
+
|
1315
|
+
Parameters
|
1316
|
+
----------
|
1317
|
+
duration : float, optional
|
1318
|
+
The duration of the simulation. Default is 300.
|
1319
|
+
"""
|
1320
|
+
self.simulator.run(duration)
|
1321
|
+
|
1322
|
+
def get_traces(self):
|
1323
|
+
return self.simulator.get_traces()
|
1324
|
+
|
1325
|
+
def plot(self, *args, **kwargs):
|
1326
|
+
self.simulator.plot(*args, **kwargs)
|
1327
|
+
|
1328
|
+
# ========================================================================
|
1329
|
+
# MORPHOLOGY
|
1330
|
+
# ========================================================================
|
1331
|
+
|
1332
|
+
def remove_subtree(self, sec):
|
1333
|
+
"""
|
1334
|
+
Remove a subtree from the model.
|
1335
|
+
|
1336
|
+
Parameters
|
1337
|
+
----------
|
1338
|
+
sec : Section
|
1339
|
+
The root section of the subtree to remove.
|
1340
|
+
"""
|
1341
|
+
self.sec_tree.remove_subtree(sec)
|
1342
|
+
self.sec_tree.sort()
|
1343
|
+
self._remove_empty()
|
1344
|
+
|
1345
|
+
|
1346
|
+
def merge_domains(self, domain_names: List[str]):
|
1347
|
+
"""
|
1348
|
+
Merge two domains into one.
|
1349
|
+
"""
|
1350
|
+
domains = [self.domains[domain_name] for domain_name in domain_names]
|
1351
|
+
for domain in domains[1:]:
|
1352
|
+
domains[0].merge(domain)
|
1353
|
+
self.remove_empty()
|
1354
|
+
|
1355
|
+
|
1356
|
+
def reduce_subtree(self, root, reduction_frequency=0, total_segments_manual=-1, fit=True):
|
1357
|
+
"""
|
1358
|
+
Reduce a subtree to a single section.
|
1359
|
+
|
1360
|
+
Parameters
|
1361
|
+
----------
|
1362
|
+
root : Section
|
1363
|
+
The root section of the subtree to reduce.
|
1364
|
+
reduction_frequency : float, optional
|
1365
|
+
The frequency of the reduction. Default is 0.
|
1366
|
+
total_segments_manual : int, optional
|
1367
|
+
The number of segments in the reduced subtree. Default is -1 (automatic).
|
1368
|
+
fit : bool, optional
|
1369
|
+
Whether to create distributions for the reduced subtree by fitting
|
1370
|
+
the calculated average values. Default is True.
|
1371
|
+
"""
|
1372
|
+
|
1373
|
+
domain_name = root.domain
|
1374
|
+
parent = root.parent
|
1375
|
+
domains_in_subtree = [self.domains[domain_name]
|
1376
|
+
for domain_name in set([sec.domain for sec in root.subtree])]
|
1377
|
+
if len(domains_in_subtree) > 1:
|
1378
|
+
# ensure the domains have the same mechanisms using self.domains_to_mechs
|
1379
|
+
domains_to_mechs = {domain_name: mech_names for domain_name, mech_names
|
1380
|
+
in self.domains_to_mechs.items() if domain_name in [domain.name for domain in domains_in_subtree]}
|
1381
|
+
common_mechs = set.intersection(*domains_to_mechs.values())
|
1382
|
+
if not common_mechs:
|
1383
|
+
raise ValueError(
|
1384
|
+
'The domains in the subtree have different mechanisms. '
|
1385
|
+
'Please ensure that all domains in the subtree have the same mechanisms. '
|
1386
|
+
'You may need to insert the missing mechanisms and set their conductances to 0 where they are not needed.'
|
1387
|
+
)
|
1388
|
+
|
1389
|
+
inserted_mechs = {mech_name: mech for mech_name, mech
|
1390
|
+
in self.mechanisms.items()
|
1391
|
+
if mech_name in self.domains_to_mechs[domain_name]
|
1392
|
+
}
|
1393
|
+
|
1394
|
+
subtree_without_root = [sec for sec in root.subtree if sec is not root]
|
1395
|
+
|
1396
|
+
# Map original segment names to their parameters
|
1397
|
+
segs_to_params = rdc.map_segs_to_params(root, inserted_mechs)
|
1398
|
+
|
1399
|
+
|
1400
|
+
# Temporarily remove active mechanisms
|
1401
|
+
for mech_name in inserted_mechs:
|
1402
|
+
if mech_name == 'Leak':
|
1403
|
+
continue
|
1404
|
+
for sec in root.subtree:
|
1405
|
+
sec.uninsert_mechanism(mech_name)
|
1406
|
+
|
1407
|
+
# Disconnect
|
1408
|
+
root.disconnect_from_parent()
|
1409
|
+
|
1410
|
+
# Calculate new properties of a reduced subtree
|
1411
|
+
new_cable_properties = rdc.get_unique_cable_properties(root._ref, reduction_frequency)
|
1412
|
+
new_nseg = rdc.calculate_nsegs(new_cable_properties, total_segments_manual)
|
1413
|
+
print(new_cable_properties)
|
1414
|
+
|
1415
|
+
|
1416
|
+
# Map segment names to their new locations in the reduced cylinder
|
1417
|
+
segs_to_locs = rdc.map_segs_to_locs(root, reduction_frequency, new_cable_properties)
|
1418
|
+
|
1419
|
+
|
1420
|
+
# Reconnect
|
1421
|
+
root.connect_to_parent(parent)
|
1422
|
+
|
1423
|
+
# Delete the original subtree
|
1424
|
+
children = root.children[:]
|
1425
|
+
for child_sec in children:
|
1426
|
+
self.remove_subtree(child_sec)
|
1427
|
+
|
1428
|
+
# Set passive mechanisms for the reduced cylinder:
|
1429
|
+
rdc.apply_params_to_section(root, new_cable_properties, new_nseg)
|
1430
|
+
|
1431
|
+
|
1432
|
+
# Reinsert active mechanisms
|
1433
|
+
for mech_name in inserted_mechs:
|
1434
|
+
if mech_name == 'Leak':
|
1435
|
+
continue
|
1436
|
+
for sec in root.subtree:
|
1437
|
+
sec.insert_mechanism(mech_name)
|
1438
|
+
|
1439
|
+
# Replace locs with corresponding segs
|
1440
|
+
|
1441
|
+
segs_to_reduced_segs = rdc.map_segs_to_reduced_segs(segs_to_locs, root)
|
1442
|
+
|
1443
|
+
# Map reduced segments to lists of parameters of corresponding original segments
|
1444
|
+
reduced_segs_to_params = rdc.map_reduced_segs_to_params(segs_to_reduced_segs, segs_to_params)
|
1445
|
+
|
1446
|
+
# Set new values of parameters
|
1447
|
+
rdc.set_avg_params_to_reduced_segs(reduced_segs_to_params)
|
1448
|
+
rdc.interpolate_missing_values(reduced_segs_to_params, root)
|
1449
|
+
|
1450
|
+
if not fit:
|
1451
|
+
return
|
1452
|
+
|
1453
|
+
root_segs = [seg for seg in root.segments]
|
1454
|
+
params_to_coeffs = {}
|
1455
|
+
for param_name in self.params:
|
1456
|
+
coeffs = self.fit_distribution(param_name, segments=root_segs, plot=False)
|
1457
|
+
params_to_coeffs[param_name] = coeffs
|
1458
|
+
|
1459
|
+
|
1460
|
+
# Create new domain
|
1461
|
+
reduced_domains = [domain_name for domain_name in self.domains if domain_name.startswith('reduced')]
|
1462
|
+
new_reduced_domain_name = f'reduced_{len(reduced_domains)}'
|
1463
|
+
group_name = new_reduced_domain_name
|
1464
|
+
self.define_domain(new_reduced_domain_name, sections=[root], distribute=False)
|
1465
|
+
|
1466
|
+
|
1467
|
+
# Reinsert active mechanisms after creating the new domain
|
1468
|
+
for mech_name in inserted_mechs:
|
1469
|
+
root.insert_mechanism(mech_name)
|
1470
|
+
self.domains_to_mechs[new_reduced_domain_name] = set(inserted_mechs.keys())
|
1471
|
+
|
1472
|
+
|
1473
|
+
# # Fit distributions to data for the group
|
1474
|
+
for param_name, coeffs in params_to_coeffs.items():
|
1475
|
+
self._set_distribution(param_name, group_name, coeffs, plot=True)
|
1476
|
+
|
1477
|
+
# # Distribute parameters
|
1478
|
+
self.distribute_all()
|
1479
|
+
|
1480
|
+
return {
|
1481
|
+
'domain_name': new_reduced_domain_name,
|
1482
|
+
'group_name': group_name,
|
1483
|
+
'segs_to_params': segs_to_params,
|
1484
|
+
'segs_to_locs': segs_to_locs,
|
1485
|
+
'segs_to_reduced_segs': segs_to_reduced_segs,
|
1486
|
+
'reduced_segs_to_params': reduced_segs_to_params,
|
1487
|
+
'params_to_coeffs': params_to_coeffs
|
1488
|
+
}
|
1489
|
+
|
1490
|
+
|
1491
|
+
def fit_distribution(self, param_name, segments, max_degree=6, tolerance=1e-7, plot=False):
|
1492
|
+
from numpy import polyfit, polyval
|
1493
|
+
values = [seg.get_param_value(param_name) for seg in segments]
|
1494
|
+
distances = [seg.path_distance() for seg in segments]
|
1495
|
+
sorted_pairs = sorted(zip(distances, values))
|
1496
|
+
distances, values = zip(*sorted_pairs)
|
1497
|
+
degrees = range(0, max_degree+1)
|
1498
|
+
for degree in degrees:
|
1499
|
+
coeffs = polyfit(distances, values, degree)
|
1500
|
+
residuals = values - polyval(coeffs, distances)
|
1501
|
+
if all(abs(residuals) < tolerance):
|
1502
|
+
break
|
1503
|
+
if not all(abs(residuals) < tolerance):
|
1504
|
+
warnings.warn(f'Fitting failed for parameter {param_name} with the provided tolerance.\nUsing the last valid fit (degree={degree}). Maximum residual: {max(abs(residuals))}')
|
1505
|
+
if plot and degree > 0:
|
1506
|
+
self.plot_param(param_name, show_nan=False)
|
1507
|
+
plt.plot(distances, polyval(coeffs, distances), label='Fitted', color='red', linestyle='--')
|
1508
|
+
plt.legend()
|
1509
|
+
return coeffs
|
1510
|
+
|
1511
|
+
|
1512
|
+
def _set_distribution(self, param_name, group_name, coeffs, plot=False):
|
1513
|
+
# Set the distribution based on the degree of the polynomial fit
|
1514
|
+
coeffs = np.where(np.round(coeffs) == 0, coeffs, np.round(coeffs, 10))
|
1515
|
+
if len(coeffs) == 1:
|
1516
|
+
self.params[param_name][group_name] = Distribution('constant', value=coeffs[0])
|
1517
|
+
elif len(coeffs) == 2:
|
1518
|
+
self.params[param_name][group_name] = Distribution('linear', slope=coeffs[0], intercept=coeffs[1])
|
1519
|
+
else:
|
1520
|
+
self.params[param_name][group_name] = Distribution('polynomial', coeffs=coeffs)
|
1521
|
+
|
1522
|
+
|
1523
|
+
# ========================================================================
|
1524
|
+
# PLOTTING
|
1525
|
+
# ========================================================================
|
1526
|
+
|
1527
|
+
def plot_param(self, param_name, ax=None, show_nan=True):
|
1528
|
+
"""
|
1529
|
+
Plot the distribution of a parameter in the model.
|
1530
|
+
|
1531
|
+
Parameters
|
1532
|
+
----------
|
1533
|
+
param_name : str
|
1534
|
+
The name of the parameter to plot.
|
1535
|
+
ax : matplotlib.axes.Axes, optional
|
1536
|
+
The axes to plot on. Default is None.
|
1537
|
+
show_nan : bool, optional
|
1538
|
+
Whether to show NaN values. Default is True.
|
1539
|
+
"""
|
1540
|
+
if ax is None:
|
1541
|
+
fig, ax = plt.subplots(figsize=(10, 2))
|
1542
|
+
|
1543
|
+
if param_name not in self.params:
|
1544
|
+
warnings.warn(f'Parameter {param_name} not found.')
|
1545
|
+
|
1546
|
+
values = [(seg.path_distance(), seg.get_param_value(param_name)) for seg in self.seg_tree]
|
1547
|
+
colors = [get_domain_color(seg.domain) for seg in self.seg_tree]
|
1548
|
+
|
1549
|
+
valid_values = [(x, y) for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
|
1550
|
+
zero_values = [(x, y) for (x, y), color in zip(values, colors) if y == 0]
|
1551
|
+
nan_values = [(x, 0) for (x, y), color in zip(values, colors) if pd.isna(y)]
|
1552
|
+
valid_colors = [color for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
|
1553
|
+
zero_colors = [color for (x, y), color in zip(values, colors) if y == 0]
|
1554
|
+
nan_colors = [color for (x, y), color in zip(values, colors) if pd.isna(y)]
|
1555
|
+
|
1556
|
+
if valid_values:
|
1557
|
+
ax.scatter(*zip(*valid_values), c=valid_colors)
|
1558
|
+
if zero_values:
|
1559
|
+
ax.scatter(*zip(*zero_values), edgecolors=zero_colors, facecolors='none', marker='.')
|
1560
|
+
if nan_values and show_nan:
|
1561
|
+
ax.scatter(*zip(*nan_values), c=nan_colors, marker='x', alpha=0.5, zorder=0)
|
1562
|
+
plt.axhline(y=0, color='k', linestyle='--')
|
1563
|
+
|
1564
|
+
ax.set_xlabel('Path distance')
|
1565
|
+
ax.set_ylabel(param_name)
|
1566
|
+
ax.set_title(f'{param_name} distribution')
|
1567
|
+
|
1568
|
+
|
1569
|
+
# ========================================================================
|
1570
|
+
# FILE EXPORT
|
1571
|
+
# ========================================================================
|
1572
|
+
|
1573
|
+
def export_morphology(self, file_name):
|
1574
|
+
"""
|
1575
|
+
Write the SWC tree to an SWC file.
|
1576
|
+
|
1577
|
+
Parameters
|
1578
|
+
----------
|
1579
|
+
version : str, optional
|
1580
|
+
The version of the morphology appended to the morphology name.
|
1581
|
+
"""
|
1582
|
+
path_to_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
|
1583
|
+
|
1584
|
+
self.point_tree.to_swc(path_to_file)
|
1585
|
+
|
1586
|
+
|
1587
|
+
def to_dict(self):
|
1588
|
+
"""
|
1589
|
+
Return a dictionary representation of the model.
|
1590
|
+
|
1591
|
+
Returns
|
1592
|
+
-------
|
1593
|
+
dict
|
1594
|
+
The dictionary representation of the model.
|
1595
|
+
"""
|
1596
|
+
return {
|
1597
|
+
'metadata': {
|
1598
|
+
'name': self.name,
|
1599
|
+
},
|
1600
|
+
'd_lambda': self.d_lambda,
|
1601
|
+
'domains': {domain: list(mechs) for domain, mechs in self.domains_to_mechs.items()},
|
1602
|
+
'groups': [
|
1603
|
+
group.to_dict() for group in self._groups
|
1604
|
+
],
|
1605
|
+
'params': {
|
1606
|
+
param_name: {
|
1607
|
+
group_name: distribution if isinstance(distribution, str) else distribution.to_dict()
|
1608
|
+
for group_name, distribution in distributions.items()
|
1609
|
+
}
|
1610
|
+
for param_name, distributions in self.params.items()
|
1611
|
+
},
|
1612
|
+
}
|
1613
|
+
|
1614
|
+
def from_dict(self, data):
|
1615
|
+
"""
|
1616
|
+
Load the model from a dictionary.
|
1617
|
+
|
1618
|
+
Parameters
|
1619
|
+
----------
|
1620
|
+
data : dict
|
1621
|
+
The dictionary representation of the model.
|
1622
|
+
"""
|
1623
|
+
if not self.name == data['metadata']['name']:
|
1624
|
+
raise ValueError('Model name does not match the data.')
|
1625
|
+
|
1626
|
+
self.d_lambda = data['d_lambda']
|
1627
|
+
|
1628
|
+
# Domains and mechanisms
|
1629
|
+
self.domains_to_mechs = {
|
1630
|
+
domain: set(mechs) for domain, mechs in data['domains'].items()
|
1631
|
+
}
|
1632
|
+
if self.verbose: print('Inserting mechanisms...')
|
1633
|
+
for domain_name, mechs in self.domains_to_mechs.items():
|
1634
|
+
for mech_name in mechs:
|
1635
|
+
self.insert_mechanism(mech_name, domain_name, distribute=False)
|
1636
|
+
# print('Distributing parameters...')
|
1637
|
+
# self.distribute_all()
|
1638
|
+
|
1639
|
+
# Groups
|
1640
|
+
if self.verbose: print('Adding groups...')
|
1641
|
+
self._groups = [SegmentGroup.from_dict(group) for group in data['groups']]
|
1642
|
+
|
1643
|
+
if self.verbose: print('Distributing parameters...')
|
1644
|
+
# Parameters
|
1645
|
+
self.params = {
|
1646
|
+
param_name: {
|
1647
|
+
group_name: distribution if isinstance(distribution, str) else Distribution.from_dict(distribution)
|
1648
|
+
for group_name, distribution in distributions.items()
|
1649
|
+
}
|
1650
|
+
for param_name, distributions in data['params'].items()
|
1651
|
+
}
|
1652
|
+
|
1653
|
+
if self.verbose: print('Setting segmentation...')
|
1654
|
+
if self.sec_tree is not None:
|
1655
|
+
d_lambda = self.d_lambda
|
1656
|
+
self.set_segmentation(d_lambda=d_lambda)
|
1657
|
+
|
1658
|
+
|
1659
|
+
|
1660
|
+
def export_membrane(self, file_name, **kwargs):
|
1661
|
+
"""
|
1662
|
+
Export the membrane properties of the model to a JSON file.
|
1663
|
+
|
1664
|
+
Parameters
|
1665
|
+
----------
|
1666
|
+
file_name : str
|
1667
|
+
The name of the file to write to.
|
1668
|
+
**kwargs : dict
|
1669
|
+
Additional keyword arguments to pass to `json.dump`.
|
1670
|
+
"""
|
1671
|
+
|
1672
|
+
path_to_json = self.path_manager.get_file_path('membrane', file_name, extension='json')
|
1673
|
+
|
1674
|
+
data = self.to_dict()
|
1675
|
+
with open(path_to_json, 'w') as f:
|
1676
|
+
json.dump(data, f, **kwargs)
|
1677
|
+
|
1678
|
+
|
1679
|
+
def load_membrane(self, file_name, recompile=True):
|
1680
|
+
"""
|
1681
|
+
Load the membrane properties of the model from a JSON file.
|
1682
|
+
|
1683
|
+
Parameters
|
1684
|
+
----------
|
1685
|
+
file_name : str
|
1686
|
+
The name of the file to read from.
|
1687
|
+
recompile : bool, optional
|
1688
|
+
Whether to recompile the mechanisms after loading. Default is True.
|
1689
|
+
"""
|
1690
|
+
self.add_default_mechanisms()
|
1691
|
+
self.add_mechanisms('mod', recompile=recompile)
|
1692
|
+
|
1693
|
+
path_to_json = self.path_manager.get_file_path('membrane', file_name, extension='json')
|
1694
|
+
|
1695
|
+
with open(path_to_json, 'r') as f:
|
1696
|
+
data = json.load(f)
|
1697
|
+
|
1698
|
+
self.from_dict(data)
|
1699
|
+
|
1700
|
+
|
1701
|
+
def stimuli_to_dict(self):
|
1702
|
+
"""
|
1703
|
+
Convert the stimuli to a dictionary representation.
|
1704
|
+
|
1705
|
+
Returns
|
1706
|
+
-------
|
1707
|
+
dict
|
1708
|
+
The dictionary representation of the stimuli.
|
1709
|
+
"""
|
1710
|
+
return {
|
1711
|
+
'metadata': {
|
1712
|
+
'name': self.name,
|
1713
|
+
},
|
1714
|
+
'simulation': {
|
1715
|
+
**self.simulator.to_dict(),
|
1716
|
+
},
|
1717
|
+
'stimuli': {
|
1718
|
+
'iclamps': [
|
1719
|
+
{
|
1720
|
+
'name': f'iclamp_{i}',
|
1721
|
+
'amp': iclamp.amp,
|
1722
|
+
'delay': iclamp.delay,
|
1723
|
+
'dur': iclamp.dur
|
1724
|
+
}
|
1725
|
+
for i, (seg, iclamp) in enumerate(self.iclamps.items())
|
1726
|
+
],
|
1727
|
+
'populations': {
|
1728
|
+
syn_type: [pop.to_dict() for pop in pops.values()]
|
1729
|
+
for syn_type, pops in self.populations.items()
|
1730
|
+
}
|
1731
|
+
},
|
1732
|
+
}
|
1733
|
+
|
1734
|
+
|
1735
|
+
def _stimuli_to_csv(self, path_to_csv=None):
|
1736
|
+
"""
|
1737
|
+
Write the model to a CSV file.
|
1738
|
+
|
1739
|
+
Parameters
|
1740
|
+
----------
|
1741
|
+
path_to_csv : str
|
1742
|
+
The path to the CSV file to write.
|
1743
|
+
"""
|
1744
|
+
|
1745
|
+
rec_data = {
|
1746
|
+
'type': ['recording'] * len(self.recordings),
|
1747
|
+
'idx': [i for i in range(len(self.recordings))],
|
1748
|
+
'sec_idx': [seg._section.idx for seg in self.recordings],
|
1749
|
+
'loc': [seg.x for seg in self.recordings],
|
1750
|
+
}
|
1751
|
+
|
1752
|
+
iclamp_data = {
|
1753
|
+
'type': ['iclamp'] * len(self.iclamps),
|
1754
|
+
'idx': [i for i in range(len(self.iclamps))],
|
1755
|
+
'sec_idx': [seg._section.idx for seg in self.iclamps],
|
1756
|
+
'loc': [seg.x for seg in self.iclamps],
|
1757
|
+
}
|
1758
|
+
|
1759
|
+
synapses_data = {
|
1760
|
+
'type': [],
|
1761
|
+
'idx': [],
|
1762
|
+
'sec_idx': [],
|
1763
|
+
'loc': [],
|
1764
|
+
}
|
1765
|
+
|
1766
|
+
for syn_type, pops in self.populations.items():
|
1767
|
+
for pop_name, pop in pops.items():
|
1768
|
+
pop_data = pop.to_csv()
|
1769
|
+
synapses_data['type'] += pop_data['syn_type']
|
1770
|
+
synapses_data['idx'] += [int(name.rsplit('_', 1)[1]) for name in pop_data['name']]
|
1771
|
+
synapses_data['sec_idx'] += pop_data['sec_idx']
|
1772
|
+
synapses_data['loc'] += pop_data['loc']
|
1773
|
+
|
1774
|
+
df = pd.concat([
|
1775
|
+
pd.DataFrame(rec_data),
|
1776
|
+
pd.DataFrame(iclamp_data),
|
1777
|
+
pd.DataFrame(synapses_data)
|
1778
|
+
], ignore_index=True)
|
1779
|
+
df['sec_idx'] = df['sec_idx'].astype(int)
|
1780
|
+
if path_to_csv: df.to_csv(path_to_csv, index=False)
|
1781
|
+
|
1782
|
+
return df
|
1783
|
+
|
1784
|
+
|
1785
|
+
def export_stimuli(self, file_name, **kwargs):
|
1786
|
+
"""
|
1787
|
+
Export the stimuli to a JSON and CSV file.
|
1788
|
+
|
1789
|
+
Parameters
|
1790
|
+
----------
|
1791
|
+
file_name : str
|
1792
|
+
The name of the file to write to.
|
1793
|
+
**kwargs : dict
|
1794
|
+
Additional keyword arguments to pass to `json.dump`.
|
1795
|
+
"""
|
1796
|
+
path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
|
1797
|
+
|
1798
|
+
data = self.stimuli_to_dict()
|
1799
|
+
|
1800
|
+
with open(path_to_json, 'w') as f:
|
1801
|
+
json.dump(data, f, **kwargs)
|
1802
|
+
|
1803
|
+
path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
|
1804
|
+
self._stimuli_to_csv(path_to_stimuli_csv)
|
1805
|
+
|
1806
|
+
|
1807
|
+
def load_stimuli(self, file_name):
|
1808
|
+
"""
|
1809
|
+
Load the stimuli from a JSON file.
|
1810
|
+
|
1811
|
+
Parameters
|
1812
|
+
----------
|
1813
|
+
file_name : str
|
1814
|
+
The name of the file to read from.
|
1815
|
+
"""
|
1816
|
+
|
1817
|
+
path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
|
1818
|
+
path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
|
1819
|
+
|
1820
|
+
with open(path_to_json, 'r') as f:
|
1821
|
+
data = json.load(f)
|
1822
|
+
|
1823
|
+
if not self.name == data['metadata']['name']:
|
1824
|
+
raise ValueError('Model name does not match the data.')
|
1825
|
+
|
1826
|
+
df_stimuli = pd.read_csv(path_to_stimuli_csv)
|
1827
|
+
|
1828
|
+
self.simulator.from_dict(data['simulation'])
|
1829
|
+
|
1830
|
+
# Recordings ---------------------------------------------------------
|
1831
|
+
|
1832
|
+
df_recs = df_stimuli[df_stimuli['type'] == 'recording']
|
1833
|
+
for i, row in df_recs.iterrows():
|
1834
|
+
self.add_recording(
|
1835
|
+
self.sec_tree.sections[row['sec_idx']], row['loc']
|
1836
|
+
)
|
1837
|
+
|
1838
|
+
# IClamps -----------------------------------------------------------
|
1839
|
+
|
1840
|
+
df_iclamps = df_stimuli[df_stimuli['type'] == 'iclamp'].reset_index(drop=True, inplace=False)
|
1841
|
+
|
1842
|
+
for i, row in df_iclamps.iterrows():
|
1843
|
+
self.add_iclamp(
|
1844
|
+
self.sec_tree.sections[row['sec_idx']],
|
1845
|
+
row['loc'],
|
1846
|
+
data['stimuli']['iclamps'][i]['amp'],
|
1847
|
+
data['stimuli']['iclamps'][i]['delay'],
|
1848
|
+
data['stimuli']['iclamps'][i]['dur']
|
1849
|
+
)
|
1850
|
+
|
1851
|
+
# Populations -------------------------------------------------------
|
1852
|
+
|
1853
|
+
syn_types = ['AMPA', 'NMDA', 'AMPA_NMDA', 'GABAa']
|
1854
|
+
|
1855
|
+
for syn_type in syn_types:
|
1856
|
+
|
1857
|
+
df_syn = df_stimuli[df_stimuli['type'] == syn_type]
|
1858
|
+
|
1859
|
+
for i, pop_data in enumerate(data['stimuli']['populations'][syn_type]):
|
1860
|
+
|
1861
|
+
df_pop = df_syn[df_syn['idx'] == i]
|
1862
|
+
|
1863
|
+
segments = [self.sec_tree.sections[sec_idx](loc)
|
1864
|
+
for sec_idx, loc in zip(df_pop['sec_idx'], df_pop['loc'])]
|
1865
|
+
|
1866
|
+
pop = Population(i,
|
1867
|
+
segments,
|
1868
|
+
pop_data['N'],
|
1869
|
+
'AMPA')
|
1870
|
+
|
1871
|
+
syn_locs = [(self.sec_tree.sections[sec_idx], loc) for sec_idx, loc in zip(df_pop['sec_idx'].tolist(), df_pop['loc'].tolist())]
|
1872
|
+
|
1873
|
+
pop.allocate_synapses(syn_locs=syn_locs)
|
1874
|
+
pop.update_kinetic_params(**pop_data['kinetic_params'])
|
1875
|
+
pop.update_input_params(**pop_data['input_params'])
|
1876
|
+
self._add_population(pop)
|
1877
|
+
|
1878
|
+
|
1879
|
+
def export_to_NEURON(self, file_name, include_kinetic_params=True):
|
1880
|
+
"""
|
1881
|
+
Export the model to a python file using NEURON.
|
1882
|
+
|
1883
|
+
Parameters
|
1884
|
+
----------
|
1885
|
+
file_name : str
|
1886
|
+
The name of the file to write to.
|
1887
|
+
"""
|
1888
|
+
from dendrotweaks.model_io import render_template
|
1889
|
+
from dendrotweaks.model_io import get_params_to_valid_domains
|
1890
|
+
from dendrotweaks.model_io import filter_params
|
1891
|
+
from dendrotweaks.model_io import get_neuron_domain
|
1892
|
+
|
1893
|
+
params_to_valid_domains = get_params_to_valid_domains(self)
|
1894
|
+
params = self.params if include_kinetic_params else filter_params(self)
|
1895
|
+
path_to_template = self.path_manager.get_file_path('templates', 'NEURON_template', extension='py')
|
1896
|
+
|
1897
|
+
output = render_template(path_to_template,
|
1898
|
+
{
|
1899
|
+
'param_dict': params,
|
1900
|
+
'groups_dict': self.groups,
|
1901
|
+
'params_to_mechs': self.params_to_mechs,
|
1902
|
+
'domains_to_mechs': self.domains_to_mechs,
|
1903
|
+
'iclamps': self.iclamps,
|
1904
|
+
'recordings': self.simulator.recordings,
|
1905
|
+
'params_to_valid_domains': params_to_valid_domains,
|
1906
|
+
'domains_to_NEURON': {domain: get_neuron_domain(domain) for domain in self.domains},
|
1907
|
+
})
|
1908
|
+
|
1909
|
+
if not file_name.endswith('.py'):
|
1910
|
+
file_name += '.py'
|
1911
|
+
path_to_model = self.path_manager.path_to_model
|
1912
|
+
output_path = os.path.join(path_to_model, file_name)
|
1913
|
+
with open(output_path, 'w') as f:
|
1914
|
+
f.write(output)
|
1915
|
+
|
1916
|
+
|