dendrotweaks 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dendrotweaks/__init__.py +10 -0
  2. dendrotweaks/analysis/__init__.py +11 -0
  3. dendrotweaks/analysis/ephys_analysis.py +482 -0
  4. dendrotweaks/analysis/morphometric_analysis.py +106 -0
  5. dendrotweaks/membrane/__init__.py +6 -0
  6. dendrotweaks/membrane/default_mod/AMPA.mod +65 -0
  7. dendrotweaks/membrane/default_mod/AMPA_NMDA.mod +100 -0
  8. dendrotweaks/membrane/default_mod/CaDyn.mod +54 -0
  9. dendrotweaks/membrane/default_mod/GABAa.mod +65 -0
  10. dendrotweaks/membrane/default_mod/Leak.mod +27 -0
  11. dendrotweaks/membrane/default_mod/NMDA.mod +72 -0
  12. dendrotweaks/membrane/default_mod/vecstim.mod +76 -0
  13. dendrotweaks/membrane/default_templates/NEURON_template.py +354 -0
  14. dendrotweaks/membrane/default_templates/default.py +73 -0
  15. dendrotweaks/membrane/default_templates/standard_channel.mod +87 -0
  16. dendrotweaks/membrane/default_templates/template_jaxley.py +108 -0
  17. dendrotweaks/membrane/default_templates/template_jaxley_new.py +108 -0
  18. dendrotweaks/membrane/distributions.py +324 -0
  19. dendrotweaks/membrane/groups.py +103 -0
  20. dendrotweaks/membrane/io/__init__.py +11 -0
  21. dendrotweaks/membrane/io/ast.py +201 -0
  22. dendrotweaks/membrane/io/code_generators.py +312 -0
  23. dendrotweaks/membrane/io/converter.py +108 -0
  24. dendrotweaks/membrane/io/factories.py +144 -0
  25. dendrotweaks/membrane/io/grammar.py +417 -0
  26. dendrotweaks/membrane/io/loader.py +90 -0
  27. dendrotweaks/membrane/io/parser.py +499 -0
  28. dendrotweaks/membrane/io/reader.py +212 -0
  29. dendrotweaks/membrane/mechanisms.py +574 -0
  30. dendrotweaks/model.py +1916 -0
  31. dendrotweaks/model_io.py +75 -0
  32. dendrotweaks/morphology/__init__.py +5 -0
  33. dendrotweaks/morphology/domains.py +100 -0
  34. dendrotweaks/morphology/io/__init__.py +5 -0
  35. dendrotweaks/morphology/io/factories.py +212 -0
  36. dendrotweaks/morphology/io/reader.py +66 -0
  37. dendrotweaks/morphology/io/validation.py +212 -0
  38. dendrotweaks/morphology/point_trees.py +681 -0
  39. dendrotweaks/morphology/reduce/__init__.py +16 -0
  40. dendrotweaks/morphology/reduce/reduce.py +155 -0
  41. dendrotweaks/morphology/reduce/reduced_cylinder.py +129 -0
  42. dendrotweaks/morphology/sec_trees.py +1112 -0
  43. dendrotweaks/morphology/seg_trees.py +157 -0
  44. dendrotweaks/morphology/trees.py +567 -0
  45. dendrotweaks/path_manager.py +261 -0
  46. dendrotweaks/simulators.py +235 -0
  47. dendrotweaks/stimuli/__init__.py +3 -0
  48. dendrotweaks/stimuli/iclamps.py +73 -0
  49. dendrotweaks/stimuli/populations.py +265 -0
  50. dendrotweaks/stimuli/synapses.py +203 -0
  51. dendrotweaks/utils.py +239 -0
  52. dendrotweaks-0.3.1.dist-info/METADATA +70 -0
  53. dendrotweaks-0.3.1.dist-info/RECORD +56 -0
  54. dendrotweaks-0.3.1.dist-info/WHEEL +5 -0
  55. dendrotweaks-0.3.1.dist-info/licenses/LICENSE +674 -0
  56. dendrotweaks-0.3.1.dist-info/top_level.txt +1 -0
