bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/synapses.py CHANGED
@@ -1,28 +1,40 @@
1
- import os
2
1
  import json
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Dict, List, Optional, Tuple
5
+
6
+ # widgets
7
+ import ipywidgets as widgets
8
+ import matplotlib.pyplot as plt
3
9
  import neuron
4
10
  import numpy as np
11
+ from IPython.display import clear_output, display
12
+ from ipywidgets import HBox, VBox
5
13
  from neuron import h
6
- from typing import List, Dict, Callable, Optional,Tuple
7
- from tqdm.notebook import tqdm
8
- import matplotlib.pyplot as plt
9
14
  from neuron.units import ms, mV
10
- from dataclasses import dataclass
15
+ from scipy.optimize import curve_fit, minimize, minimize_scalar
16
+
11
17
  # scipy
12
18
  from scipy.signal import find_peaks
13
- from scipy.optimize import curve_fit,minimize_scalar,minimize
14
- # widgets
15
- import ipywidgets as widgets
16
- from IPython.display import display, clear_output
17
- from ipywidgets import HBox, VBox
19
+ from tqdm.notebook import tqdm
20
+
18
21
 
19
22
  class SynapseTuner:
20
- def __init__(self, mechanisms_dir: str, templates_dir: str, conn_type_settings: dict, connection: str,
21
- general_settings: dict, json_folder_path: str = None, current_name: str = 'i',
22
- other_vars_to_record: list = None, slider_vars: list = None) -> None:
23
+ def __init__(
24
+ self,
25
+ mechanisms_dir: str,
26
+ templates_dir: str,
27
+ conn_type_settings: dict,
28
+ connection: str,
29
+ general_settings: dict,
30
+ json_folder_path: str = None,
31
+ current_name: str = "i",
32
+ other_vars_to_record: list = None,
33
+ slider_vars: list = None,
34
+ ) -> None:
23
35
  """
24
36
  Initialize the SynapseModule class with connection type settings, mechanisms, and template directories.
25
-
37
+
26
38
  Parameters:
27
39
  -----------
28
40
  mechanisms_dir : str
@@ -53,15 +65,17 @@ class SynapseTuner:
53
65
  self._update_spec_syn_param(json_folder_path)
54
66
  self.general_settings = general_settings
55
67
  self.conn = self.conn_type_settings[connection]
56
- self.synaptic_props = self.conn['spec_syn_param']
57
- self.vclamp = general_settings['vclamp']
68
+ self.synaptic_props = self.conn["spec_syn_param"]
69
+ self.vclamp = general_settings["vclamp"]
58
70
  self.current_name = current_name
59
71
  self.other_vars_to_record = other_vars_to_record
60
72
  self.ispk = None
61
73
 
62
74
  if slider_vars:
63
75
  # Start by filtering based on keys in slider_vars
64
- self.slider_vars = {key: value for key, value in self.synaptic_props.items() if key in slider_vars}
76
+ self.slider_vars = {
77
+ key: value for key, value in self.synaptic_props.items() if key in slider_vars
78
+ }
65
79
  # Iterate over slider_vars and check for missing keys in self.synaptic_props
66
80
  for key in slider_vars:
67
81
  # If the key is missing from synaptic_props, get the value using getattr
@@ -70,8 +84,8 @@ class SynapseTuner:
70
84
  # Get the alternative value from getattr dynamically
71
85
  self._set_up_cell()
72
86
  self._set_up_synapse()
73
- value = getattr(self.syn,key)
74
- #print(value)
87
+ value = getattr(self.syn, key)
88
+ # print(value)
75
89
  self.slider_vars[key] = value
76
90
  except AttributeError as e:
77
91
  print(f"Error accessing '{key}' in syn {self.syn}: {e}")
@@ -79,30 +93,41 @@ class SynapseTuner:
79
93
  else:
80
94
  self.slider_vars = self.synaptic_props
81
95
 
82
-
83
- h.tstop = general_settings['tstart'] + general_settings['tdur']
84
- h.dt = general_settings['dt'] # Time step (resolution) of the simulation in ms
96
+ h.tstop = general_settings["tstart"] + general_settings["tdur"]
97
+ h.dt = general_settings["dt"] # Time step (resolution) of the simulation in ms
85
98
  h.steps_per_ms = 1 / h.dt
86
- h.celsius = general_settings['celsius']
87
-
99
+ h.celsius = general_settings["celsius"]
100
+
88
101
  # get some stuff set up we need for both SingleEvent and Interactive Tuner
89
102
  self._set_up_cell()
90
103
  self._set_up_synapse()
91
-
104
+
92
105
  self.nstim = h.NetStim()
93
106
  self.nstim2 = h.NetStim()
94
-
107
+
95
108
  self.vcl = h.VClamp(self.cell.soma[0](0.5))
96
-
97
- self.nc = h.NetCon(self.nstim, self.syn, self.general_settings['threshold'], self.general_settings['delay'], self.general_settings['weight'])
98
- self.nc2 = h.NetCon(self.nstim2, self.syn, self.general_settings['threshold'], self.general_settings['delay'], self.general_settings['weight'])
99
-
109
+
110
+ self.nc = h.NetCon(
111
+ self.nstim,
112
+ self.syn,
113
+ self.general_settings["threshold"],
114
+ self.general_settings["delay"],
115
+ self.general_settings["weight"],
116
+ )
117
+ self.nc2 = h.NetCon(
118
+ self.nstim2,
119
+ self.syn,
120
+ self.general_settings["threshold"],
121
+ self.general_settings["delay"],
122
+ self.general_settings["weight"],
123
+ )
124
+
100
125
  self._set_up_recorders()
101
126
 
102
127
  def _update_spec_syn_param(self, json_folder_path):
103
128
  """
104
129
  Update specific synaptic parameters using JSON files located in the specified folder.
105
-
130
+
106
131
  Parameters:
107
132
  -----------
108
133
  json_folder_path : str
@@ -111,42 +136,45 @@ class SynapseTuner:
111
136
  for conn_type, settings in self.conn_type_settings.items():
112
137
  json_file_path = os.path.join(json_folder_path, f"{conn_type}.json")
113
138
  if os.path.exists(json_file_path):
114
- with open(json_file_path, 'r') as json_file:
139
+ with open(json_file_path, "r") as json_file:
115
140
  json_data = json.load(json_file)
116
- settings['spec_syn_param'].update(json_data)
141
+ settings["spec_syn_param"].update(json_data)
117
142
  else:
118
143
  print(f"JSON file for {conn_type} not found.")
119
144
 
120
-
121
145
  def _set_up_cell(self):
122
146
  """
123
147
  Set up the neuron cell based on the specified connection settings.
124
148
  """
125
- self.cell = getattr(h, self.conn['spec_settings']['post_cell'])()
126
-
149
+ self.cell = getattr(h, self.conn["spec_settings"]["post_cell"])()
127
150
 
128
151
  def _set_up_synapse(self):
129
152
  """
130
153
  Set up the synapse on the target cell according to the synaptic parameters in `conn_type_settings`.
131
-
154
+
132
155
  Notes:
133
156
  ------
134
157
  - `_set_up_cell()` should be called before setting up the synapse.
135
158
  - Synapse location, type, and properties are specified within `spec_syn_param` and `spec_settings`.
136
159
  """
137
- self.syn = getattr(h, self.conn['spec_settings']['level_of_detail'])(list(self.cell.all)[self.conn['spec_settings']['sec_id']](self.conn['spec_settings']['sec_x']))
138
- for key, value in self.conn['spec_syn_param'].items():
160
+ self.syn = getattr(h, self.conn["spec_settings"]["level_of_detail"])(
161
+ list(self.cell.all)[self.conn["spec_settings"]["sec_id"]](
162
+ self.conn["spec_settings"]["sec_x"]
163
+ )
164
+ )
165
+ for key, value in self.conn["spec_syn_param"].items():
139
166
  if isinstance(value, (int, float)): # Only create sliders for numeric values
140
167
  if hasattr(self.syn, key):
141
168
  setattr(self.syn, key, value)
142
169
  else:
143
- print(f"Warning: {key} cannot be assigned as it does not exist in the synapse. Check your mod file or spec_syn_param.")
144
-
170
+ print(
171
+ f"Warning: {key} cannot be assigned as it does not exist in the synapse. Check your mod file or spec_syn_param."
172
+ )
145
173
 
146
174
  def _set_up_recorders(self):
147
175
  """
148
176
  Set up recording vectors to capture simulation data.
149
-
177
+
150
178
  The method sets up recorders for:
151
179
  - Synaptic current specified by `current_name`
152
180
  - Other specified synaptic variables (`other_vars_to_record`)
@@ -155,15 +183,17 @@ class SynapseTuner:
155
183
  self.rec_vectors = {}
156
184
  for var in self.other_vars_to_record:
157
185
  self.rec_vectors[var] = h.Vector()
158
- ref_attr = f'_ref_{var}'
186
+ ref_attr = f"_ref_{var}"
159
187
  if hasattr(self.syn, ref_attr):
160
188
  self.rec_vectors[var].record(getattr(self.syn, ref_attr))
161
189
  else:
162
- print(f"Warning: {ref_attr} not found in the syn object. Use vars() to inspect available attributes.")
190
+ print(
191
+ f"Warning: {ref_attr} not found in the syn object. Use vars() to inspect available attributes."
192
+ )
163
193
 
164
194
  # Record synaptic current
165
195
  self.rec_vectors[self.current_name] = h.Vector()
166
- ref_attr = f'_ref_{self.current_name}'
196
+ ref_attr = f"_ref_{self.current_name}"
167
197
  if hasattr(self.syn, ref_attr):
168
198
  self.rec_vectors[self.current_name].record(getattr(self.syn, ref_attr))
169
199
  else:
@@ -181,70 +211,75 @@ class SynapseTuner:
181
211
  self.soma_v.record(self.cell.soma[0](0.5)._ref_v)
182
212
  self.ivcl.record(self.vcl._ref_i)
183
213
 
184
-
185
- def SingleEvent(self,plot_and_print=True):
214
+ def SingleEvent(self, plot_and_print=True):
186
215
  """
187
216
  Simulate a single synaptic event by delivering an input stimulus to the synapse.
188
-
189
- The method sets up the neuron cell, synapse, stimulus, and voltage clamp,
217
+
218
+ The method sets up the neuron cell, synapse, stimulus, and voltage clamp,
190
219
  and then runs the NEURON simulation for a single event. The single synaptic event will occur at general_settings['tstart']
191
220
  Will display graphs and synaptic properies works best with a jupyter notebook
