bmtool 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/connectors.py CHANGED
@@ -1082,7 +1082,7 @@ class ReciprocalConnector(AbstractConnector):
1082
1082
 
1083
1083
  class UnidirectionConnector(AbstractConnector):
1084
1084
  """
1085
- Object for buiilding unidirectional connections in bmtk network model with
1085
+ Object for building unidirectional connections in bmtk network model with
1086
1086
  given probability within a single population (or between two populations).
1087
1087
 
1088
1088
  Parameters:
bmtool/singlecell.py CHANGED
@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
7
7
  from scipy.optimize import curve_fit
8
8
  import neuron
9
9
  from neuron import h
10
+ import pandas as pd
10
11
 
11
12
 
12
13
  def load_biophys1():
@@ -356,8 +357,12 @@ class FI(object):
356
357
  self.nspks = [len(v) for v in self.tspk_vecs]
357
358
  print()
358
359
  print("Results")
359
- print(f'Injection (nA): ' + ', '.join(f'{x:g}' for x in self.amps))
360
- print(f'Number of spikes: ' + ', '.join(f'{x:d}' for x in self.nspks))
360
+ # lets make a df so the results line up nice
361
+ data = {'Injection (nA):':self.amps,'number of spikes':self.nspks}
362
+ df = pd.DataFrame(data)
363
+ print(df)
364
+ #print(f'Injection (nA): ' + ', '.join(f'{x:g}' for x in self.amps))
365
+ #print(f'Number of spikes: ' + ', '.join(f'{x:d}' for x in self.nspks))
361
366
  print()
362
367
 
363
368
  return self.amps, self.nspks
bmtool/synapses.py CHANGED
@@ -1,13 +1,17 @@
1
1
  import os
2
2
  import json
3
- import numpy as np
4
3
  import neuron
4
+ import numpy as np
5
5
  from neuron import h
6
- from neuron.units import ms, mV
6
+ from typing import List, Dict, Callable, Optional,Tuple
7
+ from tqdm.notebook import tqdm
7
8
  import matplotlib.pyplot as plt
9
+ from neuron.units import ms, mV
10
+ from dataclasses import dataclass
11
+ # scipy
8
12
  from scipy.signal import find_peaks
9
- from scipy.optimize import curve_fit
10
-
13
+ from scipy.optimize import curve_fit,minimize_scalar,minimize
14
+ # widgets
11
15
  import ipywidgets as widgets
12
16
  from IPython.display import display, clear_output
13
17
  from ipywidgets import HBox, VBox
@@ -174,6 +178,11 @@ class SynapseTuner:
174
178
  """
175
179
  self._set_up_cell()
176
180
  self._set_up_synapse()
181
+
182
+ # user slider values if the sliders are set up
183
+ if hasattr(self, 'dynamic_sliders'):
184
+ syn_props = {var: slider.value for var, slider in self.dynamic_sliders.items()}
185
+ self._set_syn_prop(**syn_props)
177
186
 
178
187
  # Set up the stimulus
179
188
  self.nstim = h.NetStim()
@@ -200,10 +209,17 @@ class SynapseTuner:
200
209
  self.nstim.number = 1
201
210
  self.nstim2.start = h.tstop
202
211
  h.run()
212
+
213
+ current = np.array(self.rec_vectors[self.current_name])
214
+ syn_prop = self._get_syn_prop(short=True)
215
+ current = (current - syn_prop['baseline']) * 1000 # Convert to pA
216
+ current_integral = np.trapz(current, dx=h.dt) # pA·ms
217
+
203
218
  self._plot_model([self.general_settings['tstart'] - 5, self.general_settings['tstart'] + self.general_settings['tdur']])
204
219
  syn_props = self._get_syn_prop(rise_interval=self.general_settings['rise_interval'])
205
220
  for prop in syn_props.items():
206
221
  print(prop)
222
+ print(f'Current Integral in pA: {current_integral:.2f}')
207
223
 
208
224
 
209
225
  def _find_first(self, x):
@@ -319,10 +335,14 @@ class SynapseTuner:
319
335
  axs = axs.ravel()
320
336
 
321
337
  # Plot synaptic current (always included)
322
- axs[0].plot(self.t, 1000 * self.rec_vectors[self.current_name])
338
+ current = self.rec_vectors[self.current_name]
339
+ syn_prop = self._get_syn_prop(short=True)
340
+ current = (current - syn_prop['baseline'])
341
+ current = current * 1000
342
+
343
+ axs[0].plot(self.t, current)
323
344
  if self.ispk !=None:
324
345
  for num in range(len(self.ispk)):
325
- current = 1000 * np.array(self.rec_vectors[self.current_name].to_python())
326
346
  axs[0].text(self.t[self.ispk[num]],current[self.ispk[num]],f"{str(num+1)}")
327
347
 
328
348
  axs[0].set_ylabel('Synaptic Current (pA)')
@@ -407,7 +427,7 @@ class SynapseTuner:
407
427
  A list containing the peak amplitudes for each segment of the recorded synaptic current.
408
428
 
