bmtool 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bmtool might be problematic. Click here for more details.

bmtool/synapses.py CHANGED
@@ -38,6 +38,7 @@ DEFAULT_GAP_JUNCTION_GENERAL_SETTINGS = {
38
38
  "tdur": 500.0,
39
39
  "dt": 0.025,
40
40
  "celsius": 20,
41
+ "iclamp_amp": -0.01, # nA
41
42
  }
42
43
 
43
44
 
@@ -172,6 +173,7 @@ class SynapseTuner:
172
173
  self.other_vars_to_record = other_vars_to_record or []
173
174
  self.ispk = None
174
175
  self.input_mode = False # Add input_mode attribute
176
+ self.last_figure = None # Store reference to last generated figure
175
177
 
176
178
  # Store original slider_vars for connection switching
177
179
  self.original_slider_vars = slider_vars or list(self.synaptic_props.keys())
@@ -1044,7 +1046,9 @@ class SynapseTuner:
1044
1046
  for j in range(num_vars_to_plot, len(axs)):
1045
1047
  fig.delaxes(axs[j])
1046
1048
 
1047
- # plt.tight_layout()
1049
+ #plt.tight_layout()
1050
+ fig.suptitle(f"Connection: {self.current_connection}")
1051
+ self.last_figure = plt.gcf()
1048
1052
  plt.show()
1049
1053
 
1050
1054
  def _set_drive_train(self, freq=50.0, delay=250.0):
@@ -1147,7 +1151,7 @@ class SynapseTuner:
1147
1151
 
1148
1152
  def _calc_ppr_induction_recovery(self, amp, normalize_by_trial=True, print_math=True):
1149
1153
  """
1150
- Calculates paired-pulse ratio, induction, and recovery metrics from response amplitudes.
1154
+ Calculates paired-pulse ratio, induction, recovery, and simple PPR metrics from response amplitudes.
1151
1155
 
1152
1156
  Parameters:
1153
1157
  -----------
@@ -1162,13 +1166,15 @@ class SynapseTuner:
1162
1166
  --------
1163
1167
  tuple
1164
1168
  A tuple containing:
1165
- - ppr: Paired-pulse ratio (2nd pulse / 1st pulse)
1169
+ - ppr: Paired-pulse ratio (2nd pulse - 1st pulse) normalized by 90th percentile amplitude
1166
1170
  - induction: Measure of facilitation/depression during initial pulses
1167
1171
  - recovery: Measure of recovery after the delay period
1172
+ - simple_ppr: Simple paired-pulse ratio (2nd pulse / 1st pulse)
1168
1173
 
1169
1174
  Notes:
1170
1175
  ------
1171
- - PPR > 1 indicates facilitation, PPR < 1 indicates depression
1176
+ - PPR > 0 indicates facilitation, PPR < 0 indicates depression
1177
+ - Simple PPR > 1 indicates facilitation, Simple PPR < 1 indicates depression
1172
1178
  - Induction > 0 indicates facilitation, Induction < 0 indicates depression
1173
1179
  - Recovery compares the response after delay to the initial pulses
1174
1180
  """
@@ -1189,34 +1195,44 @@ class SynapseTuner:
1189
1195
  f"Short Term Plasticity Results for {self.train_freq}Hz with {self.train_delay} Delay"
1190
1196
  )
1191
1197
  print("=" * 40)
1192
- print("PPR: Above 0 is facilitating, below 0 is depressing.")
1193
- print("Induction: Above 0 is facilitating, below 0 is depressing.")
1194
- print("Recovery: A measure of how fast STP decays.\n")
1198
+ print("Simple PPR: Above 1 is facilitating, below 1 is depressing")
1199
+ print("PPR: Above 0 is facilitating, below 0 is depressing.")
1200
+ print("Induction: Above 0 is facilitating, below 0 is depressing.")
1201
+ print("Recovery: A measure of how fast STP decays.\n")
1202
+
1203
+ # Simple PPR Calculation: Avg 2nd pulse / Avg 1st pulse
1204
+ simple_ppr = np.mean(amp[:, 1:2]) / np.mean(amp[:, 0:1])
1205
+ print("Simple Paired Pulse Ratio (PPR)")
1206
+ print(" Calculation: Avg 2nd pulse / Avg 1st pulse")
1207
+ print(
1208
+ f" Values: {np.mean(amp[:, 1:2]):.3f} / {np.mean(amp[:, 0:1]):.3f} = {simple_ppr:.3f}\n"
1209
+ )
1195
1210
 
1196
1211
  # PPR Calculation: (Avg 2nd pulse - Avg 1st pulse) / 90th percentile amplitude
1197
1212
  ppr = (np.mean(amp[:, 1:2]) - np.mean(amp[:, 0:1])) / percentile_90
1198
1213
  print("Paired Pulse Response (PPR)")
1199
- print("Calculation: (Avg 2nd pulse - Avg 1st pulse) / 90th percentile amplitude")
1214
+ print(" Calculation: (Avg 2nd pulse - Avg 1st pulse) / 90th percentile amplitude")
1200
1215
  print(
1201
- f"Values: ({np.mean(amp[:, 1:2]):.3f} - {np.mean(amp[:, 0:1]):.3f}) / {percentile_90:.3f} = {ppr:.3f}\n"
1216
+ f" Values: ({np.mean(amp[:, 1:2]):.3f} - {np.mean(amp[:, 0:1]):.3f}) / {percentile_90:.3f} = {ppr:.3f}\n"
1202
1217
  )
1218
+
1203
1219
 
1204
1220
  # Induction Calculation: (Avg (6th, 7th, 8th pulses) - Avg 1st pulse) / 90th percentile amplitude
1205
1221
  induction = (np.mean(amp[:, 5:8]) - np.mean(amp[:, :1])) / percentile_90
1206
1222
  print("Induction")
1207
- print("Calculation: (Avg(6th, 7th, 8th pulses) - Avg 1st pulse) / 90th percentile amplitude")
1223
+ print(" Calculation: (Avg(6th, 7th, 8th pulses) - Avg 1st pulse) / 90th percentile amplitude")
1208
1224
  print(
1209
- f"Values: {np.mean(amp[:, 5:8]):.3f} - {np.mean(amp[:, :1]):.3f} / {percentile_90:.3f} = {induction:.3f}\n"
1225
+ f" Values: {np.mean(amp[:, 5:8]):.3f} - {np.mean(amp[:, :1]):.3f} / {percentile_90:.3f} = {induction:.3f}\n"
1210
1226
  )
1211
1227
 
1212
1228
  # Recovery Calculation: (Avg (9th, 10th, 11th, 12th pulses) - Avg (1st, 2nd, 3rd, 4th pulses)) / 90th percentile amplitude
1213
1229
  recovery = (np.mean(amp[:, 8:12]) - np.mean(amp[:, :4])) / percentile_90
1214
1230
  print("Recovery")
1215
1231
  print(
1216
- "Calculation: (Avg(9th, 10th, 11th, 12th pulses) - Avg(1st to 4th pulses)) / 90th percentile amplitude"
1232
+ " Calculation: (Avg(9th, 10th, 11th, 12th pulses) - Avg(1st to 4th pulses)) / 90th percentile amplitude"
1217
1233
  )
1218
1234
  print(
1219
- f"Values: {np.mean(amp[:, 8:12]):.3f} - {np.mean(amp[:, :4]):.3f} / {percentile_90:.3f} = {recovery:.3f}\n"
1235
+ f" Values: {np.mean(amp[:, 8:12]):.3f} - {np.mean(amp[:, :4]):.3f} / {percentile_90:.3f} = {recovery:.3f}\n"
1220
1236
  )
1221
1237
 
1222
1238
  print("=" * 40 + "\n")
@@ -1225,8 +1241,9 @@ class SynapseTuner:
1225
1241
  ppr = (np.mean(amp[:, 1:2]) - np.mean(amp[:, 0:1])) / percentile_90
1226
1242
  induction = (np.mean(amp[:, 5:8]) - np.mean(amp[:, :1])) / percentile_90
1227
1243
  recovery = (np.mean(amp[:, 8:12]) - np.mean(amp[:, :4])) / percentile_90
1244
+ simple_ppr = np.mean(amp[:, 1:2]) / np.mean(amp[:, 0:1])
1228
1245
 
1229
- return ppr, induction, recovery
1246
+ return ppr, induction, recovery, simple_ppr
1230
1247
 
1231
1248
  def _set_syn_prop(self, **kwargs):
1232
1249
  """
@@ -1267,19 +1284,19 @@ class SynapseTuner:
1267
1284
  for i in range(3):
1268
1285
  self.vcl.amp[i] = self.conn["spec_settings"]["vclamp_amp"]
1269
1286
  self.vcl.dur[i] = vcldur[1][i]
1270
- #h.finitialize(self.cell.Vinit * mV)
1271
- #h.continuerun(self.tstop * ms)
1272
- h.run()
1287
+ h.finitialize(70 * mV)
1288
+ h.continuerun(self.tstop * ms)
1289
+ #h.run()
1273
1290
  else:
1274
- self.tstop = self.general_settings["tstart"] + self.general_settings["tdur"]
1291
+ # Continuous input mode: ensure simulation runs long enough for the full stimulation duration
1292
+ self.tstop = self.general_settings["tstart"] + self.w_duration.value + 300 # 300ms buffer time
1275
1293
  self.nstim.interval = 1000 / input_frequency
1276
1294
  self.nstim.number = np.ceil(self.w_duration.value / 1000 * input_frequency + 1)
1277
1295
  self.nstim2.number = 0
1278
- self.tstop = self.w_duration.value + self.general_settings["tstart"]
1279
1296
 
1280
- #h.finitialize(self.cell.Vinit * mV)
1281
- #h.continuerun(self.tstop * ms)
1282
- h.run()
1297
+ h.finitialize(70 * mV)
1298
+ h.continuerun(self.tstop * ms)
1299
+ #h.run()
1283
1300
 
1284
1301
  def InteractiveTuner(self):
1285
1302
  """
@@ -1325,7 +1342,7 @@ class SynapseTuner:
1325
1342
  vlamp_status = self.vclamp
1326
1343
 
1327
1344
  # Connection dropdown
1328
- connection_options = list(self.conn_type_settings.keys())
1345
+ connection_options = sorted(list(self.conn_type_settings.keys()))
1329
1346
  w_connection = widgets.Dropdown(
1330
1347
  options=connection_options,
1331
1348
  value=self.current_connection,
@@ -1374,6 +1391,52 @@ class SynapseTuner:
1374
1391
  options=durations, value=duration0, description="Duration"
1375
1392
  )
1376
1393
 
1394
+ # Save functionality widgets
1395
+ save_path_text = widgets.Text(
1396
+ value="plot.png",
1397
+ description="Save path:",
1398
+ layout=widgets.Layout(width='300px')
1399
+ )
1400
+ save_button = widgets.Button(description="Save Plot", icon="save", button_style="success")
1401
+
1402
+ def save_plot(b):
1403
+ if hasattr(self, 'last_figure') and self.last_figure is not None:
1404
+ try:
1405
+ # Create a new figure with just the first subplot (synaptic current)
1406
+ fig, ax = plt.subplots(figsize=(8, 6))
1407
+
1408
+ # Get the axes from the original figure
1409
+ original_axes = self.last_figure.get_axes()
1410
+ if len(original_axes) > 0:
1411
+ first_ax = original_axes[0]
1412
+
1413
+ # Copy the data from the first subplot
1414
+ for line in first_ax.get_lines():
1415
+ ax.plot(line.get_xdata(), line.get_ydata(),
1416
+ color=line.get_color(), label=line.get_label())
1417
+
1418
+ # Copy axis labels and title
1419
+ ax.set_xlabel(first_ax.get_xlabel())
1420
+ ax.set_ylabel(first_ax.get_ylabel())
1421
+ ax.set_title(first_ax.get_title())
1422
+ ax.set_xlim(first_ax.get_xlim())
1423
+ ax.legend()
1424
+ ax.grid(True)
1425
+
1426
+ # Save the new figure
1427
+ fig.savefig(save_path_text.value)
1428
+ plt.close(fig) # Close the temporary figure
1429
+ print(f"Synaptic current plot saved to {save_path_text.value}")
1430
+ else:
1431
+ print("No subplots found in the figure")
1432
+
1433
+ except Exception as e:
1434
+ print(f"Error saving plot: {e}")
1435
+ else:
1436
+ print("No plot to save")
1437
+
1438
+ save_button.on_click(save_plot)
1439
+
1377
1440
  def create_dynamic_sliders():
1378
1441
  """Create sliders based on current connection's parameters"""
1379
1442
  sliders = {}
@@ -1452,7 +1515,7 @@ class SynapseTuner:
1452
1515
  the network dropdown. It coordinates the complete switching process:
1453
1516
  1. Calls _switch_network() to rebuild connections for the new network
1454
1517
  2. Updates the connection dropdown options with new network's connections
1455
- 3. Recreates dynamic sliders for the new connection parameters
1518
+ 3. Recreates dynamic sliders for new connection parameters
1456
1519
  4. Refreshes the entire UI to reflect all changes
1457
1520
  """
1458
1521
  if w_network is None:
@@ -1514,8 +1577,9 @@ class SynapseTuner:
1514
1577
  else:
1515
1578
  connection_row = HBox([w_connection])
1516
1579
  slider_row = HBox([w_input_freq, self.w_delay, self.w_duration])
1580
+ save_row = HBox([save_path_text, save_button])
1517
1581
 
1518
- ui = VBox([connection_row, button_row, slider_row, slider_columns])
1582
+ ui = VBox([connection_row, button_row, slider_row, slider_columns, save_row])
1519
1583
 
1520
1584
  # Function to update UI based on input mode
1521
1585
  def update_ui(*args):
@@ -1617,6 +1681,7 @@ class SynapseTuner:
1617
1681
  Dictionary containing frequency-dependent metrics with keys:
1618
1682
  - 'frequencies': List of tested frequencies
1619
1683
  - 'ppr': Paired-pulse ratios at each frequency
1684
+ - 'simple_ppr': Simple paired-pulse ratios (2nd/1st pulse) at each frequency
1620
1685
  - 'induction': Induction values at each frequency
1621
1686
  - 'recovery': Recovery values at each frequency
1622
1687
 
@@ -1626,7 +1691,7 @@ class SynapseTuner:
1626
1691
  behavior of synapses, such as identifying facilitating vs. depressing regimes
1627
1692
  or the frequency at which a synapse transitions between these behaviors.
1628
1693
  """
1629
- results = {"frequencies": freqs, "ppr": [], "induction": [], "recovery": []}
1694
+ results = {"frequencies": freqs, "ppr": [], "induction": [], "recovery": [], "simple_ppr": []}
1630
1695
 
1631
1696
  # Store original state
1632
1697
  original_ispk = self.ispk
@@ -1634,11 +1699,12 @@ class SynapseTuner:
1634
1699
  for freq in tqdm(freqs, desc="Analyzing frequencies"):
1635
1700
  self._simulate_model(freq, delay)
1636
1701
  amp = self._response_amplitude()