192
221
  """
193
222
  self.ispk = None
194
-
223
+
195
224
  # user slider values if the sliders are set up
196
- if hasattr(self, 'dynamic_sliders'):
225
+ if hasattr(self, "dynamic_sliders"):
197
226
  syn_props = {var: slider.value for var, slider in self.dynamic_sliders.items()}
198
227
  self._set_syn_prop(**syn_props)
199
-
200
- # sets values based off optimizer
201
- if hasattr(self,'using_optimizer'):
228
+
229
+ # sets values based off optimizer
230
+ if hasattr(self, "using_optimizer"):
202
231
  for name, value in zip(self.param_names, self.params):
203
232
  setattr(self.syn, name, value)
204
233
 
205
234
  # Set up the stimulus
206
- self.nstim.start = self.general_settings['tstart']
235
+ self.nstim.start = self.general_settings["tstart"]
207
236
  self.nstim.noise = 0
208
237
  self.nstim2.start = h.tstop
209
238
  self.nstim2.noise = 0
210
-
239
+
211
240
  # Set up voltage clamp
212
- vcldur = [[0, 0, 0], [self.general_settings['tstart'], h.tstop, 1e9]]
241
+ vcldur = [[0, 0, 0], [self.general_settings["tstart"], h.tstop, 1e9]]
213
242
  for i in range(3):
214
- self.vcl.amp[i] = self.conn['spec_settings']['vclamp_amp']
243
+ self.vcl.amp[i] = self.conn["spec_settings"]["vclamp_amp"]
215
244
  self.vcl.dur[i] = vcldur[1][i]
216
245
 
217
246
  # Run simulation
218
- h.tstop = self.general_settings['tstart'] + self.general_settings['tdur']
219
- self.nstim.interval = self.general_settings['tdur']
247
+ h.tstop = self.general_settings["tstart"] + self.general_settings["tdur"]
248
+ self.nstim.interval = self.general_settings["tdur"]
220
249
  self.nstim.number = 1
221
250
  self.nstim2.start = h.tstop
222
251
  h.run()
223
-
252
+
224
253
  current = np.array(self.rec_vectors[self.current_name])
225
- syn_props = self._get_syn_prop(rise_interval=self.general_settings['rise_interval'],dt=h.dt)
226
- current = (current - syn_props['baseline']) * 1000 # Convert to pA
254
+ syn_props = self._get_syn_prop(
255
+ rise_interval=self.general_settings["rise_interval"], dt=h.dt
256
+ )
257
+ current = (current - syn_props["baseline"]) * 1000 # Convert to pA
227
258
  current_integral = np.trapz(current, dx=h.dt) # pA·ms
228
-
259
+
229
260
  if plot_and_print:
230
- self._plot_model([self.general_settings['tstart'] - 5, self.general_settings['tstart'] + self.general_settings['tdur']])
261
+ self._plot_model(
262
+ [
263
+ self.general_settings["tstart"] - 5,
264
+ self.general_settings["tstart"] + self.general_settings["tdur"],
265
+ ]
266
+ )
231
267
  for prop in syn_props.items():
232
268
  print(prop)
233
- print(f'Current Integral in pA*ms: {current_integral:.2f}')
234
-
235
- self.rise_time = syn_props['rise_time']
236
- self.decay_time = syn_props['decay_time']
269
+ print(f"Current Integral in pA*ms: {current_integral:.2f}")
237
270
 
271
+ self.rise_time = syn_props["rise_time"]
272
+ self.decay_time = syn_props["decay_time"]
238
273
 
239
274
  def _find_first(self, x):
240
275
  """
241
276
  Find the index of the first non-zero element in a given array.
242
-
277
+
243
278
  Parameters:
244
279
  -----------
245
280
  x : np.array
246
281
  The input array to search.
247
-
282
+
248
283
  Returns:
249
284
  --------
250
285
  int
@@ -254,11 +289,10 @@ class SynapseTuner:
254
289
  idx = np.nonzero(x)[0]
255
290
  return idx[0] if idx.size else None
256
291
 
257
-
258
292
  def _get_syn_prop(self, rise_interval=(0.2, 0.8), dt=h.dt, short=False):
259
293
  """
260
294
  Calculate synaptic properties such as peak amplitude, latency, rise time, decay time, and half-width.
261
-
295
+
262
296
  Parameters:
263
297
  -----------
264
298
  rise_interval : tuple of floats, optional
@@ -267,11 +301,11 @@ class SynapseTuner:
267
301
  Time step of the simulation (default is NEURON's `h.dt`).
268
302
  short : bool, optional
269
303
  If True, only return baseline and sign without calculating full properties.
270
-
304
+
271
305
  Returns:
272
306
  --------
273
307
  dict
274
- A dictionary containing the synaptic properties: baseline, sign, peak amplitude, latency, rise time,
308
+ A dictionary containing the synaptic properties: baseline, sign, peak amplitude, latency, rise time,
275
309
  decay time, and half-width.
276
310
  """
277
311
  if self.vclamp:
@@ -282,17 +316,17 @@ class SynapseTuner:
282
316
  tspk = np.asarray(self.tspk)
283
317
  if tspk.size:
284
318
  tspk = tspk[0]
285
-
319
+
286
320
  ispk = int(np.floor(tspk / dt))
287
321
  baseline = isyn[ispk]
288
322
  isyn = isyn[ispk:] - baseline
289
323
  # print(np.abs(isyn))
290
324
  # print(np.argmax(np.abs(isyn)))
291
325
  # print(isyn[np.argmax(np.abs(isyn))])
292
- # print(np.sign(isyn[np.argmax(np.abs(isyn))]))
293
- sign = np.sign(isyn[np.argmax(np.abs(isyn))])
326
+ # print(np.sign(isyn[np.argmax(np.abs(isyn))]))
327
+ sign = np.sign(isyn[np.argmax(np.abs(isyn))])
294
328
  if short:
295
- return {'baseline': baseline, 'sign': sign}
329
+ return {"baseline": baseline, "sign": sign}
296
330
  isyn *= sign
297
331
  # print(isyn)
298
332
  # peak amplitude
@@ -300,29 +334,39 @@ class SynapseTuner:
300
334
  ipk = ipk[0]
301
335
  peak = isyn[ipk]
302
336
  # latency
303
- istart = self._find_first(np.diff(isyn[:ipk + 1]) > 0)
337
+ istart = self._find_first(np.diff(isyn[: ipk + 1]) > 0)
304
338
  latency = dt * (istart + 1)
305
339
  # rise time
306
- rt1 = self._find_first(isyn[istart:ipk + 1] > rise_interval[0] * peak)
307
- rt2 = self._find_first(isyn[istart:ipk + 1] > rise_interval[1] * peak)
340
+ rt1 = self._find_first(isyn[istart : ipk + 1] > rise_interval[0] * peak)
341
+ rt2 = self._find_first(isyn[istart : ipk + 1] > rise_interval[1] * peak)
308
342
  rise_time = (rt2 - rt1) * dt
309
343
  # decay time
310
344
  iend = self._find_first(np.diff(isyn[ipk:]) > 0)
311
345
  iend = isyn.size - 1 if iend is None else iend + ipk
312
346
  decay_len = iend - ipk + 1
313
- popt, _ = curve_fit(lambda t, a, tau: a * np.exp(-t / tau), dt * np.arange(decay_len),
314
- isyn[ipk:iend + 1], p0=(peak, dt * decay_len / 2))
347
+ popt, _ = curve_fit(
348
+ lambda t, a, tau: a * np.exp(-t / tau),
349
+ dt * np.arange(decay_len),
350
+ isyn[ipk : iend + 1],
351
+ p0=(peak, dt * decay_len / 2),
352
+ )
315
353
  decay_time = popt[1]
316
354
  # half-width
317
- hw1 = self._find_first(isyn[istart:ipk + 1] > 0.5 * peak)
355
+ hw1 = self._find_first(isyn[istart : ipk + 1] > 0.5 * peak)
318
356
  hw2 = self._find_first(isyn[ipk:] < 0.5 * peak)
319
357
  hw2 = isyn.size if hw2 is None else hw2 + ipk
320
358
  half_width = dt * (hw2 - hw1)
321
- output = {'baseline': baseline, 'sign': sign, 'latency': latency,
322
- 'amp': peak, 'rise_time': rise_time, 'decay_time': decay_time, 'half_width': half_width}
359
+ output = {
360
+ "baseline": baseline,
361
+ "sign": sign,
362
+ "latency": latency,
363
+ "amp": peak,
364
+ "rise_time": rise_time,
365
+ "decay_time": decay_time,
366
+ "half_width": half_width,
367
+ }
323
368
  return output
324
369
 
325
-
326
370
  def _plot_model(self, xlim):
327
371
  """
328
372
  Plots the results of the simulation, including synaptic current, soma voltage,
@@ -332,7 +376,7 @@ class SynapseTuner:
332
376
  -----------
333
377
  xlim : tuple
334
378
  A tuple specifying the limits of the x-axis for the plot (start_time, end_time).
335
-
379
+
336
380
  Notes:
337
381
  ------
338
382
  - The function determines how many plots to generate based on the number of variables recorded.
@@ -342,25 +386,25 @@ class SynapseTuner:
342
386
  """
343
387
  # Determine the number of plots to generate (at least 2: current and voltage)
344
388
  num_vars_to_plot = 2 + (len(self.other_vars_to_record) if self.other_vars_to_record else 0)
345
-
389
+
346
390
  # Set up figure based on number of plots (2x2 grid max)
347
391
  num_rows = (num_vars_to_plot + 1) // 2 # This ensures we have enough rows
348
392
  fig, axs = plt.subplots(num_rows, 2, figsize=(12, 7))
349
393
  axs = axs.ravel()
350
-
394
+
351
395
  # Plot synaptic current (always included)
352
396
  current = self.rec_vectors[self.current_name]
353
- syn_prop = self._get_syn_prop(short=True,dt=h.dt)
354
- current = (current - syn_prop['baseline'])
397
+ syn_prop = self._get_syn_prop(short=True, dt=h.dt)
398
+ current = current - syn_prop["baseline"]
355
399
  current = current * 1000
356
-
400
+
357
401
  axs[0].plot(self.t, current)
358
- if self.ispk !=None:
402
+ if self.ispk is not None:
359
403
  for num in range(len(self.ispk)):
360
- axs[0].text(self.t[self.ispk[num]],current[self.ispk[num]],f"{str(num+1)}")
404
+ axs[0].text(self.t[self.ispk[num]], current[self.ispk[num]], f"{str(num+1)}")
405
+
406
+ axs[0].set_ylabel("Synaptic Current (pA)")
361
407
 
362
- axs[0].set_ylabel('Synaptic Current (pA)')
363
-
364
408
  # Plot voltage clamp or soma voltage (always included)
365
409
  ispk = int(np.round(self.tspk[0] / h.dt))
366
410
  if self.vclamp:
@@ -368,37 +412,36 @@ class SynapseTuner:
368
412
  ivcl_plt = np.array(self.ivcl) - baseline
369
413
  ivcl_plt[:ispk] = 0
370
414
  axs[1].plot(self.t, 1000 * ivcl_plt)
371
- axs[1].set_ylabel('VClamp Current (pA)')
415
+ axs[1].set_ylabel("VClamp Current (pA)")
372
416
  else:
373
417
  soma_v_plt = np.array(self.soma_v)
374
418
  soma_v_plt[:ispk] = soma_v_plt[ispk]
375
419
 
376
420
  axs[1].plot(self.t, soma_v_plt)
377
- axs[1].set_ylabel('Soma Voltage (mV)')
378
-
421
+ axs[1].set_ylabel("Soma Voltage (mV)")
422
+
379
423
  # Plot any other variables from other_vars_to_record, if provided
380
424
  if self.other_vars_to_record:
381
425
  for i, var in enumerate(self.other_vars_to_record, start=2):
382
426
  if var in self.rec_vectors:
383
427
  axs[i].plot(self.t, self.rec_vectors[var])
384
- axs[i].set_ylabel(f'{var.capitalize()}')
428
+ axs[i].set_ylabel(f"{var.capitalize()}")
385
429
 
386
430
  # Adjust the layout
387
431
  for i, ax in enumerate(axs[:num_vars_to_plot]):
388
432
  ax.set_xlim(*xlim)
389
433
  if i >= num_vars_to_plot - 2: # Add x-label to the last row
390
- ax.set_xlabel('Time (ms)')
391
-
434
+ ax.set_xlabel("Time (ms)")
435
+
392
436
  # Remove extra subplots if less than 4 plots are present
393
437
  if num_vars_to_plot < len(axs):
394
438
  for j in range(num_vars_to_plot, len(axs)):
395
439
  fig.delaxes(axs[j])
396
440
 
397
- #plt.tight_layout()
441
+ # plt.tight_layout()
398
442
  plt.show()
399
443
 
400
-
401
- def _set_drive_train(self,freq=50., delay=250.):
444
+ def _set_drive_train(self, freq=50.0, delay=250.0):
402
445
  """
403
446
  Configures trains of 12 action potentials at a specified frequency and delay period
404
447
  between pulses 8 and 9.
@@ -414,7 +457,7 @@ class SynapseTuner:
414
457
  --------
415
458
  tstop : float
416
459
  The time at which the last pulse stops.
417
-
460
+
418
461
  Notes:
419
462
  ------
420
463
  - This function is based on experiments from the Allen Database.
@@ -422,10 +465,10 @@ class SynapseTuner:
422
465
  # lets also set the train drive and delay here
423
466
  self.train_freq = freq
424
467
  self.train_delay = delay
425
-
468
+
426
469
  n_init_pulse = 8
427
470
  n_ending_pulse = 4
428
- self.nstim.start = self.general_settings['tstart']
471
+ self.nstim.start = self.general_settings["tstart"]
429
472
  self.nstim.interval = 1000 / freq
430
473
  self.nstim2.interval = 1000 / freq
431
474
  self.nstim.number = n_init_pulse
@@ -433,7 +476,6 @@ class SynapseTuner:
433
476
  self.nstim2.start = self.nstim.start + (n_init_pulse - 1) * self.nstim.interval + delay
434
477
  tstop = self.nstim2.start + n_ending_pulse * self.nstim2.interval
435
478
  return tstop
436
-
437
479
 
438
480
  def _response_amplitude(self):
439
481
  """
@@ -443,7 +485,7 @@ class SynapseTuner:
443
485
  --------
444
486
  amp : list
445
487
  A list containing the peak amplitudes for each pulse in the recorded synaptic current.
446
-
488
+
447
489
  Notes:
448
490
  ------
449
491
  This method:
@@ -451,28 +493,31 @@ class SynapseTuner:
451
493
  2. Identifies spike times and segments the current accordingly
452
494
  3. Calculates the peak response amplitude for each segment
453
495
  4. Records the indices of peak amplitudes for visualization
454
-
496
+
455
497
  The amplitude values are returned in the original current units (before pA conversion).
456
498
  """
457
499
  isyn = np.array(self.rec_vectors[self.current_name].to_python())
458
500
  tspk = np.append(np.asarray(self.tspk), h.tstop)
459
- syn_prop = self._get_syn_prop(short=True,dt=h.dt)
501
+ syn_prop = self._get_syn_prop(short=True, dt=h.dt)
460
502
  # print("syn_prp[sign] = " + str(syn_prop['sign']))
461
- isyn = (isyn - syn_prop['baseline'])
462
- isyn *= syn_prop['sign']
463
- ispk = np.floor((tspk + self.general_settings['delay']) / h.dt).astype(int)
464
-
465
- try:
466
- amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 1)]
467
- # indexs of where the max of the synaptic current is at. This is then plotted
468
- self.ispk = [np.argmax(isyn[ispk[i]:ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 1)]
503
+ isyn = isyn - syn_prop["baseline"]
504
+ isyn *= syn_prop["sign"]
505
+ ispk = np.floor((tspk + self.general_settings["delay"]) / h.dt).astype(int)
506
+
507
+ try:
508
+ amp = [isyn[ispk[i] : ispk[i + 1]].max() for i in range(ispk.size - 1)]
509
+ # indexs of where the max of the synaptic current is at. This is then plotted
510
+ self.ispk = [
511
+ np.argmax(isyn[ispk[i] : ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 1)
512
+ ]
469
513
  # Sometimes the sim can cutoff at the peak of synaptic activity. So we just reduce the range by 1 and ingore that point
470
514
  except:
471
- amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 2)]
472
- self.ispk = [np.argmax(isyn[ispk[i]:ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 2)]
473
-
474
- return amp
515
+ amp = [isyn[ispk[i] : ispk[i + 1]].max() for i in range(ispk.size - 2)]
516
+ self.ispk = [
517
+ np.argmax(isyn[ispk[i] : ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 2)
518
+ ]
475
519
 
520
+ return amp
476
521
 
477
522
  def _find_max_amp(self, amp):
478
523
  """
@@ -482,7 +527,7 @@ class SynapseTuner:
482
527
  -----------
483
528
  amp : array-like
484
529
  Array containing the amplitudes of synaptic responses.
485
-
530
+
486
531
  Returns:
487
532
  --------
488
533
  max_amp : float
@@ -490,10 +535,9 @@ class SynapseTuner:
490
535
  """
491
536
  max_amp = max(amp)
492
537
  min_amp = min(amp)
493
- if(abs(min_amp) > max_amp):
494
- return min_amp * 1000 # scale unit
495
- return max_amp * 1000 # scale unit
496
-
538
+ if abs(min_amp) > max_amp:
539
+ return min_amp * 1000 # scale unit
540
+ return max_amp * 1000 # scale unit
497
541
 
498
542
  def _calc_ppr_induction_recovery(self, amp, normalize_by_trial=True, print_math=True):
499
543
  """
@@ -515,7 +559,7 @@ class SynapseTuner:
515
559
  - ppr: Paired-pulse ratio (2nd pulse / 1st pulse)
516
560
  - induction: Measure of facilitation/depression during initial pulses
517
561
  - recovery: Measure of recovery after the delay period
518
-
562
+
519
563
  Notes:
520
564
  ------
521
565
  - PPR > 1 indicates facilitation, PPR < 1 indicates depression
@@ -523,18 +567,20 @@ class SynapseTuner:
523
567
  - Recovery compares the response after delay to the initial pulses
524
568
  """
525
569
  amp = np.array(amp)
526
- amp = (amp * 1000) # scale up
570
+ amp = amp * 1000 # scale up
527
571
  amp = amp.reshape(-1, amp.shape[-1])
528
572
  maxamp = amp.max(axis=1 if normalize_by_trial else None)
529
573
 
530
574
  def format_array(arr):
531
575
  """Format an array to 2 significant figures for cleaner output."""
532
- return np.array2string(arr, precision=2, separator=', ', suppress_small=True)
533
-
576
+ return np.array2string(arr, precision=2, separator=", ", suppress_small=True)
577
+
534
578
  if print_math:
535
- print("\n" + "="*40)
536
- print(f"Short Term Plasticity Results for {self.train_freq}Hz with {self.train_delay} Delay")
537
- print("="*40)
579
+ print("\n" + "=" * 40)
580
+ print(
581
+ f"Short Term Plasticity Results for {self.train_freq}Hz with {self.train_delay} Delay"
582
+ )
583
+ print("=" * 40)
538
584
  print("PPR: Above 1 is facilitating, below 1 is depressing.")
539
585
  print("Induction: Above 0 is facilitating, below 0 is depressing.")
540
586
  print("Recovery: A measure of how fast STP decays.\n")
@@ -543,37 +589,48 @@ class SynapseTuner:
543
589
  ppr = amp[:, 1:2] / amp[:, 0:1]
544
590
  print("Paired Pulse Response (PPR)")
545
591
  print("Calculation: 2nd pulse / 1st pulse")
546
- print(f"Values: ({format_array(amp[:, 1:2])}) / ({format_array(amp[:, 0:1])}) = {format_array(ppr)}\n")
592
+ print(
593
+ f"Values: ({format_array(amp[:, 1:2])}) / ({format_array(amp[:, 0:1])}) = {format_array(ppr)}\n"
594
+ )
547
595
 
548
596
  # Induction Calculation
549
597
  induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
550
598
  print("Induction")
551
599
  print("Calculation: (avg(6th, 7th, 8th pulses) - 1st pulse) / max amps")
552
- print(f"Values: avg({format_array(amp[:, 5:8])}) - {format_array(amp[:, :1])} / {format_array(maxamp)}")
553
- print(f"({format_array(amp[:, 5:8].mean(axis=1))}) - ({format_array(amp[:, :1].mean(axis=1))}) / {format_array(maxamp)} = {induction:.3f}\n")
600
+ print(
601
+ f"Values: avg({format_array(amp[:, 5:8])}) - {format_array(amp[:, :1])} / {format_array(maxamp)}"
602
+ )
603
+ print(
604
+ f"({format_array(amp[:, 5:8].mean(axis=1))}) - ({format_array(amp[:, :1].mean(axis=1))}) / {format_array(maxamp)} = {induction:.3f}\n"
605
+ )
554
606
 
555
607
  # Recovery Calculation
556
608
  recovery = np.mean((amp[:, 8:12].mean(axis=1) - amp[:, :4].mean(axis=1)) / maxamp)
557
609
  print("Recovery")
558
- print("Calculation: (avg(9th, 10th, 11th, 12th pulses) - avg(1st to 4th pulses)) / max amps")
559
- print(f"Values: avg({format_array(amp[:, 8:12])}) - avg({format_array(amp[:, :4])}) / {format_array(maxamp)}")
560
- print(f"({format_array(amp[:, 8:12].mean(axis=1))}) - ({format_array(amp[:, :4].mean(axis=1))}) / {format_array(maxamp)} = {recovery:.3f}\n")
610
+ print(
611
+ "Calculation: (avg(9th, 10th, 11th, 12th pulses) - avg(1st to 4th pulses)) / max amps"
612
+ )
613
+ print(
614
+ f"Values: avg({format_array(amp[:, 8:12])}) - avg({format_array(amp[:, :4])}) / {format_array(maxamp)}"
615
+ )
616
+ print(
617
+ f"({format_array(amp[:, 8:12].mean(axis=1))}) - ({format_array(amp[:, :4].mean(axis=1))}) / {format_array(maxamp)} = {recovery:.3f}\n"
618
+ )
561
619
 
562
- print("="*40 + "\n")
620
+ print("=" * 40 + "\n")
563
621
 
564
622
  recovery = np.mean((amp[:, 8:12].mean(axis=1) - amp[:, :4].mean(axis=1)) / maxamp)
565
623
  induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
566
624
  ppr = amp[:, 1:2] / amp[:, 0:1]
567
625
  # maxamp = max(amp, key=lambda x: abs(x[0]))
568
626
  maxamp = maxamp.max()
569
-
570
- return ppr, induction, recovery
571
627
 
628
+ return ppr, induction, recovery
572
629
 
573
630
  def _set_syn_prop(self, **kwargs):
574
631
  """
