superneuroabm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
superneuroabm/model.py ADDED
@@ -0,0 +1,689 @@
1
+ """
2
+ Model class for building an SNN
3
+ """
4
+
5
+ from collections import defaultdict
6
+ from typing import Dict, Callable, List, Set
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import cupy as cp
11
+ from sagesim.space import NetworkSpace
12
+ from sagesim.model import Model
13
+ from sagesim.breed import Breed
14
+
15
+ from superneuroabm.step_functions.soma.izh import izh_soma_step_func
16
+ from superneuroabm.step_functions.soma.lif import lif_soma_step_func
17
+ from superneuroabm.step_functions.soma.lif_soma_adaptive_thr import lif_soma_adaptive_thr_step_func
18
+ from superneuroabm.step_functions.synapse.single_exp import synapse_single_exp_step_func
19
+ from superneuroabm.step_functions.synapse.stdp.learning_rule_selector import (
20
+ learning_rule_selector,
21
+ )
22
+ from superneuroabm.util import load_component_configurations
23
+ import copy
24
+
25
+ CURRENT_DIR_ABSPATH = Path(__file__).resolve().parent
26
+
27
+
28
+ class NeuromorphicModel(Model):
29
+ def __init__(
30
+ self,
31
+ soma_breed_info: Dict[str, List[Callable]] = {
32
+ "izh_soma": [
33
+ (
34
+ izh_soma_step_func,
35
+ CURRENT_DIR_ABSPATH / "step_functions" / "soma" / "izh.py",
36
+ )
37
+ ],
38
+ "lif_soma": [
39
+ (
40
+ lif_soma_step_func,
41
+ CURRENT_DIR_ABSPATH / "step_functions" / "soma" / "lif.py",
42
+ )
43
+ ],
44
+ "lif_soma_adaptive_thr": [
45
+ (
46
+ lif_soma_adaptive_thr_step_func,
47
+ CURRENT_DIR_ABSPATH / "step_functions" / "soma" / "lif_soma_adaptive_thr.py",
48
+ )
49
+ ],
50
+ },
51
+ synapse_breed_info: Dict[str, List[Callable]] = {
52
+ "single_exp_synapse": [
53
+ (
54
+ synapse_single_exp_step_func,
55
+ CURRENT_DIR_ABSPATH
56
+ / "step_functions"
57
+ / "synapse"
58
+ / "single_exp.py",
59
+ ),
60
+ (
61
+ learning_rule_selector,
62
+ CURRENT_DIR_ABSPATH
63
+ / "step_functions"
64
+ / "synapse"
65
+ / "stdp"
66
+ / "learning_rule_selector.py",
67
+ ),
68
+ ],
69
+ },
70
+ enable_internal_state_tracking: bool = True,
71
+ ) -> None:
72
+ """
73
+ Creates an SNN Model and provides methods to create, simulate,
74
+ and monitor soma and synapses.
75
+
76
+ :param use_gpu: True if the system supports CUDA GPU
77
+ acceleration.
78
+ :param soma_breed_info: Dict of breed name to List of
79
+ Callable step functions. If specifed, will override
80
+ the default soma breed and soma and synapse step
81
+ functions, allowing for multi-breed simulations.
82
+ Step functions will be executed on the respective
83
+ breed every simulation step in the order specifed in the
84
+ list.
85
+ :param enable_internal_state_tracking: If True, tracks and stores
86
+ internal states history for all agents during simulation.
87
+ If False, disables tracking to reduce memory usage and improve
88
+ performance. Default is True for backward compatibility.
89
+ """
90
+ super().__init__(space=NetworkSpace(ordered=True))
91
+
92
+ self.enable_internal_state_tracking = enable_internal_state_tracking
93
+
94
+ self.register_global_property("dt", 1e-3) # Time step (100 μs)
95
+ self.register_global_property("I_bias", 0) # No bias current
96
+
97
+ # Soma properties: (default_value, neighbor_visible)
98
+ # neighbor_visible=True means the property is sent to neighbors during MPI sync
99
+ # Only output_spikes_tensor is read by neighbors (synapses read soma spikes)
100
+ soma_properties = {
101
+ "hyperparameters": ([0.0, 0.0, 0.0, 0.0, 0.0], False), # k, vth, C, a, b,
102
+ "learning_hyperparameters": (
103
+ [0.0 for _ in range(5)], False
104
+ ), # STDP_function name, tau_pre_stdp, tau_post_stdp, a_exp_pre, a_exp_post, Wmax, Wmin
105
+ "internal_state": ([0.0, 0.0, 0.0, 0.0], False), # v, u
106
+ "internal_learning_state": (
107
+ [0.0 for _ in range(3)], False
108
+ ), # pre_trace, post_trace, dW
109
+ "synapse_delay_reg": ([], False), # Synapse delay
110
+ "input_spikes_tensor": ([], False), # input spikes tensor
111
+ "output_spikes_tensor": ([], True), # NEIGHBOR-VISIBLE: synapses read soma spikes
112
+ "internal_states_buffer": ([], False),
113
+ "internal_learning_states_buffer": ([], False), # learning states buffer
114
+ }
115
+ # Synapse properties: (default_value, neighbor_visible)
116
+ # Only internal_state is read by neighbors (somas read I_synapse from synapses)
117
+ synapse_properties = {
118
+ "hyperparameters": (
119
+ [0.0 for _ in range(10)], False
120
+ ), # weight, delay, scale, Tau_fall, Tau_rise, tau_pre_stdp, tau_post_stdp, a_exp_pre, a_exp_post, stdp_history_length
121
+ "learning_hyperparameters": (
122
+ [0.0 for _ in range(5)], False
123
+ ), # STDP_function name, tau_pre_stdp, tau_post_stdp, a_exp_pre, a_exp_post, Wmax, Wmin
124
+ "internal_state": (
125
+ [0.0 for _ in range(4)], True
126
+ ), # NEIGHBOR-VISIBLE: somas read Isyn; Isyn, Isyn_supp, pre_trace, post_trace
127
+ "internal_learning_state": (
128
+ [0.0 for _ in range(3)], False
129
+ ), # pre_trace, post_trace, dW
130
+ "synapse_delay_reg": ([], False), # Synapse delay
131
+ "input_spikes_tensor": ([], False), # input spikes tensor
132
+ "output_spikes_tensor": ([], False),
133
+ "internal_states_buffer": ([], False),
134
+ "internal_learning_states_buffer": ([], False), # learning states buffer
135
+ }
136
+ self._synapse_ids = []
137
+ self._soma_ids = []
138
+ self._soma_reset_states = {}
139
+
140
+ # Disable double buffering for properties written by soma:
141
+ # - internal_state: self-access only (v, tcount, tlast)
142
+ # - output_spikes_tensor: synapse reads [t-1], soma writes [t] (different indices)
143
+ # - internal_states_buffer: self-access only, no cross-agent reads
144
+ soma_no_double_buffer = [
145
+ "internal_state",
146
+ "output_spikes_tensor",
147
+ "internal_states_buffer",
148
+ ]
149
+
150
+ self._soma_breeds: Dict[str, Breed] = {}
151
+ for breed_name, step_funcs in soma_breed_info.items():
152
+ soma_breed = Breed(breed_name)
153
+ for prop_name, (default_val, neighbor_visible) in soma_properties.items():
154
+ soma_breed.register_property(prop_name, default_val, neighbor_visible=neighbor_visible)
155
+ for step_func_order, (step_func, module_fpath) in enumerate(step_funcs):
156
+ module_fpath = (
157
+ CURRENT_DIR_ABSPATH / "izh_soma.py"
158
+ if module_fpath is None
159
+ else module_fpath
160
+ )
161
+ soma_breed.register_step_func(
162
+ step_func=step_func,
163
+ module_fpath=module_fpath,
164
+ priority=step_func_order,
165
+ no_double_buffer=soma_no_double_buffer,
166
+ )
167
+ self.register_breed(soma_breed)
168
+ self._soma_breeds[breed_name] = soma_breed
169
+
170
+ # Disable double buffering for properties written by synapse:
171
+ # - internal_state: soma reads at P0 before synapse writes at P100
172
+ # - internal_states_buffer: self-access only, no cross-agent reads
173
+ synapse_no_double_buffer = [
174
+ "internal_state",
175
+ "internal_states_buffer",
176
+ ]
177
+
178
+ self._synapse_breeds: Dict[str, Breed] = {}
179
+ for breed_name, step_funcs in synapse_breed_info.items():
180
+ synapse_breed = Breed(breed_name)
181
+ for prop_name, (default_val, neighbor_visible) in synapse_properties.items():
182
+ synapse_breed.register_property(prop_name, default_val, neighbor_visible=neighbor_visible)
183
+ for step_func_order, (step_func, module_fpath) in enumerate(step_funcs):
184
+ module_fpath = (
185
+ CURRENT_DIR_ABSPATH / "izh_soma.py"
186
+ if module_fpath is None
187
+ else module_fpath
188
+ )
189
+ synapse_breed.register_step_func(
190
+ step_func=step_func,
191
+ module_fpath=module_fpath,
192
+ priority=100 + step_func_order,
193
+ no_double_buffer=synapse_no_double_buffer,
194
+ )
195
+ self.register_breed(synapse_breed)
196
+ self._synapse_breeds[breed_name] = synapse_breed
197
+
198
+ self.tag2component = defaultdict(set) # tag -> agent_id
199
+
200
+ # Load and hold configurations
201
+ self.agentid2config = {}
202
+ self._component_configurations = load_component_configurations()
203
+
204
+ self.synapse2soma_map = defaultdict(
205
+ dict
206
+ ) # synapse_id -> "pre" or "post" -> soma_id
207
+ self.soma2synapse_map = defaultdict(
208
+ lambda: defaultdict(set)
209
+ ) # soma_id -> "pre" or "post" -> List[synapse_id]
210
+ self._synapse2defaultparameters: Dict[int, List[float]] = {}
211
+ self._synapse2defaultlearningparameters: Dict[int, List[float]] = {}
212
+ self._synapse2defaultinternalstate: Dict[int, List[float]] = {}
213
+ self._synapse2defaultinternallearningstate: Dict[int, List[float]] = {}
214
+
215
+ def get_agent_config_name(self, agent_id: int) -> Dict[str, any]:
216
+ """
217
+ Returns the configuration of the agent with the given ID.
218
+ """
219
+ return self.agentid2config.get(agent_id, None)
220
+
221
+ def get_agent_breed(self, agent_id: int) -> str:
222
+ """
223
+ Returns the breed of the agent with the given ID.
224
+ """
225
+ breed_idx = int(
226
+ self.get_agent_property_value(id=agent_id, property_name="breed")
227
+ )
228
+ return list(self._agent_factory.breeds)[breed_idx]
229
+
230
+ def get_synapse_connectivity(self, synapse_id: int) -> List[int]:
231
+ """
232
+ Returns the connectivity of the synapse with the given ID.
233
+ The connectivity is a list of length 2 containing pre and post soma IDs.
234
+
235
+ Note: This returns the ordered locations [pre_soma_id, post_soma_id].
236
+ These are agent IDs, not local indices.
237
+ """
238
+
239
+ return self.get_agent_property_value(
240
+ id=synapse_id, property_name="locations"
241
+ )
242
+
243
+
244
+ def get_agent_config_diff(self, agent_id: int) -> Dict[str, any]:
245
+ """
246
+ Returns the configuration overrides for the agent with the given ID.
247
+ """
248
+ component_class = (
249
+ "soma" if agent_id in self.get_agents_with_tag("soma") else "synapse"
250
+ )
251
+ breed_name = self.get_agent_breed(agent_id).name
252
+ config_name = self.get_agent_config_name(agent_id)
253
+ config = self._component_configurations[component_class][breed_name][
254
+ config_name
255
+ ]
256
+ overrides = {}
257
+ # Must use Python 3.7+ dict comprehension syntax for ordered dicts
258
+ property_names = config.keys()
259
+ for property_name in property_names:
260
+ config_property_key_values = config.get(property_name, {})
261
+ current_property_key_values = self.get_agent_property_value(
262
+ id=agent_id, property_name=property_name
263
+ )
264
+ diffs = {
265
+ k: (
266
+ v,
267
+ current_property_key_values[i],
268
+ v - current_property_key_values[i],
269
+ )
270
+ for i, (k, v) in enumerate(config_property_key_values.items())
271
+ if v != current_property_key_values[i]
272
+ }
273
+
274
+ overrides[property_name] = {
275
+ k: current_property_key_values[i]
276
+ for i, (k, v) in enumerate(config_property_key_values.items())
277
+ if v != current_property_key_values[i]
278
+ }
279
+ return overrides
280
+
281
+ def get_agents_with_tag(self, tag: str) -> Set[int]:
282
+ """
283
+ Returns a list of agent IDs associated with the given tag.
284
+
285
+ :param tag: The tag to filter agents by.
286
+ :return: List of agent IDs that have the specified tag.
287
+ """
288
+ return self.tag2component.get(tag, set())
289
+
290
+ def _reset_agents(self, retain_parameters: bool = True) -> None:
291
+ """
292
+ Internal method to reset all soma and synapse agents to their initial states.
293
+
294
+ :param retain_parameters: If True, keeps current learned parameters.
295
+ If False, resets parameters to their default values.
296
+ """
297
+ # Reset all synapses
298
+ for synapse_id in self._synapse_ids:
299
+ # OPTIMIZED: Use depth-2 flattened format [tick, value, tick, value, ...] instead of depth-3 [[tick, value], ...]
300
+ super().set_agent_property_value(
301
+ id=synapse_id,
302
+ property_name="input_spikes_tensor",
303
+ value=[-1, 0.0], # Flattened: [tick, value]
304
+ )
305
+ # Reset synapse delay registers
306
+ synapse_delay = len(
307
+ super().get_agent_property_value(
308
+ id=synapse_id, property_name="synapse_delay_reg"
309
+ )
310
+ )
311
+ synapse_delay_reg = [0 for _ in range(synapse_delay)]
312
+ super().set_agent_property_value(
313
+ id=synapse_id,
314
+ property_name="synapse_delay_reg",
315
+ value=synapse_delay_reg,
316
+ )
317
+ # Reset internal states
318
+ super().set_agent_property_value(
319
+ id=synapse_id,
320
+ property_name="internal_state",
321
+ value=self._synapse2defaultinternalstate[synapse_id].copy(),
322
+ )
323
+ super().set_agent_property_value(
324
+ id=synapse_id,
325
+ property_name="internal_learning_state",
326
+ value=self._synapse2defaultinternallearningstate[synapse_id].copy(),
327
+ )
328
+ # Reset parameters to defaults if retain_parameters is False
329
+ if not retain_parameters:
330
+ super().set_agent_property_value(
331
+ id=synapse_id,
332
+ property_name="hyperparameters",
333
+ value=self._synapse2defaultparameters[synapse_id].copy(),
334
+ )
335
+ super().set_agent_property_value(
336
+ id=synapse_id,
337
+ property_name="learning_hyperparameters",
338
+ value=self._synapse2defaultlearningparameters[synapse_id].copy(),
339
+ )
340
+
341
+ # Reset all somas
342
+ for soma_id in self._soma_ids:
343
+ # Reset internal states
344
+ super().set_agent_property_value(
345
+ id=soma_id,
346
+ property_name="internal_state",
347
+ value=self._soma_reset_states[soma_id].copy(),
348
+ )
349
+
350
+ def reset(self, retain_parameters: bool = True) -> None:
351
+ """
352
+ Resets all soma and synapse agents to their initial states.
353
+
354
+ :param retain_parameters: If True, keeps current learned parameters.
355
+ If False, resets parameters to their default values.
356
+ """
357
+ self._reset_agents(retain_parameters=retain_parameters)
358
+ # Clear SAGESim's agent data cache to avoid expensive comparisons on next simulation
359
+ self._agent_factory._prev_agent_data.clear()
360
+ super().reset()
361
+
362
+ def setup(
363
+ self,
364
+ use_gpu: bool = True,
365
+ retain_parameters=True,
366
+ ) -> None:
367
+ """
368
+ Resets the simulation and initializes agents.
369
+
370
+ :param retain_parameters: False by default. If True, parameters are
371
+ reset to their default values upon setup.
372
+ """
373
+ # Reset all agents using the shared helper function
374
+ self._reset_agents(retain_parameters=retain_parameters)
375
+ super().setup(use_gpu=use_gpu)
376
+
377
+ def simulate(
378
+ self, ticks: int, update_data_ticks: int = 1 # , num_cpu_proc: int = 4
379
+ ) -> None:
380
+ """
381
+ Override of superneuroabm.core.model mainly to register an
382
+ AgentDataCollector to monitor marked output somas.
383
+
384
+ """
385
+ for soma_id in self._soma_ids:
386
+ # Clear output buffer
387
+ output_buffer = [0 for _ in range(ticks)]
388
+ super().set_agent_property_value(
389
+ id=soma_id,
390
+ property_name="output_spikes_tensor",
391
+ value=output_buffer,
392
+ )
393
+ initial_internal_state = super().get_agent_property_value(
394
+ id=soma_id, property_name="internal_state"
395
+ )
396
+ # Allocate full buffer when tracking enabled, minimal dummy buffer when disabled
397
+ if self.enable_internal_state_tracking:
398
+ internal_states_buffer = [initial_internal_state[::] for _ in range(ticks)]
399
+ else:
400
+ # Minimal dummy buffer - single element that gets overwritten each tick
401
+ internal_states_buffer = [initial_internal_state[::]]
402
+ super().set_agent_property_value(
403
+ id=soma_id,
404
+ property_name="internal_states_buffer",
405
+ value=internal_states_buffer,
406
+ )
407
+ # Allocate internal_learning_states_buffer for somas too (for MPI consistency)
408
+ # Somas don't use this, but having consistent structure across agent types
409
+ # prevents issues when MPI workers have different agent type distributions
410
+ initial_internal_learning_state = super().get_agent_property_value(
411
+ id=soma_id, property_name="internal_learning_state"
412
+ )
413
+ if self.enable_internal_state_tracking:
414
+ internal_learning_states_buffer = [
415
+ initial_internal_learning_state[::] for _ in range(ticks)
416
+ ]
417
+ else:
418
+ internal_learning_states_buffer = [initial_internal_learning_state[::]]
419
+ super().set_agent_property_value(
420
+ id=soma_id,
421
+ property_name="internal_learning_states_buffer",
422
+ value=internal_learning_states_buffer,
423
+ )
424
+ for synapse_id in self._synapse_ids:
425
+ initial_internal_state = super().get_agent_property_value(
426
+ id=synapse_id, property_name="internal_state"
427
+ )
428
+ # Allocate full buffer when tracking enabled, minimal dummy buffer when disabled
429
+ if self.enable_internal_state_tracking:
430
+ internal_states_buffer = [initial_internal_state[::] for _ in range(ticks)]
431
+ else:
432
+ # Minimal dummy buffer - single element that gets overwritten each tick
433
+ internal_states_buffer = [initial_internal_state[::]]
434
+ super().set_agent_property_value(
435
+ id=synapse_id,
436
+ property_name="internal_states_buffer",
437
+ value=internal_states_buffer,
438
+ )
439
+
440
+ initial_internal_learning_state = super().get_agent_property_value(
441
+ id=synapse_id, property_name="internal_learning_state"
442
+ )
443
+ # Allocate full buffer when tracking enabled, minimal dummy buffer when disabled
444
+ if self.enable_internal_state_tracking:
445
+ internal_learning_states_buffer = [
446
+ initial_internal_learning_state[::] for _ in range(ticks)
447
+ ]
448
+ else:
449
+ # Minimal dummy buffer - single element that gets overwritten each tick
450
+ internal_learning_states_buffer = [initial_internal_learning_state[::]]
451
+ super().set_agent_property_value(
452
+ id=synapse_id,
453
+ property_name="internal_learning_states_buffer",
454
+ value=internal_learning_states_buffer,
455
+ )
456
+ super().simulate(ticks, update_data_ticks) # , num_cpu_proc)
457
+
458
+ def create_soma(
459
+ self,
460
+ breed: str,
461
+ config_name: str,
462
+ hyperparameters_overrides: Dict[str, float] = None,
463
+ default_internal_state_overrides: Dict[str, float] = None,
464
+ tags: Set[str] = None,
465
+ ) -> int:
466
+ """
467
+ Creates and soma agent.
468
+
469
+ :return: SAGESim agent id of soma
470
+
471
+ """
472
+ tags = tags if tags else set()
473
+
474
+ # Get relevant configuration
475
+ config = copy.deepcopy(
476
+ self._component_configurations["soma"][breed][config_name]
477
+ )
478
+ # Apply overrides to hyperparameters and default internal state
479
+ if hyperparameters_overrides:
480
+ for parameter_name, parameter_value in hyperparameters_overrides.items():
481
+ config["hyperparameters"][parameter_name] = parameter_value
482
+ if default_internal_state_overrides:
483
+ for state_name, state_value in default_internal_state_overrides.items():
484
+ config["internal_state"][state_name] = state_value
485
+
486
+ hyperparameters = [float(val) for val in config["hyperparameters"].values()]
487
+ default_internal_state = [
488
+ float(val) for val in config["internal_state"].values()
489
+ ]
490
+
491
+ soma_id = super().create_agent_of_breed(
492
+ breed=self._soma_breeds[breed],
493
+ hyperparameters=hyperparameters,
494
+ internal_state=default_internal_state,
495
+ )
496
+
497
+ self._soma_ids.append(soma_id)
498
+ self._soma_reset_states[soma_id] = default_internal_state
499
+
500
+ self.agentid2config[soma_id] = config_name
501
+
502
+ tags.update({"soma", breed})
503
+ for tag in tags:
504
+ self.tag2component[tag].add(soma_id)
505
+ return soma_id
506
+
507
+ def create_synapse(
508
+ self,
509
+ breed: str,
510
+ pre_soma_id: int,
511
+ post_soma_id: int,
512
+ config_name: str,
513
+ hyperparameters_overrides: Dict[str, float] = None,
514
+ default_internal_state_overrides: Dict[str, float] = None,
515
+ learning_hyperparameters_overrides: Dict[str, float] = None,
516
+ default_internal_learning_state_overrides: Dict[str, float] = None,
517
+ tags: Set[str] = None,
518
+ ) -> int:
519
+ """
520
+ Creates and adds a Synapse agent.
521
+
522
+ Parameters:
523
+ breed (str): Synapse breed name (e.g., 'single_exp_synapse').
524
+ pre_soma_id (int): Presynaptic soma agent ID (or -1 for external input).
525
+ post_soma_id (int): Postsynaptic soma agent ID (or -1 for external output).
526
+ config_name (str): Name of the configuration to use for this synapse.
527
+ hyperparameters_overrides (dict, optional): Dict of hyperparameter overrides.
528
+ default_internal_state_overrides (dict, optional): Dict of internal state overrides.
529
+ learning_hyperparameters_overrides (dict, optional): Dict of learning hyperparameter overrides.
530
+ default_internal_learning_state_overrides (dict, optional): Dict of internal learning state overrides.
531
+ tags (set of str, optional): Tags to associate with this synapse.
532
+
533
+ Returns:
534
+ int: SAGESim agent ID of the created synapse.
535
+ """
536
+ tags = tags if tags else set()
537
+
538
+ # Get relevant configuration
539
+ config = copy.deepcopy(
540
+ self._component_configurations["synapse"][breed][config_name]
541
+ )
542
+
543
+ # Apply overrides to hyperparameters and default internal state
544
+ if hyperparameters_overrides:
545
+ for parameter_name, parameter_value in hyperparameters_overrides.items():
546
+ config["hyperparameters"][parameter_name] = parameter_value
547
+ if default_internal_state_overrides:
548
+ for state_name, state_value in default_internal_state_overrides.items():
549
+ config["internal_state"][state_name] = state_value
550
+ if learning_hyperparameters_overrides:
551
+ for (
552
+ parameter_name,
553
+ parameter_value,
554
+ ) in learning_hyperparameters_overrides.items():
555
+ config["learning_hyperparameters"][parameter_name] = parameter_value
556
+ if default_internal_learning_state_overrides:
557
+ for (
558
+ state_name,
559
+ state_value,
560
+ ) in default_internal_learning_state_overrides.items():
561
+ config["internal_learning_state"][state_name] = state_value
562
+ hyperparameters = [float(val) for val in config["hyperparameters"].values()]
563
+ default_internal_state = [
564
+ float(val) for val in config["internal_state"].values()
565
+ ]
566
+ learning_hyperparameters = [
567
+ float(val)
568
+ for val in config.get(
569
+ "learning_hyperparameters", {"stdp_type": -1}
570
+ ).values()
571
+ ]
572
+ default_internal_learning_state = [
573
+ float(val) for val in config.get("internal_learning_state", {}).values()
574
+ ]
575
+
576
+ synaptic_delay = int(hyperparameters[1])
577
+ delay_reg = [0 for _ in range(synaptic_delay)]
578
+ synapse_id = self.create_agent_of_breed(
579
+ breed=self._synapse_breeds[breed],
580
+ hyperparameters=hyperparameters,
581
+ learning_hyperparameters=learning_hyperparameters,
582
+ internal_state=default_internal_state,
583
+ internal_learning_state=default_internal_learning_state,
584
+ synapse_delay_reg=delay_reg,
585
+ )
586
+ self._synapse2defaultparameters[synapse_id] = hyperparameters
587
+ self._synapse2defaultlearningparameters[synapse_id] = learning_hyperparameters
588
+ self._synapse2defaultinternalstate[synapse_id] = default_internal_state
589
+ self._synapse2defaultinternallearningstate[synapse_id] = (
590
+ default_internal_learning_state
591
+ )
592
+ self._synapse_ids.append(synapse_id)
593
+
594
+ network_space: NetworkSpace = self.get_space()
595
+
596
+ # Connect synapse to somas using SAGESim's API
597
+ # With ordered=True, connections are maintained in insertion order
598
+ # So synapse's locations will be [pre_soma_id, post_soma_id] after we connect them
599
+
600
+ # IMPORTANT: Connect in order [pre, post] to maintain ordered locations
601
+ # First connection: pre_soma (if exists)
602
+ # -1 indicates external input
603
+ if pre_soma_id != -1:
604
+ network_space.connect_agents(synapse_id, pre_soma_id, directed=True)
605
+ self.soma2synapse_map[pre_soma_id]["post"].add(synapse_id)
606
+ self.synapse2soma_map[synapse_id]["pre"] = pre_soma_id
607
+ else:
608
+ # For external input, manually add -1 to locations to maintain [pre, post] order
609
+ network_space.get_location(synapse_id).append(-1)
610
+ self.synapse2soma_map[synapse_id]["pre"] = -1 # External input
611
+ tags.add("input_synapse")
612
+
613
+
614
+ # Second connection: post_soma (if exists)
615
+ # -1 indicates external output
616
+ if post_soma_id != -1:
617
+ network_space.connect_agents(synapse_id, post_soma_id, directed=True)
618
+ network_space.connect_agents(post_soma_id, synapse_id, directed=True) # Bidirectional for STDP
619
+ self.synapse2soma_map[synapse_id]["post"] = post_soma_id
620
+ self.soma2synapse_map[post_soma_id]["pre"].add(synapse_id)
621
+ else:
622
+ self.synapse2soma_map[synapse_id]["post"] = -1
623
+ # For external output (rare), manually add -1
624
+ network_space.get_location(synapse_id).append(-1)
625
+
626
+ self.agentid2config[synapse_id] = config_name
627
+ tags.update({"synapse", breed})
628
+ for tag in tags:
629
+ self.tag2component[tag].add(synapse_id)
630
+ return synapse_id
631
+
632
+ def add_spike(self, synapse_id: int, tick: int, value: float) -> None:
633
+ """
634
+ Schedules an external input spike to this soma.
635
+
636
+ :param tick: tick at which spike should be triggered
637
+ :param value: spike value
638
+ """
639
+ spikes = self.get_agent_property_value(
640
+ id=synapse_id,
641
+ property_name="input_spikes_tensor",
642
+ )
643
+ # OPTIMIZED: Store as flattened [tick, value, tick, value, ...] (depth 2) instead of [[tick, value], ...] (depth 3)
644
+ spikes.append(tick)
645
+ spikes.append(value)
646
+ self.set_agent_property_value(
647
+ synapse_id, "input_spikes_tensor", spikes
648
+ )
649
+
650
+ def add_spike_list(self, synapse_id: int, spike_list):
651
+ """
652
+ Schedules a list of external input spikes to this synapse.
653
+
654
+ :param spike_list: List of [tick, value] pairs
655
+ """
656
+ spikes = self.get_agent_property_value(
657
+ id=synapse_id,
658
+ property_name="input_spikes_tensor",
659
+ )
660
+ # OPTIMIZED: Flatten [[tick, value], ...] to [tick, value, tick, value, ...]
661
+ for spike_pair in spike_list:
662
+ spikes.append(spike_pair[0]) # tick
663
+ spikes.append(spike_pair[1]) # value
664
+ self.set_agent_property_value(
665
+ synapse_id, "input_spikes_tensor", spikes
666
+ )
667
+
668
+ def get_spike_times(self, soma_id: int) -> np.array:
669
+ spike_train = super().get_agent_property_value(
670
+ id=soma_id,
671
+ property_name="output_spikes_tensor",
672
+ )
673
+ spike_times = [i for i in range(len(spike_train)) if spike_train[i] > 0]
674
+ return spike_times
675
+
676
+ def get_internal_states_history(self, agent_id: int) -> np.array:
677
+ if not self.enable_internal_state_tracking:
678
+ return []
679
+ return super().get_agent_property_value(
680
+ id=agent_id, property_name="internal_states_buffer"
681
+ )
682
+
683
+ def get_internal_learning_states_history(self, agent_id: int) -> np.array:
684
+ if not self.enable_internal_state_tracking:
685
+ return []
686
+ return super().get_agent_property_value(
687
+ id=agent_id, property_name="internal_learning_states_buffer"
688
+ )
689
+