dendrotweaks 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dendrotweaks/model.py CHANGED
@@ -1,63 +1,46 @@
1
+ # Imports
1
2
  from typing import List, Union, Callable
2
3
  import os
3
- import json
4
+ from collections import defaultdict
5
+
4
6
  import matplotlib.pyplot as plt
5
7
  import numpy as np
8
+ from numpy import nan
9
+ import pandas as pd
6
10
  import quantities as pq
7
11
 
8
- from dendrotweaks.morphology.point_trees import PointTree
9
- from dendrotweaks.morphology.sec_trees import NeuronSection, Section, SectionTree, Domain
10
- from dendrotweaks.morphology.seg_trees import NeuronSegment, Segment, SegmentTree
12
+ # DendroTweaks imports
11
13
  from dendrotweaks.simulators import NeuronSimulator
12
- from dendrotweaks.biophys.groups import SegmentGroup
13
- from dendrotweaks.biophys.mechanisms import Mechanism, LeakChannel, CaDynamics
14
- from dendrotweaks.biophys.io import create_channel, standardize_channel, create_standard_channel
15
14
  from dendrotweaks.biophys.io import MODFileLoader
16
- from dendrotweaks.morphology.io import create_point_tree, create_section_tree, create_segment_tree
17
- from dendrotweaks.stimuli.iclamps import IClamp
15
+ from dendrotweaks.morphology import Domain
16
+ from dendrotweaks.biophys.groups import SegmentGroup
18
17
  from dendrotweaks.biophys.distributions import Distribution
19
- from dendrotweaks.stimuli.populations import Population
20
- from dendrotweaks.utils import calculate_lambda_f, dynamic_import
21
- from dendrotweaks.utils import get_domain_color, timeit
22
-
23
- from collections import OrderedDict, defaultdict
24
- from numpy import nan
25
- # from .logger import logger
26
-
27
18
  from dendrotweaks.path_manager import PathManager
28
19
  import dendrotweaks.morphology.reduce as rdc
20
+ from dendrotweaks.utils import INDEPENDENT_PARAMS, DOMAIN_TO_GROUP, POPULATIONS
21
+ from dendrotweaks.utils import DEFAULT_FIT_MODELS
29
22
 
30
- import pandas as pd
23
+ # Mixins
24
+ from dendrotweaks.model_io import IOMixin
25
+ from dendrotweaks.model_simulation import SimulationMixin
31
26
 
27
+ # Warnings configuration
32
28
  import warnings
33
29
 
34
- POPULATIONS = {'AMPA': {}, 'NMDA': {}, 'AMPA_NMDA': {}, 'GABAa': {}}
35
-
36
30
  def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
37
31
  return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
38
32
 
39
33
  warnings.formatwarning = custom_warning_formatter
40
34
 
41
- INDEPENDENT_PARAMS = {
42
- 'cm': 1, # uF/cm2
43
- 'Ra': 100, # Ohm cm
44
- 'ena': 50, # mV
45
- 'ek': -77, # mV
46
- 'eca': 140 # mV
47
- }
48
35
 
49
- DOMAIN_TO_GROUP = {
50
- 'soma': 'somatic',
51
- 'axon': 'axonal',
52
- 'dend': 'dendritic',
53
- 'apic': 'apical',
54
- }
55
36
 
56
-
57
- class Model():
37
+ class Model(IOMixin, SimulationMixin):
58
38
  """
59
39
  A model object that represents a neuron model.
60
40
 
41
+ The class incorporates various mixins to separate concerns while
42
+ maintaining a flat interface.
43
+
61
44
  Parameters
62
45
  ----------
63
46
  name : str
@@ -168,6 +151,7 @@ class Model():
168
151
  """
169
152
  return self._name
170
153
 
154
+
171
155
  @property
172
156
  def verbose(self):
173
157
  """
@@ -175,6 +159,7 @@ class Model():
175
159
  """
176
160
  return self._verbose
177
161
 
162
+
178
163
  @verbose.setter
179
164
  def verbose(self, value):
180
165
  self._verbose = value
@@ -191,18 +176,17 @@ class Model():
191
176
 
192
177
 
193
178
  @property
194
- def recordings(self):
179
+ def mechs_to_domains(self):
195
180
  """
196
- The recordings of the model. Reference to the recordings in the simulator.
181
+ The dictionary mapping mechanisms to domains where they are inserted.
197
182
  """
198
- return self.simulator.recordings
199
-
200
-
201
- @recordings.setter
202
- def recordings(self, recordings):
203
- self.simulator.recordings = recordings
204
-
183
+ mechs_to_domains = defaultdict(set)
184
+ for domain_name, mech_names in self.domains_to_mechs.items():
185
+ for mech_name in mech_names:
186
+ mechs_to_domains[mech_name].add(domain_name)
187
+ return dict(mechs_to_domains)
205
188
 
189
+
206
190
  @property
207
191
  def groups(self):
208
192
  """
@@ -225,17 +209,6 @@ class Model():
225
209
  groups_to_parameters[group.name] = params
226
210
  return groups_to_parameters
227
211
 
228
- @property
229
- def mechs_to_domains(self):
230
- """
231
- The dictionary mapping mechanisms to domains where they are inserted.
232
- """
233
- mechs_to_domains = defaultdict(set)
234
- for domain_name, mech_names in self.domains_to_mechs.items():
235
- for mech_name in mech_names:
236
- mechs_to_domains[mech_name].add(domain_name)
237
- return dict(mechs_to_domains)
238
-
239
212
 
240
213
  @property
241
214
  def parameters_to_groups(self):
@@ -291,25 +264,6 @@ class Model():
291
264
  """
292
265
  return {param: value for param, value in self.params.items()
293
266
  if param.startswith('gbar')}
294
- # -----------------------------------------------------------------------
295
- # METADATA
296
- # -----------------------------------------------------------------------
297
-
298
- def info(self):
299
- """
300
- Print information about the model.
301
- """
302
- info_str = (
303
- f"Model: {self.name}\n"
304
- f"Path to data: {self.path_manager.path_to_data}\n"
305
- f"Simulator: {self.simulator_name}\n"
306
- f"Groups: {len(self.groups)}\n"
307
- f"Avaliable mechanisms: {len(self.mechanisms)}\n"
308
- f"Inserted mechanisms: {len(self.mechs_to_params) - 1}\n"
309
- # f"Parameters: {len(self.parameters)}\n"
310
- f"IClamps: {len(self.iclamps)}\n"
311
- )
312
- print(info_str)
313
267
 
314
268
 
315
269
  @property
@@ -331,365 +285,6 @@ class Model():
331
285
  df = pd.DataFrame(data)
332
286
  return df
333
287
 