1637
- ppr, induction, recovery = self._calc_ppr_induction_recovery(amp, print_math=False)
1702
+ ppr, induction, recovery, simple_ppr = self._calc_ppr_induction_recovery(amp, print_math=False)
1638
1703
 
1639
1704
  results["ppr"].append(float(ppr))
1640
1705
  results["induction"].append(float(induction))
1641
1706
  results["recovery"].append(float(recovery))
1707
+ results["simple_ppr"].append(float(simple_ppr))
1642
1708
 
1643
1709
  # Restore original state
1644
1710
  self.ispk = original_ispk
@@ -1658,6 +1724,7 @@ class SynapseTuner:
1658
1724
  Dictionary containing frequency analysis results with keys:
1659
1725
  - 'frequencies': List of tested frequencies
1660
1726
  - 'ppr': Paired-pulse ratios at each frequency
1727
+ - 'simple_ppr': Simple paired-pulse ratios at each frequency
1661
1728
  - 'induction': Induction values at each frequency
1662
1729
  - 'recovery': Recovery values at each frequency
1663
1730
  log_plot : bool
@@ -1666,24 +1733,27 @@ class SynapseTuner:
1666
1733
  Notes:
1667
1734
  ------
1668
1735
  Creates a figure with three subplots showing:
1669
- 1. Paired-pulse ratio vs. frequency
1736
+ 1. Paired-pulse ratios (both normalized and simple) vs. frequency
1670
1737
  2. Induction vs. frequency
1671
1738
  3. Recovery vs. frequency
1672
1739
 
1673
1740
  Each plot includes a horizontal reference line at y=0 or y=1 to indicate
1674
1741
  the boundary between facilitation and depression.
1675
1742
  """
1676
- fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
1743
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
1677
1744
 
1678
- # Plot PPR
1745
+ # Plot both PPR measures
1679
1746
  if log_plot:
1680
- ax1.semilogx(results["frequencies"], results["ppr"], "o-")
1747
+ ax1.semilogx(results["frequencies"], results["ppr"], "o-", label="Normalized PPR")
1748
+ ax1.semilogx(results["frequencies"], results["simple_ppr"], "s-", label="Simple PPR")
1681
1749
  else:
1682
- ax1.plot(results["frequencies"], results["ppr"], "o-")
1750
+ ax1.plot(results["frequencies"], results["ppr"], "o-", label="Normalized PPR")
1751
+ ax1.plot(results["frequencies"], results["simple_ppr"], "s-", label="Simple PPR")
1683
1752
  ax1.axhline(y=1, color="gray", linestyle="--", alpha=0.5)
1684
1753
  ax1.set_xlabel("Frequency (Hz)")
1685
1754
  ax1.set_ylabel("Paired Pulse Ratio")
1686
1755
  ax1.set_title("PPR vs Frequency")
1756
+ ax1.legend()
1687
1757
  ax1.grid(True)
1688
1758
 
1689
1759
  # Plot Induction
@@ -1712,6 +1782,168 @@ class SynapseTuner:
1712
1782
  plt.show()
1713
1783
 
1714
1784
 
1785
+ def generate_synaptic_table(self, stp_frequency=50.0, stp_delay=250.0, plot=True):
1786
+ """
1787
+ Generate a comprehensive table of synaptic parameters for all connections.
1788
+
1789
+ This method iterates through all available connections, runs simulations to
1790
+ characterize each synapse, and compiles the results into a pandas DataFrame.
1791
+
1792
+ Parameters:
1793
+ -----------
1794
+ stp_frequency : float, optional
1795
+ Frequency in Hz to use for STP (short-term plasticity) analysis. Default is 50.0 Hz.
1796
+ stp_delay : float, optional
1797
+ Delay in ms between pulse trains for STP analysis. Default is 250.0 ms.
1798
+ plot : bool, optional
1799
+ Whether to display the resulting table. Default is True.
1800
+
1801
+ Returns:
1802
+ --------
1803
+ pd.DataFrame
1804
+ DataFrame containing synaptic parameters for each connection with columns:
1805
+ - connection: Connection name
1806
+ - rise_time: 20-80% rise time (ms)
1807
+ - decay_time: Decay time constant (ms)
1808
+ - latency: Response latency (ms)
1809
+ - half_width: Response half-width (ms)
1810
+ - peak_amplitude: Peak synaptic current amplitude (pA)
1811
+ - baseline: Baseline current (pA)
1812
+ - ppr: Paired-pulse ratio (normalized)
1813
+ - simple_ppr: Simple paired-pulse ratio (2nd/1st pulse)
1814
+ - induction: STP induction measure
1815
+ - recovery: STP recovery measure
1816
+
1817
+ Notes:
1818
+ ------
1819
+ This method temporarily switches between connections to characterize each one,
1820
+ then restores the original connection. The STP metrics are calculated at the
1821
+ specified frequency and delay.
1822
+ """
1823
+ # Store original connection to restore later
1824
+ original_connection = self.current_connection
1825
+
1826
+ # Initialize results list
1827
+ results = []
1828
+
1829
+ print(f"Analyzing {len(self.conn_type_settings)} connections...")
1830
+
1831
+ for conn_name in tqdm(self.conn_type_settings.keys(), desc="Analyzing connections"):
1832
+ try:
1833
+ # Switch to this connection
1834
+ self._switch_connection(conn_name)
1835
+
1836
+ # Run single event analysis
1837
+ self.SingleEvent(plot_and_print=False)
1838
+
1839
+ # Get synaptic properties from the single event
1840
+ syn_props = self._get_syn_prop()
1841
+
1842
+ # Run STP analysis at specified frequency
1843
+ stp_results = self.stp_frequency_response(
1844
+ freqs=[stp_frequency],
1845
+ delay=stp_delay,
1846
+ plot=False,
1847
+ log_plot=False
1848
+ )
1849
+
1850
+ # Extract STP metrics for this frequency
1851
+ freq_idx = 0 # Only one frequency tested
1852
+ ppr = stp_results['ppr'][freq_idx]
1853
+ induction = stp_results['induction'][freq_idx]
1854
+ recovery = stp_results['recovery'][freq_idx]
1855
+ simple_ppr = stp_results['simple_ppr'][freq_idx]
1856
+
1857
+ # Compile results for this connection
1858
+ conn_results = {
1859
+ 'connection': conn_name,
1860
+ 'rise_time': float(self.rise_time),
1861
+ 'decay_time': float(self.decay_time),
1862
+ 'latency': float(syn_props.get('latency', 0)),
1863
+ 'half_width': float(syn_props.get('half_width', 0)),
1864
+ 'peak_amplitude': float(syn_props.get('amp', 0)),
1865
+ 'baseline': float(syn_props.get('baseline', 0)),
1866
+ 'ppr': float(ppr),
1867
+ 'simple_ppr': float(simple_ppr),
1868
+ 'induction': float(induction),
1869
+ 'recovery': float(recovery)
1870
+ }
1871
+
1872
+ results.append(conn_results)
1873
+
1874
+ except Exception as e:
1875
+ print(f"Warning: Failed to analyze connection '{conn_name}': {e}")
1876
+ # Add partial results if possible
1877
+ results.append({
1878
+ 'connection': conn_name,
1879
+ 'rise_time': float('nan'),
1880
+ 'decay_time': float('nan'),
1881
+ 'latency': float('nan'),
1882
+ 'half_width': float('nan'),
1883
+ 'peak_amplitude': float('nan'),
1884
+ 'baseline': float('nan'),
1885
+ 'ppr': float('nan'),
1886
+ 'simple_ppr': float('nan'),
1887
+ 'induction': float('nan'),
1888
+ 'recovery': float('nan')
1889
+ })
1890
+
1891
+ # Restore original connection
1892
+ if original_connection in self.conn_type_settings:
1893
+ self._switch_connection(original_connection)
1894
+
1895
+ # Create DataFrame
1896
+ df = pd.DataFrame(results)
1897
+
1898
+ # Set connection as index for better display
1899
+ df = df.set_index('connection')
1900
+
1901
+ if plot:
1902
+ # Display the table
1903
+ print("\nSynaptic Parameters Table:")
1904
+ print("=" * 80)
1905
+ display(df.round(4))
1906
+
1907
+ # Optional: Create a simple bar plot for key metrics
1908
+ try:
1909
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
1910
+ fig.suptitle(f'Synaptic Parameters Across Connections (STP at {stp_frequency}Hz)', fontsize=16)
1911
+
1912
+ # Plot rise/decay times
1913
+ df[['rise_time', 'decay_time']].plot(kind='bar', ax=axes[0,0])
1914
+ axes[0,0].set_title('Rise and Decay Times')
1915
+ axes[0,0].set_ylabel('Time (ms)')
1916
+ axes[0,0].tick_params(axis='x', rotation=45)
1917
+
1918
+ # Plot PPR metrics
1919
+ df[['ppr', 'simple_ppr']].plot(kind='bar', ax=axes[0,1])
1920
+ axes[0,1].set_title('Paired-Pulse Ratios')
1921
+ axes[0,1].axhline(y=1, color='gray', linestyle='--', alpha=0.5)
1922
+ axes[0,1].tick_params(axis='x', rotation=45)
1923
+
1924
+ # Plot induction
1925
+ df['induction'].plot(kind='bar', ax=axes[1,0], color='green')
1926
+ axes[1,0].set_title('STP Induction')
1927
+ axes[1,0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
1928
+ axes[1,0].set_ylabel('Induction')
1929
+ axes[1,0].tick_params(axis='x', rotation=45)
1930
+
1931
+ # Plot recovery
1932
+ df['recovery'].plot(kind='bar', ax=axes[1,1], color='orange')
1933
+ axes[1,1].set_title('STP Recovery')
1934
+ axes[1,1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
1935
+ axes[1,1].set_ylabel('Recovery')
1936
+ axes[1,1].tick_params(axis='x', rotation=45)
1937
+
1938
+ plt.tight_layout()
1939
+ plt.show()
1940
+
1941
+ except Exception as e:
1942
+ print(f"Warning: Could not create plots: {e}")
1943
+
1944
+ return df
1945
+
1946
+
1715
1947
  class GapJunctionTuner:
1716
1948
  def __init__(
1717
1949
  self,
@@ -1762,83 +1994,398 @@ class GapJunctionTuner:
1762
1994
  self.general_settings = {**DEFAULT_GAP_JUNCTION_GENERAL_SETTINGS, **general_settings}
1763
1995
  self.conn_type_settings = conn_type_settings
1764
1996
 
1997
+ self._syn_params_cache = {}
1998
+ self.config = config
1999
+ self.available_networks = []
2000
+ self.current_network = None
2001
+ self.last_figure = None
2002
+ if self.conn_type_settings is None and self.config is not None:
2003
+ self.conn_type_settings = self._build_conn_type_settings_from_config(self.config)
2004
+ if self.conn_type_settings is None or len(self.conn_type_settings) == 0:
2005
+ raise ValueError("conn_type_settings must be provided or config must be given to load gap junction connections from")
2006
+ self.current_connection = list(self.conn_type_settings.keys())[0]
2007
+ self.conn = self.conn_type_settings[self.current_connection]
2008
+
1765
2009
  h.tstop = self.general_settings["tstart"] + self.general_settings["tdur"] + 100.0
1766
2010
  h.dt = self.general_settings["dt"] # Time step (resolution) of the simulation in ms
1767
2011
  h.steps_per_ms = 1 / h.dt
1768
2012
  h.celsius = self.general_settings["celsius"]
1769
2013
 
2014
+ # Clean up any existing parallel context before setting up gap junctions
2015
+ try:
2016
+ pc_temp = h.ParallelContext()
2017
+ pc_temp.done() # Clean up any existing parallel context
2018
+ except:
2019
+ pass # Ignore errors if no existing context
2020
+
2021
+ # Force cleanup
2022
+ import gc
2023
+ gc.collect()
2024
+
1770
2025
  # set up gap junctions
1771
- pc = h.ParallelContext()
2026
+ self.pc = h.ParallelContext()
1772
2027
 
1773
2028
  # Use provided hoc_cell or create new cells
1774
2029
  if self.hoc_cell is not None:
1775
2030
  self.cell1 = self.hoc_cell
1776
2031
  # For gap junctions, we need two cells, so create a second one if using hoc_cell
1777
- self.cell_name = conn_type_settings["cell"]
2032
+ self.cell_name = self.conn['cell']
1778
2033
  self.cell2 = getattr(h, self.cell_name)()
1779
2034
  else:
1780
- self.cell_name = conn_type_settings["cell"]
2035
+ print(self.conn)
2036
+ self.cell_name = self.conn['cell']
1781
2037
  self.cell1 = getattr(h, self.cell_name)()
1782
2038
  self.cell2 = getattr(h, self.cell_name)()
1783
2039
 
1784
2040
  self.icl = h.IClamp(self.cell1.soma[0](0.5))
1785
2041
  self.icl.delay = self.general_settings["tstart"]
1786
2042
  self.icl.dur = self.general_settings["tdur"]
1787
- self.icl.amp = self.conn_type_settings["iclamp_amp"] # nA
2043
+ self.icl.amp = self.general_settings["iclamp_amp"] # nA
1788
2044
 
1789
- sec1 = list(self.cell1.all)[conn_type_settings["sec_id"]]
1790
- sec2 = list(self.cell2.all)[conn_type_settings["sec_id"]]
2045
+ sec1 = list(self.cell1.all)[self.conn["sec_id"]]
2046
+ sec2 = list(self.cell2.all)[self.conn["sec_id"]]
1791
2047
 
1792
- pc.source_var(sec1(conn_type_settings["sec_x"])._ref_v, 0, sec=sec1)
2048
+ # Use unique IDs to avoid conflicts with existing parallel context setups
2049
+ import time
2050
+ unique_id = int(time.time() * 1000) % 10000 # Use timestamp as unique base ID
2051
+
2052
+ self.pc.source_var(sec1(self.conn["sec_x"])._ref_v, unique_id, sec=sec1)
1793
2053
  self.gap_junc_1 = h.Gap(sec1(0.5))
1794
- pc.target_var(self.gap_junc_1._ref_vgap, 1)
2054
+ self.pc.target_var(self.gap_junc_1._ref_vgap, unique_id + 1)
1795
2055
 
1796
- pc.source_var(sec2(conn_type_settings["sec_x"])._ref_v, 1, sec=sec2)
2056
+ self.pc.source_var(sec2(self.conn["sec_x"])._ref_v, unique_id + 1, sec=sec2)
1797
2057
  self.gap_junc_2 = h.Gap(sec2(0.5))
1798
- pc.target_var(self.gap_junc_2._ref_vgap, 0)
2058
+ self.pc.target_var(self.gap_junc_2._ref_vgap, unique_id)
1799
2059
 
1800
- pc.setup_transfer()
1801
-
1802
- def model(self, resistance):
1803
- """
1804
- Run a simulation with a specified gap junction resistance.
2060
+ self.pc.setup_transfer()
2061
+
2062
+ # Now it's safe to initialize NEURON
2063
+ h.finitialize()
1805
2064
 
