bsb-arbor 6.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bsb-arbor might be problematic. Click here for more details.
- bsb_arbor/__init__.py +20 -0
- bsb_arbor/adapter.py +456 -0
- bsb_arbor/cell.py +86 -0
- bsb_arbor/connection.py +68 -0
- bsb_arbor/device.py +50 -0
- bsb_arbor/devices/__init__.py +9 -0
- bsb_arbor/devices/poisson_generator.py +30 -0
- bsb_arbor/devices/probe.py +54 -0
- bsb_arbor/devices/spike_recorder.py +42 -0
- bsb_arbor/simulation.py +38 -0
- bsb_arbor-6.0.0a5.dist-info/METADATA +36 -0
- bsb_arbor-6.0.0a5.dist-info/RECORD +15 -0
- bsb_arbor-6.0.0a5.dist-info/WHEEL +4 -0
- bsb_arbor-6.0.0a5.dist-info/entry_points.txt +3 -0
- bsb_arbor-6.0.0a5.dist-info/licenses/LICENSE +619 -0
bsb_arbor/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Arbor simulation adapter for the BSB framework.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from bsb import SimulationBackendPlugin
|
|
6
|
+
|
|
7
|
+
from .adapter import ArborAdapter
|
|
8
|
+
from .devices import PoissonGenerator, Probe, SpikeRecorder
|
|
9
|
+
from .simulation import ArborSimulation
|
|
10
|
+
|
|
11
|
+
__plugin__ = SimulationBackendPlugin(Simulation=ArborSimulation, Adapter=ArborAdapter)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"PoissonGenerator",
|
|
16
|
+
"Probe",
|
|
17
|
+
"SpikeRecorder",
|
|
18
|
+
"ArborAdapter",
|
|
19
|
+
"ArborSimulation",
|
|
20
|
+
]
|
bsb_arbor/adapter.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import itertools as it
|
|
3
|
+
import time
|
|
4
|
+
import typing
|
|
5
|
+
|
|
6
|
+
import arbor
|
|
7
|
+
import numpy as np
|
|
8
|
+
from arbor import units as U
|
|
9
|
+
from bsb import (
|
|
10
|
+
AdapterError,
|
|
11
|
+
Chunk,
|
|
12
|
+
SimulationData,
|
|
13
|
+
SimulatorAdapter,
|
|
14
|
+
UnknownGIDError,
|
|
15
|
+
report,
|
|
16
|
+
warn,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
if typing.TYPE_CHECKING:
|
|
20
|
+
from .simulation import ArborSimulation
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ArborSimulationData(SimulationData):
|
|
24
|
+
"""
|
|
25
|
+
Container class for simulation data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, simulation):
|
|
29
|
+
"""
|
|
30
|
+
Container class for simulation data.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(simulation)
|
|
33
|
+
self.arbor_sim: arbor.simulation = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ReceiverCollection(list):
|
|
37
|
+
"""
|
|
38
|
+
Receiver collections store the incoming connections and deduplicate them into multiple
|
|
39
|
+
targets.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self._endpoint_counters = {}
|
|
45
|
+
|
|
46
|
+
def append(self, rcv):
|
|
47
|
+
endpoint = str(rcv.loc_on)
|
|
48
|
+
id = self._endpoint_counters.get(endpoint, 0)
|
|
49
|
+
self._endpoint_counters[endpoint] = id + 1
|
|
50
|
+
rcv.index = id
|
|
51
|
+
super().append(rcv)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SingleReceiverCollection(list):
|
|
55
|
+
"""
|
|
56
|
+
The single receiver collection redirects all incoming connections to the same
|
|
57
|
+
receiver.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def append(self, rcv):
|
|
61
|
+
rcv.index = 0
|
|
62
|
+
super().append(rcv)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Population:
|
|
66
|
+
def __init__(self, simdata, cell_model, offset):
|
|
67
|
+
self._model = cell_model
|
|
68
|
+
self._simdata = simdata
|
|
69
|
+
ps = cell_model.get_placement_set(simdata.chunks)
|
|
70
|
+
self._ranges = self._get_ranges(simdata.chunks, ps, offset)
|
|
71
|
+
self._offset = offset
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def model(self):
|
|
75
|
+
return self._model
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def offset(self):
|
|
79
|
+
return self._offset
|
|
80
|
+
|
|
81
|
+
def __len__(self):
|
|
82
|
+
return sum(stop - start for start, stop in self._ranges)
|
|
83
|
+
|
|
84
|
+
def __contains__(self, i):
|
|
85
|
+
return any(start <= i < stop for start, stop in self._ranges)
|
|
86
|
+
|
|
87
|
+
def copy(self):
|
|
88
|
+
return Population(self._simdata, self._model, self._offset)
|
|
89
|
+
|
|
90
|
+
def __getitem__(self, item):
|
|
91
|
+
# Boolean masking, kind of
|
|
92
|
+
if getattr(item, "dtype", None) == bool or _all_bools(item): # noqa: E721
|
|
93
|
+
if len(item) != len(self):
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Dimension mismatch between population ({len(self)}) "
|
|
96
|
+
f"and mask ({len(item)})"
|
|
97
|
+
)
|
|
98
|
+
return self._subpop_np(np.array(self)[item])
|
|
99
|
+
elif getattr(item, "dtype", None) == int or _all_ints(item): # noqa: E721
|
|
100
|
+
if getattr(item, "ndim", None) == 0:
|
|
101
|
+
return self._subpop_one(item)
|
|
102
|
+
return self._subpop_np(np.array(self)[item])
|
|
103
|
+
elif isinstance(item, slice):
|
|
104
|
+
return self._subpop_np(np.array(self)[item])
|
|
105
|
+
else:
|
|
106
|
+
return self._subpop_one(item)
|
|
107
|
+
|
|
108
|
+
def _get_ranges(self, chunks, ps, offset):
|
|
109
|
+
stats = ps.get_chunk_stats()
|
|
110
|
+
ranges = []
|
|
111
|
+
for chunk, len_ in sorted(
|
|
112
|
+
stats.items(), key=lambda k: Chunk.from_id(int(k[0]), None).id
|
|
113
|
+
):
|
|
114
|
+
if chunk in chunks:
|
|
115
|
+
ranges.append((offset, offset + len_))
|
|
116
|
+
offset += len_
|
|
117
|
+
return ranges
|
|
118
|
+
|
|
119
|
+
def _subpop_np(self, arr):
|
|
120
|
+
pop = self.copy()
|
|
121
|
+
if not len(pop):
|
|
122
|
+
return pop
|
|
123
|
+
ranges = []
|
|
124
|
+
prev = None
|
|
125
|
+
start, stop = self._ranges[0]
|
|
126
|
+
for i in arr:
|
|
127
|
+
if prev is None:
|
|
128
|
+
start += i
|
|
129
|
+
stop = start + 1
|
|
130
|
+
elif i == prev + 1:
|
|
131
|
+
stop += 1
|
|
132
|
+
else:
|
|
133
|
+
ranges.append((start, stop))
|
|
134
|
+
start = i
|
|
135
|
+
stop = i + 1
|
|
136
|
+
prev = i
|
|
137
|
+
pop._ranges = ranges
|
|
138
|
+
return pop
|
|
139
|
+
|
|
140
|
+
def _subpop_one(self, item):
|
|
141
|
+
if item >= len(self):
|
|
142
|
+
raise IndexError(f"Index {item} out of bounds for size {len(self)}")
|
|
143
|
+
pop = self.copy()
|
|
144
|
+
ptr = 0
|
|
145
|
+
for start, stop in self._ranges:
|
|
146
|
+
if item < (ptr + stop - start):
|
|
147
|
+
pop._ranges = [(start + ptr - item, start + ptr - item + 1)]
|
|
148
|
+
return pop
|
|
149
|
+
else:
|
|
150
|
+
ptr += stop - start
|
|
151
|
+
|
|
152
|
+
def __iter__(self):
|
|
153
|
+
yield from itertools.chain.from_iterable(range(r[0], r[1]) for r in self._ranges)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class GIDManager:
|
|
157
|
+
def __init__(self, simulation, simdata):
|
|
158
|
+
self._gid_offsets = {}
|
|
159
|
+
self._model_order = self.sort_models(simulation.cell_models.values())
|
|
160
|
+
ctr = 0
|
|
161
|
+
for model in self._model_order:
|
|
162
|
+
self._gid_offsets[model] = ctr
|
|
163
|
+
ctr += len(model.get_placement_set())
|
|
164
|
+
self._populations = [
|
|
165
|
+
Population(simdata, model, offset)
|
|
166
|
+
for model, offset in self._gid_offsets.items()
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
def sort_models(self, models):
|
|
170
|
+
return sorted(
|
|
171
|
+
models,
|
|
172
|
+
key=lambda model: len(model.get_placement_set()),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def lookup_offset(self, gid):
|
|
176
|
+
model = self.lookup_model(gid)
|
|
177
|
+
return self._gid_offsets[model]
|
|
178
|
+
|
|
179
|
+
def lookup_kind(self, gid):
|
|
180
|
+
return self._lookup(gid).model.get_cell_kind(gid)
|
|
181
|
+
|
|
182
|
+
def lookup_model(self, gid):
|
|
183
|
+
return self._lookup(gid).model
|
|
184
|
+
|
|
185
|
+
def _lookup(self, gid):
|
|
186
|
+
try:
|
|
187
|
+
return next(c for c in self._populations if gid in c)
|
|
188
|
+
except StopIteration:
|
|
189
|
+
raise UnknownGIDError(f"Can't find gid {gid}.") from None
|
|
190
|
+
|
|
191
|
+
def all(self):
|
|
192
|
+
yield from itertools.chain.from_iterable(self._populations)
|
|
193
|
+
|
|
194
|
+
def get_populations(self):
|
|
195
|
+
return {pop.model: pop for pop in self._populations}
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class ArborRecipe(arbor.recipe):
|
|
199
|
+
def __init__(self, simulation, simdata):
|
|
200
|
+
super().__init__()
|
|
201
|
+
self._simulation = simulation
|
|
202
|
+
self._simdata = simdata
|
|
203
|
+
self._global_properties = arbor.neuron_cable_properties()
|
|
204
|
+
self._global_properties.set_property(
|
|
205
|
+
Vm=-65 * U.mV,
|
|
206
|
+
tempK=300 * U.Kelvin,
|
|
207
|
+
rL=35.4 * U.Ohm * U.cm,
|
|
208
|
+
cm=0.01 * U.F / U.m2,
|
|
209
|
+
)
|
|
210
|
+
self._global_properties.set_ion(
|
|
211
|
+
ion="na", int_con=10 * U.mM, ext_con=140 * U.mM, rev_pot=50 * U.mM
|
|
212
|
+
)
|
|
213
|
+
self._global_properties.set_ion(
|
|
214
|
+
ion="k", int_con=54.4 * U.mM, ext_con=2.5 * U.mM, rev_pot=-77 * U.mM
|
|
215
|
+
)
|
|
216
|
+
self._global_properties.set_ion(
|
|
217
|
+
ion="ca", int_con=0.0001 * U.mM, ext_con=2 * U.mM, rev_pot=132.5 * U.mM
|
|
218
|
+
)
|
|
219
|
+
self._global_properties.set_ion(
|
|
220
|
+
ion="h",
|
|
221
|
+
valence=1,
|
|
222
|
+
int_con=1.0 * U.mM,
|
|
223
|
+
ext_con=1.0 * U.mM,
|
|
224
|
+
rev_pot=-34 * U.mM,
|
|
225
|
+
)
|
|
226
|
+
self._global_properties.catalogue = self._get_catalogue()
|
|
227
|
+
|
|
228
|
+
def _get_catalogue(self):
|
|
229
|
+
catalogue = arbor.default_catalogue()
|
|
230
|
+
prefixes = set()
|
|
231
|
+
for model in self._simulation.cell_models.values():
|
|
232
|
+
prefix, model_catalogue = model.get_prefixed_catalogue()
|
|
233
|
+
if model_catalogue is not None and prefix not in prefixes:
|
|
234
|
+
prefixes.add(prefix)
|
|
235
|
+
catalogue.extend(model_catalogue, "")
|
|
236
|
+
|
|
237
|
+
return catalogue
|
|
238
|
+
|
|
239
|
+
def global_properties(self, kind):
|
|
240
|
+
return self._global_properties
|
|
241
|
+
|
|
242
|
+
def num_cells(self):
|
|
243
|
+
return sum(
|
|
244
|
+
len(model.get_placement_set())
|
|
245
|
+
for model in self._simulation.cell_models.values()
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def cell_kind(self, gid):
|
|
249
|
+
return self._simdata.gid_manager.lookup_kind(gid)
|
|
250
|
+
|
|
251
|
+
def cell_description(self, gid):
|
|
252
|
+
model = self._simdata.gid_manager.lookup_model(gid)
|
|
253
|
+
return model.get_description(gid)
|
|
254
|
+
|
|
255
|
+
def connections_on(self, gid):
|
|
256
|
+
return [
|
|
257
|
+
arbor.connection(rcv.from_(), rcv.on(), rcv.weight, rcv.delay * U.ms)
|
|
258
|
+
for rcv in self._simdata.connections_on[gid]
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
def gap_junctions_on(self, gid):
|
|
262
|
+
return [
|
|
263
|
+
c.model.gap_junction(c) for c in self._simdata.gap_junctions_on.get(gid, [])
|
|
264
|
+
]
|
|
265
|
+
|
|
266
|
+
def probes(self, gid):
|
|
267
|
+
devices = self._simdata.devices_on[gid]
|
|
268
|
+
_ntag = 0
|
|
269
|
+
probes = []
|
|
270
|
+
for device in devices:
|
|
271
|
+
device_probes = device.implement_probes(self._simdata, gid)
|
|
272
|
+
for tag in range(_ntag, _ntag + len(device_probes)):
|
|
273
|
+
device.register_probe_id(gid, tag)
|
|
274
|
+
probes.extend(device_probes)
|
|
275
|
+
return probes
|
|
276
|
+
|
|
277
|
+
def event_generators(self, gid):
|
|
278
|
+
devices = self._simdata.devices_on[gid]
|
|
279
|
+
generators = []
|
|
280
|
+
for device in devices:
|
|
281
|
+
device_generators = device.implement_generators(self._simdata, gid)
|
|
282
|
+
generators.extend(device_generators)
|
|
283
|
+
return generators
|
|
284
|
+
|
|
285
|
+
def _name_of(self, gid):
|
|
286
|
+
return self._simdata.gid_manager.lookup_model(gid).cell_type.name
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class ArborAdapter(SimulatorAdapter):
|
|
290
|
+
def __init__(self, comm=None):
|
|
291
|
+
super().__init__(comm)
|
|
292
|
+
self.simdata: dict[ArborSimulation, ArborSimulationData] = {}
|
|
293
|
+
|
|
294
|
+
def prepare(self, simulation: "ArborSimulation") -> ArborSimulationData:
|
|
295
|
+
"""
|
|
296
|
+
Prepares the arbor simulation engine with the given simulation.
|
|
297
|
+
"""
|
|
298
|
+
simdata = self._create_simdata(simulation)
|
|
299
|
+
try:
|
|
300
|
+
context = arbor.context(arbor.proc_allocation(threads=simulation.threads))
|
|
301
|
+
if self.comm.get_size() > 1:
|
|
302
|
+
if not arbor.config()["mpi4py"]:
|
|
303
|
+
warn(
|
|
304
|
+
f"Arbor does not seem to be built with MPI support, running"
|
|
305
|
+
f"duplicate simulations on {self.comm.get_size()} nodes."
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
context = arbor.context(
|
|
309
|
+
arbor.proc_allocation(threads=simulation.threads),
|
|
310
|
+
mpi=self.comm.get_communicator(),
|
|
311
|
+
)
|
|
312
|
+
if simulation.profiling:
|
|
313
|
+
if arbor.config()["profiling"]:
|
|
314
|
+
report("enabling profiler", level=2)
|
|
315
|
+
arbor.profiler_initialize(context)
|
|
316
|
+
else:
|
|
317
|
+
raise RuntimeError(
|
|
318
|
+
"Arbor must be built with profiling support to use the "
|
|
319
|
+
"`profiling` flag."
|
|
320
|
+
)
|
|
321
|
+
simdata.gid_manager = self.get_gid_manager(simulation, simdata)
|
|
322
|
+
simdata.populations = simdata.gid_manager.get_populations()
|
|
323
|
+
report("preparing simulation", level=1)
|
|
324
|
+
report("MPI processes:", context.ranks, level=2)
|
|
325
|
+
report("Threads per process:", context.threads, level=2)
|
|
326
|
+
recipe = self.get_recipe(simulation, simdata)
|
|
327
|
+
# Gap junctions are required for domain decomposition
|
|
328
|
+
self.domain = arbor.partition_load_balance(recipe, context)
|
|
329
|
+
self.gids = set(it.chain.from_iterable(g.gids for g in self.domain.groups))
|
|
330
|
+
simdata.arbor_sim = arbor.simulation(recipe, context, self.domain)
|
|
331
|
+
self.prepare_samples(simulation, simdata)
|
|
332
|
+
report("prepared simulation", level=1)
|
|
333
|
+
return simdata
|
|
334
|
+
except Exception:
|
|
335
|
+
del self.simdata[simulation]
|
|
336
|
+
raise
|
|
337
|
+
|
|
338
|
+
def get_gid_manager(self, simulation, simdata):
|
|
339
|
+
return GIDManager(simulation, simdata)
|
|
340
|
+
|
|
341
|
+
def prepare_samples(self, simulation, simdata):
|
|
342
|
+
for device in simulation.devices.values():
|
|
343
|
+
device.prepare_samples(simdata, comm=self.comm)
|
|
344
|
+
|
|
345
|
+
def run(self, *simulations):
|
|
346
|
+
if len(simulations) != 1:
|
|
347
|
+
raise RuntimeError(
|
|
348
|
+
"Can not run multiple simultaneous simulations. Composition not "
|
|
349
|
+
"implemented."
|
|
350
|
+
)
|
|
351
|
+
simulation = simulations[0]
|
|
352
|
+
try:
|
|
353
|
+
simdata = self.simdata[simulation]
|
|
354
|
+
arbor_sim = simdata.arbor_sim
|
|
355
|
+
except KeyError:
|
|
356
|
+
raise AdapterError(
|
|
357
|
+
f"Can't run unprepared simulation '{simulation.name}'"
|
|
358
|
+
) from None
|
|
359
|
+
try:
|
|
360
|
+
if not self.comm.get_rank():
|
|
361
|
+
arbor_sim.record(arbor.spike_recording.all)
|
|
362
|
+
|
|
363
|
+
start = time.time()
|
|
364
|
+
report("running simulation", level=1)
|
|
365
|
+
arbor_sim.run(simulation.duration * U.ms, dt=simulation.resolution * U.ms)
|
|
366
|
+
report(f"completed simulation. {time.time() - start:.2f}s", level=1)
|
|
367
|
+
if simulation.profiling and arbor.config()["profiling"]:
|
|
368
|
+
report("printing profiler summary", level=2)
|
|
369
|
+
report(arbor.profiler_summary(), level=1)
|
|
370
|
+
return [simdata.result]
|
|
371
|
+
finally:
|
|
372
|
+
del self.simdata[simulation]
|
|
373
|
+
|
|
374
|
+
def get_recipe(self, simulation, simdata=None):
|
|
375
|
+
if simdata is None:
|
|
376
|
+
simdata = self._create_simdata(simulation)
|
|
377
|
+
self._cache_gap_junctions(simulation, simdata)
|
|
378
|
+
self._cache_connections(simulation, simdata)
|
|
379
|
+
self._cache_devices(simulation, simdata)
|
|
380
|
+
return ArborRecipe(simulation, simdata)
|
|
381
|
+
|
|
382
|
+
def _create_simdata(self, simulation):
|
|
383
|
+
self.simdata[simulation] = simdata = ArborSimulationData(simulation)
|
|
384
|
+
self._assign_chunks(simulation, simdata)
|
|
385
|
+
return simdata
|
|
386
|
+
|
|
387
|
+
def _cache_gap_junctions(self, simulation, simdata):
|
|
388
|
+
simdata.gap_junctions_on = {}
|
|
389
|
+
for conn_model in simulation.connection_models.values():
|
|
390
|
+
if conn_model.gap:
|
|
391
|
+
conn_set = conn_model.get_connectivity_set()
|
|
392
|
+
conns = conn_set.load_connections().to(simdata.chunks).as_globals()
|
|
393
|
+
conn_model.create_gap_junctions_on(simdata.gap_junctions_on, conns)
|
|
394
|
+
|
|
395
|
+
def _cache_connections(self, simulation, simdata):
|
|
396
|
+
simdata.connections_on = {
|
|
397
|
+
gid: simdata.gid_manager.lookup_model(gid).make_receiver_collection()
|
|
398
|
+
for gid in simdata.gid_manager.all()
|
|
399
|
+
}
|
|
400
|
+
simdata.connections_from = {gid: [] for gid in simdata.gid_manager.all()}
|
|
401
|
+
for conn_model in simulation.connection_models.values():
|
|
402
|
+
if conn_model.gap:
|
|
403
|
+
continue
|
|
404
|
+
conn_set = conn_model.get_connectivity_set()
|
|
405
|
+
pop_pre, pop_post = None, None
|
|
406
|
+
for model in simulation.cell_models.values():
|
|
407
|
+
if model.cell_type is conn_set.pre_type:
|
|
408
|
+
pop_pre = simdata.populations[model]
|
|
409
|
+
if model.cell_type is conn_set.post_type:
|
|
410
|
+
pop_post = simdata.populations[model]
|
|
411
|
+
# Get the arriving connection iterator
|
|
412
|
+
conns_on = conn_set.load_connections().to(simdata.chunks).as_globals()
|
|
413
|
+
# Create the arriving connections
|
|
414
|
+
conn_model.create_connections_on(
|
|
415
|
+
simdata.connections_on, conns_on, pop_pre, pop_post
|
|
416
|
+
)
|
|
417
|
+
# Get the outgoing connection iterator
|
|
418
|
+
conns_from = conn_set.load_connections().from_(simdata.chunks).as_globals()
|
|
419
|
+
# Create the outgoing connections
|
|
420
|
+
conn_model.create_connections_from(
|
|
421
|
+
simdata.connections_from, conns_from, pop_pre, pop_post
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
def _cache_devices(self, simulation, simdata):
|
|
425
|
+
simdata.devices_on = {gid: [] for gid in simdata.gid_manager.all()}
|
|
426
|
+
for device in simulation.devices.values():
|
|
427
|
+
targets = device.targetting.get_targets(self, simulation, simdata)
|
|
428
|
+
for target in itertools.chain.from_iterable(targets.values()):
|
|
429
|
+
simdata.devices_on[target].append(device)
|
|
430
|
+
|
|
431
|
+
def _assign_chunks(self, simulation, simdata):
|
|
432
|
+
chunk_stats = simulation.scaffold.storage.get_chunk_stats()
|
|
433
|
+
size = self.comm.get_size()
|
|
434
|
+
all_chunks = [Chunk.from_id(int(chunk), None) for chunk in chunk_stats]
|
|
435
|
+
simdata.node_chunk_alloc = [all_chunks[rank::size] for rank in range(0, size)]
|
|
436
|
+
simdata.chunk_node_map = {}
|
|
437
|
+
for node, chunks in enumerate(simdata.node_chunk_alloc):
|
|
438
|
+
for chunk in chunks:
|
|
439
|
+
simdata.chunk_node_map[chunk] = node
|
|
440
|
+
simdata.chunks = simdata.node_chunk_alloc[self.comm.get_rank()]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _all_bools(arr):
|
|
444
|
+
try:
|
|
445
|
+
return all(isinstance(b, bool) for b in arr)
|
|
446
|
+
except TypeError:
|
|
447
|
+
# Not iterable
|
|
448
|
+
return False
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _all_ints(arr):
|
|
452
|
+
try:
|
|
453
|
+
return all(isinstance(b, int) for b in arr)
|
|
454
|
+
except TypeError:
|
|
455
|
+
# Not iterable
|
|
456
|
+
return False
|
bsb_arbor/cell.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
import arbor
|
|
4
|
+
from bsb import CellModel, ConfigurationError, PlacementSet, config, types
|
|
5
|
+
|
|
6
|
+
from .adapter import SingleReceiverCollection
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@config.dynamic(
|
|
10
|
+
attr_name="model_strategy",
|
|
11
|
+
auto_classmap=True,
|
|
12
|
+
required=True,
|
|
13
|
+
classmap_entry=None,
|
|
14
|
+
)
|
|
15
|
+
class ArborCell(CellModel):
|
|
16
|
+
model_strategy: config.ConfigurationAttribute
|
|
17
|
+
"""
|
|
18
|
+
Optional importable reference to a different modelling strategy than the default
|
|
19
|
+
Arborize strategy.
|
|
20
|
+
"""
|
|
21
|
+
gap = config.attr(type=bool, default=False)
|
|
22
|
+
"""Is this synapse a gap junction?"""
|
|
23
|
+
model = config.attr(type=types.class_(), required=True)
|
|
24
|
+
"""Importable reference to the arborize model describing the cell type."""
|
|
25
|
+
|
|
26
|
+
@abc.abstractmethod
|
|
27
|
+
def cache_population_data(self, simdata, ps: PlacementSet):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def discard_population_data(self):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def get_prefixed_catalogue(self):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def get_cell_kind(self, gid):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def make_receiver_collection(self):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def get_description(self, gid):
|
|
47
|
+
morphology, labels, decor = self.model.cable_cell_template()
|
|
48
|
+
labels = self._add_labels(gid, labels, morphology)
|
|
49
|
+
decor = self._add_decor(gid, decor)
|
|
50
|
+
cc = arbor.cable_cell(morphology, labels, decor)
|
|
51
|
+
return cc
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@config.node
|
|
55
|
+
class LIFCell(ArborCell, classmap_entry="lif"):
|
|
56
|
+
model = config.unset()
|
|
57
|
+
"""Importable reference to the arborize model describing the cell type."""
|
|
58
|
+
constants = config.dict(type=types.any_())
|
|
59
|
+
"""Dictionary linking the parameters' name to its value."""
|
|
60
|
+
|
|
61
|
+
def cache_population_data(self, simdata, ps: PlacementSet):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
def discard_population_data(self):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def get_prefixed_catalogue(self):
|
|
68
|
+
return None, None
|
|
69
|
+
|
|
70
|
+
def get_cell_kind(self, gid):
|
|
71
|
+
return arbor.cell_kind.lif
|
|
72
|
+
|
|
73
|
+
def get_description(self, gid):
|
|
74
|
+
cell = arbor.lif_cell("-1_-1", "-1_-1_0")
|
|
75
|
+
try:
|
|
76
|
+
for k, v in self.constants.items():
|
|
77
|
+
setattr(cell, k, v * getattr(cell, k).units)
|
|
78
|
+
except AttributeError:
|
|
79
|
+
node_name = type(self).constants.get_node_name(self)
|
|
80
|
+
raise ConfigurationError(
|
|
81
|
+
f"'{k}' is not a valid LIF parameter in '{node_name}'."
|
|
82
|
+
) from None
|
|
83
|
+
return cell
|
|
84
|
+
|
|
85
|
+
def make_receiver_collection(self):
|
|
86
|
+
return SingleReceiverCollection()
|
bsb_arbor/connection.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import arbor
|
|
2
|
+
import tqdm
|
|
3
|
+
from bsb import ConnectionModel, config
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Receiver:
|
|
7
|
+
def __init__(self, conn_model, from_gid, loc_from, loc_on, index=-1):
|
|
8
|
+
self.conn_model = conn_model
|
|
9
|
+
self.from_gid = from_gid
|
|
10
|
+
self.loc_from = loc_from
|
|
11
|
+
self.loc_on = loc_on
|
|
12
|
+
self.synapse = arbor.synapse("expsyn")
|
|
13
|
+
self.index = index
|
|
14
|
+
|
|
15
|
+
def from_(self):
|
|
16
|
+
b, p = self.loc_from
|
|
17
|
+
return arbor.cell_global_label(self.from_gid, f"{b}_{p}")
|
|
18
|
+
|
|
19
|
+
def on(self):
|
|
20
|
+
# self.index is set on us by the ReceiverCollection when we are appended.
|
|
21
|
+
b, p = self.loc_on
|
|
22
|
+
return arbor.cell_local_label(f"{b}_{p}_{self.index}")
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def weight(self):
|
|
26
|
+
return self.conn_model.weight
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def delay(self):
|
|
30
|
+
return self.conn_model.delay
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Connection:
|
|
34
|
+
def __init__(self, pre_loc, post_loc):
|
|
35
|
+
self.from_id = pre_loc[0]
|
|
36
|
+
self.to_id = post_loc[0]
|
|
37
|
+
self.pre_loc = pre_loc[1:]
|
|
38
|
+
self.post_loc = post_loc[1:]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@config.node
|
|
42
|
+
class ArborConnection(ConnectionModel):
|
|
43
|
+
gap = config.attr(type=bool, default=False)
|
|
44
|
+
"""Is this synapce a gap junction?"""
|
|
45
|
+
weight = config.attr(type=float, required=True)
|
|
46
|
+
"""Weight of the connection between the presynaptic and the postsynaptic cells."""
|
|
47
|
+
delay = config.attr(type=float, required=True)
|
|
48
|
+
"""Delay of the transmission between the presynaptic and the postsynaptic cells."""
|
|
49
|
+
|
|
50
|
+
def create_gap_junctions_on(self, gj_on_gid, conns):
|
|
51
|
+
for pre_loc, post_loc in conns:
|
|
52
|
+
conn = Connection(pre_loc, post_loc)
|
|
53
|
+
gj_on_gid.setdefault(conn.from_id, []).append(conn)
|
|
54
|
+
|
|
55
|
+
def create_connections_on(self, conns_on_gid, conns, pop_pre, pop_post):
|
|
56
|
+
for pre_loc, post_loc in tqdm.tqdm(conns, total=len(conns), desc=self.name):
|
|
57
|
+
conns_on_gid[post_loc[0] + pop_post.offset].append(
|
|
58
|
+
Receiver(self, pre_loc[0] + pop_pre.offset, pre_loc[1:], post_loc[1:])
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def create_connections_from(self, conns_from_gid, conns, pop_pre, pop_post):
|
|
62
|
+
for pre_loc, _post_loc in conns:
|
|
63
|
+
conns_from_gid[int(pre_loc[0] + pop_pre.offset)].append(pre_loc[1:])
|
|
64
|
+
|
|
65
|
+
def gap_junction(self, conn):
|
|
66
|
+
l_ = arbor.cell_local_label(f"gap_{conn.to_compartment.id}")
|
|
67
|
+
g = arbor.cell_global_label(int(conn.from_id), f"gap_{conn.from_compartment.id}")
|
|
68
|
+
return arbor.gap_junction_connection(g, l_, self.weight)
|
bsb_arbor/device.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
import arbor
|
|
4
|
+
from bsb import DeviceModel, Targetting, config, types
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@config.dynamic(attr_name="device", auto_classmap=True, classmap_entry=None)
|
|
8
|
+
class ArborDevice(DeviceModel):
|
|
9
|
+
device: config.ConfigurationAttribute
|
|
10
|
+
"""Optional importable reference to the device strategy."""
|
|
11
|
+
targetting = config.attr(type=Targetting, required=True)
|
|
12
|
+
"""Targets of the device, which should be either a population or a nest rule."""
|
|
13
|
+
resolution = config.attr(type=float)
|
|
14
|
+
"""Time resolution of the device."""
|
|
15
|
+
sampling_policy = config.attr(type=types.in_(["exact"]))
|
|
16
|
+
"""Policy used to sample simulation data from the device."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, **kwargs):
|
|
19
|
+
self._probe_ids = []
|
|
20
|
+
|
|
21
|
+
def __boot__(self):
|
|
22
|
+
self.resolution = self.resolution or self.simulation.resolution
|
|
23
|
+
|
|
24
|
+
def register_probe_id(self, gid, tag):
|
|
25
|
+
self._probe_ids.append((gid, tag))
|
|
26
|
+
|
|
27
|
+
def prepare_samples(self, simdata, comm):
|
|
28
|
+
self._handles = [
|
|
29
|
+
self.sample(simdata.arbor_sim, probe_id) for probe_id in self._probe_ids
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
def sample(self, sim, probe_id):
|
|
33
|
+
schedule = arbor.regular_schedule(self.resolution)
|
|
34
|
+
sampling_policy = getattr(arbor.sampling_policy, self.sampling_policy)
|
|
35
|
+
return sim.sample(probe_id, schedule, sampling_policy)
|
|
36
|
+
|
|
37
|
+
def get_samples(self, sim):
|
|
38
|
+
return [sim.samples(handle) for handle in self._handles]
|
|
39
|
+
|
|
40
|
+
def get_meta(self):
|
|
41
|
+
attrs = ("name", "sampling_policy", "resolution")
|
|
42
|
+
return dict(zip(attrs, (getattr(self, attr) for attr in attrs), strict=False))
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def implement_probes(self, simdata, target):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abc.abstractmethod
|
|
49
|
+
def implement_generators(self, simdata, target):
|
|
50
|
+
pass
|