575
632
  Sets the synaptic parameters based on user inputs from sliders.
576
-
633
+
577
634
  Parameters:
578
635
  -----------
579
636
  **kwargs : dict
@@ -582,8 +639,7 @@ class SynapseTuner:
582
639
  for key, value in kwargs.items():
583
640
  setattr(self.syn, key, value)
584
641
 
585
-
586
- def _simulate_model(self,input_frequency, delay, vclamp=None):
642
+ def _simulate_model(self, input_frequency, delay, vclamp=None):
587
643
  """
588
644
  Runs the simulation with the specified input frequency, delay, and voltage clamp settings.
589
645
 
@@ -595,53 +651,52 @@ class SynapseTuner:
595
651
  Delay period in milliseconds between the 8th and 9th pulses.
596
652
  vclamp : bool or None, optional
597
653
  Whether to use voltage clamp. If None, the current setting is used. Default is None.
598
-
654
+
599
655
  Notes:
600
656
  ------
601
657
  This method handles two different input modes:
602
658
  - Standard train mode with 8 initial pulses followed by a delay and 4 additional pulses
603
659
  - Continuous input mode where stimulation continues for a specified duration
604
660
  """
605
- if self.input_mode == False:
661
+ if not self.input_mode:
606
662
  self.tstop = self._set_drive_train(input_frequency, delay)
607
663
  h.tstop = self.tstop
608
664
 
609
- vcldur = [[0, 0, 0], [self.general_settings['tstart'], self.tstop, 1e9]]
665
+ vcldur = [[0, 0, 0], [self.general_settings["tstart"], self.tstop, 1e9]]
610
666
  for i in range(3):
611
- self.vcl.amp[i] = self.conn['spec_settings']['vclamp_amp']
667
+ self.vcl.amp[i] = self.conn["spec_settings"]["vclamp_amp"]
612
668
  self.vcl.dur[i] = vcldur[1][i]
613
669
  h.finitialize(self.cell.Vinit * mV)
614
670
  h.continuerun(self.tstop * ms)
615
671
  else:
616
- self.tstop = self.general_settings['tstart'] + self.general_settings['tdur']
672
+ self.tstop = self.general_settings["tstart"] + self.general_settings["tdur"]
617
673
  self.nstim.interval = 1000 / input_frequency
618
674
  self.nstim.number = np.ceil(self.w_duration.value / 1000 * input_frequency + 1)
619
675
  self.nstim2.number = 0
620
- self.tstop = self.w_duration.value + self.general_settings['tstart']
621
-
676
+ self.tstop = self.w_duration.value + self.general_settings["tstart"]
677
+
622
678
  h.finitialize(self.cell.Vinit * mV)
623
679
  h.continuerun(self.tstop * ms)
624
-
625
-
680
+
626
681
  def InteractiveTuner(self):
627
682
  """
628
683
  Sets up interactive sliders for tuning short-term plasticity (STP) parameters in a Jupyter Notebook.
629
-
684
+
630
685
  This method creates an interactive UI with sliders for:
631
686
  - Input frequency
632
687
  - Delay between pulse trains
633
688
  - Duration of stimulation (for continuous input mode)
634
689
  - Synaptic parameters (e.g., Use, tau_f, tau_d) based on the syn model
635
-
690
+
636
691
  It also provides buttons for:
637
692
  - Running a single event simulation
638
693
  - Running a train input simulation
639
694
  - Toggling voltage clamp mode
640
695
  - Switching between standard and continuous input modes
641
-
696
+
642
697
  Notes:
643
698
  ------
644
- Ideal for exploratory parameter tuning and interactive visualization of
699
+ Ideal for exploratory parameter tuning and interactive visualization of
645
700
  synapse behavior with different parameter values and stimulation protocols.
646
701
  """
647
702
  # Widgets setup (Sliders)
@@ -653,28 +708,44 @@ class SynapseTuner:
653
708
  duration0 = 300
654
709
  vlamp_status = self.vclamp
655
710
 
656
- w_run = widgets.Button(description='Run Train', icon='history', button_style='primary')
657
- w_single = widgets.Button(description='Single Event', icon='check', button_style='success')
658
- w_vclamp = widgets.ToggleButton(value=vlamp_status, description='Voltage Clamp', icon='fast-backward', button_style='warning')
659
- w_input_mode = widgets.ToggleButton(value=False, description='Continuous input', icon='eject', button_style='info')
660
- w_input_freq = widgets.SelectionSlider(options=freqs, value=freq0, description='Input Freq')
661
-
711
+ w_run = widgets.Button(description="Run Train", icon="history", button_style="primary")
712
+ w_single = widgets.Button(description="Single Event", icon="check", button_style="success")
713
+ w_vclamp = widgets.ToggleButton(
714
+ value=vlamp_status,
715
+ description="Voltage Clamp",
716
+ icon="fast-backward",
717
+ button_style="warning",
718
+ )
719
+ w_input_mode = widgets.ToggleButton(
720
+ value=False, description="Continuous input", icon="eject", button_style="info"
721
+ )
722
+ w_input_freq = widgets.SelectionSlider(options=freqs, value=freq0, description="Input Freq")
662
723
 
663
724
  # Sliders for delay and duration
664
- self.w_delay = widgets.SelectionSlider(options=delays, value=delay0, description='Delay')
665
- self.w_duration = widgets.SelectionSlider(options=durations, value=duration0, description='Duration')
725
+ self.w_delay = widgets.SelectionSlider(options=delays, value=delay0, description="Delay")
726
+ self.w_duration = widgets.SelectionSlider(
727
+ options=durations, value=duration0, description="Duration"
728
+ )
666
729
 
667
730
  # Generate sliders dynamically based on valid numeric entries in self.slider_vars
668
731
  self.dynamic_sliders = {}
669
- print("Setting up slider! The sliders ranges are set by their init value so try changing that if you dont like the slider range!")
732
+ print(
733
+ "Setting up slider! The sliders ranges are set by their init value so try changing that if you dont like the slider range!"
734
+ )
670
735
  for key, value in self.slider_vars.items():
671
736
  if isinstance(value, (int, float)): # Only create sliders for numeric values
672
737
  if hasattr(self.syn, key):
673
738
  if value == 0:
674
- print(f'{key} was set to zero, going to try to set a range of values, try settings the {key} to a nonzero value if you dont like the range!')
675
- slider = widgets.FloatSlider(value=value, min=0, max=1000, step=1, description=key)
739
+ print(
740
+ f"{key} was set to zero, going to try to set a range of values, try settings the {key} to a nonzero value if you dont like the range!"
741
+ )
742
+ slider = widgets.FloatSlider(
743
+ value=value, min=0, max=1000, step=1, description=key
744
+ )
676
745
  else:
677
- slider = widgets.FloatSlider(value=value, min=0, max=value*20, step=value/5, description=key)
746
+ slider = widgets.FloatSlider(
747
+ value=value, min=0, max=value * 20, step=value / 5, description=key
748
+ )
678
749
  self.dynamic_sliders[key] = slider
679
750
  else:
680
751
  print(f"skipping slider for {key} due to not being a synaptic variable")
@@ -684,7 +755,7 @@ class SynapseTuner:
684
755
  display(ui)
685
756
  self.vclamp = w_vclamp.value
686
757
  # Update synaptic properties based on slider values
687
- self.ispk=None
758
+ self.ispk = None
688
759
  self.SingleEvent()
689
760
 
690
761
  # Function to update UI based on input mode
@@ -695,28 +766,30 @@ class SynapseTuner:
695
766
  self.input_mode = w_input_mode.value
696
767
  syn_props = {var: slider.value for var, slider in self.dynamic_sliders.items()}
697
768
  self._set_syn_prop(**syn_props)
698
- if self.input_mode == False:
769
+ if not self.input_mode:
699
770
  self._simulate_model(w_input_freq.value, self.w_delay.value, w_vclamp.value)
700
771
  else:
701
772
  self._simulate_model(w_input_freq.value, self.w_duration.value, w_vclamp.value)
702
773
  amp = self._response_amplitude()
703
- self._plot_model([self.general_settings['tstart'] - self.nstim.interval / 3, self.tstop])
774
+ self._plot_model(
775
+ [self.general_settings["tstart"] - self.nstim.interval / 3, self.tstop]
776
+ )
704
777
  _ = self._calc_ppr_induction_recovery(amp)
705
778
 
706
779
  # Function to switch between delay and duration sliders
707
780
  def switch_slider(*args):
708
781
  if w_input_mode.value:
709
- self.w_delay.layout.display = 'none' # Hide delay slider
710
- self.w_duration.layout.display = '' # Show duration slider
782
+ self.w_delay.layout.display = "none" # Hide delay slider
783
+ self.w_duration.layout.display = "" # Show duration slider
711
784
  else:
712
- self.w_delay.layout.display = '' # Show delay slider
713
- self.w_duration.layout.display = 'none' # Hide duration slider
785
+ self.w_delay.layout.display = "" # Show delay slider
786
+ self.w_duration.layout.display = "none" # Hide duration slider
714
787
 
715
788
  # Link input mode to slider switch
716
- w_input_mode.observe(switch_slider, names='value')
789
+ w_input_mode.observe(switch_slider, names="value")
717
790
 
718
791
  # Hide the duration slider initially until the user selects it
719
- self.w_duration.layout.display = 'none' # Hide duration slider
792
+ self.w_duration.layout.display = "none" # Hide duration slider
720
793
 
721
794
  w_single.on_click(run_single_event)
722
795
  w_run.on_click(update_ui)
@@ -726,7 +799,7 @@ class SynapseTuner:
726
799
 
727
800
  button_row = HBox([w_run, w_single, w_vclamp, w_input_mode])
728
801
  slider_row = HBox([w_input_freq, self.w_delay, self.w_duration])
729
-
802
+
730
803
  half = len(slider_widgets) // 2
731
804
  col1 = VBox(slider_widgets[:half])
732
805
  col2 = VBox(slider_widgets[half:])
@@ -736,17 +809,21 @@ class SynapseTuner:
736
809
 
737
810
  display(ui)
738
811
  update_ui()
739
-
740
-
741
- def stp_frequency_response(self, freqs=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200],
742
- delay=250, plot=True,log_plot=True):
812
+
813
+ def stp_frequency_response(
814
+ self,
815
+ freqs=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200],
816
+ delay=250,
817
+ plot=True,
818
+ log_plot=True,
819
+ ):
743
820
  """
744
821
  Analyze synaptic response across different stimulation frequencies.
745
-
746
- This method systematically tests how the synapse model responds to different
822
+
823
+ This method systematically tests how the synapse model responds to different
747
824
  stimulation frequencies, calculating key short-term plasticity (STP) metrics
748
825
  for each frequency.
749
-
826
+
750
827
  Parameters:
751
828
  -----------
752
829
  freqs : list, optional
@@ -757,7 +834,7 @@ class SynapseTuner:
757
834
  Whether to plot the results. Default is True.
758
835
  log_plot : bool, optional
759
836
  Whether to use logarithmic scale for frequency axis. Default is True.
760
-
837
+
761
838
  Returns:
762
839
  --------
763
840
  dict
@@ -766,45 +843,39 @@ class SynapseTuner:
766
843
  - 'ppr': Paired-pulse ratios at each frequency
767
844
  - 'induction': Induction values at each frequency
768
845
  - 'recovery': Recovery values at each frequency
769
-
846
+
770
847
  Notes:
771
848
  ------
772
- This method is particularly useful for characterizing the frequency-dependent
849
+ This method is particularly useful for characterizing the frequency-dependent
773
850
  behavior of synapses, such as identifying facilitating vs. depressing regimes
774
851
  or the frequency at which a synapse transitions between these behaviors.
775
852
  """
