bsb-nest 4.0.0rc2__py2.py3-none-any.whl → 4.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bsb-nest might be problematic. Click here for more details.

bsb_nest/connection.py CHANGED
@@ -1,165 +1,164 @@
1
- import functools
2
- import sys
3
-
4
- import nest
5
- import numpy as np
6
- import psutil
7
- from bsb import MPI, ConnectionModel, compose_nodes, config, types
8
- from tqdm import tqdm
9
-
10
- from .exceptions import NestConnectError
11
-
12
-
13
- @config.node
14
- class NestSynapseSettings:
15
- model = config.attr(type=str, default="static_synapse")
16
- weight = config.attr(type=float, required=True)
17
- delay = config.attr(type=float, required=True)
18
- receptor_type = config.attr(type=int)
19
- constants = config.catch_all(type=types.any_())
20
-
21
-
22
- @config.node
23
- class NestConnectionSettings:
24
- rule = config.attr(type=str)
25
- constants = config.catch_all(type=types.any_())
26
-
27
-
28
- class LazySynapseCollection:
29
- def __init__(self, pre, post):
30
- self._pre = pre
31
- self._post = post
32
-
33
- def __len__(self):
34
- return self.collection.__len__()
35
-
36
- def __str__(self):
37
- return self.collection.__str__()
38
-
39
- def __iter__(self):
40
- return iter(self.collection)
41
-
42
- def __getattr__(self, attr):
43
- return getattr(self.collection, attr)
44
-
45
- @functools.cached_property
46
- def collection(self):
47
- return nest.GetConnections(self._pre, self._post)
48
-
49
-
50
- @config.dynamic(attr_name="model_strategy", required=False)
51
- class NestConnection(compose_nodes(NestConnectionSettings, ConnectionModel)):
52
- tag = config.attr(type=str)
53
- synapse = config.attr(type=NestSynapseSettings, required=True)
54
-
55
- def create_connections(self, simdata, pre_nodes, post_nodes, cs):
56
- import nest
57
-
58
- syn_spec = self.get_syn_spec()
59
- if syn_spec["synapse_model"] not in nest.Models(mtype="synapses"):
60
- raise NestConnectError(
61
- f"Unknown synapse model '{syn_spec['synapse_model']}'."
62
- )
63
- if self.rule is not None:
64
- nest.Connect(pre_nodes, post_nodes, self.get_conn_spec(), syn_spec)
65
- else:
66
- MPI.barrier()
67
- for pre_locs, post_locs in self.predict_mem_iterator(
68
- pre_nodes, post_nodes, cs
69
- ):
70
- MPI.barrier()
71
- cell_pairs, multiplicity = np.unique(
72
- np.column_stack((pre_locs[:, 0], post_locs[:, 0])),
73
- return_counts=True,
74
- axis=0,
75
- )
76
- prel = pre_nodes.tolist()
77
- postl = post_nodes.tolist()
78
- ssw = {**syn_spec}
79
- bw = syn_spec["weight"]
80
- ssw["weight"] = [bw * m for m in multiplicity]
81
- ssw["delay"] = [syn_spec["delay"]] * len(ssw["weight"])
82
- nest.Connect(
83
- [prel[x] for x in cell_pairs[:, 0]],
84
- [postl[x] for x in cell_pairs[:, 1]],
85
- "one_to_one",
86
- ssw,
87
- return_synapsecollection=False,
88
- )
89
- MPI.barrier()
90
- return LazySynapseCollection(pre_nodes, post_nodes)
91
-
92
- def predict_mem_iterator(self, pre_nodes, post_nodes, cs):
93
- avmem = psutil.virtual_memory().available
94
- predicted_all_mem = (
95
- len(pre_nodes) * 8 * 2
96
- + len(post_nodes) * 8 * 2
97
- + len(cs) * 6 * 8 * (16 + 2)
98
- ) * MPI.get_size()
99
- predicted_local_mem = predicted_all_mem / len(cs.get_local_chunks("out"))
100
- if predicted_local_mem > avmem / 2:
101
- # Iterate block-by-block
102
- return self.block_iterator(cs)
103
- elif predicted_all_mem > avmem / 2:
104
- # Iterate local hyperblocks
105
- return self.local_iterator(cs)
106
- else:
107
- # Iterate all
108
- return (cs.load_connections().as_globals().all(),)
109
-
110
- def block_iterator(self, cs):
111
- locals = cs.get_local_chunks("out")
112
-
113
- def block_iter():
114
- iter = locals
115
- if MPI.get_rank() == 0:
116
- iter = tqdm(iter, desc="hyperblocks", file=sys.stdout)
117
- for local in iter:
118
- inner_iter = cs.load_connections().as_globals().from_(local)
119
- if MPI.get_rank() == 0:
120
- yield from tqdm(
121
- inner_iter,
122
- desc="blocks",
123
- total=len(cs.get_global_chunks("out", local)),
124
- file=sys.stdout,
125
- leave=False,
126
- )
127
- else:
128
- yield from inner_iter
129
-
130
- return block_iter()
131
-
132
- def local_iterator(self, cs):
133
- iter = cs.get_local_chunks("out")
134
- if MPI.get_rank() == 0:
135
- iter = tqdm(iter, desc="hyperblocks", file=sys.stdout)
136
- yield from (
137
- cs.load_connections().as_globals().from_(local).all() for local in iter
138
- )
139
-
140
- def get_connectivity_set(self):
141
- if self.tag is not None:
142
- return self.scaffold.get_connectivity_set(self.tag)
143
- else:
144
- return self.connection_model
145
-
146
- def get_conn_spec(self):
147
- return {
148
- "rule": self.rule,
149
- **self.constants,
150
- }
151
-
152
- def get_syn_spec(self):
153
- return {
154
- **{
155
- label: value
156
- for attr, label in (
157
- ("model", "synapse_model"),
158
- ["weight"] * 2,
159
- ["delay"] * 2,
160
- ["receptor_type"] * 2,
161
- )
162
- if (value := getattr(self.synapse, attr)) is not None
163
- },
164
- **self.synapse.constants,
165
- }
1
+ import functools
2
+ import sys
3
+
4
+ import nest
5
+ import numpy as np
6
+ import psutil
7
+ from bsb import MPI, ConnectionModel, compose_nodes, config, types
8
+ from tqdm import tqdm
9
+
10
+ from .distributions import nest_parameter
11
+ from .exceptions import NestConnectError
12
+
13
+
14
+ @config.node
15
+ class NestSynapseSettings:
16
+ model = config.attr(type=str, default="static_synapse")
17
+ weight = config.attr(type=float, required=True)
18
+ delay = config.attr(type=float, required=True)
19
+ receptor_type = config.attr(type=int)
20
+ constants = config.catch_all(type=nest_parameter())
21
+
22
+
23
+ @config.node
24
+ class NestConnectionSettings:
25
+ rule = config.attr(type=str)
26
+ constants = config.catch_all(type=types.any_())
27
+
28
+
29
+ class LazySynapseCollection:
30
+ def __init__(self, pre, post):
31
+ self._pre = pre
32
+ self._post = post
33
+
34
+ def __len__(self):
35
+ return self.collection.__len__()
36
+
37
+ def __str__(self):
38
+ return self.collection.__str__()
39
+
40
+ def __iter__(self):
41
+ return iter(self.collection)
42
+
43
+ def __getattr__(self, attr):
44
+ return getattr(self.collection, attr)
45
+
46
+ @functools.cached_property
47
+ def collection(self):
48
+ return nest.GetConnections(self._pre, self._post)
49
+
50
+
51
+ @config.dynamic(attr_name="model_strategy", required=False)
52
+ class NestConnection(compose_nodes(NestConnectionSettings, ConnectionModel)):
53
+ tag = config.attr(type=str)
54
+ synapse = config.attr(type=NestSynapseSettings, required=True)
55
+
56
+ def create_connections(self, simdata, pre_nodes, post_nodes, cs):
57
+ import nest
58
+
59
+ syn_spec = self.get_syn_spec()
60
+ if syn_spec["synapse_model"] not in nest.Models(mtype="synapses"):
61
+ raise NestConnectError(
62
+ f"Unknown synapse model '{syn_spec['synapse_model']}'."
63
+ )
64
+ if self.rule is not None:
65
+ nest.Connect(pre_nodes, post_nodes, self.get_conn_spec(), syn_spec)
66
+ else:
67
+ MPI.barrier()
68
+ for pre_locs, post_locs in self.predict_mem_iterator(
69
+ pre_nodes, post_nodes, cs
70
+ ):
71
+ MPI.barrier()
72
+ cell_pairs, multiplicity = np.unique(
73
+ np.column_stack((pre_locs[:, 0], post_locs[:, 0])),
74
+ return_counts=True,
75
+ axis=0,
76
+ )
77
+ prel = pre_nodes.tolist()
78
+ postl = post_nodes.tolist()
79
+ ssw = {**syn_spec}
80
+ bw = syn_spec["weight"]
81
+ ssw["weight"] = [bw * m for m in multiplicity]
82
+ ssw["delay"] = [syn_spec["delay"]] * len(ssw["weight"])
83
+ nest.Connect(
84
+ [prel[x] for x in cell_pairs[:, 0]],
85
+ [postl[x] for x in cell_pairs[:, 1]],
86
+ "one_to_one",
87
+ ssw,
88
+ return_synapsecollection=False,
89
+ )
90
+ MPI.barrier()
91
+ return LazySynapseCollection(pre_nodes, post_nodes)
92
+
93
+ def predict_mem_iterator(self, pre_nodes, post_nodes, cs):
94
+ avmem = psutil.virtual_memory().available
95
+ predicted_all_mem = (
96
+ len(pre_nodes) * 8 * 2 + len(post_nodes) * 8 * 2 + len(cs) * 6 * 8 * (16 + 2)
97
+ ) * MPI.get_size()
98
+ predicted_local_mem = predicted_all_mem / len(cs.get_local_chunks("out"))
99
+ if predicted_local_mem > avmem / 2:
100
+ # Iterate block-by-block
101
+ return self.block_iterator(cs)
102
+ elif predicted_all_mem > avmem / 2:
103
+ # Iterate local hyperblocks
104
+ return self.local_iterator(cs)
105
+ else:
106
+ # Iterate all
107
+ return (cs.load_connections().as_globals().all(),)
108
+
109
+ def block_iterator(self, cs):
110
+ locals = cs.get_local_chunks("out")
111
+
112
+ def block_iter():
113
+ iter = locals
114
+ if MPI.get_rank() == 0:
115
+ iter = tqdm(iter, desc="hyperblocks", file=sys.stdout)
116
+ for local in iter:
117
+ inner_iter = cs.load_connections().as_globals().from_(local)
118
+ if MPI.get_rank() == 0:
119
+ yield from tqdm(
120
+ inner_iter,
121
+ desc="blocks",
122
+ total=len(cs.get_global_chunks("out", local)),
123
+ file=sys.stdout,
124
+ leave=False,
125
+ )
126
+ else:
127
+ yield from inner_iter
128
+
129
+ return block_iter()
130
+
131
+ def local_iterator(self, cs):
132
+ iter = cs.get_local_chunks("out")
133
+ if MPI.get_rank() == 0:
134
+ iter = tqdm(iter, desc="hyperblocks", file=sys.stdout)
135
+ yield from (
136
+ cs.load_connections().as_globals().from_(local).all() for local in iter
137
+ )
138
+
139
+ def get_connectivity_set(self):
140
+ if self.tag is not None:
141
+ return self.scaffold.get_connectivity_set(self.tag)
142
+ else:
143
+ return self.connection_model
144
+
145
+ def get_conn_spec(self):
146
+ return {
147
+ "rule": self.rule,
148
+ **self.constants,
149
+ }
150
+
151
+ def get_syn_spec(self):
152
+ return {
153
+ **{
154
+ label: value
155
+ for attr, label in (
156
+ ("model", "synapse_model"),
157
+ ["weight"] * 2,
158
+ ["delay"] * 2,
159
+ ["receptor_type"] * 2,
160
+ )
161
+ if (value := getattr(self.synapse, attr)) is not None
162
+ },
163
+ **self.synapse.constants,
164
+ }
bsb_nest/device.py CHANGED
@@ -1,70 +1,72 @@
1
- import warnings
2
-
3
- import nest
4
- from bsb import DeviceModel, Targetting, config, refs, types
5
-
6
-
7
- @config.node
8
- class NestRule:
9
- rule = config.attr(type=str, required=True)
10
- constants = config.catch_all(type=types.any_())
11
- cell_models = config.reflist(refs.sim_cell_model_ref)
12
-
13
-
14
- @config.dynamic(attr_name="device", auto_classmap=True, default="external")
15
- class NestDevice(DeviceModel):
16
- weight = config.attr(type=float, required=True)
17
- delay = config.attr(type=float, required=True)
18
- targetting = config.attr(
19
- type=types.or_(Targetting, NestRule), default=dict, call_default=True
20
- )
21
-
22
- def get_target_nodes(self, adapter, simulation, simdata):
23
- if isinstance(self.targetting, Targetting):
24
- node_collector = self.targetting.get_targets(
25
- adapter, simulation, simdata
26
- ).values()
27
- else:
28
- node_collector = (
29
- simdata.populations[model][targets]
30
- for model, targets in simdata.populations.items()
31
- if not self.targetting.cell_models
32
- or model in self.targetting.cell_models
33
- )
34
- return sum(node_collector, start=nest.NodeCollection())
35
-
36
- def connect_to_nodes(self, device, nodes):
37
- if len(nodes) == 0:
38
- warnings.warn(f"{self.name} has no targets")
39
- else:
40
- try:
41
- nest.Connect(
42
- device,
43
- nodes,
44
- syn_spec={"weight": self.weight, "delay": self.delay},
45
- )
46
- except Exception as e:
47
- if "does not send output" not in str(e):
48
- raise
49
- nest.Connect(
50
- nodes,
51
- device,
52
- syn_spec={"weight": self.weight, "delay": self.delay},
53
- )
54
-
55
- def register_device(self, simdata, device):
56
- simdata.devices[self] = device
57
- return device
58
-
59
-
60
- @config.node
61
- class ExtNestDevice(NestDevice, classmap_entry="external"):
62
- nest_model = config.attr(type=str, required=True)
63
- constants = config.dict(type=types.or_(types.number(), str))
64
-
65
- def implement(self, adapter, simulation, simdata):
66
- simdata.devices[self] = device = nest.Create(
67
- self.nest_model, params=self.constants
68
- )
69
- nodes = self.get_target_nodes(adapter, simdata)
70
- self.connect_to_nodes(device, nodes)
1
+ import warnings
2
+
3
+ import nest
4
+ from bsb import DeviceModel, Targetting, config, refs, types
5
+
6
+
7
+ @config.node
8
+ class NestRule:
9
+ rule = config.attr(type=str, required=True)
10
+ constants = config.catch_all(type=types.any_())
11
+ cell_models = config.reflist(refs.sim_cell_model_ref)
12
+
13
+
14
+ @config.dynamic(attr_name="device", auto_classmap=True, default="external")
15
+ class NestDevice(DeviceModel):
16
+ weight = config.attr(type=float, required=True)
17
+ """weight of the connection between the device and its target"""
18
+ delay = config.attr(type=float, required=True)
19
+ """delay of the transmission between the device and its target"""
20
+ targetting = config.attr(
21
+ type=types.or_(Targetting, NestRule), default=dict, call_default=True
22
+ )
23
+ """Targets of the device, which should be either a population or a nest rule"""
24
+
25
+ def get_target_nodes(self, adapter, simulation, simdata):
26
+ if isinstance(self.targetting, Targetting):
27
+ node_collector = self.targetting.get_targets(
28
+ adapter, simulation, simdata
29
+ ).values()
30
+ else:
31
+ node_collector = (
32
+ simdata.populations[model][targets]
33
+ for model, targets in simdata.populations.items()
34
+ if not self.targetting.cell_models or model in self.targetting.cell_models
35
+ )
36
+ return sum(node_collector, start=nest.NodeCollection())
37
+
38
+ def connect_to_nodes(self, device, nodes):
39
+ if len(nodes) == 0:
40
+ warnings.warn(f"{self.name} has no targets")
41
+ else:
42
+ try:
43
+ nest.Connect(
44
+ device,
45
+ nodes,
46
+ syn_spec={"weight": self.weight, "delay": self.delay},
47
+ )
48
+ except Exception as e:
49
+ if "does not send output" not in str(e):
50
+ raise
51
+ nest.Connect(
52
+ nodes,
53
+ device,
54
+ syn_spec={"weight": self.weight, "delay": self.delay},
55
+ )
56
+
57
+ def register_device(self, simdata, device):
58
+ simdata.devices[self] = device
59
+ return device
60
+
61
+
62
+ @config.node
63
+ class ExtNestDevice(NestDevice, classmap_entry="external"):
64
+ nest_model = config.attr(type=str, required=True)
65
+ constants = config.dict(type=types.or_(types.number(), str))
66
+
67
+ def implement(self, adapter, simulation, simdata):
68
+ simdata.devices[self] = device = nest.Create(
69
+ self.nest_model, params=self.constants
70
+ )
71
+ nodes = self.get_target_nodes(adapter, simdata)
72
+ self.connect_to_nodes(device, nodes)
@@ -1,2 +1,4 @@
1
- from .poisson_generator import PoissonGenerator
2
- from .spike_recorder import SpikeRecorder
1
+ from .dc_generator import DCGenerator
2
+ from .multimeter import Multimeter
3
+ from .poisson_generator import PoissonGenerator
4
+ from .spike_recorder import SpikeRecorder
@@ -0,0 +1,23 @@
1
+ import nest
2
+ from bsb import config
3
+
4
+ from ..device import NestDevice
5
+
6
+
7
+ @config.node
8
+ class DCGenerator(NestDevice, classmap_entry="dc_generator"):
9
+ amplitude = config.attr(type=float, required=True)
10
+ """Current amplitude of the dc generator"""
11
+ start = config.attr(type=float, required=False, default=0.0)
12
+ """Activation time in ms"""
13
+ stop = config.attr(type=float, required=False, default=None)
14
+ """Deactivation time in ms.
15
+ If not specified, generator will last until the end of the simulation."""
16
+
17
+ def implement(self, adapter, simulation, simdata):
18
+ nodes = self.get_target_nodes(adapter, simulation, simdata)
19
+ params = {"amplitude": self.amplitude, "start": self.start}
20
+ if self.stop is not None and self.stop > self.start:
21
+ params["stop"] = self.stop
22
+ device = self.register_device(simdata, nest.Create("dc_generator", params=params))
23
+ self.connect_to_nodes(device, nodes)
@@ -0,0 +1,52 @@
1
+ import nest
2
+ import quantities as pq
3
+ from bsb import ConfigurationError, _util, config, types
4
+ from neo import AnalogSignal
5
+
6
+ from ..device import NestDevice
7
+
8
+
9
+ @config.node
10
+ class Multimeter(NestDevice, classmap_entry="multimeter"):
11
+ weight = config.provide(1)
12
+ properties: list[str] = config.attr(type=types.list(str))
13
+ """List of properties to record in the Nest model."""
14
+ units: list[str] = config.attr(type=types.list(str))
15
+ """List of properties' units."""
16
+
17
+ def boot(self):
18
+ _util.assert_samelen(self.properties, self.units)
19
+ for i in range(len(self.units)):
20
+ if not self.units[i] in pq.units.__dict__.keys():
21
+ raise ConfigurationError(
22
+ f"Unit {self.units[i]} not in the list of known units of quantities"
23
+ )
24
+
25
+ def implement(self, adapter, simulation, simdata):
26
+
27
+ nodes = self.get_target_nodes(adapter, simulation, simdata)
28
+ device = self.register_device(
29
+ simdata,
30
+ nest.Create(
31
+ "multimeter",
32
+ params={
33
+ "interval": self.simulation.resolution,
34
+ "record_from": self.properties,
35
+ },
36
+ ),
37
+ )
38
+ self.connect_to_nodes(device, nodes)
39
+
40
+ def recorder(segment):
41
+ for prop, unit in zip(self.properties, self.units):
42
+ segment.analogsignals.append(
43
+ AnalogSignal(
44
+ device.events[prop],
45
+ units=pq.units.__dict__[unit],
46
+ sampling_period=self.simulation.resolution * pq.ms,
47
+ name=self.name,
48
+ senders=device.events["senders"],
49
+ )
50
+ )
51
+
52
+ simdata.result.create_recorder(recorder)
@@ -1,32 +1,41 @@
1
- import nest
2
- from bsb import config
3
- from neo import SpikeTrain
4
-
5
- from ..device import NestDevice
6
-
7
-
8
- @config.node
9
- class PoissonGenerator(NestDevice, classmap_entry="poisson_generator"):
10
- rate = config.attr(type=float, required=True)
11
-
12
- def implement(self, adapter, simulation, simdata):
13
- nodes = self.get_target_nodes(adapter, simulation, simdata)
14
- device = self.register_device(
15
- simdata, nest.Create("poisson_generator", params={"rate": self.rate})
16
- )
17
- sr = nest.Create("spike_recorder")
18
- nest.Connect(device, sr)
19
- self.connect_to_nodes(device, nodes)
20
-
21
- def recorder(segment):
22
- segment.spiketrains.append(
23
- SpikeTrain(
24
- sr.events["times"],
25
- units="ms",
26
- senders=sr.events["senders"],
27
- t_stop=simulation.duration,
28
- device=self.name,
29
- )
30
- )
31
-
32
- simdata.result.create_recorder(recorder)
1
+ import nest
2
+ from bsb import config
3
+ from neo import SpikeTrain
4
+
5
+ from ..device import NestDevice
6
+
7
+
8
+ @config.node
9
+ class PoissonGenerator(NestDevice, classmap_entry="poisson_generator"):
10
+ rate = config.attr(type=float, required=True)
11
+ """Frequency of the poisson generator"""
12
+ start = config.attr(type=float, required=False, default=0.0)
13
+ """Activation time in ms"""
14
+ stop = config.attr(type=float, required=False, default=None)
15
+ """Deactivation time in ms.
16
+ If not specified, generator will last until the end of the simulation."""
17
+
18
+ def implement(self, adapter, simulation, simdata):
19
+ nodes = self.get_target_nodes(adapter, simulation, simdata)
20
+ params = {"rate": self.rate, "start": self.start}
21
+ if self.stop is not None and self.stop > self.start:
22
+ params["stop"] = self.stop
23
+ device = self.register_device(
24
+ simdata, nest.Create("poisson_generator", params=params)
25
+ )
26
+ sr = nest.Create("spike_recorder")
27
+ nest.Connect(device, sr)
28
+ self.connect_to_nodes(device, nodes)
29
+
30
+ def recorder(segment):
31
+ segment.spiketrains.append(
32
+ SpikeTrain(
33
+ sr.events["times"],
34
+ units="ms",
35
+ senders=sr.events["senders"],
36
+ t_stop=simulation.duration,
37
+ device=self.name,
38
+ )
39
+ )
40
+
41
+ simdata.result.create_recorder(recorder)