409
429
  """
410
- isyn = np.asarray(self.rec_vectors['i'])
430
+ isyn = np.array(self.rec_vectors[self.current_name].to_python())
411
431
  tspk = np.append(np.asarray(self.tspk), h.tstop)
412
432
  syn_prop = self._get_syn_prop(short=True)
413
433
  # print("syn_prp[sign] = " + str(syn_prop['sign']))
@@ -415,6 +435,7 @@ class SynapseTuner:
415
435
  isyn *= syn_prop['sign']
416
436
  ispk = np.floor((tspk + self.general_settings['delay']) / h.dt).astype(int)
417
437
 
438
+
418
439
  try:
419
440
  amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 1)]
420
441
  # indexs of where the max of the synaptic current is at. This is then plotted
@@ -423,34 +444,32 @@ class SynapseTuner:
423
444
  except:
424
445
  amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 2)]
425
446
  self.ispk = [np.argmax(isyn[ispk[i]:ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 2)]
426
-
447
+
427
448
  return amp
428
449
 
429
450
 
430
- def _find_max_amp(self, amp, normalize_by_trial=True):
451
+ def _find_max_amp(self, amp):
431
452
  """
432
- Determines the maximum amplitude from the response data.
453
+ Determines the maximum amplitude from the response data and returns the max in pA
433
454
 
434
455
  Parameters:
435
456
  -----------
436
457
  amp : array-like
437
458
  Array containing the amplitudes of synaptic responses.
438
- normalize_by_trial : bool, optional
439
- If True, normalize the maximum amplitude within each trial. Default is True.
440
459
 
441
460
  Returns:
442
461
  --------
443
462
  max_amp : float
444
463
  The maximum or minimum amplitude based on the sign of the response.
445
464
  """
446
- max_amp = amp.max(axis=1 if normalize_by_trial else None)
447
- min_amp = amp.min(axis=1 if normalize_by_trial else None)
465
+ max_amp = max(amp)
466
+ min_amp = min(amp)
448
467
  if(abs(min_amp) > max_amp):
449
- return min_amp
450
- return max_amp
468
+ return min_amp * 1000 # scale unit
469
+ return max_amp * 1000 # scale unit
451
470
 
452
471
 
453
- def _print_ppr_induction_recovery(self,amp, normalize_by_trial=True):
472
+ def _calc_ppr_induction_recovery(self,amp, normalize_by_trial=True,print_math=True):
454
473
  """
455
474
  Calculates induction and recovery metrics from the synaptic response amplitudes.
456
475
 
@@ -471,46 +490,51 @@ class SynapseTuner:
471
490
  The maximum amplitude in the response.
