bsb-nest 6.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bsb-nest might be problematic. Click here for more details.

bsb_nest/__init__.py ADDED
@@ -0,0 +1,27 @@
1
+ """
2
+ NEST simulation adapter for the BSB framework.
3
+ """
4
+
5
+ from bsb import SimulationBackendPlugin
6
+
7
+ from .adapter import NestAdapter
8
+ from .devices import (
9
+ DCGenerator,
10
+ Multimeter,
11
+ PoissonGenerator,
12
+ SinusoidalPoissonGenerator,
13
+ SpikeRecorder,
14
+ )
15
+ from .simulation import NestSimulation
16
+
17
+ __plugin__ = SimulationBackendPlugin(Simulation=NestSimulation, Adapter=NestAdapter)
18
+
19
+ __all__ = [
20
+ "DCGenerator",
21
+ "Multimeter",
22
+ "NestAdapter",
23
+ "NestSimulation",
24
+ "PoissonGenerator",
25
+ "SinusoidalPoissonGenerator",
26
+ "SpikeRecorder",
27
+ ]
bsb_nest/adapter.py ADDED
@@ -0,0 +1,203 @@
1
+ import contextlib
2
+ import sys
3
+ import typing
4
+
5
+ import nest
6
+ from bsb import (
7
+ AdapterError,
8
+ AdapterProgress,
9
+ SimulationData,
10
+ SimulationResult,
11
+ SimulatorAdapter,
12
+ report,
13
+ warn,
14
+ )
15
+ from neo import SpikeTrain
16
+ from tqdm import tqdm
17
+
18
+ from .exceptions import KernelWarning, NestConnectError, NestModelError, NestModuleError
19
+
20
+ if typing.TYPE_CHECKING:
21
+ from .simulation import NestSimulation
22
+
23
+
24
+ class NestResult(SimulationResult):
25
+ def record(self, nc, **annotations):
26
+ recorder = nest.Create("spike_recorder", params={"record_to": "memory"})
27
+ nest.Connect(nc, recorder)
28
+
29
+ def flush(segment):
30
+ events = recorder.events[0]
31
+
32
+ segment.spiketrains.append(
33
+ SpikeTrain(
34
+ events["times"],
35
+ array_annotations={"senders": events["senders"]},
36
+ t_stop=nest.biological_time,
37
+ units="ms",
38
+ **annotations,
39
+ )
40
+ )
41
+
42
+ self.create_recorder(flush)
43
+
44
+
45
+ class NestAdapter(SimulatorAdapter):
46
+ def __init__(self, comm=None):
47
+ super().__init__(comm=comm)
48
+ self.loaded_modules = set()
49
+
50
+ def simulate(self, *simulations, post_prepare=None):
51
+ try:
52
+ self.reset_kernel()
53
+ return super().simulate(*simulations, post_prepare=post_prepare)
54
+ finally:
55
+ self.reset_kernel()
56
+
57
+ def prepare(self, simulation):
58
+ """
59
+ Prepare the simulation environment in NEST.
60
+
61
+ This method initializes internal data structures and performs all
62
+ setup steps required before running the simulation:
63
+
64
+ - Loads and installs required NEST modules.
65
+ - Applies simulation-level settings (e.g., resolution, verbosity, seed).
66
+ - Creates neuron populations based on cell models.
67
+ - Establishes connectivity between neurons using connection models.
68
+ - Instantiates devices (e.g., recorders, stimuli) used in the simulation.
69
+
70
+ If any error occurs during preparation, the corresponding internal state
71
+ is cleaned up to avoid partial setups.
72
+
73
+ :param simulation: The simulation configuration to prepare.
74
+ :type simulation: NestSimulation
75
+ :returns: The prepared simulation data associated with the given simulation.
76
+ :rtype: bsb.simulation.adapter.SimulationData
77
+ """
78
+ self.simdata[simulation] = SimulationData(
79
+ simulation, result=NestResult(simulation)
80
+ )
81
+ try:
82
+ report("Installing NEST modules...", level=2)
83
+ self.load_modules(simulation)
84
+ self.set_settings(simulation)
85
+ report("Creating neurons...", level=2)
86
+ self.create_neurons(simulation)
87
+ report("Creating connections...", level=2)
88
+ self.connect_neurons(simulation)
89
+ report("Creating devices...", level=2)
90
+ self.create_devices(simulation)
91
+ return self.simdata[simulation]
92
+ except Exception:
93
+ del self.simdata[simulation]
94
+ raise
95
+
96
+ def reset_kernel(self):
97
+ nest.ResetKernel()
98
+ # Reset which modules we should consider explicitly loaded by the user
99
+ # to appropriately warn them when they load them twice.
100
+ self.loaded_modules = set()
101
+
102
+ def run(self, *simulations):
103
+ unprepared = [sim for sim in simulations if sim not in self.simdata]
104
+ if unprepared:
105
+ raise AdapterError(f"Unprepared for simulations: {', '.join(unprepared)}")
106
+ report("Simulating...", level=2)
107
+ duration = max(sim.duration for sim in simulations)
108
+ progress = AdapterProgress(duration)
109
+ try:
110
+ with nest.RunManager():
111
+ for oi, i in progress.steps(step=1):
112
+ nest.Run(i - oi)
113
+ progress.tick(i)
114
+ finally:
115
+ results = [self.simdata[sim].result for sim in simulations]
116
+ for sim in simulations:
117
+ del self.simdata[sim]
118
+ progress.complete()
119
+ report("Simulation done.", level=2)
120
+ return results
121
+
122
+ def load_modules(self, simulation):
123
+ for module in simulation.modules:
124
+ try:
125
+ nest.Install(module)
126
+ self.loaded_modules.add(module)
127
+ except Exception as e:
128
+ if e.errorname == "DynamicModuleManagementError":
129
+ if "loaded already" in e.message:
130
+ # Modules stay loaded in between `ResetKernel` calls.
131
+ # If the module is not in the `loaded_modules` set, then
132
+ # it's the first time this `reset`/`prepare` cycle,
133
+ # and there is no user-side issue.
134
+ if module in self.loaded_modules:
135
+ warn(f"Already loaded '{module}'.", KernelWarning)
136
+ elif "file not found" in e.message:
137
+ raise NestModuleError(f"Module {module} not found") from None
138
+ else:
139
+ raise
140
+ else:
141
+ raise
142
+
143
+ def create_neurons(self, simulation):
144
+ """
145
+ Create a population of nodes in the NEST simulator based on the cell model
146
+ configurations.
147
+ """
148
+ simdata = self.simdata[simulation]
149
+ for cell_model in simulation.cell_models.values():
150
+ simdata.populations[cell_model] = cell_model.create_population(simdata)
151
+
152
+ def connect_neurons(self, simulation):
153
+ """
154
+ Connect the cells in NEST according to the connection model configurations
155
+ """
156
+ simdata = self.simdata[simulation]
157
+ iter = simulation.connection_models.values()
158
+ if self.comm.get_rank() == 0:
159
+ iter = tqdm(iter, desc="", file=sys.stdout)
160
+ for connection_model in iter:
161
+ with contextlib.suppress(AttributeError):
162
+ # Only rank 0 should report progress bar
163
+ iter.set_description(connection_model.name)
164
+ cs = simulation.scaffold.get_connectivity_set(
165
+ connection_model.tag or connection_model.name
166
+ )
167
+ try:
168
+ pre_nodes = simdata.populations[simulation.get_model_of(cs.pre_type)]
169
+ except KeyError:
170
+ raise NestModelError(f"No model found for {cs.pre_type}") from None
171
+ try:
172
+ post_nodes = simdata.populations[simulation.get_model_of(cs.post_type)]
173
+ except KeyError:
174
+ raise NestModelError(f"No model found for {cs.post_type}") from None
175
+ try:
176
+ simdata.connections[connection_model] = (
177
+ connection_model.create_connections(
178
+ simdata, pre_nodes, post_nodes, cs, self.comm
179
+ )
180
+ )
181
+ except Exception:
182
+ raise NestConnectError(
183
+ f"{connection_model} error during connect."
184
+ ) from None
185
+
186
+ def create_devices(self, simulation):
187
+ simdata = self.simdata[simulation]
188
+ for device_model in simulation.devices.values():
189
+ device_model.implement(self, simulation, simdata)
190
+
191
+ def set_settings(self, simulation: "NestSimulation"):
192
+ nest.set_verbosity(simulation.verbosity)
193
+ nest.resolution = simulation.resolution
194
+ nest.overwrite_files = True
195
+ if simulation.seed is not None:
196
+ nest.rng_seed = simulation.seed
197
+
198
+ def check_comm(self):
199
+ if nest.NumProcesses() != self.comm.get_size():
200
+ raise RuntimeError(
201
+ f"NEST is managing {nest.NumProcesses()} processes, but "
202
+ f"{self.comm.get_size()} were detected. Please check your MPI setup."
203
+ )
bsb_nest/cell.py ADDED
@@ -0,0 +1,32 @@
1
+ import nest
2
+ from bsb import CellModel, config
3
+
4
+ from .distributions import NestRandomDistribution, nest_parameter
5
+
6
+
7
+ @config.node
8
+ class NestCell(CellModel):
9
+ model = config.attr(type=str, default="iaf_psc_alpha")
10
+ """Importable reference to the NEST model describing the cell type."""
11
+ constants = config.dict(type=nest_parameter())
12
+ """Dictionary of the constants values to assign to the cell model."""
13
+
14
+ def create_population(self, simdata):
15
+ n = len(simdata.placement[self])
16
+ population = nest.Create(self.model, n) if n else nest.NodeCollection([])
17
+ self.set_constants(population)
18
+ self.set_parameters(population, simdata)
19
+ return population
20
+
21
+ def set_constants(self, population):
22
+ population.set(
23
+ {
24
+ k: (v() if isinstance(v, NestRandomDistribution) else v)
25
+ for k, v in self.constants.items()
26
+ }
27
+ )
28
+
29
+ def set_parameters(self, population, simdata):
30
+ ps = simdata.placement[self]
31
+ for param in self.parameters:
32
+ population.set(param.name, param.get_value(ps))
bsb_nest/connection.py ADDED
@@ -0,0 +1,193 @@
1
+ import functools
2
+ import sys
3
+
4
+ import nest
5
+ import numpy as np
6
+ import psutil
7
+ from bsb import ConnectionModel, compose_nodes, config, types
8
+ from tqdm import tqdm
9
+
10
+ from .distributions import nest_parameter
11
+ from .exceptions import NestConnectError
12
+
13
+
14
+ @config.node
15
+ class NestSynapseSettings:
16
+ """
17
+ Class interfacing a NEST synapse model.
18
+ """
19
+
20
+ model = config.attr(type=str, default="static_synapse")
21
+ """Importable reference to the NEST model describing the synapse type."""
22
+ weight = config.attr(type=float, required=True)
23
+ """Weight of the connection between the presynaptic and the postsynaptic cells."""
24
+ delay = config.attr(type=float, required=True)
25
+ """Delay of the transmission between the presynaptic and the postsynaptic cells."""
26
+ receptor_type = config.attr(type=int)
27
+ """Index of the postsynaptic receptor to target."""
28
+ constants = config.catch_all(type=nest_parameter())
29
+ """Dictionary of the constants values to assign to the synapse model."""
30
+
31
+
32
+ @config.node
33
+ class NestConnectionSettings:
34
+ """
35
+ Class interfacing a NEST connection rule.
36
+ """
37
+
38
+ rule = config.attr(type=str)
39
+ """Importable reference to the Nest connection rule used to connect the cells."""
40
+ constants = config.catch_all(type=types.any_())
41
+ """Dictionary of parameters to assign to the connection rule."""
42
+
43
+
44
+ class LazySynapseCollection:
45
+ def __init__(self, pre, post):
46
+ self._pre = pre
47
+ self._post = post
48
+
49
+ def __len__(self):
50
+ return self.collection.__len__()
51
+
52
+ def __str__(self):
53
+ return self.collection.__str__()
54
+
55
+ def __iter__(self):
56
+ return iter(self.collection)
57
+
58
+ def __getattr__(self, attr):
59
+ return getattr(self.collection, attr)
60
+
61
+ @functools.cached_property
62
+ def collection(self):
63
+ return nest.GetConnections(self._pre, self._post)
64
+
65
+
66
+ @config.dynamic(attr_name="model_strategy", required=False)
67
+ class NestConnection(compose_nodes(NestConnectionSettings, ConnectionModel)):
68
+ """
69
+ Class interfacing a NEST connection, including its connection rule and synaptic
70
+ parameters.
71
+ """
72
+
73
+ model_strategy: str
74
+ """
75
+ Specifies the strategy used by the connection model for synapse creation and
76
+ management.
77
+ """
78
+
79
+ synapse = config.attr(type=NestSynapseSettings, required=True)
80
+ """Nest synapse model with its parameters."""
81
+
82
+ def create_connections(self, simdata, pre_nodes, post_nodes, cs, comm):
83
+ import nest
84
+
85
+ syn_spec = self.get_syn_spec()
86
+ if syn_spec["synapse_model"] not in nest.Models(mtype="synapses"):
87
+ raise NestConnectError(
88
+ f"Unknown synapse model '{syn_spec['synapse_model']}'."
89
+ )
90
+ if self.rule is not None:
91
+ nest.Connect(pre_nodes, post_nodes, self.get_conn_spec(), syn_spec)
92
+ else:
93
+ comm.barrier()
94
+ for pre_locs, post_locs in self.predict_mem_iterator(
95
+ pre_nodes, post_nodes, cs, comm
96
+ ):
97
+ comm.barrier()
98
+ if len(pre_locs) == 0 or len(post_locs) == 0:
99
+ continue
100
+ cell_pairs, multiplicity = np.unique(
101
+ np.column_stack((pre_locs[:, 0], post_locs[:, 0])),
102
+ return_counts=True,
103
+ axis=0,
104
+ )
105
+ prel = pre_nodes.tolist()
106
+ postl = post_nodes.tolist()
107
+ ssw = {**syn_spec}
108
+ bw = syn_spec["weight"]
109
+ ssw["weight"] = [bw * m for m in multiplicity]
110
+ ssw["delay"] = [syn_spec["delay"]] * len(ssw["weight"])
111
+ nest.Connect(
112
+ [prel[x] for x in cell_pairs[:, 0]],
113
+ [postl[x] for x in cell_pairs[:, 1]],
114
+ "one_to_one",
115
+ ssw,
116
+ return_synapsecollection=False,
117
+ )
118
+ comm.barrier()
119
+ return LazySynapseCollection(pre_nodes, post_nodes)
120
+
121
+ def predict_mem_iterator(self, pre_nodes, post_nodes, cs, comm):
122
+ avmem = psutil.virtual_memory().available
123
+ predicted_all_mem = (
124
+ len(pre_nodes) * 8 * 2 + len(post_nodes) * 8 * 2 + len(cs) * 6 * 8 * (16 + 2)
125
+ ) * comm.get_size()
126
+ n_chunks = len(cs.get_local_chunks("out"))
127
+ predicted_local_mem = (predicted_all_mem / n_chunks) if n_chunks > 0 else 0.0
128
+ if predicted_local_mem > avmem / 2:
129
+ # Iterate block-by-block
130
+ return self.block_iterator(cs, comm)
131
+ elif predicted_all_mem > avmem / 2:
132
+ # Iterate local hyperblocks
133
+ return self.local_iterator(cs, comm)
134
+ else:
135
+ # Iterate all
136
+ return (cs.load_connections().as_globals().all(),)
137
+
138
+ def block_iterator(self, cs, comm):
139
+ locals = cs.get_local_chunks("out")
140
+
141
+ def block_iter():
142
+ iter = locals
143
+ if comm.get_rank() == 0:
144
+ iter = tqdm(iter, desc="hyperblocks", file=sys.stdout)
145
+ for local in iter:
146
+ inner_iter = cs.load_connections().as_globals().from_(local)
147
+ if comm.get_rank() == 0:
148
+ yield from tqdm(
149
+ inner_iter,
150
+ desc="blocks",
151
+ total=len(cs.get_global_chunks("out", local)),
152
+ file=sys.stdout,
153
+ leave=False,
154
+ )
155
+ else:
156
+ yield from inner_iter
157
+
158
+ return block_iter()
159
+
160
+ def local_iterator(self, cs, comm):
161
+ iter = cs.get_local_chunks("out")
162
+ if comm.get_rank() == 0:
163
+ iter = tqdm(iter, desc="hyperblocks", file=sys.stdout)
164
+ yield from (
165
+ cs.load_connections().as_globals().from_(local).all() for local in iter
166
+ )
167
+
168
+ def get_connectivity_set(self):
169
+ if self.tag is not None:
170
+ return self.scaffold.get_connectivity_set(self.tag)
171
+ else:
172
+ return self.connection_model
173
+
174
+ def get_conn_spec(self):
175
+ return {
176
+ "rule": self.rule,
177
+ **self.constants,
178
+ }
179
+
180
+ def get_syn_spec(self):
181
+ return {
182
+ **{
183
+ label: value
184
+ for attr, label in (
185
+ ("model", "synapse_model"),
186
+ ["weight"] * 2,
187
+ ["delay"] * 2,
188
+ ["receptor_type"] * 2,
189
+ )
190
+ if (value := getattr(self.synapse, attr)) is not None
191
+ },
192
+ **self.synapse.constants,
193
+ }
bsb_nest/device.py ADDED
@@ -0,0 +1,149 @@
1
+ import abc
2
+ import warnings
3
+
4
+ import nest
5
+ from bsb import DeviceModel, Targetting, config, refs, types
6
+
7
+
8
+ @config.node
9
+ class NestRule:
10
+ """
11
+ Interface to connect a device directly through the NEST interface.
12
+ """
13
+
14
+ rule = config.attr(type=str, required=True)
15
+ """Connection rule to connect """
16
+ constants = config.catch_all(type=types.any_())
17
+ """Dictionary of parameters for the targetting rule."""
18
+ cell_models = config.reflist(refs.sim_cell_model_ref)
19
+ """Reference to the Nest cell model to target with the Device"""
20
+
21
+
22
+ @config.dynamic(attr_name="device", auto_classmap=True, default="external")
23
+ class NestDevice(DeviceModel):
24
+ device: str
25
+ """Name of the NEST device model (e.g., "spike_generator", "poisson_generator")."""
26
+ weight = config.attr(type=float, required=True)
27
+ """weight of the connection between the device and its target"""
28
+ delay = config.attr(type=float, required=True)
29
+ """delay of the transmission between the device and its target"""
30
+ targetting = config.attr(
31
+ type=types.or_(Targetting, NestRule), default=dict, call_default=True
32
+ )
33
+ """Targets of the device, which should be either a population or a nest rule"""
34
+ receptor_type = config.attr(type=int, required=False, default=0)
35
+ """Integer ID of the postsynaptic target receptor"""
36
+
37
+ def get_dict_targets(
38
+ self,
39
+ adapter,
40
+ simulation,
41
+ simdata,
42
+ ) -> dict:
43
+ """
44
+ Get a dictionary from a target group to its NEST Collection
45
+ for each target group of the device.
46
+
47
+ :param bsb_nest.adapter.NestAdapter adapter: Nest adapter instance
48
+ :param bsb_nest.simulation.NestSimulation simulation: Nest simulation instance
49
+ :param bsb.simulation.adapter.SimulationData simdata: Simulation data instance
50
+ :return: dictionary of device target group to NEST Collection
51
+ :rtype: dict
52
+ """
53
+ if isinstance(self.targetting, Targetting):
54
+ node_collector = self.targetting.get_targets(adapter, simulation, simdata)
55
+ else:
56
+ node_collector = {
57
+ model: simdata.populations[model][targets]
58
+ for model, targets in simdata.populations.items()
59
+ if not self.targetting.cell_models or model in self.targetting.cell_models
60
+ }
61
+ return node_collector
62
+
63
+ @staticmethod
64
+ def _flatten_nodes_ids(dict_targets):
65
+ return sum(dict_targets.values(), start=nest.NodeCollection())
66
+
67
+ @staticmethod
68
+ def _invert_targets_dict(dict_targets):
69
+ return {elem: k.name for k, v in dict_targets.items() for elem in v.tolist()}
70
+
71
+ def get_target_nodes(
72
+ self,
73
+ adapter,
74
+ simulation,
75
+ simdata,
76
+ ):
77
+ """
78
+ Get the NEST Collection of the targets of the device.
79
+
80
+ :param bsb_nest.adapter.NestAdapter adapter:
81
+ :param bsb_nest.simulation.NestSimulation simulation: Nest simulation instance
82
+ :param bsb.simulation.adapter.SimulationData simdata: Simulation data instance
83
+ :return: Flattened NEST collection with all the targets of the device
84
+ """
85
+ targets_dict = self.get_dict_targets(adapter, simulation, simdata)
86
+ return self._flatten_nodes_ids(targets_dict)
87
+
88
+ def connect_to_nodes(self, device, nodes):
89
+ if len(nodes) == 0:
90
+ warnings.warn(f"{self.name} has no targets", stacklevel=2)
91
+ else:
92
+ try:
93
+ nest.Connect(
94
+ device,
95
+ nodes,
96
+ syn_spec={
97
+ "weight": self.weight,
98
+ "delay": self.delay,
99
+ "receptor_type": self.receptor_type,
100
+ },
101
+ )
102
+
103
+ except Exception as e:
104
+ if "does not send output" not in str(e):
105
+ raise
106
+ nest.Connect(
107
+ nodes,
108
+ device,
109
+ syn_spec={"weight": self.weight, "delay": self.delay},
110
+ )
111
+
112
+ def register_device(self, simdata, device):
113
+ simdata.devices[self] = device
114
+ return device
115
+
116
+ @abc.abstractmethod
117
+ def implement(
118
+ self,
119
+ adapter,
120
+ simulation,
121
+ simdata,
122
+ ):
123
+ """
124
+ Create, connect and register the Nest device.
125
+
126
+ :param bsb_nest.adapter.NestAdapter adapter:
127
+ :param bsb_nest.simulation.NestSimulation simulation: Nest simulation instance
128
+ :param bsb.simulation.adapter.SimulationData simdata: Simulation data instance
129
+ """
130
+ pass
131
+
132
+
133
+ @config.node
134
+ class ExtNestDevice(NestDevice, classmap_entry="external"):
135
+ """
136
+ Class interfacing Nest devices.
137
+ """
138
+
139
+ nest_model = config.attr(type=str, required=True)
140
+ """Importable reference to the NEST model describing the device type."""
141
+ constants = config.dict(type=types.or_(types.number(), str))
142
+ """Dictionary of the constants values to assign to the device model."""
143
+
144
+ def implement(self, adapter, simulation, simdata):
145
+ simdata.devices[self] = device = nest.Create(
146
+ self.nest_model, params=self.constants
147
+ )
148
+ nodes = self.get_target_nodes(adapter, simulation, simdata)
149
+ self.connect_to_nodes(device, nodes)
@@ -0,0 +1,13 @@
1
+ from .dc_generator import DCGenerator
2
+ from .multimeter import Multimeter
3
+ from .poisson_generator import PoissonGenerator
4
+ from .sinusoidal_poisson_generator import SinusoidalPoissonGenerator
5
+ from .spike_recorder import SpikeRecorder
6
+
7
+ __all__ = [
8
+ "DCGenerator",
9
+ "Multimeter",
10
+ "PoissonGenerator",
11
+ "SinusoidalPoissonGenerator",
12
+ "SpikeRecorder",
13
+ ]
@@ -0,0 +1,23 @@
1
+ import nest
2
+ from bsb import config
3
+
4
+ from ..device import NestDevice
5
+
6
+
7
+ @config.node
8
+ class DCGenerator(NestDevice, classmap_entry="dc_generator"):
9
+ amplitude = config.attr(type=float, required=True)
10
+ """Current amplitude of the dc generator"""
11
+ start = config.attr(type=float, required=False, default=0.0)
12
+ """Activation time in ms"""
13
+ stop = config.attr(type=float, required=False, default=None)
14
+ """Deactivation time in ms.
15
+ If not specified, generator will last until the end of the simulation."""
16
+
17
+ def implement(self, adapter, simulation, simdata):
18
+ nodes = self.get_target_nodes(adapter, simulation, simdata)
19
+ params = {"amplitude": self.amplitude, "start": self.start}
20
+ if self.stop is not None and self.stop > self.start:
21
+ params["stop"] = self.stop
22
+ device = self.register_device(simdata, nest.Create("dc_generator", params=params))
23
+ self.connect_to_nodes(device, nodes)