dendrotweaks 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,326 @@
1
+ # Imports
2
+ import os
3
+
4
+ # DendroTweaks imports
5
+ from dendrotweaks.stimuli.populations import Population
6
+ from dendrotweaks.stimuli.iclamps import IClamp
7
+ from dendrotweaks.morphology.io import create_segment_tree
8
+ from dendrotweaks.prerun import prerun
9
+ from dendrotweaks.utils import calculate_lambda_f
10
+ from dendrotweaks.utils import POPULATIONS
11
+
12
+ # Warnings configuration
13
+ import warnings
14
+
15
+ def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
16
+ return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"
17
+
18
+ warnings.formatwarning = custom_warning_formatter
19
+
20
+
21
+ class SimulationMixin:
22
+
23
+ """Mixin class for model simulation functionalities."""
24
+
25
+ @property
26
+ def recordings(self):
27
+ """
28
+ The recordings of the model. Reference to the recordings in the simulator.
29
+ """
30
+ return self.simulator.recordings
31
+
32
+
33
+ @recordings.setter
34
+ def recordings(self, recordings):
35
+ self.simulator.recordings = recordings
36
+
37
+
38
+ # ========================================================================
39
+ # SEGMENTATION
40
+ # ========================================================================
41
+
42
+ # TODO Make a context manager for this
43
+ def _temp_clear_stimuli(self):
44
+ """
45
+ Temporarily save and clear stimuli.
46
+ """
47
+ self.export_stimuli(file_name='_temp_stimuli')
48
+ self.remove_all_stimuli()
49
+ self.remove_all_recordings()
50
+
51
+
52
+ def _temp_reload_stimuli(self):
53
+ """
54
+ Load stimuli from a temporary file and clean up.
55
+ """
56
+ self.load_stimuli(file_name='_temp_stimuli')
57
+ for ext in ['json', 'csv']:
58
+ temp_path = self.path_manager.get_file_path('stimuli', '_temp_stimuli', extension=ext)
59
+ if os.path.exists(temp_path):
60
+ os.remove(temp_path)
61
+
62
+
63
+ def set_segmentation(self, d_lambda=0.1, f=100):
64
+ """
65
+ Set the number of segments in each section based on the geometry.
66
+
67
+ Parameters
68
+ ----------
69
+ d_lambda : float
70
+ The lambda value to use.
71
+ f : float
72
+ The frequency value to use.
73
+ """
74
+ self.d_lambda = d_lambda
75
+
76
+ # Temporarily save and clear stimuli
77
+ self._temp_clear_stimuli()
78
+
79
+ # Pre-distribute parameters needed for lambda_f calculation
80
+ for param_name in ['cm', 'Ra']:
81
+ self.distribute(param_name)
82
+
83
+ # Calculate lambda_f and set nseg for each section
84
+ for sec in self.sec_tree.sections:
85
+ lambda_f = calculate_lambda_f(sec.distances, sec.diameters, sec.Ra, sec.cm, f)
86
+ nseg = max(1, int((sec.L / (d_lambda * lambda_f) + 0.9) / 2) * 2 + 1)
87
+ sec._nseg = sec._ref.nseg = nseg
88
+
89
+ # Rebuild the segment tree and redistribute parameters
90
+ self.seg_tree = create_segment_tree(self.sec_tree)
91
+ self.distribute_all()
92
+
93
+ # Reload stimuli and clean up temporary files
94
+ self._temp_reload_stimuli()
95
+
96
+
97
+ # -----------------------------------------------------------------------
98
+ # ICLAMPS
99
+ # -----------------------------------------------------------------------
100
+
101
+ def add_iclamp(self, sec, loc, amp=0, delay=100, dur=100):
102
+ """
103
+ Add an IClamp to a section.
104
+
105
+ Parameters
106
+ ----------
107
+ sec : Section
108
+ The section to add the IClamp to.
109
+ loc : float
110
+ The location of the IClamp in the section.
111
+ amp : float, optional
112
+ The amplitude of the IClamp. Default is 0.
113
+ delay : float, optional
114
+ The delay of the IClamp. Default is 100.
115
+ dur : float, optional
116
+ The duration of the IClamp. Default is 100.
117
+ """
118
+ seg = sec(loc)
119
+ if self.iclamps.get(seg):
120
+ self.remove_iclamp(sec, loc)
121
+ iclamp = IClamp(sec, loc, amp, delay, dur)
122
+ print(f'IClamp added to sec {sec} at loc {loc}.')
123
+ self.iclamps[seg] = iclamp
124
+
125
+
126
+ def remove_iclamp(self, sec, loc):
127
+ """
128
+ Remove an IClamp from a section.
129
+
130
+ Parameters
131
+ ----------
132
+ sec : Section
133
+ The section to remove the IClamp from.
134
+ loc : float
135
+ The location of the IClamp in the section.
136
+ """
137
+ seg = sec(loc)
138
+ if self.iclamps.get(seg):
139
+ self.iclamps.pop(seg)
140
+
141
+
142
+ def remove_all_iclamps(self):
143
+ """
144
+ Remove all IClamps from the model.
145
+ """
146
+
147
+ for seg in list(self.iclamps.keys()):
148
+ sec, loc = seg._section, seg.x
149
+ self.remove_iclamp(sec, loc)
150
+ if self.iclamps:
151
+ warnings.warn(f'Not all iclamps were removed: {self.iclamps}')
152
+ self.iclamps = {}
153
+
154
+
155
+ # -----------------------------------------------------------------------
156
+ # SYNAPSES
157
+ # -----------------------------------------------------------------------
158
+
159
+ def _add_population(self, population):
160
+ self.populations[population.syn_type][population.name] = population
161
+
162
+
163
+ def add_population(self, segments, N, syn_type):
164
+ """
165
+ Add a population of synapses to the model.
166
+
167
+ Parameters
168
+ ----------
169
+ segments : list[Segment]
170
+ The segments to add the synapses to.
171
+ N : int
172
+ The number of synapses to add.
173
+ syn_type : str
174
+ The type of synapse to add.
175
+ """
176
+ idx = len(self.populations[syn_type])
177
+ population = Population(idx, segments, N, syn_type)
178
+ population.allocate_synapses()
179
+ population.create_inputs()
180
+ self._add_population(population)
181
+
182
+
183
+ def update_population_kinetic_params(self, pop_name, **params):
184
+ """
185
+ Update the kinetic parameters of a population of synapses.
186
+
187
+ Parameters
188
+ ----------
189
+ pop_name : str
190
+ The name of the population.
191
+ params : dict
192
+ The parameters to update.
193
+ """
194
+ syn_type, idx = pop_name.rsplit('_', 1)
195
+ population = self.populations[syn_type][pop_name]
196
+ population.update_kinetic_params(**params)
197
+ print(population.kinetic_params)
198
+
199
+
200
+ def update_population_input_params(self, pop_name, **params):
201
+ """
202
+ Update the input parameters of a population of synapses.
203
+
204
+ Parameters
205
+ ----------
206
+ pop_name : str
207
+ The name of the population.
208
+ params : dict
209
+ The parameters to update.
210
+ """
211
+ syn_type, idx = pop_name.rsplit('_', 1)
212
+ population = self.populations[syn_type][pop_name]
213
+ population.update_input_params(**params)
214
+ print(population.input_params)
215
+
216
+
217
+ def remove_population(self, name):
218
+ """
219
+ Remove a population of synapses from the model.
220
+
221
+ Parameters
222
+ ----------
223
+ name : str
224
+ The name of the population
225
+ """
226
+ syn_type, idx = name.rsplit('_', 1)
227
+ population = self.populations[syn_type].pop(name)
228
+ population.clean()
229
+
230
+
231
+ def remove_all_populations(self):
232
+ """
233
+ Remove all populations of synapses from the model.
234
+ """
235
+ for syn_type in self.populations:
236
+ for name in list(self.populations[syn_type].keys()):
237
+ self.remove_population(name)
238
+ if any(self.populations.values()):
239
+ warnings.warn(f'Not all populations were removed: {self.populations}')
240
+ self.populations = POPULATIONS
241
+
242
+
243
+ def remove_all_stimuli(self):
244
+ """
245
+ Remove all stimuli from the model.
246
+ """
247
+ self.remove_all_iclamps()
248
+ self.remove_all_populations()
249
+
250
+
251
+ # ========================================================================
252
+ # SIMULATION
253
+ # ========================================================================
254
+
255
+ def add_recording(self, sec, loc, var='v'):
256
+ """
257
+ Add a recording to the model.
258
+
259
+ Parameters
260
+ ----------
261
+ sec : Section
262
+ The section to record from.
263
+ loc : float
264
+ The location along the normalized section length to record from.
265
+ var : str, optional
266
+ The variable to record. Default is 'v'.
267
+ """
268
+ self.simulator.add_recording(sec, loc, var)
269
+ print(f'Recording added to sec {sec} at loc {loc}.')
270
+
271
+
272
+ def remove_recording(self, sec, loc, var='v'):
273
+ """
274
+ Remove a recording from the model.
275
+
276
+ Parameters
277
+ ----------
278
+ sec : Section
279
+ The section to remove the recording from.
280
+ loc : float
281
+ The location along the normalized section length to remove the recording from.
282
+ """
283
+ self.simulator.remove_recording(sec, loc, var)
284
+
285
+
286
+ def remove_all_recordings(self, var=None):
287
+ """
288
+ Remove all recordings from the model.
289
+ """
290
+ self.simulator.remove_all_recordings(var=var)
291
+
292
+
293
+ def run(self, duration=300, prerun_time=0, truncate=True):
294
+ """
295
+ Run the simulation for a specified duration, optionally preceded by a prerun period
296
+ to stabilize the model.
297
+
298
+ Parameters
299
+ ----------
300
+ duration : float
301
+ Duration of the main simulation (excluding prerun).
302
+ prerun_time : float
303
+ Optional prerun period to run before the main simulation.
304
+ truncate : bool
305
+ Whether to truncate prerun data after the simulation.
306
+ """
307
+ if duration <= 0:
308
+ raise ValueError("Simulation duration must be positive.")
309
+ if prerun_time < 0:
310
+ raise ValueError("Prerun time must be non-negative.")
311
+
312
+ total_time = duration + prerun_time
313
+
314
+ if prerun_time > 0:
315
+ with prerun(self, duration=prerun_time, truncate=truncate):
316
+ self.simulator.run(total_time)
317
+ else:
318
+ self.simulator.run(duration)
319
+
320
+
321
+ def get_traces(self):
322
+ return self.simulator.get_traces()
323
+
324
+
325
+ def plot(self, *args, **kwargs):
326
+ self.simulator.plot(*args, **kwargs)
@@ -70,6 +70,6 @@ class SWCReader():
70
70
  for t in df['Type'].unique():
