dendrotweaks 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dendrotweaks/model.py CHANGED
@@ -1,63 +1,46 @@
1
+ # Imports
1
2
  from typing import List, Union, Callable
2
3
  import os
3
- import json
4
+ from collections import defaultdict
5
+
4
6
  import matplotlib.pyplot as plt
5
7
  import numpy as np
8
+ from numpy import nan
9
+ import pandas as pd
6
10
  import quantities as pq
7
11
 
8
- from dendrotweaks.morphology.point_trees import PointTree
9
- from dendrotweaks.morphology.sec_trees import NeuronSection, Section, SectionTree, Domain
10
- from dendrotweaks.morphology.seg_trees import NeuronSegment, Segment, SegmentTree
12
+ # DendroTweaks imports
11
13
  from dendrotweaks.simulators import NeuronSimulator
12
- from dendrotweaks.biophys.groups import SegmentGroup
13
- from dendrotweaks.biophys.mechanisms import Mechanism, LeakChannel, CaDynamics
14
- from dendrotweaks.biophys.io import create_channel, standardize_channel, create_standard_channel
15
14
  from dendrotweaks.biophys.io import MODFileLoader
16
- from dendrotweaks.morphology.io import create_point_tree, create_section_tree, create_segment_tree
17
- from dendrotweaks.stimuli.iclamps import IClamp
15
+ from dendrotweaks.morphology import Domain
16
+ from dendrotweaks.biophys.groups import SegmentGroup
18
17
  from dendrotweaks.biophys.distributions import Distribution
19
- from dendrotweaks.stimuli.populations import Population
20
- from dendrotweaks.utils import calculate_lambda_f, dynamic_import
21
- from dendrotweaks.utils import get_domain_color, timeit
22
-
23
- from collections import OrderedDict, defaultdict
24
- from numpy import nan
25
- # from .logger import logger
26
-
27
18
  from dendrotweaks.path_manager import PathManager
28
19
  import dendrotweaks.morphology.reduce as rdc
20
+ from dendrotweaks.utils import INDEPENDENT_PARAMS, DOMAIN_TO_GROUP, POPULATIONS
21
+ from dendrotweaks.utils import DEFAULT_FIT_MODELS
29
22
 
30
- import pandas as pd
23
+ # Mixins
24
+ from dendrotweaks.model_io import IOMixin
25
+ from dendrotweaks.model_simulation import SimulationMixin
31
26
 
27
+ # Warnings configuration
32
28
  import warnings
33
29
 
34
- POPULATIONS = {'AMPA': {}, 'NMDA': {}, 'AMPA_NMDA': {}, 'GABAa': {}}
35
-
36
30
  def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
37
31
  return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
38
32
 
39
33
  warnings.formatwarning = custom_warning_formatter
40
34
 
41
- INDEPENDENT_PARAMS = {
42
- 'cm': 1, # uF/cm2
43
- 'Ra': 100, # Ohm cm
44
- 'ena': 50, # mV
45
- 'ek': -77, # mV
46
- 'eca': 140 # mV
47
- }
48
35
 
49
- DOMAIN_TO_GROUP = {
50
- 'soma': 'somatic',
51
- 'axon': 'axonal',
52
- 'dend': 'dendritic',
53
- 'apic': 'apical',
54
- }
55
36
 
56
-
57
- class Model():
37
+ class Model(IOMixin, SimulationMixin):
58
38
  """
59
39
  A model object that represents a neuron model.
60
40
 
41
+ The class incorporates various mixins to separate concerns while
42
+ maintaining a flat interface.
43
+
61
44
  Parameters
62
45
  ----------
63
46
  name : str
@@ -168,6 +151,7 @@ class Model():
168
151
  """
169
152
  return self._name
170
153
 
154
+
171
155
  @property
172
156
  def verbose(self):
173
157
  """
@@ -175,6 +159,7 @@ class Model():
175
159
  """
176
160
  return self._verbose
177
161
 
162
+
178
163
  @verbose.setter
179
164
  def verbose(self, value):
180
165
  self._verbose = value
@@ -191,18 +176,17 @@ class Model():
191
176
 
192
177
 
193
178
  @property
194
- def recordings(self):
179
+ def mechs_to_domains(self):
195
180
  """
196
- The recordings of the model. Reference to the recordings in the simulator.
181
+ The dictionary mapping mechanisms to domains where they are inserted.
197
182
  """
198
- return self.simulator.recordings
199
-
200
-
201
- @recordings.setter
202
- def recordings(self, recordings):
203
- self.simulator.recordings = recordings
204
-
183
+ mechs_to_domains = defaultdict(set)
184
+ for domain_name, mech_names in self.domains_to_mechs.items():
185
+ for mech_name in mech_names:
186
+ mechs_to_domains[mech_name].add(domain_name)
187
+ return dict(mechs_to_domains)
205
188
 
189
+
206
190
  @property
207
191
  def groups(self):
208
192
  """
@@ -225,17 +209,6 @@ class Model():
225
209
  groups_to_parameters[group.name] = params
226
210
  return groups_to_parameters
227
211
 
228
- @property
229
- def mechs_to_domains(self):
230
- """
231
- The dictionary mapping mechanisms to domains where they are inserted.
232
- """
233
- mechs_to_domains = defaultdict(set)
234
- for domain_name, mech_names in self.domains_to_mechs.items():
235
- for mech_name in mech_names:
236
- mechs_to_domains[mech_name].add(domain_name)
237
- return dict(mechs_to_domains)
238
-
239
212
 
240
213
  @property
241
214
  def parameters_to_groups(self):
@@ -291,25 +264,6 @@ class Model():
291
264
  """
292
265
  return {param: value for param, value in self.params.items()
293
266
  if param.startswith('gbar')}
294
- # -----------------------------------------------------------------------
295
- # METADATA
296
- # -----------------------------------------------------------------------
297
-
298
- def info(self):
299
- """
300
- Print information about the model.
301
- """
302
- info_str = (
303
- f"Model: {self.name}\n"
304
- f"Path to data: {self.path_manager.path_to_data}\n"
305
- f"Simulator: {self.simulator_name}\n"
306
- f"Groups: {len(self.groups)}\n"
307
- f"Avaliable mechanisms: {len(self.mechanisms)}\n"
308
- f"Inserted mechanisms: {len(self.mechs_to_params) - 1}\n"
309
- # f"Parameters: {len(self.parameters)}\n"
310
- f"IClamps: {len(self.iclamps)}\n"
311
- )
312
- print(info_str)
313
267
 
314
268
 
315
269
  @property
@@ -331,365 +285,6 @@ class Model():
331
285
  df = pd.DataFrame(data)
332
286
  return df
333
287
 