472
491
  """
473
492
  amp = np.array(amp)
493
+ amp = (amp * 1000) # scale up
474
494
  amp = amp.reshape(-1, amp.shape[-1])
475
495
  maxamp = amp.max(axis=1 if normalize_by_trial else None)
476
496
 
477
- # functions used to round array to 2 sig figs
478
- def format_value(x):
479
- return f"{x:.2g}"
480
-
481
- # Function to apply format_value to an entire array
482
497
  def format_array(arr):
483
- # Flatten the array and format each element
484
- return ' '.join([format_value(x) for x in arr.flatten()])
498
+ """Format an array to 2 significant figures for cleaner output."""
499
+ return np.array2string(arr, precision=2, separator=', ', suppress_small=True)
485
500
 
486
- print("Short Term Plasticity")
487
- print("PPR: above 1 is facilitating below 1 is depressing")
488
- print("Induction: above 0 is facilitating below 0 is depressing")
489
- print("Recovery: measure of how fast STP decays")
490
- print("")
491
-
492
- ppr = amp[:,1:2] / amp[:,0:1]
493
- print(f"Paired Pulse Response Calculation: 2nd pulse / 1st pulse ")
494
- print(f"{format_array(amp[:,1:2])} - {format_array(amp[:,0:1])} = {format_array(ppr)}")
495
- print("")
496
-
497
- induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
498
- print(f"Induction Calculation: (avg(6,7,8 pulses) - 1 pulse) / max amps")
499
- # Format and print arrays with 2 significant figures
500
- print(f"{format_array(amp[:, 5:8])} - {format_array(amp[:, :1])} / {format_array(maxamp)}")
501
- print(f"{format_array(amp[:, 5:8].mean(axis=1))} - {format_array(amp[:, :1].mean(axis=1))} / {format_array(maxamp)} = {format_array(induction)}")
502
- print("")
501
+ if print_math:
502
+ print("\n" + "="*40)
503
+ print("Short Term Plasticity Results")
504
+ print("="*40)
505
+ print("PPR: Above 1 is facilitating, below 1 is depressing.")
506
+ print("Induction: Above 0 is facilitating, below 0 is depressing.")
507
+ print("Recovery: A measure of how fast STP decays.\n")
508
+
509
+ # PPR Calculation
510
+ ppr = amp[:, 1:2] / amp[:, 0:1]
511
+ print("Paired Pulse Response (PPR)")
512
+ print("Calculation: 2nd pulse / 1st pulse")
513
+ print(f"Values: ({format_array(amp[:, 1:2])}) / ({format_array(amp[:, 0:1])}) = {format_array(ppr)}\n")
514
+
515
+ # Induction Calculation
516
+ induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
517
+ print("Induction")
518
+ print("Calculation: (avg(6th, 7th, 8th pulses) - 1st pulse) / max amps")
519
+ print(f"Values: avg({format_array(amp[:, 5:8])}) - {format_array(amp[:, :1])} / {format_array(maxamp)}")
520
+ print(f"({format_array(amp[:, 5:8].mean(axis=1))}) - ({format_array(amp[:, :1].mean(axis=1))}) / {format_array(maxamp)} = {induction:.3f}\n")
521
+
522
+ # Recovery Calculation
523
+ recovery = np.mean((amp[:, 8:12].mean(axis=1) - amp[:, :4].mean(axis=1)) / maxamp)
524
+ print("Recovery")
525
+ print("Calculation: (avg(9th, 10th, 11th, 12th pulses) - avg(1st to 4th pulses)) / max amps")
526
+ print(f"Values: avg({format_array(amp[:, 8:12])}) - avg({format_array(amp[:, :4])}) / {format_array(maxamp)}")
527
+ print(f"({format_array(amp[:, 8:12].mean(axis=1))}) - ({format_array(amp[:, :4].mean(axis=1))}) / {format_array(maxamp)} = {recovery:.3f}\n")
528
+
529
+ print("="*40 + "\n")
503
530
 
504
531
  recovery = np.mean((amp[:, 8:12].mean(axis=1) - amp[:, :4].mean(axis=1)) / maxamp)
505
- print("Recovery Calculation: avg(9,10,11,12 pulses) - avg(1,2,3,4 pulses) / max amps")
506
- print(f"{format_array(amp[:, 8:12])} - {format_array(amp[:, :4])} / {format_array(maxamp)}")
507
- print(f"{format_array(amp[:, 8:12].mean(axis=1))} - {format_array(amp[:, :4].mean(axis=1))} / {format_array(maxamp)} = {format_array(recovery)}")
508
- print("")
509
-
510
-
532
+ induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
533
+ ppr = amp[:, 1:2] / amp[:, 0:1]
511
534
  # maxamp = max(amp, key=lambda x: abs(x[0]))
512
535
  maxamp = maxamp.max()
513
- #return induction, recovery, maxamp
536
+
537
+ return ppr, induction, recovery
514
538
 
515
539
 
516
540
  def _set_syn_prop(self, **kwargs):
@@ -575,17 +599,19 @@ class SynapseTuner:
575
599
  duration0 = 300
576
600
  vlamp_status = self.vclamp
577
601
 
578
- w_run = widgets.Button(description='Run', icon='history', button_style='primary')
602
+ w_run = widgets.Button(description='Run Train', icon='history', button_style='primary')
603
+ w_single = widgets.Button(description='Single Event', icon='check', button_style='success')
579
604
  w_vclamp = widgets.ToggleButton(value=vlamp_status, description='Voltage Clamp', icon='fast-backward', button_style='warning')
580
605
  w_input_mode = widgets.ToggleButton(value=False, description='Continuous input', icon='eject', button_style='info')
581
606
  w_input_freq = widgets.SelectionSlider(options=freqs, value=freq0, description='Input Freq')
582
607
 
608
+
583
609
  # Sliders for delay and duration
584
610
  self.w_delay = widgets.SelectionSlider(options=delays, value=delay0, description='Delay')
585
611
  self.w_duration = widgets.SelectionSlider(options=durations, value=duration0, description='Duration')
586
612
 
587
613
  # Generate sliders dynamically based on valid numeric entries in self.slider_vars
588
- dynamic_sliders = {}
614
+ self.dynamic_sliders = {}
589
615
  print("Setting up slider! The sliders ranges are set by their init value so try changing that if you dont like the slider range!")
590
616
  for key, value in self.slider_vars.items():
591
617
  if isinstance(value, (int, float)): # Only create sliders for numeric values
@@ -595,18 +621,25 @@ class SynapseTuner:
595
621
  slider = widgets.FloatSlider(value=value, min=0, max=1000, step=1, description=key)
596
622
  else:
597
623
  slider = widgets.FloatSlider(value=value, min=0, max=value*20, step=value/5, description=key)
598
- dynamic_sliders[key] = slider
624
+ self.dynamic_sliders[key] = slider
599
625
  else:
600
626
  print(f"skipping slider for {key} due to not being a synaptic variable")
601
627
 
628
+ def run_single_event(*args):
629
+ clear_output()
630
+ display(ui)
631
+ self.vclamp = w_vclamp.value
632
+ # Update synaptic properties based on slider values
633
+ self.ispk=None
634
+ self.SingleEvent()
635
+
602
636
  # Function to update UI based on input mode
603
637
  def update_ui(*args):
604
638
  clear_output()
605
639
  display(ui)
606
640
  self.vclamp = w_vclamp.value
607
641
  self.input_mode = w_input_mode.value
608
- # Update synaptic properties based on slider values
609
- syn_props = {var: slider.value for var, slider in dynamic_sliders.items()}
642
+ syn_props = {var: slider.value for var, slider in self.dynamic_sliders.items()}
610
643
  self._set_syn_prop(**syn_props)
611
644
  if self.input_mode == False:
612
645
  self._simulate_model(w_input_freq.value, self.w_delay.value, w_vclamp.value)
@@ -614,10 +647,7 @@ class SynapseTuner:
614
647
  self._simulate_model(w_input_freq.value, self.w_duration.value, w_vclamp.value)
615
648
  amp = self._response_amplitude()
616
649
  self._plot_model([self.general_settings['tstart'] - self.nstim.interval / 3, self.tstop])
617
- self._print_ppr_induction_recovery(amp)
618
- # print('Single trial ' + ('PSC' if self.vclamp else 'PSP'))
619
- # print(f'Induction: {induction_single:.2f}; Recovery: {recovery:.2f}')
620
- #print(f'Rest Amp: {amp[0]:.2f}; Maximum Amp: {maxamp:.2f}')
650
+ _ = self._calc_ppr_induction_recovery(amp)
621
651
 
622
652
  # Function to switch between delay and duration sliders
623
653
  def switch_slider(*args):
@@ -634,23 +664,23 @@ class SynapseTuner:
634
664
  # Hide the duration slider initially until the user selects it
635
665
  self.w_duration.layout.display = 'none' # Hide duration slider
636
666
 
667
+ w_single.on_click(run_single_event)
637
668
  w_run.on_click(update_ui)
638
669
 
639
670
  # Add the dynamic sliders to the UI
640
- slider_widgets = [slider for slider in dynamic_sliders.values()]
671
+ slider_widgets = [slider for slider in self.dynamic_sliders.values()]
641
672
 
642
- # Divide sliders into two columns
643
- half = len(slider_widgets) // 2
644
- col1 = VBox(slider_widgets[:half]) # First half of sliders
645
- col2 = VBox(slider_widgets[half:]) # Second half of sliders
673
+ button_row = HBox([w_run, w_single, w_vclamp, w_input_mode])
674
+ slider_row = HBox([w_input_freq, self.w_delay, self.w_duration])
646
675
 
647
- # Create a two-column layout with HBox
676
+ half = len(slider_widgets) // 2
677
+ col1 = VBox(slider_widgets[:half])
678
+ col2 = VBox(slider_widgets[half:])
648
679
  slider_columns = HBox([col1, col2])
649
680
 
650
- ui = VBox([HBox([w_run, w_vclamp, w_input_mode]), HBox([w_input_freq, self.w_delay, self.w_duration]), slider_columns])
681
+ ui = VBox([button_row, slider_row, slider_columns])
651
682
 
652
683
  display(ui)
653
- # run model with default parameters
654
684
  update_ui()
655
685
 
656
686
  class GapJunctionTuner:
@@ -738,7 +768,7 @@ class GapJunctionTuner:
738
768
  return (v2[idx2] - v2[idx1]) / (v1[idx2] - v1[idx1])
739
769
 
740
770
 
741
- def run_model(self):
771
+ def InteractiveTuner(self):
742
772
  w_run = widgets.Button(description='Run', icon='history', button_style='primary')
743
773
  values = [i * 10**-4 for i in range(1, 101)] # From 1e-4 to 1e-2
744
774
 
@@ -762,4 +792,402 @@ class GapJunctionTuner:
762
792
  print(f"coupling_coefficient is {cc:0.4f}")
763
793
 
764
794
  on_button()
765
- w_run.on_click(on_button)
795
+ w_run.on_click(on_button)
796
+
797
+
798
+ # optimizers!
799
+
800
+ @dataclass
801
+ class SynapseOptimizationResult:
802
+ """Container for synaptic parameter optimization results"""
803
+ optimal_params: Dict[str, float]
804
+ achieved_metrics: Dict[str, float]
805
+ target_metrics: Dict[str, float]
806
+ error: float
807
+ optimization_path: List[Dict[str, float]]
808
+
809
+ class SynapseOptimizer:
810
+ def __init__(self, tuner):
811
+ """
812
+ Initialize the synapse optimizer with parameter scaling
813
+
814
+ Parameters:
815
+ -----------
816
+ tuner : SynapseTuner
817
+ Instance of the SynapseTuner class
818
+ """
819
+ self.tuner = tuner
820
+ self.optimization_history = []
821
+ self.param_scales = {}
822
+
823
+ def _normalize_params(self, params: np.ndarray, param_names: List[str]) -> np.ndarray:
824
+ """Normalize parameters to similar scales"""
825
+ return np.array([params[i] / self.param_scales[name] for i, name in enumerate(param_names)])
826
+
827
+ def _denormalize_params(self, normalized_params: np.ndarray, param_names: List[str]) -> np.ndarray:
828
+ """Convert normalized parameters back to original scale"""
829
+ return np.array([normalized_params[i] * self.param_scales[name] for i, name in enumerate(param_names)])
830
+
831
+ def _calculate_metrics(self) -> Dict[str, float]:
832
+ """Calculate standard metrics from the current simulation"""
833
+ self.tuner._simulate_model(50, 250) # 50 Hz with 250ms Delay
834
+ amp = self.tuner._response_amplitude()
835
+ ppr, induction, recovery = self.tuner._calc_ppr_induction_recovery(amp, print_math=False)
836
+ amp = self.tuner._find_max_amp(amp)
837
+ return {
838
+ 'induction': float(induction), # Ensure these are scalar values
839
+ 'ppr': float(ppr),
840
+ 'recovery': float(recovery),
841
+ 'max_amplitude': float(amp)
842
+ }
843
+
844
+ def _default_cost_function(self, metrics: Dict[str, float], target_metrics: Dict[str, float]) -> float:
845
+ """Default cost function that targets induction"""
846
+ return float((metrics['induction'] - target_metrics['induction']) ** 2)
847
+
848
+ def _objective_function(self,
849
+ normalized_params: np.ndarray,
850
+ param_names: List[str],
851
+ cost_function: Callable,
852
+ target_metrics: Dict[str, float]) -> float:
853
+ """
854
+ Calculate error using provided cost function
855
+ """
856
+ # Denormalize parameters
857
+ params = self._denormalize_params(normalized_params, param_names)
858
+
859
+ # Set parameters
860
+ for name, value in zip(param_names, params):
861
+ setattr(self.tuner.syn, name, value)
862
+
863
+ # Calculate metrics and error
864
+ metrics = self._calculate_metrics()
865
+ error = float(cost_function(metrics, target_metrics)) # Ensure error is scalar
866
+
867
+ # Store history with denormalized values
868
+ history_entry = {
869
+ 'params': dict(zip(param_names, params)),
870
+ 'metrics': metrics,
871
+ 'error': error
872
+ }
873
+ self.optimization_history.append(history_entry)
874
+
875
+ return error
876
+
877
+ def optimize_parameters(self,
878
+ target_metrics: Dict[str, float],
879
+ param_bounds: Dict[str, Tuple[float, float]],
880
+ cost_function: Optional[Callable] = None,
881
+ method: str = 'SLSQP',init_guess='random') -> SynapseOptimizationResult:
882
+ """
883
+ Optimize synaptic parameters using custom cost function
884
+ """
885
+ self.optimization_history = []
886
+ param_names = list(param_bounds.keys())
887
+ bounds = [param_bounds[name] for name in param_names]
888
+
889
+ if cost_function is None:
890
+ cost_function = self._default_cost_function
891
+
892
+ # Calculate scaling factors
893
+ self.param_scales = {
894
+ name: max(abs(bounds[i][0]), abs(bounds[i][1]))
895
+ for i, name in enumerate(param_names)
896
+ }
897
+
898
+ # Normalize bounds
899
+ normalized_bounds = [
900
+ (b[0]/self.param_scales[name], b[1]/self.param_scales[name])
901
+ for name, b in zip(param_names, bounds)
902
+ ]
903
+
904
+ # picks with method of init value we want to use
905
+ if init_guess=='random':
906
+ x0 = np.array([np.random.uniform(b[0], b[1]) for b in bounds])
907
+ elif init_guess=='middle_guess':
908
+ x0 = [(b[0] + b[1])/2 for b in bounds]
909
+ else:
910
+ raise Exception("Pick a vaid init guess method either random or midde_guess")
911
+ normalized_x0 = self._normalize_params(np.array(x0), param_names)
912
+
913
+
914
+ # Run optimization
915
+ result = minimize(
916
+ self._objective_function,
917
+ normalized_x0,
918
+ args=(param_names, cost_function, target_metrics),
919
+ method=method,
920
+ bounds=normalized_bounds
921
+ )
922
+
923
+ # Get final parameters and metrics
924
+ final_params = dict(zip(param_names, self._denormalize_params(result.x, param_names)))
925
+ for name, value in final_params.items():
926
+ setattr(self.tuner.syn, name, value)
927
+ final_metrics = self._calculate_metrics()
928
+
929
+ return SynapseOptimizationResult(
930
+ optimal_params=final_params,
931
+ achieved_metrics=final_metrics,
932
+ target_metrics=target_metrics,
933
+ error=result.fun,
934
+ optimization_path=self.optimization_history
935
+ )
936
+
937
+ def plot_optimization_results(self, result: SynapseOptimizationResult):
938
+ """Plot optimization results including convergence and final traces."""
939
+ # Ensure errors are properly shaped for plotting
940
+ iterations = range(len(result.optimization_path))
941
+ errors = np.array([float(h['error']) for h in result.optimization_path]).flatten()
942
+
943
+ # Plot error convergence
944
+ fig1, ax1 = plt.subplots(figsize=(8, 5))
945
+ ax1.plot(iterations, errors, label='Error')
946
+ ax1.set_xlabel('Iteration')
947
+ ax1.set_ylabel('Error')
948
+ ax1.set_title('Error Convergence')
949
+ ax1.set_yscale('log')
950
+ ax1.legend()
951
+ plt.tight_layout()
952
+ plt.show()
953
+
954
+ # Plot parameter convergence
955
+ param_names = list(result.optimal_params.keys())
956
+ num_params = len(param_names)
957
+ fig2, axs = plt.subplots(nrows=num_params, ncols=1, figsize=(8, 5 * num_params))
958
+
959
+ if num_params == 1:
960
+ axs = [axs]
961
+
962
+ for ax, param in zip(axs, param_names):
963
+ values = [float(h['params'][param]) for h in result.optimization_path]
964
+ ax.plot(iterations, values, label=f'{param}')
965
+ ax.set_xlabel('Iteration')
966
+ ax.set_ylabel('Parameter Value')
967
+ ax.set_title(f'Convergence of {param}')
968
+ ax.legend()
969
+
970
+ plt.tight_layout()
971
+ plt.show()
972
+
973
+ # Print final results
974
+ print("Optimization Results:")
975
+ print(f"Final Error: {float(result.error):.2e}\n")
976
+ print("Target Metrics:")
977
+ for metric, value in result.target_metrics.items():
978
+ achieved = result.achieved_metrics.get(metric)
979
+ if achieved is not None and metric != 'amplitudes': # Skip amplitude array
980
+ print(f"{metric}: {float(achieved):.3f} (target: {float(value):.3f})")
981
+
982
+ print("\nOptimal Parameters:")
983
+ for param, value in result.optimal_params.items():
984
+ print(f"{param}: {float(value):.3f}")
985
+
986
+ # Plot final model response
987
+ self.tuner._plot_model([self.tuner.general_settings['tstart'] - self.tuner.nstim.interval / 3, self.tuner.tstop])
988
+ amp = self.tuner._response_amplitude()
989
+ self.tuner._calc_ppr_induction_recovery(amp)
990
+
991
+
992
+ # dataclass means just init the typehints as self.typehint. looks a bit cleaner
993
+ @dataclass
994
+ class GapOptimizationResult:
995
+ """Container for gap junction optimization results"""
996
+ optimal_resistance: float
997
+ achieved_cc: float
998
+ target_cc: float
999
+ error: float
1000
+ optimization_path: List[Dict[str, float]]
1001
+
1002
+ class GapJunctionOptimizer:
1003
+ def __init__(self, tuner):
1004
+ """
1005
+ Initialize the gap junction optimizer
1006
+
1007
+ Parameters:
1008
+ -----------
1009
+ tuner : GapJunctionTuner
1010
+ Instance of the GapJunctionTuner class
1011
+ """
1012
+ self.tuner = tuner
1013
+ self.optimization_history = []
1014
+
1015
+ def _objective_function(self, resistance: float, target_cc: float) -> float:
1016
+ """
1017
+ Calculate error between achieved and target coupling coefficient
1018
+
1019
+ Parameters:
1020
+ -----------
1021
+ resistance : float
1022
+ Gap junction resistance to try
1023
+ target_cc : float
1024
+ Target coupling coefficient to match
1025
+
1026
+ Returns:
1027
+ --------
1028
+ float : Error between achieved and target coupling coefficient
1029
+ """
1030
+ # Run model with current resistance
1031
+ self.tuner.model(resistance)
1032
+
1033
+ # Calculate coupling coefficient
1034
+ achieved_cc = self.tuner.coupling_coefficient(
1035
+ self.tuner.t_vec,
1036
+ self.tuner.soma_v_1,
1037
+ self.tuner.soma_v_2,
1038
+ self.tuner.general_settings['tstart'],
1039
+ self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
1040
+ )
1041
+
1042
+ # Calculate error
1043
+ error = (achieved_cc - target_cc) ** 2 #MSE
1044
+
1045
+ # Store history
1046
+ self.optimization_history.append({
1047
+ 'resistance': resistance,
1048
+ 'achieved_cc': achieved_cc,
1049
+ 'error': error
1050
+ })
1051
+
1052
+ return error
1053
+
1054
+ def optimize_resistance(self, target_cc: float,
1055
+ resistance_bounds: tuple = (1e-4, 1e-2),
1056
+ method: str = 'bounded') -> GapOptimizationResult:
1057
+ """
1058
+ Optimize gap junction resistance to achieve target coupling coefficient
1059
+
1060
+ Parameters:
1061
+ -----------
1062
+ target_cc : float
1063
+ Target coupling coefficient to achieve
1064
+ resistance_bounds : tuple, optional
1065
+ (min, max) bounds for resistance search
1066
+ method : str, optional
1067
+ Optimization method to use (default: 'bounded')
1068
+
1069
+ Returns:
1070
+ --------
1071
+ GapOptimizationResult
1072
+ Container with optimization results
1073
+ """
1074
+ self.optimization_history = []
1075
+
1076
+ # Run optimization
1077
+ result = minimize_scalar(
1078
+ self._objective_function,
1079
+ args=(target_cc,),
1080
+ bounds=resistance_bounds,
1081
+ method=method
1082
+ )
1083
+
1084
+ # Run final model with optimal resistance
1085
+ self.tuner.model(result.x)
1086
+ final_cc = self.tuner.coupling_coefficient(
1087
+ self.tuner.t_vec,
1088
+ self.tuner.soma_v_1,
1089
+ self.tuner.soma_v_2,
1090
+ self.tuner.general_settings['tstart'],
1091
+ self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
1092
+ )
1093
+
1094
+ # Package up our results
1095
+ optimization_result = GapOptimizationResult(
1096
+ optimal_resistance=result.x,
1097
+ achieved_cc=final_cc,
1098
+ target_cc=target_cc,
1099
+ error=result.fun,
1100
+ optimization_path=self.optimization_history
1101
+ )
1102
+
1103
+ return optimization_result
1104
+
1105
+ def plot_optimization_results(self, result: GapOptimizationResult):
1106
+ """
1107
+ Plot optimization results including convergence and final voltage traces
1108
+
1109
+ Parameters:
1110
+ -----------
1111
+ result : GapOptimizationResult
1112
+ Results from optimization
1113
+ """
1114
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
1115
+
1116
+ # Plot voltage traces
1117
+ t_range = [
1118
+ self.tuner.general_settings['tstart'] - 100.,
1119
+ self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur'] + 100.
1120
+ ]
1121
+ t = np.array(self.tuner.t_vec)
1122
+ v1 = np.array(self.tuner.soma_v_1)
1123
+ v2 = np.array(self.tuner.soma_v_2)
1124
+ tidx = (t >= t_range[0]) & (t <= t_range[1])
1125
+
1126
+ ax1.plot(t[tidx], v1[tidx], 'b', label=f'{self.tuner.cell_name} 1')
1127
+ ax1.plot(t[tidx], v2[tidx], 'r', label=f'{self.tuner.cell_name} 2')
1128
+ ax1.set_xlabel('Time (ms)')
1129
+ ax1.set_ylabel('Membrane Voltage (mV)')
1130
+ ax1.legend()
1131
+ ax1.set_title('Optimized Voltage Traces')
1132
+
1133
+ # Plot error convergence
1134
+ errors = [h['error'] for h in result.optimization_path]
1135
+ ax2.plot(errors)
1136
+ ax2.set_xlabel('Iteration')
1137
+ ax2.set_ylabel('Error')
1138
+ ax2.set_title('Error Convergence')
1139
+ ax2.set_yscale('log')
1140
+
1141
+ # Plot resistance convergence
1142
+ resistances = [h['resistance'] for h in result.optimization_path]
1143
+ ax3.plot(resistances)
1144
+ ax3.set_xlabel('Iteration')
1145
+ ax3.set_ylabel('Resistance')
1146
+ ax3.set_title('Resistance Convergence')
1147
+ ax3.set_yscale('log')
1148
+
1149
+ # Print final results
1150
+ result_text = (
1151
+ f'Optimal Resistance: {result.optimal_resistance:.2e}\n'
1152
+ f'Target CC: {result.target_cc:.3f}\n'
1153
+ f'Achieved CC: {result.achieved_cc:.3f}\n'
1154
+ f'Final Error: {result.error:.2e}'
1155
+ )
1156
+ ax4.text(0.1, 0.7, result_text, transform=ax4.transAxes, fontsize=10)
1157
+ ax4.axis('off')
1158
+
1159
+ plt.tight_layout()
1160
+ plt.show()
1161
+
1162
+ def parameter_sweep(self, resistance_range: np.ndarray) -> dict:
1163
+ """
1164
+ Perform a parameter sweep across different resistance values
1165
+
1166
+ Parameters:
1167
+ -----------
1168
+ resistance_range : np.ndarray
1169
+ Array of resistance values to test
1170
+
1171
+ Returns:
1172
+ --------
1173
+ dict : Results of parameter sweep including coupling coefficients
1174
+ """
1175
+ results = {
1176
+ 'resistance': [],
1177
+ 'coupling_coefficient': []
1178
+ }
1179
+
1180
+ for resistance in tqdm(resistance_range, desc="Sweeping resistance values"):
1181
+ self.tuner.model(resistance)
1182
+ cc = self.tuner.coupling_coefficient(
1183
+ self.tuner.t_vec,
1184
+ self.tuner.soma_v_1,
1185
+ self.tuner.soma_v_2,
1186
+ self.tuner.general_settings['tstart'],
1187
+ self.tuner.general_settings['tstart'] + self.tuner.general_settings['tdur']
1188
+ )
1189
+
1190
+ results['resistance'].append(resistance)
1191
+ results['coupling_coefficient'].append(cc)
1192
+
1193
+ return results
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bmtool
3
- Version: 0.6.3
3
+ Version: 0.6.5
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -114,7 +114,9 @@ Commands:
114
114
  - [ZAP](#zap)
115
115
  - [Tuner](#single-cell-tuning)
116
116
  - [VHalf Segregation](#vhalf-segregation-module)
117
- #### The single cell module can take any neuron HOC object and calculate passive properties, run a current clamp, calculate FI curve, or run a ZAP. The module is designed to work with HOC template files and can also turn Allen database SWC and json files into HOC objects and use those. The examples below uses "Cell_Cf" which is the name of a HOC templated loaded by the profiler.
117
+ #### Jupyter Notebook for how to use passive properties, current injection, FI curve, and ZAP can be found [here](examples/single_cell/). There are versions with example how to use single cells in HOC format and in the Allen Database format.
118
+
119
+ #### The single cell module can take any neuron HOC object and calculate passive properties, run a current clamp, calculate FI curve, or run a ZAP. The module is designed to work with HOC template files and can also turn Allen database SWC and json files into HOC objects and use those. The examples below uses "Cell_Cf" which is the name of a HOC templated loaded by the profiler. E
118
120
 
119
121
  #### First step is it initialize the profiler.
120
122
 
@@ -351,7 +353,7 @@ ex: [https://github.com/tjbanks/two-cell-hco](https://github.com/tjbanks/two-cel
351
353
  -Gap Junction tuner
352
354
 
353
355
  #### SynapticTuner - Aids in the tuning of synapses by printing out synaptic properties and giving the user sliders in a Jupyter notebook to tune the synapse. For more info view the example [here](examples/synapses/synaptic_tuner.ipynb)
354
- #### GapJunctionTuner - Provides jupyter sliders to tune for coupling coefficient in a similar style to the SynapticTuner an example can be viewed [here](examples/synapses/gap_junction_tuner.ipynb)
356
+ #### GapJunctionTuner - Provides jupyter sliders to tune for coupling coefficient in a similar style to the SynapticTuner. The Gap junction tuner also has an optimizer which can find the best resistance for the desired coupling coefficient. an example can be viewed [here](examples/synapses/gap_junction_tuner.ipynb)
355
357
 
356
358
  ### Connectors Module
357
359
  - [UnidirectionConnector](#unidirectional-connector---unidirectional-connections-in-bmtk-network-model-with-given-probability-within-a-single-population-or-between-two-populations)
@@ -409,7 +411,6 @@ net.add_edges(**connector.edge_params())
409
411
  ```
410
412
 
411
413
  ## Bmplot Module
412
- ### for a demo please see the notebook [here](examples/bmplot/bmplot.ipynb)
413
414
  - [total_connection_matrix](#total_connection_matrix)
414
415
  - [percent_connection_matrix](#percent_connection_matrix)
415
416
  - [connector_percent_matrix](#connector_percent_matrix)
@@ -420,6 +421,7 @@ net.add_edges(**connector.edge_params())
420
421
  - [connection_histogram](#connection_histogram)
421
422
  - [plot_3d_positions](#plot_3d_positions)
422
423
  - [plot_3d_cell_rotation](#plot_3d_cell_rotation)
424
+ ### for a demo please see the notebook [here](examples/bmplot/bmplot.ipynb)
423
425
 
424
426
  ### total_connection_matrix
425
427
  #### Generates a table of total number of connections each neuron population recieves
@@ -2,12 +2,12 @@ bmtool/SLURM.py,sha256=4KvtrPofaHv5iairetgrlXdhAdowJfg_aeTx9W59JDM,16618
2
2
  bmtool/__init__.py,sha256=ZStTNkAJHJxG7Pwiy5UgCzC4KlhMS5pUNPtUJZVwL_Y,136
3
3
  bmtool/__main__.py,sha256=TmFkmDxjZ6250nYD4cgGhn-tbJeEm0u-EMz2ajAN9vE,650
4
4
  bmtool/bmplot.py,sha256=Im-Jrv8TK3CmTtksFzHrVogAve0l9ZwRrCW4q2MFRiA,53966
5
- bmtool/connectors.py,sha256=2vVUsqYMaCuWZ-4C5eUzqwsFItFM9vm0ytZdRQdWgoc,72243
5
+ bmtool/connectors.py,sha256=hWkUUcJ4tmas8NDOFPPjQT-TgTlPcpjuZsYyAW2WkPA,72242
6
6
  bmtool/graphs.py,sha256=K8BiughRUeXFVvAgo8UzrwpSClIVg7UfmIcvtEsEsk0,6020
7
7
  bmtool/manage.py,sha256=_lCU0qBQZ4jSxjzAJUd09JEetb--cud7KZgxQFbLGSY,657
8
8
  bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
9
- bmtool/singlecell.py,sha256=Q4poQvG9fw0jlyMmHFzbRPrpcEkPz5MKS8Guuo73Bzs,26849
10
- bmtool/synapses.py,sha256=xNq2ln9XRf5CZa6dZ_dabPTDrrufQTiVVoYYeqQmDN4,32597
9
+ bmtool/singlecell.py,sha256=MQiLucsI6OBIjtcJra3Z9PTFQOE-Zn5ST-R9SmFvrbQ,27049
10
+ bmtool/synapses.py,sha256=jQFOpi9hzzBEijDQ7dsWfcxW-DtxN9v0UWxCqDSlcTs,48466
11
11
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
13
13
  bmtool/debug/debug.py,sha256=xqnkzLiH3s-tS26Y5lZZL62qR2evJdi46Gud-HzxEN4,207
@@ -16,9 +16,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
16
16
  bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
17
17
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
19
- bmtool-0.6.3.dist-info/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
20
- bmtool-0.6.3.dist-info/METADATA,sha256=tQ9eD3BcEnX6hwKmJ92Syr2OO8cR47KolGIcWud1-TM,19859
21
- bmtool-0.6.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
22
- bmtool-0.6.3.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
23
- bmtool-0.6.3.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
24
- bmtool-0.6.3.dist-info/RECORD,,
19
+ bmtool-0.6.5.dist-info/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
20
+ bmtool-0.6.5.dist-info/METADATA,sha256=-67VVvOyPqiGX7qoidn8SLVHuHRxSZnOIPKPxqZ8TNI,20224
21
+ bmtool-0.6.5.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
22
+ bmtool-0.6.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
23
+ bmtool-0.6.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
24
+ bmtool-0.6.5.dist-info/RECORD,,
File without changes