71
71
  color = types_to_colors.get(t, 'k')
72
72
  mask = df['Type'] == t
73
- ax.scatter(df[mask]['X'], df[mask]['Y'], df[mask]['Z'],
74
- c=color, s=1, label=f'Type {t}')
73
+ ax.scatter(xs=df[mask]['X'], ys=df[mask]['Y'], zs=df[mask]['Z'],
74
+ c=color, s=1, label=f'Type {t}')
75
75
  ax.legend()
@@ -50,9 +50,9 @@ class PathManager:
50
50
  @property
51
51
  def path_to_data(self):
52
52
  """
53
- The path to the data directory.
53
+ The path to the data directory, which is always the parent directory of path_to_model.
54
54
  """
55
- return os.path.dirname(self.path_to_model)
55
+ return os.path.abspath(os.path.join(self.path_to_model, os.pardir))
56
56
 
57
57
  def __repr__(self):
58
58
  return f"PathManager({self.path_to_model})"
dendrotweaks/prerun.py ADDED
@@ -0,0 +1,63 @@
1
+ class prerun:
2
+ """
3
+ Context manager to prerun a simulation for a specified duration.
4
+ This is useful to stabilize the model before running the main simulation.
5
+ """
6
+ def __init__(self, model, duration=300, truncate=True):
7
+ self.model = model
8
+ if duration <= 0:
9
+ raise ValueError("Duration must be a positive number.")
10
+ self.duration = duration
11
+ self.truncate = truncate
12
+ self._original_iclamp_delays = {}
13
+ self._original_input_params = {}
14
+
15
+ def __enter__(self):
16
+
17
+ self._original_iclamp_delays = {k: v.delay for k, v in self.model.iclamps.items()}
18
+ self._original_input_params = {
19
+ (pk, sk): (pop.input_params['start'], pop.input_params['end'])
20
+ for pk, pops in self.model.populations.items()
21
+ for sk, pop in pops.items()
22
+ }
23
+ for iclamp in self.model.iclamps.values():
24
+ iclamp.delay += self.duration
25
+
26
+ for pops in self.model.populations.values():
27
+ for pop in pops.values():
28
+ start = pop.input_params['start'] + self.duration
29
+ end = pop.input_params['end'] + self.duration
30
+ pop.update_input_params(**{
31
+ 'start': start,
32
+ 'end': end
33
+ })
34
+
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_value, traceback):
38
+
39
+ # Restore iClamp delays
40
+ for seg, delay in self._original_iclamp_delays.items():
41
+ self.model.iclamps[seg].delay = delay
42
+
43
+ # Restore input timings
44
+ for (pk, sk), (start, end) in self._original_input_params.items():
45
+ self.model.populations[pk][sk].update_input_params(**{'start': start, 'end': end})
46
+
47
+ duration_timepoints = int(self.duration / self.model.simulator.dt)
48
+ if len(self.model.simulator.t) > duration_timepoints and self.truncate:
49
+ self._truncate()
50
+
51
+ def _truncate(self):
52
+ """Truncate the simulation time and recordings to the specified duration."""
53
+
54
+ onset = int(self.duration / self.model.simulator.dt)
55
+
56
+ self.model.simulator.t = self.model.simulator.t[onset:]
57
+ self.model.simulator.t = [t - self.duration for t in self.model.simulator.t]
58
+
59
+ for var, recs in self.model.recordings.items():
60
+ for seg, rec in recs.items():
61
+ recs[seg] = rec[onset:]
62
+
63
+ self.model.simulator._duration -= self.duration
dendrotweaks/utils.py CHANGED
@@ -26,24 +26,41 @@ SWC_ID_TO_DOMAIN = {
26
26
  8: 'reduced',
27
27
  }