776
- results = {
777
- 'frequencies': freqs,
778
- 'ppr': [],
779
- 'induction': [],
780
- 'recovery': []
781
- }
782
-
853
+ results = {"frequencies": freqs, "ppr": [], "induction": [], "recovery": []}
854
+
783
855
  # Store original state
784
856
  original_ispk = self.ispk
785
-
857
+
786
858
  for freq in tqdm(freqs, desc="Analyzing frequencies"):
787
859
  self._simulate_model(freq, delay)
788
860
  amp = self._response_amplitude()
789
861
  ppr, induction, recovery = self._calc_ppr_induction_recovery(amp, print_math=False)
790
-
791
- results['ppr'].append(float(ppr))
792
- results['induction'].append(float(induction))
793
- results['recovery'].append(float(recovery))
794
-
862
+
863
+ results["ppr"].append(float(ppr))
864
+ results["induction"].append(float(induction))
865
+ results["recovery"].append(float(recovery))
866
+
795
867
  # Restore original state
796
868
  self.ispk = original_ispk
797
-
869
+
798
870
  if plot:
799
- self._plot_frequency_analysis(results,log_plot=log_plot)
800
-
801
- return results
871
+ self._plot_frequency_analysis(results, log_plot=log_plot)
802
872
 
873
+ return results
803
874
 
804
875
  def _plot_frequency_analysis(self, results, log_plot):
805
876
  """
806
877
  Plot the frequency-dependent synaptic properties.
807
-
878
+
808
879
  Parameters:
809
880
  -----------
810
881
  results : dict
@@ -815,61 +886,67 @@ class SynapseTuner:
815
886
  - 'recovery': Recovery values at each frequency
816
887
  log_plot : bool
817
888
  Whether to use logarithmic scale for frequency axis
818
-
889
+
819
890
  Notes:
820
891
  ------
821
892
  Creates a figure with three subplots showing:
822
893
  1. Paired-pulse ratio vs. frequency
823
894
  2. Induction vs. frequency
824
895
  3. Recovery vs. frequency
825
-
896
+
826
897
  Each plot includes a horizontal reference line at y=0 or y=1 to indicate
827
898
  the boundary between facilitation and depression.
828
899
  """
829
900
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
830
-
831
-
901
+
832
902
  # Plot PPR
833
903
  if log_plot:
834
- ax1.semilogx(results['frequencies'], results['ppr'], 'o-')
904
+ ax1.semilogx(results["frequencies"], results["ppr"], "o-")
835
905
  else:
836
- ax1.plot(results['frequencies'], results['ppr'], 'o-')
837
- ax1.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
838
- ax1.set_xlabel('Frequency (Hz)')
839
- ax1.set_ylabel('Paired Pulse Ratio')
840
- ax1.set_title('PPR vs Frequency')
906
+ ax1.plot(results["frequencies"], results["ppr"], "o-")
907
+ ax1.axhline(y=1, color="gray", linestyle="--", alpha=0.5)
908
+ ax1.set_xlabel("Frequency (Hz)")
909
+ ax1.set_ylabel("Paired Pulse Ratio")
910
+ ax1.set_title("PPR vs Frequency")
841
911
  ax1.grid(True)
842
-
912
+
843
913
  # Plot Induction
844
914
  if log_plot:
845
- ax2.semilogx(results['frequencies'], results['induction'], 'o-')
915
+ ax2.semilogx(results["frequencies"], results["induction"], "o-")
846
916
  else:
847
- ax2.plot(results['frequencies'], results['induction'], 'o-')
848
- ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
849
- ax2.set_xlabel('Frequency (Hz)')
850
- ax2.set_ylabel('Induction')
851
- ax2.set_title('Induction vs Frequency')
917
+ ax2.plot(results["frequencies"], results["induction"], "o-")
918
+ ax2.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
919
+ ax2.set_xlabel("Frequency (Hz)")
920
+ ax2.set_ylabel("Induction")
921
+ ax2.set_title("Induction vs Frequency")
852
922
  ax2.grid(True)
853
-
923
+
854
924
  # Plot Recovery
855
925
  if log_plot:
856
- ax3.semilogx(results['frequencies'], results['recovery'], 'o-')
926
+ ax3.semilogx(results["frequencies"], results["recovery"], "o-")
857
927
  else:
858
- ax3.plot(results['frequencies'], results['recovery'], 'o-')
859
- ax3.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
860
- ax3.set_xlabel('Frequency (Hz)')
861
- ax3.set_ylabel('Recovery')
862
- ax3.set_title('Recovery vs Frequency')
928
+ ax3.plot(results["frequencies"], results["recovery"], "o-")
929
+ ax3.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
930
+ ax3.set_xlabel("Frequency (Hz)")
931
+ ax3.set_ylabel("Recovery")
932
+ ax3.set_title("Recovery vs Frequency")
863
933
  ax3.grid(True)
864
-
934
+
865
935
  plt.tight_layout()
866
936
  plt.show()
867
937
 
938
+
868
939
  class GapJunctionTuner:
869
- def __init__(self, mechanisms_dir: str, templates_dir: str, general_settings: dict, conn_type_settings: dict):
940
+ def __init__(
941
+ self,
942
+ mechanisms_dir: str,
943
+ templates_dir: str,
944
+ general_settings: dict,
945
+ conn_type_settings: dict,
946
+ ):
870
947
  """
871
948
  Initialize the GapJunctionTuner class.
872
-
949
+
873
950
  Parameters:
874
951
  -----------
875
952
  mechanisms_dir : str
@@ -883,50 +960,50 @@ class GapJunctionTuner:
883
960
  """
884
961
  neuron.load_mechanisms(mechanisms_dir)
885
962
  h.load_file(templates_dir)
886
-
963
+
887
964
  self.general_settings = general_settings
888
965
  self.conn_type_settings = conn_type_settings
889
-
890
- h.tstop = general_settings['tstart'] + general_settings['tdur'] + 100.
891
- h.dt = general_settings['dt'] # Time step (resolution) of the simulation in ms
966
+
967
+ h.tstop = general_settings["tstart"] + general_settings["tdur"] + 100.0
968
+ h.dt = general_settings["dt"] # Time step (resolution) of the simulation in ms
892
969
  h.steps_per_ms = 1 / h.dt
893
- h.celsius = general_settings['celsius']
970
+ h.celsius = general_settings["celsius"]
971
+
972
+ self.cell_name = conn_type_settings["cell"]
894
973
 
895
- self.cell_name = conn_type_settings['cell']
896
-
897
974
  # set up gap junctions
898
975
  pc = h.ParallelContext()
899
976
 
900
977
  self.cell1 = getattr(h, self.cell_name)()
901
978
  self.cell2 = getattr(h, self.cell_name)()
902
-
979
+
903
980
  self.icl = h.IClamp(self.cell1.soma[0](0.5))
904
- self.icl.delay = self.general_settings['tstart']
905
- self.icl.dur = self.general_settings['tdur']
906
- self.icl.amp = self.conn_type_settings['iclamp_amp'] # nA
907
-
908
- sec1 = list(self.cell1.all)[conn_type_settings['sec_id']]
909
- sec2 = list(self.cell2.all)[conn_type_settings['sec_id']]
910
-
911
- pc.source_var(sec1(conn_type_settings['sec_x'])._ref_v, 0, sec=sec1)
981
+ self.icl.delay = self.general_settings["tstart"]
982
+ self.icl.dur = self.general_settings["tdur"]
983
+ self.icl.amp = self.conn_type_settings["iclamp_amp"] # nA
984
+
985
+ sec1 = list(self.cell1.all)[conn_type_settings["sec_id"]]
986
+ sec2 = list(self.cell2.all)[conn_type_settings["sec_id"]]
987
+
988
+ pc.source_var(sec1(conn_type_settings["sec_x"])._ref_v, 0, sec=sec1)
912
989
  self.gap_junc_1 = h.Gap(sec1(0.5))
913
- pc.target_var(self.gap_junc_1 ._ref_vgap, 1)
990
+ pc.target_var(self.gap_junc_1._ref_vgap, 1)
914
991
 
915
- pc.source_var(sec2(conn_type_settings['sec_x'])._ref_v, 1, sec=sec2)
992
+ pc.source_var(sec2(conn_type_settings["sec_x"])._ref_v, 1, sec=sec2)
916
993
  self.gap_junc_2 = h.Gap(sec2(0.5))
917
994
  pc.target_var(self.gap_junc_2._ref_vgap, 0)
918
995
 
919
996
  pc.setup_transfer()
920
-
921
- def model(self,resistance):
997
+
998
+ def model(self, resistance):
922
999
  """
923
1000
  Run a simulation with a specified gap junction resistance.
924
-
1001
+
925
1002
  Parameters:
926
1003
  -----------
927
1004
  resistance : float
928
1005
  The gap junction resistance value (in MOhm) to use for the simulation.
929
-
1006
+
930
1007
  Notes:
931
1008
  ------
932
1009
  This method sets up the gap junction resistance, initializes recording vectors for time
@@ -934,49 +1011,50 @@ class GapJunctionTuner:
934
1011
  """
935
1012
  self.gap_junc_1.g = resistance
936
1013
  self.gap_junc_2.g = resistance
937
-
1014
+
938
1015
  t_vec = h.Vector()
939
1016
  soma_v_1 = h.Vector()
940
1017
  soma_v_2 = h.Vector()
941
1018
  t_vec.record(h._ref_t)
942
1019
  soma_v_1.record(self.cell1.soma[0](0.5)._ref_v)
943
1020
  soma_v_2.record(self.cell2.soma[0](0.5)._ref_v)
944
-
1021
+
945
1022
  self.t_vec = t_vec
946
1023
  self.soma_v_1 = soma_v_1
947
1024
  self.soma_v_2 = soma_v_2
948
-
1025
+
949
1026
  h.finitialize(-70 * mV)
950
1027
  h.continuerun(h.tstop * ms)
951
-
952
-
1028
+
953
1029
  def plot_model(self):
954
1030
  """
955
1031
  Plot the voltage traces of both cells to visualize gap junction coupling.
956
-
1032
+
957
1033
  This method creates a plot showing the membrane potential of both cells over time,
958
1034
  highlighting the effect of gap junction coupling when a current step is applied to cell 1.
959
1035
  """
960
- t_range = [self.general_settings['tstart'] - 100., self.general_settings['tstart']+self.general_settings['tdur'] + 100.]
1036
+ t_range = [
1037
+ self.general_settings["tstart"] - 100.0,
1038
+ self.general_settings["tstart"] + self.general_settings["tdur"] + 100.0,
1039
+ ]
961
1040
  t = np.array(self.t_vec)
962
1041
  v1 = np.array(self.soma_v_1)