dendrotweaks/model.py ADDED
@@ -0,0 +1,1916 @@
1
+ from typing import List, Union, Callable
2
+ import os
3
+ import json
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import quantities as pq
7
+
8
+ from dendrotweaks.morphology.point_trees import PointTree
9
+ from dendrotweaks.morphology.sec_trees import Section, SectionTree, Domain
10
+ from dendrotweaks.morphology.seg_trees import Segment, SegmentTree
11
+ from dendrotweaks.simulators import NEURONSimulator
12
+ from dendrotweaks.membrane.groups import SegmentGroup
13
+ from dendrotweaks.membrane.mechanisms import Mechanism, LeakChannel, CaDynamics
14
+ from dendrotweaks.membrane.io import create_channel, standardize_channel, create_standard_channel
15
+ from dendrotweaks.membrane.io import MODFileLoader
16
+ from dendrotweaks.morphology.io import create_point_tree, create_section_tree, create_segment_tree
17
+ from dendrotweaks.stimuli.iclamps import IClamp
18
+ from dendrotweaks.membrane.distributions import Distribution
19
+ from dendrotweaks.stimuli.populations import Population
20
+ from dendrotweaks.utils import calculate_lambda_f, dynamic_import
21
+ from dendrotweaks.utils import get_domain_color, timeit
22
+
23
+ from collections import OrderedDict, defaultdict
24
+ from numpy import nan
25
+ # from .logger import logger
26
+
27
+ from dendrotweaks.path_manager import PathManager
28
+ import dendrotweaks.morphology.reduce as rdc
29
+
30
+ import pandas as pd
31
+
32
+ import warnings
33
+
34
+ POPULATIONS = {'AMPA': {}, 'NMDA': {}, 'AMPA_NMDA': {}, 'GABAa': {}}
35
+
36
+ def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
37
+ return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
38
+
39
+ warnings.formatwarning = custom_warning_formatter
40
+
41
+ INDEPENDENT_PARAMS = {
42
+ 'cm': 1, # uF/cm2
43
+ 'Ra': 100, # Ohm cm
44
+ 'ena': 50, # mV
45
+ 'ek': -77, # mV
46
+ 'eca': 140 # mV
47
+ }
48
+
49
+ DOMAIN_TO_GROUP = {
50
+ 'soma': 'somatic',
51
+ 'axon': 'axonal',
52
+ 'dend': 'dendritic',
53
+ 'apic': 'apical',
54
+ }
55
+
56
+
57
+ class Model():
58
+ """
59
+ A model object that represents a neuron model.
60
+
61
+ Parameters
62
+ ----------
63
+ name : str
64
+ The name of the model.
65
+ simulator_name : str
66
+ The name of the simulator to use (either 'NEURON' or 'Jaxley').
67
+ path_to_data : str
68
+ The path to the data files where swc and mod files are stored.
69
+
70
+ Attributes
71
+ ----------
72
+ path_to_model : str
73
+ The path to the model directory.
74
+ path_manager : PathManager
75
+ The path manager for the model.
76
+ mod_loader : MODFileLoader
77
+ The MOD file loader.
78
+ simulator_name : str
79
+ The name of the simulator to use. Default is 'NEURON'.
80
+ point_tree : PointTree
81
+ The point tree representing the morphological reconstruction.
82
+ sec_tree : SectionTree
83
+ The section tree representing the morphology on the section level.
84
+ mechanisms : dict
85
+ A dictionary of mechanisms available for the model.
86
+ domains_to_mechs : dict
87
+ A dictionary mapping domains to mechanisms inserted in them.
88
+ params : dict
89
+ A dictionary mapping parameters to their distributions.
90
+ d_lambda : float
91
+ The spatial discretization parameter.
92
+ seg_tree : SegmentTree
93
+ The segment tree representing the morphology on the segment level.
94
+ iclamps : dict
95
+ A dictionary of current clamps in the model.
96
+ populations : dict
97
+ A dictionary of "virtual" populations forming synapses on the model.
98
+ simulator : Simulator
99
+ The simulator object to use.
100
+ """
101
+
102
+ def __init__(self, path_to_model,
103
+ simulator_name='NEURON',) -> None:
104
+
105
+ # Metadata
106
+ self.path_to_model = path_to_model
107
+ self._name = os.path.basename(os.path.normpath(path_to_model))
108
+ self.morphology_name = ''
109
+ self.version = ''
110
+ self.path_manager = PathManager(path_to_model)
111
+ self.simulator_name = simulator_name
112
+ self._verbose = False
113
+
114
+ # File managers
115
+ self.mod_loader = MODFileLoader()
116
+
117
+ # Morphology
118
+ self.point_tree = None
119
+ self.sec_tree = None
120
+
121
+ # Mechanisms
122
+ self.mechanisms = {}
123
+ self.domains_to_mechs = {}
124
+
125
+ # Parameters
126
+ self.params = {
127
+ 'cm': {'all': Distribution('constant', value=1)}, # uF/cm2
128
+ 'Ra': {'all': Distribution('constant', value=35.4)}, # Ohm cm
129
+ }
130
+
131
+ self.params_to_units = {
132
+ 'cm': pq.uF/pq.cm**2,
133
+ 'Ra': pq.ohm*pq.cm,
134
+ }
135
+
136
+ # Groups
137
+ self._groups = []
138
+
139
+ # Distributions
140
+ # self.distributed_params = {}
141
+
142
+ # Segmentation
143
+ self.d_lambda = 0.1
144
+ self.seg_tree = None
145
+
146
+ # Stimuli
147
+ self.iclamps = {}
148
+ self.populations = POPULATIONS
149
+
150
+ # Simulator
151
+ if simulator_name == 'NEURON':
152
+ self.simulator = NEURONSimulator()
153
+ elif simulator_name == 'Jaxley':
154
+ self.simulator = JaxleySimulator()
155
+ else:
156
+ raise ValueError(
157
+ 'Simulator name not recognized. Use NEURON or Jaxley.')
158
+
159
+
160
+ # -----------------------------------------------------------------------
161
+ # PROPERTIES
162
+ # -----------------------------------------------------------------------
163
+
164
+ @property
165
+ def name(self):
166
+ """
167
+ The name of the directory containing the model.
168
+ """
169
+ return self._name
170
+
171
+ @property
172
+ def verbose(self):
173
+ """
174
+ Whether to print verbose output.
175
+ """
176
+ return self._verbose
177
+
178
+ @verbose.setter
179
+ def verbose(self, value):
180
+ self._verbose = value
181
+ self.mod_loader.verbose = value
182
+
183
+
184
+ @property
185
+ def domains(self):
186
+ """
187
+ The morphological or functional domains of the model.
188
+ Reference to the domains in the section tree.
189
+ """
190
+ return self.sec_tree.domains
191
+
192
+
193
+ @property
194
+ def recordings(self):
195
+ """
196
+ The recordings of the model. Reference to the recordings in the simulator.
197
+ """
198
+ return self.simulator.recordings
199
+
200
+
201
+ @recordings.setter
202
+ def recordings(self, recordings):
203
+ self.simulator.recordings = recordings
204
+
205
+
206
+ @property
207
+ def groups(self):
208
+ """
209
+ The dictionary of segment groups in the model.
210
+ """
211
+ return {group.name: group for group in self._groups}
212
+
213
+
214
+ @property
215
+ def groups_to_parameters(self):
216
+ """
217
+ The dictionary mapping segment groups to parameters.
218
+ """
219
+ groups_to_parameters = {}
220
+ for group in self._groups:
221
+ groups_to_parameters[group.name] = {}
222
+ for mech_name, params in self.mechs_to_params.items():
223
+ if mech_name not in group.mechanisms:
224
+ continue
225
+ groups_to_parameters[group.name] = params
226
+ return groups_to_parameters
227
+
228
+ @property
229
+ def mechs_to_domains(self):
230
+ """
231
+ The dictionary mapping mechanisms to domains where they are inserted.
232
+ """
233
+ mechs_to_domains = defaultdict(set)
234
+ for domain, mechs in self.domains_to_mechs.items():
235
+ for mech in mechs:
236
+ mechs_to_domains[mech].add(domain)
237
+ return dict(mechs_to_domains)
238
+
239
+
240
+ @property
241
+ def parameters_to_groups(self):
242
+ """
243
+ The dictionary mapping parameters to groups where they are distributed.
244
+ """
245
+ parameters_to_groups = defaultdict(list)
246
+ for group in self._groups:
247
+ for mech_name, params in self.mechs_to_params.items():
248
+ if mech_name not in group.mechanisms:
249
+ continue
250
+ for param in params:
251
+ parameters_to_groups[param].append(group.name)
252
+ return dict(parameters_to_groups)
253
+
254
+
255
+ @property
256
+ def params_to_mechs(self):
257
+ """
258
+ The dictionary mapping parameters to mechanisms to which they belong.
259
+ """
260
+ params_to_mechs = {}
261
+ # Sort mechanisms by length (longer first) to ensure specific matches
262
+ sorted_mechs = sorted(self.mechanisms, key=len, reverse=True)
263
+ for param in self.params:
264
+ matched = False
265
+ for mech in sorted_mechs:
266
+ suffix = f"_{mech}" # Define exact suffix
267
+ if param.endswith(suffix):
268
+ params_to_mechs[param] = mech
269
+ matched = True
270
+ break
271
+ if not matched:
272
+ params_to_mechs[param] = "Independent" # No match found
273
+ return params_to_mechs
274
+
275
+
276
+ @property
277
+ def mechs_to_params(self):
278
+ """
279
+ The dictionary mapping mechanisms to parameters they contain.
280
+ """
281
+ mechs_to_params = defaultdict(list)
282
+ for param, mech_name in self.params_to_mechs.items():
283
+ mechs_to_params[mech_name].append(param)
284
+ return dict(mechs_to_params)
285
+
286
+
287
+ @property
288
+ def conductances(self):
289
+ """
290
+ A filtered dictionary of parameters that represent conductances.
291
+ """
292
+ return {param: value for param, value in self.params.items()
293
+ if param.startswith('gbar')}
294
+ # -----------------------------------------------------------------------
295
+ # METADATA
296
+ # -----------------------------------------------------------------------
297
+
298
+ def info(self):
299
+ """
300
+ Print information about the model.
301
+ """
302
+ info_str = (
303
+ f"Model: {self.name}\n"
304
+ f"Path to data: {self.path_manager.path_to_data}\n"
305
+ f"Simulator: {self.simulator_name}\n"
306
+ f"Groups: {len(self.groups)}\n"
307
+ f"Avaliable mechanisms: {len(self.mechanisms)}\n"
308
+ f"Inserted mechanisms: {len(self.mechs_to_params) - 1}\n"
309
+ # f"Parameters: {len(self.parameters)}\n"
310
+ f"IClamps: {len(self.iclamps)}\n"
311
+ )
312
+ print(info_str)
313
+
314
+
315
+ @property
316
+ def df_params(self):
317
+ """
318
+ A DataFrame of parameters and their distributions.
319
+ """
320
+ data = []
321
+ for mech_name, params in self.mechs_to_params.items():
322
+ for param in params:
323
+ for group_name, distribution in self.params[param].items():
324
+ data.append({
325
+ 'Mechanism': mech_name,
326
+ 'Parameter': param,
327
+ 'Group': group_name,
328
+ 'Distribution': distribution if isinstance(distribution, str) else distribution.function_name,
329
+ 'Distribution params': {} if isinstance(distribution, str) else distribution.parameters,
330
+ })
331
+ df = pd.DataFrame(data)
332
+ return df
333
+
334
+ def print_directory_tree(self, *args, **kwargs):
335
+ """
336
+ Print the directory tree.
337
+ """
338
+ return self.path_manager.print_directory_tree(*args, **kwargs)
339
+
340
+ def list_morphologies(self, extension='swc'):
341
+ """
342
+ List the morphologies available for the model.
343
+ """
344
+ return self.path_manager.list_files('morphology', extension=extension)
345
+
346
+ def list_membrane_configs(self, extension='json'):
347
+ """
348
+ List the membrane configurations available for the model.
349
+ """
350
+ return self.path_manager.list_files('membrane', extension=extension)
351
+
352
+ def list_mechanisms(self, extension='mod'):
353
+ """
354
+ List the mechanisms available for the model.
355
+ """
356
+ return self.path_manager.list_files('mod', extension=extension)
357
+
358
+ def list_stimuli_configs(self, extension='json'):
359
+ """
360
+ List the stimuli configurations available for the model.
361
+ """
362
+ return self.path_manager.list_files('stimuli', extension=extension)
363
+
364
+ # ========================================================================
365
+ # MORPHOLOGY
366
+ # ========================================================================
367
+
368
+ def load_morphology(self, file_name, soma_notation='3PS',
369
+ align=True, sort_children=True, force=False) -> None:
370
+ """
371
+ Read an SWC file and build the SWC and section trees.
372
+
373
+ Parameters
374
+ ----------
375
+ file_name : str
376
+ The name of the SWC file to read.
377
+ soma_notation : str, optional
378
+ The notation of the soma in the SWC file. Can be '3PS' (three-point soma) or '1PS'. Default is '3PS'.
379
+ align : bool, optional
380
+ Whether to align the morphology to the soma center and align the apical dendrite (if present).
381
+ sort_children : bool, optional
382
+ Whether to sort the children of each node by increasing subtree size
383
+ in the tree sorting algorithms. If True, the traversal visits
384
+ children with shorter subtrees first and assigns them lower indices. If False, children
385
+ are visited in their original SWC file order (matching NEURON's behavior).
386
+ """
387
+ # self.name = file_name.split('.')[0]
388
+ self.morphology_name = file_name.replace('.swc', '')
389
+ path_to_swc_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
390
+ point_tree = create_point_tree(path_to_swc_file)
391
+ # point_tree.remove_overlaps()
392
+ point_tree.change_soma_notation(soma_notation)
393
+ point_tree.sort(sort_children=sort_children, force=force)
394
+ if align:
395
+ point_tree.shift_coordinates_to_soma_center()
396
+ point_tree.align_apical_dendrite()
397
+ point_tree.round_coordinates(8)
398
+ self.point_tree = point_tree
399
+
400
+ sec_tree = create_section_tree(point_tree)
401
+ sec_tree.sort(sort_children=sort_children, force=force)
402
+ self.sec_tree = sec_tree
403
+
404
+ self.create_and_reference_sections_in_simulator()
405
+ seg_tree = create_segment_tree(sec_tree)
406
+ self.seg_tree = seg_tree
407
+
408
+ self._add_default_segment_groups()
409
+ self._initialize_domains_to_mechs()
410
+
411
+ d_lambda = self.d_lambda
412
+ self.set_segmentation(d_lambda=d_lambda)
413
+
414
+
415
+ def create_and_reference_sections_in_simulator(self):
416
+ """
417
+ Create and reference sections in the simulator.
418
+ """
419
+ if self.verbose: print(f'Building sections in {self.simulator_name}...')
420
+ for sec in self.sec_tree.sections:
421
+ sec.create_and_reference(self.simulator_name)
422
+ n_sec = len([sec._ref for sec in self.sec_tree.sections
423
+ if sec._ref is not None])
424
+ if self.verbose: print(f'{n_sec} sections created.')
425
+
426
+
427
+
428
+
429
+ def _add_default_segment_groups(self):
430
+ self.add_group('all', list(self.domains.keys()))
431
+ for domain_name in self.domains:
432
+ group_name = DOMAIN_TO_GROUP.get(domain_name, domain_name)
433
+ self.add_group(group_name, [domain_name])
434
+
435
+
436
+ def _initialize_domains_to_mechs(self):
437
+ for domain_name in self.domains:
438
+ # Only if haven't been defined for the previous morphology
439
+ # TODO: Check that domains match
440
+ if not domain_name in self.domains_to_mechs:
441
+ self.domains_to_mechs[domain_name] = set()
442
+ for domain_name, mechs in self.domains_to_mechs.items():
443
+ for mech_name in mechs:
444
+ self.insert_mechanism(mech_name, domain_name)
445
+
446
+
447
+ def get_sections(self, filter_function):
448
+ """Filter sections using a lambda function.
449
+
450
+ Parameters
451
+ ----------
452
+ filter_function : Callable
453
+ The lambda function to filter sections.
454
+ """
455
+ return [sec for sec in self.sec_tree.sections if filter_function(sec)]
456
+
457
+
458
+ def get_segments(self, group_names=None):
459
+ """
460
+ Get the segments in specified groups.
461
+
462
+ Parameters
463
+ ----------
464
+ group_names : List[str]
465
+ The names of the groups to get segments from.
466
+ """
467
+ if not isinstance(group_names, list):
468
+ raise ValueError('Group names must be a list.')
469
+ return [seg for group_name in group_names for seg in self.seg_tree.segments if seg in self.groups[group_name]]
470
+
471
+ # ========================================================================
472
+ # SEGMENTATION
473
+ # ========================================================================
474
+
475
+ def set_segmentation(self, d_lambda=0.1, f=100):
476
+ """
477
+ Set the number of segments in each section based on the geometry.
478
+
479
+ Parameters
480
+ ----------
481
+ d_lambda : float
482
+ The lambda value to use.
483
+ f : float
484
+ The frequency value to use.
485
+ """
486
+ self.d_lambda = d_lambda
487
+
488
+ # Pre-distribute parameters needed for lambda_f calculation
489
+ for param_name in ['cm', 'Ra']:
490
+ self.distribute(param_name)
491
+
492
+ # Calculate lambda_f for each section and set nseg
493
+ for sec in self.sec_tree.sections:
494
+ lambda_f = calculate_lambda_f(sec.distances, sec.diameters, sec.Ra, sec.cm, f)
495
+ nseg = int((sec.L / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1
496
+ # TODO: Set sec._nseg instead
497
+ sec._ref.nseg = nseg
498
+ # Rebuild the segment tree
499
+ self.seg_tree = create_segment_tree(self.sec_tree)
500
+
501
+ # Redistribute parameters
502
+ self.distribute_all()
503
+
504
+
505
+ # ========================================================================
506
+ # MECHANISMS
507
+ # ========================================================================
508
+
509
+ def add_default_mechanisms(self, recompile=False):
510
+ """
511
+ Add default mechanisms to the model.
512
+
513
+ Parameters
514
+ ----------
515
+ recompile : bool, optional
516
+ Whether to recompile the mechanisms.
517
+ """
518
+ leak = LeakChannel()
519
+ self.mechanisms[leak.name] = leak
520
+
521
+ cadyn = CaDynamics()
522
+ self.mechanisms[cadyn.name] = cadyn
523
+
524
+ self.load_mechanisms('default_mod', recompile=recompile)
525
+
526
+
527
+ def add_mechanisms(self, dir_name:str = 'mod', recompile=True) -> None:
528
+ """
529
+ Add a set of mechanisms from an archive to the model.
530
+
531
+ Parameters
532
+ ----------
533
+ dir_name : str, optional
534
+ The name of the archive to load mechanisms from. Default is 'mod'.
535
+ recompile : bool, optional
536
+ Whether to recompile the mechanisms.
537
+ """
538
+ # Create Mechanism objects and add them to the model
539
+ for mechanism_name in self.path_manager.list_files(dir_name, extension='mod'):
540
+ self.add_mechanism(mechanism_name,
541
+ load=True,
542
+ dir_name=dir_name,
543
+ recompile=recompile)
544
+
545
+
546
+
547
+ def add_mechanism(self, mechanism_name: str,
548
+ python_template_name: str = 'default',
549
+ load=True, dir_name: str = 'mod', recompile=True
550
+ ) -> None:
551
+ """
552
+ Create a Mechanism object from the MOD file (or LeakChannel).
553
+
554
+ Parameters
555
+ ----------
556
+ mechanism_name : str
557
+ The name of the mechanism to add.
558
+ python_template_name : str, optional
559
+ The name of the Python template to use. Default is 'default'.
560
+ load : bool, optional
561
+ Whether to load the mechanism using neuron.load_mechanisms.
562
+ """
563
+ paths = self.path_manager.get_channel_paths(
564
+ mechanism_name,
565
+ python_template_name=python_template_name
566
+ )
567
+ mech = create_channel(**paths)
568
+ # Add the mechanism to the model
569
+ self.mechanisms[mech.name] = mech
570
+ # Update the global parameters
571
+
572
+ if load:
573
+ self.load_mechanism(mechanism_name, dir_name, recompile)
574
+
575
+
576
+
577
+ def load_mechanisms(self, dir_name: str = 'mod', recompile=True) -> None:
578
+ """
579
+ Load mechanisms from an archive.
580
+
581
+ Parameters
582
+ ----------
583
+ dir_name : str, optional
584
+ The name of the archive to load mechanisms from.
585
+ recompile : bool, optional
586
+ Whether to recompile the mechanisms.
587
+ """
588
+ mod_files = self.path_manager.list_files(dir_name, extension='mod')
589
+ for mechanism_name in mod_files:
590
+ self.load_mechanism(mechanism_name, dir_name, recompile)
591
+
592
+
593
+ def load_mechanism(self, mechanism_name, dir_name='mod', recompile=False) -> None:
594
+ """
595
+ Load a mechanism from the specified archive.
596
+
597
+ Parameters
598
+ ----------
599
+ mechanism_name : str
600
+ The name of the mechanism to load.
601
+ dir_name : str, optional
602
+ The name of the directory to load the mechanism from. Default is 'mod'.
603
+ recompile : bool, optional
604
+ Whether to recompile the mechanism.
605
+ """
606
+ path_to_mod_file = self.path_manager.get_file_path(
607
+ dir_name, mechanism_name, extension='mod'
608
+ )
609
+ self.mod_loader.load_mechanism(
610
+ path_to_mod_file=path_to_mod_file, recompile=recompile
611
+ )
612
+
613
+
614
+ def standardize_channel(self, channel_name,
615
+ python_template_name=None, mod_template_name=None, remove_old=True):
616
+ """
617
+ Standardize a channel by creating a new channel with the same kinetic
618
+ properties using the standard equations.
619
+
620
+ Parameters
621
+ ----------
622
+ channel_name : str
623
+ The name of the channel to standardize.
624
+ python_template_name : str, optional
625
+ The name of the Python template to use.
626
+ mod_template_name : str, optional
627
+ The name of the MOD template to use.
628
+ remove_old : bool, optional
629
+ Whether to remove the old channel from the model. Default is True.
630
+ """
631
+
632
+ # Get data to transfer
633
+ channel = self.mechanisms[channel_name]
634
+ channel_domain_names = [domain_name for domain_name, mechs
635
+ in self.domains_to_mechs.items() if channel_name in mechs]
636
+ gbar_name = f'gbar_{channel_name}'
637
+ gbar_distributions = self.params[gbar_name]
638
+ # Kinetic variables cannot be transferred
639
+
640
+ # Uninsert the old channel
641
+ for domain_name in self.domains:
642
+ if channel_name in self.domains_to_mechs[domain_name]:
643
+ self.uninsert_mechanism(channel_name, domain_name)
644
+
645
+ # Remove the old channel
646
+ if remove_old:
647
+ self.mechanisms.pop(channel_name)
648
+
649
+ # Create, add and load a new channel
650
+ paths = self.path_manager.get_standard_channel_paths(
651
+ channel_name,
652
+ mod_template_name=mod_template_name
653
+ )
654
+ standard_channel = standardize_channel(channel, **paths)
655
+
656
+ self.mechanisms[standard_channel.name] = standard_channel
657
+ self.load_mechanism(standard_channel.name, recompile=True)
658
+
659
+ # Insert the new channel
660
+ for domain_name in channel_domain_names:
661
+ self.insert_mechanism(standard_channel.name, domain_name)
662
+
663
+ # Transfer data
664
+ gbar_name = f'gbar_{standard_channel.name}'
665
+ for group_name, distribution in gbar_distributions.items():
666
+ self.set_param(gbar_name, group_name,
667
+ distribution.function_name, **distribution.parameters)
668
+
669
+
670
+ # ========================================================================
671
+ # DOMAINS
672
+ # ========================================================================
673
+
674
+ def define_domain(self, domain_name: str, sections, distribute=True):
675
+ """
676
+ Adds a new domain to the tree and ensures correct partitioning of
677
+ the section tree graph.
678
+
679
+ Parameters
680
+ ----------
681
+ domain_name : str
682
+ The name of the domain.
683
+ sections : list[Section] or Callable
684
+ The sections to include in the domain. If a callable is provided,
685
+ it should be a filter function applied to the list of all sections
686
+ of the cell.
687
+ """
688
+ if isinstance(sections, Callable):
689
+ sections = self.get_sections(sections)
690
+
691
+ if domain_name not in self.domains:
692
+ domain = Domain(domain_name)
693
+ self._add_domain_groups(domain.name)
694
+ self.domains[domain_name] = domain
695
+ self.domains_to_mechs[domain_name] = set()
696
+ else:
697
+ domain = self.domains[domain_name]
698
+
699
+ sections_to_move = [sec for sec in sections
700
+ if sec.domain != domain_name]
701
+
702
+ if not sections_to_move:
703
+ warnings.warn(f'Sections already in domain {domain_name}.')
704
+ return
705
+
706
+ for sec in sections_to_move:
707
+ old_domain = self.domains[sec.domain]
708
+ old_domain.remove_section(sec)
709
+ for mech_name in self.domains_to_mechs[old_domain.name]:
710
+ # TODO: What if section is already in domain? Can't be as
711
+ # we use a filtered list of sections.
712
+ sec.uninsert_mechanism(mech_name)
713
+
714
+
715
+ for sec in sections_to_move:
716
+ domain.add_section(sec)
717
+ for mech_name in self.domains_to_mechs.get(domain.name, set()):
718
+ sec.insert_mechanism(mech_name, distribute=distribute)
719
+
720
+ self._remove_empty()
721
+
722
+
723
+ def _add_domain_groups(self, domain_name):
724
+ """
725
+ Manage groups when a domain is added.
726
+ """
727
+ # Add new domain to `all` group
728
+ if self.groups.get('all'):
729
+ self.groups['all'].domains.append(domain_name)
730
+ # Create a new group for the domain
731
+ group_name = DOMAIN_TO_GROUP.get(domain_name, domain_name)
732
+ self.add_group(group_name, domains=[domain_name])
733
+
734
+
735
+ def _remove_empty(self):
736
+ self._remove_empty_domains()
737
+ self._remove_uninserted_mechanisms()
738
+ self._remove_empty_groups()
739
+
740
+
741
+ def _remove_empty_domains(self):
742
+ """
743
+ """
744
+ empty_domains = [domain for domain in self.domains.values()
745
+ if domain.is_empty()]
746
+ for domain in empty_domains:
747
+ warnings.warn(f'Domain {domain.name} is empty and will be removed.')
748
+ self.domains.pop(domain.name)
749
+ self.domains_to_mechs.pop(domain.name)
750
+ for group in self._groups:
751
+ if domain.name in group.domains:
752
+ group.domains.remove(domain.name)
753
+ # self.groups['all'].domains.remove(domain.name)
754
+
755
+
756
+ def _remove_uninserted_mechanisms(self):
757
+ mech_names = list(self.mechs_to_params.keys())
758
+ mechs = [self.mechanisms[mech_name] for mech_name in mech_names
759
+ if mech_name != 'Independent']
760
+ uninserted_mechs = [mech for mech in mechs
761
+ if mech.name not in self.mechs_to_domains]
762
+ for mech in uninserted_mechs:
763
+ warnings.warn(f'Mechanism {mech.name} is not inserted in any domain and will be removed.')
764
+ self._remove_mechanism_params(mech)
765
+
766
+
767
+ def _remove_empty_groups(self):
768
+ empty_groups = [group for group in self._groups
769
+ if not any(seg in group
770
+ for seg in self.seg_tree)]
771
+ for group in empty_groups:
772
+ warnings.warn(f'Group {group.name} is empty and will be removed.')
773
+ self.remove_group(group.name)
774
+
775
+
776
+ # -----------------------------------------------------------------------
777
+ # INSERT / UNINSERT MECHANISMS
778
+ # -----------------------------------------------------------------------
779
+
780
+ def insert_mechanism(self, mechanism_name: str,
781
+ domain_name: str, distribute=True):
782
+ """
783
+ Insert a mechanism into all sections in a domain.
784
+
785
+ Parameters
786
+ ----------
787
+ mechanism_name : str
788
+ The name of the mechanism to insert.
789
+ domain_name : str
790
+ The name of the domain to insert the mechanism into.
791
+ distribute : bool, optional
792
+ Whether to distribute the parameters after inserting the mechanism.
793
+ """
794
+ mech = self.mechanisms[mechanism_name]
795
+ domain = self.domains[domain_name]
796
+
797
+ # domain.insert_mechanism(mech)
798
+ self.domains_to_mechs[domain_name].add(mech.name)
799
+ for sec in domain.sections:
800
+ sec.insert_mechanism(mech.name)
801
+ self._add_mechanism_params(mech)
802
+
803
+ # TODO: Redistribute parameters if any group contains this domain
804
+ if distribute:
805
+ for param_name in self.params:
806
+ self.distribute(param_name)
807
+
808
+
809
+ def _add_mechanism_params(self, mech):
810
+ """
811
+ Update the parameters when a mechanism is inserted.
812
+ By default each parameter is set to a constant value
813
+ through the entire cell.
814
+ """
815
+ for param_name, value in mech.range_params_with_suffix.items():
816
+ self.params[param_name] = {'all': Distribution('constant', value=value)}
817
+
818
+ if hasattr(mech, 'ion') and mech.ion in ['na', 'k', 'ca']:
819
+ self._add_equilibrium_potentials_on_mech_insert(mech.ion)
820
+
821
+
822
+ def _add_equilibrium_potentials_on_mech_insert(self, ion: str) -> None:
823
+ """
824
+ """
825
+ if ion == 'na' and not self.params.get('ena'):
826
+ self.params['ena'] = {'all': Distribution('constant', value=50)}
827
+ elif ion == 'k' and not self.params.get('ek'):
828
+ self.params['ek'] = {'all': Distribution('constant', value=-77)}
829
+ elif ion == 'ca' and not self.params.get('eca'):
830
+ self.params['eca'] = {'all': Distribution('constant', value=140)}
831
+
832
+
833
+ def uninsert_mechanism(self, mechanism_name: str,
834
+ domain_name: str):
835
+ """
836
+ Uninsert a mechanism from all sections in a domain
837
+
838
+ Parameters
839
+ ----------
840
+ mechanism_name : str
841
+ The name of the mechanism to uninsert.
842
+ domain_name : str
843
+ The name of the domain to uninsert the mechanism from.
844
+ """
845
+ mech = self.mechanisms[mechanism_name]
846
+ domain = self.domains[domain_name]
847
+
848
+ # domain.uninsert_mechanism(mech)
849
+ for sec in domain.sections:
850
+ sec.uninsert_mechanism(mech.name)
851
+ self.domains_to_mechs[domain_name].remove(mech.name)
852
+
853
+ if not self.mechs_to_domains.get(mech.name):
854
+ warnings.warn(f'Mechanism {mech.name} is not inserted in any domain and will be removed.')
855
+ self._remove_mechanism_params(mech)
856
+
857
+
858
+ def _remove_mechanism_params(self, mech):
859
+ for param_name in self.mechs_to_params.get(mech.name, []):
860
+ self.params.pop(param_name)
861
+
862
+ if hasattr(mech, 'ion') and mech.ion in ['na', 'k', 'ca']:
863
+ self._remove_equilibrium_potentials_on_mech_uninsert(mech.ion)
864
+
865
+
866
+ def _remove_equilibrium_potentials_on_mech_uninsert(self, ion: str) -> None:
867
+ """
868
+ """
869
+ for mech_name, mech in self.mechanisms.items():
870
+ if hasattr(mech, 'ion'):
871
+ if mech.ion == mech.ion: return
872
+
873
+ if ion == 'na':
874
+ self.params.pop('ena', None)
875
+ elif ion == 'k':
876
+ self.params.pop('ek', None)
877
+ elif ion == 'ca':
878
+ self.params.pop('eca', None)
879
+
880
+
881
+ # ========================================================================
882
+ # SET PARAMETERS
883
+ # ========================================================================
884
+
885
+ # -----------------------------------------------------------------------
886
+ # GROUPS
887
+ # -----------------------------------------------------------------------
888
+
889
+ def add_group(self, name, domains, select_by=None, min_value=None, max_value=None):
890
+ """
891
+ Add a group of sections to the model.
892
+
893
+ Parameters
894
+ ----------
895
+ name : str
896
+ The name of the group.
897
+ domains : list[str]
898
+ The domains to include in the group.
899
+ select_by : str, optional
900
+ The parameter to select the sections by. Can be 'diam', 'distance', 'domain_distance'.
901
+ min_value : float, optional
902
+ The minimum value of the parameter.
903
+ max_value : float, optional
904
+ The maximum value of the
905
+ """
906
+ if self.verbose: print(f'Adding group {name}...')
907
+ group = SegmentGroup(name, domains, select_by, min_value, max_value)
908
+ self._groups.append(group)
909
+
910
+
911
+ def remove_group(self, group_name):
912
+ """
913
+ Remove a group from the model.
914
+
915
+ Parameters
916
+ ----------
917
+ group_name : str
918
+ The name of the group to remove.
919
+ """
920
+ # Remove group from the list of groups
921
+ self._groups = [group for group in self._groups
922
+ if group.name != group_name]
923
+ # Remove distributions that refer to this group
924
+ for param_name, groups_to_distrs in self.params.items():
925
+ groups_to_distrs.pop(group_name, None)
926
+
927
+
928
+ def move_group_down(self, name):
929
+ """
930
+ Move a group down in the list of groups.
931
+
932
+ Parameters
933
+ ----------
934
+ name : str
935
+ The name of the group to move down.
936
+ """
937
+ idx = next(i for i, group in enumerate(self._groups) if group.name == name)
938
+ if idx > 0:
939
+ self._groups[idx-1], self._groups[idx] = self._groups[idx], self._groups[idx-1]
940
+ for param_name in self.distributed_params:
941
+ self.distribute(param_name)
942
+
943
+
944
+ def move_group_up(self, name):
945
+ """
946
+ Move a group up in the list of groups.
947
+
948
+ Parameters
949
+ ----------
950
+ name : str
951
+ The name of the group to move up.
952
+ """
953
+ idx = next(i for i, group in enumerate(self._groups) if group.name == name)
954
+ if idx < len(self._groups) - 1:
955
+ self._groups[idx+1], self._groups[idx] = self._groups[idx], self._groups[idx+1]
956
+ for param_name in self.distributed_params:
957
+ self.distribute(param_name)
958
+
959
+
960
+ # -----------------------------------------------------------------------
961
+ # DISTRIBUTIONS
962
+ # -----------------------------------------------------------------------
963
+
964
+ def set_param(self, param_name: str,
965
+ group_name: str = 'all',
966
+ distr_type: str = 'constant',
967
+ **distr_params):
968
+ """
969
+ Set a parameter for a group of segments.
970
+
971
+ Parameters
972
+ ----------
973
+ param_name : str
974
+ The name of the parameter to set.
975
+ group_name : str, optional
976
+ The name of the group to set the parameter for. Default is 'all'.
977
+ distr_type : str, optional
978
+ The type of the distribution to use. Default is 'constant'.
979
+ distr_params : dict
980
+ The parameters of the distribution.
981
+ """
982
+
983
+ if 'group' in distr_params:
984
+ raise ValueError("Did you mean 'group_name' instead of 'group'?")
985
+
986
+ if param_name in ['temperature', 'v_init']:
987
+ setattr(self.simulator, param_name, distr_params['value'])
988
+ return
989
+
990
+ for key, value in distr_params.items():
991
+ if not isinstance(value, (int, float)) or value is nan:
992
+ raise ValueError(f"Parameter '{key}' must be a numeric value and not NaN, got {type(value).__name__} instead.")
993
+
994
+ self.set_distribution(param_name, group_name, distr_type, **distr_params)
995
+ self.distribute(param_name)
996
+
997
+
998
+ def set_distribution(self, param_name: str,
999
+ group_name: None,
1000
+ distr_type: str = 'constant',
1001
+ **distr_params):
1002
+ """
1003
+ Set a distribution for a parameter.
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ param_name : str
1008
+ The name of the parameter to set.
1009
+ group_name : str, optional
1010
+ The name of the group to set the parameter for. Default is 'all'.
1011
+ distr_type : str, optional
1012
+ The type of the distribution to use. Default is 'constant'.
1013
+ distr_params : dict
1014
+ The parameters of the distribution.
1015
+ """
1016
+
1017
+ if distr_type == 'inherit':
1018
+ distribution = 'inherit'
1019
+ else:
1020
+ distribution = Distribution(distr_type, **distr_params)
1021
+ self.params[param_name][group_name] = distribution
1022
+
1023
+
1024
+ def distribute_all(self):
1025
+ """
1026
+ Distribute all parameters to the segments.
1027
+ """
1028
+ groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group]
1029
+ for group in self._groups}
1030
+ for param_name in self.params:
1031
+ self.distribute(param_name, groups_to_segments)
1032
+
1033
+
1034
+ def distribute(self, param_name: str, precomputed_groups=None):
1035
+ """
1036
+ Distribute a parameter to the segments.
1037
+
1038
+ Parameters
1039
+ ----------
1040
+ param_name : str
1041
+ The name of the parameter to distribute.
1042
+ precomputed_groups : dict, optional
1043
+ A dictionary mapping group names to segments. Default is None.
1044
+ """
1045
+ if param_name == 'Ra':
1046
+ self._distribute_Ra(precomputed_groups)
1047
+ return
1048
+
1049
+ groups_to_segments = precomputed_groups
1050
+ if groups_to_segments is None:
1051
+ groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group]
1052
+ for group in self._groups}
1053
+
1054
+ param_distributions = self.params[param_name]
1055
+
1056
+ for group_name, distribution in param_distributions.items():
1057
+
1058
+ filtered_segments = groups_to_segments[group_name]
1059
+
1060
+ if distribution == 'inherit':
1061
+ for seg in filtered_segments:
1062
+ value = seg.parent.get_param_value(param_name)
1063
+ seg.set_param_value(param_name, value)
1064
+ else:
1065
+ for seg in filtered_segments:
1066
+ value = distribution(seg.path_distance())
1067
+ seg.set_param_value(param_name, value)
1068
+
1069
+
1070
+ def _distribute_Ra(self, precomputed_groups=None):
1071
+ """
1072
+ Distribute the axial resistance to the segments.
1073
+ """
1074
+
1075
+ groups_to_segments = precomputed_groups
1076
+ if groups_to_segments is None:
1077
+ groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group]
1078
+ for group in self._groups}
1079
+
1080
+ param_distributions = self.params['Ra']
1081
+
1082
+ for group_name, distribution in param_distributions.items():
1083
+
1084
+ filtered_segments = groups_to_segments[group_name]
1085
+ if distribution == 'inherit':
1086
+ raise NotImplementedError("Inheritance of Ra is not implemented.")
1087
+ else:
1088
+ for seg in filtered_segments:
1089
+ value = distribution(seg._section.path_distance(0.5))
1090
+ seg._section._ref.Ra = value
1091
+
1092
+
1093
+ def remove_distribution(self, param_name, group_name):
1094
+ """
1095
+ Remove a distribution for a parameter.
1096
+
1097
+ Parameters
1098
+ ----------
1099
+ param_name : str
1100
+ The name of the parameter to remove the distribution for.
1101
+ group_name : str
1102
+ The name of the group to remove the distribution for.
1103
+ """
1104
+ self.params[param_name].pop(group_name, None)
1105
+ self.distribute(param_name)
1106
+
1107
+ # def set_section_param(self, param_name, value, domains=None):
1108
+
1109
+ # domains = domains or self.domains
1110
+ # for sec in self.sec_tree.sections:
1111
+ # if sec.domain in domains:
1112
+ # setattr(sec._ref, param_name, value)
1113
+
1114
+ # ========================================================================
1115
+ # STIMULI
1116
+ # ========================================================================
1117
+
1118
+ # -----------------------------------------------------------------------
1119
+ # ICLAMPS
1120
+ # -----------------------------------------------------------------------
1121
+
1122
+ def add_iclamp(self, sec, loc, amp=0, delay=100, dur=100):
1123
+ """
1124
+ Add an IClamp to a section.
1125
+
1126
+ Parameters
1127
+ ----------
1128
+ sec : Section
1129
+ The section to add the IClamp to.
1130
+ loc : float
1131
+ The location of the IClamp in the section.
1132
+ amp : float, optional
1133
+ The amplitude of the IClamp. Default is 0.
1134
+ delay : float, optional
1135
+ The delay of the IClamp. Default is 100.
1136
+ dur : float, optional
1137
+ The duration of the IClamp. Default is 100.
1138
+ """
1139
+ seg = sec(loc)
1140
+ if self.iclamps.get(seg):
1141
+ self.remove_iclamp(sec, loc)
1142
+ iclamp = IClamp(sec, loc, amp, delay, dur)
1143
+ print(f'IClamp added to sec {sec} at loc {loc}.')
1144
+ self.iclamps[seg] = iclamp
1145
+
1146
+
1147
+ def remove_iclamp(self, sec, loc):
1148
+ """
1149
+ Remove an IClamp from a section.
1150
+
1151
+ Parameters
1152
+ ----------
1153
+ sec : Section
1154
+ The section to remove the IClamp from.
1155
+ loc : float
1156
+ The location of the IClamp in the section.
1157
+ """
1158
+ seg = sec(loc)
1159
+ if self.iclamps.get(seg):
1160
+ self.iclamps.pop(seg)
1161
+
1162
+
1163
+ def remove_all_iclamps(self):
1164
+ """
1165
+ Remove all IClamps from the model.
1166
+ """
1167
+
1168
+ for seg in list(self.iclamps.keys()):
1169
+ sec, loc = seg._section, seg.x
1170
+ self.remove_iclamp(sec, loc)
1171
+ if self.iclamps:
1172
+ warnings.warn(f'Not all iclamps were removed: {self.iclamps}')
1173
+ self.iclamps = {}
1174
+
1175
+
1176
+ # -----------------------------------------------------------------------
1177
+ # SYNAPSES
1178
+ # -----------------------------------------------------------------------
1179
+
1180
+ def _add_population(self, population):
1181
+ self.populations[population.syn_type][population.name] = population
1182
+
1183
+
1184
+ def add_population(self, segments, N, syn_type):
1185
+ """
1186
+ Add a population of synapses to the model.
1187
+
1188
+ Parameters
1189
+ ----------
1190
+ segments : list[Segment]
1191
+ The segments to add the synapses to.
1192
+ N : int
1193
+ The number of synapses to add.
1194
+ syn_type : str
1195
+ The type of synapse to add.
1196
+ """
1197
+ idx = len(self.populations[syn_type])
1198
+ population = Population(idx, segments, N, syn_type)
1199
+ population.allocate_synapses()
1200
+ population.create_inputs()
1201
+ self._add_population(population)
1202
+
1203
+
1204
+ def update_population_kinetic_params(self, pop_name, **params):
1205
+ """
1206
+ Update the kinetic parameters of a population of synapses.
1207
+
1208
+ Parameters
1209
+ ----------
1210
+ pop_name : str
1211
+ The name of the population.
1212
+ params : dict
1213
+ The parameters to update.
1214
+ """
1215
+ syn_type, idx = pop_name.rsplit('_', 1)
1216
+ population = self.populations[syn_type][pop_name]
1217
+ population.update_kinetic_params(**params)
1218
+ print(population.kinetic_params)
1219
+
1220
+
1221
+ def update_population_input_params(self, pop_name, **params):
1222
+ """
1223
+ Update the input parameters of a population of synapses.
1224
+
1225
+ Parameters
1226
+ ----------
1227
+ pop_name : str
1228
+ The name of the population.
1229
+ params : dict
1230
+ The parameters to update.
1231
+ """
1232
+ syn_type, idx = pop_name.rsplit('_', 1)
1233
+ population = self.populations[syn_type][pop_name]
1234
+ population.update_input_params(**params)
1235
+ print(population.input_params)
1236
+
1237
+
1238
+ def remove_population(self, name):
1239
+ """
1240
+ Remove a population of synapses from the model.
1241
+
1242
+ Parameters
1243
+ ----------
1244
+ name : str
1245
+ The name of the population
1246
+ """
1247
+ syn_type, idx = name.rsplit('_', 1)
1248
+ population = self.populations[syn_type].pop(name)
1249
+ population.clean()
1250
+
1251
+ def remove_all_populations(self):
1252
+ """
1253
+ Remove all populations of synapses from the model.
1254
+ """
1255
+ for syn_type in self.populations:
1256
+ for name in list(self.populations[syn_type].keys()):
1257
+ self.remove_population(name)
1258
+ if any(self.populations.values()):
1259
+ warnings.warn(f'Not all populations were removed: {self.populations}')
1260
+ self.populations = POPULATIONS
1261
+
1262
+ def remove_all_stimuli(self):
1263
+ """
1264
+ Remove all stimuli from the model.
1265
+ """
1266
+ self.remove_all_iclamps()
1267
+ self.remove_all_populations()
1268
+
1269
+ # ========================================================================
1270
+ # SIMULATION
1271
+ # ========================================================================
1272
+
1273
+ def add_recording(self, sec, loc, var='v'):
1274
+ """
1275
+ Add a recording to the model.
1276
+
1277
+ Parameters
1278
+ ----------
1279
+ sec : Section
1280
+ The section to record from.
1281
+ loc : float
1282
+ The location along the normalized section length to record from.
1283
+ var : str, optional
1284
+ The variable to record. Default is 'v'.
1285
+ """
1286
+ self.simulator.add_recording(sec, loc, var)
1287
+ print(f'Recording added to sec {sec} at loc {loc}.')
1288
+
1289
+
1290
+ def remove_recording(self, sec, loc):
1291
+ """
1292
+ Remove a recording from the model.
1293
+
1294
+ Parameters
1295
+ ----------
1296
+ sec : Section
1297
+ The section to remove the recording from.
1298
+ loc : float
1299
+ The location along the normalized section length to remove the recording from.
1300
+ """
1301
+ self.simulator.remove_recording(sec, loc)
1302
+
1303
+
1304
+ def remove_all_recordings(self):
1305
+ """
1306
+ Remove all recordings from the model.
1307
+ """
1308
+ self.simulator.remove_all_recordings()
1309
+
1310
+
1311
+ def run(self, duration=300):
1312
+ """
1313
+ Run the simulation for a specified duration.
1314
+
1315
+ Parameters
1316
+ ----------
1317
+ duration : float, optional
1318
+ The duration of the simulation. Default is 300.
1319
+ """
1320
+ self.simulator.run(duration)
1321
+
1322
+ def get_traces(self):
1323
+ return self.simulator.get_traces()
1324
+
1325
+ def plot(self, *args, **kwargs):
1326
+ self.simulator.plot(*args, **kwargs)
1327
+
1328
+ # ========================================================================
1329
+ # MORPHOLOGY
1330
+ # ========================================================================
1331
+
1332
+ def remove_subtree(self, sec):
1333
+ """
1334
+ Remove a subtree from the model.
1335
+
1336
+ Parameters
1337
+ ----------
1338
+ sec : Section
1339
+ The root section of the subtree to remove.
1340
+ """
1341
+ self.sec_tree.remove_subtree(sec)
1342
+ self.sec_tree.sort()
1343
+ self._remove_empty()
1344
+
1345
+
1346
+ def merge_domains(self, domain_names: List[str]):
1347
+ """
1348
+ Merge two domains into one.
1349
+ """
1350
+ domains = [self.domains[domain_name] for domain_name in domain_names]
1351
+ for domain in domains[1:]:
1352
+ domains[0].merge(domain)
1353
+ self.remove_empty()
1354
+
1355
+
1356
+ def reduce_subtree(self, root, reduction_frequency=0, total_segments_manual=-1, fit=True):
1357
+ """
1358
+ Reduce a subtree to a single section.
1359
+
1360
+ Parameters
1361
+ ----------
1362
+ root : Section
1363
+ The root section of the subtree to reduce.
1364
+ reduction_frequency : float, optional
1365
+ The frequency of the reduction. Default is 0.
1366
+ total_segments_manual : int, optional
1367
+ The number of segments in the reduced subtree. Default is -1 (automatic).
1368
+ fit : bool, optional
1369
+ Whether to create distributions for the reduced subtree by fitting
1370
+ the calculated average values. Default is True.
1371
+ """
1372
+
1373
+ domain_name = root.domain
1374
+ parent = root.parent
1375
+ domains_in_subtree = [self.domains[domain_name]
1376
+ for domain_name in set([sec.domain for sec in root.subtree])]
1377
+ if len(domains_in_subtree) > 1:
1378
+ # ensure the domains have the same mechanisms using self.domains_to_mechs
1379
+ domains_to_mechs = {domain_name: mech_names for domain_name, mech_names
1380
+ in self.domains_to_mechs.items() if domain_name in [domain.name for domain in domains_in_subtree]}
1381
+ common_mechs = set.intersection(*domains_to_mechs.values())
1382
+ if not common_mechs:
1383
+ raise ValueError(
1384
+ 'The domains in the subtree have different mechanisms. '
1385
+ 'Please ensure that all domains in the subtree have the same mechanisms. '
1386
+ 'You may need to insert the missing mechanisms and set their conductances to 0 where they are not needed.'
1387
+ )
1388
+
1389
+ inserted_mechs = {mech_name: mech for mech_name, mech
1390
+ in self.mechanisms.items()
1391
+ if mech_name in self.domains_to_mechs[domain_name]
1392
+ }
1393
+
1394
+ subtree_without_root = [sec for sec in root.subtree if sec is not root]
1395
+
1396
+ # Map original segment names to their parameters
1397
+ segs_to_params = rdc.map_segs_to_params(root, inserted_mechs)
1398
+
1399
+
1400
+ # Temporarily remove active mechanisms
1401
+ for mech_name in inserted_mechs:
1402
+ if mech_name == 'Leak':
1403
+ continue
1404
+ for sec in root.subtree:
1405
+ sec.uninsert_mechanism(mech_name)
1406
+
1407
+ # Disconnect
1408
+ root.disconnect_from_parent()
1409
+
1410
+ # Calculate new properties of a reduced subtree
1411
+ new_cable_properties = rdc.get_unique_cable_properties(root._ref, reduction_frequency)
1412
+ new_nseg = rdc.calculate_nsegs(new_cable_properties, total_segments_manual)
1413
+ print(new_cable_properties)
1414
+
1415
+
1416
+ # Map segment names to their new locations in the reduced cylinder
1417
+ segs_to_locs = rdc.map_segs_to_locs(root, reduction_frequency, new_cable_properties)
1418
+
1419
+
1420
+ # Reconnect
1421
+ root.connect_to_parent(parent)
1422
+
1423
+ # Delete the original subtree
1424
+ children = root.children[:]
1425
+ for child_sec in children:
1426
+ self.remove_subtree(child_sec)
1427
+
1428
+ # Set passive mechanisms for the reduced cylinder:
1429
+ rdc.apply_params_to_section(root, new_cable_properties, new_nseg)
1430
+
1431
+
1432
+ # Reinsert active mechanisms
1433
+ for mech_name in inserted_mechs:
1434
+ if mech_name == 'Leak':
1435
+ continue
1436
+ for sec in root.subtree:
1437
+ sec.insert_mechanism(mech_name)
1438
+
1439
+ # Replace locs with corresponding segs
1440
+
1441
+ segs_to_reduced_segs = rdc.map_segs_to_reduced_segs(segs_to_locs, root)
1442
+
1443
+ # Map reduced segments to lists of parameters of corresponding original segments
1444
+ reduced_segs_to_params = rdc.map_reduced_segs_to_params(segs_to_reduced_segs, segs_to_params)
1445
+
1446
+ # Set new values of parameters
1447
+ rdc.set_avg_params_to_reduced_segs(reduced_segs_to_params)
1448
+ rdc.interpolate_missing_values(reduced_segs_to_params, root)
1449
+
1450
+ if not fit:
1451
+ return
1452
+
1453
+ root_segs = [seg for seg in root.segments]
1454
+ params_to_coeffs = {}
1455
+ for param_name in self.params:
1456
+ coeffs = self.fit_distribution(param_name, segments=root_segs, plot=False)
1457
+ params_to_coeffs[param_name] = coeffs
1458
+
1459
+
1460
+ # Create new domain
1461
+ reduced_domains = [domain_name for domain_name in self.domains if domain_name.startswith('reduced')]
1462
+ new_reduced_domain_name = f'reduced_{len(reduced_domains)}'
1463
+ group_name = new_reduced_domain_name
1464
+ self.define_domain(new_reduced_domain_name, sections=[root], distribute=False)
1465
+
1466
+
1467
+ # Reinsert active mechanisms after creating the new domain
1468
+ for mech_name in inserted_mechs:
1469
+ root.insert_mechanism(mech_name)
1470
+ self.domains_to_mechs[new_reduced_domain_name] = set(inserted_mechs.keys())
1471
+
1472
+
1473
+ # # Fit distributions to data for the group
1474
+ for param_name, coeffs in params_to_coeffs.items():
1475
+ self._set_distribution(param_name, group_name, coeffs, plot=True)
1476
+
1477
+ # # Distribute parameters
1478
+ self.distribute_all()
1479
+
1480
+ return {
1481
+ 'domain_name': new_reduced_domain_name,
1482
+ 'group_name': group_name,
1483
+ 'segs_to_params': segs_to_params,
1484
+ 'segs_to_locs': segs_to_locs,
1485
+ 'segs_to_reduced_segs': segs_to_reduced_segs,
1486
+ 'reduced_segs_to_params': reduced_segs_to_params,
1487
+ 'params_to_coeffs': params_to_coeffs
1488
+ }
1489
+
1490
+
1491
+ def fit_distribution(self, param_name, segments, max_degree=6, tolerance=1e-7, plot=False):
1492
+ from numpy import polyfit, polyval
1493
+ values = [seg.get_param_value(param_name) for seg in segments]
1494
+ distances = [seg.path_distance() for seg in segments]
1495
+ sorted_pairs = sorted(zip(distances, values))
1496
+ distances, values = zip(*sorted_pairs)
1497
+ degrees = range(0, max_degree+1)
1498
+ for degree in degrees:
1499
+ coeffs = polyfit(distances, values, degree)
1500
+ residuals = values - polyval(coeffs, distances)
1501
+ if all(abs(residuals) < tolerance):
1502
+ break
1503
+ if not all(abs(residuals) < tolerance):
1504
+ warnings.warn(f'Fitting failed for parameter {param_name} with the provided tolerance.\nUsing the last valid fit (degree={degree}). Maximum residual: {max(abs(residuals))}')
1505
+ if plot and degree > 0:
1506
+ self.plot_param(param_name, show_nan=False)
1507
+ plt.plot(distances, polyval(coeffs, distances), label='Fitted', color='red', linestyle='--')
1508
+ plt.legend()
1509
+ return coeffs
1510
+
1511
+
1512
+ def _set_distribution(self, param_name, group_name, coeffs, plot=False):
1513
+ # Set the distribution based on the degree of the polynomial fit
1514
+ coeffs = np.where(np.round(coeffs) == 0, coeffs, np.round(coeffs, 10))
1515
+ if len(coeffs) == 1:
1516
+ self.params[param_name][group_name] = Distribution('constant', value=coeffs[0])
1517
+ elif len(coeffs) == 2:
1518
+ self.params[param_name][group_name] = Distribution('linear', slope=coeffs[0], intercept=coeffs[1])
1519
+ else:
1520
+ self.params[param_name][group_name] = Distribution('polynomial', coeffs=coeffs)
1521
+
1522
+
1523
+ # ========================================================================
1524
+ # PLOTTING
1525
+ # ========================================================================
1526
+
1527
+ def plot_param(self, param_name, ax=None, show_nan=True):
1528
+ """
1529
+ Plot the distribution of a parameter in the model.
1530
+
1531
+ Parameters
1532
+ ----------
1533
+ param_name : str
1534
+ The name of the parameter to plot.
1535
+ ax : matplotlib.axes.Axes, optional
1536
+ The axes to plot on. Default is None.
1537
+ show_nan : bool, optional
1538
+ Whether to show NaN values. Default is True.
1539
+ """
1540
+ if ax is None:
1541
+ fig, ax = plt.subplots(figsize=(10, 2))
1542
+
1543
+ if param_name not in self.params:
1544
+ warnings.warn(f'Parameter {param_name} not found.')
1545
+
1546
+ values = [(seg.path_distance(), seg.get_param_value(param_name)) for seg in self.seg_tree]
1547
+ colors = [get_domain_color(seg.domain) for seg in self.seg_tree]
1548
+
1549
+ valid_values = [(x, y) for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
1550
+ zero_values = [(x, y) for (x, y), color in zip(values, colors) if y == 0]
1551
+ nan_values = [(x, 0) for (x, y), color in zip(values, colors) if pd.isna(y)]
1552
+ valid_colors = [color for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
1553
+ zero_colors = [color for (x, y), color in zip(values, colors) if y == 0]
1554
+ nan_colors = [color for (x, y), color in zip(values, colors) if pd.isna(y)]
1555
+
1556
+ if valid_values:
1557
+ ax.scatter(*zip(*valid_values), c=valid_colors)
1558
+ if zero_values:
1559
+ ax.scatter(*zip(*zero_values), edgecolors=zero_colors, facecolors='none', marker='.')
1560
+ if nan_values and show_nan:
1561
+ ax.scatter(*zip(*nan_values), c=nan_colors, marker='x', alpha=0.5, zorder=0)
1562
+ plt.axhline(y=0, color='k', linestyle='--')
1563
+
1564
+ ax.set_xlabel('Path distance')
1565
+ ax.set_ylabel(param_name)
1566
+ ax.set_title(f'{param_name} distribution')
1567
+
1568
+
1569
+ # ========================================================================
1570
+ # FILE EXPORT
1571
+ # ========================================================================
1572
+
1573
+ def export_morphology(self, file_name):
1574
+ """
1575
+ Write the SWC tree to an SWC file.
1576
+
1577
+ Parameters
1578
+ ----------
1579
+ version : str, optional
1580
+ The version of the morphology appended to the morphology name.
1581
+ """
1582
+ path_to_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
1583
+
1584
+ self.point_tree.to_swc(path_to_file)
1585
+
1586
+
1587
+ def to_dict(self):
1588
+ """
1589
+ Return a dictionary representation of the model.
1590
+
1591
+ Returns
1592
+ -------
1593
+ dict
1594
+ The dictionary representation of the model.
1595
+ """
1596
+ return {
1597
+ 'metadata': {
1598
+ 'name': self.name,
1599
+ },
1600
+ 'd_lambda': self.d_lambda,
1601
+ 'domains': {domain: list(mechs) for domain, mechs in self.domains_to_mechs.items()},
1602
+ 'groups': [
1603
+ group.to_dict() for group in self._groups
1604
+ ],
1605
+ 'params': {
1606
+ param_name: {
1607
+ group_name: distribution if isinstance(distribution, str) else distribution.to_dict()
1608
+ for group_name, distribution in distributions.items()
1609
+ }
1610
+ for param_name, distributions in self.params.items()
1611
+ },
1612
+ }
1613
+
1614
+ def from_dict(self, data):
1615
+ """
1616
+ Load the model from a dictionary.
1617
+
1618
+ Parameters
1619
+ ----------
1620
+ data : dict
1621
+ The dictionary representation of the model.
1622
+ """
1623
+ if not self.name == data['metadata']['name']:
1624
+ raise ValueError('Model name does not match the data.')
1625
+
1626
+ self.d_lambda = data['d_lambda']
1627
+
1628
+ # Domains and mechanisms
1629
+ self.domains_to_mechs = {
1630
+ domain: set(mechs) for domain, mechs in data['domains'].items()
1631
+ }
1632
+ if self.verbose: print('Inserting mechanisms...')
1633
+ for domain_name, mechs in self.domains_to_mechs.items():
1634
+ for mech_name in mechs:
1635
+ self.insert_mechanism(mech_name, domain_name, distribute=False)
1636
+ # print('Distributing parameters...')
1637
+ # self.distribute_all()
1638
+
1639
+ # Groups
1640
+ if self.verbose: print('Adding groups...')
1641
+ self._groups = [SegmentGroup.from_dict(group) for group in data['groups']]
1642
+
1643
+ if self.verbose: print('Distributing parameters...')
1644
+ # Parameters
1645
+ self.params = {
1646
+ param_name: {
1647
+ group_name: distribution if isinstance(distribution, str) else Distribution.from_dict(distribution)
1648
+ for group_name, distribution in distributions.items()
1649
+ }
1650
+ for param_name, distributions in data['params'].items()
1651
+ }
1652
+
1653
+ if self.verbose: print('Setting segmentation...')
1654
+ if self.sec_tree is not None:
1655
+ d_lambda = self.d_lambda
1656
+ self.set_segmentation(d_lambda=d_lambda)
1657
+
1658
+
1659
+
1660
+ def export_membrane(self, file_name, **kwargs):
1661
+ """
1662
+ Export the membrane properties of the model to a JSON file.
1663
+
1664
+ Parameters
1665
+ ----------
1666
+ file_name : str
1667
+ The name of the file to write to.
1668
+ **kwargs : dict
1669
+ Additional keyword arguments to pass to `json.dump`.
1670
+ """
1671
+
1672
+ path_to_json = self.path_manager.get_file_path('membrane', file_name, extension='json')
1673
+
1674
+ data = self.to_dict()
1675
+ with open(path_to_json, 'w') as f:
1676
+ json.dump(data, f, **kwargs)
1677
+
1678
+
1679
+ def load_membrane(self, file_name, recompile=True):
1680
+ """
1681
+ Load the membrane properties of the model from a JSON file.
1682
+
1683
+ Parameters
1684
+ ----------
1685
+ file_name : str
1686
+ The name of the file to read from.
1687
+ recompile : bool, optional
1688
+ Whether to recompile the mechanisms after loading. Default is True.
1689
+ """
1690
+ self.add_default_mechanisms()
1691
+ self.add_mechanisms('mod', recompile=recompile)
1692
+
1693
+ path_to_json = self.path_manager.get_file_path('membrane', file_name, extension='json')
1694
+
1695
+ with open(path_to_json, 'r') as f:
1696
+ data = json.load(f)
1697
+
1698
+ self.from_dict(data)
1699
+
1700
+
1701
+ def stimuli_to_dict(self):
1702
+ """
1703
+ Convert the stimuli to a dictionary representation.
1704
+
1705
+ Returns
1706
+ -------
1707
+ dict
1708
+ The dictionary representation of the stimuli.
1709
+ """
1710
+ return {
1711
+ 'metadata': {
1712
+ 'name': self.name,
1713
+ },
1714
+ 'simulation': {
1715
+ **self.simulator.to_dict(),
1716
+ },
1717
+ 'stimuli': {
1718
+ 'iclamps': [
1719
+ {
1720
+ 'name': f'iclamp_{i}',
1721
+ 'amp': iclamp.amp,
1722
+ 'delay': iclamp.delay,
1723
+ 'dur': iclamp.dur
1724
+ }
1725
+ for i, (seg, iclamp) in enumerate(self.iclamps.items())
1726
+ ],
1727
+ 'populations': {
1728
+ syn_type: [pop.to_dict() for pop in pops.values()]
1729
+ for syn_type, pops in self.populations.items()
1730
+ }
1731
+ },
1732
+ }
1733
+
1734
+
1735
+ def _stimuli_to_csv(self, path_to_csv=None):
1736
+ """
1737
+ Write the model to a CSV file.
1738
+
1739
+ Parameters
1740
+ ----------
1741
+ path_to_csv : str
1742
+ The path to the CSV file to write.
1743
+ """
1744
+
1745
+ rec_data = {
1746
+ 'type': ['recording'] * len(self.recordings),
1747
+ 'idx': [i for i in range(len(self.recordings))],
1748
+ 'sec_idx': [seg._section.idx for seg in self.recordings],
1749
+ 'loc': [seg.x for seg in self.recordings],
1750
+ }
1751
+
1752
+ iclamp_data = {
1753
+ 'type': ['iclamp'] * len(self.iclamps),
1754
+ 'idx': [i for i in range(len(self.iclamps))],
1755
+ 'sec_idx': [seg._section.idx for seg in self.iclamps],
1756
+ 'loc': [seg.x for seg in self.iclamps],
1757
+ }
1758
+
1759
+ synapses_data = {
1760
+ 'type': [],
1761
+ 'idx': [],
1762
+ 'sec_idx': [],
1763
+ 'loc': [],
1764
+ }
1765
+
1766
+ for syn_type, pops in self.populations.items():
1767
+ for pop_name, pop in pops.items():
1768
+ pop_data = pop.to_csv()
1769
+ synapses_data['type'] += pop_data['syn_type']
1770
+ synapses_data['idx'] += [int(name.rsplit('_', 1)[1]) for name in pop_data['name']]
1771
+ synapses_data['sec_idx'] += pop_data['sec_idx']
1772
+ synapses_data['loc'] += pop_data['loc']
1773
+
1774
+ df = pd.concat([
1775
+ pd.DataFrame(rec_data),
1776
+ pd.DataFrame(iclamp_data),
1777
+ pd.DataFrame(synapses_data)
1778
+ ], ignore_index=True)
1779
+ df['sec_idx'] = df['sec_idx'].astype(int)
1780
+ if path_to_csv: df.to_csv(path_to_csv, index=False)
1781
+
1782
+ return df
1783
+
1784
+
1785
+ def export_stimuli(self, file_name, **kwargs):
1786
+ """
1787
+ Export the stimuli to a JSON and CSV file.
1788
+
1789
+ Parameters
1790
+ ----------
1791
+ file_name : str
1792
+ The name of the file to write to.
1793
+ **kwargs : dict
1794
+ Additional keyword arguments to pass to `json.dump`.
1795
+ """
1796
+ path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
1797
+
1798
+ data = self.stimuli_to_dict()
1799
+
1800
+ with open(path_to_json, 'w') as f:
1801
+ json.dump(data, f, **kwargs)
1802
+
1803
+ path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
1804
+ self._stimuli_to_csv(path_to_stimuli_csv)
1805
+
1806
+
1807
+ def load_stimuli(self, file_name):
1808
+ """
1809
+ Load the stimuli from a JSON file.
1810
+
1811
+ Parameters
1812
+ ----------
1813
+ file_name : str
1814
+ The name of the file to read from.
1815
+ """
1816
+
1817
+ path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
1818
+ path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
1819
+
1820
+ with open(path_to_json, 'r') as f:
1821
+ data = json.load(f)
1822
+
1823
+ if not self.name == data['metadata']['name']:
1824
+ raise ValueError('Model name does not match the data.')
1825
+
1826
+ df_stimuli = pd.read_csv(path_to_stimuli_csv)
1827
+
1828
+ self.simulator.from_dict(data['simulation'])
1829
+
1830
+ # Recordings ---------------------------------------------------------
1831
+
1832
+ df_recs = df_stimuli[df_stimuli['type'] == 'recording']
1833
+ for i, row in df_recs.iterrows():
1834
+ self.add_recording(
1835
+ self.sec_tree.sections[row['sec_idx']], row['loc']
1836
+ )
1837
+
1838
+ # IClamps -----------------------------------------------------------
1839
+
1840
+ df_iclamps = df_stimuli[df_stimuli['type'] == 'iclamp'].reset_index(drop=True, inplace=False)
1841
+
1842
+ for i, row in df_iclamps.iterrows():
1843
+ self.add_iclamp(
1844
+ self.sec_tree.sections[row['sec_idx']],
1845
+ row['loc'],
1846
+ data['stimuli']['iclamps'][i]['amp'],
1847
+ data['stimuli']['iclamps'][i]['delay'],
1848
+ data['stimuli']['iclamps'][i]['dur']
1849
+ )
1850
+
1851
+ # Populations -------------------------------------------------------
1852
+
1853
+ syn_types = ['AMPA', 'NMDA', 'AMPA_NMDA', 'GABAa']
1854
+
1855
+ for syn_type in syn_types:
1856
+
1857
+ df_syn = df_stimuli[df_stimuli['type'] == syn_type]
1858
+
1859
+ for i, pop_data in enumerate(data['stimuli']['populations'][syn_type]):
1860
+
1861
+ df_pop = df_syn[df_syn['idx'] == i]
1862
+
1863
+ segments = [self.sec_tree.sections[sec_idx](loc)
1864
+ for sec_idx, loc in zip(df_pop['sec_idx'], df_pop['loc'])]
1865
+
1866
+ pop = Population(i,
1867
+ segments,
1868
+ pop_data['N'],
1869
+ 'AMPA')
1870
+
1871
+ syn_locs = [(self.sec_tree.sections[sec_idx], loc) for sec_idx, loc in zip(df_pop['sec_idx'].tolist(), df_pop['loc'].tolist())]
1872
+
1873
+ pop.allocate_synapses(syn_locs=syn_locs)
1874
+ pop.update_kinetic_params(**pop_data['kinetic_params'])
1875
+ pop.update_input_params(**pop_data['input_params'])
1876
+ self._add_population(pop)
1877
+
1878
+
1879
+ def export_to_NEURON(self, file_name, include_kinetic_params=True):
1880
+ """
1881
+ Export the model to a python file using NEURON.
1882
+
1883
+ Parameters
1884
+ ----------
1885
+ file_name : str
1886
+ The name of the file to write to.
1887
+ """
1888
+ from dendrotweaks.model_io import render_template
1889
+ from dendrotweaks.model_io import get_params_to_valid_domains
1890
+ from dendrotweaks.model_io import filter_params
1891
+ from dendrotweaks.model_io import get_neuron_domain
1892
+
1893
+ params_to_valid_domains = get_params_to_valid_domains(self)
1894
+ params = self.params if include_kinetic_params else filter_params(self)
1895
+ path_to_template = self.path_manager.get_file_path('templates', 'NEURON_template', extension='py')
1896
+
1897
+ output = render_template(path_to_template,
1898
+ {
1899
+ 'param_dict': params,
1900
+ 'groups_dict': self.groups,
1901
+ 'params_to_mechs': self.params_to_mechs,
1902
+ 'domains_to_mechs': self.domains_to_mechs,
1903
+ 'iclamps': self.iclamps,
1904
+ 'recordings': self.simulator.recordings,
1905
+ 'params_to_valid_domains': params_to_valid_domains,
1906
+ 'domains_to_NEURON': {domain: get_neuron_domain(domain) for domain in self.domains},
1907
+ })
1908
+
1909
+ if not file_name.endswith('.py'):
1910
+ file_name += '.py'
1911
+ path_to_model = self.path_manager.path_to_model
1912
+ output_path = os.path.join(path_to_model, file_name)
1913
+ with open(output_path, 'w') as f:
1914
+ f.write(output)
1915
+
1916
+