1806
- Parameters:
1807
- -----------
1808
- resistance : float
1809
- The gap junction resistance value (in MOhm) to use for the simulation.
2065
+ def _load_synaptic_params_from_config(self, config: dict, dynamics_params: str) -> dict:
2066
+ try:
2067
+ # Get the synaptic models directory from config
2068
+ synaptic_models_dir = config.get('components', {}).get('synaptic_models_dir', '')
2069
+ if synaptic_models_dir:
2070
+ # Handle path variables
2071
+ if synaptic_models_dir.startswith('$'):
2072
+ # This is a placeholder, try to resolve it
2073
+ config_dir = os.path.dirname(config.get('config_path', ''))
2074
+ synaptic_models_dir = synaptic_models_dir.replace('$COMPONENTS_DIR',
2075
+ os.path.join(config_dir, 'components'))
2076
+ synaptic_models_dir = synaptic_models_dir.replace('$BASE_DIR', config_dir)
2077
+
2078
+ dynamics_file = os.path.join(synaptic_models_dir, dynamics_params)
2079
+
2080
+ if os.path.exists(dynamics_file):
2081
+ with open(dynamics_file, 'r') as f:
2082
+ return json.load(f)
2083
+ else:
2084
+ print(f"Warning: Dynamics params file not found: {dynamics_file}")
2085
+ except Exception as e:
2086
+ print(f"Warning: Error loading synaptic parameters: {e}")
2087
+
2088
+ return {}
1810
2089
 
1811
- Notes:
1812
- ------
1813
- This method sets up the gap junction resistance, initializes recording vectors for time
1814
- and membrane voltages of both cells, and runs the NEURON simulation.
2090
+ def _load_available_networks(self) -> None:
1815
2091
  """
1816
- self.gap_junc_1.g = resistance
1817
- self.gap_junc_2.g = resistance
1818
-
1819
- t_vec = h.Vector()
1820
- soma_v_1 = h.Vector()
1821
- soma_v_2 = h.Vector()
1822
- t_vec.record(h._ref_t)
1823
- soma_v_1.record(self.cell1.soma[0](0.5)._ref_v)
1824
- soma_v_2.record(self.cell2.soma[0](0.5)._ref_v)
1825
-
1826
- self.t_vec = t_vec
1827
- self.soma_v_1 = soma_v_1
1828
- self.soma_v_2 = soma_v_2
1829
-
1830
- h.finitialize(-70 * mV)
1831
- h.continuerun(h.tstop * ms)
1832
-
1833
- def plot_model(self):
2092
+ Load available network names from the config file for the network dropdown feature.
2093
+
2094
+ This method is automatically called during initialization when a config file is provided.
2095
+ It populates the available_networks list which enables the network dropdown in
2096
+ InteractiveTuner when multiple networks are available.
2097
+
2098
+ Network Dropdown Behavior:
2099
+ -------------------------
2100
+ - If only one network exists: No network dropdown is shown
2101
+ - If multiple networks exist: Network dropdown appears next to connection dropdown
2102
+ - Networks are loaded from the edges data in the config file
2103
+ - Current network defaults to the first available if not specified during init
1834
2104
  """
1835
- Plot the voltage traces of both cells to visualize gap junction coupling.
2105
+ if self.config is None:
2106
+ self.available_networks = []
2107
+ return
2108
+
2109
+ try:
2110
+ edges = load_edges_from_config(self.config)
2111
+ self.available_networks = list(edges.keys())
2112
+
2113
+ # Set current network to first available if not specified
2114
+ if self.current_network is None and self.available_networks:
2115
+ self.current_network = self.available_networks[0]
2116
+ except Exception as e:
2117
+ print(f"Warning: Could not load networks from config: {e}")
2118
+ self.available_networks = []
1836
2119
 
1837
- This method creates a plot showing the membrane potential of both cells over time,
1838
- highlighting the effect of gap junction coupling when a current step is applied to cell 1.
2120
+ def _build_conn_type_settings_from_config(self, config_path: str) -> Dict[str, dict]:
2121
+ # Load configuration and get nodes and edges using util.py methods
2122
+ config = load_config(config_path)
2123
+ # Ensure the config dict knows its source path so path substitutions can be resolved
2124
+ try:
2125
+ config['config_path'] = config_path
2126
+ except Exception:
2127
+ pass
2128
+ nodes = load_nodes_from_config(config_path)
2129
+ edges = load_edges_from_config(config_path)
2130
+
2131
+ conn_type_settings = {}
2132
+
2133
+ # Process all edge datasets
2134
+ for edge_dataset_name, edge_df in edges.items():
2135
+ if edge_df.empty:
2136
+ continue
2137
+
2138
+ # Merging with node data to get model templates
2139
+ source_node_df = None
2140
+ target_node_df = None
2141
+
2142
+ # First, try to deterministically parse the edge_dataset_name for patterns like '<src>_to_<tgt>'
2143
+ if '_to_' in edge_dataset_name:
2144
+ parts = edge_dataset_name.split('_to_')
2145
+ if len(parts) == 2:
2146
+ src_name, tgt_name = parts
2147
+ if src_name in nodes:
2148
+ source_node_df = nodes[src_name].add_prefix('source_')
2149
+ if tgt_name in nodes:
2150
+ target_node_df = nodes[tgt_name].add_prefix('target_')
2151
+
2152
+ # If not found by parsing name, fall back to inspecting a sample edge row
2153
+ if source_node_df is None or target_node_df is None:
2154
+ sample_edge = edge_df.iloc[0] if len(edge_df) > 0 else None
2155
+ if sample_edge is not None:
2156
+ source_pop_name = sample_edge.get('source_population', '')
2157
+ target_pop_name = sample_edge.get('target_population', '')
2158
+ if source_pop_name in nodes:
2159
+ source_node_df = nodes[source_pop_name].add_prefix('source_')
2160
+ if target_pop_name in nodes:
2161
+ target_node_df = nodes[target_pop_name].add_prefix('target_')
2162
+
2163
+ # As a last resort, attempt to heuristically match
2164
+ if source_node_df is None or target_node_df is None:
2165
+ for pop_name, node_df in nodes.items():
2166
+ if source_node_df is None and (edge_dataset_name.startswith(pop_name) or edge_dataset_name.endswith(pop_name)):
2167
+ source_node_df = node_df.add_prefix('source_')
2168
+ if target_node_df is None and (edge_dataset_name.startswith(pop_name) or edge_dataset_name.endswith(pop_name)):
2169
+ target_node_df = node_df.add_prefix('target_')
2170
+
2171
+ if source_node_df is None or target_node_df is None:
2172
+ print(f"Warning: Could not find node data for edge dataset {edge_dataset_name}")
2173
+ continue
2174
+
2175
+ # Merge edge data with source node info
2176
+ edges_with_source = pd.merge(
2177
+ edge_df.reset_index(),
2178
+ source_node_df,
2179
+ how='left',
2180
+ left_on='source_node_id',
2181
+ right_index=True
2182
+ )
2183
+
2184
+ # Merge with target node info
2185
+ edges_with_nodes = pd.merge(
2186
+ edges_with_source,
2187
+ target_node_df,
2188
+ how='left',
2189
+ left_on='target_node_id',
2190
+ right_index=True
2191
+ )
2192
+
2193
+ # Skip edge datasets that don't have gap junction information
2194
+ if 'is_gap_junction' not in edges_with_nodes.columns:
2195
+ continue
2196
+
2197
+ # Filter to only gap junction edges
2198
+ # Handle NaN values in is_gap_junction column
2199
+ gap_junction_mask = edges_with_nodes['is_gap_junction'].fillna(False) == True
2200
+ gap_junction_edges = edges_with_nodes[gap_junction_mask]
2201
+ if gap_junction_edges.empty:
2202
+ continue
2203
+
2204
+ # Get unique edge types from the gap junction edges
2205
+ if 'edge_type_id' in gap_junction_edges.columns:
2206
+ edge_types = gap_junction_edges['edge_type_id'].unique()
2207
+ else:
2208
+ edge_types = [None] # Single edge type
2209
+
2210
+ # Process each edge type
2211
+ for edge_type_id in edge_types:
2212
+ # Filter edges for this type
2213
+ if edge_type_id is not None:
2214
+ edge_type_data = gap_junction_edges[gap_junction_edges['edge_type_id'] == edge_type_id]
2215
+ else:
2216
+ edge_type_data = gap_junction_edges
2217
+
2218
+ if len(edge_type_data) == 0:
2219
+ continue
2220
+
2221
+ # Get representative edge for this type
2222
+ edge_info = edge_type_data.iloc[0]
2223
+
2224
+ # Process gap junction
2225
+ source_model_template = edge_info.get('source_model_template', '')
2226
+ target_model_template = edge_info.get('target_model_template', '')
2227
+
2228
+ source_cell_type = source_model_template.replace('hoc:', '') if source_model_template.startswith('hoc:') else source_model_template
2229
+ target_cell_type = target_model_template.replace('hoc:', '') if target_model_template.startswith('hoc:') else target_model_template
2230
+
2231
+ if source_cell_type != target_cell_type:
2232
+ continue # Only process gap junctions between same cell types
2233
+
2234
+ source_pop = edge_info.get('source_pop_name', '')
2235
+ target_pop = edge_info.get('target_pop_name', '')
2236
+
2237
+ conn_name = f"{source_pop}2{target_pop}_gj"
2238
+ if edge_type_id is not None:
2239
+ conn_name += f"_type_{edge_type_id}"
2240
+
2241
+ conn_settings = {
2242
+ 'cell': source_cell_type,
2243
+ 'sec_id': 0,
2244
+ 'sec_x': 0.5,
2245
+ 'iclamp_amp': -0.01,
2246
+ 'spec_syn_param': {}
2247
+ }
2248
+
2249
+ # Load dynamics params
2250
+ dynamics_file_name = edge_info.get('dynamics_params', '')
2251
+ if dynamics_file_name and dynamics_file_name.upper() != 'NULL':
2252
+ try:
2253
+ syn_params = self._load_synaptic_params_from_config(config, dynamics_file_name)
2254
+ conn_settings['spec_syn_param'] = syn_params
2255
+ except Exception as e:
2256
+ print(f"Warning: could not load dynamics_params file '{dynamics_file_name}': {e}")
2257
+
2258
+ conn_type_settings[conn_name] = conn_settings
2259
+
2260
+ return conn_type_settings
2261
+
2262
+ def _switch_connection(self, new_connection: str) -> None:
1839
2263
  """
1840
- t_range = [
1841
- self.general_settings["tstart"] - 100.0,
2264
+ Switch to a different gap junction connection and update all related properties.
2265
+
2266
+ Parameters:
2267
+ -----------
2268
+ new_connection : str
2269
+ Name of the new connection type to switch to.
2270
+ """
2271
+ if new_connection not in self.conn_type_settings:
2272
+ raise ValueError(f"Connection '{new_connection}' not found in conn_type_settings")
2273
+
2274
+ # Update current connection
2275
+ self.current_connection = new_connection
2276
+ self.conn = self.conn_type_settings[new_connection]
2277
+
2278
+ # Check if cell type changed
2279
+ new_cell_name = self.conn['cell']
2280
+ if self.cell_name != new_cell_name:
2281
+ self.cell_name = new_cell_name
2282
+
2283
+ # Recreate cells
2284
+ if self.hoc_cell is None:
2285
+ self.cell1 = getattr(h, self.cell_name)()
2286
+ self.cell2 = getattr(h, self.cell_name)()
2287
+ else:
2288
+ # For hoc_cell, recreate the second cell
2289
+ self.cell2 = getattr(h, self.cell_name)()
2290
+
2291
+ # Recreate IClamp
2292
+ self.icl = h.IClamp(self.cell1.soma[0](0.5))
2293
+ self.icl.delay = self.general_settings["tstart"]
2294
+ self.icl.dur = self.general_settings["tdur"]
2295
+ self.icl.amp = self.general_settings["iclamp_amp"]
2296
+ else:
2297
+ # Update IClamp parameters even if same cell type
2298
+ self.icl.amp = self.general_settings["iclamp_amp"]
2299
+
2300
+ # Always recreate gap junctions when switching connections
2301
+ # (even for same cell type, sec_id or sec_x might differ)
2302
+
2303
+ # Clean up previous gap junctions and parallel context
2304
+ if hasattr(self, 'gap_junc_1'):
2305
+ del self.gap_junc_1
2306
+ if hasattr(self, 'gap_junc_2'):
2307
+ del self.gap_junc_2
2308
+
2309
+ # Properly clean up the existing parallel context
2310
+ if hasattr(self, 'pc'):
2311
+ self.pc.done() # Clean up existing parallel context
2312
+
2313
+ # Force garbage collection and reset NEURON state
2314
+ import gc
2315
+ gc.collect()
2316
+ h.finitialize()
2317
+
2318
+ # Create a fresh parallel context after cleanup
2319
+ self.pc = h.ParallelContext()
2320
+
2321
+ try:
2322
+ sec1 = list(self.cell1.all)[self.conn["sec_id"]]
2323
+ sec2 = list(self.cell2.all)[self.conn["sec_id"]]
2324
+
2325
+ # Use unique IDs to avoid conflicts with existing parallel context setups
2326
+ import time
2327
+ unique_id = int(time.time() * 1000) % 10000 # Use timestamp as unique base ID
2328
+
2329
+ self.pc.source_var(sec1(self.conn["sec_x"])._ref_v, unique_id, sec=sec1)
2330
+ self.gap_junc_1 = h.Gap(sec1(0.5))
2331
+ self.pc.target_var(self.gap_junc_1._ref_vgap, unique_id + 1)
2332
+
2333
+ self.pc.source_var(sec2(self.conn["sec_x"])._ref_v, unique_id + 1, sec=sec2)
2334
+ self.gap_junc_2 = h.Gap(sec2(0.5))
2335
+ self.pc.target_var(self.gap_junc_2._ref_vgap, unique_id)
2336
+
2337
+ self.pc.setup_transfer()
2338
+ except Exception as e:
2339
+ print(f"Error setting up gap junctions: {e}")
2340
+ # Try to continue with basic setup
2341
+ self.gap_junc_1 = h.Gap(list(self.cell1.all)[self.conn["sec_id"]](0.5))
2342
+ self.gap_junc_2 = h.Gap(list(self.cell2.all)[self.conn["sec_id"]](0.5))
2343
+
2344
+ # Reset NEURON state after complete setup
2345
+ h.finitialize()
2346
+
2347
+ print(f"Successfully switched to connection: {new_connection}")
2348
+
2349
+ def model(self, resistance):
2350
+ """
2351
+ Run a simulation with a specified gap junction resistance.
2352
+
2353
+ Parameters:
2354
+ -----------
2355
+ resistance : float
2356
+ The gap junction resistance value (in MOhm) to use for the simulation.
2357
+
2358
+ Notes:
2359
+ ------
2360
+ This method sets up the gap junction resistance, initializes recording vectors for time
2361
+ and membrane voltages of both cells, and runs the NEURON simulation.
2362
+ """
2363
+ self.gap_junc_1.g = resistance
2364
+ self.gap_junc_2.g = resistance
2365
+
2366
+ t_vec = h.Vector()
2367
+ soma_v_1 = h.Vector()
2368
+ soma_v_2 = h.Vector()
2369
+ t_vec.record(h._ref_t)
2370
+ soma_v_1.record(self.cell1.soma[0](0.5)._ref_v)
2371
+ soma_v_2.record(self.cell2.soma[0](0.5)._ref_v)
2372
+
2373
+ self.t_vec = t_vec
2374
+ self.soma_v_1 = soma_v_1
2375
+ self.soma_v_2 = soma_v_2
2376
+
2377
+ h.finitialize(-70 * mV)
2378
+ h.continuerun(h.tstop * ms)
2379
+
2380
+ def plot_model(self):
2381
+ """
2382
+ Plot the voltage traces of both cells to visualize gap junction coupling.
2383
+
2384
+ This method creates a plot showing the membrane potential of both cells over time,
2385
+ highlighting the effect of gap junction coupling when a current step is applied to cell 1.
2386
+ """
2387
+ t_range = [
2388
+ self.general_settings["tstart"] - 100.0,
1842
2389
  self.general_settings["tstart"] + self.general_settings["tdur"] + 100.0,
1843
2390
  ]
