bmtool 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/connectors.py +1 -1
- bmtool/singlecell.py +7 -2
- bmtool/synapses.py +495 -67
- {bmtool-0.6.3.dist-info → bmtool-0.6.5.dist-info}/METADATA +6 -4
- {bmtool-0.6.3.dist-info → bmtool-0.6.5.dist-info}/RECORD +9 -9
- {bmtool-0.6.3.dist-info → bmtool-0.6.5.dist-info}/LICENSE +0 -0
- {bmtool-0.6.3.dist-info → bmtool-0.6.5.dist-info}/WHEEL +0 -0
- {bmtool-0.6.3.dist-info → bmtool-0.6.5.dist-info}/entry_points.txt +0 -0
- {bmtool-0.6.3.dist-info → bmtool-0.6.5.dist-info}/top_level.txt +0 -0
bmtool/connectors.py
CHANGED
@@ -1082,7 +1082,7 @@ class ReciprocalConnector(AbstractConnector):
|
|
1082
1082
|
|
1083
1083
|
class UnidirectionConnector(AbstractConnector):
|
1084
1084
|
"""
|
1085
|
-
Object for
|
1085
|
+
Object for building unidirectional connections in bmtk network model with
|
1086
1086
|
given probability within a single population (or between two populations).
|
1087
1087
|
|
1088
1088
|
Parameters:
|
bmtool/singlecell.py
CHANGED
@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
|
|
7
7
|
from scipy.optimize import curve_fit
|
8
8
|
import neuron
|
9
9
|
from neuron import h
|
10
|
+
import pandas as pd
|
10
11
|
|
11
12
|
|
12
13
|
def load_biophys1():
|
@@ -356,8 +357,12 @@ class FI(object):
|
|
356
357
|
self.nspks = [len(v) for v in self.tspk_vecs]
|
357
358
|
print()
|
358
359
|
print("Results")
|
359
|
-
|
360
|
-
|
360
|
+
# lets make a df so the results line up nice
|
361
|
+
data = {'Injection (nA):':self.amps,'number of spikes':self.nspks}
|
362
|
+
df = pd.DataFrame(data)
|
363
|
+
print(df)
|
364
|
+
#print(f'Injection (nA): ' + ', '.join(f'{x:g}' for x in self.amps))
|
365
|
+
#print(f'Number of spikes: ' + ', '.join(f'{x:d}' for x in self.nspks))
|
361
366
|
print()
|
362
367
|
|
363
368
|
return self.amps, self.nspks
|
bmtool/synapses.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1
1
|
import os
|
2
2
|
import json
|
3
|
-
import numpy as np
|
4
3
|
import neuron
|
4
|
+
import numpy as np
|
5
5
|
from neuron import h
|
6
|
-
from
|
6
|
+
from typing import List, Dict, Callable, Optional,Tuple
|
7
|
+
from tqdm.notebook import tqdm
|
7
8
|
import matplotlib.pyplot as plt
|
9
|
+
from neuron.units import ms, mV
|
10
|
+
from dataclasses import dataclass
|
11
|
+
# scipy
|
8
12
|
from scipy.signal import find_peaks
|
9
|
-
from scipy.optimize import curve_fit
|
10
|
-
|
13
|
+
from scipy.optimize import curve_fit,minimize_scalar,minimize
|
14
|
+
# widgets
|
11
15
|
import ipywidgets as widgets
|
12
16
|
from IPython.display import display, clear_output
|
13
17
|
from ipywidgets import HBox, VBox
|
@@ -174,6 +178,11 @@ class SynapseTuner:
|
|
174
178
|
"""
|
175
179
|
self._set_up_cell()
|
176
180
|
self._set_up_synapse()
|
181
|
+
|
182
|
+
# user slider values if the sliders are set up
|
183
|
+
if hasattr(self, 'dynamic_sliders'):
|
184
|
+
syn_props = {var: slider.value for var, slider in self.dynamic_sliders.items()}
|
185
|
+
self._set_syn_prop(**syn_props)
|
177
186
|
|
178
187
|
# Set up the stimulus
|
179
188
|
self.nstim = h.NetStim()
|
@@ -200,10 +209,17 @@ class SynapseTuner:
|
|
200
209
|
self.nstim.number = 1
|
201
210
|
self.nstim2.start = h.tstop
|
202
211
|
h.run()
|
212
|
+
|
213
|
+
current = np.array(self.rec_vectors[self.current_name])
|
214
|
+
syn_prop = self._get_syn_prop(short=True)
|
215
|
+
current = (current - syn_prop['baseline']) * 1000 # Convert to pA
|
216
|
+
current_integral = np.trapz(current, dx=h.dt) # pA·ms
|
217
|
+
|
203
218
|
self._plot_model([self.general_settings['tstart'] - 5, self.general_settings['tstart'] + self.general_settings['tdur']])
|
204
219
|
syn_props = self._get_syn_prop(rise_interval=self.general_settings['rise_interval'])
|
205
220
|
for prop in syn_props.items():
|
206
221
|
print(prop)
|
222
|
+
print(f'Current Integral in pA: {current_integral:.2f}')
|
207
223
|
|
208
224
|
|
209
225
|
def _find_first(self, x):
|
@@ -319,10 +335,14 @@ class SynapseTuner:
|
|
319
335
|
axs = axs.ravel()
|
320
336
|
|
321
337
|
# Plot synaptic current (always included)
|
322
|
-
|
338
|
+
current = self.rec_vectors[self.current_name]
|
339
|
+
syn_prop = self._get_syn_prop(short=True)
|
340
|
+
current = (current - syn_prop['baseline'])
|
341
|
+
current = current * 1000
|
342
|
+
|
343
|
+
axs[0].plot(self.t, current)
|
323
344
|
if self.ispk !=None:
|
324
345
|
for num in range(len(self.ispk)):
|
325
|
-
current = 1000 * np.array(self.rec_vectors[self.current_name].to_python())
|
326
346
|
axs[0].text(self.t[self.ispk[num]],current[self.ispk[num]],f"{str(num+1)}")
|
327
347
|
|
328
348
|
axs[0].set_ylabel('Synaptic Current (pA)')
|
@@ -407,7 +427,7 @@ class SynapseTuner:
|
|
407
427
|
A list containing the peak amplitudes for each segment of the recorded synaptic current.
|
408
428
|
|
409
429
|
"""
|
410
|
-
isyn = np.
|
430
|
+
isyn = np.array(self.rec_vectors[self.current_name].to_python())
|
411
431
|
tspk = np.append(np.asarray(self.tspk), h.tstop)
|
412
432
|
syn_prop = self._get_syn_prop(short=True)
|
413
433
|
# print("syn_prp[sign] = " + str(syn_prop['sign']))
|
@@ -415,6 +435,7 @@ class SynapseTuner:
|
|
415
435
|
isyn *= syn_prop['sign']
|
416
436
|
ispk = np.floor((tspk + self.general_settings['delay']) / h.dt).astype(int)
|
417
437
|
|
438
|
+
|
418
439
|
try:
|
419
440
|
amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 1)]
|
420
441
|
# indexs of where the max of the synaptic current is at. This is then plotted
|
@@ -423,34 +444,32 @@ class SynapseTuner:
|
|
423
444
|
except:
|
424
445
|
amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 2)]
|
425
446
|
self.ispk = [np.argmax(isyn[ispk[i]:ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 2)]
|
426
|
-
|
447
|
+
|
427
448
|
return amp
|
428
449
|
|
429
450
|
|
430
|
-
def _find_max_amp(self, amp
|
451
|
+
def _find_max_amp(self, amp):
|
431
452
|
"""
|
432
|
-
Determines the maximum amplitude from the response data
|
453
|
+
Determines the maximum amplitude from the response data and returns the max in pA
|
433
454
|
|
434
455
|
Parameters:
|
435
456
|
-----------
|
436
457
|
amp : array-like
|
437
458
|
Array containing the amplitudes of synaptic responses.
|
438
|
-
normalize_by_trial : bool, optional
|
439
|
-
If True, normalize the maximum amplitude within each trial. Default is True.
|
440
459
|
|
441
460
|
Returns:
|
442
461
|
--------
|
443
462
|
max_amp : float
|
444
463
|
The maximum or minimum amplitude based on the sign of the response.
|
445
464
|
"""
|
446
|
-
max_amp =
|
447
|
-
min_amp =
|
465
|
+
max_amp = max(amp)
|
466
|
+
min_amp = min(amp)
|
448
467
|
if(abs(min_amp) > max_amp):
|
449
|
-
return min_amp
|
450
|
-
return max_amp
|
468
|
+
return min_amp * 1000 # scale unit
|
469
|
+
return max_amp * 1000 # scale unit
|
451
470
|
|
452
471
|
|
453
|
-
def
|
472
|
+
def _calc_ppr_induction_recovery(self,amp, normalize_by_trial=True,print_math=True):
|
454
473
|
"""
|
455
474
|
Calculates induction and recovery metrics from the synaptic response amplitudes.
|
456
475
|
|
@@ -471,46 +490,51 @@ class SynapseTuner:
|
|
471
490
|
The maximum amplitude in the response.
|
472
491
|
"""
|
473
492
|
amp = np.array(amp)
|
493
|
+
amp = (amp * 1000) # scale up
|
474
494
|
amp = amp.reshape(-1, amp.shape[-1])
|
475
495
|
maxamp = amp.max(axis=1 if normalize_by_trial else None)
|
476
496
|
|
477
|
-
# functions used to round array to 2 sig figs
|
478
|
-
def format_value(x):
|
479
|
-
return f"{x:.2g}"
|
480
|
-
|
481
|
-
# Function to apply format_value to an entire array
|
482
497
|
def format_array(arr):
|
483
|
-
|
484
|
-
return
|
498
|
+
"""Format an array to 2 significant figures for cleaner output."""
|
499
|
+
return np.array2string(arr, precision=2, separator=', ', suppress_small=True)
|
485
500
|
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
501
|
+
if print_math:
|
502
|
+
print("\n" + "="*40)
|
503
|
+
print("Short Term Plasticity Results")
|
504
|
+
print("="*40)
|
505
|
+
print("PPR: Above 1 is facilitating, below 1 is depressing.")
|
506
|
+
print("Induction: Above 0 is facilitating, below 0 is depressing.")
|
507
|
+
print("Recovery: A measure of how fast STP decays.\n")
|
508
|
+
|
509
|
+
# PPR Calculation
|
510
|
+
ppr = amp[:, 1:2] / amp[:, 0:1]
|
511
|
+
print("Paired Pulse Response (PPR)")
|
512
|
+
print("Calculation: 2nd pulse / 1st pulse")
|
513
|
+
print(f"Values: ({format_array(amp[:, 1:2])}) / ({format_array(amp[:, 0:1])}) = {format_array(ppr)}\n")
|
514
|
+
|
515
|
+
# Induction Calculation
|
516
|
+
induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
|
517
|
+
print("Induction")
|
518
|
+
print("Calculation: (avg(6th, 7th, 8th pulses) - 1st pulse) / max amps")
|
519
|
+
print(f"Values: avg({format_array(amp[:, 5:8])}) - {format_array(amp[:, :1])} / {format_array(maxamp)}")
|
520
|
+
print(f"({format_array(amp[:, 5:8].mean(axis=1))}) - ({format_array(amp[:, :1].mean(axis=1))}) / {format_array(maxamp)} = {induction:.3f}\n")
|
521
|
+
|
522
|
+
# Recovery Calculation
|
523
|
+
recovery = np.mean((amp[:, 8:12].mean(axis=1) - amp[:, :4].mean(axis=1)) / maxamp)
|
524
|
+
print("Recovery")
|
525
|
+
print("Calculation: (avg(9th, 10th, 11th, 12th pulses) - avg(1st to 4th pulses)) / max amps")
|
526
|
+
print(f"Values: avg({format_array(amp[:, 8:12])}) - avg({format_array(amp[:, :4])}) / {format_array(maxamp)}")
|
527
|
+
print(f"({format_array(amp[:, 8:12].mean(axis=1))}) - ({format_array(amp[:, :4].mean(axis=1))}) / {format_array(maxamp)} = {recovery:.3f}\n")
|
528
|
+
|
529
|
+
print("="*40 + "\n")
|
503
530
|
|
504
531
|
recovery = np.mean((amp[:, 8:12].mean(axis=1) - amp[:, :4].mean(axis=1)) / maxamp)
|
505
|
-
|
506
|
-
|
507
|
-
print(f"{format_array(amp[:, 8:12].mean(axis=1))} - {format_array(amp[:, :4].mean(axis=1))} / {format_array(maxamp)} = {format_array(recovery)}")
|
508
|
-
print("")
|
509
|
-
|
510
|
-
|
532
|
+
induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
|
533
|
+
ppr = amp[:, 1:2] / amp[:, 0:1]
|
511
534
|
# maxamp = max(amp, key=lambda x: abs(x[0]))
|
512
535
|
maxamp = maxamp.max()
|
513
|
-
|
536
|
+
|
537
|
+
return ppr, induction, recovery
|
514
538
|
|
515
539
|
|
516
540
|
def _set_syn_prop(self, **kwargs):
|
@@ -575,17 +599,19 @@ class SynapseTuner:
|
|
575
599
|
duration0 = 300
|
576
600
|
vlamp_status = self.vclamp
|
577
601
|
|
578
|
-
w_run = widgets.Button(description='Run', icon='history', button_style='primary')
|
602
|
+
w_run = widgets.Button(description='Run Train', icon='history', button_style='primary')
|
603
|
+
w_single = widgets.Button(description='Single Event', icon='check', button_style='success')
|
579
604
|
w_vclamp = widgets.ToggleButton(value=vlamp_status, description='Voltage Clamp', icon='fast-backward', button_style='warning')
|
580
605
|
w_input_mode = widgets.ToggleButton(value=False, description='Continuous input', icon='eject', button_style='info')
|
581
606
|
w_input_freq = widgets.SelectionSlider(options=freqs, value=freq0, description='Input Freq')
|
582
607
|
|
608
|
+
|
583
609
|
# Sliders for delay and duration
|
584
610
|
self.w_delay = widgets.SelectionSlider(options=delays, value=delay0, description='Delay')
|
585
611
|
self.w_duration = widgets.SelectionSlider(options=durations, value=duration0, description='Duration')
|
586
612
|
|
587
613
|
# Generate sliders dynamically based on valid numeric entries in self.slider_vars
|
588
|
-
dynamic_sliders = {}
|
614
|
+
self.dynamic_sliders = {}
|
589
615
|
print("Setting up slider! The sliders ranges are set by their init value so try changing that if you dont like the slider range!")
|
590
616
|
for key, value in self.slider_vars.items():
|
591
617
|
if isinstance(value, (int, float)): # Only create sliders for numeric values
|
@@ -595,18 +621,25 @@ class SynapseTuner:
|
|
595
621
|
slider = widgets.FloatSlider(value=value, min=0, max=1000, step=1, description=key)
|
596
622
|
else:
|
597
623
|
slider = widgets.FloatSlider(value=value, min=0, max=value*20, step=value/5, description=key)
|
598
|
-
dynamic_sliders[key] = slider
|
624
|
+
self.dynamic_sliders[key] = slider
|
599
625
|
else:
|
600
626
|
print(f"skipping slider for {key} due to not being a synaptic variable")
|
601
627
|
|
628
|
+
def run_single_event(*args):
|
629
|
+
clear_output()
|
630
|
+
display(ui)
|
631
|
+
self.vclamp = w_vclamp.value
|
632
|
+
# Update synaptic properties based on slider values
|
633
|
+
self.ispk=None
|
634
|
+
self.SingleEvent()
|
635
|
+
|
602
636
|
# Function to update UI based on input mode
|
603
637
|
def update_ui(*args):
|
604
638
|
clear_output()
|
605
639
|
display(ui)
|
606
640
|
self.vclamp = w_vclamp.value
|
607
641
|
self.input_mode = w_input_mode.value
|
608
|
-
|
609
|
-
syn_props = {var: slider.value for var, slider in dynamic_sliders.items()}
|
642
|
+
syn_props = {var: slider.value for var, slider in self.dynamic_sliders.items()}
|
610
643
|
self._set_syn_prop(**syn_props)
|
611
644
|
if self.input_mode == False:
|
612
645
|
self._simulate_model(w_input_freq.value, self.w_delay.value, w_vclamp.value)
|
@@ -614,10 +647,7 @@ class SynapseTuner:
|
|
614
647
|
self._simulate_model(w_input_freq.value, self.w_duration.value, w_vclamp.value)
|
615
648
|
amp = self._response_amplitude()
|
616
649
|
self._plot_model([self.general_settings['tstart'] - self.nstim.interval / 3, self.tstop])
|
617
|
-
self.
|
618
|
-
# print('Single trial ' + ('PSC' if self.vclamp else 'PSP'))
|
619
|
-
# print(f'Induction: {induction_single:.2f}; Recovery: {recovery:.2f}')
|
620
|
-
#print(f'Rest Amp: {amp[0]:.2f}; Maximum Amp: {maxamp:.2f}')
|
650
|
+
_ = self._calc_ppr_induction_recovery(amp)
|
621
651
|
|
622
652
|
# Function to switch between delay and duration sliders
|
623
653
|
def switch_slider(*args):
|
@@ -634,23 +664,23 @@ class SynapseTuner:
|
|
634
664
|
# Hide the duration slider initially until the user selects it
|
635
665
|
self.w_duration.layout.display = 'none' # Hide duration slider
|
636
666
|
|
667
|
+
w_single.on_click(run_single_event)
|
637
668
|
w_run.on_click(update_ui)
|
638
669
|
|
639
670
|
# Add the dynamic sliders to the UI
|
640
|
-
slider_widgets = [slider for slider in dynamic_sliders.values()]
|
671
|
+
slider_widgets = [slider for slider in self.dynamic_sliders.values()]
|
641
672
|
|
642
|
-
|
643
|
-
|
644
|
-
col1 = VBox(slider_widgets[:half]) # First half of sliders
|
645
|
-
col2 = VBox(slider_widgets[half:]) # Second half of sliders
|
673
|
+
button_row = HBox([w_run, w_single, w_vclamp, w_input_mode])
|
674
|
+
slider_row = HBox([w_input_freq, self.w_delay, self.w_duration])
|
646
675
|
|
647
|
-
|
676
|
+
half = len(slider_widgets) // 2
|
677
|
+
col1 = VBox(slider_widgets[:half])
|
678
|
+
col2 = VBox(slider_widgets[half:])
|
648
679
|
slider_columns = HBox([col1, col2])
|
649
680
|
|
650
|
-
ui = VBox([
|
681
|
+
ui = VBox([button_row, slider_row, slider_columns])
|
651
682
|
|
652
683
|
display(ui)
|
653
|
-
# run model with default parameters
|
654
684
|
update_ui()
|
655
685
|
|
656
686
|
class GapJunctionTuner:
|
@@ -738,7 +768,7 @@ class GapJunctionTuner:
|
|
738
768
|
return (v2[idx2] - v2[idx1]) / (v1[idx2] - v1[idx1])
|
739
769
|
|
740
770
|
|
741
|
-
def
|
771
|
+
def InteractiveTuner(self):
|
742
772
|
w_run = widgets.Button(description='Run', icon='history', button_style='primary')
|
743
773
|
values = [i * 10**-4 for i in range(1, 101)] # From 1e-4 to 1e-2
|
744
774
|
|
@@ -762,4 +792,402 @@ class GapJunctionTuner:
|
|
762
792
|
print(f"coupling_coefficient is {cc:0.4f}")
|
763
793
|
|
764
794
|
on_button()
|
765
|
-
w_run.on_click(on_button)
|
795
|
+
w_run.on_click(on_button)
|
796
|
+
|
797
|
+
|
798
|
+
# optimizers!
|
799
|
+
|
800
|
+
@dataclass
|
801
|
+
class SynapseOptimizationResult:
|
802
|
+
"""Container for synaptic parameter optimization results"""
|
803
|
+
optimal_params: Dict[str, float]
|
804
|
+
achieved_metrics: Dict[str, float]
|
805
|
+
target_metrics: Dict[str, float]
|
806
|
+
error: float
|
807
|
+
optimization_path: List[Dict[str, float]]
|
808
|
+
|
809
|
+
class SynapseOptimizer:
|
810
|
+
def __init__(self, tuner):
|
811
|
+
"""
|
812
|
+
Initialize the synapse optimizer with parameter scaling
|
813
|
+
|
814
|
+
Parameters:
|
815
|
+
-----------
|
816
|
+
tuner : SynapseTuner
|
817
|
+
Instance of the SynapseTuner class
|
818
|
+
"""
|
819
|
+
self.tuner = tuner
|
820
|
+
self.optimization_history = []
|
821
|
+
self.param_scales = {}
|
822
|
+
|
823
|
+
def _normalize_params(self, params: np.ndarray, param_names: List[str]) -> np.ndarray:
|
824
|
+
"""Normalize parameters to similar scales"""
|
825
|
+
return np.array([params[i] / self.param_scales[name] for i, name in enumerate(param_names)])
|
826
|
+
|
827
|
+
def _denormalize_params(self, normalized_params: np.ndarray, param_names: List[str]) -> np.ndarray:
|
828
|
+
"""Convert normalized parameters back to original scale"""
|
829
|
+
return np.array([normalized_params[i] * self.param_scales[name] for i, name in enumerate(param_names)])
|
830
|
+
|
831
|
+
def _calculate_metrics(self) -> Dict[str, float]:
|
832
|
+
"""Calculate standard metrics from the current simulation"""
|
833
|
+
self.tuner._simulate_model(50, 250) # 50 Hz with 250ms Delay
|
834
|
+
amp = self.tuner._response_amplitude()
|
835
|
+
ppr, induction, recovery = self.tuner._calc_ppr_induction_recovery(amp, print_math=False)
|
836
|
+
amp = self.tuner._find_max_amp(amp)
|
837
|
+
return {
|
838
|
+
'induction': float(induction), # Ensure these are scalar values
|
839
|
+
'ppr': float(ppr),
|
840
|
+
'recovery': float(recovery),
|
841
|
+
'max_amplitude': float(amp)
|
842
|
+
}
|
843
|
+
|
844
|
+
def _default_cost_function(self, metrics: Dict[str, float], target_metrics: Dict[str, float]) -> float:
|
845
|
+
"""Default cost function that targets induction"""
|
846
|
+
return float((metrics['induction'] - target_metrics['induction']) ** 2)
|
847
|
+
|
848
|
+
def _objective_function(self,
|
849
|
+
normalized_params: np.ndarray,
|
850
|
+
param_names: List[str],
|
851
|
+
cost_function: Callable,
|
852
|
+
target_metrics: Dict[str, float]) -> float:
|
853
|
+
"""
|
854
|
+
Calculate error using provided cost function
|
855
|
+
"""
|
856
|
+
# Denormalize parameters
|
857
|
+
params = self._denormalize_params(normalized_params, param_names)
|
858
|
+
|
859
|
+
# Set parameters
|
860
|
+
for name, value in zip(param_names, params):
|
861
|
+
setattr(self.tuner.syn, name, value)
|
862
|
+
|
863
|
+
# Calculate metrics and error
|
864
|
+
metrics = self._calculate_metrics()
|
865
|
+
error = float(cost_function(metrics, target_metrics)) # Ensure error is scalar
|
866
|
+
|
867
|
+
# Store history with denormalized values
|
868
|
+
history_entry = {
|
869
|
+
'params': dict(zip(param_names, params)),
|
870
|
+
'metrics': metrics,
|
871
|
+
'error': error
|
872
|
+
}
|
873
|
+
self.optimization_history.append(history_entry)
|
874
|
+
|
875
|
+
return error
|
876
|
+
|
877
|
+
def optimize_parameters(self,
|
878
|
+
target_metrics: Dict[str, float],
|
879
|
+
param_bounds: Dict[str, Tuple[float, float]],
|
880
|
+
cost_function: Optional[Callable] = None,
|
881
|
+
method: str = 'SLSQP',init_guess='random') -> SynapseOptimizationResult:
|
882
|
+
"""
|
883
|
+
Optimize synaptic parameters using custom cost function
|
884
|
+
"""
|
885
|
+
self.optimization_history = []
|
886
|
+
param_names = list(param_bounds.keys())
|
887
|
+
bounds = [param_bounds[name] for name in param_names]
|
888
|
+
|
889
|
+
if cost_function is None:
|
890
|
+
cost_function = self._default_cost_function
|
891
|
+
|
892
|
+
# Calculate scaling factors
|
893
|
+
self.param_scales = {
|
894
|
+
name: max(abs(bounds[i][0]), abs(bounds[i][1]))
|
895
|
+
for i, name in enumerate(param_names)
|
896
|
+
}
|
897
|
+
|
898
|
+
# Normalize bounds
|
899
|
+
normalized_bounds = [
|
900
|
+
(b[0]/self.param_scales[name], b[1]/self.param_scales[name])
|
901
|
+
for name, b in zip(param_names, bounds)
|
902
|
+
]
|
903
|
+
|
904
|
+
# picks with method of init value we want to use
|
905
|
+
if init_guess=='random':
|
906
|
+
x0 = np.array([np.random.uniform(b[0], b[1]) for b in bounds])
|
907
|
+
elif init_guess=='middle_guess':
|
908
|
+
x0 = [(b[0] + b[1])/2 for b in bounds]
|
909
|
+
else:
|
910
|
+
raise Exception("Pick a vaid init guess method either random or midde_guess")
|
911
|
+
normalized_x0 = self._normalize_params(np.array(x0), param_names)
|
912
|
+
|
913
|
+
|
914
|
+
# Run optimization
|
915
|
+
result = minimize(
|
916
|
+
self._objective_function,
|
917
|
+
normalized_x0,
|
918
|
+
args=(param_names, cost_function, target_metrics),
|
919
|
+
method=method,
|
920
|
+
bounds=normalized_bounds
|
921
|
+
)
|
922
|
+
|
923
|
+
# Get final parameters and metrics
|
924
|
+
final_params = dict(zip(param_names, self._denormalize_params(result.x, param_names)))
|
925
|
+
for name, value in final_params.items():
|
926
|
+
setattr(self.tuner.syn, name, value)
|
927
|
+
final_metrics = self._calculate_metrics()
|
928
|
+
|
929
|
+
return SynapseOptimizationResult(
|
930
|
+
optimal_params=final_params,
|
931
|
+
achieved_metrics=final_metrics,
|
932
|
+
target_metrics=target_metrics,
|
933
|
+
error=result.fun,
|
934
|
+
optimization_path=self.optimization_history
|
935
|
+
)
|
936
|
+
|
937
|
+
def plot_optimization_results(self, result: SynapseOptimizationResult):
|
938
|
+
"""Plot optimization results including convergence and final traces."""
|
939
|
+
# Ensure errors are properly shaped for plotting
|
940
|
+
iterations = range(len(result.optimization_path))
|
941
|
+
errors = np.array([float(h['error']) for h in result.optimization_path]).flatten()
|
942
|
+
|
943
|
+
# Plot error convergence
|
944
|
+
fig1, ax1 = plt.subplots(figsize=(8, 5))
|
945
|
+
ax1.plot(iterations, errors, label='Error')
|
946
|
+
ax1.set_xlabel('Iteration')
|
947
|
+
ax1.set_ylabel('Error')
|
948
|
+
ax1.set_title('Error Convergence')
|
949
|
+
ax1.set_yscale('log')
|
950
|
+
ax1.legend()
|
951
|
+
plt.tight_layout()
|
952
|
+
plt.show()
|
953
|
+
|
954
|
+
# Plot parameter convergence
|
955
|
+
param_names = list(result.optimal_params.keys())
|
956
|
+
num_params = len(param_names)
|
957
|
+
fig2, axs = plt.subplots(nrows=num_params, ncols=1, figsize=(8, 5 * num_params))
|
958
|
+
|
959
|
+
if num_params == 1:
|
960
|
+
axs = [axs]
|
961
|
+
|
962
|
+
for ax, param in zip(axs, param_names):
|
963
|
+
values = [float(h['params'][param]) for h in result.optimization_path]
|
964
|
+
ax.plot(iterations, values, label=f'{param}')
|
965
|
+
ax.set_xlabel('Iteration')
|
966
|
+
ax.set_ylabel('Parameter Value')
|
967
|
+
ax.set_title(f'Convergence of {param}')
|
968
|
+
ax.legend()
|
969
|
+
|
970
|
+
plt.tight_layout()
|
971
|
+
plt.show()
|
972
|
+
|
973
|
+
# Print final results
|
974
|
+
print("Optimization Results:")
|
975
|
+
print(f"Final Error: {float(result.error):.2e}\n")
|
976
|
+
print("Target Metrics:")
|
977
|
+
for metric, value in result.target_metrics.items():
|
978
|
+
achieved = result.achieved_metrics.get(metric)
|
979
|
+
if achieved is not None and metric != 'amplitudes': # Skip amplitude array
|
980
|
+
print(f"{metric}: {float(achieved):.3f} (target: {float(value):.3f})")
|
981
|
+
|
982
|
+
print("\nOptimal Parameters:")
|
983
|
+
for param, value in result.optimal_params.items():
|
984
|
+
print(f"{param}: {float(value):.3f}")
|
985
|
+
|
986
|
+
# Plot final model response
|
987
|
+
self.tuner._plot_model([self.tuner.general_settings['tstart'] - self.tuner.nstim.interval / 3, self.tuner.tstop])
|
988
|
+
amp = self.tuner._response_amplitude()
|
989
|
+
self.tuner._calc_ppr_induction_recovery(amp)
|
990
|
+
|
991
|
+
|
992
|
+
# dataclass means just init the typehints as self.typehint. looks a bit cleaner
|
993
|
+
@dataclass
|
994
|
+
class GapOptimizationResult:
|
995
|
+
"""Container for gap junction optimization results"""
|
996
|
+
optimal_resistance: float
|
997
|
+
achieved_cc: float
|
998
|
+
target_cc: float
|
999
|
+
error: float
|
1000
|
+
optimization_path: List[Dict[str, float]]
|
1001
|
+
|
1002
|
+
class GapJunctionOptimizer:
|
1003
|
+
def __init__(self, tuner):
|
1004
|
+
"""
|
1005
|
+
Initialize the gap junction optimizer
|
1006
|
+
|
1007
|
+
Parameters:
|
1008
|
+
-----------
|
1009
|
+
tuner : GapJunctionTuner
|
1010
|
+
Instance of the GapJunctionTuner class
|
1011
|
+
"""
|
1012
|
+
self.tuner = tuner
|
1013
|
+
self.optimization_history = []
|
1014
|
+
|
1015
|
+
def _objective_function(self, resistance: float, target_cc: float) -> float:
|
1016
|
+
"""
|
1017
|
+
Calculate error between achieved and target coupling coefficient
|
1018
|
+
|
1019
|
+
Parameters:
|
1020
|
+
-----------
|
1021
|
+
resistance : float
|
1022
|
+
Gap junction resistance to try
|
1023
|
+
target_cc : float
|
1024
|
+
Target coupling coefficient to match
|
1025
|
+
|
1026
|
+
Returns:
|
1027
|
+
--------
|
1028
|
+
float : Error between achieved and target coupling coefficient
|
1029
|
+
"""
|
1030
|
+
# Run model with current resistance
|
1031
|
+
self.tuner.model(resistance)
|
1032
|
+
|
1033
|
+
# Calculate coupling coefficient
|
1034
|
+
achieved_cc = self.tuner.coupling_coefficient(
|
1035
|
+
self.tuner.t_vec,
|
1036
|
+
self.tuner.soma_v_1,
|
1037
|
+
self.tuner.soma_v_2,
|
1038
|
+
self.tuner.general_settings['tstart'],
|
1039
|
+
self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
|
1040
|
+
)
|
1041
|
+
|
1042
|
+
# Calculate error
|
1043
|
+
error = (achieved_cc - target_cc) ** 2 #MSE
|
1044
|
+
|
1045
|
+
# Store history
|
1046
|
+
self.optimization_history.append({
|
1047
|
+
'resistance': resistance,
|
1048
|
+
'achieved_cc': achieved_cc,
|
1049
|
+
'error': error
|
1050
|
+
})
|
1051
|
+
|
1052
|
+
return error
|
1053
|
+
|
1054
|
+
def optimize_resistance(self, target_cc: float,
|
1055
|
+
resistance_bounds: tuple = (1e-4, 1e-2),
|
1056
|
+
method: str = 'bounded') -> GapOptimizationResult:
|
1057
|
+
"""
|
1058
|
+
Optimize gap junction resistance to achieve target coupling coefficient
|
1059
|
+
|
1060
|
+
Parameters:
|
1061
|
+
-----------
|
1062
|
+
target_cc : float
|
1063
|
+
Target coupling coefficient to achieve
|
1064
|
+
resistance_bounds : tuple, optional
|
1065
|
+
(min, max) bounds for resistance search
|
1066
|
+
method : str, optional
|
1067
|
+
Optimization method to use (default: 'bounded')
|
1068
|
+
|
1069
|
+
Returns:
|
1070
|
+
--------
|
1071
|
+
GapOptimizationResult
|
1072
|
+
Container with optimization results
|
1073
|
+
"""
|
1074
|
+
self.optimization_history = []
|
1075
|
+
|
1076
|
+
# Run optimization
|
1077
|
+
result = minimize_scalar(
|
1078
|
+
self._objective_function,
|
1079
|
+
args=(target_cc,),
|
1080
|
+
bounds=resistance_bounds,
|
1081
|
+
method=method
|
1082
|
+
)
|
1083
|
+
|
1084
|
+
# Run final model with optimal resistance
|
1085
|
+
self.tuner.model(result.x)
|
1086
|
+
final_cc = self.tuner.coupling_coefficient(
|
1087
|
+
self.tuner.t_vec,
|
1088
|
+
self.tuner.soma_v_1,
|
1089
|
+
self.tuner.soma_v_2,
|
1090
|
+
self.tuner.general_settings['tstart'],
|
1091
|
+
self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
|
1092
|
+
)
|
1093
|
+
|
1094
|
+
# Package up our results
|
1095
|
+
optimization_result = GapOptimizationResult(
|
1096
|
+
optimal_resistance=result.x,
|
1097
|
+
achieved_cc=final_cc,
|
1098
|
+
target_cc=target_cc,
|
1099
|
+
error=result.fun,
|
1100
|
+
optimization_path=self.optimization_history
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
return optimization_result
|
1104
|
+
|
1105
|
+
def plot_optimization_results(self, result: GapOptimizationResult):
|
1106
|
+
"""
|
1107
|
+
Plot optimization results including convergence and final voltage traces
|
1108
|
+
|
1109
|
+
Parameters:
|
1110
|
+
-----------
|
1111
|
+
result : GapOptimizationResult
|
1112
|
+
Results from optimization
|
1113
|
+
"""
|
1114
|
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
|
1115
|
+
|
1116
|
+
# Plot voltage traces
|
1117
|
+
t_range = [
|
1118
|
+
self.tuner.general_settings['tstart'] - 100.,
|
1119
|
+
self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur'] + 100.
|
1120
|
+
]
|
1121
|
+
t = np.array(self.tuner.t_vec)
|
1122
|
+
v1 = np.array(self.tuner.soma_v_1)
|
1123
|
+
v2 = np.array(self.tuner.soma_v_2)
|
1124
|
+
tidx = (t >= t_range[0]) & (t <= t_range[1])
|
1125
|
+
|
1126
|
+
ax1.plot(t[tidx], v1[tidx], 'b', label=f'{self.tuner.cell_name} 1')
|
1127
|
+
ax1.plot(t[tidx], v2[tidx], 'r', label=f'{self.tuner.cell_name} 2')
|
1128
|
+
ax1.set_xlabel('Time (ms)')
|
1129
|
+
ax1.set_ylabel('Membrane Voltage (mV)')
|
1130
|
+
ax1.legend()
|
1131
|
+
ax1.set_title('Optimized Voltage Traces')
|
1132
|
+
|
1133
|
+
# Plot error convergence
|
1134
|
+
errors = [h['error'] for h in result.optimization_path]
|
1135
|
+
ax2.plot(errors)
|
1136
|
+
ax2.set_xlabel('Iteration')
|
1137
|
+
ax2.set_ylabel('Error')
|
1138
|
+
ax2.set_title('Error Convergence')
|
1139
|
+
ax2.set_yscale('log')
|
1140
|
+
|
1141
|
+
# Plot resistance convergence
|
1142
|
+
resistances = [h['resistance'] for h in result.optimization_path]
|
1143
|
+
ax3.plot(resistances)
|
1144
|
+
ax3.set_xlabel('Iteration')
|
1145
|
+
ax3.set_ylabel('Resistance')
|
1146
|
+
ax3.set_title('Resistance Convergence')
|
1147
|
+
ax3.set_yscale('log')
|
1148
|
+
|
1149
|
+
# Print final results
|
1150
|
+
result_text = (
|
1151
|
+
f'Optimal Resistance: {result.optimal_resistance:.2e}\n'
|
1152
|
+
f'Target CC: {result.target_cc:.3f}\n'
|
1153
|
+
f'Achieved CC: {result.achieved_cc:.3f}\n'
|
1154
|
+
f'Final Error: {result.error:.2e}'
|
1155
|
+
)
|
1156
|
+
ax4.text(0.1, 0.7, result_text, transform=ax4.transAxes, fontsize=10)
|
1157
|
+
ax4.axis('off')
|
1158
|
+
|
1159
|
+
plt.tight_layout()
|
1160
|
+
plt.show()
|
1161
|
+
|
1162
|
+
def parameter_sweep(self, resistance_range: np.ndarray) -> dict:
|
1163
|
+
"""
|
1164
|
+
Perform a parameter sweep across different resistance values
|
1165
|
+
|
1166
|
+
Parameters:
|
1167
|
+
-----------
|
1168
|
+
resistance_range : np.ndarray
|
1169
|
+
Array of resistance values to test
|
1170
|
+
|
1171
|
+
Returns:
|
1172
|
+
--------
|
1173
|
+
dict : Results of parameter sweep including coupling coefficients
|
1174
|
+
"""
|
1175
|
+
results = {
|
1176
|
+
'resistance': [],
|
1177
|
+
'coupling_coefficient': []
|
1178
|
+
}
|
1179
|
+
|
1180
|
+
for resistance in tqdm(resistance_range, desc="Sweeping resistance values"):
|
1181
|
+
self.tuner.model(resistance)
|
1182
|
+
cc = self.tuner.coupling_coefficient(
|
1183
|
+
self.tuner.t_vec,
|
1184
|
+
self.tuner.soma_v_1,
|
1185
|
+
self.tuner.soma_v_2,
|
1186
|
+
self.tuner.general_settings['tstart'],
|
1187
|
+
self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
results['resistance'].append(resistance)
|
1191
|
+
results['coupling_coefficient'].append(cc)
|
1192
|
+
|
1193
|
+
return results
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: bmtool
|
3
|
-
Version: 0.6.
|
3
|
+
Version: 0.6.5
|
4
4
|
Summary: BMTool
|
5
5
|
Home-page: https://github.com/cyneuro/bmtool
|
6
6
|
Download-URL:
|
@@ -114,7 +114,9 @@ Commands:
|
|
114
114
|
- [ZAP](#zap)
|
115
115
|
- [Tuner](#single-cell-tuning)
|
116
116
|
- [VHalf Segregation](#vhalf-segregation-module)
|
117
|
-
####
|
117
|
+
#### Jupyter Notebook for how to use passive properties, current injection, FI curve, and ZAP can be found [here](examples/single_cell/). There are versions with example how to use single cells in HOC format and in the Allen Database format.
|
118
|
+
|
119
|
+
#### The single cell module can take any neuron HOC object and calculate passive properties, run a current clamp, calculate FI curve, or run a ZAP. The module is designed to work with HOC template files and can also turn Allen database SWC and json files into HOC objects and use those. The examples below uses "Cell_Cf" which is the name of a HOC templated loaded by the profiler. E
|
118
120
|
|
119
121
|
#### First step is it initialize the profiler.
|
120
122
|
|
@@ -351,7 +353,7 @@ ex: [https://github.com/tjbanks/two-cell-hco](https://github.com/tjbanks/two-cel
|
|
351
353
|
-Gap Junction tuner
|
352
354
|
|
353
355
|
#### SynapticTuner - Aids in the tuning of synapses by printing out synaptic properties and giving the user sliders in a Jupyter notebook to tune the synapse. For more info view the example [here](examples/synapses/synaptic_tuner.ipynb)
|
354
|
-
#### GapJunctionTuner - Provides jupyter sliders to tune for coupling coefficient in a similar style to the SynapticTuner an example can be viewed [here](examples/synapses/gap_junction_tuner.ipynb)
|
356
|
+
#### GapJunctionTuner - Provides jupyter sliders to tune for coupling coefficient in a similar style to the SynapticTuner. The Gap junction tuner also has an optimizer which can find the best resistance for the desired coupling coefficient. an example can be viewed [here](examples/synapses/gap_junction_tuner.ipynb)
|
355
357
|
|
356
358
|
### Connectors Module
|
357
359
|
- [UnidirectionConnector](#unidirectional-connector---unidirectional-connections-in-bmtk-network-model-with-given-probability-within-a-single-population-or-between-two-populations)
|
@@ -409,7 +411,6 @@ net.add_edges(**connector.edge_params())
|
|
409
411
|
```
|
410
412
|
|
411
413
|
## Bmplot Module
|
412
|
-
### for a demo please see the notebook [here](examples/bmplot/bmplot.ipynb)
|
413
414
|
- [total_connection_matrix](#total_connection_matrix)
|
414
415
|
- [percent_connection_matrix](#percent_connection_matrix)
|
415
416
|
- [connector_percent_matrix](#connector_percent_matrix)
|
@@ -420,6 +421,7 @@ net.add_edges(**connector.edge_params())
|
|
420
421
|
- [connection_histogram](#connection_histogram)
|
421
422
|
- [plot_3d_positions](#plot_3d_positions)
|
422
423
|
- [plot_3d_cell_rotation](#plot_3d_cell_rotation)
|
424
|
+
### for a demo please see the notebook [here](examples/bmplot/bmplot.ipynb)
|
423
425
|
|
424
426
|
### total_connection_matrix
|
425
427
|
#### Generates a table of total number of connections each neuron population recieves
|
@@ -2,12 +2,12 @@ bmtool/SLURM.py,sha256=4KvtrPofaHv5iairetgrlXdhAdowJfg_aeTx9W59JDM,16618
|
|
2
2
|
bmtool/__init__.py,sha256=ZStTNkAJHJxG7Pwiy5UgCzC4KlhMS5pUNPtUJZVwL_Y,136
|
3
3
|
bmtool/__main__.py,sha256=TmFkmDxjZ6250nYD4cgGhn-tbJeEm0u-EMz2ajAN9vE,650
|
4
4
|
bmtool/bmplot.py,sha256=Im-Jrv8TK3CmTtksFzHrVogAve0l9ZwRrCW4q2MFRiA,53966
|
5
|
-
bmtool/connectors.py,sha256=
|
5
|
+
bmtool/connectors.py,sha256=hWkUUcJ4tmas8NDOFPPjQT-TgTlPcpjuZsYyAW2WkPA,72242
|
6
6
|
bmtool/graphs.py,sha256=K8BiughRUeXFVvAgo8UzrwpSClIVg7UfmIcvtEsEsk0,6020
|
7
7
|
bmtool/manage.py,sha256=_lCU0qBQZ4jSxjzAJUd09JEetb--cud7KZgxQFbLGSY,657
|
8
8
|
bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
9
|
-
bmtool/singlecell.py,sha256=
|
10
|
-
bmtool/synapses.py,sha256=
|
9
|
+
bmtool/singlecell.py,sha256=MQiLucsI6OBIjtcJra3Z9PTFQOE-Zn5ST-R9SmFvrbQ,27049
|
10
|
+
bmtool/synapses.py,sha256=jQFOpi9hzzBEijDQ7dsWfcxW-DtxN9v0UWxCqDSlcTs,48466
|
11
11
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
|
13
13
|
bmtool/debug/debug.py,sha256=xqnkzLiH3s-tS26Y5lZZL62qR2evJdi46Gud-HzxEN4,207
|
@@ -16,9 +16,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
|
16
16
|
bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
|
17
17
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
18
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
19
|
-
bmtool-0.6.
|
20
|
-
bmtool-0.6.
|
21
|
-
bmtool-0.6.
|
22
|
-
bmtool-0.6.
|
23
|
-
bmtool-0.6.
|
24
|
-
bmtool-0.6.
|
19
|
+
bmtool-0.6.5.dist-info/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
20
|
+
bmtool-0.6.5.dist-info/METADATA,sha256=-67VVvOyPqiGX7qoidn8SLVHuHRxSZnOIPKPxqZ8TNI,20224
|
21
|
+
bmtool-0.6.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
22
|
+
bmtool-0.6.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
23
|
+
bmtool-0.6.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
24
|
+
bmtool-0.6.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|