963
1042
  v2 = np.array(self.soma_v_2)
964
1043
  tidx = (t >= t_range[0]) & (t <= t_range[1])
965
1044
 
966
1045
  plt.figure()
967
- plt.plot(t[tidx], v1[tidx], 'b', label=f'{self.cell_name} 1')
968
- plt.plot(t[tidx], v2[tidx], 'r', label=f'{self.cell_name} 2')
1046
+ plt.plot(t[tidx], v1[tidx], "b", label=f"{self.cell_name} 1")
1047
+ plt.plot(t[tidx], v2[tidx], "r", label=f"{self.cell_name} 2")
969
1048
  plt.title(f"{self.cell_name} gap junction")
970
- plt.xlabel('Time (ms)')
971
- plt.ylabel('Membrane Voltage (mV)')
1049
+ plt.xlabel("Time (ms)")
1050
+ plt.ylabel("Membrane Voltage (mV)")
972
1051
  plt.legend()
973
- plt.show()
974
-
1052
+ plt.show()
975
1053
 
976
- def coupling_coefficient(self,t, v1, v2, t_start, t_end, dt=h.dt):
1054
+ def coupling_coefficient(self, t, v1, v2, t_start, t_end, dt=h.dt):
977
1055
  """
978
1056
  Calculate the coupling coefficient between two cells connected by a gap junction.
979
-
1057
+
980
1058
  Parameters:
981
1059
  -----------
982
1060
  t : array-like
@@ -991,11 +1069,11 @@ class GapJunctionTuner:
991
1069
  End time for calculating the steady-state voltage change.
992
1070
  dt : float, optional
993
1071
  Time step of the simulation. Default is h.dt.
994
-
1072
+
995
1073
  Returns:
996
1074
  --------
997
1075
  float
998
- The coupling coefficient, defined as the ratio of voltage change in cell 2
1076
+ The coupling coefficient, defined as the ratio of voltage change in cell 2
999
1077
  to voltage change in cell 1 (ΔV₂/ΔV₁).
1000
1078
  """
1001
1079
  t = np.asarray(t)
@@ -1005,21 +1083,21 @@ class GapJunctionTuner:
1005
1083
  idx2 = np.nonzero(t < t_end)[0][-1]
1006
1084
  return (v2[idx2] - v2[idx1]) / (v1[idx2] - v1[idx1])
1007
1085
 
1008
-
1009
1086
  def InteractiveTuner(self):
1010
- w_run = widgets.Button(description='Run', icon='history', button_style='primary')
1087
+ w_run = widgets.Button(description="Run", icon="history", button_style="primary")
1011
1088
  values = [i * 10**-4 for i in range(1, 101)] # From 1e-4 to 1e-2
1012
1089
 
1013
1090
  # Create the SelectionSlider widget with appropriate formatting
1014
1091
  resistance = widgets.SelectionSlider(
1015
- options=[("%g"%i,i) for i in values], # Use scientific notation for display
1092
+ options=[("%g" % i, i) for i in values], # Use scientific notation for display
1016
1093
  value=10**-3, # Default value
1017
- description='Resistance: ',
1018
- continuous_update=True
1019
- )
1094
+ description="Resistance: ",
1095
+ continuous_update=True,
1096
+ )
1020
1097
 
1021
- ui = VBox([w_run,resistance])
1098
+ ui = VBox([w_run, resistance])
1022
1099
  display(ui)
1100
+
1023
1101
  def on_button(*args):
1024
1102
  clear_output()
1025
1103
  display(ui)
@@ -1029,26 +1107,29 @@ class GapJunctionTuner:
1029
1107
  cc = self.coupling_coefficient(self.t_vec, self.soma_v_1, self.soma_v_2, 500, 1000)
1030
1108
  print(f"coupling_coefficient is {cc:0.4f}")
1031
1109
 
1032
- on_button()
1110
+ on_button()
1033
1111
  w_run.on_click(on_button)
1034
-
1035
-
1036
- # optimizers!
1037
-
1112
+
1113
+
1114
+ # optimizers!
1115
+
1116
+
1038
1117
  @dataclass
1039
1118
  class SynapseOptimizationResult:
1040
1119
  """Container for synaptic parameter optimization results"""
1120
+
1041
1121
  optimal_params: Dict[str, float]
1042
1122
  achieved_metrics: Dict[str, float]
1043
1123
  target_metrics: Dict[str, float]
1044
1124
  error: float
1045
1125
  optimization_path: List[Dict[str, float]]
1046
1126
 
1127
+
1047
1128
  class SynapseOptimizer:
1048
1129
  def __init__(self, tuner):
1049
1130
  """
1050
1131
  Initialize the synapse optimizer with parameter scaling
1051
-
1132
+
1052
1133
  Parameters:
1053
1134
  -----------
1054
1135
  tuner : SynapseTuner
@@ -1057,50 +1138,54 @@ class SynapseOptimizer:
1057
1138
  self.tuner = tuner
1058
1139
  self.optimization_history = []
1059
1140
  self.param_scales = {}
1060
-
1141
+
1061
1142
  def _normalize_params(self, params: np.ndarray, param_names: List[str]) -> np.ndarray:
1062
1143
  """
1063
1144
  Normalize parameters to similar scales for better optimization performance.
1064
-
1145
+
1065
1146
  Parameters:
1066
1147
  -----------
1067
1148
  params : np.ndarray
1068
1149
  Original parameter values.
1069
1150
  param_names : List[str]
1070
1151
  Names of the parameters corresponding to the values.
1071
-
1152
+
1072
1153
  Returns:
1073
1154
  --------
1074
1155
  np.ndarray
1075
1156
  Normalized parameter values.
1076
1157
  """
1077
1158
  return np.array([params[i] / self.param_scales[name] for i, name in enumerate(param_names)])
1078
-
1079
- def _denormalize_params(self, normalized_params: np.ndarray, param_names: List[str]) -> np.ndarray:
1159
+
1160
+ def _denormalize_params(
1161
+ self, normalized_params: np.ndarray, param_names: List[str]
1162
+ ) -> np.ndarray:
1080
1163
  """
1081
1164
  Convert normalized parameters back to original scale.
1082
-
1165
+
1083
1166
  Parameters:
1084
1167
  -----------
1085
1168
  normalized_params : np.ndarray
1086
1169
  Normalized parameter values.
1087
1170
  param_names : List[str]
1088
1171
  Names of the parameters corresponding to the normalized values.
1089
-
1172
+
1090
1173
  Returns:
1091
1174
  --------
1092
1175
  np.ndarray
1093
1176
  Denormalized parameter values in their original scale.
1094
1177
  """
1095
- return np.array([normalized_params[i] * self.param_scales[name] for i, name in enumerate(param_names)])
1096
-
1178
+ return np.array(
1179
+ [normalized_params[i] * self.param_scales[name] for i, name in enumerate(param_names)]
1180
+ )
1181
+
1097
1182
  def _calculate_metrics(self) -> Dict[str, float]:
1098
1183
  """
1099
1184
  Calculate standard metrics from the current simulation.
1100
-
1101
- This method runs either a single event simulation, a train input simulation,
1185
+
1186
+ This method runs either a single event simulation, a train input simulation,
1102
1187
  or both based on configuration flags, and calculates relevant synaptic metrics.
1103
-
1188
+
1104
1189
  Returns:
1105
1190
  --------
1106
1191
  Dict[str, float]
@@ -1112,96 +1197,108 @@ class SynapseOptimizer:
1112
1197
  - rise_time: time for synaptic response to rise from 20% to 80% of peak
1113
1198
  - decay_time: time constant of synaptic response decay
1114
1199
  """
1115
- # Set these to 0 for when we return the dict
1200
+ # Set these to 0 for when we return the dict
1116
1201
  induction = 0
1117
1202
  ppr = 0
1118
1203
  recovery = 0
1119
1204
  amp = 0
1120
1205
  rise_time = 0
1121
1206
  decay_time = 0
1122
-
1207
+
1123
1208
  if self.run_single_event:
1124
1209
  self.tuner.SingleEvent(plot_and_print=False)
1125
1210
  rise_time = self.tuner.rise_time
1126
1211
  decay_time = self.tuner.decay_time
1127
-
1212
+
1128
1213
  if self.run_train_input:
1129
1214
  self.tuner._simulate_model(self.train_frequency, self.train_delay)
1130
1215
  amp = self.tuner._response_amplitude()
1131
- ppr, induction, recovery = self.tuner._calc_ppr_induction_recovery(amp, print_math=False)
1216
+ ppr, induction, recovery = self.tuner._calc_ppr_induction_recovery(
1217
+ amp, print_math=False
1218
+ )
1132
1219
  amp = self.tuner._find_max_amp(amp)
1133
-
1220
+
1134
1221
  return {
1135
- 'induction': float(induction),
1136
- 'ppr': float(ppr),
1137
- 'recovery': float(recovery),
1138
- 'max_amplitude': float(amp),
1139
- 'rise_time': float(rise_time),
1140
- 'decay_time': float(decay_time)
1222
+ "induction": float(induction),
1223
+ "ppr": float(ppr),
1224
+ "recovery": float(recovery),
1225
+ "max_amplitude": float(amp),
1226
+ "rise_time": float(rise_time),
1227
+ "decay_time": float(decay_time),
1141
1228
  }
1142
-
1143
- def _default_cost_function(self, metrics: Dict[str, float], target_metrics: Dict[str, float]) -> float:
1229
+
1230
+ def _default_cost_function(
1231
+ self, metrics: Dict[str, float], target_metrics: Dict[str, float]
1232
+ ) -> float:
1144
1233
  """
1145
1234
  Default cost function that minimizes the squared difference between achieved and target induction.
1146
-
1235
+
1147
1236
  Parameters:
1148
1237
  -----------
1149
1238
  metrics : Dict[str, float]
1150
1239
  Dictionary of calculated metrics from the current simulation.
1151
1240
  target_metrics : Dict[str, float]
1152
1241
  Dictionary of target metrics to optimize towards.
1153
-
1242
+
1154
1243
  Returns:
1155
1244
  --------
1156
1245
  float
1157
1246
  The squared error between achieved and target induction.
1158
1247
  """
1159
- return float((metrics['induction'] - target_metrics['induction']) ** 2)
1248
+ return float((metrics["induction"] - target_metrics["induction"]) ** 2)
1160
1249
 
1161
- def _objective_function(self,
1162
- normalized_params: np.ndarray,
1163
- param_names: List[str],
1164
- cost_function: Callable,
1165
- target_metrics: Dict[str, float]) -> float:
1250
+ def _objective_function(
1251
+ self,
1252
+ normalized_params: np.ndarray,
1253
+ param_names: List[str],
1254
+ cost_function: Callable,
1255
+ target_metrics: Dict[str, float],
1256
+ ) -> float:
1166
1257
  """
1167
1258
  Calculate error using provided cost function