334
- def print_directory_tree(self, *args, **kwargs):
335
- """
336
- Print the directory tree.
337
- """
338
- return self.path_manager.print_directory_tree(*args, **kwargs)
339
-
340
- def list_morphologies(self, extension='swc'):
341
- """
342
- List the morphologies available for the model.
343
- """
344
- return self.path_manager.list_files('morphology', extension=extension)
345
-
346
- def list_biophys(self, extension='json'):
347
- """
348
- List the biophysical configurations available for the model.
349
- """
350
- return self.path_manager.list_files('biophys', extension=extension)
351
-
352
- def list_mechanisms(self, extension='mod'):
353
- """
354
- List the mechanisms available for the model.
355
- """
356
- return self.path_manager.list_files('mod', extension=extension)
357
-
358
- def list_stimuli(self, extension='json'):
359
- """
360
- List the stimuli configurations available for the model.
361
- """
362
- return self.path_manager.list_files('stimuli', extension=extension)
363
-
364
- # ========================================================================
365
- # MORPHOLOGY
366
- # ========================================================================
367
-
368
- def load_morphology(self, file_name, soma_notation='3PS',
369
- align=True, sort_children=True, force=False) -> None:
370
- """
371
- Read an SWC file and build the SWC and section trees.
372
-
373
- Parameters
374
- ----------
375
- file_name : str
376
- The name of the SWC file to read.
377
- soma_notation : str, optional
378
- The notation of the soma in the SWC file. Can be '3PS' (three-point soma) or '1PS'. Default is '3PS'.
379
- align : bool, optional
380
- Whether to align the morphology to the soma center and align the apical dendrite (if present).
381
- sort_children : bool, optional
382
- Whether to sort the children of each node by increasing subtree size
383
- in the tree sorting algorithms. If True, the traversal visits
384
- children with shorter subtrees first and assigns them lower indices. If False, children
385
- are visited in their original SWC file order (matching NEURON's behavior).
386
- """
387
- # self.name = file_name.split('.')[0]
388
- self.morphology_name = file_name.replace('.swc', '')
389
- path_to_swc_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
390
- point_tree = create_point_tree(path_to_swc_file)
391
- # point_tree.remove_overlaps()
392
- point_tree.change_soma_notation(soma_notation)
393
- point_tree.sort(sort_children=sort_children, force=force)
394
- if align:
395
- point_tree.shift_coordinates_to_soma_center()
396
- point_tree.align_apical_dendrite()
397
- point_tree.round_coordinates(8)
398
- self.point_tree = point_tree
399
-
400
- sec_tree = create_section_tree(point_tree)
401
- sec_tree.sort(sort_children=sort_children, force=force)
402
- self.sec_tree = sec_tree
403
-
404
- self.create_and_reference_sections_in_simulator()
405
- seg_tree = create_segment_tree(sec_tree)
406
- self.seg_tree = seg_tree
407
-
408
- self._add_default_segment_groups()
409
- self._initialize_domains_to_mechs()
410
-
411
- d_lambda = self.d_lambda
412
- self.set_segmentation(d_lambda=d_lambda)
413
-
414
-
415
- def create_and_reference_sections_in_simulator(self):
416
- """
417
- Create and reference sections in the simulator.
418
- """
419
- if self.verbose: print(f'Building sections in {self.simulator_name}...')
420
- for sec in self.sec_tree.sections:
421
- sec.create_and_reference()
422
- n_sec = len([sec._ref for sec in self.sec_tree.sections
423
- if sec._ref is not None])
424
- if self.verbose: print(f'{n_sec} sections created.')
425
-
426
-
427
-
428
-
429
- def _add_default_segment_groups(self):
430
- self.add_group('all', list(self.domains.keys()))
431
- for domain_name in self.domains:
432
- group_name = DOMAIN_TO_GROUP.get(domain_name, domain_name)
433
- self.add_group(group_name, [domain_name])
434
-
435
-
436
- def _initialize_domains_to_mechs(self):
437
- for domain_name in self.domains:
438
- # Only if haven't been defined for the previous morphology
439
- # TODO: Check that domains match
440
- if not domain_name in self.domains_to_mechs:
441
- self.domains_to_mechs[domain_name] = set()
442
- for domain_name, mech_names in self.domains_to_mechs.items():
443
- for mech_name in mech_names:
444
- mech = self.mechanisms[mech_name]
445
- self.insert_mechanism(mech, domain_name)
446
-
447
-
448
- def get_sections(self, filter_function):
449
- """Filter sections using a lambda function.
450
-
451
- Parameters
452
- ----------
453
- filter_function : Callable
454
- The lambda function to filter sections.
455
- """
456
- return [sec for sec in self.sec_tree.sections if filter_function(sec)]
457
-
458
-
459
- def get_segments(self, group_names=None):
460
- """
461
- Get the segments in specified groups.
462
-
463
- Parameters
464
- ----------
465
- group_names : List[str]
466
- The names of the groups to get segments from.
467
- """
468
- if not isinstance(group_names, list):
469
- raise ValueError('Group names must be a list.')
470
- return [seg for group_name in group_names for seg in self.seg_tree.segments if seg in self.groups[group_name]]
471
-
472
- # ========================================================================
473
- # SEGMENTATION
474
- # ========================================================================
475
-
476
- # TODO Make a context manager for this
477
- def _temp_clear_stimuli(self):
478
- """
479
- Temporarily save and clear stimuli.
480
- """
481
- self.export_stimuli(file_name='_temp_stimuli')
482
- self.remove_all_stimuli()
483
- self.remove_all_recordings()
484
-
485
- def _temp_reload_stimuli(self):
486
- """
487
- Load stimuli from a temporary file and clean up.
488
- """
489
- self.load_stimuli(file_name='_temp_stimuli')
490
- for ext in ['json', 'csv']:
491
- temp_path = self.path_manager.get_file_path('stimuli', '_temp_stimuli', extension=ext)
492
- if os.path.exists(temp_path):
493
- os.remove(temp_path)
494
-
495
- def set_segmentation(self, d_lambda=0.1, f=100):
496
- """
497
- Set the number of segments in each section based on the geometry.
498
-
499
- Parameters
500
- ----------
501
- d_lambda : float
502
- The lambda value to use.
503
- f : float
504
- The frequency value to use.
505
- """
506
- self.d_lambda = d_lambda
507
-
508
- # Temporarily save and clear stimuli
509
- self._temp_clear_stimuli()
510
-
511
- # Pre-distribute parameters needed for lambda_f calculation
512
- for param_name in ['cm', 'Ra']:
513
- self.distribute(param_name)
514
-
515
- # Calculate lambda_f and set nseg for each section
516
- for sec in self.sec_tree.sections:
517
- lambda_f = calculate_lambda_f(sec.distances, sec.diameters, sec.Ra, sec.cm, f)
518
- nseg = max(1, int((sec.L / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1)
519
- sec._nseg = sec._ref.nseg = nseg
520
-
521
- # Rebuild the segment tree and redistribute parameters
522
- self.seg_tree = create_segment_tree(self.sec_tree)
523
- self.distribute_all()
524
-
525
- # Reload stimuli and clean up temporary files
526
- self._temp_reload_stimuli()
527
-
528
-
529
- # ========================================================================
530
- # MECHANISMS
531
- # ========================================================================
532
-
533
- def add_default_mechanisms(self, recompile=False):
534
- """
535
- Add default mechanisms to the model.
536
-
537
- Parameters
538
- ----------
539
- recompile : bool, optional
540
- Whether to recompile the mechanisms.
541
- """
542
- leak = LeakChannel()
543
- self.mechanisms[leak.name] = leak
544
-
545
- cadyn = CaDynamics()
546
- self.mechanisms[cadyn.name] = cadyn
547
-
548
- self.load_mechanisms('default_mod', recompile=recompile)
549
-
550
-
551
- def add_mechanisms(self, dir_name:str = 'mod', recompile=True) -> None:
552
- """
553
- Add a set of mechanisms from an archive to the model.
554
-
555
- Parameters
556
- ----------
557
- dir_name : str, optional
558
- The name of the archive to load mechanisms from. Default is 'mod'.
559
- recompile : bool, optional
560
- Whether to recompile the mechanisms.
561
- """
562
- # Create Mechanism objects and add them to the model
563
- for mechanism_name in self.path_manager.list_files(dir_name, extension='mod'):
564
- self.add_mechanism(mechanism_name,
565
- load=True,
566
- dir_name=dir_name,
567
- recompile=recompile)
568
-
569
-
570
-
571
- def add_mechanism(self, mechanism_name: str,
572
- python_template_name: str = 'default',
573
- load=True, dir_name: str = 'mod', recompile=True
574
- ) -> None:
575
- """
576
- Create a Mechanism object from the MOD file (or LeakChannel).
577
-
578
- Parameters
579
- ----------
580
- mechanism_name : str
581
- The name of the mechanism to add.
582
- python_template_name : str, optional
583
- The name of the Python template to use. Default is 'default'.
584
- load : bool, optional
585
- Whether to load the mechanism using neuron.load_mechanisms.
586
- """
587
- paths = self.path_manager.get_channel_paths(
588
- mechanism_name,
589
- python_template_name=python_template_name
590
- )
591
- mech = create_channel(**paths)
592
- # Add the mechanism to the model
593
- self.mechanisms[mech.name] = mech
594
- # Update the global parameters
595
-
596
- if load:
597
- self.load_mechanism(mechanism_name, dir_name, recompile)
598
-
599
-
600
-
601
- def load_mechanisms(self, dir_name: str = 'mod', recompile=True) -> None:
602
- """
603
- Load mechanisms from an archive.
604
-
605
- Parameters
606
- ----------
607
- dir_name : str, optional
608
- The name of the archive to load mechanisms from.
609
- recompile : bool, optional
610
- Whether to recompile the mechanisms.
611
- """
612
- mod_files = self.path_manager.list_files(dir_name, extension='mod')
613
- for mechanism_name in mod_files:
614
- self.load_mechanism(mechanism_name, dir_name, recompile)
615
-
616
-
617
- def load_mechanism(self, mechanism_name, dir_name='mod', recompile=False) -> None:
618
- """
619
- Load a mechanism from the specified archive.
620
-
621
- Parameters
622
- ----------
623
- mechanism_name : str
624
- The name of the mechanism to load.
625
- dir_name : str, optional
626
- The name of the directory to load the mechanism from. Default is 'mod'.
627
- recompile : bool, optional
628
- Whether to recompile the mechanism.
629
- """
630
- path_to_mod_file = self.path_manager.get_file_path(
631
- dir_name, mechanism_name, extension='mod'
632
- )
633
- self.mod_loader.load_mechanism(
634
- path_to_mod_file=path_to_mod_file, recompile=recompile
635
- )
636
-
637
-
638
- def standardize_channel(self, channel_name,
639
- python_template_name=None, mod_template_name=None, remove_old=True):
640
- """
641
- Standardize a channel by creating a new channel with the same kinetic
642
- properties using the standard equations.
643
-
644
- Parameters
645
- ----------
646
- channel_name : str
647
- The name of the channel to standardize.
648
- python_template_name : str, optional
649
- The name of the Python template to use.
650
- mod_template_name : str, optional
651
- The name of the MOD template to use.
652
- remove_old : bool, optional
653
- Whether to remove the old channel from the model. Default is True.
654
- """
655
-
656
- # Get data to transfer
657
- channel = self.mechanisms[channel_name]
658
- channel_domain_names = [domain_name for domain_name, mech_names
659
- in self.domains_to_mechs.items() if channel_name in mech_names]
660
- gbar_name = f'gbar_{channel_name}'
661
- gbar_distributions = self.params[gbar_name]
662
- # Kinetic variables cannot be transferred
663
-
664
- # Uninsert the old channel
665
- for domain_name in self.domains:
666
- if channel_name in self.domains_to_mechs[domain_name]:
667
- self.uninsert_mechanism(channel_name, domain_name)
668
-
669
- # Remove the old channel
670
- if remove_old:
671
- self.mechanisms.pop(channel_name)
672
-
673
- # Create, add and load a new channel
674
- paths = self.path_manager.get_standard_channel_paths(
675
- channel_name,
676
- mod_template_name=mod_template_name
677
- )
678
- standard_channel = standardize_channel(channel, **paths)
679
-
680
- self.mechanisms[standard_channel.name] = standard_channel
681
- self.load_mechanism(standard_channel.name, recompile=True)
682
-
683
- # Insert the new channel
684
- for domain_name in channel_domain_names:
685
- self.insert_mechanism(standard_channel.name, domain_name)
686
-
687
- # Transfer data
688
- gbar_name = f'gbar_{standard_channel.name}'
689
- for group_name, distribution in gbar_distributions.items():
690
- self.set_param(gbar_name, group_name,
691
- distribution.function_name, **distribution.parameters)
692
-
693
288
 
694
289
  # ========================================================================
695
290
  # DOMAINS
@@ -816,9 +411,9 @@ class Model():
816
411
  self.remove_group(group.name)
817
412
 
818
413
 
819
- # -----------------------------------------------------------------------
820
- # INSERT / UNINSERT MECHANISMS
821
- # -----------------------------------------------------------------------
414
+ # ========================================================================
415
+ # MECHANISMS
416
+ # ========================================================================
822
417
 
823
418
  def insert_mechanism(self, mechanism_name: str,
824
419
  domain_name: str, distribute=True):
@@ -922,11 +517,11 @@ class Model():
922
517
 
923
518
 
924
519
  # ========================================================================
925
- # SET PARAMETERS
520
+ # PARAMETERS
926
521
  # ========================================================================
927
522
 
928
523
  # -----------------------------------------------------------------------
929
- # GROUPS
524
+ # SEGMENT GROUPS (Where)
930
525
  # -----------------------------------------------------------------------
931
526
 
932
527
  def add_group(self, name, domains, select_by=None, min_value=None, max_value=None):
@@ -1001,7 +596,7 @@ class Model():
1001
596
 
1002
597
 
1003
598
  # -----------------------------------------------------------------------
1004
- # DISTRIBUTIONS
599
+ # DISTRIBUTIONS (How)
1005
600
  # -----------------------------------------------------------------------
1006
601
 
1007
602
  def set_param(self, param_name: str,
@@ -1063,6 +658,7 @@ class Model():
1063
658
  distribution = Distribution(distr_type, **distr_params)
1064
659
  self.params[param_name][group_name] = distribution
1065
660
 
661
+
1066
662
  def distribute_all(self):
1067
663
  """
1068
664
  Distribute all parameters to the segments.
@@ -1146,230 +742,152 @@ class Model():
1146
742
  self.params[param_name].pop(group_name, None)
1147
743
  self.distribute(param_name)
1148
744
 
1149
- # def set_section_param(self, param_name, value, domains=None):
1150
-
1151
- # domains = domains or self.domains
1152
- # for sec in self.sec_tree.sections:
1153
- # if sec.domain in domains:
1154
- # setattr(sec._ref, param_name, value)
1155
-
1156
- # ========================================================================
1157
- # STIMULI
1158
- # ========================================================================
1159
745
 
1160
746
  # -----------------------------------------------------------------------
1161
- # ICLAMPS
747
+ # FITTING
1162
748
  # -----------------------------------------------------------------------
1163
749
 
1164
- def add_iclamp(self, sec, loc, amp=0, delay=100, dur=100):
1165
- """
1166
- Add an IClamp to a section.
750
+ def fit_distribution(self, param_name: str, segments, candidate_models=None, plot=True):
751
+ if candidate_models is None:
752
+ candidate_models = DEFAULT_FIT_MODELS
1167
753
 
1168
- Parameters
1169
- ----------
1170
- sec : Section
1171
- The section to add the IClamp to.
1172
- loc : float
1173
- The location of the IClamp in the section.
1174
- amp : float, optional
1175
- The amplitude of the IClamp. Default is 0.
1176
- delay : float, optional
1177
- The delay of the IClamp. Default is 100.
1178
- dur : float, optional
1179
- The duration of the IClamp. Default is 100.
1180
- """
1181
- seg = sec(loc)
1182
- if self.iclamps.get(seg):
1183
- self.remove_iclamp(sec, loc)
1184
- iclamp = IClamp(sec, loc, amp, delay, dur)
1185
- print(f'IClamp added to sec {sec} at loc {loc}.')
1186
- self.iclamps[seg] = iclamp
754
+ from dendrotweaks.utils import mse
1187
755
 
756
+ values = [seg.get_param_value(param_name) for seg in segments]
757
+ if all(np.isnan(values)):
758
+ return None
1188
759
 
1189
- def remove_iclamp(self, sec, loc):
1190
- """
1191
- Remove an IClamp from a section.
760
+ distances = [seg.path_distance() for seg in segments]
761
+ distances, values = zip(*sorted(zip(distances, values)))
1192
762
 
1193
- Parameters
1194
- ----------
1195
- sec : Section
1196
- The section to remove the IClamp from.
1197
- loc : float
1198
- The location of the IClamp in the section.
1199
- """
1200
- seg = sec(loc)
1201
- if self.iclamps.get(seg):
1202
- self.iclamps.pop(seg)
763
+ best_score = float('inf')
764
+ best_model = None
765
+ best_params = None
766
+ best_pred = None
1203
767
 
768
+ results = []
1204
769
 
1205
- def remove_all_iclamps(self):
1206
- """
1207
- Remove all IClamps from the model.
1208
- """
770
+ for name, model in candidate_models.items():
771
+ try:
772
+ params, pred_values = model['fit'](distances, values)
773
+ score = model.get('score', mse)(values, pred_values)
774
+ complexity = model.get('complexity', 1)(params)
775
+ results.append((name, score, params, complexity, pred_values))
776
+ except Exception as e:
777
+ warnings.warn(f"Model {name} failed to fit: {e}")
1209
778
 
1210
- for seg in list(self.iclamps.keys()):
1211
- sec, loc = seg._section, seg.x
1212
- self.remove_iclamp(sec, loc)
1213
- if self.iclamps:
1214
- warnings.warn(f'Not all iclamps were removed: {self.iclamps}')
1215
- self.iclamps = {}
779
+ # Sort results by score and complexity
780
+ results.sort(key=lambda x: (np.round(x[1], 10), x[3]))
1216
781
 
782
+ best_model, best_score, best_params, _, best_pred = results[0]
1217
783
 
1218
- # -----------------------------------------------------------------------
1219
- # SYNAPSES
1220
- # -----------------------------------------------------------------------
784
+ if plot:
785
+ self.plot_param(param_name, show_nan=False)
786
+ plt.plot(distances, best_pred, label=f'Best Fit: {best_model}', color='red', linestyle='--')
787
+ plt.legend()
1221
788
 
1222
- def _add_population(self, population):
1223
- self.populations[population.syn_type][population.name] = population
789
+ return {'model': best_model, 'params': best_params, 'score': best_score}
1224
790
 
1225
791
 
1226
- def add_population(self, segments, N, syn_type):
1227
- """
1228
- Add a population of synapses to the model.
792
+ def _set_distribution(self, param_name, group_name, fit_result, plot=False):
793
+ if fit_result is None:
794
+ warnings.warn(f"No valid fit found for parameter {param_name}. Skipping distribution assignment.")
795
+ return
1229
796
 
1230
- Parameters
1231
- ----------
1232
- segments : list[Segment]
1233
- The segments to add the synapses to.
1234
- N : int
1235
- The number of synapses to add.
1236
- syn_type : str
1237
- The type of synapse to add.
1238
- """
1239
- idx = len(self.populations[syn_type])
1240
- population = Population(idx, segments, N, syn_type)
1241
- population.allocate_synapses()
1242
- population.create_inputs()
1243
- self._add_population(population)
797
+ model_type = fit_result['model']
798
+ params = fit_result['params']
1244
799
 
800
+ if model_type == 'poly':
801
+ coeffs = np.array(params)
802
+ coeffs = np.where(np.round(coeffs) == 0, coeffs, np.round(coeffs, 10))
803
+ if len(coeffs) == 1:
804
+ self.params[param_name][group_name] = Distribution('constant', value=coeffs[0])
805
+ elif len(coeffs) == 2:
806
+ self.params[param_name][group_name] = Distribution('linear', slope=coeffs[0], intercept=coeffs[1])
807
+ else:
808
+ self.params[param_name][group_name] = Distribution('polynomial', coeffs=coeffs.tolist())
1245
809
 
1246
- def update_population_kinetic_params(self, pop_name, **params):
1247
- """
1248
- Update the kinetic parameters of a population of synapses.
810
+ elif model_type == 'step':
811
+ start, end, min_value, max_value = params
812
+ self.params[param_name][group_name] = Distribution('step', max_value=max_value, min_value=min_value, start=start, end=end)
1249
813
 
1250
- Parameters
1251
- ----------
1252
- pop_name : str
1253
- The name of the population.
1254
- params : dict
1255
- The parameters to update.
1256
- """
1257
- syn_type, idx = pop_name.rsplit('_', 1)
1258
- population = self.populations[syn_type][pop_name]
1259
- population.update_kinetic_params(**params)
1260
- print(population.kinetic_params)
1261
814
 
1262
-
1263
- def update_population_input_params(self, pop_name, **params):
815
+ # -----------------------------------------------------------------------
816
+ # PLOTTING
817
+ # -----------------------------------------------------------------------
818
+
819
+ def plot_param(self, param_name, ax=None, show_nan=True):
1264
820
  """
1265
- Update the input parameters of a population of synapses.
821
+ Plot the distribution of a parameter in the model.
1266
822
 
1267
823
  Parameters
1268
824
  ----------
1269
- pop_name : str
1270
- The name of the population.
1271
- params : dict
1272
- The parameters to update.
825
+ param_name : str
826
+ The name of the parameter to plot.
827
+ ax : matplotlib.axes.Axes, optional
828
+ The axes to plot on. Default is None.
829
+ show_nan : bool, optional
830
+ Whether to show NaN values. Default is True.
1273
831
  """
1274
- syn_type, idx = pop_name.rsplit('_', 1)
1275
- population = self.populations[syn_type][pop_name]
1276
- population.update_input_params(**params)
1277
- print(population.input_params)
1278
-
832
+ from dendrotweaks.utils import get_domain_color
1279
833
 
1280
- def remove_population(self, name):
1281
- """
1282
- Remove a population of synapses from the model.
834
+ if ax is None:
835
+ fig, ax = plt.subplots(figsize=(10, 2))
1283
836
 
1284
- Parameters
1285
- ----------
1286
- name : str
1287
- The name of the population
1288
- """
1289
- syn_type, idx = name.rsplit('_', 1)
1290
- population = self.populations[syn_type].pop(name)
1291
- population.clean()
1292
-
1293
- def remove_all_populations(self):
1294
- """
1295
- Remove all populations of synapses from the model.
1296
- """
1297
- for syn_type in self.populations:
1298
- for name in list(self.populations[syn_type].keys()):
1299
- self.remove_population(name)
1300
- if any(self.populations.values()):
1301
- warnings.warn(f'Not all populations were removed: {self.populations}')
1302
- self.populations = POPULATIONS
837
+ if param_name not in self.params:
838
+ warnings.warn(f'Parameter {param_name} not found.')
1303
839
 
1304
- def remove_all_stimuli(self):
1305
- """
1306
- Remove all stimuli from the model.
1307
- """
1308
- self.remove_all_iclamps()
1309
- self.remove_all_populations()
840
+ values = [(seg.path_distance(), seg.get_param_value(param_name)) for seg in self.seg_tree]
841
+ colors = [get_domain_color(seg.domain) for seg in self.seg_tree]
1310
842
 
1311
- # ========================================================================
1312
- # SIMULATION
1313
- # ========================================================================
843
+ valid_values = [(x, y) for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
844
+ zero_values = [(x, y) for (x, y), color in zip(values, colors) if y == 0]
845
+ nan_values = [(x, 0) for (x, y), color in zip(values, colors) if pd.isna(y)]
846
+ valid_colors = [color for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
847
+ zero_colors = [color for (x, y), color in zip(values, colors) if y == 0]
848
+ nan_colors = [color for (x, y), color in zip(values, colors) if pd.isna(y)]
1314
849
 
1315
- def add_recording(self, sec, loc, var='v'):
1316
- """
1317
- Add a recording to the model.
850
+ if valid_values:
851
+ ax.scatter(*zip(*valid_values), c=valid_colors)
852
+ if zero_values:
853
+ ax.scatter(*zip(*zero_values), edgecolors=zero_colors, facecolors='none', marker='.')
854
+ if nan_values and show_nan:
855
+ ax.scatter(*zip(*nan_values), c=nan_colors, marker='x', alpha=0.5, zorder=0)
856
+ ax.axhline(y=0, color='k', linestyle='--')
1318
857
 
1319
- Parameters
1320
- ----------
1321
- sec : Section
1322
- The section to record from.
1323
- loc : float
1324
- The location along the normalized section length to record from.
1325
- var : str, optional
1326
- The variable to record. Default is 'v'.
1327
- """
1328
- self.simulator.add_recording(sec, loc, var)
1329
- print(f'Recording added to sec {sec} at loc {loc}.')
858
+ ax.set_xlabel('Path distance')
859
+ ax.set_ylabel(param_name)
860
+ ax.set_title(f'{param_name} distribution')
1330
861
 
1331
862
 
1332
- def remove_recording(self, sec, loc, var='v'):
1333
- """
1334
- Remove a recording from the model.
863
+ # ========================================================================
864
+ # MORPHOLOGY
865
+ # ========================================================================
1335
866
 
867
+ def get_sections(self, filter_function):
868
+ """Filter sections using a lambda function.
869
+
1336
870
  Parameters
1337
871
  ----------
1338
- sec : Section
1339
- The section to remove the recording from.
1340
- loc : float
1341
- The location along the normalized section length to remove the recording from.
1342
- """
1343
- self.simulator.remove_recording(sec, loc, var)
1344
-
1345
-
1346
- def remove_all_recordings(self, var=None):
1347
- """
1348
- Remove all recordings from the model.
872
+ filter_function : Callable
873
+ The lambda function to filter sections.
1349
874
  """
1350
- self.simulator.remove_all_recordings(var=var)
875
+ return [sec for sec in self.sec_tree.sections if filter_function(sec)]
1351
876
 
1352
877
 
1353
- def run(self, duration=300):
878
+ def get_segments(self, group_names=None):
1354
879
  """
1355
- Run the simulation for a specified duration.
880
+ Get the segments in specified groups.
1356
881
 
1357
882
  Parameters
1358
883
  ----------
1359
- duration : float, optional
1360
- The duration of the simulation. Default is 300.
884
+ group_names : List[str]
885
+ The names of the groups to get segments from.
1361
886
  """
1362
- self.simulator.run(duration)
1363
-
1364
- def get_traces(self):
1365
- return self.simulator.get_traces()
1366
-
1367
- def plot(self, *args, **kwargs):
1368
- self.simulator.plot(*args, **kwargs)
887
+ if not isinstance(group_names, list):
888
+ raise ValueError('Group names must be a list.')
889
+ return [seg for group_name in group_names for seg in self.seg_tree.segments if seg in self.groups[group_name]]
1369
890
 
1370
- # ========================================================================
1371
- # MORPHOLOGY
1372
- # ========================================================================
1373
891
 
1374
892
  def remove_subtree(self, sec):
1375
893
  """
@@ -1421,12 +939,15 @@ class Model():
1421
939
  domains_to_mechs = {domain_name: mech_names for domain_name, mech_names
1422
940
  in self.domains_to_mechs.items() if domain_name in [domain.name for domain in domains_in_subtree]}
1423
941
  common_mechs = set.intersection(*domains_to_mechs.values())
1424
- if not common_mechs:
942
+ if not all(mech_names == common_mechs
943
+ for mech_names in domains_to_mechs.values()):
1425
944
  raise ValueError(
1426
945
  'The domains in the subtree have different mechanisms. '
1427
946
  'Please ensure that all domains in the subtree have the same mechanisms. '
1428
947
  'You may need to insert the missing mechanisms and set their conductances to 0 where they are not needed.'
1429
948
  )
949
+ elif len(domains_in_subtree) == 1:
950
+ common_mechs = self.domains_to_mechs[domain_name].copy()
1430
951
 
1431
952
  inserted_mechs = {mech_name: mech for mech_name, mech
1432
953
  in self.mechanisms.items()
@@ -1495,10 +1016,13 @@ class Model():
1495
1016
  return
1496
1017
 
1497
1018
  root_segs = [seg for seg in root.segments]
1498
- params_to_coeffs = {}
1499
- for param_name in self.params:
1500
- coeffs = self.fit_distribution(param_name, segments=root_segs, plot=False)
1501
- params_to_coeffs[param_name] = coeffs
1019
+ params_to_fits = {}
1020
+ # for param_name in self.params:
1021
+ common_mechs.add('Independent')
1022
+ for mech in common_mechs:
1023
+ for param_name in self.mechs_to_params[mech]:
1024
+ fit_result = self.fit_distribution(param_name, segments=root_segs, plot=False)
1025
+ params_to_fits[param_name] = fit_result
1502
1026
 
1503
1027
 
1504
1028
  # Create new domain
@@ -1518,9 +1042,9 @@ class Model():
1518
1042
 
1519
1043
 
1520
1044
  # # Fit distributions to data for the group
1521
- for param_name, coeffs in params_to_coeffs.items():
1522
- self._set_distribution(param_name, group_name, coeffs, plot=True)
1523
-
1045
+ for param_name, fit_result in params_to_fits.items():
1046
+ self._set_distribution(param_name, group_name, fit_result, plot=True)
1047
+
1524
1048
  # # Distribute parameters
1525
1049
  self.distribute_all()
1526
1050
 
@@ -1531,466 +1055,5 @@ class Model():
1531
1055
  'segs_to_locs': segs_to_locs,
1532
1056
  'segs_to_reduced_segs': segs_to_reduced_segs,
1533
1057
  'reduced_segs_to_params': reduced_segs_to_params,
1534
- 'params_to_coeffs': params_to_coeffs
1535
- }
1536
-
1537
-
1538
- def fit_distribution(self, param_name, segments, max_degree=6, tolerance=1e-7, plot=False):
1539
- from numpy import polyfit, polyval
1540
- values = [seg.get_param_value(param_name) for seg in segments]
1541
- distances = [seg.path_distance() for seg in segments]
1542
- sorted_pairs = sorted(zip(distances, values))
1543
- distances, values = zip(*sorted_pairs)
1544
- degrees = range(0, max_degree+1)
1545
- for degree in degrees:
1546
- coeffs = polyfit(distances, values, degree)
1547
- residuals = values - polyval(coeffs, distances)
1548
- if all(abs(residuals) < tolerance):
1549
- break
1550
- if not all(abs(residuals) < tolerance):
1551
- warnings.warn(f'Fitting failed for parameter {param_name} with the provided tolerance.\nUsing the last valid fit (degree={degree}). Maximum residual: {max(abs(residuals))}')
1552
- if plot and degree > 0:
1553
- self.plot_param(param_name, show_nan=False)
1554
- plt.plot(distances, polyval(coeffs, distances), label='Fitted', color='red', linestyle='--')
1555
- plt.legend()
1556
- return coeffs
1557
-
1558
-
1559
- def _set_distribution(self, param_name, group_name, coeffs, plot=False):
1560
- # Set the distribution based on the degree of the polynomial fit
1561
- coeffs = np.where(np.round(coeffs) == 0, coeffs, np.round(coeffs, 10))
1562
- if len(coeffs) == 1:
1563
- self.params[param_name][group_name] = Distribution('constant', value=coeffs[0])
1564
- elif len(coeffs) == 2:
1565
- self.params[param_name][group_name] = Distribution('linear', slope=coeffs[0], intercept=coeffs[1])
1566
- else:
1567
- self.params[param_name][group_name] = Distribution('polynomial', coeffs=coeffs)
1568
-
1569
-
1570
- # ========================================================================
1571
- # PLOTTING
1572
- # ========================================================================
1573
-
1574
- def plot_param(self, param_name, ax=None, show_nan=True):
1575
- """
1576
- Plot the distribution of a parameter in the model.
1577
-
1578
- Parameters
1579
- ----------
1580
- param_name : str
1581
- The name of the parameter to plot.
1582
- ax : matplotlib.axes.Axes, optional
1583
- The axes to plot on. Default is None.
1584
- show_nan : bool, optional
1585
- Whether to show NaN values. Default is True.
1586
- """
1587
- if ax is None:
1588
- fig, ax = plt.subplots(figsize=(10, 2))
1589
-
1590
- if param_name not in self.params:
1591
- warnings.warn(f'Parameter {param_name} not found.')
1592
-
1593
- values = [(seg.path_distance(), seg.get_param_value(param_name)) for seg in self.seg_tree]
1594
- colors = [get_domain_color(seg.domain) for seg in self.seg_tree]
1595
-
1596
- valid_values = [(x, y) for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
1597
- zero_values = [(x, y) for (x, y), color in zip(values, colors) if y == 0]
1598
- nan_values = [(x, 0) for (x, y), color in zip(values, colors) if pd.isna(y)]
1599
- valid_colors = [color for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0]
1600
- zero_colors = [color for (x, y), color in zip(values, colors) if y == 0]
1601
- nan_colors = [color for (x, y), color in zip(values, colors) if pd.isna(y)]
1602
-
1603
- if valid_values:
1604
- ax.scatter(*zip(*valid_values), c=valid_colors)
1605
- if zero_values:
1606
- ax.scatter(*zip(*zero_values), edgecolors=zero_colors, facecolors='none', marker='.')
1607
- if nan_values and show_nan:
1608
- ax.scatter(*zip(*nan_values), c=nan_colors, marker='x', alpha=0.5, zorder=0)
1609
- plt.axhline(y=0, color='k', linestyle='--')
1610
-
1611
- ax.set_xlabel('Path distance')
1612
- ax.set_ylabel(param_name)
1613
- ax.set_title(f'{param_name} distribution')
1614
-
1615
-
1616
- # ========================================================================
1617
- # FILE EXPORT
1618
- # ========================================================================
1619
-
1620
- def export_morphology(self, file_name):
1621
- """
1622
- Write the SWC tree to an SWC file.
1623
-
1624
- Parameters
1625
- ----------
1626
- version : str, optional
1627
- The version of the morphology appended to the morphology name.
1628
- """
1629
- path_to_file = self.path_manager.get_file_path('morphology', file_name, extension='swc')
1630
-
1631
- self.point_tree.to_swc(path_to_file)
1632
-
1633
-
1634
- def to_dict(self):
1635
- """
1636
- Return a dictionary representation of the model.
1637
-
1638
- Returns
1639
- -------
1640
- dict
1641
- The dictionary representation of the model.
1642
- """
1643
- return {
1644
- 'metadata': {
1645
- 'name': self.name,
1646
- },
1647
- 'd_lambda': self.d_lambda,
1648
- 'domains': {domain: sorted(list(mechs)) for domain, mechs in self.domains_to_mechs.items()},
1649
- 'groups': [
1650
- group.to_dict() for group in self._groups
1651
- ],
1652
- 'params': {
1653
- param_name: {
1654
- group_name: distribution if isinstance(distribution, str) else distribution.to_dict()
1655
- for group_name, distribution in distributions.items()
1656
- }
1657
- for param_name, distributions in self.params.items()
1658
- },
1659
- }
1660
-
1661
- def from_dict(self, data):
1662
- """
1663
- Load the model from a dictionary.
1664
-
1665
- Parameters
1666
- ----------
1667
- data : dict
1668
- The dictionary representation of the model.
1669
- """
1670
- if not self.name == data['metadata']['name']:
1671
- raise ValueError('Model name does not match the data.')
1672
-
1673
- self.d_lambda = data['d_lambda']
1674
-
1675
- # Domains and mechanisms
1676
- self.domains_to_mechs = {
1677
- domain: set(mechs) for domain, mechs in data['domains'].items()
1678
- }
1679
- if self.verbose: print('Inserting mechanisms...')
1680
- for domain_name, mechs in self.domains_to_mechs.items():
1681
- for mech_name in mechs:
1682
- self.insert_mechanism(mech_name, domain_name, distribute=False)
1683
- # print('Distributing parameters...')
1684
- # self.distribute_all()
1685
-
1686
- # Groups
1687
- if self.verbose: print('Adding groups...')
1688
- self._groups = [SegmentGroup.from_dict(group) for group in data['groups']]
1689
-
1690
- if self.verbose: print('Distributing parameters...')
1691
- # Parameters
1692
- self.params = {
1693
- param_name: {
1694
- group_name: distribution if isinstance(distribution, str) else Distribution.from_dict(distribution)
1695
- for group_name, distribution in distributions.items()
1696
- }
1697
- for param_name, distributions in data['params'].items()
1698
- }
1699
-
1700
- if self.verbose: print('Setting segmentation...')
1701
- if self.sec_tree is not None:
1702
- d_lambda = self.d_lambda
1703
- self.set_segmentation(d_lambda=d_lambda)
1704
-
1705
-
1706
-
1707
- def export_biophys(self, file_name, **kwargs):
1708
- """
1709
- Export the biophysical properties of the model to a JSON file.
1710
-
1711
- Parameters
1712
- ----------
1713
- file_name : str
1714
- The name of the file to write to.
1715
- **kwargs : dict
1716
- Additional keyword arguments to pass to `json.dump`.
1717
- """
1718
-
1719
- path_to_json = self.path_manager.get_file_path('biophys', file_name, extension='json')
1720
- if not kwargs.get('indent'):
1721
- kwargs['indent'] = 4
1722
-
1723
- data = self.to_dict()
1724
- with open(path_to_json, 'w') as f:
1725
- json.dump(data, f, **kwargs)
1726
-
1727
-
1728
- def load_biophys(self, file_name, recompile=True):
1729
- """
1730
- Load the biophysical properties of the model from a JSON file.
1731
-
1732
- Parameters
1733
- ----------
1734
- file_name : str
1735
- The name of the file to read from.
1736
- recompile : bool, optional
1737
- Whether to recompile the mechanisms after loading. Default is True.
1738
- """
1739
- self.add_default_mechanisms()
1740
-
1741
-
1742
- path_to_json = self.path_manager.get_file_path('biophys', file_name, extension='json')
1743
-
1744
- with open(path_to_json, 'r') as f:
1745
- data = json.load(f)
1746
-
1747
- for mech_name in {mech for mechs in data['domains'].values() for mech in mechs}:
1748
- if mech_name in ['Leak', 'CaDyn', 'Independent']:
1749
- continue
1750
- self.add_mechanism(mech_name, dir_name='mod', recompile=recompile)
1751
-
1752
- self.from_dict(data)
1753
-
1754
-
1755
- def stimuli_to_dict(self):
1756
- """
1757
- Convert the stimuli to a dictionary representation.
1758
-
1759
- Returns
1760
- -------
1761
- dict
1762
- The dictionary representation of the stimuli.
1763
- """
1764
- return {
1765
- 'metadata': {
1766
- 'name': self.name,
1767
- },
1768
- 'simulation': {
1769
- **self.simulator.to_dict(),
1770
- },
1771
- 'stimuli': {
1772
- 'recordings': [
1773
- {
1774
- 'name': f'rec_{i}',
1775
- 'var': var
1776
- }
1777
- for var, recs in self.simulator.recordings.items()
1778
- for i, _ in enumerate(recs)
1779
- ],
1780
- 'iclamps': [
1781
- {
1782
- 'name': f'iclamp_{i}',
1783
- 'amp': iclamp.amp,
1784
- 'delay': iclamp.delay,
1785
- 'dur': iclamp.dur
1786
- }
1787
- for i, (seg, iclamp) in enumerate(self.iclamps.items())
1788
- ],
1789
- 'populations': {
1790
- syn_type: [pop.to_dict() for pop in pops.values()]
1791
- for syn_type, pops in self.populations.items()
1792
- }
1793
- },
1794
- }
1795
-
1796
-
1797
- def _stimuli_to_csv(self, path_to_csv=None):
1798
- """
1799
- Write the model to a CSV file.
1800
-
1801
- Parameters
1802
- ----------
1803
- path_to_csv : str
1804
- The path to the CSV file to write.
1805
- """
1806
-
1807
- rec_data = {
1808
- 'type': [],
1809
- 'idx': [],
1810
- 'sec_idx': [],
1811
- 'loc': [],
1812
- }
1813
- for var, recs in self.simulator.recordings.items():
1814
- rec_data['type'].extend(['rec'] * len(recs))
1815
- rec_data['idx'].extend([i for i in range(len(recs))])
1816
- rec_data['sec_idx'].extend([seg._section.idx for seg in recs])
1817
- rec_data['loc'].extend([seg.x for seg in recs])
1818
-
1819
- iclamp_data = {
1820
- 'type': ['iclamp'] * len(self.iclamps),
1821
- 'idx': [i for i in range(len(self.iclamps))],
1822
- 'sec_idx': [seg._section.idx for seg in self.iclamps],
1823
- 'loc': [seg.x for seg in self.iclamps],
1824
- }
1825
-
1826
- synapses_data = {
1827
- 'type': [],
1828
- 'idx': [],
1829
- 'sec_idx': [],
1830
- 'loc': [],
1831
- }
1832
-
1833
- for syn_type, pops in self.populations.items():
1834
- for pop_name, pop in pops.items():
1835
- pop_data = pop.to_csv()
1836
- synapses_data['type'] += pop_data['syn_type']
1837
- synapses_data['idx'] += [int(name.rsplit('_', 1)[1]) for name in pop_data['name']]
1838
- synapses_data['sec_idx'] += pop_data['sec_idx']
1839
- synapses_data['loc'] += pop_data['loc']
1840
-
1841
- df = pd.concat([
1842
- pd.DataFrame(rec_data),
1843
- pd.DataFrame(iclamp_data),
1844
- pd.DataFrame(synapses_data)
1845
- ], ignore_index=True)
1846
- df['idx'] = df['idx'].astype(int)
1847
- df['sec_idx'] = df['sec_idx'].astype(int)
1848
- if path_to_csv: df.to_csv(path_to_csv, index=False)
1849
-
1850
- return df
1851
-
1852
-
1853
- def export_stimuli(self, file_name, **kwargs):
1854
- """
1855
- Export the stimuli to a JSON and CSV file.
1856
-
1857
- Parameters
1858
- ----------
1859
- file_name : str
1860
- The name of the file to write to.
1861
- **kwargs : dict
1862
- Additional keyword arguments to pass to `json.dump`.
1863
- """
1864
- path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
1865
-
1866
- data = self.stimuli_to_dict()
1867
-
1868
- if not kwargs.get('indent'):
1869
- kwargs['indent'] = 4
1870
- with open(path_to_json, 'w') as f:
1871
- json.dump(data, f, **kwargs)
1872
-
1873
- path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
1874
- self._stimuli_to_csv(path_to_stimuli_csv)
1875
-
1876
-
1877
- def load_stimuli(self, file_name):
1878
- """
1879
- Load the stimuli from a JSON file.
1880
-
1881
- Parameters
1882
- ----------
1883
- file_name : str
1884
- The name of the file to read from.
1885
- """
1886
-
1887
- path_to_json = self.path_manager.get_file_path('stimuli', file_name, extension='json')
1888
- path_to_stimuli_csv = self.path_manager.get_file_path('stimuli', file_name, extension='csv')
1889
-
1890
- with open(path_to_json, 'r') as f:
1891
- data = json.load(f)
1892
-
1893
- if not self.name == data['metadata']['name']:
1894
- raise ValueError('Model name does not match the data.')
1895
-
1896
- df_stimuli = pd.read_csv(path_to_stimuli_csv)
1897
-
1898
- self.simulator.from_dict(data['simulation'])
1899
-
1900
- # Clear all stimuli and recordings
1901
- self.remove_all_stimuli()
1902
- self.remove_all_recordings()
1903
-
1904
- # IClamps -----------------------------------------------------------
1905
-
1906
- df_iclamps = df_stimuli[df_stimuli['type'] == 'iclamp'].reset_index(drop=True, inplace=False)
1907
-
1908
- for i, row in df_iclamps.iterrows():
1909
- self.add_iclamp(
1910
- self.sec_tree.sections[row['sec_idx']],
1911
- row['loc'],
1912
- data['stimuli']['iclamps'][i]['amp'],
1913
- data['stimuli']['iclamps'][i]['delay'],
1914
- data['stimuli']['iclamps'][i]['dur']
1915
- )
1916
-
1917
- # Populations -------------------------------------------------------
1918
-
1919
- syn_types = ['AMPA', 'NMDA', 'AMPA_NMDA', 'GABAa']
1920
-
1921
- for syn_type in syn_types:
1922
-
1923
- df_syn = df_stimuli[df_stimuli['type'] == syn_type]
1924
-
1925
- for i, pop_data in enumerate(data['stimuli']['populations'][syn_type]):
1926
-
1927
- df_pop = df_syn[df_syn['idx'] == i]
1928
-
1929
- segments = [self.sec_tree.sections[sec_idx](loc)
1930
- for sec_idx, loc in zip(df_pop['sec_idx'], df_pop['loc'])]
1931
-
1932
- pop = Population(idx=i,
1933
- segments=segments,
1934
- N=pop_data['N'],
1935
- syn_type=syn_type)
1936
-
1937
- syn_locs = [(self.sec_tree.sections[sec_idx], loc) for sec_idx, loc in zip(df_pop['sec_idx'].tolist(), df_pop['loc'].tolist())]
1938
-
1939
- pop.allocate_synapses(syn_locs=syn_locs)
1940
- pop.update_kinetic_params(**pop_data['kinetic_params'])
1941
- pop.update_input_params(**pop_data['input_params'])
1942
- self._add_population(pop)
1943
-
1944
- # Recordings ---------------------------------------------------------
1945
-
1946
- df_recs = df_stimuli[df_stimuli['type'] == 'rec'].reset_index(drop=True, inplace=False)
1947
- for i, row in df_recs.iterrows():
1948
- # TODO: This conditional statement is to account for a recent change
1949
- # in the JSON structure. It should be removed in the future.
1950
- if data['stimuli'].get('recordings'):
1951
- var = data['stimuli']['recordings'][i]['var']
1952
- else:
1953
- var = 'v'
1954
- self.add_recording(
1955
- self.sec_tree.sections[row['sec_idx']], row['loc'], var
1956
- )
1957
-
1958
-
1959
- def export_to_NEURON(self, file_name, include_kinetic_params=True):
1960
- """
1961
- Export the model to a python file using NEURON.
1962
-
1963
- Parameters
1964
- ----------
1965
- file_name : str
1966
- The name of the file to write to.
1967
- """
1968
- from dendrotweaks.model_io import render_template
1969
- from dendrotweaks.model_io import get_params_to_valid_domains
1970
- from dendrotweaks.model_io import filter_params
1971
- from dendrotweaks.model_io import get_neuron_domain
1972
-
1973
- params_to_valid_domains = get_params_to_valid_domains(self)
1974
- params = self.params if include_kinetic_params else filter_params(self)
1975
- path_to_template = self.path_manager.get_file_path('templates', 'NEURON_template', extension='py')
1976
-
1977
- output = render_template(path_to_template,
1978
- {
1979
- 'param_dict': params,
1980
- 'groups_dict': self.groups,
1981
- 'params_to_mechs': self.params_to_mechs,
1982
- 'domains_to_mechs': self.domains_to_mechs,
1983
- 'iclamps': self.iclamps,
1984
- 'recordings': self.simulator.recordings,
1985
- 'params_to_valid_domains': params_to_valid_domains,
1986
- 'domains_to_NEURON': {domain: get_neuron_domain(domain) for domain in self.domains},
1987
- })
1988
-
1989
- if not file_name.endswith('.py'):
1990
- file_name += '.py'
1991
- path_to_model = self.path_manager.path_to_model
1992
- output_path = os.path.join(path_to_model, file_name)
1993
- with open(output_path, 'w') as f:
1994
- f.write(output)
1995
-
1996
-
1058
+ 'params_to_fits': params_to_fits
1059
+ }