334
- def print_directory_tree(self, *args, **kwargs):
335
- """
336
- Print the directory tree.
337
- """
338
- return self.path_manager.print_directory_tree(*args, **kwargs)
339
-
340
- def list_morphologies(self, extension='swc'):
341
- """
342
- List the morphologies available for the model.
343
- """
344
- return self.path_manager.list_files('morphology', extension=extension)
345
-
346
- def list_biophys(self, extension='json'):
347
- """
348
- List the biophysical configurations available for the model.
349
- """
350
- return self.path_manager.list_files('biophys', extension=extension)
351
-
352
- def list_mechanisms(self, extension='mod'):
353
- """
354
- List the mechanisms available for the model.
355
- """
356
- return self.path_manager.list_files('mod', extension=extension)
357
-
358
- def list_stimuli(self, extension='json'):
359
- """
360
- List the stimuli configurations available for the model.
361
- """
362
- return self.path_manager.list_files('stimuli', extension=extension)
363
-
364
- # ========================================================================
365
- # MORPHOLOGY
366
- # ========================================================================
367
-
368
- def load_morphology(self, file_name, soma_notation='3PS',
369
- align=True, sort_children=True, force=False) -> None:
370
- """
371
- Read an SWC file and build the SWC and section trees.
372
-
373
- Parameters
374
- ----------
375
- file_name : str
376
- The name of the SWC file to read.
377
- soma_notation : str, optional
378
- The notation of the soma in the SWC file. Can be '3PS' (three-point soma) or '1PS'. Default is '3PS'.
379
- align : bool, optional
380
- Whether to align the morphology to the soma center and align the apical dendrite (if present).
381
- sort_children : bool, optional
382
- Whether to sort the children of each node by increasing subtree size
383
- in the tree sorting algorithms. If True, the traversal visits
384
- children with shorter subtrees first and assigns them lower indices. If False, children
385
- are visited in their original SWC file order (matching NEURON's behavior).
386
- """
387
- # self.name = file_name.split('.')[0]
388
- self.morphology_name = file_name.replace('.swc', '')
389
- path_to_swc_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
390
- point_tree = create_point_tree(path_to_swc_file)
391
- # point_tree.remove_overlaps()
392
- point_tree.change_soma_notation(soma_notation)
393
- point_tree.sort(sort_children=sort_children, force=force)
394
- if align:
395
- point_tree.shift_coordinates_to_soma_center()
396
- point_tree.align_apical_dendrite()
397
- point_tree.round_coordinates(8)
398
- self.point_tree = point_tree
399
-
400
- sec_tree = create_section_tree(point_tree)
401
- sec_tree.sort(sort_children=sort_children, force=force)
402
- self.sec_tree = sec_tree
403
-
404
- self.create_and_reference_sections_in_simulator()
405
- seg_tree = create_segment_tree(sec_tree)
406
- self.seg_tree = seg_tree
407
-
408
- self._add_default_segment_groups()
409
- self._initialize_domains_to_mechs()
410
-
411
- d_lambda = self.d_lambda
412
- self.set_segmentation(d_lambda=d_lambda)
413
-
414
-
415
- def create_and_reference_sections_in_simulator(self):
416
- """
417
- Create and reference sections in the simulator.
418
- """
419
- if self.verbose: print(f'Building sections in {self.simulator_name}...')
420
- for sec in self.sec_tree.sections:
421
- sec.create_and_reference()
422
- n_sec = len([sec._ref for sec in self.sec_tree.sections
423
- if sec._ref is not None])
424
- if self.verbose: print(f'{n_sec} sections created.')
425
-
426
-
427
-
428
-
429
- def _add_default_segment_groups(self):
430
- self.add_group('all', list(self.domains.keys()))
431
- for domain_name in self.domains:
432
- group_name = DOMAIN_TO_GROUP.get(domain_name, domain_name)
433
- self.add_group(group_name, [domain_name])
434
-
435
-
436
- def _initialize_domains_to_mechs(self):
437
- for domain_name in self.domains:
438
- # Only if haven't been defined for the previous morphology
439
- # TODO: Check that domains match
440
- if not domain_name in self.domains_to_mechs:
441
- self.domains_to_mechs[domain_name] = set()
442
- for domain_name, mech_names in self.domains_to_mechs.items():
443
- for mech_name in mech_names:
444
- mech = self.mechanisms[mech_name]
445
- self.insert_mechanism(mech, domain_name)
446
-
447
-
448
- def get_sections(self, filter_function):
449
- """Filter sections using a lambda function.
450
-
451
- Parameters
452
- ----------
453
- filter_function : Callable
454
- The lambda function to filter sections.
455
- """
456
- return [sec for sec in self.sec_tree.sections if filter_function(sec)]
457
-
458
-
459
- def get_segments(self, group_names=None):
460
- """
461
- Get the segments in specified groups.
462
-
463
- Parameters
464
- ----------
465
- group_names : List[str]
466
- The names of the groups to get segments from.
467
- """
468
- if not isinstance(group_names, list):
469
- raise ValueError('Group names must be a list.')
470
- return [seg for group_name in group_names for seg in self.seg_tree.segments if seg in self.groups[group_name]]
471
-
472
- # ========================================================================
473
- # SEGMENTATION
474
- # ========================================================================
475
-
476
- # TODO Make a context manager for this
477
- def _temp_clear_stimuli(self):
478
- """
479
- Temporarily save and clear stimuli.
480
- """
481
- self.export_stimuli(file_name='_temp_stimuli')
482
- self.remove_all_stimuli()
483
- self.remove_all_recordings()
484
-
485
- def _temp_reload_stimuli(self):
486
- """
487
- Load stimuli from a temporary file and clean up.
488
- """
489
- self.load_stimuli(file_name='_temp_stimuli')
490
- for ext in ['json', 'csv']:
491
- temp_path = self.path_manager.get_file_path('stimuli', '_temp_stimuli', extension=ext)
492
- if os.path.exists(temp_path):
493
- os.remove(temp_path)
494
-
495
- def set_segmentation(self, d_lambda=0.1, f=100):
496
- """
497
- Set the number of segments in each section based on the geometry.
498
-
499
- Parameters
500
- ----------
501
- d_lambda : float
502
- The lambda value to use.
503
- f : float
504
- The frequency value to use.
505
- """
506
- self.d_lambda = d_lambda
507
-
508
- # Temporarily save and clear stimuli
509
- self._temp_clear_stimuli()
510
-
511
- # Pre-distribute parameters needed for lambda_f calculation
512
- for param_name in ['cm', 'Ra']:
513
- self.distribute(param_name)
514
-
515
- # Calculate lambda_f and set nseg for each section
516
- for sec in self.sec_tree.sections:
517
- lambda_f = calculate_lambda_f(sec.distances, sec.diameters, sec.Ra, sec.cm, f)
518
- nseg = max(1, int((sec.L / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1)
519
- sec._nseg = sec._ref.nseg = nseg
520
-
521
- # Rebuild the segment tree and redistribute parameters
522
- self.seg_tree = create_segment_tree(self.sec_tree)
523
- self.distribute_all()
524
-
525
- # Reload stimuli and clean up temporary files
526
- self._temp_reload_stimuli()
527
-
528
-
529
- # ========================================================================
530
- # MECHANISMS
531
- # ========================================================================
532
-
533
- def add_default_mechanisms(self, recompile=False):
534
- """
535
- Add default mechanisms to the model.
536
-
537
- Parameters
538
- ----------
539
- recompile : bool, optional
540
- Whether to recompile the mechanisms.
541
- """
542
- leak = LeakChannel()
543
- self.mechanisms[leak.name] = leak
544
-
545
- cadyn = CaDynamics()
546
- self.mechanisms[cadyn.name] = cadyn
547
-
548
- self.load_mechanisms('default_mod', recompile=recompile)
549
-
550
-
551
- def add_mechanisms(self, dir_name:str = 'mod', recompile=True) -> None:
552
- """
553
- Add a set of mechanisms from an archive to the model.
554
-
555
- Parameters
556
- ----------
557
- dir_name : str, optional
558
- The name of the archive to load mechanisms from. Default is 'mod'.
559
- recompile : bool, optional
560
- Whether to recompile the mechanisms.
561
- """
562
- # Create Mechanism objects and add them to the model
563
- for mechanism_name in self.path_manager.list_files(dir_name, extension='mod'):
564
- self.add_mechanism(mechanism_name,
565
- load=True,
566
- dir_name=dir_name,
567
- recompile=recompile)
568
-
569
-
570
-
571
- def add_mechanism(self, mechanism_name: str,
572
- python_template_name: str = 'default',
573
- load=True, dir_name: str = 'mod', recompile=True
574
- ) -> None:
575
- """
576
- Create a Mechanism object from the MOD file (or LeakChannel).
577
-
578
- Parameters
579
- ----------
580
- mechanism_name : str
581
- The name of the mechanism to add.
582
- python_template_name : str, optional
583
- The name of the Python template to use. Default is 'default'.
584
- load : bool, optional
585
- Whether to load the mechanism using neuron.load_mechanisms.
586
- """
587
- paths = self.path_manager.get_channel_paths(
588
- mechanism_name,
589
- python_template_name=python_template_name
590
- )
591
- mech = create_channel(**paths)
592
- # Add the mechanism to the model
593
- self.mechanisms[mech.name] = mech
594
- # Update the global parameters
595
-
596
- if load:
597
- self.load_mechanism(mechanism_name, dir_name, recompile)
598
-
599
-
600
-
601
- def load_mechanisms(self, dir_name: str = 'mod', recompile=True) -> None:
602
- """
603
- Load mechanisms from an archive.
604
-
605
- Parameters
606
- ----------
607
- dir_name : str, optional
608
- The name of the archive to load mechanisms from.
609
- recompile : bool, optional
610
- Whether to recompile the mechanisms.
611
- """
612
- mod_files = self.path_manager.list_files(dir_name, extension='mod')
613
- for mechanism_name in mod_files:
614
- self.load_mechanism(mechanism_name, dir_name, recompile)
615
-
616
-
617
- def load_mechanism(self, mechanism_name, dir_name='mod', recompile=False) -> None:
618
- """
619
- Load a mechanism from the specified archive.
620
-
621
- Parameters
622
- ----------
623
- mechanism_name : str
624
- The name of the mechanism to load.
625
- dir_name : str, optional
626
- The name of the directory to load the mechanism from. Default is 'mod'.
627
- recompile : bool, optional
628
- Whether to recompile the mechanism.
629
- """
630
- path_to_mod_file = self.path_manager.get_file_path(
631
- dir_name, mechanism_name, extension='mod'
632
- )
633
- self.mod_loader.load_mechanism(
634
- path_to_mod_file=path_to_mod_file, recompile=recompile
635
- )
636
-
637
-
638
- def standardize_channel(self, channel_name,
639
- python_template_name=None, mod_template_name=None, remove_old=True):
640
- """
641
- Standardize a channel by creating a new channel with the same kinetic
642
- properties using the standard equations.
643
-
644
- Parameters
645
- ----------
646
- channel_name : str
647
- The name of the channel to standardize.
648
- python_template_name : str, optional
649
- The name of the Python template to use.
650
- mod_template_name : str, optional
651
- The name of the MOD template to use.
652
- remove_old : bool, optional
653
- Whether to remove the old channel from the model. Default is True.
654
- """
655
-
656
- # Get data to transfer
657
- channel = self.mechanisms[channel_name]
658
- channel_domain_names = [domain_name for domain_name, mech_names
659
- in self.domains_to_mechs.items() if channel_name in mech_names]
660
- gbar_name = f'gbar_{channel_name}'
661
- gbar_distributions = self.params[gbar_name]
662
- # Kinetic variables cannot be transferred
663
-
664
- # Uninsert the old channel
665
- for domain_name in self.domains:
666
- if channel_name in self.domains_to_mechs[domain_name]:
667
- self.uninsert_mechanism(channel_name, domain_name)
668
-
669
- # Remove the old channel
670
- if remove_old:
671
- self.mechanisms.pop(channel_name)
672
-
673
- # Create, add and load a new channel
674
- paths = self.path_manager.get_standard_channel_paths(
675
- channel_name,
676
- mod_template_name=mod_template_name
677
- )
678
- standard_channel = standardize_channel(channel, **paths)
679
-
680
- self.mechanisms[standard_channel.name] = standard_channel
681
- self.load_mechanism(standard_channel.name, recompile=True)
682
-
683
- # Insert the new channel
684
- for domain_name in channel_domain_names:
685
- self.insert_mechanism(standard_channel.name, domain_name)
686
-
687
- # Transfer data
688
- gbar_name = f'gbar_{standard_channel.name}'
689
- for group_name, distribution in gbar_distributions.items():
690
- self.set_param(gbar_name, group_name,
691
- distribution.function_name, **distribution.parameters)
692
-
693
288
 
694
289
  # ========================================================================
695
290
  # DOMAINS
@@ -816,9 +411,9 @@ class Model():
816
411
  self.remove_group(group.name)
817
412
 
818
413
 
819
- # -----------------------------------------------------------------------
820
- # INSERT / UNINSERT MECHANISMS
821
- # -----------------------------------------------------------------------
414
+ # ========================================================================
415
+ # MECHANISMS
416
+ # ========================================================================
822
417
 
823
418
  def insert_mechanism(self, mechanism_name: str,
824
419
  domain_name: str, distribute=True):
@@ -922,11 +517,11 @@ class Model():
922
517
 
923
518
 
924
519
  # ========================================================================
925
- # SET PARAMETERS
520
+ # PARAMETERS
926
521
  # ========================================================================
927
522
 
928
523
  # -----------------------------------------------------------------------
929
- # GROUPS
524
+ # SEGMENT GROUPS (Where)
930
525
  # -----------------------------------------------------------------------
931
526
 
932
527
  def add_group(self, name, domains, select_by=None, min_value=None, max_value=None):
@@ -1001,7 +596,7 @@ class Model():
1001
596
 
1002
597
 
1003
598
  # -----------------------------------------------------------------------
1004
- # DISTRIBUTIONS
599
+ # DISTRIBUTIONS (How)
1005
600
  # -----------------------------------------------------------------------
1006
601
 
1007
602
  def set_param(self, param_name: str,
@@ -1063,6 +658,7 @@ class Model():
1063
658
  distribution = Distribution(distr_type, **distr_params)
1064
659
  self.params[param_name][group_name] = distribution
1065
660
 
661
+
1066
662
  def distribute_all(self):
1067
663
  """
1068
664
  Distribute all parameters to the segments.
@@ -1146,230 +742,152 @@ class Model():
1146
742
  self.params[param_name].pop(group_name, None)
1147
743
  self.distribute(param_name)
1148
744
 
1149
- # def set_section_param(self, param_name, value, domains=None):
1150
-
1151
- # domains = domains or self.domains
1152
- # for sec in self.sec_tree.sections:
1153
- # if sec.domain in domains:
1154
- # setattr(sec._ref, param_name, value)
1155
-
1156
- # ========================================================================
1157
- # STIMULI
1158
- # ========================================================================
1159
745
 
1160
746
  # -----------------------------------------------------------------------
1161
- # ICLAMPS
747
+ # FITTING
1162
748
  # -----------------------------------------------------------------------
1163
749
 
1164
- def add_iclamp(self, sec, loc, amp=0, delay=100, dur=100):
1165
- """
1166
- Add an IClamp to a section.
750
+ def fit_distribution(self, param_name: str, segments, candidate_models=None, plot=True):
751
+ if candidate_models is None:
752
+ candidate_models = DEFAULT_FIT_MODELS
1167
753
 
1168
- Parameters
1169
- ----------
1170
- sec : Section
1171
- The section to add the IClamp to.
1172
- loc : float
1173
- The location of the IClamp in the section.
1174
- amp : float, optional
1175
- The amplitude of the IClamp. Default is 0.
1176
- delay : float, optional
1177
- The delay of the IClamp. Default is 100.
1178
- dur : float, optional
1179
- The duration of the IClamp. Default is 100.
1180
- """
1181
- seg = sec(loc)
1182
- if self.iclamps.get(seg):
1183
- self.remove_iclamp(sec, loc)
1184
- iclamp = IClamp(sec, loc, amp, delay, dur)
1185
- print(f'IClamp added to sec {sec} at loc {loc}.')
1186
- self.iclamps[seg] = iclamp
754
+ from dendrotweaks.utils import mse
1187
755
 
756
+ values = [seg.get_param_value(param_name) for seg in segments]
757
+ if all(np.isnan(values)):
758
+ return None
1188
759
 
1189
- def remove_iclamp(self, sec, loc):
1190
- """
1191
- Remove an IClamp from a section.
760
+ distances = [seg.path_distance() for seg in segments]
761
+ distances, values = zip(*sorted(zip(distances, values)))
1192
762
 
1193
- Parameters
1194
- ----------
1195
- sec : Section
1196
- The section to remove the IClamp from.
1197
- loc : float
1198
- The location of the IClamp in the section.
1199
- """
1200
- seg = sec(loc)
1201
- if self.iclamps.get(seg):
1202
- self.iclamps.pop(seg)
763
+ best_score = float('inf')
764
+ best_model = None
765
+ best_params = None
766
+ best_pred = None
1203
767
 
768
+ results = []
1204
769
 
1205
- def remove_all_iclamps(self):
1206
- """
1207
- Remove all IClamps from the model.
1208
- """
770
+ for name, model in candidate_models.items():
771
+ try:
772
+ params, pred_values = model['fit'](distances, values)
773
+ score = model.get('score', mse)(values, pred_values)
774
+ complexity = model.get('complexity', 1)(params)
775
+ results.append((name, score, params, complexity, pred_values))
776
+ except Exception as e:
777
+ warnings.warn(f"Model {name} failed to fit: {e}")
1209
778
 
1210
- for seg in list(self.iclamps.keys()):
1211
- sec, loc = seg._section, seg.x
1212
- self.remove_iclamp(sec, loc)
1213
- if self.iclamps:
1214
- warnings.warn(f'Not all iclamps were removed: {self.iclamps}')
1215
- self.iclamps = {}
779
+ # Sort results by score and complexity
780
+ results.sort(key=lambda x: (np.round(x[1], 10), x[3]))
1216
781
 
782
+ best_model, best_score, best_params, _, best_pred = results[0]
1217
783
 
1218
- # -----------------------------------------------------------------------
1219
- # SYNAPSES
1220
- # -----------------------------------------------------------------------
784
+ if plot:
785
+ self.plot_param(param_name, show_nan=False)
786
+ plt.plot(distances, best_pred, label=f'Best Fit: {best_model}', color='red', linestyle='--')
787
+ plt.legend()
1221
788
 
1222
- def _add_population(self, population):
1223
- self.populations[population.syn_type][population.name] = population
789
+ return {'model': best_model, 'params': best_params, 'score': best_score}
1224
790
 
1225
791
 
1226
- def add_population(self, segments, N, syn_type):
1227
- """
1228
- Add a population of synapses to the model.
792
+ def _set_distribution(self, param_name, group_name, fit_result, plot=False):
793
+ if fit_result is None:
794
+ warnings.warn(f"No valid fit found for parameter {param_name}. Skipping distribution assignment.")
795
+ return
1229
796
 
1230
- Parameters
1231
- ----------
1232
- segments : list[Segment]
1233
- The segments to add the synapses to.
1234
- N : int
1235
- The number of synapses to add.
1236
- syn_type : str
1237
- The type of synapse to add.
1238
- """
1239
- idx = len(self.populations[syn_type])
1240
- population = Population(idx, segments, N, syn_type)
1241
- population.allocate_synapses()
1242
- population.create_inputs()
1243
- self._add_population(population)
797
+ model_type = fit_result['model']
798
+ params = fit_result['params']
1244
799
 
800
+ if model_type == 'poly':
801
+ coeffs = np.array(params)
802
+ coeffs = np.where(np.round(coeffs) == 0, coeffs, np.round(coeffs, 10))
803
+ if len(coeffs) == 1:
804
+ self.params[param_name][group_name] = Distribution('constant', value=coeffs[0])
805
+ elif len(coeffs) == 2:
806
+ self.params[param_name][group_name] = Distribution('linear', slope=coeffs[0], intercept=coeffs[1])
807
+ else:
808
+ self.params[param_name][group_name] = Distribution('polynomial', coeffs=coeffs.tolist())
1245
809
 
1246
- def update_population_kinetic_params(self, pop_name, **params):
1247
- """
1248
- Update the kinetic parameters of a population of synapses.
810
+ elif model_type == 'step':
811
+ start, end, min_value, max_value = params
812
+ self.params[param_name][group_name] = Distribution('step', max_value=max_value, min_value=min_value, start=start, end=end)
1249
813
 
1250
- Parameters
1251
- ----------
1252
- pop_name : str
1253
- The name of the population.
1254
- params : dict
1255
- The parameters to update.
1256
- """
1257
- syn_type, idx = pop_name.rsplit('_', 1)
1258
- population = self.populations[syn_type][pop_name]
1259
- population.update_kinetic_params(**params)
1260
- print(population.kinetic_params)
1261
814
 
1262
-
1263
- def update_population_input_params(self, pop_name, **params):
815
+ # -----------------------------------------------------------------------
816
+ # PLOTTING
817
+ # -----------------------------------------------------------------------
818
+
819
+ def plot_param(self, param_name, ax=None, show_nan=True):
1264
820
  """
1265
- Update the input parameters of a population of synapses.
821
+ Plot the distribution of a parameter in the model.
1266
822
 
1267
823
  Parameters
1268
824
  ----------
1269
- pop_name : str
1270
- The name of the population.
1271
- params : dict
1272
- The parameters to update.
825
+ param_name : str
826
+ The name of the parameter to plot.
827
+ ax : matplotlib.axes.Axes, optional
828
+ The axes to plot on. Default is None.
829
+ show_nan : bool, optional
830
+ Whether to show NaN values. Default is True.
1273
831
  """
1274
- syn_type, idx = pop_name.rsplit('_', 1)
1275
- population = self.populations[syn_type][pop_name]
1276
- population.update_input_params(**params)
1277
- print(population.input_params)
832
+ from dendrotweaks.utils import get_domain_color
1278
833
 
834
+ if ax is None:
835
+ fig, ax = plt.subplots(figsize=(10, 2))
1279
836
 
1280
- def remove_population(self, name):
1281
- """
1282
- Remove a population of synapses from the model.
1283
-
1284
- Parameters
1285
- ----------
1286
- name : str
1287
- The name of the population
1288
- """
1289
- syn_type, idx = name.rsplit('_', 1)
1290
- population = self.populations[syn_type].pop(name)
1291
- population.clean()
1292
-
1293
- def remove_all_populations(self):
1294
- """
1295
- Remove all populations of synapses from the model.
1296
- """
1297
- for syn_type in self.populations:
1298
- for name in list(self.populations[syn_type].keys()):
1299
- self.remove_population(name)
1300
- if any(self.populations.values()):
1301
- warnings.warn(f'Not all populations were removed: {self.populations}')
1302
- self.populations = POPULATIONS
837
+ if param_name not in self.params:
838
+ warnings.warn(f'Parameter {param_name} not found.')
1303
839
 
1304
- def remove_all_stimuli(self):
1305
- """
1306
- Remove all stimuli from the model.
1307
- """
1308
- self.remove_all_iclamps()
1309
- self.remove_all_populations()
840
+ values = [(seg.path_distance(), seg.get_param_value(param_name)) for seg in self.seg_tree]
841
+ colors = [get_domain_color(seg.domain) for seg in self.seg_tree]
1310
842
 
1311
- # ========================================================================
1312
- # SIMULATION
1313
- # ========================================================================
843
+ valid_values = [(x, y) for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
844
+ zero_values = [(x, y) for (x, y), color in zip(values, colors) if y == 0]
845
+ nan_values = [(x, 0) for (x, y), color in zip(values, colors) if pd.isna(y)]
846
+ valid_colors = [color for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
847
+ zero_colors = [color for (x, y), color in zip(values, colors) if y == 0]
848
+ nan_colors = [color for (x, y), color in zip(values, colors) if pd.isna(y)]
1314
849
 
1315
- def add_recording(self, sec, loc, var='v'):
1316
- """
1317
- Add a recording to the model.
850
+ if valid_values:
851
+ ax.scatter(*zip(*valid_values), c=valid_colors)
852
+ if zero_values:
853
+ ax.scatter(*zip(*zero_values), edgecolors=zero_colors, facecolors='none', marker='.')
854
+ if nan_values and show_nan:
855
+ ax.scatter(*zip(*nan_values), c=nan_colors, marker='x', alpha=0.5, zorder=0)
856
+ ax.axhline(y=0, color='k', linestyle='--')
1318
857
 
1319
- Parameters
1320
- ----------
1321
- sec : Section
1322
- The section to record from.
1323
- loc : float
1324
- The location along the normalized section length to record from.
1325
- var : str, optional
1326
- The variable to record. Default is 'v'.
1327
- """
1328
- self.simulator.add_recording(sec, loc, var)
1329
- print(f'Recording added to sec {sec} at loc {loc}.')
858
+ ax.set_xlabel('Path distance')
859
+ ax.set_ylabel(param_name)
860
+ ax.set_title(f'{param_name} distribution')
1330
861
 
1331
862
 
1332
- def remove_recording(self, sec, loc, var='v'):
1333
- """
1334
- Remove a recording from the model.
863
+ # ========================================================================
864
+ # MORPHOLOGY
865
+ # ========================================================================
1335
866
 
867
+ def get_sections(self, filter_function):
868
+ """Filter sections using a lambda function.
869
+
1336
870
  Parameters
1337
871
  ----------
1338
- sec : Section
1339
- The section to remove the recording from.
1340
- loc : float
1341
- The location along the normalized section length to remove the recording from.
1342
- """
1343
- self.simulator.remove_recording(sec, loc, var)
1344
-
1345
-
1346
- def remove_all_recordings(self, var=None):
1347
- """
1348
- Remove all recordings from the model.
872
+ filter_function : Callable
873
+ The lambda function to filter sections.
1349
874
  """
1350
- self.simulator.remove_all_recordings(var=var)
875
+ return [sec for sec in self.sec_tree.sections if filter_function(sec)]
1351
876
 
1352
877
 
1353
- def run(self, duration=300):
878
+ def get_segments(self, group_names=None):
1354
879
  """
1355
- Run the simulation for a specified duration.
880
+ Get the segments in specified groups.
1356
881
 
1357
882
  Parameters
1358
883
  ----------
1359
- duration : float, optional
1360
- The duration of the simulation. Default is 300.
884
+ group_names : List[str]
885
+ The names of the groups to get segments from.
1361
886
  """
1362
- self.simulator.run(duration)
1363
-
1364
- def get_traces(self):
1365
- return self.simulator.get_traces()
887
+ if not isinstance(group_names, list):
888
+ raise ValueError('Group names must be a list.')
889
+ return [seg for group_name in group_names for seg in self.seg_tree.segments if seg in self.groups[group_name]]
1366
890
 
1367
- def plot(self, *args, **kwargs):
1368
- self.simulator.plot(*args, **kwargs)
1369
-
1370
- # ========================================================================
1371
- # MORPHOLOGY
1372
- # ========================================================================
1373
891
 
1374
892
  def remove_subtree(self, sec):
1375
893
  """
@@ -1498,16 +1016,13 @@ class Model():
1498
1016
  return
1499
1017
 
1500
1018
  root_segs = [seg for seg in root.segments]
1501
- params_to_coeffs = {}
1019
+ params_to_fits = {}
1502
1020
  # for param_name in self.params:
1503
1021
  common_mechs.add('Independent')
1504
1022
  for mech in common_mechs:
1505
1023
  for param_name in self.mechs_to_params[mech]:
1506
- coeffs = self.fit_distribution(param_name, segments=root_segs, plot=False)
1507
- if coeffs is None:
1508
- warnings.warn(f'Cannot fit distribution for parameter {param_name}. No values found.')
1509
- continue
1510
- params_to_coeffs[param_name] = coeffs
1024
+ fit_result = self.fit_distribution(param_name, segments=root_segs, plot=False)
1025
+ params_to_fits[param_name] = fit_result
1511
1026
 
1512
1027
 
1513
1028
  # Create new domain
@@ -1527,9 +1042,9 @@ class Model():
1527
1042
 
1528
1043
 
1529
1044
  # # Fit distributions to data for the group
1530
- for param_name, coeffs in params_to_coeffs.items():
1531
- self._set_distribution(param_name, group_name, coeffs, plot=True)
1532
-
1045
+ for param_name, fit_result in params_to_fits.items():
1046
+ self._set_distribution(param_name, group_name, fit_result, plot=True)
1047
+
1533
1048
  # # Distribute parameters
1534
1049
  self.distribute_all()
1535
1050
 
@@ -1540,464 +1055,5 @@ class Model():
1540
1055
  'segs_to_locs': segs_to_locs,
1541
1056
  'segs_to_reduced_segs': segs_to_reduced_segs,
1542
1057
  'reduced_segs_to_params': reduced_segs_to_params,
1543
- 'params_to_coeffs': params_to_coeffs
1544
- }
1545
-
1546
-
1547
- def fit_distribution(self, param_name, segments, max_degree=20, tolerance=1e-7, plot=False):
1548
- from numpy import polyfit, polyval
1549
- values = [seg.get_param_value(param_name) for seg in segments]
1550
- # if all values are NaN, return None
1551
- if all(np.isnan(values)):
1552
- return None
1553
- distances = [seg.path_distance() for seg in segments]
1554
- sorted_pairs = sorted(zip(distances, values))
1555
- distances, values = zip(*sorted_pairs)
1556
- degrees = range(0, max_degree+1)
1557
- for degree in degrees:
1558
- coeffs = polyfit(distances, values, degree)
1559
- residuals = values - polyval(coeffs, distances)
1560
- if all(abs(residuals) < tolerance):
1561
- break
1562
- if not all(abs(residuals) < tolerance):
1563
- warnings.warn(f'Fitting failed for parameter {param_name} with the provided tolerance.\nUsing the last valid fit (degree={degree}). Maximum residual: {max(abs(residuals))}')
1564
- if plot and degree > 0:
1565
- self.plot_param(param_name, show_nan=False)
1566
- plt.plot(distances, polyval(coeffs, distances), label='Fitted', color='red', linestyle='--')
1567
- plt.legend()
1568
- return coeffs
1569
-
1570
-
1571
- def _set_distribution(self, param_name, group_name, coeffs, plot=False):
1572
- # Set the distribution based on the degree of the polynomial fit
1573
- coeffs = np.where(np.round(coeffs) == 0, coeffs, np.round(coeffs, 10))
1574
- if len(coeffs) == 1:
1575
- self.params[param_name][group_name] = Distribution('constant', value=coeffs[0])
1576
- elif len(coeffs) == 2:
1577
- self.params[param_name][group_name] = Distribution('linear', slope=coeffs[0], intercept=coeffs[1])
1578
- else:
1579
- self.params[param_name][group_name] = Distribution('polynomial', coeffs=coeffs.tolist())
1580
-
1581
-
1582
- # ========================================================================
1583
- # PLOTTING
1584
- # ========================================================================
1585
-
1586
- def plot_param(self, param_name, ax=None, show_nan=True):
1587
- """
1588
- Plot the distribution of a parameter in the model.
1589
-
1590
- Parameters
1591
- ----------
1592
- param_name : str
1593
- The name of the parameter to plot.
1594
- ax : matplotlib.axes.Axes, optional
1595
- The axes to plot on. Default is None.
1596
- show_nan : bool, optional
1597
- Whether to show NaN values. Default is True.
1598
- """
1599
- if ax is None:
1600
- fig, ax = plt.subplots(figsize=(10, 2))
1601
-
1602
- if param_name not in self.params:
1603
- warnings.warn(f'Parameter {param_name} not found.')
1604
-
1605
- values = [(seg.path_distance(), seg.get_param_value(param_name)) for seg in self.seg_tree]
1606
- colors = [get_domain_color(seg.domain) for seg in self.seg_tree]
1607
-
1608
- valid_values = [(x, y) for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
1609
- zero_values = [(x, y) for (x, y), color in zip(values, colors) if y == 0]
1610
- nan_values = [(x, 0) for (x, y), color in zip(values, colors) if pd.isna(y)]
1611
- valid_colors = [color for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
1612
- zero_colors = [color for (x, y), color in zip(values, colors) if y == 0]
1613
- nan_colors = [color for (x, y), color in zip(values, colors) if pd.isna(y)]
1614
-
1615
- if valid_values:
1616
- ax.scatter(*zip(*valid_values), c=valid_colors)
1617
- if zero_values:
1618
- ax.scatter(*zip(*zero_values), edgecolors=zero_colors, facecolors='none', marker='.')
1619
- if nan_values and show_nan:
1620
- ax.scatter(*zip(*nan_values), c=nan_colors, marker='x', alpha=0.5, zorder=0)
1621
- plt.axhline(y=0, color='k', linestyle='--')
1622
-
1623
- ax.set_xlabel('Path distance')
1624
- ax.set_ylabel(param_name)
1625
- ax.set_title(f'{param_name} distribution')
1626
-
1627
-
1628
- # ========================================================================
1629
- # FILE EXPORT
1630
- # ========================================================================
1631
-
1632
- def export_morphology(self, file_name):
1633
- """
1634
- Write the SWC tree to an SWC file.
1635
-
1636
- Parameters
1637
- ----------
1638
- version : str, optional
1639
- The version of the morphology appended to the morphology name.
1640
- """
1641
- path_to_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
1642
-
1643
- self.point_tree.to_swc(path_to_file)
1644
-
1645
-
1646
- def to_dict(self):
1647
- """
1648
- Return a dictionary representation of the model.
1649
-
1650
- Returns
1651
- -------
1652
- dict
1653
- The dictionary representation of the model.
1654
- """
1655
- return {
1656
- 'metadata': {
1657
- 'name': self.name,
1658
- },
1659
- 'd_lambda': self.d_lambda,
1660
- 'domains': {domain: sorted(list(mechs)) for domain, mechs in self.domains_to_mechs.items()},
1661
- 'groups': [
1662
- group.to_dict() for group in self._groups
1663
- ],
1664
- 'params': {
1665
- param_name: {
1666
- group_name: distribution if isinstance(distribution, str) else distribution.to_dict()
1667
- for group_name, distribution in distributions.items()
1668
- }
1669
- for param_name, distributions in self.params.items()
1670
- },
1671
- }
1672
-
1673
- def from_dict(self, data):
1674
- """
1675
- Load the model from a dictionary.
1676
-
1677
- Parameters
1678
- ----------
1679
- data : dict
1680
- The dictionary representation of the model.
1681
- """
1682
- if not self.name == data['metadata']['name']:
1683
- raise ValueError('Model name does not match the data.')
1684
-
1685
- self.d_lambda = data['d_lambda']
1686
-
1687
- # Domains and mechanisms
1688
- self.domains_to_mechs = {
1689
- domain: set(mechs) for domain, mechs in data['domains'].items()
1690
- }
1691
- if self.verbose: print('Inserting mechanisms...')
1692
- for domain_name, mechs in self.domains_to_mechs.items():
1693
- for mech_name in mechs:
1694
- self.insert_mechanism(mech_name, domain_name, distribute=False)
1695
- # print('Distributing parameters...')
1696
- # self.distribute_all()
1697
-
1698
- # Groups
1699
- if self.verbose: print('Adding groups...')
1700
- self._groups = [SegmentGroup.from_dict(group) for group in data['groups']]
1701
-
1702
- if self.verbose: print('Distributing parameters...')
1703
- # Parameters
1704
- self.params = {
1705
- param_name: {
1706
- group_name: distribution if isinstance(distribution, str) else Distribution.from_dict(distribution)
1707
- for group_name, distribution in distributions.items()
1708
- }
1709
- for param_name, distributions in data['params'].items()
1710
- }
1711
-
1712
- if self.verbose: print('Setting segmentation...')
1713
- if self.sec_tree is not None:
1714
- d_lambda = self.d_lambda
1715
- self.set_segmentation(d_lambda=d_lambda)
1716
-
1717
-
1718
-
1719
- def export_biophys(self, file_name, **kwargs):
1720
- """
1721
- Export the biophysical properties of the model to a JSON file.
1722
-
1723
- Parameters
1724
- ----------
1725
- file_name : str
1726
- The name of the file to write to.
1727
- **kwargs : dict
1728
- Additional keyword arguments to pass to `json.dump`.
1729
- """
1730
-
1731
- path_to_json = self.path_manager.get_file_path('biophys', file_name, extension='json')
1732
- if not kwargs.get('indent'):
1733
- kwargs['indent'] = 4
1734
-
1735
- data = self.to_dict()
1736
- with open(path_to_json, 'w') as f:
1737
- json.dump(data, f, **kwargs)
1738
-
1739
-
1740
- def load_biophys(self, file_name, recompile=True):
1741
- """
1742
- Load the biophysical properties of the model from a JSON file.
1743
-
1744
- Parameters
1745
- ----------
1746
- file_name : str
1747
- The name of the file to read from.
1748
- recompile : bool, optional
1749
- Whether to recompile the mechanisms after loading. Default is True.
1750
- """
1751
- self.add_default_mechanisms()
1752
-
1753
-
1754
- path_to_json = self.path_manager.get_file_path('biophys', file_name, extension='json')
1755
-
1756
- with open(path_to_json, 'r') as f:
1757
- data = json.load(f)
1758
-
1759
- for mech_name in {mech for mechs in data['domains'].values() for mech in mechs}:
1760
- if mech_name in ['Leak', 'CaDyn', 'Independent']:
1761
- continue
1762
- self.add_mechanism(mech_name, dir_name='mod', recompile=recompile)
1763
-
1764
- self.from_dict(data)
1765
-
1766
-
1767
- def stimuli_to_dict(self):
1768
- """
1769
- Convert the stimuli to a dictionary representation.
1770
-
1771
- Returns
1772
- -------
1773
- dict
1774
- The dictionary representation of the stimuli.
1775
- """
1776
- return {
1777
- 'metadata': {
1778
- 'name': self.name,
1779
- },
1780
- 'simulation': {
1781
- **self.simulator.to_dict(),
1782
- },
1783
- 'stimuli': {
1784
- 'recordings': [
1785
- {
1786
- 'name': f'rec_{i}',
1787
- 'var': var
1788
- }
1789
- for var, recs in self.simulator.recordings.items()
1790
- for i, _ in enumerate(recs)
1791
- ],
1792
- 'iclamps': [
1793
- {
1794
- 'name': f'iclamp_{i}',
1795
- 'amp': iclamp.amp,
1796
- 'delay': iclamp.delay,
1797
- 'dur': iclamp.dur
1798
- }
1799
- for i, (seg, iclamp) in enumerate(self.iclamps.items())
1800
- ],
1801
- 'populations': {
1802
- syn_type: [pop.to_dict() for pop in pops.values()]
1803
- for syn_type, pops in self.populations.items()
1804
- }
1805
- },
1806
- }
1807
-
1808
-
1809
- def _stimuli_to_csv(self, path_to_csv=None):
1810
- """
1811
- Write the model to a CSV file.
1812
-
1813
- Parameters
1814
- ----------
1815
- path_to_csv : str
1816
- The path to the CSV file to write.
1817
- """
1818
-
1819
- rec_data = {
1820
- 'type': [],
1821
- 'idx': [],
1822
- 'sec_idx': [],
1823
- 'loc': [],
1824
- }
1825
- for var, recs in self.simulator.recordings.items():
1826
- rec_data['type'].extend(['rec'] * len(recs))
1827
- rec_data['idx'].extend([i for i in range(len(recs))])
1828
- rec_data['sec_idx'].extend([seg._section.idx for seg in recs])
1829
- rec_data['loc'].extend([seg.x for seg in recs])
1830
-
1831
- iclamp_data = {
1832
- 'type': ['iclamp'] * len(self.iclamps),
1833
- 'idx': [i for i in range(len(self.iclamps))],
1834
- 'sec_idx': [seg._section.idx for seg in self.iclamps],
1835
- 'loc': [seg.x for seg in self.iclamps],
1836
- }
1837
-
1838
- synapses_data = {
1839
- 'type': [],
1840
- 'idx': [],
1841
- 'sec_idx': [],
1842
- 'loc': [],
1843
- }
1844
-
1845
- for syn_type, pops in self.populations.items():
1846
- for pop_name, pop in pops.items():
1847
- pop_data = pop.to_csv()
1848
- synapses_data['type'] += pop_data['syn_type']
1849
- synapses_data['idx'] += [int(name.rsplit('_', 1)[1]) for name in pop_data['name']]
1850
- synapses_data['sec_idx'] += pop_data['sec_idx']
1851
- synapses_data['loc'] += pop_data['loc']
1852
-
1853
- df = pd.concat([
1854
- pd.DataFrame(rec_data),
1855
- pd.DataFrame(iclamp_data),
1856
- pd.DataFrame(synapses_data)
1857
- ], ignore_index=True)
1858
- df['idx'] = df['idx'].astype(int)
1859
- df['sec_idx'] = df['sec_idx'].astype(int)
1860
- if path_to_csv: df.to_csv(path_to_csv, index=False)
1861
-
1862
- return df
1863
-
1864
-
1865
- def export_stimuli(self, file_name, **kwargs):
1866
- """
1867
- Export the stimuli to a JSON and CSV file.
1868
-
1869
- Parameters
1870
- ----------
1871
- file_name : str
1872
- The name of the file to write to.
1873
- **kwargs : dict
1874
- Additional keyword arguments to pass to `json.dump`.
1875
- """
1876
- path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
1877
-
1878
- data = self.stimuli_to_dict()
1879
-
1880
- if not kwargs.get('indent'):
1881
- kwargs['indent'] = 4
1882
- with open(path_to_json, 'w') as f:
1883
- json.dump(data, f, **kwargs)
1884
-
1885
- path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
1886
- self._stimuli_to_csv(path_to_stimuli_csv)
1887
-
1888
-
1889
- def load_stimuli(self, file_name):
1890
- """
1891
- Load the stimuli from a JSON file.
1892
-
1893
- Parameters
1894
- ----------
1895
- file_name : str
1896
- The name of the file to read from.
1897
- """
1898
-
1899
- path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
1900
- path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
1901
-
1902
- with open(path_to_json, 'r') as f:
1903
- data = json.load(f)
1904
-
1905
- if not self.name == data['metadata']['name']:
1906
- raise ValueError('Model name does not match the data.')
1907
-
1908
- df_stimuli = pd.read_csv(path_to_stimuli_csv)
1909
-
1910
- self.simulator.from_dict(data['simulation'])
1911
-
1912
- # Clear all stimuli and recordings
1913
- self.remove_all_stimuli()
1914
- self.remove_all_recordings()
1915
-
1916
- # IClamps -----------------------------------------------------------
1917
-
1918
- df_iclamps = df_stimuli[df_stimuli['type'] == 'iclamp'].reset_index(drop=True, inplace=False)
1919
-
1920
- for row in df_iclamps.itertuples(index=False):
1921
- self.add_iclamp(
1922
- self.sec_tree.sections[row.sec_idx],
1923
- row.loc,
1924
- data['stimuli']['iclamps'][row.idx]['amp'],
1925
- data['stimuli']['iclamps'][row.idx]['delay'],
1926
- data['stimuli']['iclamps'][row.idx]['dur']
1927
- )
1928
-
1929
- # Populations -------------------------------------------------------
1930
-
1931
- syn_types = ['AMPA', 'NMDA', 'AMPA_NMDA', 'GABAa']
1932
-
1933
- for syn_type in syn_types:
1934
-
1935
- df_syn = df_stimuli[df_stimuli['type'] == syn_type]
1936
-
1937
- for i, pop_data in enumerate(data['stimuli']['populations'][syn_type]):
1938
-
1939
- df_pop = df_syn[df_syn['idx'] == i]
1940
-
1941
- segments = [self.sec_tree.sections[sec_idx](loc)
1942
- for sec_idx, loc in zip(df_pop['sec_idx'], df_pop['loc'])]
1943
-
1944
- pop = Population(idx=i,
1945
- segments=segments,
1946
- N=pop_data['N'],
1947
- syn_type=syn_type)
1948
-
1949
- syn_locs = [(self.sec_tree.sections[sec_idx], loc) for sec_idx, loc in zip(df_pop['sec_idx'].tolist(), df_pop['loc'].tolist())]
1950
-
1951
- pop.allocate_synapses(syn_locs=syn_locs)
1952
- pop.update_kinetic_params(**pop_data['kinetic_params'])
1953
- pop.update_input_params(**pop_data['input_params'])
1954
- self._add_population(pop)
1955
-
1956
- # Recordings ---------------------------------------------------------
1957
-
1958
- df_recs = df_stimuli[df_stimuli['type'] == 'rec'].reset_index(drop=True, inplace=False)
1959
- for row in df_recs.itertuples(index=False):
1960
- var = data['stimuli']['recordings'][row.idx]['var']
1961
- self.add_recording(
1962
- self.sec_tree.sections[row.sec_idx], row.loc, var
1963
- )
1964
-
1965
-
1966
- def export_to_NEURON(self, file_name, include_kinetic_params=True):
1967
- """
1968
- Export the model to a python file using NEURON.
1969
-
1970
- Parameters
1971
- ----------
1972
- file_name : str
1973
- The name of the file to write to.
1974
- """
1975
- from dendrotweaks.model_io import render_template
1976
- from dendrotweaks.model_io import get_params_to_valid_domains
1977
- from dendrotweaks.model_io import filter_params
1978
- from dendrotweaks.model_io import get_neuron_domain
1979
-
1980
- params_to_valid_domains = get_params_to_valid_domains(self)
1981
- params = self.params if include_kinetic_params else filter_params(self)
1982
- path_to_template = self.path_manager.get_file_path('templates', 'NEURON_template', extension='py')
1983
-
1984
- output = render_template(path_to_template,
1985
- {
1986
- 'param_dict': params,
1987
- 'groups_dict': self.groups,
1988
- 'params_to_mechs': self.params_to_mechs,
1989
- 'domains_to_mechs': self.domains_to_mechs,
1990
- 'iclamps': self.iclamps,
1991
- 'recordings': self.simulator.recordings,
1992
- 'params_to_valid_domains': params_to_valid_domains,
1993
- 'domains_to_NEURON': {domain: get_neuron_domain(domain) for domain in self.domains},
1994
- })
1995
-
1996
- if not file_name.endswith('.py'):
1997
- file_name += '.py'
1998
- path_to_model = self.path_manager.path_to_model
1999
- output_path = os.path.join(path_to_model, file_name)
2000
- with open(output_path, 'w') as f:
2001
- f.write(output)
2002
-
2003
-
1058
+ 'params_to_fits': params_to_fits
1059
+ }