28
28
 
29
+ POPULATIONS = {'AMPA': {}, 'NMDA': {}, 'AMPA_NMDA': {}, 'GABAa': {}}
30
+
31
+ INDEPENDENT_PARAMS = {
32
+ 'cm': 1, # uF/cm2
33
+ 'Ra': 100, # Ohm cm
34
+ 'ena': 50, # mV
35
+ 'ek': -77, # mV
36
+ 'eca': 140 # mV
37
+ }
38
+
39
+ DOMAIN_TO_GROUP = {
40
+ 'soma': 'somatic',
41
+ 'axon': 'axonal',
42
+ 'dend': 'dendritic',
43
+ 'apic': 'apical',
44
+ }
45
+
29
46
  DOMAIN_TO_SWC_ID = {
30
47
  v: k for k, v in SWC_ID_TO_DOMAIN.items()
31
48
  }
32
49
 
33
- def get_swc_idx(domain_name):
34
- base_domain, _, idx = domain_name.partition('_')
35
- if base_domain == 'reduced':
36
- return int(f'8{idx}')
37
- elif base_domain == 'custom':
38
- return int(f'5{idx}')
39
- return DOMAIN_TO_SWC_ID.get(base_domain, 0)
40
-
41
- def get_domain_name(swc_idx):
42
- if str(swc_idx).startswith('8'):
43
- return 'reduced_' + str(swc_idx)[1:]
44
- elif str(swc_idx).startswith('5'):
45
- return 'custom_' + str(swc_idx)[1:]
46
- return SWC_ID_TO_DOMAIN.get(swc_idx, 'undefined')
50
+ DOMAINS_TO_NEURON = {
51
+ 'soma': 'soma',
52
+ 'perisomatic': 'dend_11',
53
+ 'axon': 'axon',
54
+ 'apic': 'apic',
55
+ 'dend': 'dend',
56
+ 'basal': 'dend_31',
57
+ 'trunk': 'dend_41',
58
+ 'tuft': 'dend_42',
59
+ 'oblique': 'dend_43',
60
+ 'custom': 'dend_5',
61
+ 'reduced': 'dend_8',
62
+ 'undefined': 'dend_0',
63
+ }
47
64
 