1844
2391
  t = np.array(self.t_vec)
@@ -1853,6 +2400,7 @@ class GapJunctionTuner:
1853
2400
  plt.xlabel("Time (ms)")
1854
2401
  plt.ylabel("Membrane Voltage (mV)")
1855
2402
  plt.legend()
2403
+ self.last_figure = plt.gcf()
1856
2404
 
1857
2405
  def coupling_coefficient(self, t, v1, v2, t_start, t_end, dt=h.dt):
1858
2406
  """
@@ -1887,648 +2435,1133 @@ class GapJunctionTuner:
1887
2435
  return (v2[idx2] - v2[idx1]) / (v1[idx2] - v1[idx1])
1888
2436
 
1889
2437
  def InteractiveTuner(self):
1890
- w_run = widgets.Button(description="Run", icon="history", button_style="primary")
1891
- values = [i * 10**-4 for i in range(1, 1001)] # From 1e-4 to 1e-1
1892
-
1893
- # Create the SelectionSlider widget with appropriate formatting
1894
- resistance = widgets.FloatLogSlider(
1895
- value=0.001,
1896
- base=10,
1897
- min=-4, # max exponent of base
1898
- max=-1, # min exponent of base
1899
- step=0.1, # exponent step
1900
- description="Resistance: ",
1901
- continuous_update=True,
1902
- )
1903
-
1904
- ui = VBox([w_run, resistance])
1905
-
1906
- # Create an output widget to control what gets cleared
1907
- output = widgets.Output()
1908
-
1909
- display(ui)
1910
- display(output)
1911
-
1912
- def on_button(*args):
1913
- with output:
1914
- # Clear only the output widget, not the entire cell
1915
- output.clear_output(wait=True)
1916
-
1917
- resistance_for_gap = resistance.value
1918
- print(f"Running simulation with resistance: {resistance_for_gap}")
1919
-
1920
- try:
1921
- self.model(resistance_for_gap)
1922
- self.plot_model()
1923
-
1924
- # Convert NEURON vectors to numpy arrays
1925
- t_array = np.array(self.t_vec)
1926
- v1_array = np.array(self.soma_v_1)
1927
- v2_array = np.array(self.soma_v_2)
1928
-
1929
- cc = self.coupling_coefficient(t_array, v1_array, v2_array, 500, 1000)
1930
- print(f"coupling_coefficient is {cc:0.4f}")
1931
- plt.show()
1932
-
1933
- except Exception as e:
1934
- print(f"Error during simulation or analysis: {e}")
1935
- import traceback
1936
-
1937
- traceback.print_exc()
1938
-
1939
- # Run once initially
1940
- on_button()
1941
- w_run.on_click(on_button)
1942
-
1943
-
1944
- # optimizers!
1945
-
1946
-
1947
- @dataclass
1948
- class SynapseOptimizationResult:
1949
- """Container for synaptic parameter optimization results"""
1950
-
1951
- optimal_params: Dict[str, float]
1952
- achieved_metrics: Dict[str, float]
1953
- target_metrics: Dict[str, float]
1954
- error: float
1955
- optimization_path: List[Dict[str, float]]
1956
-
1957
-
1958
- class SynapseOptimizer:
1959
- def __init__(self, tuner):
1960
2438
  """
1961
- Initialize the synapse optimizer with parameter scaling
1962
-
1963
- Parameters:
1964
- -----------
1965
- tuner : SynapseTuner
1966
- Instance of the SynapseTuner class
1967
- """
1968
- self.tuner = tuner
1969
- self.optimization_history = []
1970
- self.param_scales = {}
2439
+ Sets up interactive sliders for tuning short-term plasticity (STP) parameters in a Jupyter Notebook.
1971
2440
 
1972
- def _normalize_params(self, params: np.ndarray, param_names: List[str]) -> np.ndarray:
1973
- """
1974
- Normalize parameters to similar scales for better optimization performance.
2441
+ This method creates an interactive UI with sliders for:
2442
+ - Network selection dropdown (if multiple networks available and config provided)
2443
+ - Connection type selection dropdown
2444
+ - Input frequency
2445
+ - Delay between pulse trains
2446
+ - Duration of stimulation (for continuous input mode)
2447
+ - Synaptic parameters (e.g., Use, tau_f, tau_d) based on the syn model
1975
2448
 
1976
- Parameters:
1977
- -----------
1978
- params : np.ndarray
1979
- Original parameter values.
1980
- param_names : List[str]
1981
- Names of the parameters corresponding to the values.
2449
+ It also provides buttons for:
2450
+ - Running a single event simulation
2451
+ - Running a train input simulation
2452
+ - Toggling voltage clamp mode
2453
+ - Switching between standard and continuous input modes
1982
2454
 
1983
- Returns:
1984
- --------
1985
- np.ndarray
1986
- Normalized parameter values.
1987
- """
1988
- return np.array([params[i] / self.param_scales[name] for i, name in enumerate(param_names)])
2455
+ Network Dropdown Feature:
2456
+ ------------------------
2457
+ When the SynapseTuner is initialized with a BMTK config file containing multiple networks:
2458
+ - A network dropdown appears next to the connection dropdown
2459
+ - Users can dynamically switch between networks (e.g., 'network_to_network', 'external_to_network')
2460
+ - Switching networks rebuilds available connections and updates the connection dropdown
2461
+ - The current connection is preserved if it exists in the new network
2462
+ - If multiple networks exist but only one is specified during init, that network is used as default
1989
2463
 
1990
- def _denormalize_params(
1991
- self, normalized_params: np.ndarray, param_names: List[str]
1992
- ) -> np.ndarray:
2464
+ Notes:
2465
+ ------
2466
+ Ideal for exploratory parameter tuning and interactive visualization of
2467
+ synapse behavior with different parameter values and stimulation protocols.
2468
+ The network dropdown feature enables comprehensive exploration of multi-network
2469
+ BMTK simulations without needing to reinitialize the tuner.
1993
2470
  """
1994
- Convert normalized parameters back to original scale.
1995
-
1996
- Parameters:
1997
- -----------
1998
- normalized_params : np.ndarray
1999
- Normalized parameter values.
2000
- param_names : List[str]
2001
- Names of the parameters corresponding to the normalized values.
2471
+ # Widgets setup (Sliders)
2472
+ freqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200]
2473
+ delays = [125, 250, 500, 1000, 2000, 4000]
2474
+ durations = [100, 300, 500, 1000, 2000, 5000, 10000]
2475
+ freq0 = 50
2476
+ delay0 = 250
2477
+ duration0 = 300
2478
+ vlamp_status = self.vclamp
2002
2479
 
2003
- Returns:
2004
- --------
2005
- np.ndarray
2006
- Denormalized parameter values in their original scale.
2007
- """
2008
- return np.array(
2009
- [normalized_params[i] * self.param_scales[name] for i, name in enumerate(param_names)]
2480
+ # Connection dropdown
2481
+ connection_options = sorted(list(self.conn_type_settings.keys()))
2482
+ w_connection = widgets.Dropdown(
2483
+ options=connection_options,
2484
+ value=self.current_connection,
2485
+ description="Connection:",
2486
+ style={'description_width': 'initial'}
2010
2487
  )
2011
2488
 
2012
- def _calculate_metrics(self) -> Dict[str, float]:
2013
- """
2014
- Calculate standard metrics from the current simulation.
2015
-
2016
- This method runs either a single event simulation, a train input simulation,
2017
- or both based on configuration flags, and calculates relevant synaptic metrics.
2018
-
2019
- Returns:
2020
- --------
2021
- Dict[str, float]
2022
- Dictionary of calculated metrics including:
2023
- - induction: measure of synaptic facilitation/depression
2024
- - ppr: paired-pulse ratio
2025
- - recovery: recovery from facilitation/depression
2026
- - max_amplitude: maximum synaptic response amplitude
2027
- - rise_time: time for synaptic response to rise from 20% to 80% of peak
2028
- - decay_time: time constant of synaptic response decay
2029
- - latency: synaptic response latency
2030
- - half_width: synaptic response half-width
2031
- - baseline: baseline current
2032
- - amp: peak amplitude from syn_props
2033
- """
2034
- # Set these to 0 for when we return the dict
2035
- induction = 0
2036
- ppr = 0
2037
- recovery = 0
2038
- amp = 0
2039
- rise_time = 0
2040
- decay_time = 0
2041
- latency = 0
2042
- half_width = 0
2043
- baseline = 0
2044
- syn_amp = 0
2045
-
2046
- if self.run_single_event:
2047
- self.tuner.SingleEvent(plot_and_print=False)
2048
- # Use the attributes set by SingleEvent method
2049
- rise_time = getattr(self.tuner, "rise_time", 0)
2050
- decay_time = getattr(self.tuner, "decay_time", 0)
2051
- # Get additional syn_props directly
2052
- syn_props = self.tuner._get_syn_prop()
2053
- latency = syn_props.get("latency", 0)
2054
- half_width = syn_props.get("half_width", 0)
2055
- baseline = syn_props.get("baseline", 0)
2056
- syn_amp = syn_props.get("amp", 0)
2057
-
2058
- if self.run_train_input:
2059
- self.tuner._simulate_model(self.train_frequency, self.train_delay)
2060
- amp = self.tuner._response_amplitude()
2061
- ppr, induction, recovery = self.tuner._calc_ppr_induction_recovery(
2062
- amp, print_math=False
2489
+ # Network dropdown - only shown if config was provided and multiple networks are available
2490
+ # This enables users to switch between different network datasets dynamically
2491
+ w_network = None
2492
+ if self.config is not None and len(self.available_networks) > 1:
2493
+ w_network = widgets.Dropdown(
2494
+ options=self.available_networks,
2495
+ value=self.current_network,
2496
+ description="Network:",
2497
+ style={'description_width': 'initial'}
2063
2498
  )
2064
- amp = self.tuner._find_max_amp(amp)
2065
-
2066
- return {
2067
- "induction": float(induction),
2068
- "ppr": float(ppr),
2069
- "recovery": float(recovery),
2070
- "max_amplitude": float(amp),
2071
- "rise_time": float(rise_time),
2072
- "decay_time": float(decay_time),
2073
- "latency": float(latency),
2074
- "half_width": float(half_width),
2075
- "baseline": float(baseline),
2076
- "amp": float(syn_amp),
2077
- }
2078
-
2079
- def _default_cost_function(
2080
- self, metrics: Dict[str, float], target_metrics: Dict[str, float]
2081
- ) -> float:
2082
- """
2083
- Default cost function that minimizes the squared difference between achieved and target induction.
2084
-
2085
- Parameters:
2086
- -----------
2087
- metrics : Dict[str, float]
2088
- Dictionary of calculated metrics from the current simulation.
2089
- target_metrics : Dict[str, float]
2090
- Dictionary of target metrics to optimize towards.
2091
-
2092
- Returns:
2093
- --------
2094
- float
2095
- The squared error between achieved and target induction.
2096
- """
2097
- return float((metrics["induction"] - target_metrics["induction"]) ** 2)
2098
-
2099
- def _objective_function(
2100
- self,
2101
- normalized_params: np.ndarray,
2102
- param_names: List[str],
2103
- cost_function: Callable,
2104
- target_metrics: Dict[str, float],
2105
- ) -> float:
2106
- """
2107
- Calculate error using provided cost function
2108
- """
2109
- # Denormalize parameters
2110
- params = self._denormalize_params(normalized_params, param_names)
2111
-
2112
- # Set parameters
2113
- for name, value in zip(param_names, params):
2114
- setattr(self.tuner.syn, name, value)
2115
-
2116
- # just do this and have the SingleEvent handle it
2117
- if self.run_single_event:
2118
- self.tuner.using_optimizer = True
2119
- self.tuner.param_names = param_names
2120
- self.tuner.params = params
2121
-
2122
- # Calculate metrics and error
2123
- metrics = self._calculate_metrics()
2124
- error = float(cost_function(metrics, target_metrics)) # Ensure error is scalar
2125
-
2126
- # Store history with denormalized values
2127
- history_entry = {
2128
- "params": dict(zip(param_names, params)),
2129
- "metrics": metrics,
2130
- "error": error,
2131
- }
2132
- self.optimization_history.append(history_entry)
2133
-
2134
- return error
2135
-
2136
- def optimize_parameters(
2137
- self,
2138
- target_metrics: Dict[str, float],
2139
- param_bounds: Dict[str, Tuple[float, float]],
2140
- run_single_event: bool = False,
2141
- run_train_input: bool = True,
2142
- train_frequency: float = 50,
2143
- train_delay: float = 250,
2144
- cost_function: Optional[Callable] = None,
2145
- method: str = "SLSQP",
2146
- init_guess="random",
2147
- ) -> SynapseOptimizationResult:
2148
- """
2149
- Optimize synaptic parameters to achieve target metrics.
2150
2499
 
