bsb-arbor 5.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bsb-arbor might be problematic. Click here for more details.

bsb_arbor/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """
2
+ Arbor simulation adapter for the BSB framework.
3
+ """
4
+
5
+ from bsb import SimulationBackendPlugin
6
+
7
+ from .adapter import ArborAdapter
8
+ from .devices import PoissonGenerator, Probe, SpikeRecorder
9
+ from .simulation import ArborSimulation
10
+
11
+ __plugin__ = SimulationBackendPlugin(Simulation=ArborSimulation, Adapter=ArborAdapter)
12
+
13
+
14
+ __all__ = [
15
+ "PoissonGenerator",
16
+ "Probe",
17
+ "SpikeRecorder",
18
+ "ArborAdapter",
19
+ "ArborSimulation",
20
+ ]
bsb_arbor/adapter.py ADDED
@@ -0,0 +1,456 @@
1
+ import itertools
2
+ import itertools as it
3
+ import time
4
+ import typing
5
+
6
+ import arbor
7
+ import numpy as np
8
+ from arbor import units as U
9
+ from bsb import (
10
+ AdapterError,
11
+ Chunk,
12
+ SimulationData,
13
+ SimulatorAdapter,
14
+ UnknownGIDError,
15
+ report,
16
+ warn,
17
+ )
18
+
19
+ if typing.TYPE_CHECKING:
20
+ from .simulation import ArborSimulation
21
+
22
+
23
+ class ArborSimulationData(SimulationData):
24
+ """
25
+ Container class for simulation data.
26
+ """
27
+
28
+ def __init__(self, simulation):
29
+ """
30
+ Container class for simulation data.
31
+ """
32
+ super().__init__(simulation)
33
+ self.arbor_sim: arbor.simulation = None
34
+
35
+
36
+ class ReceiverCollection(list):
37
+ """
38
+ Receiver collections store the incoming connections and deduplicate them into multiple
39
+ targets.
40
+ """
41
+
42
+ def __init__(self):
43
+ super().__init__()
44
+ self._endpoint_counters = {}
45
+
46
+ def append(self, rcv):
47
+ endpoint = str(rcv.loc_on)
48
+ id = self._endpoint_counters.get(endpoint, 0)
49
+ self._endpoint_counters[endpoint] = id + 1
50
+ rcv.index = id
51
+ super().append(rcv)
52
+
53
+
54
+ class SingleReceiverCollection(list):
55
+ """
56
+ The single receiver collection redirects all incoming connections to the same
57
+ receiver.
58
+ """
59
+
60
+ def append(self, rcv):
61
+ rcv.index = 0
62
+ super().append(rcv)
63
+
64
+
65
+ class Population:
66
+ def __init__(self, simdata, cell_model, offset):
67
+ self._model = cell_model
68
+ self._simdata = simdata
69
+ ps = cell_model.get_placement_set(simdata.chunks)
70
+ self._ranges = self._get_ranges(simdata.chunks, ps, offset)
71
+ self._offset = offset
72
+
73
+ @property
74
+ def model(self):
75
+ return self._model
76
+
77
+ @property
78
+ def offset(self):
79
+ return self._offset
80
+
81
+ def __len__(self):
82
+ return sum(stop - start for start, stop in self._ranges)
83
+
84
+ def __contains__(self, i):
85
+ return any(start <= i < stop for start, stop in self._ranges)
86
+
87
+ def copy(self):
88
+ return Population(self._simdata, self._model, self._offset)
89
+
90
+ def __getitem__(self, item):
91
+ # Boolean masking, kind of
92
+ if getattr(item, "dtype", None) == bool or _all_bools(item): # noqa: E721
93
+ if len(item) != len(self):
94
+ raise ValueError(
95
+ f"Dimension mismatch between population ({len(self)}) "
96
+ f"and mask ({len(item)})"
97
+ )
98
+ return self._subpop_np(np.array(self)[item])
99
+ elif getattr(item, "dtype", None) == int or _all_ints(item): # noqa: E721
100
+ if getattr(item, "ndim", None) == 0:
101
+ return self._subpop_one(item)
102
+ return self._subpop_np(np.array(self)[item])
103
+ elif isinstance(item, slice):
104
+ return self._subpop_np(np.array(self)[item])
105
+ else:
106
+ return self._subpop_one(item)
107
+
108
+ def _get_ranges(self, chunks, ps, offset):
109
+ stats = ps.get_chunk_stats()
110
+ ranges = []
111
+ for chunk, len_ in sorted(
112
+ stats.items(), key=lambda k: Chunk.from_id(int(k[0]), None).id
113
+ ):
114
+ if chunk in chunks:
115
+ ranges.append((offset, offset + len_))
116
+ offset += len_
117
+ return ranges
118
+
119
+ def _subpop_np(self, arr):
120
+ pop = self.copy()
121
+ if not len(pop):
122
+ return pop
123
+ ranges = []
124
+ prev = None
125
+ start, stop = self._ranges[0]
126
+ for i in arr:
127
+ if prev is None:
128
+ start += i
129
+ stop = start + 1
130
+ elif i == prev + 1:
131
+ stop += 1
132
+ else:
133
+ ranges.append((start, stop))
134
+ start = i
135
+ stop = i + 1
136
+ prev = i
137
+ pop._ranges = ranges
138
+ return pop
139
+
140
+ def _subpop_one(self, item):
141
+ if item >= len(self):
142
+ raise IndexError(f"Index {item} out of bounds for size {len(self)}")
143
+ pop = self.copy()
144
+ ptr = 0
145
+ for start, stop in self._ranges:
146
+ if item < (ptr + stop - start):
147
+ pop._ranges = [(start + ptr - item, start + ptr - item + 1)]
148
+ return pop
149
+ else:
150
+ ptr += stop - start
151
+
152
+ def __iter__(self):
153
+ yield from itertools.chain.from_iterable(range(r[0], r[1]) for r in self._ranges)
154
+
155
+
156
+ class GIDManager:
157
+ def __init__(self, simulation, simdata):
158
+ self._gid_offsets = {}
159
+ self._model_order = self.sort_models(simulation.cell_models.values())
160
+ ctr = 0
161
+ for model in self._model_order:
162
+ self._gid_offsets[model] = ctr
163
+ ctr += len(model.get_placement_set())
164
+ self._populations = [
165
+ Population(simdata, model, offset)
166
+ for model, offset in self._gid_offsets.items()
167
+ ]
168
+
169
+ def sort_models(self, models):
170
+ return sorted(
171
+ models,
172
+ key=lambda model: len(model.get_placement_set()),
173
+ )
174
+
175
+ def lookup_offset(self, gid):
176
+ model = self.lookup_model(gid)
177
+ return self._gid_offsets[model]
178
+
179
+ def lookup_kind(self, gid):
180
+ return self._lookup(gid).model.get_cell_kind(gid)
181
+
182
+ def lookup_model(self, gid):
183
+ return self._lookup(gid).model
184
+
185
+ def _lookup(self, gid):
186
+ try:
187
+ return next(c for c in self._populations if gid in c)
188
+ except StopIteration:
189
+ raise UnknownGIDError(f"Can't find gid {gid}.") from None
190
+
191
+ def all(self):
192
+ yield from itertools.chain.from_iterable(self._populations)
193
+
194
+ def get_populations(self):
195
+ return {pop.model: pop for pop in self._populations}
196
+
197
+
198
+ class ArborRecipe(arbor.recipe):
199
+ def __init__(self, simulation, simdata):
200
+ super().__init__()
201
+ self._simulation = simulation
202
+ self._simdata = simdata
203
+ self._global_properties = arbor.neuron_cable_properties()
204
+ self._global_properties.set_property(
205
+ Vm=-65 * U.mV,
206
+ tempK=300 * U.Kelvin,
207
+ rL=35.4 * U.Ohm * U.cm,
208
+ cm=0.01 * U.F / U.m2,
209
+ )
210
+ self._global_properties.set_ion(
211
+ ion="na", int_con=10 * U.mM, ext_con=140 * U.mM, rev_pot=50 * U.mM
212
+ )
213
+ self._global_properties.set_ion(
214
+ ion="k", int_con=54.4 * U.mM, ext_con=2.5 * U.mM, rev_pot=-77 * U.mM
215
+ )
216
+ self._global_properties.set_ion(
217
+ ion="ca", int_con=0.0001 * U.mM, ext_con=2 * U.mM, rev_pot=132.5 * U.mM
218
+ )
219
+ self._global_properties.set_ion(
220
+ ion="h",
221
+ valence=1,
222
+ int_con=1.0 * U.mM,
223
+ ext_con=1.0 * U.mM,
224
+ rev_pot=-34 * U.mM,
225
+ )
226
+ self._global_properties.catalogue = self._get_catalogue()
227
+
228
+ def _get_catalogue(self):
229
+ catalogue = arbor.default_catalogue()
230
+ prefixes = set()
231
+ for model in self._simulation.cell_models.values():
232
+ prefix, model_catalogue = model.get_prefixed_catalogue()
233
+ if model_catalogue is not None and prefix not in prefixes:
234
+ prefixes.add(prefix)
235
+ catalogue.extend(model_catalogue, "")
236
+
237
+ return catalogue
238
+
239
+ def global_properties(self, kind):
240
+ return self._global_properties
241
+
242
+ def num_cells(self):
243
+ return sum(
244
+ len(model.get_placement_set())
245
+ for model in self._simulation.cell_models.values()
246
+ )
247
+
248
+ def cell_kind(self, gid):
249
+ return self._simdata.gid_manager.lookup_kind(gid)
250
+
251
+ def cell_description(self, gid):
252
+ model = self._simdata.gid_manager.lookup_model(gid)
253
+ return model.get_description(gid)
254
+
255
+ def connections_on(self, gid):
256
+ return [
257
+ arbor.connection(rcv.from_(), rcv.on(), rcv.weight, rcv.delay * U.ms)
258
+ for rcv in self._simdata.connections_on[gid]
259
+ ]
260
+
261
+ def gap_junctions_on(self, gid):
262
+ return [
263
+ c.model.gap_junction(c) for c in self._simdata.gap_junctions_on.get(gid, [])
264
+ ]
265
+
266
+ def probes(self, gid):
267
+ devices = self._simdata.devices_on[gid]
268
+ _ntag = 0
269
+ probes = []
270
+ for device in devices:
271
+ device_probes = device.implement_probes(self._simdata, gid)
272
+ for tag in range(_ntag, _ntag + len(device_probes)):
273
+ device.register_probe_id(gid, tag)
274
+ probes.extend(device_probes)
275
+ return probes
276
+
277
+ def event_generators(self, gid):
278
+ devices = self._simdata.devices_on[gid]
279
+ generators = []
280
+ for device in devices:
281
+ device_generators = device.implement_generators(self._simdata, gid)
282
+ generators.extend(device_generators)
283
+ return generators
284
+
285
+ def _name_of(self, gid):
286
+ return self._simdata.gid_manager.lookup_model(gid).cell_type.name
287
+
288
+
289
+ class ArborAdapter(SimulatorAdapter):
290
+ def __init__(self, comm=None):
291
+ super().__init__(comm)
292
+ self.simdata: dict[ArborSimulation, ArborSimulationData] = {}
293
+
294
+ def prepare(self, simulation: "ArborSimulation") -> ArborSimulationData:
295
+ """
296
+ Prepares the arbor simulation engine with the given simulation.
297
+ """
298
+ simdata = self._create_simdata(simulation)
299
+ try:
300
+ context = arbor.context(arbor.proc_allocation(threads=simulation.threads))
301
+ if self.comm.get_size() > 1:
302
+ if not arbor.config()["mpi4py"]:
303
+ warn(
304
+ f"Arbor does not seem to be built with MPI support, running"
305
+ f"duplicate simulations on {self.comm.get_size()} nodes."
306
+ )
307
+ else:
308
+ context = arbor.context(
309
+ arbor.proc_allocation(threads=simulation.threads),
310
+ mpi=self.comm.get_communicator(),
311
+ )
312
+ if simulation.profiling:
313
+ if arbor.config()["profiling"]:
314
+ report("enabling profiler", level=2)
315
+ arbor.profiler_initialize(context)
316
+ else:
317
+ raise RuntimeError(
318
+ "Arbor must be built with profiling support to use the "
319
+ "`profiling` flag."
320
+ )
321
+ simdata.gid_manager = self.get_gid_manager(simulation, simdata)
322
+ simdata.populations = simdata.gid_manager.get_populations()
323
+ report("preparing simulation", level=1)
324
+ report("MPI processes:", context.ranks, level=2)
325
+ report("Threads per process:", context.threads, level=2)
326
+ recipe = self.get_recipe(simulation, simdata)
327
+ # Gap junctions are required for domain decomposition
328
+ self.domain = arbor.partition_load_balance(recipe, context)
329
+ self.gids = set(it.chain.from_iterable(g.gids for g in self.domain.groups))
330
+ simdata.arbor_sim = arbor.simulation(recipe, context, self.domain)
331
+ self.prepare_samples(simulation, simdata)
332
+ report("prepared simulation", level=1)
333
+ return simdata
334
+ except Exception:
335
+ del self.simdata[simulation]
336
+ raise
337
+
338
+ def get_gid_manager(self, simulation, simdata):
339
+ return GIDManager(simulation, simdata)
340
+
341
+ def prepare_samples(self, simulation, simdata):
342
+ for device in simulation.devices.values():
343
+ device.prepare_samples(simdata, comm=self.comm)
344
+
345
+ def run(self, *simulations):
346
+ if len(simulations) != 1:
347
+ raise RuntimeError(
348
+ "Can not run multiple simultaneous simulations. Composition not "
349
+ "implemented."
350
+ )
351
+ simulation = simulations[0]
352
+ try:
353
+ simdata = self.simdata[simulation]
354
+ arbor_sim = simdata.arbor_sim
355
+ except KeyError:
356
+ raise AdapterError(
357
+ f"Can't run unprepared simulation '{simulation.name}'"
358
+ ) from None
359
+ try:
360
+ if not self.comm.get_rank():
361
+ arbor_sim.record(arbor.spike_recording.all)
362
+
363
+ start = time.time()
364
+ report("running simulation", level=1)
365
+ arbor_sim.run(simulation.duration * U.ms, dt=simulation.resolution * U.ms)
366
+ report(f"completed simulation. {time.time() - start:.2f}s", level=1)
367
+ if simulation.profiling and arbor.config()["profiling"]:
368
+ report("printing profiler summary", level=2)
369
+ report(arbor.profiler_summary(), level=1)
370
+ return [simdata.result]
371
+ finally:
372
+ del self.simdata[simulation]
373
+
374
+ def get_recipe(self, simulation, simdata=None):
375
+ if simdata is None:
376
+ simdata = self._create_simdata(simulation)
377
+ self._cache_gap_junctions(simulation, simdata)
378
+ self._cache_connections(simulation, simdata)
379
+ self._cache_devices(simulation, simdata)
380
+ return ArborRecipe(simulation, simdata)
381
+
382
+ def _create_simdata(self, simulation):
383
+ self.simdata[simulation] = simdata = ArborSimulationData(simulation)
384
+ self._assign_chunks(simulation, simdata)
385
+ return simdata
386
+
387
+ def _cache_gap_junctions(self, simulation, simdata):
388
+ simdata.gap_junctions_on = {}
389
+ for conn_model in simulation.connection_models.values():
390
+ if conn_model.gap:
391
+ conn_set = conn_model.get_connectivity_set()
392
+ conns = conn_set.load_connections().to(simdata.chunks).as_globals()
393
+ conn_model.create_gap_junctions_on(simdata.gap_junctions_on, conns)
394
+
395
+ def _cache_connections(self, simulation, simdata):
396
+ simdata.connections_on = {
397
+ gid: simdata.gid_manager.lookup_model(gid).make_receiver_collection()
398
+ for gid in simdata.gid_manager.all()
399
+ }
400
+ simdata.connections_from = {gid: [] for gid in simdata.gid_manager.all()}
401
+ for conn_model in simulation.connection_models.values():
402
+ if conn_model.gap:
403
+ continue
404
+ conn_set = conn_model.get_connectivity_set()
405
+ pop_pre, pop_post = None, None
406
+ for model in simulation.cell_models.values():
407
+ if model.cell_type is conn_set.pre_type:
408
+ pop_pre = simdata.populations[model]
409
+ if model.cell_type is conn_set.post_type:
410
+ pop_post = simdata.populations[model]
411
+ # Get the arriving connection iterator
412
+ conns_on = conn_set.load_connections().to(simdata.chunks).as_globals()
413
+ # Create the arriving connections
414
+ conn_model.create_connections_on(
415
+ simdata.connections_on, conns_on, pop_pre, pop_post
416
+ )
417
+ # Get the outgoing connection iterator
418
+ conns_from = conn_set.load_connections().from_(simdata.chunks).as_globals()
419
+ # Create the outgoing connections
420
+ conn_model.create_connections_from(
421
+ simdata.connections_from, conns_from, pop_pre, pop_post
422
+ )
423
+
424
+ def _cache_devices(self, simulation, simdata):
425
+ simdata.devices_on = {gid: [] for gid in simdata.gid_manager.all()}
426
+ for device in simulation.devices.values():
427
+ targets = device.targetting.get_targets(self, simulation, simdata)
428
+ for target in itertools.chain.from_iterable(targets.values()):
429
+ simdata.devices_on[target].append(device)
430
+
431
+ def _assign_chunks(self, simulation, simdata):
432
+ chunk_stats = simulation.scaffold.storage.get_chunk_stats()
433
+ size = self.comm.get_size()
434
+ all_chunks = [Chunk.from_id(int(chunk), None) for chunk in chunk_stats]
435
+ simdata.node_chunk_alloc = [all_chunks[rank::size] for rank in range(0, size)]
436
+ simdata.chunk_node_map = {}
437
+ for node, chunks in enumerate(simdata.node_chunk_alloc):
438
+ for chunk in chunks:
439
+ simdata.chunk_node_map[chunk] = node
440
+ simdata.chunks = simdata.node_chunk_alloc[self.comm.get_rank()]
441
+
442
+
443
+ def _all_bools(arr):
444
+ try:
445
+ return all(isinstance(b, bool) for b in arr)
446
+ except TypeError:
447
+ # Not iterable
448
+ return False
449
+
450
+
451
+ def _all_ints(arr):
452
+ try:
453
+ return all(isinstance(b, int) for b in arr)
454
+ except TypeError:
455
+ # Not iterable
456
+ return False
bsb_arbor/cell.py ADDED
@@ -0,0 +1,86 @@
1
+ import abc
2
+
3
+ import arbor
4
+ from bsb import CellModel, ConfigurationError, PlacementSet, config, types
5
+
6
+ from .adapter import SingleReceiverCollection
7
+
8
+
9
+ @config.dynamic(
10
+ attr_name="model_strategy",
11
+ auto_classmap=True,
12
+ required=True,
13
+ classmap_entry=None,
14
+ )
15
+ class ArborCell(CellModel):
16
+ model_strategy: config.ConfigurationAttribute
17
+ """
18
+ Optional importable reference to a different modelling strategy than the default
19
+ Arborize strategy.
20
+ """
21
+ gap = config.attr(type=bool, default=False)
22
+ """Is this synapse a gap junction?"""
23
+ model = config.attr(type=types.class_(), required=True)
24
+ """Importable reference to the arborize model describing the cell type."""
25
+
26
+ @abc.abstractmethod
27
+ def cache_population_data(self, simdata, ps: PlacementSet):
28
+ pass
29
+
30
+ @abc.abstractmethod
31
+ def discard_population_data(self):
32
+ pass
33
+
34
+ @abc.abstractmethod
35
+ def get_prefixed_catalogue(self):
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def get_cell_kind(self, gid):
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ def make_receiver_collection(self):
44
+ pass
45
+
46
+ def get_description(self, gid):
47
+ morphology, labels, decor = self.model.cable_cell_template()
48
+ labels = self._add_labels(gid, labels, morphology)
49
+ decor = self._add_decor(gid, decor)
50
+ cc = arbor.cable_cell(morphology, labels, decor)
51
+ return cc
52
+
53
+
54
+ @config.node
55
+ class LIFCell(ArborCell, classmap_entry="lif"):
56
+ model = config.unset()
57
+ """Importable reference to the arborize model describing the cell type."""
58
+ constants = config.dict(type=types.any_())
59
+ """Dictionary linking the parameters' name to its value."""
60
+
61
+ def cache_population_data(self, simdata, ps: PlacementSet):
62
+ pass
63
+
64
+ def discard_population_data(self):
65
+ pass
66
+
67
+ def get_prefixed_catalogue(self):
68
+ return None, None
69
+
70
+ def get_cell_kind(self, gid):
71
+ return arbor.cell_kind.lif
72
+
73
+ def get_description(self, gid):
74
+ cell = arbor.lif_cell("-1_-1", "-1_-1_0")
75
+ try:
76
+ for k, v in self.constants.items():
77
+ setattr(cell, k, v * getattr(cell, k).units)
78
+ except AttributeError:
79
+ node_name = type(self).constants.get_node_name(self)
80
+ raise ConfigurationError(
81
+ f"'{k}' is not a valid LIF parameter in '{node_name}'."
82
+ ) from None
83
+ return cell
84
+
85
+ def make_receiver_collection(self):
86
+ return SingleReceiverCollection()
@@ -0,0 +1,68 @@
1
+ import arbor
2
+ import tqdm
3
+ from bsb import ConnectionModel, config
4
+
5
+
6
+ class Receiver:
7
+ def __init__(self, conn_model, from_gid, loc_from, loc_on, index=-1):
8
+ self.conn_model = conn_model
9
+ self.from_gid = from_gid
10
+ self.loc_from = loc_from
11
+ self.loc_on = loc_on
12
+ self.synapse = arbor.synapse("expsyn")
13
+ self.index = index
14
+
15
+ def from_(self):
16
+ b, p = self.loc_from
17
+ return arbor.cell_global_label(self.from_gid, f"{b}_{p}")
18
+
19
+ def on(self):
20
+ # self.index is set on us by the ReceiverCollection when we are appended.
21
+ b, p = self.loc_on
22
+ return arbor.cell_local_label(f"{b}_{p}_{self.index}")
23
+
24
+ @property
25
+ def weight(self):
26
+ return self.conn_model.weight
27
+
28
+ @property
29
+ def delay(self):
30
+ return self.conn_model.delay
31
+
32
+
33
+ class Connection:
34
+ def __init__(self, pre_loc, post_loc):
35
+ self.from_id = pre_loc[0]
36
+ self.to_id = post_loc[0]
37
+ self.pre_loc = pre_loc[1:]
38
+ self.post_loc = post_loc[1:]
39
+
40
+
41
+ @config.node
42
+ class ArborConnection(ConnectionModel):
43
+ gap = config.attr(type=bool, default=False)
44
+ """Is this synapce a gap junction?"""
45
+ weight = config.attr(type=float, required=True)
46
+ """Weight of the connection between the presynaptic and the postsynaptic cells."""
47
+ delay = config.attr(type=float, required=True)
48
+ """Delay of the transmission between the presynaptic and the postsynaptic cells."""
49
+
50
+ def create_gap_junctions_on(self, gj_on_gid, conns):
51
+ for pre_loc, post_loc in conns:
52
+ conn = Connection(pre_loc, post_loc)
53
+ gj_on_gid.setdefault(conn.from_id, []).append(conn)
54
+
55
+ def create_connections_on(self, conns_on_gid, conns, pop_pre, pop_post):
56
+ for pre_loc, post_loc in tqdm.tqdm(conns, total=len(conns), desc=self.name):
57
+ conns_on_gid[post_loc[0] + pop_post.offset].append(
58
+ Receiver(self, pre_loc[0] + pop_pre.offset, pre_loc[1:], post_loc[1:])
59
+ )
60
+
61
+ def create_connections_from(self, conns_from_gid, conns, pop_pre, pop_post):
62
+ for pre_loc, _post_loc in conns:
63
+ conns_from_gid[int(pre_loc[0] + pop_pre.offset)].append(pre_loc[1:])
64
+
65
+ def gap_junction(self, conn):
66
+ l_ = arbor.cell_local_label(f"gap_{conn.to_compartment.id}")
67
+ g = arbor.cell_global_label(int(conn.from_id), f"gap_{conn.from_compartment.id}")
68
+ return arbor.gap_junction_connection(g, l_, self.weight)
bsb_arbor/device.py ADDED
@@ -0,0 +1,50 @@
1
+ import abc
2
+
3
+ import arbor
4
+ from bsb import DeviceModel, Targetting, config, types
5
+
6
+
7
+ @config.dynamic(attr_name="device", auto_classmap=True, classmap_entry=None)
8
+ class ArborDevice(DeviceModel):
9
+ device: config.ConfigurationAttribute
10
+ """Optional importable reference to the device strategy."""
11
+ targetting = config.attr(type=Targetting, required=True)
12
+ """Targets of the device, which should be either a population or a nest rule."""
13
+ resolution = config.attr(type=float)
14
+ """Time resolution of the device."""
15
+ sampling_policy = config.attr(type=types.in_(["exact"]))
16
+ """Policy used to sample simulation data from the device."""
17
+
18
+ def __init__(self, **kwargs):
19
+ self._probe_ids = []
20
+
21
+ def __boot__(self):
22
+ self.resolution = self.resolution or self.simulation.resolution
23
+
24
+ def register_probe_id(self, gid, tag):
25
+ self._probe_ids.append((gid, tag))
26
+
27
+ def prepare_samples(self, simdata, comm):
28
+ self._handles = [
29
+ self.sample(simdata.arbor_sim, probe_id) for probe_id in self._probe_ids
30
+ ]
31
+
32
+ def sample(self, sim, probe_id):
33
+ schedule = arbor.regular_schedule(self.resolution)
34
+ sampling_policy = getattr(arbor.sampling_policy, self.sampling_policy)
35
+ return sim.sample(probe_id, schedule, sampling_policy)
36
+
37
+ def get_samples(self, sim):
38
+ return [sim.samples(handle) for handle in self._handles]
39
+
40
+ def get_meta(self):
41
+ attrs = ("name", "sampling_policy", "resolution")
42
+ return dict(zip(attrs, (getattr(self, attr) for attr in attrs), strict=False))
43
+
44
+ @abc.abstractmethod
45
+ def implement_probes(self, simdata, target):
46
+ pass
47
+
48
+ @abc.abstractmethod
49
+ def implement_generators(self, simdata, target):
50
+ pass
@@ -0,0 +1,9 @@
1
+ from .poisson_generator import PoissonGenerator
2
+ from .probe import Probe
3
+ from .spike_recorder import SpikeRecorder
4
+
5
+ __all__ = [
6
+ "PoissonGenerator",
7
+ "Probe",
8
+ "SpikeRecorder",
9
+ ]