48
65
  DOMAINS_TO_COLORS = {
49
66
  'soma': '#E69F00',
@@ -60,11 +77,24 @@ DOMAINS_TO_COLORS = {
60
77
  'undefined': '#7F7F7F',
61
78
  }
62
79
 
63
- def get_domain_color(domain_name):
80
+ def get_swc_idx(domain_name):
64
81
  base_domain, _, idx = domain_name.partition('_')
65
- return DOMAINS_TO_COLORS.get(base_domain, '#7F7F7F')
82
+ if base_domain == 'reduced':
83
+ return int(f'8{idx}')
84
+ elif base_domain == 'custom':
85
+ return int(f'5{idx}')
86
+ return DOMAIN_TO_SWC_ID.get(base_domain, 0)
66
87
 
88
+ def get_domain_name(swc_idx):
89
+ if str(swc_idx).startswith('8'):
90
+ return 'reduced_' + str(swc_idx)[1:]
91
+ elif str(swc_idx).startswith('5'):
92
+ return 'custom_' + str(swc_idx)[1:]
93
+ return SWC_ID_TO_DOMAIN.get(swc_idx, 'undefined')
67
94
 
95
+ def get_domain_color(domain_name):
96
+ base_domain, _, idx = domain_name.partition('_')
97
+ return DOMAINS_TO_COLORS.get(base_domain, '#7F7F7F')
68
98
 
69
99
  def timeit(func):
70
100
  def wrapper(*args, **kwargs):
@@ -75,7 +105,6 @@ def timeit(func):
75
105
  return result
76
106
  return wrapper
77
107
 
78
-
79
108
  def calculate_lambda_f(distances, diameters, Ra=35.4, Cm=1, frequency=100):
80
109
  """
81
110
  Calculate the frequency-dependent length constant (lambda_f) according to NEURON's implementation,
@@ -118,10 +147,6 @@ def calculate_lambda_f(distances, diameters, Ra=35.4, Cm=1, frequency=100):
118
147
  # Return section_L/lam (electrotonic length of the section)
119
148
  return section_L / lam
120
149
 
121
- if (__name__ == '__main__'):
122
- print('Executing as standalone script')
123
-
124
-
125
150
  def dynamic_import(module_name, class_name):
126
151
  """
127
152
  Dynamically import a class from a module.
@@ -142,20 +167,17 @@ def dynamic_import(module_name, class_name):
142
167
  module = import_module(module_name)
143
168
  return getattr(module, class_name)
144
169
 
145
-
146
170
  def list_folders(path_to_folder):
147
171
  folders = [f for f in os.listdir(path_to_folder)
148
172
  if os.path.isdir(os.path.join(path_to_folder, f))]
149
173
  sorted_folders = sorted(folders, key=lambda x: x.lower())
150
174
  return sorted_folders
151
175
 
152
-
153
176
  def list_files(path_to_folder, extension):
154
177
  files = [f for f in os.listdir(path_to_folder)
155
178
  if f.endswith(extension)]
156
179
  return files
157
180
 
158
-
159
181
  def write_file(content: str, path_to_file: str, verbose: bool = True) -> None:
160
182
  """
161
183
  Write content to a file.
@@ -175,13 +197,11 @@ def write_file(content: str, path_to_file: str, verbose: bool = True) -> None:
175
197
  f.write(content)
176
198
  print(f"Saved content to {path_to_file}")
177
199
 
178
-
179
200
  def read_file(path_to_file):
180
201
  with open(path_to_file, 'r') as f:
181
202
  content = f.read()
182
203
  return content
183
204
 
184
-
185
205
  def download_example_data(path_to_destination, include_templates=True, include_modfiles=True):
186
206
  """
187
207
  Download and extract specific folders from the DendroTweaks GitHub repository:
@@ -240,8 +260,6 @@ def download_example_data(path_to_destination, include_templates=True, include_m
240
260
  os.remove(zip_path)
241
261
  print(f"Data downloaded and extracted successfully to {path_to_destination}/.")
242
262
 
243
-
244
-
245
263
  def apply_dark_theme():
246
264
  """
247
265
  Apply a dark theme to matplotlib plots.
@@ -259,4 +277,71 @@ def apply_dark_theme():
259
277
  'ytick.color': 'white',
260
278
  'text.color': 'white',
261
279
  'axes.prop_cycle': plt.cycler(color=plt.cm.tab10.colors), # use standard matplotlib colors
262
- })
280
+ })
281
+
282
+ def mse(y_true, y_pred):
283
+ return np.mean((np.array(y_true) - np.array(y_pred)) ** 2)
284
+
285
+ def poly_fit(x, y, max_degree=6, tolerance=1e-6):
286
+ """
287
+ Fit a polynomial to the data and return the coefficients and predicted values.
288
+ """
289
+ for degree in range(max_degree + 1):
290
+ coeffs = np.polyfit(x, y, degree)
291
+ y_pred = np.polyval(coeffs, x)
292
+ if np.all(np.abs(np.array(y) - y_pred) < tolerance):
293
+ break
294
+ return coeffs, y_pred
295
+
296
+ def step_fit(x, y):
297
+ """
298
+ Fit a single step function with variable-width transition zone.
299
+ Returns (high_val, low_val, start, end), and predicted y-values.
300
+ """
301
+ x = np.array(x)
302
+ y = np.array(y)
303
+
304
+ sort_idx = np.argsort(x)
305
+ x = x[sort_idx]
306
+ y = y[sort_idx]
307
+
308
+ best_mse = float('inf')
309
+ best_params = None
310
+ best_pred = None
311
+
312
+ n = len(x)
313
+ for i in range(n - 1):
314
+ for j in range(i + 1, n):
315
+ start = x[i]
316
+ end = x[j]
317
+ inside = (x > start) & (x < end)
318
+ outside = ~inside
319
+
320
+ if not np.any(inside) or not np.any(outside):
321
+ continue
322
+
323
+ high_val = np.nanmean(y[inside])
324
+ low_val = np.nanmean(y[outside])
325
+
326
+ pred = np.where(inside, high_val, low_val)
327
+ score = mse(y, pred)
328
+
329
+ if score < best_mse:
330
+ best_mse = score
331
+ best_params = (start, end, low_val, high_val)
332
+ best_pred = pred
333
+
334
+ return best_params, best_pred
335
+
336
+ DEFAULT_FIT_MODELS = {
337
+ 'poly': {
338
+ 'fit': poly_fit,
339
+ 'score': mse,
340
+ 'complexity': lambda coeffs: len(coeffs) - 1 # degree of polynomial
341
+ },
342
+ 'step': {
343
+ 'fit': step_fit,
344
+ 'score': mse,
345
+ 'complexity': lambda params: 4
346
+ }
347
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dendrotweaks
3
- Version: 0.4.5
3
+ Version: 0.4.6
4
4
  Summary: A toolbox for exploring dendritic dynamics
5
5
  Home-page: https://dendrotweaks.dendrites.gr
6
6
  Author: Roman Makarov