2151
- Parameters:
2152
- -----------
2153
- target_metrics : Dict[str, float]
2154
- Target values for synaptic metrics (e.g., {'induction': 0.2, 'rise_time': 0.5})
2155
- param_bounds : Dict[str, Tuple[float, float]]
2156
- Bounds for each parameter to optimize (e.g., {'tau_d': (5, 50), 'Use': (0.1, 0.9)})
2157
- run_single_event : bool, optional
2158
- Whether to run single event simulations during optimization (default: False)
2159
- run_train_input : bool, optional
2160
- Whether to run train input simulations during optimization (default: True)
2161
- train_frequency : float, optional
2162
- Frequency of the stimulus train in Hz (default: 50)
2163
- train_delay : float, optional
2164
- Delay between pulse trains in ms (default: 250)
2165
- cost_function : Optional[Callable]
2166
- Custom cost function for optimization. If None, uses default cost function
2167
- that optimizes induction.
2168
- method : str, optional
2169
- Optimization method to use (default: 'SLSQP')
2170
- init_guess : str, optional
2171
- Method for initial parameter guess ('random' or 'middle_guess')
2172
-
2173
- Returns:
2174
- --------
2175
- SynapseOptimizationResult
2176
- Results of the optimization including optimal parameters, achieved metrics,
2177
- target metrics, final error, and optimization path.
2178
-
2179
- Notes:
2180
- ------
2181
- This function uses scipy.optimize.minimize to find the optimal parameter values
2182
- that minimize the difference between achieved and target metrics.
2183
- """
2184
- self.optimization_history = []
2185
- self.train_frequency = train_frequency
2186
- self.train_delay = train_delay
2187
- self.run_single_event = run_single_event
2188
- self.run_train_input = run_train_input
2500
+ w_run = widgets.Button(description="Run Train", icon="history", button_style="primary")
2501
+ w_single = widgets.Button(description="Single Event", icon="check", button_style="success")
2502
+ w_vclamp = widgets.ToggleButton(
2503
+ value=vlamp_status,
2504
+ description="Voltage Clamp",
2505
+ icon="fast-backward",
2506
+ button_style="warning",
2507
+ )
2508
+
2509
+ # Voltage clamp amplitude input
2510
+ default_vclamp_amp = getattr(self.conn['spec_settings'], 'vclamp_amp', -70.0)
2511
+ w_vclamp_amp = widgets.FloatText(
2512
+ value=default_vclamp_amp,
2513
+ description="V_clamp (mV):",
2514
+ step=5.0,
2515
+ style={'description_width': 'initial'},
2516
+ layout=widgets.Layout(width='150px')
2517
+ )
2518
+
2519
+ w_input_mode = widgets.ToggleButton(
2520
+ value=False, description="Continuous input", icon="eject", button_style="info"
2521
+ )
2522
+ w_input_freq = widgets.SelectionSlider(options=freqs, value=freq0, description="Input Freq")
2189
2523
 
2190
- param_names = list(param_bounds.keys())
2191
- bounds = [param_bounds[name] for name in param_names]
2524
+ # Sliders for delay and duration
2525
+ self.w_delay = widgets.SelectionSlider(options=delays, value=delay0, description="Delay")
2526
+ self.w_duration = widgets.SelectionSlider(
2527
+ options=durations, value=duration0, description="Duration"
2528
+ )
2192
2529
 
2193
- if cost_function is None:
2194
- cost_function = self._default_cost_function
2530
+ # Save functionality widgets
2531
+ save_path_text = widgets.Text(
2532
+ value="plot.png",
2533
+ description="Save path:",
2534
+ layout=widgets.Layout(width='300px')
2535
+ )
2536
+ save_button = widgets.Button(description="Save Plot", icon="save", button_style="success")
2195
2537
 
2196
- # Calculate scaling factors
2197
- self.param_scales = {
2198
- name: max(abs(bounds[i][0]), abs(bounds[i][1])) for i, name in enumerate(param_names)
2199
- }
2538
+ def save_plot(b):
2539
+ if hasattr(self, 'last_figure') and self.last_figure is not None:
2540
+ try:
2541
+ # Create a new figure with just the first subplot (synaptic current)
2542
+ fig, ax = plt.subplots(figsize=(8, 6))
2543
+
2544
+ # Get the axes from the original figure
2545
+ original_axes = self.last_figure.get_axes()
2546
+ if len(original_axes) > 0:
2547
+ first_ax = original_axes[0]
2548
+
2549
+ # Copy the data from the first subplot
2550
+ for line in first_ax.get_lines():
2551
+ ax.plot(line.get_xdata(), line.get_ydata(),
2552
+ color=line.get_color(), label=line.get_label())
2553
+
2554
+ # Copy axis labels and title
2555
+ ax.set_xlabel(first_ax.get_xlabel())
2556
+ ax.set_ylabel(first_ax.get_ylabel())
2557
+ ax.set_title(first_ax.get_title())
2558
+ ax.set_xlim(first_ax.get_xlim())
2559
+ ax.legend()
2560
+ ax.grid(True)
2561
+
2562
+ # Save the new figure
2563
+ fig.savefig(save_path_text.value)
2564
+ plt.close(fig) # Close the temporary figure
2565
+ print(f"Synaptic current plot saved to {save_path_text.value}")
2566
+ else:
2567
+ print("No subplots found in the figure")
2568
+
2569
+ except Exception as e:
2570
+ print(f"Error saving plot: {e}")
2571
+ else:
2572
+ print("No plot to save")
2200
2573
 
2201
- # Normalize bounds
2202
- normalized_bounds = [
2203
- (b[0] / self.param_scales[name], b[1] / self.param_scales[name])
2204
- for name, b in zip(param_names, bounds)
2205
- ]
2574
+ save_button.on_click(save_plot)
2206
2575
 
2207
- # picks with method of init value we want to use
2208
- if init_guess == "random":
2209
- x0 = np.array([np.random.uniform(b[0], b[1]) for b in bounds])
2210
- elif init_guess == "middle_guess":
2211
- x0 = [(b[0] + b[1]) / 2 for b in bounds]
2212
- else:
2213
- raise Exception("Pick a vaid init guess method either random or midde_guess")
2214
- normalized_x0 = self._normalize_params(np.array(x0), param_names)
2215
-
2216
- # Run optimization
2217
- result = minimize(
2218
- self._objective_function,
2219
- normalized_x0,
2220
- args=(param_names, cost_function, target_metrics),
2221
- method=method,
2222
- bounds=normalized_bounds,
2223
- )
2576
+ def create_dynamic_sliders():
2577
+ """Create sliders based on current connection's parameters"""
2578
+ sliders = {}
2579
+ for key, value in self.slider_vars.items():
2580
+ if isinstance(value, (int, float)): # Only create sliders for numeric values
2581
+ if hasattr(self.syn, key):
2582
+ if value == 0:
2583
+ print(
2584
+ f"{key} was set to zero, going to try to set a range of values, try settings the {key} to a nonzero value if you dont like the range!"
2585
+ )
2586
+ slider = widgets.FloatSlider(
2587
+ value=value, min=0, max=1000, step=1, description=key
2588
+ )
2589
+ else:
2590
+ slider = widgets.FloatSlider(
2591
+ value=value, min=0, max=value * 20, step=value / 5, description=key
2592
+ )
2593
+ sliders[key] = slider
2594
+ else:
2595
+ print(f"skipping slider for {key} due to not being a synaptic variable")
2596
+ return sliders
2224
2597
 
