dendrotweaks 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,326 @@
1
+ # Imports
2
+ import os
3
+
4
+ # DendroTweaks imports
5
+ from dendrotweaks.stimuli.populations import Population
6
+ from dendrotweaks.stimuli.iclamps import IClamp
7
+ from dendrotweaks.morphology.io import create_segment_tree
8
+ from dendrotweaks.prerun import prerun
9
+ from dendrotweaks.utils import calculate_lambda_f
10
+ from dendrotweaks.utils import POPULATIONS
11
+
12
+ # Warnings configuration
13
+ import warnings
14
+
15
+ def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
16
+ return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
17
+
18
+ warnings.formatwarning = custom_warning_formatter
19
+
20
+
21
+ class SimulationMixin:
22
+
23
+ """Mixin class for model simulation functionalities."""
24
+
25
+ @property
26
+ def recordings(self):
27
+ """
28
+ The recordings of the model. Reference to the recordings in the simulator.
29
+ """
30
+ return self.simulator.recordings
31
+
32
+
33
+ @recordings.setter
34
+ def recordings(self, recordings):
35
+ self.simulator.recordings = recordings
36
+
37
+
38
+ # ========================================================================
39
+ # SEGMENTATION
40
+ # ========================================================================
41
+
42
+ # TODO Make a context manager for this
43
+ def _temp_clear_stimuli(self):
44
+ """
45
+ Temporarily save and clear stimuli.
46
+ """
47
+ self.export_stimuli(file_name='_temp_stimuli')
48
+ self.remove_all_stimuli()
49
+ self.remove_all_recordings()
50
+
51
+
52
+ def _temp_reload_stimuli(self):
53
+ """
54
+ Load stimuli from a temporary file and clean up.
55
+ """
56
+ self.load_stimuli(file_name='_temp_stimuli')
57
+ for ext in ['json', 'csv']:
58
+ temp_path = self.path_manager.get_file_path('stimuli', '_temp_stimuli', extension=ext)
59
+ if os.path.exists(temp_path):
60
+ os.remove(temp_path)
61
+
62
+
63
+ def set_segmentation(self, d_lambda=0.1, f=100):
64
+ """
65
+ Set the number of segments in each section based on the geometry.
66
+
67
+ Parameters
68
+ ----------
69
+ d_lambda : float
70
+ The lambda value to use.
71
+ f : float
72
+ The frequency value to use.
73
+ """
74
+ self.d_lambda = d_lambda
75
+
76
+ # Temporarily save and clear stimuli
77
+ self._temp_clear_stimuli()
78
+
79
+ # Pre-distribute parameters needed for lambda_f calculation
80
+ for param_name in ['cm', 'Ra']:
81
+ self.distribute(param_name)
82
+
83
+ # Calculate lambda_f and set nseg for each section
84
+ for sec in self.sec_tree.sections:
85
+ lambda_f = calculate_lambda_f(sec.distances, sec.diameters, sec.Ra, sec.cm, f)
86
+ nseg = max(1, int((sec.L / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1)
87
+ sec._nseg = sec._ref.nseg = nseg
88
+
89
+ # Rebuild the segment tree and redistribute parameters
90
+ self.seg_tree = create_segment_tree(self.sec_tree)
91
+ self.distribute_all()
92
+
93
+ # Reload stimuli and clean up temporary files
94
+ self._temp_reload_stimuli()
95
+
96
+
97
+ # -----------------------------------------------------------------------
98
+ # ICLAMPS
99
+ # -----------------------------------------------------------------------
100
+
101
+ def add_iclamp(self, sec, loc, amp=0, delay=100, dur=100):
102
+ """
103
+ Add an IClamp to a section.
104
+
105
+ Parameters
106
+ ----------
107
+ sec : Section
108
+ The section to add the IClamp to.
109
+ loc : float
110
+ The location of the IClamp in the section.
111
+ amp : float, optional
112
+ The amplitude of the IClamp. Default is 0.
113
+ delay : float, optional
114
+ The delay of the IClamp. Default is 100.
115
+ dur : float, optional
116
+ The duration of the IClamp. Default is 100.
117
+ """
118
+ seg = sec(loc)
119
+ if self.iclamps.get(seg):
120
+ self.remove_iclamp(sec, loc)
121
+ iclamp = IClamp(sec, loc, amp, delay, dur)
122
+ print(f'IClamp added to sec {sec} at loc {loc}.')
123
+ self.iclamps[seg] = iclamp
124
+
125
+
126
+ def remove_iclamp(self, sec, loc):
127
+ """
128
+ Remove an IClamp from a section.
129
+
130
+ Parameters
131
+ ----------
132
+ sec : Section
133
+ The section to remove the IClamp from.
134
+ loc : float
135
+ The location of the IClamp in the section.
136
+ """
137
+ seg = sec(loc)
138
+ if self.iclamps.get(seg):
139
+ self.iclamps.pop(seg)
140
+
141
+
142
+ def remove_all_iclamps(self):
143
+ """
144
+ Remove all IClamps from the model.
145
+ """
146
+
147
+ for seg in list(self.iclamps.keys()):
148
+ sec, loc = seg._section, seg.x
149
+ self.remove_iclamp(sec, loc)
150
+ if self.iclamps:
151
+ warnings.warn(f'Not all iclamps were removed: {self.iclamps}')
152
+ self.iclamps = {}
153
+
154
+
155
+ # -----------------------------------------------------------------------
156
+ # SYNAPSES
157
+ # -----------------------------------------------------------------------
158
+
159
+ def _add_population(self, population):
160
+ self.populations[population.syn_type][population.name] = population
161
+
162
+
163
+ def add_population(self, segments, N, syn_type):
164
+ """
165
+ Add a population of synapses to the model.
166
+
167
+ Parameters
168
+ ----------
169
+ segments : list[Segment]
170
+ The segments to add the synapses to.
171
+ N : int
172
+ The number of synapses to add.
173
+ syn_type : str
174
+ The type of synapse to add.
175
+ """
176
+ idx = len(self.populations[syn_type])
177
+ population = Population(idx, segments, N, syn_type)
178
+ population.allocate_synapses()
179
+ population.create_inputs()
180
+ self._add_population(population)
181
+
182
+
183
+ def update_population_kinetic_params(self, pop_name, **params):
184
+ """
185
+ Update the kinetic parameters of a population of synapses.
186
+
187
+ Parameters
188
+ ----------
189
+ pop_name : str
190
+ The name of the population.
191
+ params : dict
192
+ The parameters to update.
193
+ """
194
+ syn_type, idx = pop_name.rsplit('_', 1)
195
+ population = self.populations[syn_type][pop_name]
196
+ population.update_kinetic_params(**params)
197
+ print(population.kinetic_params)
198
+
199
+
200
+ def update_population_input_params(self, pop_name, **params):
201
+ """
202
+ Update the input parameters of a population of synapses.
203
+
204
+ Parameters
205
+ ----------
206
+ pop_name : str
207
+ The name of the population.
208
+ params : dict
209
+ The parameters to update.
210
+ """
211
+ syn_type, idx = pop_name.rsplit('_', 1)
212
+ population = self.populations[syn_type][pop_name]
213
+ population.update_input_params(**params)
214
+ print(population.input_params)
215
+
216
+
217
+ def remove_population(self, name):
218
+ """
219
+ Remove a population of synapses from the model.
220
+
221
+ Parameters
222
+ ----------
223
+ name : str
224
+ The name of the population
225
+ """
226
+ syn_type, idx = name.rsplit('_', 1)
227
+ population = self.populations[syn_type].pop(name)
228
+ population.clean()
229
+
230
+
231
+ def remove_all_populations(self):
232
+ """
233
+ Remove all populations of synapses from the model.
234
+ """
235
+ for syn_type in self.populations:
236
+ for name in list(self.populations[syn_type].keys()):
237
+ self.remove_population(name)
238
+ if any(self.populations.values()):
239
+ warnings.warn(f'Not all populations were removed: {self.populations}')
240
+ self.populations = POPULATIONS
241
+
242
+
243
+ def remove_all_stimuli(self):
244
+ """
245
+ Remove all stimuli from the model.
246
+ """
247
+ self.remove_all_iclamps()
248
+ self.remove_all_populations()
249
+
250
+
251
+ # ========================================================================
252
+ # SIMULATION
253
+ # ========================================================================
254
+
255
+ def add_recording(self, sec, loc, var='v'):
256
+ """
257
+ Add a recording to the model.
258
+
259
+ Parameters
260
+ ----------
261
+ sec : Section
262
+ The section to record from.
263
+ loc : float
264
+ The location along the normalized section length to record from.
265
+ var : str, optional
266
+ The variable to record. Default is 'v'.
267
+ """
268
+ self.simulator.add_recording(sec, loc, var)
269
+ print(f'Recording added to sec {sec} at loc {loc}.')
270
+
271
+
272
+ def remove_recording(self, sec, loc, var='v'):
273
+ """
274
+ Remove a recording from the model.
275
+
276
+ Parameters
277
+ ----------
278
+ sec : Section
279
+ The section to remove the recording from.
280
+ loc : float
281
+ The location along the normalized section length to remove the recording from.
282
+ """
283
+ self.simulator.remove_recording(sec, loc, var)
284
+
285
+
286
+ def remove_all_recordings(self, var=None):
287
+ """
288
+ Remove all recordings from the model.
289
+ """
290
+ self.simulator.remove_all_recordings(var=var)
291
+
292
+
293
+ def run(self, duration=300, prerun_time=0, truncate=True):
294
+ """
295
+ Run the simulation for a specified duration, optionally preceded by a prerun period
296
+ to stabilize the model.
297
+
298
+ Parameters
299
+ ----------
300
+ duration : float
301
+ Duration of the main simulation (excluding prerun).
302
+ prerun_time : float
303
+ Optional prerun period to run before the main simulation.
304
+ truncate : bool
305
+ Whether to truncate prerun data after the simulation.
306
+ """
307
+ if duration <= 0:
308
+ raise ValueError("Simulation duration must be positive.")
309
+ if prerun_time < 0:
310
+ raise ValueError("Prerun time must be non-negative.")
311
+
312
+ total_time = duration + prerun_time
313
+
314
+ if prerun_time > 0:
315
+ with prerun(self, duration=prerun_time, truncate=truncate):
316
+ self.simulator.run(total_time)
317
+ else:
318
+ self.simulator.run(duration)
319
+
320
+
321
+ def get_traces(self):
322
+ return self.simulator.get_traces()
323
+
324
+
325
+ def plot(self, *args, **kwargs):
326
+ self.simulator.plot(*args, **kwargs)
@@ -36,8 +36,8 @@ def create_point_tree(source: Union[str, DataFrame]) -> PointTree:
36
36
  raise ValueError("Source must be a file path (str) or a DataFrame.")
37
37
 
38
38
  nodes = [
39
- Point(row['Index'], row['Type'], row['X'], row['Y'], row['Z'], row['R'], row['Parent'])
40
- for _, row in df.iterrows()
39
+ Point(row.Index, row.Type, row.X, row.Y, row.Z, row.R, row.Parent)
40
+ for row in df.itertuples(index=False)
41
41
  ]
42
42
  point_tree = PointTree(nodes)
43
43
  point_tree.remove_overlaps()
@@ -35,7 +35,16 @@ class SWCReader():
35
35
  header=None,
36
36
  comment='#',
37
37
  names=['Index', 'Type', 'X', 'Y', 'Z', 'R', 'Parent'],
38
- index_col=False
38
+ index_col=False,
39
+ dtype={
40
+ 'Index': int,
41
+ 'Type': int,
42
+ 'X': float,
43
+ 'Y': float,
44
+ 'Z': float,
45
+ 'R': float,
46
+ 'Parent': int
47
+ }
39
48
  )
40
49
 
41
50
  if (df['R'] == 0).all():
@@ -61,6 +70,6 @@ class SWCReader():
61
70
  for t in df['Type'].unique():
62
71
  color = types_to_colors.get(t, 'k')
63
72
  mask = df['Type'] == t
64
- ax.scatter(df[mask]['X'], df[mask]['Y'], df[mask]['Z'],
65
- c=color, s=1, label=f'Type {t}')
73
+ ax.scatter(xs=df[mask]['X'], ys=df[mask]['Y'], zs=df[mask]['Z'],
74
+ c=color, s=1, label=f'Type {t}')
66
75
  ax.legend()
@@ -59,7 +59,7 @@ class Point(Node):
59
59
  x: float, y: float, z: float, r: float,
60
60
  parent_idx: str) -> None:
61
61
  super().__init__(idx, parent_idx)
62
- self.type_idx = type_idx
62
+ self.type_idx = int(type_idx)
63
63
  self.x = x
64
64
  self.y = y
65
65
  self.z = z
@@ -50,9 +50,9 @@ class PathManager:
50
50
  @property
51
51
  def path_to_data(self):
52
52
  """
53
- The path to the data directory.
53
+ The path to the data directory, which is always the parent directory of path_to_model.
54
54
  """
55
- return os.path.dirname(self.path_to_model)
55
+ return os.path.abspath(os.path.join(self.path_to_model, os.pardir))
56
56
 
57
57
  def __repr__(self):
58
58
  return f"PathManager({self.path_to_model})"
dendrotweaks/prerun.py ADDED
@@ -0,0 +1,63 @@
1
+ class prerun:
2
+ """
3
+ Context manager to prerun a simulation for a specified duration.
4
+ This is useful to stabilize the model before running the main simulation.
5
+ """
6
+ def __init__(self, model, duration=300, truncate=True):
7
+ self.model = model
8
+ if duration <= 0:
9
+ raise ValueError("Duration must be a positive number.")
10
+ self.duration = duration
11
+ self.truncate = truncate
12
+ self._original_iclamp_delays = {}
13
+ self._original_input_params = {}
14
+
15
+ def __enter__(self):
16
+
17
+ self._original_iclamp_delays = {k: v.delay for k, v in self.model.iclamps.items()}
18
+ self._original_input_params = {
19
+ (pk, sk): (pop.input_params['start'], pop.input_params['end'])
20
+ for pk, pops in self.model.populations.items()
21
+ for sk, pop in pops.items()
22
+ }
23
+ for iclamp in self.model.iclamps.values():
24
+ iclamp.delay += self.duration
25
+
26
+ for pops in self.model.populations.values():
27
+ for pop in pops.values():
28
+ start = pop.input_params['start'] + self.duration
29
+ end = pop.input_params['end'] + self.duration
30
+ pop.update_input_params(**{
31
+ 'start': start,
32
+ 'end': end
33
+ })
34
+
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_value, traceback):
38
+
39
+ # Restore iClamp delays
40
+ for seg, delay in self._original_iclamp_delays.items():
41
+ self.model.iclamps[seg].delay = delay
42
+
43
+ # Restore input timings
44
+ for (pk, sk), (start, end) in self._original_input_params.items():
45
+ self.model.populations[pk][sk].update_input_params(**{'start': start, 'end': end})
46
+
47
+ duration_timepoints = int(self.duration / self.model.simulator.dt)
48
+ if len(self.model.simulator.t) > duration_timepoints and self.truncate:
49
+ self._truncate()
50
+
51
+ def _truncate(self):
52
+ """Truncate the simulation time and recordings to the specified duration."""
53
+
54
+ onset = int(self.duration / self.model.simulator.dt)
55
+
56
+ self.model.simulator.t = self.model.simulator.t[onset:]
57
+ self.model.simulator.t = [t - self.duration for t in self.model.simulator.t]
58
+
59
+ for var, recs in self.model.recordings.items():
60
+ for seg, rec in recs.items():
61
+ recs[seg] = rec[onset:]
62
+
63
+ self.model.simulator._duration -= self.duration