1168
1259
  """
1169
1260
  # Denormalize parameters
1170
1261
  params = self._denormalize_params(normalized_params, param_names)
1171
-
1262
+
1172
1263
  # Set parameters
1173
1264
  for name, value in zip(param_names, params):
1174
1265
  setattr(self.tuner.syn, name, value)
1175
-
1176
- # just do this and have the SingleEvent handle it
1266
+
1267
+ # just do this and have the SingleEvent handle it
1177
1268
  if self.run_single_event:
1178
1269
  self.tuner.using_optimizer = True
1179
1270
  self.tuner.param_names = param_names
1180
1271
  self.tuner.params = params
1181
-
1272
+
1182
1273
  # Calculate metrics and error
1183
1274
  metrics = self._calculate_metrics()
1184
1275
  error = float(cost_function(metrics, target_metrics)) # Ensure error is scalar
1185
-
1276
+
1186
1277
  # Store history with denormalized values
1187
1278
  history_entry = {
1188
- 'params': dict(zip(param_names, params)),
1189
- 'metrics': metrics,
1190
- 'error': error
1279
+ "params": dict(zip(param_names, params)),
1280
+ "metrics": metrics,
1281
+ "error": error,
1191
1282
  }
1192
1283
  self.optimization_history.append(history_entry)
1193
-
1284
+
1194
1285
  return error
1195
-
1196
- def optimize_parameters(self, target_metrics: Dict[str, float],
1197
- param_bounds: Dict[str, Tuple[float, float]],
1198
- run_single_event:bool = False, run_train_input:bool = True,
1199
- train_frequency: float = 50,train_delay: float = 250,
1200
- cost_function: Optional[Callable] = None,
1201
- method: str = 'SLSQP',init_guess='random') -> SynapseOptimizationResult:
1286
+
1287
+ def optimize_parameters(
1288
+ self,
1289
+ target_metrics: Dict[str, float],
1290
+ param_bounds: Dict[str, Tuple[float, float]],
1291
+ run_single_event: bool = False,
1292
+ run_train_input: bool = True,
1293
+ train_frequency: float = 50,
1294
+ train_delay: float = 250,
1295
+ cost_function: Optional[Callable] = None,
1296
+ method: str = "SLSQP",
1297
+ init_guess="random",
1298
+ ) -> SynapseOptimizationResult:
1202
1299
  """
1203
1300
  Optimize synaptic parameters to achieve target metrics.
1204
-
1301
+
1205
1302
  Parameters:
1206
1303
  -----------
1207
1304
  target_metrics : Dict[str, float]
@@ -1223,13 +1320,13 @@ class SynapseOptimizer:
1223
1320
  Optimization method to use (default: 'SLSQP')
1224
1321
  init_guess : str, optional
1225
1322
  Method for initial parameter guess ('random' or 'middle_guess')
1226
-
1323
+
1227
1324
  Returns:
1228
1325
  --------
1229
1326
  SynapseOptimizationResult
1230
1327
  Results of the optimization including optimal parameters, achieved metrics,
1231
1328
  target metrics, final error, and optimization path.
1232
-
1329
+
1233
1330
  Notes:
1234
1331
  ------
1235
1332
  This function uses scipy.optimize.minimize to find the optimal parameter values
@@ -1240,149 +1337,154 @@ class SynapseOptimizer:
1240
1337
  self.train_delay = train_delay
1241
1338
  self.run_single_event = run_single_event
1242
1339
  self.run_train_input = run_train_input
1243
-
1340
+
1244
1341
  param_names = list(param_bounds.keys())
1245
1342
  bounds = [param_bounds[name] for name in param_names]
1246
-
1343
+
1247
1344
  if cost_function is None:
1248
1345
  cost_function = self._default_cost_function
1249
-
1346
+
1250
1347
  # Calculate scaling factors
1251
1348
  self.param_scales = {
1252
- name: max(abs(bounds[i][0]), abs(bounds[i][1]))
1253
- for i, name in enumerate(param_names)
1349
+ name: max(abs(bounds[i][0]), abs(bounds[i][1])) for i, name in enumerate(param_names)
1254
1350
  }
1255
-
1351
+
1256
1352
  # Normalize bounds
1257
1353
  normalized_bounds = [
1258
- (b[0]/self.param_scales[name], b[1]/self.param_scales[name])
1354
+ (b[0] / self.param_scales[name], b[1] / self.param_scales[name])
1259
1355
  for name, b in zip(param_names, bounds)
1260
1356
  ]
1261
-
1357
+
1262
1358
  # picks with method of init value we want to use
1263
- if init_guess=='random':
1359
+ if init_guess == "random":
1264
1360
  x0 = np.array([np.random.uniform(b[0], b[1]) for b in bounds])
1265
- elif init_guess=='middle_guess':
1266
- x0 = [(b[0] + b[1])/2 for b in bounds]
1361
+ elif init_guess == "middle_guess":
1362
+ x0 = [(b[0] + b[1]) / 2 for b in bounds]
1267
1363
  else:
1268
1364
  raise Exception("Pick a vaid init guess method either random or midde_guess")
1269
1365
  normalized_x0 = self._normalize_params(np.array(x0), param_names)
1270
-
1271
-
1366
+
1272
1367
  # Run optimization
1273
1368
  result = minimize(
1274
1369
  self._objective_function,
1275
1370
  normalized_x0,
1276
1371
  args=(param_names, cost_function, target_metrics),
1277
1372
  method=method,
1278
- bounds=normalized_bounds
1373
+ bounds=normalized_bounds,
1279
1374
  )
1280
-
1375
+
1281
1376
  # Get final parameters and metrics
1282
1377
  final_params = dict(zip(param_names, self._denormalize_params(result.x, param_names)))
1283
1378
  for name, value in final_params.items():
1284
1379
  setattr(self.tuner.syn, name, value)
1285
1380
  final_metrics = self._calculate_metrics()
1286
-
1381
+
1287
1382
  return SynapseOptimizationResult(
1288
1383
  optimal_params=final_params,
1289
1384
  achieved_metrics=final_metrics,
1290
1385
  target_metrics=target_metrics,
1291
1386
  error=result.fun,
1292
- optimization_path=self.optimization_history
1387
+ optimization_path=self.optimization_history,
1293
1388
  )
1294
-
1389
+
1295
1390
  def plot_optimization_results(self, result: SynapseOptimizationResult):
1296
1391
  """
1297
1392
  Plot optimization results including convergence and final traces.
1298
-
1393
+
1299
1394
  Parameters:
1300
1395
  -----------
1301
1396
  result : SynapseOptimizationResult
1302
1397
  Results from optimization as returned by optimize_parameters()
1303
-
1398
+
1304
1399
  Notes:
1305
1400
  ------
1306
1401
  This method generates three plots:
1307
1402
  1. Error convergence plot showing how the error decreased over iterations
1308
1403
  2. Parameter convergence plots showing how each parameter changed
1309
1404
  3. Final model response with the optimal parameters
1310
-
1405
+
1311
1406
  It also prints a summary of the optimization results including target vs. achieved
1312
1407
  metrics and the optimal parameter values.
1313
1408
  """
1314
1409
  # Ensure errors are properly shaped for plotting
1315
1410
  iterations = range(len(result.optimization_path))
1316
- errors = np.array([float(h['error']) for h in result.optimization_path]).flatten()
1317
-
1411
+ errors = np.array([float(h["error"]) for h in result.optimization_path]).flatten()
1412
+
1318
1413
  # Plot error convergence
1319
1414
  fig1, ax1 = plt.subplots(figsize=(8, 5))
1320
- ax1.plot(iterations, errors, label='Error')
1321
- ax1.set_xlabel('Iteration')
1322
- ax1.set_ylabel('Error')
1323
- ax1.set_title('Error Convergence')
1324
- ax1.set_yscale('log')
1415
+ ax1.plot(iterations, errors, label="Error")
1416
+ ax1.set_xlabel("Iteration")
1417
+ ax1.set_ylabel("Error")
1418
+ ax1.set_title("Error Convergence")
1419
+ ax1.set_yscale("log")
1325
1420
  ax1.legend()
1326
1421
  plt.tight_layout()
1327
1422
  plt.show()
1328
-
1423
+
1329
1424
  # Plot parameter convergence
1330
1425
  param_names = list(result.optimal_params.keys())
1331
1426
  num_params = len(param_names)
1332
1427
  fig2, axs = plt.subplots(nrows=num_params, ncols=1, figsize=(8, 5 * num_params))
1333
-
1428
+
1334
1429
  if num_params == 1:
1335
1430
  axs = [axs]
1336
-
1431
+
1337
1432
  for ax, param in zip(axs, param_names):
1338
- values = [float(h['params'][param]) for h in result.optimization_path]
1339
- ax.plot(iterations, values, label=f'{param}')
1340
- ax.set_xlabel('Iteration')
1341
- ax.set_ylabel('Parameter Value')
1342
- ax.set_title(f'Convergence of {param}')
1433
+ values = [float(h["params"][param]) for h in result.optimization_path]
1434
+ ax.plot(iterations, values, label=f"{param}")
1435
+ ax.set_xlabel("Iteration")
1436
+ ax.set_ylabel("Parameter Value")
1437
+ ax.set_title(f"Convergence of {param}")
1343
1438
  ax.legend()
1344
-
1439
+
1345
1440
  plt.tight_layout()
1346
1441
  plt.show()
1347
-
1442
+
1348
1443
  # Print final results
1349
1444
  print("Optimization Results:")
1350
1445
  print(f"Final Error: {float(result.error):.2e}\n")
1351
1446
  print("Target Metrics:")
1352
1447
  for metric, value in result.target_metrics.items():
1353
1448
  achieved = result.achieved_metrics.get(metric)
1354
- if achieved is not None and metric != 'amplitudes': # Skip amplitude array
1449
+ if achieved is not None and metric != "amplitudes": # Skip amplitude array
1355
1450
  print(f"{metric}: {float(achieved):.3f} (target: {float(value):.3f})")
1356
-
1451
+
1357
1452
  print("\nOptimal Parameters:")
1358
1453
  for param, value in result.optimal_params.items():
1359
1454
  print(f"{param}: {float(value):.3f}")
1360
-
1455
+
1361
1456
  # Plot final model response
1362
1457
  if self.run_train_input:
1363
- self.tuner._plot_model([self.tuner.general_settings['tstart'] - self.tuner.nstim.interval / 3, self.tuner.tstop])
1458
+ self.tuner._plot_model(
1459
+ [
1460
+ self.tuner.general_settings["tstart"] - self.tuner.nstim.interval / 3,
1461
+ self.tuner.tstop,
1462
+ ]
1463
+ )
1364
1464
  amp = self.tuner._response_amplitude()
1365
1465
  self.tuner._calc_ppr_induction_recovery(amp)
1366
1466
  if self.run_single_event:
1367
- self.tuner.ispk=None
1467
+ self.tuner.ispk = None
1368
1468
  self.tuner.SingleEvent(plot_and_print=True)
1369
-
1370
-
1469
+
1470
+
1371
1471
  # dataclass means just init the typehints as self.typehint. looks a bit cleaner
1372
1472
  @dataclass
1373
1473
  class GapOptimizationResult:
1374
1474
  """Container for gap junction optimization results"""
1475
+
1375
1476
  optimal_resistance: float
1376
1477
  achieved_cc: float
1377
1478
  target_cc: float
1378
1479
  error: float
1379
1480
  optimization_path: List[Dict[str, float]]
1380
1481
 
1482
+
1381
1483
  class GapJunctionOptimizer:
1382
1484
  def __init__(self, tuner):