2225
- # Get final parameters and metrics
2226
- final_params = dict(zip(param_names, self._denormalize_params(result.x, param_names)))
2227
- for name, value in final_params.items():
2228
- setattr(self.tuner.syn, name, value)
2229
- final_metrics = self._calculate_metrics()
2230
-
2231
- return SynapseOptimizationResult(
2232
- optimal_params=final_params,
2233
- achieved_metrics=final_metrics,
2234
- target_metrics=target_metrics,
2235
- error=result.fun,
2236
- optimization_path=self.optimization_history,
2598
+ # Generate sliders dynamically based on valid numeric entries in self.slider_vars
2599
+ self.dynamic_sliders = create_dynamic_sliders()
2600
+ print(
2601
+ "Setting up slider! The sliders ranges are set by their init value so try changing that if you dont like the slider range!"
2237
2602
  )
2238
2603
 
2239
- def plot_optimization_results(self, result: SynapseOptimizationResult):
2604
+ # Create output widget for displaying results
2605
+ output_widget = widgets.Output()
2606
+
2607
+ def run_single_event(*args):
2608
+ clear_output()
2609
+ display(ui)
2610
+ display(output_widget)
2611
+
2612
+ self.vclamp = w_vclamp.value
2613
+ # Update voltage clamp amplitude if voltage clamp is enabled
2614
+ if self.vclamp:
2615
+ # Update the voltage clamp amplitude settings
2616
+ self.conn['spec_settings']['vclamp_amp'] = w_vclamp_amp.value
2617
+ # Update general settings if they exist
2618
+ if hasattr(self, 'general_settings'):
2619
+ self.general_settings['vclamp_amp'] = w_vclamp_amp.value
2620
+ # Update synaptic properties based on slider values
2621
+ self.ispk = None
2622
+
2623
+ # Clear previous results and run simulation
2624
+ output_widget.clear_output()
2625
+ with output_widget:
2626
+ self.SingleEvent()
2627
+
2628
+ def on_connection_change(*args):
2629
+ """Handle connection dropdown change"""
2630
+ try:
2631
+ new_connection = w_connection.value
2632
+ if new_connection != self.current_connection:
2633
+ # Switch to new connection
2634
+ self._switch_connection(new_connection)
2635
+
2636
+ # Recreate dynamic sliders for new connection
2637
+ self.dynamic_sliders = create_dynamic_sliders()
2638
+
2639
+ # Update UI
2640
+ update_ui_layout()
2641
+ update_ui()
2642
+
2643
+ except Exception as e:
2644
+ print(f"Error switching connection: {e}")
2645
+
2646
+ def on_network_change(*args):
2647
+ """
2648
+ Handle network dropdown change events.
2649
+
2650
+ This callback is triggered when the user selects a different network from
2651
+ the network dropdown. It coordinates the complete switching process:
2652
+ 1. Calls _switch_network() to rebuild connections for the new network
2653
+ 2. Updates the connection dropdown options with new network's connections
2654
+ 3. Recreates dynamic sliders for new connection parameters
2655
+ 4. Refreshes the entire UI to reflect all changes
2656
+ """
2657
+ if w_network is None:
2658
+ return
2659
+ try:
2660
+ new_network = w_network.value
2661
+ if new_network != self.current_network:
2662
+ # Switch to new network
2663
+ self._switch_network(new_network)
2664
+
2665
+ # Update connection dropdown options with new network's connections
2666
+ connection_options = list(self.conn_type_settings.keys())
2667
+ w_connection.options = connection_options
2668
+ if connection_options:
2669
+ w_connection.value = self.current_connection
2670
+
2671
+ # Recreate dynamic sliders for new connection
2672
+ self.dynamic_sliders = create_dynamic_sliders()
2673
+
2674
+ # Update UI
2675
+ update_ui_layout()
2676
+ update_ui()
2677
+
2678
+ except Exception as e:
2679
+ print(f"Error switching network: {e}")
2680
+
2681
+ def update_ui_layout():
2682
+ """
2683
+ Update the UI layout with new sliders and network dropdown.
2684
+
2685
+ This function reconstructs the entire UI layout including:
2686
+ - Network dropdown (if available) and connection dropdown in the top row
2687
+ - Button controls and input mode toggles
2688
+ - Parameter sliders arranged in columns
2689
+ """
2690
+ nonlocal ui, slider_columns
2691
+
2692
+ # Add the dynamic sliders to the UI
2693
+ slider_widgets = [slider for slider in self.dynamic_sliders.values()]
2694
+
2695
+ if slider_widgets:
2696
+ half = len(slider_widgets) // 2
2697
+ col1 = VBox(slider_widgets[:half])
2698
+ col2 = VBox(slider_widgets[half:])
2699
+ slider_columns = HBox([col1, col2])
2700
+ else:
2701
+ slider_columns = VBox([])
2702
+
2703
+ # Create button row with voltage clamp controls
2704
+ if w_vclamp.value: # Show voltage clamp amplitude input when toggle is on
2705
+ button_row = HBox([w_run, w_single, w_vclamp, w_vclamp_amp, w_input_mode])
2706
+ else: # Hide voltage clamp amplitude input when toggle is off
2707
+ button_row = HBox([w_run, w_single, w_vclamp, w_input_mode])
2708
+
2709
+ # Construct the top row - include network dropdown if available
2710
+ # This creates a horizontal layout with network dropdown (if present) and connection dropdown
2711
+ if w_network is not None:
2712
+ connection_row = HBox([w_network, w_connection])
2713
+ else:
2714
+ connection_row = HBox([w_connection])
2715
+ slider_row = HBox([w_input_freq, self.w_delay, self.w_duration])
2716
+ save_row = HBox([save_path_text, save_button])
2717
+
2718
+ ui = VBox([connection_row, button_row, slider_row, slider_columns, save_row])
2719
+
2720
+ # Function to update UI based on input mode
2721
+ def update_ui(*args):
2722
+ clear_output()
2723
+ display(ui)
2724
+ display(output_widget)
2725
+
2726
+ self.vclamp = w_vclamp.value
2727
+ # Update voltage clamp amplitude if voltage clamp is enabled
2728
+ if self.vclamp:
2729
+ self.conn['spec_settings']['vclamp_amp'] = w_vclamp_amp.value
2730
+ if hasattr(self, 'general_settings'):
2731
+ self.general_settings['vclamp_amp'] = w_vclamp_amp.value
2732
+
2733
+ self.input_mode = w_input_mode.value
2734
+ syn_props = {var: slider.value for var, slider in self.dynamic_sliders.items()}
2735
+ self._set_syn_prop(**syn_props)
2736
+
2737
+ # Clear previous results and run simulation
2738
+ output_widget.clear_output()
2739
+ with output_widget:
2740
+ if not self.input_mode:
2741
+ self._simulate_model(w_input_freq.value, self.w_delay.value, w_vclamp.value)
2742
+ else:
2743
+ self._simulate_model(w_input_freq.value, self.w_duration.value, w_vclamp.value)
2744
+ amp = self._response_amplitude()
2745
+ self._plot_model(
2746
+ [self.general_settings["tstart"] - self.nstim.interval / 3, self.tstop]
2747
+ )
2748
+ _ = self._calc_ppr_induction_recovery(amp)
2749
+
2750
+ # Function to switch between delay and duration sliders
2751
+ def switch_slider(*args):
2752
+ if w_input_mode.value:
2753
+ self.w_delay.layout.display = "none" # Hide delay slider
2754
+ self.w_duration.layout.display = "" # Show duration slider
2755
+ else:
2756
+ self.w_delay.layout.display = "" # Show delay slider
2757
+ self.w_duration.layout.display = "none" # Hide duration slider
2758
+
2759
+ # Function to handle voltage clamp toggle
2760
+ def on_vclamp_toggle(*args):
2761
+ """Handle voltage clamp toggle changes to show/hide amplitude input"""
2762
+ update_ui_layout()
2763
+ clear_output()
2764
+ display(ui)
2765
+ display(output_widget)
2766
+
2767
+ # Link widgets to their callback functions
2768
+ w_connection.observe(on_connection_change, names="value")
2769
+ # Link network dropdown callback only if network dropdown was created
2770
+ if w_network is not None:
2771
+ w_network.observe(on_network_change, names="value")
2772
+ w_input_mode.observe(switch_slider, names="value")
2773
+ w_vclamp.observe(on_vclamp_toggle, names="value")
2774
+
2775
+ # Hide the duration slider initially until the user selects it
2776
+ self.w_duration.layout.display = "none" # Hide duration slider
2777
+
2778
+ w_single.on_click(run_single_event)
2779
+ w_run.on_click(update_ui)
2780
+
2781
+ # Initial UI setup
2782
+ slider_columns = VBox([])
2783
+ ui = VBox([])
2784
+ update_ui_layout()
2785
+
2786
+ display(ui)
2787
+ update_ui()
2788
+
2789
+ def stp_frequency_response(
2790
+ self,
2791
+ freqs=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200],
2792
+ delay=250,
2793
+ plot=True,
2794
+ log_plot=True,
2795
+ ):
2240
2796
  """
2241
- Plot optimization results including convergence and final traces.
2797
+ Analyze synaptic response across different stimulation frequencies.
2798
+
2799
+ This method systematically tests how the synapse model responds to different
2800
+ stimulation frequencies, calculating key short-term plasticity (STP) metrics
2801
+ for each frequency.
2242
2802
 
2243
2803
  Parameters:
2244
2804
  -----------
2245
- result : SynapseOptimizationResult
2246
- Results from optimization as returned by optimize_parameters()
2805
+ freqs : list, optional
2806
+ List of frequencies to analyze (in Hz). Default covers a wide range from 1-200 Hz.
2807
+ delay : float, optional
2808
+ Delay between pulse trains in ms. Default is 250 ms.
2809
+ plot : bool, optional
2810
+ Whether to plot the results. Default is True.
2811
+ log_plot : bool, optional
2812
+ Whether to use logarithmic scale for frequency axis. Default is True.
2813
+
2814
+ Returns:
2815
+ --------
2816
+ dict
2817
+ Dictionary containing frequency-dependent metrics with keys:
2818
+ - 'frequencies': List of tested frequencies
2819
+ - 'ppr': Paired-pulse ratios at each frequency
2820
+ - 'simple_ppr': Simple paired-pulse ratios (2nd/1st pulse) at each frequency
2821
+ - 'induction': Induction values at each frequency
2822
+ - 'recovery': Recovery values at each frequency
2247
2823
 
2248
2824
  Notes:
2249
2825
  ------
2250
- This method generates three plots:
2251
- 1. Error convergence plot showing how the error decreased over iterations
2252
- 2. Parameter convergence plots showing how each parameter changed
2253
- 3. Final model response with the optimal parameters
2254
-
2255
- It also prints a summary of the optimization results including target vs. achieved
2256
- metrics and the optimal parameter values.
2257
- """
2258
- # Ensure errors are properly shaped for plotting
2259
- iterations = range(len(result.optimization_path))
2260
- errors = np.array([float(h["error"]) for h in result.optimization_path]).flatten()
2261
-
2262
- # Plot error convergence
2263
- fig1, ax1 = plt.subplots(figsize=(8, 5))
2264
- ax1.plot(iterations, errors, label="Error")
2265
- ax1.set_xlabel("Iteration")
2266
- ax1.set_ylabel("Error")
2267
- ax1.set_title("Error Convergence")
2268
- ax1.set_yscale("log")
2826
+ This method is particularly useful for characterizing the frequency-dependent
2827
+ behavior of synapses, such as identifying facilitating vs. depressing regimes
2828
+ or the frequency at which a synapse transitions between these behaviors.
2829
+ """
2830
+ results = {"frequencies": freqs, "ppr": [], "induction": [], "recovery": [], "simple_ppr": []}
2831
+
2832
+ # Store original state
2833
+ original_ispk = self.ispk
2834
+
2835
+ for freq in tqdm(freqs, desc="Analyzing frequencies"):
2836
+ self._simulate_model(freq, delay)
2837
+ amp = self._response_amplitude()
2838
+ ppr, induction, recovery, simple_ppr = self._calc_ppr_induction_recovery(amp, print_math=False)
2839
+
2840
+ results["ppr"].append(float(ppr))
2841
+ results["induction"].append(float(induction))
2842
+ results["recovery"].append(float(recovery))
2843
+ results["simple_ppr"].append(float(simple_ppr))
2844
+
2845
+ # Restore original state
2846
+ self.ispk = original_ispk
2847
+
2848
+ if plot:
2849
+ self._plot_frequency_analysis(results, log_plot=log_plot)
2850
+
2851
+ return results
2852
+
2853
+ def _plot_frequency_analysis(self, results, log_plot):
2854
+ """
2855
+ Plot the frequency-dependent synaptic properties.
2856
+
2857
+ Parameters:
2858
+ -----------
2859
+ results : dict
2860
+ Dictionary containing frequency analysis results with keys:
2861
+ - 'frequencies': List of tested frequencies
2862
+ - 'ppr': Paired-pulse ratios at each frequency
2863
+ - 'simple_ppr': Simple paired-pulse ratios at each frequency
2864
+ - 'induction': Induction values at each frequency
2865
+ - 'recovery': Recovery values at each frequency
2866
+ log_plot : bool
2867
+ Whether to use logarithmic scale for frequency axis
2868
+
2869
+ Notes:
2870
+ ------
2871
+ Creates a figure with three subplots showing:
2872
+ 1. Paired-pulse ratios (both normalized and simple) vs. frequency
2873
+ 2. Induction vs. frequency
2874
+ 3. Recovery vs. frequency
2875
+
2876
+ Each plot includes a horizontal reference line at y=0 or y=1 to indicate
2877
+ the boundary between facilitation and depression.
2878
+ """
2879
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
2880
+
2881
+ # Plot both PPR measures
2882
+ if log_plot:
2883
+ ax1.semilogx(results["frequencies"], results["ppr"], "o-", label="Normalized PPR")
2884
+ ax1.semilogx(results["frequencies"], results["simple_ppr"], "s-", label="Simple PPR")
2885
+ else:
2886
+ ax1.plot(results["frequencies"], results["ppr"], "o-", label="Normalized PPR")
2887
+ ax1.plot(results["frequencies"], results["simple_ppr"], "s-", label="Simple PPR")
2888
+ ax1.axhline(y=1, color="gray", linestyle="--", alpha=0.5)
2889
+ ax1.set_xlabel("Frequency (Hz)")
2890
+ ax1.set_ylabel("Paired Pulse Ratio")
2891
+ ax1.set_title("PPR vs Frequency")
2269
2892
  ax1.legend()
2893
+ ax1.grid(True)
2894
+
2895
+ # Plot Induction
2896
+ if log_plot:
2897
+ ax2.semilogx(results["frequencies"], results["induction"], "o-")
2898
+ else:
2899
+ ax2.plot(results["frequencies"], results["induction"], "o-")
2900
+ ax2.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
2901
+ ax2.set_xlabel("Frequency (Hz)")
2902
+ ax2.set_ylabel("Induction")
2903
+ ax2.set_title("Induction vs Frequency")
2904
+ ax2.grid(True)
2905
+
2906
+ # Plot Recovery
2907
+ if log_plot:
2908
+ ax3.semilogx(results["frequencies"], results["recovery"], "o-")
2909
+ else:
2910
+ ax3.plot(results["frequencies"], results["recovery"], "o-")
2911
+ ax3.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
2912
+ ax3.set_xlabel("Frequency (Hz)")
2913
+ ax3.set_ylabel("Recovery")
2914
+ ax3.set_title("Recovery vs Frequency")
2915
+ ax3.grid(True)
2916
+
2270
2917
  plt.tight_layout()
2271
2918
  plt.show()
2272
2919
 
2273
- # Plot parameter convergence
2274
- param_names = list(result.optimal_params.keys())
2275
- num_params = len(param_names)
2276
- fig2, axs = plt.subplots(nrows=num_params, ncols=1, figsize=(8, 5 * num_params))
2920
+ def generate_synaptic_table(self, stp_frequency=50.0, stp_delay=250.0, plot=True):
2921
+ """
2922
+ Generate a comprehensive table of synaptic parameters for all connections.
2923
+
2924
+ This method iterates through all available connections, runs simulations to
2925
+ characterize each synapse, and compiles the results into a pandas DataFrame.
2926
+
2927
+ Parameters:
2928
+ -----------
2929
+ stp_frequency : float, optional
2930
+ Frequency in Hz to use for STP (short-term plasticity) analysis. Default is 50.0 Hz.
2931
+ stp_delay : float, optional
2932
+ Delay in ms between pulse trains for STP analysis. Default is 250.0 ms.
2933
+ plot : bool, optional
2934
+ Whether to display the resulting table. Default is True.
2935
+
2936
+ Returns:
2937
+ --------
2938
+ pd.DataFrame
2939
+ DataFrame containing synaptic parameters for each connection with columns:
2940
+ - connection: Connection name
2941
+ - rise_time: 20-80% rise time (ms)
2942
+ - decay_time: Decay time constant (ms)
2943
+ - latency: Response latency (ms)
2944
+ - half_width: Response half-width (ms)
2945
+ - peak_amplitude: Peak synaptic current amplitude (pA)
2946
+ - baseline: Baseline current (pA)
2947
+ - ppr: Paired-pulse ratio (normalized)
2948
+ - simple_ppr: Simple paired-pulse ratio (2nd/1st pulse)
2949
+ - induction: STP induction measure
2950
+ - recovery: STP recovery measure
2951
+
2952
+ Notes:
2953
+ ------
2954
+ This method temporarily switches between connections to characterize each one,
2955
+ then restores the original connection. The STP metrics are calculated at the
2956
+ specified frequency and delay.
2957
+ """
2958
+ # Store original connection to restore later
2959
+ original_connection = self.current_connection
2960
+
2961
+ # Initialize results list
2962
+ results = []
2963
+
2964
+ print(f"Analyzing {len(self.conn_type_settings)} connections...")
2965
+
2966
+ for conn_name in tqdm(self.conn_type_settings.keys(), desc="Analyzing connections"):
2967
+ try:
2968
+ # Switch to this connection
2969
+ self._switch_connection(conn_name)
2970
+
2971
+ # Run single event analysis
2972
+ self.SingleEvent(plot_and_print=False)
2973
+
2974
+ # Get synaptic properties from the single event
2975
+ syn_props = self._get_syn_prop()
2976
+
2977
+ # Run STP analysis at specified frequency
2978
+ stp_results = self.stp_frequency_response(
2979
+ freqs=[stp_frequency],
2980
+ delay=stp_delay,
2981
+ plot=False,
2982
+ log_plot=False
2983
+ )
2984
+
2985
+ # Extract STP metrics for this frequency
2986
+ freq_idx = 0 # Only one frequency tested
2987
+ ppr = stp_results['ppr'][freq_idx]
2988
+ induction = stp_results['induction'][freq_idx]
2989
+ recovery = stp_results['recovery'][freq_idx]
2990
+ simple_ppr = stp_results['simple_ppr'][freq_idx]
2991
+
2992
+ # Compile results for this connection
2993
+ conn_results = {
2994
+ 'connection': conn_name,
2995
+ 'rise_time': float(self.rise_time),
2996
+ 'decay_time': float(self.decay_time),
2997
+ 'latency': float(syn_props.get('latency', 0)),
2998
+ 'half_width': float(syn_props.get('half_width', 0)),
2999
+ 'peak_amplitude': float(syn_props.get('amp', 0)),
3000
+ 'baseline': float(syn_props.get('baseline', 0)),
3001
+ 'ppr': float(ppr),
3002
+ 'simple_ppr': float(simple_ppr),
3003
+ 'induction': float(induction),
3004
+ 'recovery': float(recovery)
3005
+ }
3006
+
3007
+ results.append(conn_results)
3008
+
3009
+ except Exception as e:
3010
+ print(f"Warning: Failed to analyze connection '{conn_name}': {e}")
3011
+ # Add partial results if possible
3012
+ results.append({
3013
+ 'connection': conn_name,
3014
+ 'rise_time': float('nan'),
3015
+ 'decay_time': float('nan'),
3016
+ 'latency': float('nan'),
3017
+ 'half_width': float('nan'),
3018
+ 'peak_amplitude': float('nan'),
3019
+ 'baseline': float('nan'),
3020
+ 'ppr': float('nan'),
3021
+ 'simple_ppr': float('nan'),
3022
+ 'induction': float('nan'),
3023
+ 'recovery': float('nan')
3024
+ })
3025
+
3026
+ # Restore original connection
3027
+ if original_connection in self.conn_type_settings:
3028
+ self._switch_connection(original_connection)
3029
+
3030
+ # Create DataFrame
3031
+ df = pd.DataFrame(results)
3032
+
3033
+ # Set connection as index for better display
3034
+ df = df.set_index('connection')
3035
+
3036
+ if plot:
3037
+ # Display the table
3038
+ print("\nSynaptic Parameters Table:")
3039
+ print("=" * 80)
3040
+ display(df.round(4))
3041
+
3042
+ # Optional: Create a simple bar plot for key metrics
3043
+ try:
3044
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
3045
+ fig.suptitle(f'Synaptic Parameters Across Connections (STP at {stp_frequency}Hz)', fontsize=16)
3046
+
3047
+ # Plot rise/decay times
3048
+ df[['rise_time', 'decay_time']].plot(kind='bar', ax=axes[0,0])
3049
+ axes[0,0].set_title('Rise and Decay Times')
3050
+ axes[0,0].set_ylabel('Time (ms)')
3051
+ axes[0,0].tick_params(axis='x', rotation=45)
3052
+
3053
+ # Plot PPR metrics
3054
+ df[['ppr', 'simple_ppr']].plot(kind='bar', ax=axes[0,1])
3055
+ axes[0,1].set_title('Paired-Pulse Ratios')
3056
+ axes[0,1].axhline(y=1, color='gray', linestyle='--', alpha=0.5)
3057
+ axes[0,1].tick_params(axis='x', rotation=45)
3058
+
3059
+ # Plot induction
3060
+ df['induction'].plot(kind='bar', ax=axes[1,0], color='green')
3061
+ axes[1,0].set_title('STP Induction')
3062
+ axes[1,0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
3063
+ axes[1,0].set_ylabel('Induction')
3064
+ axes[1,0].tick_params(axis='x', rotation=45)
3065
+
3066
+ # Plot recovery
3067
+ df['recovery'].plot(kind='bar', ax=axes[1,1], color='orange')
3068
+ axes[1,1].set_title('STP Recovery')
3069
+ axes[1,1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
3070
+ axes[1,1].set_ylabel('Recovery')
3071
+ axes[1,1].tick_params(axis='x', rotation=45)
3072
+
3073
+ plt.tight_layout()
3074
+ plt.show()
3075
+
3076
+ except Exception as e:
3077
+ print(f"Warning: Could not create plots: {e}")
3078
+
3079
+ return df
3080
+
2277
3081
 
2278
- if num_params == 1:
2279
- axs = [axs]
3082
+ class GapJunctionTuner:
3083
+ def __init__(
3084
+ self,
3085
+ mechanisms_dir: Optional[str] = None,
3086
+ templates_dir: Optional[str] = None,
3087
+ config: Optional[str] = None,
3088
+ general_settings: Optional[dict] = None,
3089
+ conn_type_settings: Optional[dict] = None,
3090
+ hoc_cell: Optional[object] = None,
3091
+ ):
3092
+ """
3093
+ Initialize the GapJunctionTuner class.
2280
3094
 
