bmtool 0.7.8__tar.gz → 0.7.8.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bmtool might be problematic. Click here for more details.
- {bmtool-0.7.8 → bmtool-0.7.8.1}/PKG-INFO +1 -1
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/SLURM.py +15 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/synapses.py +548 -443
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool.egg-info/PKG-INFO +1 -1
- {bmtool-0.7.8 → bmtool-0.7.8.1}/setup.py +1 -1
- {bmtool-0.7.8 → bmtool-0.7.8.1}/LICENSE +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/README.md +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/__init__.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/__main__.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/analysis/__init__.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/analysis/entrainment.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/analysis/lfp.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/analysis/netcon_reports.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/analysis/spikes.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/bmplot/__init__.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/bmplot/connections.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/bmplot/entrainment.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/bmplot/lfp.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/bmplot/netcon_reports.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/bmplot/spikes.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/connectors.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/debug/__init__.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/debug/commands.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/debug/debug.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/graphs.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/manage.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/plot_commands.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/singlecell.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/util/__init__.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/util/commands.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/util/neuron/__init__.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/util/neuron/celltuner.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool/util/util.py +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool.egg-info/SOURCES.txt +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool.egg-info/dependency_links.txt +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool.egg-info/entry_points.txt +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool.egg-info/requires.txt +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/bmtool.egg-info/top_level.txt +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/pyproject.toml +0 -0
- {bmtool-0.7.8 → bmtool-0.7.8.1}/setup.cfg +0 -0
|
@@ -409,6 +409,16 @@ class BlockRunner:
|
|
|
409
409
|
self.param_name = param_name
|
|
410
410
|
self.json_file_path = json_file_path
|
|
411
411
|
self.syn_dict = syn_dict
|
|
412
|
+
# Store original component paths to restore later
|
|
413
|
+
self.original_component_paths = [block.component_path for block in self.blocks]
|
|
414
|
+
|
|
415
|
+
def restore_component_paths(self):
|
|
416
|
+
"""
|
|
417
|
+
Restores all blocks' component_path to their original values.
|
|
418
|
+
"""
|
|
419
|
+
for i, block in enumerate(self.blocks):
|
|
420
|
+
block.component_path = self.original_component_paths[i]
|
|
421
|
+
print("Component paths restored to original values.", flush=True)
|
|
412
422
|
|
|
413
423
|
def submit_blocks_sequentially(self):
|
|
414
424
|
"""
|
|
@@ -473,6 +483,8 @@ class BlockRunner:
|
|
|
473
483
|
|
|
474
484
|
print(f"Block {block.block_name} completed.", flush=True)
|
|
475
485
|
print("All blocks are done!", flush=True)
|
|
486
|
+
# Restore component paths to their original values
|
|
487
|
+
self.restore_component_paths()
|
|
476
488
|
if self.webhook:
|
|
477
489
|
message = "SIMULATION UPDATE: Simulation are Done!"
|
|
478
490
|
send_teams_message(self.webhook, message)
|
|
@@ -531,6 +543,9 @@ class BlockRunner:
|
|
|
531
543
|
print(f"Waiting for the last block {i} to complete...")
|
|
532
544
|
time.sleep(self.check_interval)
|
|
533
545
|
|
|
546
|
+
print("All blocks are done!", flush=True)
|
|
547
|
+
# Restore component paths to their original values
|
|
548
|
+
self.restore_component_paths()
|
|
534
549
|
if self.webhook:
|
|
535
550
|
message = "SIMULATION UPDATE: Simulations are Done!"
|
|
536
551
|
send_teams_message(self.webhook, message)
|
|
@@ -3078,490 +3078,595 @@ class GapJunctionTuner:
|
|
|
3078
3078
|
|
|
3079
3079
|
return df
|
|
3080
3080
|
|
|
3081
|
+
# optimizers!
|
|
3081
3082
|
|
|
3082
|
-
class GapJunctionTuner:
|
|
3083
|
-
def __init__(
|
|
3084
|
-
self,
|
|
3085
|
-
mechanisms_dir: Optional[str] = None,
|
|
3086
|
-
templates_dir: Optional[str] = None,
|
|
3087
|
-
config: Optional[str] = None,
|
|
3088
|
-
general_settings: Optional[dict] = None,
|
|
3089
|
-
conn_type_settings: Optional[dict] = None,
|
|
3090
|
-
hoc_cell: Optional[object] = None,
|
|
3091
|
-
):
|
|
3092
|
-
"""
|
|
3093
|
-
Initialize the GapJunctionTuner class.
|
|
3094
|
-
|
|
3095
|
-
Parameters:
|
|
3096
|
-
-----------
|
|
3097
|
-
mechanisms_dir : str
|
|
3098
|
-
Directory path containing the compiled mod files needed for NEURON mechanisms.
|
|
3099
|
-
templates_dir : str
|
|
3100
|
-
Directory path containing cell template files (.hoc or .py) loaded into NEURON.
|
|
3101
|
-
config : str
|
|
3102
|
-
Path to a BMTK config.json file. Can be used to load mechanisms, templates, and other settings.
|
|
3103
|
-
general_settings : dict
|
|
3104
|
-
General settings dictionary including parameters like simulation time step, duration, and temperature.
|
|
3105
|
-
conn_type_settings : dict
|
|
3106
|
-
A dictionary containing connection-specific settings for gap junctions.
|
|
3107
|
-
hoc_cell : object, optional
|
|
3108
|
-
An already loaded NEURON cell object. If provided, template loading and cell creation will be skipped.
|
|
3109
|
-
"""
|
|
3110
|
-
self.hoc_cell = hoc_cell
|
|
3111
|
-
|
|
3112
|
-
if hoc_cell is None:
|
|
3113
|
-
if config is None and (mechanisms_dir is None or templates_dir is None):
|
|
3114
|
-
raise ValueError(
|
|
3115
|
-
"Either a config file, both mechanisms_dir and templates_dir, or a hoc_cell must be provided."
|
|
3116
|
-
)
|
|
3117
|
-
|
|
3118
|
-
if config is None:
|
|
3119
|
-
neuron.load_mechanisms(mechanisms_dir)
|
|
3120
|
-
h.load_file(templates_dir)
|
|
3121
|
-
else:
|
|
3122
|
-
# this will load both mechs and templates
|
|
3123
|
-
load_templates_from_config(config)
|
|
3124
3083
|
|
|
3125
|
-
|
|
3126
|
-
|
|
3127
|
-
|
|
3128
|
-
else:
|
|
3129
|
-
self.general_settings = {**DEFAULT_GAP_JUNCTION_GENERAL_SETTINGS, **general_settings}
|
|
3130
|
-
self.conn_type_settings = conn_type_settings
|
|
3084
|
+
@dataclass
|
|
3085
|
+
class SynapseOptimizationResult:
|
|
3086
|
+
"""Container for synaptic parameter optimization results"""
|
|
3131
3087
|
|
|
3132
|
-
|
|
3133
|
-
|
|
3134
|
-
|
|
3135
|
-
|
|
3136
|
-
|
|
3137
|
-
if self.conn_type_settings is None and self.config is not None:
|
|
3138
|
-
self.conn_type_settings = self._build_conn_type_settings_from_config(self.config)
|
|
3139
|
-
if self.conn_type_settings is None or len(self.conn_type_settings) == 0:
|
|
3140
|
-
raise ValueError("conn_type_settings must be provided or config must be given to load gap junction connections from")
|
|
3141
|
-
self.current_connection = list(self.conn_type_settings.keys())[0]
|
|
3142
|
-
self.conn = self.conn_type_settings[self.current_connection]
|
|
3088
|
+
optimal_params: Dict[str, float]
|
|
3089
|
+
achieved_metrics: Dict[str, float]
|
|
3090
|
+
target_metrics: Dict[str, float]
|
|
3091
|
+
error: float
|
|
3092
|
+
optimization_path: List[Dict[str, float]]
|
|
3143
3093
|
|
|
3144
|
-
h.tstop = self.general_settings["tstart"] + self.general_settings["tdur"] + 100.0
|
|
3145
|
-
h.dt = self.general_settings["dt"] # Time step (resolution) of the simulation in ms
|
|
3146
|
-
h.steps_per_ms = 1 / h.dt
|
|
3147
|
-
h.celsius = self.general_settings["celsius"]
|
|
3148
3094
|
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
|
|
3153
|
-
except:
|
|
3154
|
-
pass # Ignore errors if no existing context
|
|
3155
|
-
|
|
3156
|
-
# Force cleanup
|
|
3157
|
-
import gc
|
|
3158
|
-
gc.collect()
|
|
3159
|
-
|
|
3160
|
-
# set up gap junctions
|
|
3161
|
-
self.pc = h.ParallelContext()
|
|
3162
|
-
|
|
3163
|
-
# Use provided hoc_cell or create new cells
|
|
3164
|
-
if self.hoc_cell is not None:
|
|
3165
|
-
self.cell1 = self.hoc_cell
|
|
3166
|
-
# For gap junctions, we need two cells, so create a second one if using hoc_cell
|
|
3167
|
-
self.cell_name = self.conn['cell']
|
|
3168
|
-
self.cell2 = getattr(h, self.cell_name)()
|
|
3169
|
-
else:
|
|
3170
|
-
print(self.conn)
|
|
3171
|
-
self.cell_name = self.conn['cell']
|
|
3172
|
-
self.cell1 = getattr(h, self.cell_name)()
|
|
3173
|
-
self.cell2 = getattr(h, self.cell_name)()
|
|
3174
|
-
|
|
3175
|
-
self.icl = h.IClamp(self.cell1.soma[0](0.5))
|
|
3176
|
-
self.icl.delay = self.general_settings["tstart"]
|
|
3177
|
-
self.icl.dur = self.general_settings["tdur"]
|
|
3178
|
-
self.icl.amp = self.general_settings["iclamp_amp"] # nA
|
|
3179
|
-
|
|
3180
|
-
sec1 = list(self.cell1.all)[self.conn["sec_id"]]
|
|
3181
|
-
sec2 = list(self.cell2.all)[self.conn["sec_id"]]
|
|
3182
|
-
|
|
3183
|
-
# Use unique IDs to avoid conflicts with existing parallel context setups
|
|
3184
|
-
import time
|
|
3185
|
-
unique_id = int(time.time() * 1000) % 10000 # Use timestamp as unique base ID
|
|
3186
|
-
|
|
3187
|
-
self.pc.source_var(sec1(self.conn["sec_x"])._ref_v, unique_id, sec=sec1)
|
|
3188
|
-
self.gap_junc_1 = h.Gap(sec1(0.5))
|
|
3189
|
-
self.pc.target_var(self.gap_junc_1._ref_vgap, unique_id + 1)
|
|
3190
|
-
|
|
3191
|
-
self.pc.source_var(sec2(self.conn["sec_x"])._ref_v, unique_id + 1, sec=sec2)
|
|
3192
|
-
self.gap_junc_2 = h.Gap(sec2(0.5))
|
|
3193
|
-
self.pc.target_var(self.gap_junc_2._ref_vgap, unique_id)
|
|
3194
|
-
|
|
3195
|
-
self.pc.setup_transfer()
|
|
3196
|
-
|
|
3197
|
-
# Now it's safe to initialize NEURON
|
|
3198
|
-
h.finitialize()
|
|
3199
|
-
|
|
3200
|
-
def _load_synaptic_params_from_config(self, config: dict, dynamics_params: str) -> dict:
|
|
3201
|
-
try:
|
|
3202
|
-
# Get the synaptic models directory from config
|
|
3203
|
-
synaptic_models_dir = config.get('components', {}).get('synaptic_models_dir', '')
|
|
3204
|
-
if synaptic_models_dir:
|
|
3205
|
-
# Handle path variables
|
|
3206
|
-
if synaptic_models_dir.startswith('$'):
|
|
3207
|
-
# This is a placeholder, try to resolve it
|
|
3208
|
-
config_dir = os.path.dirname(config.get('config_path', ''))
|
|
3209
|
-
synaptic_models_dir = synaptic_models_dir.replace('$COMPONENTS_DIR',
|
|
3210
|
-
os.path.join(config_dir, 'components'))
|
|
3211
|
-
synaptic_models_dir = synaptic_models_dir.replace('$BASE_DIR', config_dir)
|
|
3212
|
-
|
|
3213
|
-
dynamics_file = os.path.join(synaptic_models_dir, dynamics_params)
|
|
3214
|
-
|
|
3215
|
-
if os.path.exists(dynamics_file):
|
|
3216
|
-
with open(dynamics_file, 'r') as f:
|
|
3217
|
-
return json.load(f)
|
|
3218
|
-
else:
|
|
3219
|
-
print(f"Warning: Dynamics params file not found: {dynamics_file}")
|
|
3220
|
-
except Exception as e:
|
|
3221
|
-
print(f"Warning: Error loading synaptic parameters: {e}")
|
|
3222
|
-
|
|
3223
|
-
return {}
|
|
3095
|
+
class SynapseOptimizer:
|
|
3096
|
+
def __init__(self, tuner):
|
|
3097
|
+
"""
|
|
3098
|
+
Initialize the synapse optimizer with parameter scaling
|
|
3224
3099
|
|
|
3225
|
-
|
|
3100
|
+
Parameters:
|
|
3101
|
+
-----------
|
|
3102
|
+
tuner : SynapseTuner
|
|
3103
|
+
Instance of the SynapseTuner class
|
|
3226
3104
|
"""
|
|
3227
|
-
|
|
3228
|
-
|
|
3229
|
-
|
|
3230
|
-
|
|
3231
|
-
|
|
3232
|
-
|
|
3233
|
-
Network Dropdown Behavior:
|
|
3234
|
-
-------------------------
|
|
3235
|
-
- If only one network exists: No network dropdown is shown
|
|
3236
|
-
- If multiple networks exist: Network dropdown appears next to connection dropdown
|
|
3237
|
-
- Networks are loaded from the edges data in the config file
|
|
3238
|
-
- Current network defaults to the first available if not specified during init
|
|
3105
|
+
self.tuner = tuner
|
|
3106
|
+
self.optimization_history = []
|
|
3107
|
+
self.param_scales = {}
|
|
3108
|
+
|
|
3109
|
+
def _normalize_params(self, params: np.ndarray, param_names: List[str]) -> np.ndarray:
|
|
3239
3110
|
"""
|
|
3240
|
-
|
|
3241
|
-
self.available_networks = []
|
|
3242
|
-
return
|
|
3243
|
-
|
|
3244
|
-
try:
|
|
3245
|
-
edges = load_edges_from_config(self.config)
|
|
3246
|
-
self.available_networks = list(edges.keys())
|
|
3247
|
-
|
|
3248
|
-
# Set current network to first available if not specified
|
|
3249
|
-
if self.current_network is None and self.available_networks:
|
|
3250
|
-
self.current_network = self.available_networks[0]
|
|
3251
|
-
except Exception as e:
|
|
3252
|
-
print(f"Warning: Could not load networks from config: {e}")
|
|
3253
|
-
self.available_networks = []
|
|
3111
|
+
Normalize parameters to similar scales for better optimization performance.
|
|
3254
3112
|
|
|
3255
|
-
|
|
3256
|
-
|
|
3257
|
-
|
|
3258
|
-
|
|
3259
|
-
|
|
3260
|
-
|
|
3261
|
-
except Exception:
|
|
3262
|
-
pass
|
|
3263
|
-
nodes = load_nodes_from_config(config_path)
|
|
3264
|
-
edges = load_edges_from_config(config_path)
|
|
3265
|
-
|
|
3266
|
-
conn_type_settings = {}
|
|
3267
|
-
|
|
3268
|
-
# Process all edge datasets
|
|
3269
|
-
for edge_dataset_name, edge_df in edges.items():
|
|
3270
|
-
if edge_df.empty:
|
|
3271
|
-
continue
|
|
3272
|
-
|
|
3273
|
-
# Merging with node data to get model templates
|
|
3274
|
-
source_node_df = None
|
|
3275
|
-
target_node_df = None
|
|
3276
|
-
|
|
3277
|
-
# First, try to deterministically parse the edge_dataset_name for patterns like '<src>_to_<tgt>'
|
|
3278
|
-
if '_to_' in edge_dataset_name:
|
|
3279
|
-
parts = edge_dataset_name.split('_to_')
|
|
3280
|
-
if len(parts) == 2:
|
|
3281
|
-
src_name, tgt_name = parts
|
|
3282
|
-
if src_name in nodes:
|
|
3283
|
-
source_node_df = nodes[src_name].add_prefix('source_')
|
|
3284
|
-
if tgt_name in nodes:
|
|
3285
|
-
target_node_df = nodes[tgt_name].add_prefix('target_')
|
|
3286
|
-
|
|
3287
|
-
# If not found by parsing name, fall back to inspecting a sample edge row
|
|
3288
|
-
if source_node_df is None or target_node_df is None:
|
|
3289
|
-
sample_edge = edge_df.iloc[0] if len(edge_df) > 0 else None
|
|
3290
|
-
if sample_edge is not None:
|
|
3291
|
-
source_pop_name = sample_edge.get('source_population', '')
|
|
3292
|
-
target_pop_name = sample_edge.get('target_population', '')
|
|
3293
|
-
if source_pop_name in nodes:
|
|
3294
|
-
source_node_df = nodes[source_pop_name].add_prefix('source_')
|
|
3295
|
-
if target_pop_name in nodes:
|
|
3296
|
-
target_node_df = nodes[target_pop_name].add_prefix('target_')
|
|
3297
|
-
|
|
3298
|
-
# As a last resort, attempt to heuristically match
|
|
3299
|
-
if source_node_df is None or target_node_df is None:
|
|
3300
|
-
for pop_name, node_df in nodes.items():
|
|
3301
|
-
if source_node_df is None and (edge_dataset_name.startswith(pop_name) or edge_dataset_name.endswith(pop_name)):
|
|
3302
|
-
source_node_df = node_df.add_prefix('source_')
|
|
3303
|
-
if target_node_df is None and (edge_dataset_name.startswith(pop_name) or edge_dataset_name.endswith(pop_name)):
|
|
3304
|
-
target_node_df = node_df.add_prefix('target_')
|
|
3305
|
-
|
|
3306
|
-
if source_node_df is None or target_node_df is None:
|
|
3307
|
-
print(f"Warning: Could not find node data for edge dataset {edge_dataset_name}")
|
|
3308
|
-
continue
|
|
3309
|
-
|
|
3310
|
-
# Merge edge data with source node info
|
|
3311
|
-
edges_with_source = pd.merge(
|
|
3312
|
-
edge_df.reset_index(),
|
|
3313
|
-
source_node_df,
|
|
3314
|
-
how='left',
|
|
3315
|
-
left_on='source_node_id',
|
|
3316
|
-
right_index=True
|
|
3317
|
-
)
|
|
3318
|
-
|
|
3319
|
-
# Merge with target node info
|
|
3320
|
-
edges_with_nodes = pd.merge(
|
|
3321
|
-
edges_with_source,
|
|
3322
|
-
target_node_df,
|
|
3323
|
-
how='left',
|
|
3324
|
-
left_on='target_node_id',
|
|
3325
|
-
right_index=True
|
|
3326
|
-
)
|
|
3327
|
-
|
|
3328
|
-
# Skip edge datasets that don't have gap junction information
|
|
3329
|
-
if 'is_gap_junction' not in edges_with_nodes.columns:
|
|
3330
|
-
continue
|
|
3331
|
-
|
|
3332
|
-
# Filter to only gap junction edges
|
|
3333
|
-
# Handle NaN values in is_gap_junction column
|
|
3334
|
-
gap_junction_mask = edges_with_nodes['is_gap_junction'].fillna(False) == True
|
|
3335
|
-
gap_junction_edges = edges_with_nodes[gap_junction_mask]
|
|
3336
|
-
if gap_junction_edges.empty:
|
|
3337
|
-
continue
|
|
3338
|
-
|
|
3339
|
-
# Get unique edge types from the gap junction edges
|
|
3340
|
-
if 'edge_type_id' in gap_junction_edges.columns:
|
|
3341
|
-
edge_types = gap_junction_edges['edge_type_id'].unique()
|
|
3342
|
-
else:
|
|
3343
|
-
edge_types = [None] # Single edge type
|
|
3344
|
-
|
|
3345
|
-
# Process each edge type
|
|
3346
|
-
for edge_type_id in edge_types:
|
|
3347
|
-
# Filter edges for this type
|
|
3348
|
-
if edge_type_id is not None:
|
|
3349
|
-
edge_type_data = gap_junction_edges[gap_junction_edges['edge_type_id'] == edge_type_id]
|
|
3350
|
-
else:
|
|
3351
|
-
edge_type_data = gap_junction_edges
|
|
3352
|
-
|
|
3353
|
-
if len(edge_type_data) == 0:
|
|
3354
|
-
continue
|
|
3355
|
-
|
|
3356
|
-
# Get representative edge for this type
|
|
3357
|
-
edge_info = edge_type_data.iloc[0]
|
|
3358
|
-
|
|
3359
|
-
# Process gap junction
|
|
3360
|
-
source_model_template = edge_info.get('source_model_template', '')
|
|
3361
|
-
target_model_template = edge_info.get('target_model_template', '')
|
|
3362
|
-
|
|
3363
|
-
source_cell_type = source_model_template.replace('hoc:', '') if source_model_template.startswith('hoc:') else source_model_template
|
|
3364
|
-
target_cell_type = target_model_template.replace('hoc:', '') if target_model_template.startswith('hoc:') else target_model_template
|
|
3365
|
-
|
|
3366
|
-
if source_cell_type != target_cell_type:
|
|
3367
|
-
continue # Only process gap junctions between same cell types
|
|
3368
|
-
|
|
3369
|
-
source_pop = edge_info.get('source_pop_name', '')
|
|
3370
|
-
target_pop = edge_info.get('target_pop_name', '')
|
|
3371
|
-
|
|
3372
|
-
conn_name = f"{source_pop}2{target_pop}_gj"
|
|
3373
|
-
if edge_type_id is not None:
|
|
3374
|
-
conn_name += f"_type_{edge_type_id}"
|
|
3375
|
-
|
|
3376
|
-
conn_settings = {
|
|
3377
|
-
'cell': source_cell_type,
|
|
3378
|
-
'sec_id': 0,
|
|
3379
|
-
'sec_x': 0.5,
|
|
3380
|
-
'iclamp_amp': -0.01,
|
|
3381
|
-
'spec_syn_param': {}
|
|
3382
|
-
}
|
|
3383
|
-
|
|
3384
|
-
# Load dynamics params
|
|
3385
|
-
dynamics_file_name = edge_info.get('dynamics_params', '')
|
|
3386
|
-
if dynamics_file_name and dynamics_file_name.upper() != 'NULL':
|
|
3387
|
-
try:
|
|
3388
|
-
syn_params = self._load_synaptic_params_from_config(config, dynamics_file_name)
|
|
3389
|
-
conn_settings['spec_syn_param'] = syn_params
|
|
3390
|
-
except Exception as e:
|
|
3391
|
-
print(f"Warning: could not load dynamics_params file '{dynamics_file_name}': {e}")
|
|
3392
|
-
|
|
3393
|
-
conn_type_settings[conn_name] = conn_settings
|
|
3394
|
-
|
|
3395
|
-
return conn_type_settings
|
|
3113
|
+
Parameters:
|
|
3114
|
+
-----------
|
|
3115
|
+
params : np.ndarray
|
|
3116
|
+
Original parameter values.
|
|
3117
|
+
param_names : List[str]
|
|
3118
|
+
Names of the parameters corresponding to the values.
|
|
3396
3119
|
|
|
3397
|
-
|
|
3120
|
+
Returns:
|
|
3121
|
+
--------
|
|
3122
|
+
np.ndarray
|
|
3123
|
+
Normalized parameter values.
|
|
3398
3124
|
"""
|
|
3399
|
-
|
|
3400
|
-
|
|
3125
|
+
return np.array([params[i] / self.param_scales[name] for i, name in enumerate(param_names)])
|
|
3126
|
+
|
|
3127
|
+
def _denormalize_params(
|
|
3128
|
+
self, normalized_params: np.ndarray, param_names: List[str]
|
|
3129
|
+
) -> np.ndarray:
|
|
3130
|
+
"""
|
|
3131
|
+
Convert normalized parameters back to original scale.
|
|
3132
|
+
|
|
3401
3133
|
Parameters:
|
|
3402
3134
|
-----------
|
|
3403
|
-
|
|
3404
|
-
|
|
3135
|
+
normalized_params : np.ndarray
|
|
3136
|
+
Normalized parameter values.
|
|
3137
|
+
param_names : List[str]
|
|
3138
|
+
Names of the parameters corresponding to the normalized values.
|
|
3139
|
+
|
|
3140
|
+
Returns:
|
|
3141
|
+
--------
|
|
3142
|
+
np.ndarray
|
|
3143
|
+
Denormalized parameter values in their original scale.
|
|
3405
3144
|
"""
|
|
3406
|
-
|
|
3407
|
-
|
|
3408
|
-
|
|
3409
|
-
|
|
3410
|
-
|
|
3411
|
-
|
|
3412
|
-
|
|
3413
|
-
|
|
3414
|
-
|
|
3415
|
-
|
|
3416
|
-
|
|
3417
|
-
|
|
3418
|
-
|
|
3419
|
-
|
|
3420
|
-
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
3424
|
-
|
|
3425
|
-
|
|
3426
|
-
|
|
3427
|
-
|
|
3428
|
-
|
|
3429
|
-
|
|
3430
|
-
|
|
3145
|
+
return np.array(
|
|
3146
|
+
[normalized_params[i] * self.param_scales[name] for i, name in enumerate(param_names)]
|
|
3147
|
+
)
|
|
3148
|
+
|
|
3149
|
+
def _calculate_metrics(self) -> Dict[str, float]:
|
|
3150
|
+
"""
|
|
3151
|
+
Calculate standard metrics from the current simulation.
|
|
3152
|
+
|
|
3153
|
+
This method runs either a single event simulation, a train input simulation,
|
|
3154
|
+
or both based on configuration flags, and calculates relevant synaptic metrics.
|
|
3155
|
+
|
|
3156
|
+
Returns:
|
|
3157
|
+
--------
|
|
3158
|
+
Dict[str, float]
|
|
3159
|
+
Dictionary of calculated metrics including:
|
|
3160
|
+
- induction: measure of synaptic facilitation/depression
|
|
3161
|
+
- ppr: paired-pulse ratio
|
|
3162
|
+
- recovery: recovery from facilitation/depression
|
|
3163
|
+
- max_amplitude: maximum synaptic response amplitude
|
|
3164
|
+
- rise_time: time for synaptic response to rise from 20% to 80% of peak
|
|
3165
|
+
- decay_time: time constant of synaptic response decay
|
|
3166
|
+
- latency: synaptic response latency
|
|
3167
|
+
- half_width: synaptic response half-width
|
|
3168
|
+
- baseline: baseline current
|
|
3169
|
+
- amp: peak amplitude from syn_props
|
|
3170
|
+
"""
|
|
3171
|
+
# Set these to 0 for when we return the dict
|
|
3172
|
+
induction = 0
|
|
3173
|
+
ppr = 0
|
|
3174
|
+
recovery = 0
|
|
3175
|
+
simple_ppr = 0
|
|
3176
|
+
amp = 0
|
|
3177
|
+
rise_time = 0
|
|
3178
|
+
decay_time = 0
|
|
3179
|
+
latency = 0
|
|
3180
|
+
half_width = 0
|
|
3181
|
+
baseline = 0
|
|
3182
|
+
syn_amp = 0
|
|
3183
|
+
|
|
3184
|
+
if self.run_single_event:
|
|
3185
|
+
self.tuner.SingleEvent(plot_and_print=False)
|
|
3186
|
+
# Use the attributes set by SingleEvent method
|
|
3187
|
+
rise_time = getattr(self.tuner, "rise_time", 0)
|
|
3188
|
+
decay_time = getattr(self.tuner, "decay_time", 0)
|
|
3189
|
+
# Get additional syn_props directly
|
|
3190
|
+
syn_props = self.tuner._get_syn_prop()
|
|
3191
|
+
latency = syn_props.get("latency", 0)
|
|
3192
|
+
half_width = syn_props.get("half_width", 0)
|
|
3193
|
+
baseline = syn_props.get("baseline", 0)
|
|
3194
|
+
syn_amp = syn_props.get("amp", 0)
|
|
3195
|
+
|
|
3196
|
+
if self.run_train_input:
|
|
3197
|
+
self.tuner._simulate_model(self.train_frequency, self.train_delay)
|
|
3198
|
+
amp = self.tuner._response_amplitude()
|
|
3199
|
+
ppr, induction, recovery, simple_ppr = self.tuner._calc_ppr_induction_recovery(
|
|
3200
|
+
amp, print_math=False
|
|
3201
|
+
)
|
|
3202
|
+
amp = self.tuner._find_max_amp(amp)
|
|
3203
|
+
|
|
3204
|
+
return {
|
|
3205
|
+
"induction": float(induction),
|
|
3206
|
+
"ppr": float(ppr),
|
|
3207
|
+
"recovery": float(recovery),
|
|
3208
|
+
"simple_ppr": float(simple_ppr),
|
|
3209
|
+
"max_amplitude": float(amp),
|
|
3210
|
+
"rise_time": float(rise_time),
|
|
3211
|
+
"decay_time": float(decay_time),
|
|
3212
|
+
"latency": float(latency),
|
|
3213
|
+
"half_width": float(half_width),
|
|
3214
|
+
"baseline": float(baseline),
|
|
3215
|
+
"amp": float(syn_amp),
|
|
3216
|
+
}
|
|
3217
|
+
|
|
3218
|
+
def _default_cost_function(
|
|
3219
|
+
self, metrics: Dict[str, float], target_metrics: Dict[str, float]
|
|
3220
|
+
) -> float:
|
|
3221
|
+
"""
|
|
3222
|
+
Default cost function that minimizes the squared difference between achieved and target induction.
|
|
3223
|
+
|
|
3224
|
+
Parameters:
|
|
3225
|
+
-----------
|
|
3226
|
+
metrics : Dict[str, float]
|
|
3227
|
+
Dictionary of calculated metrics from the current simulation.
|
|
3228
|
+
target_metrics : Dict[str, float]
|
|
3229
|
+
Dictionary of target metrics to optimize towards.
|
|
3230
|
+
|
|
3231
|
+
Returns:
|
|
3232
|
+
--------
|
|
3233
|
+
float
|
|
3234
|
+
The squared error between achieved and target induction.
|
|
3235
|
+
"""
|
|
3236
|
+
return float((metrics["induction"] - target_metrics["induction"]) ** 2)
|
|
3237
|
+
|
|
3238
|
+
def _objective_function(
|
|
3239
|
+
self,
|
|
3240
|
+
normalized_params: np.ndarray,
|
|
3241
|
+
param_names: List[str],
|
|
3242
|
+
cost_function: Callable,
|
|
3243
|
+
target_metrics: Dict[str, float],
|
|
3244
|
+
) -> float:
|
|
3245
|
+
"""
|
|
3246
|
+
Calculate error using provided cost function
|
|
3247
|
+
"""
|
|
3248
|
+
# Denormalize parameters
|
|
3249
|
+
params = self._denormalize_params(normalized_params, param_names)
|
|
3250
|
+
|
|
3251
|
+
# Set parameters
|
|
3252
|
+
for name, value in zip(param_names, params):
|
|
3253
|
+
setattr(self.tuner.syn, name, value)
|
|
3254
|
+
|
|
3255
|
+
# just do this and have the SingleEvent handle it
|
|
3256
|
+
if self.run_single_event:
|
|
3257
|
+
self.tuner.using_optimizer = True
|
|
3258
|
+
self.tuner.param_names = param_names
|
|
3259
|
+
self.tuner.params = params
|
|
3260
|
+
|
|
3261
|
+
# Calculate metrics and error
|
|
3262
|
+
metrics = self._calculate_metrics()
|
|
3263
|
+
error = float(cost_function(metrics, target_metrics)) # Ensure error is scalar
|
|
3264
|
+
|
|
3265
|
+
# Store history with denormalized values
|
|
3266
|
+
history_entry = {
|
|
3267
|
+
"params": dict(zip(param_names, params)),
|
|
3268
|
+
"metrics": metrics,
|
|
3269
|
+
"error": error,
|
|
3270
|
+
}
|
|
3271
|
+
self.optimization_history.append(history_entry)
|
|
3272
|
+
|
|
3273
|
+
return error
|
|
3274
|
+
|
|
3275
|
+
def optimize_parameters(
|
|
3276
|
+
self,
|
|
3277
|
+
target_metrics: Dict[str, float],
|
|
3278
|
+
param_bounds: Dict[str, Tuple[float, float]],
|
|
3279
|
+
run_single_event: bool = False,
|
|
3280
|
+
run_train_input: bool = True,
|
|
3281
|
+
train_frequency: float = 50,
|
|
3282
|
+
train_delay: float = 250,
|
|
3283
|
+
cost_function: Optional[Callable] = None,
|
|
3284
|
+
method: str = "SLSQP",
|
|
3285
|
+
init_guess="random",
|
|
3286
|
+
) -> SynapseOptimizationResult:
|
|
3287
|
+
"""
|
|
3288
|
+
Optimize synaptic parameters to achieve target metrics.
|
|
3289
|
+
|
|
3290
|
+
Parameters:
|
|
3291
|
+
-----------
|
|
3292
|
+
target_metrics : Dict[str, float]
|
|
3293
|
+
Target values for synaptic metrics (e.g., {'induction': 0.2, 'rise_time': 0.5})
|
|
3294
|
+
param_bounds : Dict[str, Tuple[float, float]]
|
|
3295
|
+
Bounds for each parameter to optimize (e.g., {'tau_d': (5, 50), 'Use': (0.1, 0.9)})
|
|
3296
|
+
run_single_event : bool, optional
|
|
3297
|
+
Whether to run single event simulations during optimization (default: False)
|
|
3298
|
+
run_train_input : bool, optional
|
|
3299
|
+
Whether to run train input simulations during optimization (default: True)
|
|
3300
|
+
train_frequency : float, optional
|
|
3301
|
+
Frequency of the stimulus train in Hz (default: 50)
|
|
3302
|
+
train_delay : float, optional
|
|
3303
|
+
Delay between pulse trains in ms (default: 250)
|
|
3304
|
+
cost_function : Optional[Callable]
|
|
3305
|
+
Custom cost function for optimization. If None, uses default cost function
|
|
3306
|
+
that optimizes induction.
|
|
3307
|
+
method : str, optional
|
|
3308
|
+
Optimization method to use (default: 'SLSQP')
|
|
3309
|
+
init_guess : str, optional
|
|
3310
|
+
Method for initial parameter guess ('random' or 'middle_guess')
|
|
3311
|
+
|
|
3312
|
+
Returns:
|
|
3313
|
+
--------
|
|
3314
|
+
SynapseOptimizationResult
|
|
3315
|
+
Results of the optimization including optimal parameters, achieved metrics,
|
|
3316
|
+
target metrics, final error, and optimization path.
|
|
3317
|
+
|
|
3318
|
+
Notes:
|
|
3319
|
+
------
|
|
3320
|
+
This function uses scipy.optimize.minimize to find the optimal parameter values
|
|
3321
|
+
that minimize the difference between achieved and target metrics.
|
|
3322
|
+
"""
|
|
3323
|
+
self.optimization_history = []
|
|
3324
|
+
self.train_frequency = train_frequency
|
|
3325
|
+
self.train_delay = train_delay
|
|
3326
|
+
self.run_single_event = run_single_event
|
|
3327
|
+
self.run_train_input = run_train_input
|
|
3328
|
+
|
|
3329
|
+
param_names = list(param_bounds.keys())
|
|
3330
|
+
bounds = [param_bounds[name] for name in param_names]
|
|
3331
|
+
|
|
3332
|
+
if cost_function is None:
|
|
3333
|
+
cost_function = self._default_cost_function
|
|
3334
|
+
|
|
3335
|
+
# Calculate scaling factors
|
|
3336
|
+
self.param_scales = {
|
|
3337
|
+
name: max(abs(bounds[i][0]), abs(bounds[i][1])) for i, name in enumerate(param_names)
|
|
3338
|
+
}
|
|
3339
|
+
|
|
3340
|
+
# Normalize bounds
|
|
3341
|
+
normalized_bounds = [
|
|
3342
|
+
(b[0] / self.param_scales[name], b[1] / self.param_scales[name])
|
|
3343
|
+
for name, b in zip(param_names, bounds)
|
|
3344
|
+
]
|
|
3345
|
+
|
|
3346
|
+
# picks with method of init value we want to use
|
|
3347
|
+
if init_guess == "random":
|
|
3348
|
+
x0 = np.array([np.random.uniform(b[0], b[1]) for b in bounds])
|
|
3349
|
+
elif init_guess == "middle_guess":
|
|
3350
|
+
x0 = [(b[0] + b[1]) / 2 for b in bounds]
|
|
3431
3351
|
else:
|
|
3432
|
-
|
|
3433
|
-
|
|
3434
|
-
|
|
3435
|
-
#
|
|
3436
|
-
|
|
3437
|
-
|
|
3438
|
-
|
|
3439
|
-
|
|
3440
|
-
|
|
3441
|
-
|
|
3442
|
-
|
|
3443
|
-
|
|
3444
|
-
# Properly clean up the existing parallel context
|
|
3445
|
-
if hasattr(self, 'pc'):
|
|
3446
|
-
self.pc.done() # Clean up existing parallel context
|
|
3447
|
-
|
|
3448
|
-
# Force garbage collection and reset NEURON state
|
|
3449
|
-
import gc
|
|
3450
|
-
gc.collect()
|
|
3451
|
-
h.finitialize()
|
|
3452
|
-
|
|
3453
|
-
# Create a fresh parallel context after cleanup
|
|
3454
|
-
self.pc = h.ParallelContext()
|
|
3455
|
-
|
|
3456
|
-
try:
|
|
3457
|
-
sec1 = list(self.cell1.all)[self.conn["sec_id"]]
|
|
3458
|
-
sec2 = list(self.cell2.all)[self.conn["sec_id"]]
|
|
3459
|
-
|
|
3460
|
-
# Use unique IDs to avoid conflicts with existing parallel context setups
|
|
3461
|
-
import time
|
|
3462
|
-
unique_id = int(time.time() * 1000) % 10000 # Use timestamp as unique base ID
|
|
3463
|
-
|
|
3464
|
-
self.pc.source_var(sec1(self.conn["sec_x"])._ref_v, unique_id, sec=sec1)
|
|
3465
|
-
self.gap_junc_1 = h.Gap(sec1(0.5))
|
|
3466
|
-
self.pc.target_var(self.gap_junc_1._ref_vgap, unique_id + 1)
|
|
3467
|
-
|
|
3468
|
-
self.pc.source_var(sec2(self.conn["sec_x"])._ref_v, unique_id + 1, sec=sec2)
|
|
3469
|
-
self.gap_junc_2 = h.Gap(sec2(0.5))
|
|
3470
|
-
self.pc.target_var(self.gap_junc_2._ref_vgap, unique_id)
|
|
3471
|
-
|
|
3472
|
-
self.pc.setup_transfer()
|
|
3473
|
-
except Exception as e:
|
|
3474
|
-
print(f"Error setting up gap junctions: {e}")
|
|
3475
|
-
# Try to continue with basic setup
|
|
3476
|
-
self.gap_junc_1 = h.Gap(list(self.cell1.all)[self.conn["sec_id"]](0.5))
|
|
3477
|
-
self.gap_junc_2 = h.Gap(list(self.cell2.all)[self.conn["sec_id"]](0.5))
|
|
3478
|
-
|
|
3479
|
-
# Reset NEURON state after complete setup
|
|
3480
|
-
h.finitialize()
|
|
3481
|
-
|
|
3482
|
-
print(f"Successfully switched to connection: {new_connection}")
|
|
3352
|
+
raise Exception("Pick a valid init guess method: either 'random' or 'middle_guess'")
|
|
3353
|
+
normalized_x0 = self._normalize_params(np.array(x0), param_names)
|
|
3354
|
+
|
|
3355
|
+
# Run optimization
|
|
3356
|
+
result = minimize(
|
|
3357
|
+
self._objective_function,
|
|
3358
|
+
normalized_x0,
|
|
3359
|
+
args=(param_names, cost_function, target_metrics),
|
|
3360
|
+
method=method,
|
|
3361
|
+
bounds=normalized_bounds,
|
|
3362
|
+
)
|
|
3483
3363
|
|
|
3484
|
-
|
|
3364
|
+
# Get final parameters and metrics
|
|
3365
|
+
final_params = dict(zip(param_names, self._denormalize_params(result.x, param_names)))
|
|
3366
|
+
for name, value in final_params.items():
|
|
3367
|
+
setattr(self.tuner.syn, name, value)
|
|
3368
|
+
final_metrics = self._calculate_metrics()
|
|
3369
|
+
|
|
3370
|
+
return SynapseOptimizationResult(
|
|
3371
|
+
optimal_params=final_params,
|
|
3372
|
+
achieved_metrics=final_metrics,
|
|
3373
|
+
target_metrics=target_metrics,
|
|
3374
|
+
error=result.fun,
|
|
3375
|
+
optimization_path=self.optimization_history,
|
|
3376
|
+
)
|
|
3377
|
+
|
|
3378
|
+
def plot_optimization_results(self, result: SynapseOptimizationResult):
|
|
3485
3379
|
"""
|
|
3486
|
-
|
|
3380
|
+
Plot optimization results including convergence and final traces.
|
|
3381
|
+
|
|
3382
|
+
Parameters:
|
|
3383
|
+
-----------
|
|
3384
|
+
result : SynapseOptimizationResult
|
|
3385
|
+
Results from optimization as returned by optimize_parameters()
|
|
3386
|
+
|
|
3387
|
+
Notes:
|
|
3388
|
+
------
|
|
3389
|
+
This method generates three plots:
|
|
3390
|
+
1. Error convergence plot showing how the error decreased over iterations
|
|
3391
|
+
2. Parameter convergence plots showing how each parameter changed
|
|
3392
|
+
3. Final model response with the optimal parameters
|
|
3393
|
+
|
|
3394
|
+
It also prints a summary of the optimization results including target vs. achieved
|
|
3395
|
+
metrics and the optimal parameter values.
|
|
3396
|
+
"""
|
|
3397
|
+
# Ensure errors are properly shaped for plotting
|
|
3398
|
+
iterations = range(len(result.optimization_path))
|
|
3399
|
+
errors = np.array([float(h["error"]) for h in result.optimization_path]).flatten()
|
|
3400
|
+
|
|
3401
|
+
# Plot error convergence
|
|
3402
|
+
fig1, ax1 = plt.subplots(figsize=(8, 5))
|
|
3403
|
+
ax1.plot(iterations, errors, label="Error")
|
|
3404
|
+
ax1.set_xlabel("Iteration")
|
|
3405
|
+
ax1.set_ylabel("Error")
|
|
3406
|
+
ax1.set_title("Error Convergence")
|
|
3407
|
+
ax1.set_yscale("log")
|
|
3408
|
+
ax1.legend()
|
|
3409
|
+
plt.tight_layout()
|
|
3410
|
+
plt.show()
|
|
3411
|
+
|
|
3412
|
+
# Plot parameter convergence
|
|
3413
|
+
param_names = list(result.optimal_params.keys())
|
|
3414
|
+
num_params = len(param_names)
|
|
3415
|
+
fig2, axs = plt.subplots(nrows=num_params, ncols=1, figsize=(8, 5 * num_params))
|
|
3416
|
+
|
|
3417
|
+
if num_params == 1:
|
|
3418
|
+
axs = [axs]
|
|
3419
|
+
|
|
3420
|
+
for ax, param in zip(axs, param_names):
|
|
3421
|
+
values = [float(h["params"][param]) for h in result.optimization_path]
|
|
3422
|
+
ax.plot(iterations, values, label=f"{param}")
|
|
3423
|
+
ax.set_xlabel("Iteration")
|
|
3424
|
+
ax.set_ylabel("Parameter Value")
|
|
3425
|
+
ax.set_title(f"Convergence of {param}")
|
|
3426
|
+
ax.legend()
|
|
3427
|
+
|
|
3428
|
+
plt.tight_layout()
|
|
3429
|
+
plt.show()
|
|
3430
|
+
|
|
3431
|
+
# Print final results
|
|
3432
|
+
print("Optimization Results:")
|
|
3433
|
+
print(f"Final Error: {float(result.error):.2e}\n")
|
|
3434
|
+
print("Target Metrics:")
|
|
3435
|
+
for metric, value in result.target_metrics.items():
|
|
3436
|
+
achieved = result.achieved_metrics.get(metric)
|
|
3437
|
+
if achieved is not None and metric != "amplitudes": # Skip amplitude array
|
|
3438
|
+
print(f"{metric}: {float(achieved):.3f} (target: {float(value):.3f})")
|
|
3439
|
+
|
|
3440
|
+
print("\nOptimal Parameters:")
|
|
3441
|
+
for param, value in result.optimal_params.items():
|
|
3442
|
+
print(f"{param}: {float(value):.3f}")
|
|
3443
|
+
|
|
3444
|
+
# Plot final model response
|
|
3445
|
+
if self.run_train_input:
|
|
3446
|
+
self.tuner._plot_model(
|
|
3447
|
+
[
|
|
3448
|
+
self.tuner.general_settings["tstart"] - self.tuner.nstim.interval / 3,
|
|
3449
|
+
self.tuner.tstop,
|
|
3450
|
+
]
|
|
3451
|
+
)
|
|
3452
|
+
amp = self.tuner._response_amplitude()
|
|
3453
|
+
self.tuner._calc_ppr_induction_recovery(amp)
|
|
3454
|
+
if self.run_single_event:
|
|
3455
|
+
self.tuner.ispk = None
|
|
3456
|
+
self.tuner.SingleEvent(plot_and_print=True)
|
|
3457
|
+
|
|
3458
|
+
# dataclass decorator automatically generates __init__ from type-annotated class variables for cleaner code
|
|
3459
|
+
@dataclass
|
|
3460
|
+
class GapOptimizationResult:
|
|
3461
|
+
"""Container for gap junction optimization results"""
|
|
3462
|
+
|
|
3463
|
+
optimal_resistance: float
|
|
3464
|
+
achieved_cc: float
|
|
3465
|
+
target_cc: float
|
|
3466
|
+
error: float
|
|
3467
|
+
optimization_path: List[Dict[str, float]]
|
|
3468
|
+
|
|
3469
|
+
|
|
3470
|
+
class GapJunctionOptimizer:
|
|
3471
|
+
def __init__(self, tuner):
|
|
3472
|
+
"""
|
|
3473
|
+
Initialize the gap junction optimizer
|
|
3474
|
+
|
|
3475
|
+
Parameters:
|
|
3476
|
+
-----------
|
|
3477
|
+
tuner : GapJunctionTuner
|
|
3478
|
+
Instance of the GapJunctionTuner class
|
|
3479
|
+
"""
|
|
3480
|
+
self.tuner = tuner
|
|
3481
|
+
self.optimization_history = []
|
|
3482
|
+
|
|
3483
|
+
def _objective_function(self, resistance: float, target_cc: float) -> float:
|
|
3484
|
+
"""
|
|
3485
|
+
Calculate error between achieved and target coupling coefficient
|
|
3487
3486
|
|
|
3488
3487
|
Parameters:
|
|
3489
3488
|
-----------
|
|
3490
3489
|
resistance : float
|
|
3491
|
-
|
|
3490
|
+
Gap junction resistance to try
|
|
3491
|
+
target_cc : float
|
|
3492
|
+
Target coupling coefficient to match
|
|
3493
|
+
|
|
3494
|
+
Returns:
|
|
3495
|
+
--------
|
|
3496
|
+
float : Error between achieved and target coupling coefficient
|
|
3497
|
+
"""
|
|
3498
|
+
# Run model with current resistance
|
|
3499
|
+
self.tuner.model(resistance)
|
|
3500
|
+
|
|
3501
|
+
# Calculate coupling coefficient
|
|
3502
|
+
achieved_cc = self.tuner.coupling_coefficient(
|
|
3503
|
+
self.tuner.t_vec,
|
|
3504
|
+
self.tuner.soma_v_1,
|
|
3505
|
+
self.tuner.soma_v_2,
|
|
3506
|
+
self.tuner.general_settings["tstart"],
|
|
3507
|
+
self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
|
|
3508
|
+
)
|
|
3509
|
+
|
|
3510
|
+
# Calculate error
|
|
3511
|
+
error = (achieved_cc - target_cc) ** 2 # MSE
|
|
3512
|
+
|
|
3513
|
+
# Store history
|
|
3514
|
+
self.optimization_history.append(
|
|
3515
|
+
{"resistance": resistance, "achieved_cc": achieved_cc, "error": error}
|
|
3516
|
+
)
|
|
3517
|
+
|
|
3518
|
+
return error
|
|
3519
|
+
|
|
3520
|
+
def optimize_resistance(
|
|
3521
|
+
self, target_cc: float, resistance_bounds: tuple = (1e-4, 1e-2), method: str = "bounded"
|
|
3522
|
+
) -> GapOptimizationResult:
|
|
3523
|
+
"""
|
|
3524
|
+
Optimize gap junction resistance to achieve a target coupling coefficient.
|
|
3525
|
+
|
|
3526
|
+
Parameters:
|
|
3527
|
+
-----------
|
|
3528
|
+
target_cc : float
|
|
3529
|
+
Target coupling coefficient to achieve (between 0 and 1)
|
|
3530
|
+
resistance_bounds : tuple, optional
|
|
3531
|
+
(min, max) bounds for resistance search in MOhm. Default is (1e-4, 1e-2).
|
|
3532
|
+
method : str, optional
|
|
3533
|
+
Optimization method to use. Default is 'bounded' which works well
|
|
3534
|
+
for single-parameter optimization.
|
|
3535
|
+
|
|
3536
|
+
Returns:
|
|
3537
|
+
--------
|
|
3538
|
+
GapOptimizationResult
|
|
3539
|
+
Container with optimization results including:
|
|
3540
|
+
- optimal_resistance: The optimized resistance value
|
|
3541
|
+
- achieved_cc: The coupling coefficient achieved with the optimal resistance
|
|
3542
|
+
- target_cc: The target coupling coefficient
|
|
3543
|
+
- error: The final error (squared difference between target and achieved)
|
|
3544
|
+
- optimization_path: List of all values tried during optimization
|
|
3492
3545
|
|
|
3493
3546
|
Notes:
|
|
3494
3547
|
------
|
|
3495
|
-
|
|
3496
|
-
|
|
3548
|
+
Uses scipy.optimize.minimize_scalar with bounded method, which is
|
|
3549
|
+
appropriate for this single-parameter optimization problem.
|
|
3497
3550
|
"""
|
|
3498
|
-
self.
|
|
3499
|
-
self.gap_junc_2.g = resistance
|
|
3551
|
+
self.optimization_history = []
|
|
3500
3552
|
|
|
3501
|
-
|
|
3502
|
-
|
|
3503
|
-
|
|
3504
|
-
|
|
3505
|
-
soma_v_1.record(self.cell1.soma[0](0.5)._ref_v)
|
|
3506
|
-
soma_v_2.record(self.cell2.soma[0](0.5)._ref_v)
|
|
3553
|
+
# Run optimization
|
|
3554
|
+
result = minimize_scalar(
|
|
3555
|
+
self._objective_function, args=(target_cc,), bounds=resistance_bounds, method=method
|
|
3556
|
+
)
|
|
3507
3557
|
|
|
3508
|
-
|
|
3509
|
-
self.
|
|
3510
|
-
|
|
3558
|
+
# Run final model with optimal resistance
|
|
3559
|
+
self.tuner.model(result.x)
|
|
3560
|
+
final_cc = self.tuner.coupling_coefficient(
|
|
3561
|
+
self.tuner.t_vec,
|
|
3562
|
+
self.tuner.soma_v_1,
|
|
3563
|
+
self.tuner.soma_v_2,
|
|
3564
|
+
self.tuner.general_settings["tstart"],
|
|
3565
|
+
self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
|
|
3566
|
+
)
|
|
3511
3567
|
|
|
3512
|
-
|
|
3513
|
-
|
|
3568
|
+
# Package up our results
|
|
3569
|
+
optimization_result = GapOptimizationResult(
|
|
3570
|
+
optimal_resistance=result.x,
|
|
3571
|
+
achieved_cc=final_cc,
|
|
3572
|
+
target_cc=target_cc,
|
|
3573
|
+
error=result.fun,
|
|
3574
|
+
optimization_path=self.optimization_history,
|
|
3575
|
+
)
|
|
3514
3576
|
|
|
3515
|
-
|
|
3577
|
+
return optimization_result
|
|
3578
|
+
|
|
3579
|
+
def plot_optimization_results(self, result: GapOptimizationResult):
|
|
3516
3580
|
"""
|
|
3517
|
-
Plot
|
|
3581
|
+
Plot optimization results including convergence and final voltage traces
|
|
3518
3582
|
|
|
3519
|
-
|
|
3520
|
-
|
|
3583
|
+
Parameters:
|
|
3584
|
+
-----------
|
|
3585
|
+
result : GapOptimizationResult
|
|
3586
|
+
Results from optimization
|
|
3521
3587
|
"""
|
|
3588
|
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
|
|
3589
|
+
|
|
3590
|
+
# Plot voltage traces
|
|
3522
3591
|
t_range = [
|
|
3523
|
-
self.general_settings["tstart"] - 100.0,
|
|
3524
|
-
self.general_settings["tstart"] + self.general_settings["tdur"] + 100.0,
|
|
3592
|
+
self.tuner.general_settings["tstart"] - 100.0,
|
|
3593
|
+
self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"] + 100.0,
|
|
3525
3594
|
]
|
|
3526
|
-
t = np.array(self.t_vec)
|
|
3527
|
-
v1 = np.array(self.soma_v_1)
|
|
3528
|
-
v2 = np.array(self.soma_v_2)
|
|
3595
|
+
t = np.array(self.tuner.t_vec)
|
|
3596
|
+
v1 = np.array(self.tuner.soma_v_1)
|
|
3597
|
+
v2 = np.array(self.tuner.soma_v_2)
|
|
3529
3598
|
tidx = (t >= t_range[0]) & (t <= t_range[1])
|
|
3530
3599
|
|
|
3531
|
-
|
|
3532
|
-
|
|
3533
|
-
|
|
3534
|
-
|
|
3535
|
-
|
|
3536
|
-
|
|
3537
|
-
|
|
3538
|
-
|
|
3600
|
+
ax1.plot(t[tidx], v1[tidx], "b", label=f"{self.tuner.cell_name} 1")
|
|
3601
|
+
ax1.plot(t[tidx], v2[tidx], "r", label=f"{self.tuner.cell_name} 2")
|
|
3602
|
+
ax1.set_xlabel("Time (ms)")
|
|
3603
|
+
ax1.set_ylabel("Membrane Voltage (mV)")
|
|
3604
|
+
ax1.legend()
|
|
3605
|
+
ax1.set_title("Optimized Voltage Traces")
|
|
3606
|
+
|
|
3607
|
+
# Plot error convergence
|
|
3608
|
+
errors = [h["error"] for h in result.optimization_path]
|
|
3609
|
+
ax2.plot(errors)
|
|
3610
|
+
ax2.set_xlabel("Iteration")
|
|
3611
|
+
ax2.set_ylabel("Error")
|
|
3612
|
+
ax2.set_title("Error Convergence")
|
|
3613
|
+
ax2.set_yscale("log")
|
|
3614
|
+
|
|
3615
|
+
# Plot resistance convergence
|
|
3616
|
+
resistances = [h["resistance"] for h in result.optimization_path]
|
|
3617
|
+
ax3.plot(resistances)
|
|
3618
|
+
ax3.set_xlabel("Iteration")
|
|
3619
|
+
ax3.set_ylabel("Resistance")
|
|
3620
|
+
ax3.set_title("Resistance Convergence")
|
|
3621
|
+
ax3.set_yscale("log")
|
|
3622
|
+
|
|
3623
|
+
# Print final results
|
|
3624
|
+
result_text = (
|
|
3625
|
+
f"Optimal Resistance: {result.optimal_resistance:.2e}\n"
|
|
3626
|
+
f"Target CC: {result.target_cc:.3f}\n"
|
|
3627
|
+
f"Achieved CC: {result.achieved_cc:.3f}\n"
|
|
3628
|
+
f"Final Error: {result.error:.2e}"
|
|
3629
|
+
)
|
|
3630
|
+
ax4.text(0.1, 0.7, result_text, transform=ax4.transAxes, fontsize=10)
|
|
3631
|
+
ax4.axis("off")
|
|
3539
3632
|
|
|
3540
|
-
|
|
3633
|
+
plt.tight_layout()
|
|
3634
|
+
plt.show()
|
|
3635
|
+
|
|
3636
|
+
def parameter_sweep(self, resistance_range: np.ndarray) -> dict:
|
|
3541
3637
|
"""
|
|
3542
|
-
|
|
3638
|
+
Perform a parameter sweep across different resistance values.
|
|
3543
3639
|
|
|
3544
3640
|
Parameters:
|
|
3545
3641
|
-----------
|
|
3546
|
-
|
|
3547
|
-
|
|
3548
|
-
v1 : array-like
|
|
3549
|
-
Voltage trace of the cell receiving the current injection.
|
|
3550
|
-
v2 : array-like
|
|
3551
|
-
Voltage trace of the coupled cell.
|
|
3552
|
-
t_start : float
|
|
3553
|
-
Start time for calculating the steady-state voltage change.
|
|
3554
|
-
t_end : float
|
|
3555
|
-
End time for calculating the steady-state voltage change.
|
|
3556
|
-
dt : float, optional
|
|
3557
|
-
Time step of the simulation. Default is h.dt.
|
|
3642
|
+
resistance_range : np.ndarray
|
|
3643
|
+
Array of resistance values to test.
|
|
3558
3644
|
|
|
3559
3645
|
Returns:
|
|
3560
3646
|
--------
|
|
3561
|
-
|
|
3562
|
-
|
|
3563
|
-
|
|
3564
|
-
|
|
3565
|
-
|
|
3566
|
-
|
|
3567
|
-
|
|
3647
|
+
dict
|
|
3648
|
+
Dictionary containing the results of the parameter sweep, with keys:
|
|
3649
|
+
- 'resistance': List of resistance values tested
|
|
3650
|
+
- 'coupling_coefficient': Corresponding coupling coefficients
|
|
3651
|
+
|
|
3652
|
+
Notes:
|
|
3653
|
+
------
|
|
3654
|
+
This method is useful for understanding the relationship between gap junction
|
|
3655
|
+
resistance and coupling coefficient before attempting optimization.
|
|
3656
|
+
"""
|
|
3657
|
+
results = {"resistance": [], "coupling_coefficient": []}
|
|
3658
|
+
|
|
3659
|
+
for resistance in tqdm(resistance_range, desc="Sweeping resistance values"):
|
|
3660
|
+
self.tuner.model(resistance)
|
|
3661
|
+
cc = self.tuner.coupling_coefficient(
|
|
3662
|
+
self.tuner.t_vec,
|
|
3663
|
+
self.tuner.soma_v_1,
|
|
3664
|
+
self.tuner.soma_v_2,
|
|
3665
|
+
self.tuner.general_settings["tstart"],
|
|
3666
|
+
self.tuner.general_settings["tstart"] + self.tuner.general_settings["tdur"],
|
|
3667
|
+
)
|
|
3668
|
+
|
|
3669
|
+
results["resistance"].append(resistance)
|
|
3670
|
+
results["coupling_coefficient"].append(cc)
|
|
3671
|
+
|
|
3672
|
+
return results
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|