1383
1485
  """
1384
1486
  Initialize the gap junction optimizer
1385
-
1487
+
1386
1488
  Parameters:
1387
1489
  -----------
1388
1490
  tuner : GapJunctionTuner
@@ -1390,52 +1492,50 @@ class GapJunctionOptimizer:
1390
1492
  """
1391
1493
  self.tuner = tuner
1392
1494
  self.optimization_history = []
1393
-
1495
+
1394
1496
  def _objective_function(self, resistance: float, target_cc: float) -> float:
1395
1497
  """
1396
1498
  Calculate error between achieved and target coupling coefficient
1397
-
1499
+
1398
1500
  Parameters:
1399
1501
  -----------
1400
1502
  resistance : float
1401
1503
  Gap junction resistance to try
1402
1504
  target_cc : float
1403
1505
  Target coupling coefficient to match
1404
-
1506
+
1405
1507
  Returns:
1406
1508
  --------
1407
1509
  float : Error between achieved and target coupling coefficient
1408
1510
  """
1409
1511
  # Run model with current resistance
1410
1512
  self.tuner.model(resistance)
1411
-
1513
+
1412
1514
  # Calculate coupling coefficient
1413
1515
  achieved_cc = self.tuner.coupling_coefficient(
1414
- self.tuner.t_vec,
1415
- self.tuner.soma_v_1,
1516
+ self.tuner.t_vec,
1517
+ self.tuner.soma_v_1,
1416
1518
  self.tuner.soma_v_2,
1417
- self.tuner.general_settings['tstart'],
1418
- self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
1519
+ self.tuner.general_settings["tstart"],
1520
+ self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
1419
1521
  )
1420
-
1522
+
1421
1523
  # Calculate error
1422
- error = (achieved_cc - target_cc) ** 2 #MSE
1423
-
1524
+ error = (achieved_cc - target_cc) ** 2 # MSE
1525
+
1424
1526
  # Store history
1425
- self.optimization_history.append({
1426
- 'resistance': resistance,
1427
- 'achieved_cc': achieved_cc,
1428
- 'error': error
1429
- })
1430
-
1527
+ self.optimization_history.append(
1528
+ {"resistance": resistance, "achieved_cc": achieved_cc, "error": error}
1529
+ )
1530
+
1431
1531
  return error
1432
-
1433
- def optimize_resistance(self, target_cc: float,
1434
- resistance_bounds: tuple = (1e-4, 1e-2),
1435
- method: str = 'bounded') -> GapOptimizationResult:
1532
+
1533
+ def optimize_resistance(
1534
+ self, target_cc: float, resistance_bounds: tuple = (1e-4, 1e-2), method: str = "bounded"
1535
+ ) -> GapOptimizationResult:
1436
1536
  """
1437
1537
  Optimize gap junction resistance to achieve a target coupling coefficient.
1438
-
1538
+
1439
1539
  Parameters:
1440
1540
  -----------
1441
1541
  target_cc : float
@@ -1445,7 +1545,7 @@ class GapJunctionOptimizer:
1445
1545
  method : str, optional
1446
1546
  Optimization method to use. Default is 'bounded' which works well
1447
1547
  for single-parameter optimization.
1448
-
1548
+
1449
1549
  Returns:
1450
1550
  --------
1451
1551
  GapOptimizationResult
@@ -1455,137 +1555,131 @@ class GapJunctionOptimizer:
1455
1555
  - target_cc: The target coupling coefficient
1456
1556
  - error: The final error (squared difference between target and achieved)
1457
1557
  - optimization_path: List of all values tried during optimization
1458
-
1558
+
1459
1559
  Notes:
1460
1560
  ------
1461
1561
  Uses scipy.optimize.minimize_scalar with bounded method, which is
1462
1562
  appropriate for this single-parameter optimization problem.
1463
1563
  """
1464
1564
  self.optimization_history = []
1465
-
1565
+
1466
1566
  # Run optimization
1467
1567
  result = minimize_scalar(
1468
- self._objective_function,
1469
- args=(target_cc,),
1470
- bounds=resistance_bounds,
1471
- method=method
1568
+ self._objective_function, args=(target_cc,), bounds=resistance_bounds, method=method
1472
1569
  )
1473
-
1570
+
1474
1571
  # Run final model with optimal resistance
1475
1572
  self.tuner.model(result.x)
1476
1573
  final_cc = self.tuner.coupling_coefficient(
1477
1574
  self.tuner.t_vec,
1478
1575
  self.tuner.soma_v_1,
1479
1576
  self.tuner.soma_v_2,
1480
- self.tuner.general_settings['tstart'],
1481
- self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
1577
+ self.tuner.general_settings["tstart"],
1578
+ self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
1482
1579
  )
1483
-
1580
+
1484
1581
  # Package up our results
1485
1582
  optimization_result = GapOptimizationResult(
1486
1583
  optimal_resistance=result.x,
1487
1584
  achieved_cc=final_cc,
1488
1585
  target_cc=target_cc,
1489
1586
  error=result.fun,
1490
- optimization_path=self.optimization_history
1587
+ optimization_path=self.optimization_history,
1491
1588
  )
1492
-
1589
+
1493
1590
  return optimization_result
1494
-
1591
+
1495
1592
  def plot_optimization_results(self, result: GapOptimizationResult):
1496
1593
  """
1497
1594
  Plot optimization results including convergence and final voltage traces
1498
-
1595
+
1499
1596
  Parameters:
1500
1597
  -----------
1501
1598
  result : GapOptimizationResult
1502
1599
  Results from optimization
1503
1600
  """
1504
1601
  fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
1505
-
1602
+
1506
1603
  # Plot voltage traces
1507
1604
  t_range = [
1508
- self.tuner.general_settings['tstart'] - 100.,
1509
- self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur'] + 100.
1605
+ self.tuner.general_settings["tstart"] - 100.0,
1606
+ self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"] + 100.0,
1510
1607
  ]
1511
1608
  t = np.array(self.tuner.t_vec)
1512
1609
  v1 = np.array(self.tuner.soma_v_1)
1513
1610
  v2 = np.array(self.tuner.soma_v_2)
1514
1611
  tidx = (t >= t_range[0]) & (t <= t_range[1])
1515
-
1516
- ax1.plot(t[tidx], v1[tidx], 'b', label=f'{self.tuner.cell_name} 1')
1517
- ax1.plot(t[tidx], v2[tidx], 'r', label=f'{self.tuner.cell_name} 2')
1518
- ax1.set_xlabel('Time (ms)')
1519
- ax1.set_ylabel('Membrane Voltage (mV)')
1612
+
1613
+ ax1.plot(t[tidx], v1[tidx], "b", label=f"{self.tuner.cell_name} 1")
1614
+ ax1.plot(t[tidx], v2[tidx], "r", label=f"{self.tuner.cell_name} 2")
1615
+ ax1.set_xlabel("Time (ms)")
1616
+ ax1.set_ylabel("Membrane Voltage (mV)")
1520
1617
  ax1.legend()
1521
- ax1.set_title('Optimized Voltage Traces')
1522
-
1618
+ ax1.set_title("Optimized Voltage Traces")
1619
+
1523
1620
  # Plot error convergence
1524
- errors = [h['error'] for h in result.optimization_path]
1621
+ errors = [h["error"] for h in result.optimization_path]
1525
1622
  ax2.plot(errors)
1526
- ax2.set_xlabel('Iteration')
1527
- ax2.set_ylabel('Error')
1528
- ax2.set_title('Error Convergence')
1529
- ax2.set_yscale('log')
1530
-
1623
+ ax2.set_xlabel("Iteration")
1624
+ ax2.set_ylabel("Error")
1625
+ ax2.set_title("Error Convergence")
1626
+ ax2.set_yscale("log")
1627
+
1531
1628
  # Plot resistance convergence
1532
- resistances = [h['resistance'] for h in result.optimization_path]
1629
+ resistances = [h["resistance"] for h in result.optimization_path]
1533
1630
  ax3.plot(resistances)
1534
- ax3.set_xlabel('Iteration')
1535
- ax3.set_ylabel('Resistance')
1536
- ax3.set_title('Resistance Convergence')
1537
- ax3.set_yscale('log')
1538
-
1631
+ ax3.set_xlabel("Iteration")
1632
+ ax3.set_ylabel("Resistance")
1633
+ ax3.set_title("Resistance Convergence")
1634
+ ax3.set_yscale("log")
1635
+
1539
1636
  # Print final results
1540
1637
  result_text = (
1541
- f'Optimal Resistance: {result.optimal_resistance:.2e}\n'
1542
- f'Target CC: {result.target_cc:.3f}\n'
1543
- f'Achieved CC: {result.achieved_cc:.3f}\n'
1544
- f'Final Error: {result.error:.2e}'
1638
+ f"Optimal Resistance: {result.optimal_resistance:.2e}\n"
1639
+ f"Target CC: {result.target_cc:.3f}\n"
1640
+ f"Achieved CC: {result.achieved_cc:.3f}\n"
1641
+ f"Final Error: {result.error:.2e}"
1545
1642
  )
1546
1643
  ax4.text(0.1, 0.7, result_text, transform=ax4.transAxes, fontsize=10)
1547
- ax4.axis('off')
1548
-
1644
+ ax4.axis("off")
1645
+
1549
1646
  plt.tight_layout()
1550
1647
  plt.show()
1551
1648
 
1552
1649
  def parameter_sweep(self, resistance_range: np.ndarray) -> dict:
1553
1650
  """
1554
1651
  Perform a parameter sweep across different resistance values.
1555
-
1652
+
1556
1653
  Parameters:
1557
1654
  -----------
1558
1655
  resistance_range : np.ndarray
1559
1656
  Array of resistance values to test.
1560
-
1657
+
1561
1658
  Returns:
1562
1659
  --------
1563
1660
  dict
1564
1661
  Dictionary containing the results of the parameter sweep, with keys:
1565
1662
  - 'resistance': List of resistance values tested
1566
1663
  - 'coupling_coefficient': Corresponding coupling coefficients
1567
-
1664
+
1568
1665
  Notes:
1569
1666
  ------
1570
1667
  This method is useful for understanding the relationship between gap junction
1571
1668
  resistance and coupling coefficient before attempting optimization.
1572
1669
  """
1573
- results = {
1574
- 'resistance': [],
1575
- 'coupling_coefficient': []
1576
- }
1577
-
1670
+ results = {"resistance": [], "coupling_coefficient": []}
1671
+
1578
1672
  for resistance in tqdm(resistance_range, desc="Sweeping resistance values"):
1579
1673
  self.tuner.model(resistance)
1580
1674
  cc = self.tuner.coupling_coefficient(
1581
1675
  self.tuner.t_vec,
1582
1676
  self.tuner.soma_v_1,
1583
1677
  self.tuner.soma_v_2,
1584
- self.tuner.general_settings['tstart'],
1585
- self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
1678
+ self.tuner.general_settings["tstart"],
1679
+ self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
1586
1680
  )
1587
-
1588
- results['resistance'].append(resistance)
1589
- results['coupling_coefficient'].append(cc)
1590
-
1591
- return results
1681
+
1682
+ results["resistance"].append(resistance)
1683
+ results["coupling_coefficient"].append(cc)
1684
+
1685
+ return results