2281
- for ax, param in zip(axs, param_names):
2282
- values = [float(h["params"][param]) for h in result.optimization_path]
2283
- ax.plot(iterations, values, label=f"{param}")
2284
- ax.set_xlabel("Iteration")
2285
- ax.set_ylabel("Parameter Value")
2286
- ax.set_title(f"Convergence of {param}")
2287
- ax.legend()
3095
+ Parameters:
3096
+ -----------
3097
+ mechanisms_dir : str
3098
+ Directory path containing the compiled mod files needed for NEURON mechanisms.
3099
+ templates_dir : str
3100
+ Directory path containing cell template files (.hoc or .py) loaded into NEURON.
3101
+ config : str
3102
+ Path to a BMTK config.json file. Can be used to load mechanisms, templates, and other settings.
3103
+ general_settings : dict
3104
+ General settings dictionary including parameters like simulation time step, duration, and temperature.
3105
+ conn_type_settings : dict
3106
+ A dictionary containing connection-specific settings for gap junctions.
3107
+ hoc_cell : object, optional
3108
+ An already loaded NEURON cell object. If provided, template loading and cell creation will be skipped.
3109
+ """
3110
+ self.hoc_cell = hoc_cell
2288
3111
 
2289
- plt.tight_layout()
2290
- plt.show()
3112
+ if hoc_cell is None:
3113
+ if config is None and (mechanisms_dir is None or templates_dir is None):
3114
+ raise ValueError(
3115
+ "Either a config file, both mechanisms_dir and templates_dir, or a hoc_cell must be provided."
3116
+ )
2291
3117
 
2292
- # Print final results
2293
- print("Optimization Results:")
2294
- print(f"Final Error: {float(result.error):.2e}\n")
2295
- print("Target Metrics:")
2296
- for metric, value in result.target_metrics.items():
2297
- achieved = result.achieved_metrics.get(metric)
2298
- if achieved is not None and metric != "amplitudes": # Skip amplitude array
2299
- print(f"{metric}: {float(achieved):.3f} (target: {float(value):.3f})")
2300
-
2301
- print("\nOptimal Parameters:")
2302
- for param, value in result.optimal_params.items():
2303
- print(f"{param}: {float(value):.3f}")
2304
-
2305
- # Plot final model response
2306
- if self.run_train_input:
2307
- self.tuner._plot_model(
2308
- [
2309
- self.tuner.general_settings["tstart"] - self.tuner.nstim.interval / 3,
2310
- self.tuner.tstop,
2311
- ]
2312
- )
2313
- amp = self.tuner._response_amplitude()
2314
- self.tuner._calc_ppr_induction_recovery(amp)
2315
- if self.run_single_event:
2316
- self.tuner.ispk = None
2317
- self.tuner.SingleEvent(plot_and_print=True)
3118
+ if config is None:
3119
+ neuron.load_mechanisms(mechanisms_dir)
3120
+ h.load_file(templates_dir)
3121
+ else:
3122
+ # this will load both mechs and templates
3123
+ load_templates_from_config(config)
2318
3124
 
3125
+ # Use default general settings if not provided, merge with user-provided
3126
+ if general_settings is None:
3127
+ self.general_settings: dict = DEFAULT_GAP_JUNCTION_GENERAL_SETTINGS.copy()
3128
+ else:
3129
+ self.general_settings = {**DEFAULT_GAP_JUNCTION_GENERAL_SETTINGS, **general_settings}
3130
+ self.conn_type_settings = conn_type_settings
3131
+
3132
+ self._syn_params_cache = {}
3133
+ self.config = config
3134
+ self.available_networks = []
3135
+ self.current_network = None
3136
+ self.last_figure = None
3137
+ if self.conn_type_settings is None and self.config is not None:
3138
+ self.conn_type_settings = self._build_conn_type_settings_from_config(self.config)
3139
+ if self.conn_type_settings is None or len(self.conn_type_settings) == 0:
3140
+ raise ValueError("conn_type_settings must be provided or config must be given to load gap junction connections from")
3141
+ self.current_connection = list(self.conn_type_settings.keys())[0]
3142
+ self.conn = self.conn_type_settings[self.current_connection]
2319
3143
 
2320
- # dataclass means just init the typehints as self.typehint. looks a bit cleaner
2321
- @dataclass
2322
- class GapOptimizationResult:
2323
- """Container for gap junction optimization results"""
3144
+ h.tstop = self.general_settings["tstart"] + self.general_settings["tdur"] + 100.0
3145
+ h.dt = self.general_settings["dt"] # Time step (resolution) of the simulation in ms
3146
+ h.steps_per_ms = 1 / h.dt
3147
+ h.celsius = self.general_settings["celsius"]
3148
+
3149
+ # Clean up any existing parallel context before setting up gap junctions
3150
+ try:
3151
+ pc_temp = h.ParallelContext()
3152
+ pc_temp.done() # Clean up any existing parallel context
3153
+ except:
3154
+ pass # Ignore errors if no existing context
3155
+
3156
+ # Force cleanup
3157
+ import gc
3158
+ gc.collect()
3159
+
3160
+ # set up gap junctions
3161
+ self.pc = h.ParallelContext()
3162
+
3163
+ # Use provided hoc_cell or create new cells
3164
+ if self.hoc_cell is not None:
3165
+ self.cell1 = self.hoc_cell
3166
+ # For gap junctions, we need two cells, so create a second one if using hoc_cell
3167
+ self.cell_name = self.conn['cell']
3168
+ self.cell2 = getattr(h, self.cell_name)()
3169
+ else:
3170
+ print(self.conn)
3171
+ self.cell_name = self.conn['cell']
3172
+ self.cell1 = getattr(h, self.cell_name)()
3173
+ self.cell2 = getattr(h, self.cell_name)()
3174
+
3175
+ self.icl = h.IClamp(self.cell1.soma[0](0.5))
3176
+ self.icl.delay = self.general_settings["tstart"]
3177
+ self.icl.dur = self.general_settings["tdur"]
3178
+ self.icl.amp = self.general_settings["iclamp_amp"] # nA
3179
+
3180
+ sec1 = list(self.cell1.all)[self.conn["sec_id"]]
3181
+ sec2 = list(self.cell2.all)[self.conn["sec_id"]]
3182
+
3183
+ # Use unique IDs to avoid conflicts with existing parallel context setups
3184
+ import time
3185
+ unique_id = int(time.time() * 1000) % 10000 # Use timestamp as unique base ID
3186
+
3187
+ self.pc.source_var(sec1(self.conn["sec_x"])._ref_v, unique_id, sec=sec1)
3188
+ self.gap_junc_1 = h.Gap(sec1(0.5))
3189
+ self.pc.target_var(self.gap_junc_1._ref_vgap, unique_id + 1)
3190
+
3191
+ self.pc.source_var(sec2(self.conn["sec_x"])._ref_v, unique_id + 1, sec=sec2)
3192
+ self.gap_junc_2 = h.Gap(sec2(0.5))
3193
+ self.pc.target_var(self.gap_junc_2._ref_vgap, unique_id)
3194
+
3195
+ self.pc.setup_transfer()
3196
+
3197
+ # Now it's safe to initialize NEURON
3198
+ h.finitialize()
3199
+
3200
+ def _load_synaptic_params_from_config(self, config: dict, dynamics_params: str) -> dict:
3201
+ try:
3202
+ # Get the synaptic models directory from config
3203
+ synaptic_models_dir = config.get('components', {}).get('synaptic_models_dir', '')
3204
+ if synaptic_models_dir:
3205
+ # Handle path variables
3206
+ if synaptic_models_dir.startswith('$'):
3207
+ # This is a placeholder, try to resolve it
3208
+ config_dir = os.path.dirname(config.get('config_path', ''))
3209
+ synaptic_models_dir = synaptic_models_dir.replace('$COMPONENTS_DIR',
3210
+ os.path.join(config_dir, 'components'))
3211
+ synaptic_models_dir = synaptic_models_dir.replace('$BASE_DIR', config_dir)
3212
+
3213
+ dynamics_file = os.path.join(synaptic_models_dir, dynamics_params)
3214
+
3215
+ if os.path.exists(dynamics_file):
3216
+ with open(dynamics_file, 'r') as f:
3217
+ return json.load(f)
3218
+ else:
3219
+ print(f"Warning: Dynamics params file not found: {dynamics_file}")
3220
+ except Exception as e:
3221
+ print(f"Warning: Error loading synaptic parameters: {e}")
3222
+
3223
+ return {}
2324
3224
 
2325
- optimal_resistance: float
2326
- achieved_cc: float
2327
- target_cc: float
2328
- error: float
2329
- optimization_path: List[Dict[str, float]]
3225
+ def _load_available_networks(self) -> None:
3226
+ """
3227
+ Load available network names from the config file for the network dropdown feature.
3228
+
3229
+ This method is automatically called during initialization when a config file is provided.
3230
+ It populates the available_networks list which enables the network dropdown in
3231
+ InteractiveTuner when multiple networks are available.
3232
+
3233
+ Network Dropdown Behavior:
3234
+ -------------------------
3235
+ - If only one network exists: No network dropdown is shown
3236
+ - If multiple networks exist: Network dropdown appears next to connection dropdown
3237
+ - Networks are loaded from the edges data in the config file
3238
+ - Current network defaults to the first available if not specified during init
3239
+ """
3240
+ if self.config is None:
3241
+ self.available_networks = []
3242
+ return
3243
+
3244
+ try:
3245
+ edges = load_edges_from_config(self.config)
3246
+ self.available_networks = list(edges.keys())
3247
+
3248
+ # Set current network to first available if not specified
3249
+ if self.current_network is None and self.available_networks:
3250
+ self.current_network = self.available_networks[0]
3251
+ except Exception as e:
3252
+ print(f"Warning: Could not load networks from config: {e}")
3253
+ self.available_networks = []
2330
3254
 
3255
+ def _build_conn_type_settings_from_config(self, config_path: str) -> Dict[str, dict]:
3256
+ # Load configuration and get nodes and edges using util.py methods
3257
+ config = load_config(config_path)
3258
+ # Ensure the config dict knows its source path so path substitutions can be resolved
3259
+ try:
3260
+ config['config_path'] = config_path
3261
+ except Exception:
3262
+ pass
3263
+ nodes = load_nodes_from_config(config_path)
3264
+ edges = load_edges_from_config(config_path)
3265
+
3266
+ conn_type_settings = {}
3267
+
3268
+ # Process all edge datasets
3269
+ for edge_dataset_name, edge_df in edges.items():
3270
+ if edge_df.empty:
3271
+ continue
3272
+
3273
+ # Merging with node data to get model templates
3274
+ source_node_df = None
3275
+ target_node_df = None
3276
+
3277
+ # First, try to deterministically parse the edge_dataset_name for patterns like '<src>_to_<tgt>'
3278
+ if '_to_' in edge_dataset_name:
3279
+ parts = edge_dataset_name.split('_to_')
3280
+ if len(parts) == 2:
3281
+ src_name, tgt_name = parts
3282
+ if src_name in nodes:
3283
+ source_node_df = nodes[src_name].add_prefix('source_')
3284
+ if tgt_name in nodes:
3285
+ target_node_df = nodes[tgt_name].add_prefix('target_')
3286
+
3287
+ # If not found by parsing name, fall back to inspecting a sample edge row
3288
+ if source_node_df is None or target_node_df is None:
3289
+ sample_edge = edge_df.iloc[0] if len(edge_df) > 0 else None
3290
+ if sample_edge is not None:
3291
+ source_pop_name = sample_edge.get('source_population', '')
3292
+ target_pop_name = sample_edge.get('target_population', '')
3293
+ if source_pop_name in nodes:
3294
+ source_node_df = nodes[source_pop_name].add_prefix('source_')
3295
+ if target_pop_name in nodes:
3296
+ target_node_df = nodes[target_pop_name].add_prefix('target_')
3297
+
3298
+ # As a last resort, attempt to heuristically match
3299
+ if source_node_df is None or target_node_df is None:
3300
+ for pop_name, node_df in nodes.items():
3301
+ if source_node_df is None and (edge_dataset_name.startswith(pop_name) or edge_dataset_name.endswith(pop_name)):
3302
+ source_node_df = node_df.add_prefix('source_')
3303
+ if target_node_df is None and (edge_dataset_name.startswith(pop_name) or edge_dataset_name.endswith(pop_name)):
3304
+ target_node_df = node_df.add_prefix('target_')
3305
+
3306
+ if source_node_df is None or target_node_df is None:
3307
+ print(f"Warning: Could not find node data for edge dataset {edge_dataset_name}")
3308
+ continue
3309
+
3310
+ # Merge edge data with source node info
3311
+ edges_with_source = pd.merge(
3312
+ edge_df.reset_index(),
3313
+ source_node_df,
3314
+ how='left',
3315
+ left_on='source_node_id',
3316
+ right_index=True
3317
+ )
3318
+
3319
+ # Merge with target node info
3320
+ edges_with_nodes = pd.merge(
3321
+ edges_with_source,
3322
+ target_node_df,
3323
+ how='left',
3324
+ left_on='target_node_id',
3325
+ right_index=True
3326
+ )
3327
+
3328
+ # Skip edge datasets that don't have gap junction information
3329
+ if 'is_gap_junction' not in edges_with_nodes.columns:
3330
+ continue
3331
+
3332
+ # Filter to only gap junction edges
3333
+ # Handle NaN values in is_gap_junction column
3334
+ gap_junction_mask = edges_with_nodes['is_gap_junction'].fillna(False) == True
3335
+ gap_junction_edges = edges_with_nodes[gap_junction_mask]
3336
+ if gap_junction_edges.empty:
3337
+ continue
3338
+
3339
+ # Get unique edge types from the gap junction edges
3340
+ if 'edge_type_id' in gap_junction_edges.columns:
3341
+ edge_types = gap_junction_edges['edge_type_id'].unique()
3342
+ else:
3343
+ edge_types = [None] # Single edge type
3344
+
3345
+ # Process each edge type
3346
+ for edge_type_id in edge_types:
3347
+ # Filter edges for this type
3348
+ if edge_type_id is not None:
3349
+ edge_type_data = gap_junction_edges[gap_junction_edges['edge_type_id'] == edge_type_id]
3350
+ else:
3351
+ edge_type_data = gap_junction_edges
3352
+
3353
+ if len(edge_type_data) == 0:
3354
+ continue
3355
+
3356
+ # Get representative edge for this type
3357
+ edge_info = edge_type_data.iloc[0]
3358
+
3359
+ # Process gap junction
3360
+ source_model_template = edge_info.get('source_model_template', '')
3361
+ target_model_template = edge_info.get('target_model_template', '')
3362
+
3363
+ source_cell_type = source_model_template.replace('hoc:', '') if source_model_template.startswith('hoc:') else source_model_template
3364
+ target_cell_type = target_model_template.replace('hoc:', '') if target_model_template.startswith('hoc:') else target_model_template
3365
+
3366
+ if source_cell_type != target_cell_type:
3367
+ continue # Only process gap junctions between same cell types
3368
+
3369
+ source_pop = edge_info.get('source_pop_name', '')
3370
+ target_pop = edge_info.get('target_pop_name', '')
3371
+
3372
+ conn_name = f"{source_pop}2{target_pop}_gj"
3373
+ if edge_type_id is not None:
3374
+ conn_name += f"_type_{edge_type_id}"
3375
+
3376
+ conn_settings = {
3377
+ 'cell': source_cell_type,
3378
+ 'sec_id': 0,
3379
+ 'sec_x': 0.5,
3380
+ 'iclamp_amp': -0.01,
3381
+ 'spec_syn_param': {}
3382
+ }
3383
+
3384
+ # Load dynamics params
3385
+ dynamics_file_name = edge_info.get('dynamics_params', '')
3386
+ if dynamics_file_name and dynamics_file_name.upper() != 'NULL':
3387
+ try:
3388
+ syn_params = self._load_synaptic_params_from_config(config, dynamics_file_name)
3389
+ conn_settings['spec_syn_param'] = syn_params
3390
+ except Exception as e:
3391
+ print(f"Warning: could not load dynamics_params file '{dynamics_file_name}': {e}")
3392
+
3393
+ conn_type_settings[conn_name] = conn_settings
3394
+
3395
+ return conn_type_settings
2331
3396
 
2332
- class GapJunctionOptimizer:
2333
- def __init__(self, tuner):
3397
+ def _switch_connection(self, new_connection: str) -> None:
2334
3398
  """
2335
- Initialize the gap junction optimizer
2336
-
3399
+ Switch to a different gap junction connection and update all related properties.
3400
+
2337
3401
  Parameters:
2338
3402
  -----------
2339
- tuner : GapJunctionTuner
2340
- Instance of the GapJunctionTuner class
3403
+ new_connection : str
3404
+ Name of the new connection type to switch to.
2341
3405
  """
2342
- self.tuner = tuner
2343
- self.optimization_history = []
3406
+ if new_connection not in self.conn_type_settings:
3407
+ raise ValueError(f"Connection '{new_connection}' not found in conn_type_settings")
3408
+
3409
+ # Update current connection
3410
+ self.current_connection = new_connection
3411
+ self.conn = self.conn_type_settings[new_connection]
3412
+
3413
+ # Check if cell type changed
3414
+ new_cell_name = self.conn['cell']
3415
+ if self.cell_name != new_cell_name:
3416
+ self.cell_name = new_cell_name
3417
+
3418
+ # Recreate cells
3419
+ if self.hoc_cell is None:
3420
+ self.cell1 = getattr(h, self.cell_name)()
3421
+ self.cell2 = getattr(h, self.cell_name)()
3422
+ else:
3423
+ # For hoc_cell, recreate the second cell
3424
+ self.cell2 = getattr(h, self.cell_name)()
3425
+
3426
+ # Recreate IClamp
3427
+ self.icl = h.IClamp(self.cell1.soma[0](0.5))
3428
+ self.icl.delay = self.general_settings["tstart"]
3429
+ self.icl.dur = self.general_settings["tdur"]
3430
+ self.icl.amp = self.general_settings["iclamp_amp"]
3431
+ else:
3432
+ # Update IClamp parameters even if same cell type
3433
+ self.icl.amp = self.general_settings["iclamp_amp"]
3434
+
3435
+ # Always recreate gap junctions when switching connections
3436
+ # (even for same cell type, sec_id or sec_x might differ)
3437
+
3438
+ # Clean up previous gap junctions and parallel context
3439
+ if hasattr(self, 'gap_junc_1'):
3440
+ del self.gap_junc_1
3441
+ if hasattr(self, 'gap_junc_2'):
3442
+ del self.gap_junc_2
3443
+
3444
+ # Properly clean up the existing parallel context
3445
+ if hasattr(self, 'pc'):
3446
+ self.pc.done() # Clean up existing parallel context
3447
+
3448
+ # Force garbage collection and reset NEURON state
3449
+ import gc
3450
+ gc.collect()
3451
+ h.finitialize()
3452
+
3453
+ # Create a fresh parallel context after cleanup
3454
+ self.pc = h.ParallelContext()
3455
+
3456
+ try:
3457
+ sec1 = list(self.cell1.all)[self.conn["sec_id"]]
3458
+ sec2 = list(self.cell2.all)[self.conn["sec_id"]]
3459
+
3460
+ # Use unique IDs to avoid conflicts with existing parallel context setups
3461
+ import time
3462
+ unique_id = int(time.time() * 1000) % 10000 # Use timestamp as unique base ID
3463
+
3464
+ self.pc.source_var(sec1(self.conn["sec_x"])._ref_v, unique_id, sec=sec1)
3465
+ self.gap_junc_1 = h.Gap(sec1(0.5))
3466
+ self.pc.target_var(self.gap_junc_1._ref_vgap, unique_id + 1)
3467
+
3468
+ self.pc.source_var(sec2(self.conn["sec_x"])._ref_v, unique_id + 1, sec=sec2)
3469
+ self.gap_junc_2 = h.Gap(sec2(0.5))
3470
+ self.pc.target_var(self.gap_junc_2._ref_vgap, unique_id)
3471
+
3472
+ self.pc.setup_transfer()
3473
+ except Exception as e:
3474
+ print(f"Error setting up gap junctions: {e}")
3475
+ # Try to continue with basic setup
3476
+ self.gap_junc_1 = h.Gap(list(self.cell1.all)[self.conn["sec_id"]](0.5))
3477
+ self.gap_junc_2 = h.Gap(list(self.cell2.all)[self.conn["sec_id"]](0.5))
3478
+
3479
+ # Reset NEURON state after complete setup
3480
+ h.finitialize()
3481
+
3482
+ print(f"Successfully switched to connection: {new_connection}")
2344
3483
 
2345
- def _objective_function(self, resistance: float, target_cc: float) -> float:
3484
+ def model(self, resistance):
2346
3485
  """
2347
- Calculate error between achieved and target coupling coefficient
3486
+ Run a simulation with a specified gap junction resistance.
2348
3487
 
2349
3488
  Parameters:
2350
3489
  -----------
2351
3490
  resistance : float
2352
- Gap junction resistance to try
2353
- target_cc : float
2354
- Target coupling coefficient to match
2355
-
2356
- Returns:
2357
- --------
2358
- float : Error between achieved and target coupling coefficient
2359
- """
2360
- # Run model with current resistance
2361
- self.tuner.model(resistance)
2362
-
2363
- # Calculate coupling coefficient
2364
- achieved_cc = self.tuner.coupling_coefficient(
2365
- self.tuner.t_vec,
2366
- self.tuner.soma_v_1,
2367
- self.tuner.soma_v_2,
2368
- self.tuner.general_settings["tstart"],
2369
- self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
2370
- )
2371
-
2372
- # Calculate error
2373
- error = (achieved_cc - target_cc) ** 2 # MSE
2374
-
2375
- # Store history
2376
- self.optimization_history.append(
2377
- {"resistance": resistance, "achieved_cc": achieved_cc, "error": error}
2378
- )
2379
-
2380
- return error
2381
-
2382
- def optimize_resistance(
2383
- self, target_cc: float, resistance_bounds: tuple = (1e-4, 1e-2), method: str = "bounded"
2384
- ) -> GapOptimizationResult:
2385
- """
2386
- Optimize gap junction resistance to achieve a target coupling coefficient.
2387
-
2388
- Parameters:
2389
- -----------
2390
- target_cc : float
2391
- Target coupling coefficient to achieve (between 0 and 1)
2392
- resistance_bounds : tuple, optional
2393
- (min, max) bounds for resistance search in MOhm. Default is (1e-4, 1e-2).
2394
- method : str, optional
2395
- Optimization method to use. Default is 'bounded' which works well
2396
- for single-parameter optimization.
2397
-
2398
- Returns:
2399
- --------
2400
- GapOptimizationResult
2401
- Container with optimization results including:
2402
- - optimal_resistance: The optimized resistance value
2403
- - achieved_cc: The coupling coefficient achieved with the optimal resistance
2404
- - target_cc: The target coupling coefficient
2405
- - error: The final error (squared difference between target and achieved)
2406
- - optimization_path: List of all values tried during optimization
3491
+ The gap junction resistance value (in MOhm) to use for the simulation.
2407
3492
 
2408
3493
  Notes:
2409
3494
  ------
2410
- Uses scipy.optimize.minimize_scalar with bounded method, which is
2411
- appropriate for this single-parameter optimization problem.
3495
+ This method sets up the gap junction resistance, initializes recording vectors for time
3496
+ and membrane voltages of both cells, and runs the NEURON simulation.
2412
3497
  """
2413
- self.optimization_history = []
2414
-
2415
- # Run optimization
2416
- result = minimize_scalar(
2417
- self._objective_function, args=(target_cc,), bounds=resistance_bounds, method=method
2418
- )
3498
+ self.gap_junc_1.g = resistance
3499
+ self.gap_junc_2.g = resistance
2419
3500
 
2420
- # Run final model with optimal resistance
2421
- self.tuner.model(result.x)
2422
- final_cc = self.tuner.coupling_coefficient(
2423
- self.tuner.t_vec,
2424
- self.tuner.soma_v_1,
2425
- self.tuner.soma_v_2,
2426
- self.tuner.general_settings["tstart"],
2427
- self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
2428
- )
3501
+ t_vec = h.Vector()
3502
+ soma_v_1 = h.Vector()
3503
+ soma_v_2 = h.Vector()
3504
+ t_vec.record(h._ref_t)
3505
+ soma_v_1.record(self.cell1.soma[0](0.5)._ref_v)
3506
+ soma_v_2.record(self.cell2.soma[0](0.5)._ref_v)
2429
3507
 
2430
- # Package up our results
2431
- optimization_result = GapOptimizationResult(
2432
- optimal_resistance=result.x,
2433
- achieved_cc=final_cc,
2434
- target_cc=target_cc,
2435
- error=result.fun,
2436
- optimization_path=self.optimization_history,
2437
- )
3508
+ self.t_vec = t_vec
3509
+ self.soma_v_1 = soma_v_1
3510
+ self.soma_v_2 = soma_v_2
2438
3511
 
2439
- return optimization_result
3512
+ h.finitialize(-70 * mV)
3513
+ h.continuerun(h.tstop * ms)
2440
3514
 
2441
- def plot_optimization_results(self, result: GapOptimizationResult):
3515
+ def plot_model(self):
2442
3516
  """
2443
- Plot optimization results including convergence and final voltage traces
3517
+ Plot the voltage traces of both cells to visualize gap junction coupling.
2444
3518
 
2445
- Parameters:
2446
- -----------
2447
- result : GapOptimizationResult
2448
- Results from optimization
3519
+ This method creates a plot showing the membrane potential of both cells over time,
3520
+ highlighting the effect of gap junction coupling when a current step is applied to cell 1.
2449
3521
  """
2450
- fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
2451
-
2452
- # Plot voltage traces
2453
3522
  t_range = [
2454
- self.tuner.general_settings["tstart"] - 100.0,
2455
- self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"] + 100.0,
3523
+ self.general_settings["tstart"] - 100.0,
3524
+ self.general_settings["tstart"] + self.general_settings["tdur"] + 100.0,
2456
3525
  ]
2457
- t = np.array(self.tuner.t_vec)
2458
- v1 = np.array(self.tuner.soma_v_1)
2459
- v2 = np.array(self.tuner.soma_v_2)
3526
+ t = np.array(self.t_vec)
3527
+ v1 = np.array(self.soma_v_1)
3528
+ v2 = np.array(self.soma_v_2)
2460
3529
  tidx = (t >= t_range[0]) & (t <= t_range[1])
2461
3530
 
2462
- ax1.plot(t[tidx], v1[tidx], "b", label=f"{self.tuner.cell_name} 1")
2463
- ax1.plot(t[tidx], v2[tidx], "r", label=f"{self.tuner.cell_name} 2")
2464
- ax1.set_xlabel("Time (ms)")
2465
- ax1.set_ylabel("Membrane Voltage (mV)")
2466
- ax1.legend()
2467
- ax1.set_title("Optimized Voltage Traces")
2468
-
2469
- # Plot error convergence
2470
- errors = [h["error"] for h in result.optimization_path]
2471
- ax2.plot(errors)
2472
- ax2.set_xlabel("Iteration")
2473
- ax2.set_ylabel("Error")
2474
- ax2.set_title("Error Convergence")
2475
- ax2.set_yscale("log")
2476
-
2477
- # Plot resistance convergence
2478
- resistances = [h["resistance"] for h in result.optimization_path]
2479
- ax3.plot(resistances)
2480
- ax3.set_xlabel("Iteration")
2481
- ax3.set_ylabel("Resistance")
2482
- ax3.set_title("Resistance Convergence")
2483
- ax3.set_yscale("log")
2484
-
2485
- # Print final results
2486
- result_text = (
2487
- f"Optimal Resistance: {result.optimal_resistance:.2e}\n"
2488
- f"Target CC: {result.target_cc:.3f}\n"
2489
- f"Achieved CC: {result.achieved_cc:.3f}\n"
2490
- f"Final Error: {result.error:.2e}"
2491
- )
2492
- ax4.text(0.1, 0.7, result_text, transform=ax4.transAxes, fontsize=10)
2493
- ax4.axis("off")
2494
-
2495
- plt.tight_layout()
2496
- plt.show()
3531
+ plt.figure()
3532
+ plt.plot(t[tidx], v1[tidx], "b", label=f"{self.cell_name} 1")
3533
+ plt.plot(t[tidx], v2[tidx], "r", label=f"{self.cell_name} 2")
3534
+ plt.title(f"{self.cell_name} gap junction")
3535
+ plt.xlabel("Time (ms)")
3536
+ plt.ylabel("Membrane Voltage (mV)")
3537
+ plt.legend()
3538
+ self.last_figure = plt.gcf()
2497
3539
 
2498
- def parameter_sweep(self, resistance_range: np.ndarray) -> dict:
3540
+ def coupling_coefficient(self, t, v1, v2, t_start, t_end, dt=h.dt):
2499
3541
  """
2500
- Perform a parameter sweep across different resistance values.
3542
+ Calculate the coupling coefficient between two cells connected by a gap junction.
2501
3543
 
2502
3544
  Parameters:
2503
3545
  -----------
2504
- resistance_range : np.ndarray
2505
- Array of resistance values to test.
3546
+ t : array-like
3547
+ Time vector.
3548
+ v1 : array-like
3549
+ Voltage trace of the cell receiving the current injection.
3550
+ v2 : array-like
3551
+ Voltage trace of the coupled cell.
3552
+ t_start : float
3553
+ Start time for calculating the steady-state voltage change.
3554
+ t_end : float
3555
+ End time for calculating the steady-state voltage change.
3556
+ dt : float, optional
3557
+ Time step of the simulation. Default is h.dt.
2506
3558
 
2507
3559
  Returns:
2508
3560
  --------
2509
- dict
2510
- Dictionary containing the results of the parameter sweep, with keys:
2511
- - 'resistance': List of resistance values tested
2512
- - 'coupling_coefficient': Corresponding coupling coefficients
2513
-
2514
- Notes:
2515
- ------
2516
- This method is useful for understanding the relationship between gap junction
2517
- resistance and coupling coefficient before attempting optimization.
2518
- """
2519
- results = {"resistance": [], "coupling_coefficient": []}
2520
-
2521
- for resistance in tqdm(resistance_range, desc="Sweeping resistance values"):
2522
- self.tuner.model(resistance)
2523
- cc = self.tuner.coupling_coefficient(
2524
- self.tuner.t_vec,
2525
- self.tuner.soma_v_1,
2526
- self.tuner.soma_v_2,
2527
- self.tuner.general_settings["tstart"],
2528
- self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
2529
- )
2530
-
2531
- results["resistance"].append(resistance)
2532
- results["coupling_coefficient"].append(cc)
2533
-
2534
- return results
3561
+ float
3562
+ The coupling coefficient, defined as the ratio of voltage change in cell 2
3563
+ to voltage change in cell 1 (ΔV₂/ΔV₁).
3564
+ """
3565
+ t = np.asarray(t)
3566
+ v1 = np.asarray(v1)
3567
+ v2 = np