bsb-arbor 0.0.0b1__py2.py3-none-any.whl → 0.0.0b3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bsb-arbor might be problematic. Click here for more details.

bsb_arbor/adapter.py CHANGED
@@ -1,364 +1,443 @@
1
- import itertools
2
- import itertools as it
3
- import time
4
- import typing
5
-
6
- import arbor
7
- from bsb.exceptions import AdapterError, UnknownGIDError
8
- from bsb.reporting import report, warn
9
- from bsb.services import MPI
10
- from bsb.simulation.adapter import SimulationData, SimulatorAdapter
11
- from bsb.storage import Chunk
12
-
13
- if typing.TYPE_CHECKING:
14
- from .simulation import ArborSimulation
15
-
16
-
17
- class ArborSimulationData(SimulationData):
18
- def __init__(self, simulation):
19
- super().__init__(simulation)
20
- self.arbor_sim: "arbor.simulation" = None
21
-
22
-
23
- class ReceiverCollection(list):
24
- """
25
- Receiver collections store the incoming connections and deduplicate them into multiple
26
- targets.
27
- """
28
-
29
- def __init__(self):
30
- super().__init__()
31
- self._endpoint_counters = {}
32
-
33
- def append(self, rcv):
34
- endpoint = str(rcv.loc_on)
35
- id = self._endpoint_counters.get(endpoint, 0)
36
- self._endpoint_counters[endpoint] = id + 1
37
- rcv.index = id
38
- super().append(rcv)
39
-
40
-
41
- class SingleReceiverCollection(list):
42
- """
43
- The single receiver collection redirects all incoming connections to the same receiver
44
- """
45
-
46
- def append(self, rcv):
47
- rcv.index = 0
48
- super().append(rcv)
49
-
50
-
51
- class Population:
52
- def __init__(self, simdata, cell_model, offset):
53
- self._model = cell_model
54
- ps = cell_model.get_placement_set(simdata.chunks)
55
- self._ranges = self._get_ranges(simdata.chunks, ps, offset)
56
- self._offset = offset
57
-
58
- @property
59
- def model(self):
60
- return self._model
61
-
62
- @property
63
- def offset(self):
64
- return self._offset
65
-
66
- def __contains__(self, i):
67
- return any(start <= i < stop for start, stop in self._ranges)
68
-
69
- def _get_ranges(self, chunks, ps, offset):
70
- stats = ps.get_chunk_stats()
71
- ranges = []
72
- for chunk, len_ in sorted(
73
- stats.items(), key=lambda k: Chunk.from_id(int(k[0]), None).id
74
- ):
75
- if chunk in chunks:
76
- ranges.append((offset, offset + len_))
77
- offset += len_
78
- return ranges
79
-
80
- def __iter__(self):
81
- yield from itertools.chain.from_iterable(
82
- range(r[0], r[1]) for r in self._ranges
83
- )
84
-
85
-
86
- class GIDManager:
87
- def __init__(self, simulation, simdata):
88
- self._gid_offsets = {}
89
- self._model_order = self.sort_models(simulation.cell_models.values())
90
- ctr = 0
91
- for model in self._model_order:
92
- self._gid_offsets[model] = ctr
93
- ctr += len(model.get_placement_set())
94
- self._populations = [
95
- Population(simdata, model, offset)
96
- for model, offset in self._gid_offsets.items()
97
- ]
98
-
99
- def sort_models(self, models):
100
- return sorted(
101
- models,
102
- key=lambda model: len(model.get_placement_set()),
103
- )
104
-
105
- def lookup_offset(self, gid):
106
- model = self.lookup_model(gid)
107
- return self._gid_offsets[model]
108
-
109
- def lookup_kind(self, gid):
110
- return self._lookup(gid).model.get_cell_kind(gid)
111
-
112
- def lookup_model(self, gid):
113
- return self._lookup(gid).model
114
-
115
- def _lookup(self, gid):
116
- try:
117
- return next(c for c in self._populations if gid in c)
118
- except StopIteration:
119
- raise UnknownGIDError(f"Can't find gid {gid}.") from None
120
-
121
- def all(self):
122
- yield from itertools.chain.from_iterable(self._populations)
123
-
124
- def get_populations(self):
125
- return {pop.model: pop for pop in self._populations}
126
-
127
-
128
- class ArborRecipe(arbor.recipe):
129
- def __init__(self, simulation, simdata):
130
- super().__init__()
131
- self._simulation = simulation
132
- self._simdata = simdata
133
- self._global_properties = arbor.neuron_cable_properties()
134
- self._global_properties.set_property(Vm=-65, tempK=300, rL=35.4, cm=0.01)
135
- self._global_properties.set_ion(ion="na", int_con=10, ext_con=140, rev_pot=50)
136
- self._global_properties.set_ion(ion="k", int_con=54.4, ext_con=2.5, rev_pot=-77)
137
- self._global_properties.set_ion(
138
- ion="ca", int_con=0.0001, ext_con=2, rev_pot=132.5
139
- )
140
- self._global_properties.set_ion(
141
- ion="h", valence=1, int_con=1.0, ext_con=1.0, rev_pot=-34
142
- )
143
- self._global_properties.catalogue = self._get_catalogue()
144
-
145
- def _get_catalogue(self):
146
- catalogue = arbor.default_catalogue()
147
- prefixes = set()
148
- for model in self._simulation.cell_models.values():
149
- prefix, model_catalogue = model.get_prefixed_catalogue()
150
- if model_catalogue is not None and prefix not in prefixes:
151
- prefixes.add(prefix)
152
- catalogue.extend(model_catalogue, "")
153
-
154
- return catalogue
155
-
156
- def global_properties(self, kind):
157
- return self._global_properties
158
-
159
- def num_cells(self):
160
- return sum(
161
- len(model.get_placement_set())
162
- for model in self._simulation.cell_models.values()
163
- )
164
-
165
- def cell_kind(self, gid):
166
- return self._simdata.gid_manager.lookup_kind(gid)
167
-
168
- def cell_description(self, gid):
169
- model = self._simdata.gid_manager.lookup_model(gid)
170
- return model.get_description(gid)
171
-
172
- def connections_on(self, gid):
173
- return [
174
- arbor.connection(rcv.from_(), rcv.on(), rcv.weight, rcv.delay)
175
- for rcv in self._simdata.connections_on[gid]
176
- ]
177
-
178
- def gap_junctions_on(self, gid):
179
- return [
180
- c.model.gap_junction(c) for c in self._simdata.gap_junctions_on.get(gid, [])
181
- ]
182
-
183
- def probes(self, gid):
184
- devices = self._simdata.devices_on[gid]
185
- _ntag = 0
186
- probes = []
187
- for device in devices:
188
- device_probes = device.implement_probes(self._simdata, gid)
189
- for tag in range(_ntag, _ntag + len(device_probes)):
190
- device.register_probe_id(gid, tag)
191
- probes.extend(device_probes)
192
- return probes
193
-
194
- def event_generators(self, gid):
195
- devices = self._simdata.devices_on[gid]
196
- generators = []
197
- for device in devices:
198
- device_generators = device.implement_generators(self._simdata, gid)
199
- generators.extend(device_generators)
200
- return generators
201
-
202
- def _name_of(self, gid):
203
- return self._simdata.gid_manager.lookup_model(gid).cell_type.name
204
-
205
-
206
- class ArborAdapter(SimulatorAdapter):
207
- def __init__(self):
208
- super().__init__()
209
- self.simdata: typing.Dict["ArborSimulation", "SimulationData"] = {}
210
-
211
- def get_rank(self):
212
- return MPI.get_rank()
213
-
214
- def get_size(self):
215
- return MPI.get_size()
216
-
217
- def broadcast(self, data, root=0):
218
- return MPI.bcast(data, root)
219
-
220
- def barrier(self):
221
- return MPI.barrier()
222
-
223
- def prepare(self, simulation: "ArborSimulation", comm=None):
224
- simdata = self._create_simdata(simulation)
225
- try:
226
- context = arbor.context(arbor.proc_allocation(threads=simulation.threads))
227
- if MPI.get_size() > 1:
228
- if not arbor.config()["mpi4py"]:
229
- warn(
230
- f"Arbor does not seem to be built with MPI support, running"
231
- "duplicate simulations on {MPI.get_size()} nodes."
232
- )
233
- else:
234
- context = arbor.context(
235
- arbor.proc_allocation(threads=simulation.threads),
236
- mpi=comm or MPI.get_communicator(),
237
- )
238
- if simulation.profiling:
239
- if arbor.config()["profiling"]:
240
- report("enabling profiler", level=2)
241
- arbor.profiler_initialize(context)
242
- else:
243
- raise RuntimeError(
244
- "Arbor must be built with profiling support to use the `profiling` flag."
245
- )
246
- simdata.gid_manager = self.get_gid_manager(simulation, simdata)
247
- simdata.populations = simdata.gid_manager.get_populations()
248
- report("preparing simulation", level=1)
249
- report("MPI processes:", context.ranks, level=2)
250
- report("Threads per process:", context.threads, level=2)
251
- recipe = self.get_recipe(simulation, simdata)
252
- # Gap junctions are required for domain decomposition
253
- self.domain = arbor.partition_load_balance(recipe, context)
254
- self.gids = set(it.chain.from_iterable(g.gids for g in self.domain.groups))
255
- simdata.arbor_sim = arbor.simulation(recipe, context, self.domain)
256
- self.prepare_samples(simulation, simdata)
257
- report("prepared simulation", level=1)
258
- return simdata
259
- except Exception:
260
- del self.simdata[simulation]
261
- raise
262
-
263
- def get_gid_manager(self, simulation, simdata):
264
- return GIDManager(simulation, simdata)
265
-
266
- def prepare_samples(self, simulation, simdata):
267
- for device in simulation.devices.values():
268
- device.prepare_samples(simdata)
269
-
270
- def run(self, *simulations):
271
- if len(simulations) != 1:
272
- raise RuntimeError(
273
- "Can not run multiple simultaneous simulations. Composition not implemented."
274
- )
275
- simulation = simulations[0]
276
- try:
277
- simdata = self.simdata[simulation]
278
- arbor_sim = simdata.arbor_sim
279
- except KeyError:
280
- raise AdapterError(
281
- f"Can't run unprepared simulation '{simulation.name}'"
282
- ) from None
283
- try:
284
- if not MPI.get_rank():
285
- arbor_sim.record(arbor.spike_recording.all)
286
-
287
- start = time.time()
288
- report("running simulation", level=1)
289
- arbor_sim.run(simulation.duration, dt=simulation.resolution)
290
- report(f"completed simulation. {time.time() - start:.2f}s", level=1)
291
- if simulation.profiling and arbor.config()["profiling"]:
292
- report("printing profiler summary", level=2)
293
- report(arbor.profiler_summary(), level=1)
294
- return [simdata.result]
295
- finally:
296
- del self.simdata[simulation]
297
-
298
- def get_recipe(self, simulation, simdata=None):
299
- if simdata is None:
300
- simdata = self._create_simdata(simulation)
301
- self._cache_gap_junctions(simulation, simdata)
302
- self._cache_connections(simulation, simdata)
303
- self._cache_devices(simulation, simdata)
304
- return ArborRecipe(simulation, simdata)
305
-
306
- def _create_simdata(self, simulation):
307
- self.simdata[simulation] = simdata = SimulationData(simulation)
308
- self._assign_chunks(simulation, simdata)
309
- return simdata
310
-
311
- def _cache_gap_junctions(self, simulation, simdata):
312
- simdata.gap_junctions_on = {}
313
- for conn_model in simulation.connection_models.values():
314
- if conn_model.gap:
315
- conn_set = conn_model.get_connectivity_set()
316
- conns = conn_set.load_connections().to(simdata.chunks).as_globals()
317
- conn_model.create_gap_junctions_on(simdata.gap_junctions_on, conns)
318
-
319
- def _cache_connections(self, simulation, simdata):
320
- simdata.connections_on = {
321
- gid: simdata.gid_manager.lookup_model(gid).make_receiver_collection()
322
- for gid in simdata.gid_manager.all()
323
- }
324
- simdata.connections_from = {gid: [] for gid in simdata.gid_manager.all()}
325
- for conn_model in simulation.connection_models.values():
326
- if conn_model.gap:
327
- continue
328
- conn_set = conn_model.get_connectivity_set()
329
- pop_pre, pop_post = None, None
330
- for model in simulation.cell_models.values():
331
- if model.cell_type is conn_set.pre_type:
332
- pop_pre = simdata.populations[model]
333
- if model.cell_type is conn_set.post_type:
334
- pop_post = simdata.populations[model]
335
- # Get the arriving connection iterator
336
- conns_on = conn_set.load_connections().to(simdata.chunks).as_globals()
337
- # Create the arriving connections
338
- conn_model.create_connections_on(
339
- simdata.connections_on, conns_on, pop_pre, pop_post
340
- )
341
- # Get the outgoing connection iterator
342
- conns_from = conn_set.load_connections().from_(simdata.chunks).as_globals()
343
- # Create the outgoing connections
344
- conn_model.create_connections_from(
345
- simdata.connections_from, conns_from, pop_pre, pop_post
346
- )
347
-
348
- def _cache_devices(self, simulation, simdata):
349
- simdata.devices_on = {gid: [] for gid in simdata.gid_manager.all()}
350
- for device in simulation.devices.values():
351
- targets = device.targetting.get_targets(self, simulation, simdata)
352
- for target in itertools.chain.from_iterable(targets.values()):
353
- simdata.devices_on[target].append(device)
354
-
355
- def _assign_chunks(self, simulation, simdata):
356
- chunk_stats = simulation.scaffold.storage.get_chunk_stats()
357
- size = MPI.get_size()
358
- all_chunks = [Chunk.from_id(int(chunk), None) for chunk in chunk_stats.keys()]
359
- simdata.node_chunk_alloc = [all_chunks[rank::size] for rank in range(0, size)]
360
- simdata.chunk_node_map = {}
361
- for node, chunks in enumerate(simdata.node_chunk_alloc):
362
- for chunk in chunks:
363
- simdata.chunk_node_map[chunk] = node
364
- simdata.chunks = simdata.node_chunk_alloc[MPI.get_rank()]
1
+ import itertools
2
+ import itertools as it
3
+ import time
4
+ import typing
5
+ import numpy as np
6
+
7
+ import arbor
8
+ from bsb import (
9
+ MPI,
10
+ AdapterError,
11
+ Chunk,
12
+ SimulationData,
13
+ SimulatorAdapter,
14
+ UnknownGIDError,
15
+ report,
16
+ warn,
17
+ )
18
+
19
+ if typing.TYPE_CHECKING:
20
+ from .simulation import ArborSimulation
21
+
22
+
23
+ class ArborSimulationData(SimulationData):
24
+ def __init__(self, simulation):
25
+ super().__init__(simulation)
26
+ self.arbor_sim: "arbor.simulation" = None
27
+
28
+
29
+ class ReceiverCollection(list):
30
+ """
31
+ Receiver collections store the incoming connections and deduplicate them into multiple
32
+ targets.
33
+ """
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+ self._endpoint_counters = {}
38
+
39
+ def append(self, rcv):
40
+ endpoint = str(rcv.loc_on)
41
+ id = self._endpoint_counters.get(endpoint, 0)
42
+ self._endpoint_counters[endpoint] = id + 1
43
+ rcv.index = id
44
+ super().append(rcv)
45
+
46
+
47
+ class SingleReceiverCollection(list):
48
+ """
49
+ The single receiver collection redirects all incoming connections to the same receiver
50
+ """
51
+
52
+ def append(self, rcv):
53
+ rcv.index = 0
54
+ super().append(rcv)
55
+
56
+
57
+ class Population:
58
+ def __init__(self, simdata, cell_model, offset):
59
+ self._model = cell_model
60
+ self._simdata = simdata
61
+ ps = cell_model.get_placement_set(simdata.chunks)
62
+ self._ranges = self._get_ranges(simdata.chunks, ps, offset)
63
+ self._offset = offset
64
+
65
+ @property
66
+ def model(self):
67
+ return self._model
68
+
69
+ @property
70
+ def offset(self):
71
+ return self._offset
72
+
73
+ def __len__(self):
74
+ return sum(stop - start for start, stop in self._ranges)
75
+
76
+ def __contains__(self, i):
77
+ return any(start <= i < stop for start, stop in self._ranges)
78
+
79
+ def copy(self):
80
+ return Population(self._simdata, self._model, self._offset)
81
+
82
+ def __getitem__(self, item):
83
+ # Boolean masking, kind of
84
+ if getattr(item, "dtype", None) == bool or _all_bools(item):
85
+ if len(item) != len(self):
86
+ raise ValueError(
87
+ f"Dimension mismatch between population ({len(self)}) and mask ({len(item)})"
88
+ )
89
+ return self._subpop_np(np.array(self)[item])
90
+ elif getattr(item, "dtype", None) == int or _all_ints(item):
91
+ if getattr(item, "ndim", None) == 0:
92
+ return self._subpop_one(item)
93
+ return self._subpop_np(np.array(self)[item])
94
+ elif isinstance(item, slice):
95
+ return self._subpop_np(np.array(self)[item])
96
+ else:
97
+ return self._subpop_one(item)
98
+
99
+ def _get_ranges(self, chunks, ps, offset):
100
+ stats = ps.get_chunk_stats()
101
+ ranges = []
102
+ for chunk, len_ in sorted(
103
+ stats.items(), key=lambda k: Chunk.from_id(int(k[0]), None).id
104
+ ):
105
+ if chunk in chunks:
106
+ ranges.append((offset, offset + len_))
107
+ offset += len_
108
+ return ranges
109
+
110
+ def _subpop_np(self, arr):
111
+ pop = self.copy()
112
+ if not len(pop):
113
+ return pop
114
+ ranges = []
115
+ prev = None
116
+ start, stop = self._ranges[0]
117
+ for i in arr:
118
+ if prev is None:
119
+ start += i
120
+ stop = start + 1
121
+ elif i == prev + 1:
122
+ stop += 1
123
+ else:
124
+ ranges.append((start, stop))
125
+ start = i
126
+ stop = i + 1
127
+ prev = i
128
+ pop._ranges = ranges
129
+ return pop
130
+
131
+ def _subpop_one(self, item):
132
+ if item >= len(self):
133
+ raise IndexError(f"Index {item} out of bounds for size {len(self)}")
134
+ pop = self.copy()
135
+ ptr = 0
136
+ for start, stop in self._ranges:
137
+ if item < (ptr + stop - start):
138
+ pop._ranges = [(start + ptr - item, start + ptr - item + 1)]
139
+ return pop
140
+ else:
141
+ ptr += stop - start
142
+
143
+ def __iter__(self):
144
+ yield from itertools.chain.from_iterable(
145
+ range(r[0], r[1]) for r in self._ranges
146
+ )
147
+
148
+
149
+ class GIDManager:
150
+ def __init__(self, simulation, simdata):
151
+ self._gid_offsets = {}
152
+ self._model_order = self.sort_models(simulation.cell_models.values())
153
+ ctr = 0
154
+ for model in self._model_order:
155
+ self._gid_offsets[model] = ctr
156
+ ctr += len(model.get_placement_set())
157
+ self._populations = [
158
+ Population(simdata, model, offset)
159
+ for model, offset in self._gid_offsets.items()
160
+ ]
161
+
162
+ def sort_models(self, models):
163
+ return sorted(
164
+ models,
165
+ key=lambda model: len(model.get_placement_set()),
166
+ )
167
+
168
+ def lookup_offset(self, gid):
169
+ model = self.lookup_model(gid)
170
+ return self._gid_offsets[model]
171
+
172
+ def lookup_kind(self, gid):
173
+ return self._lookup(gid).model.get_cell_kind(gid)
174
+
175
+ def lookup_model(self, gid):
176
+ return self._lookup(gid).model
177
+
178
+ def _lookup(self, gid):
179
+ try:
180
+ return next(c for c in self._populations if gid in c)
181
+ except StopIteration:
182
+ raise UnknownGIDError(f"Can't find gid {gid}.") from None
183
+
184
+ def all(self):
185
+ yield from itertools.chain.from_iterable(self._populations)
186
+
187
+ def get_populations(self):
188
+ return {pop.model: pop for pop in self._populations}
189
+
190
+
191
+ class ArborRecipe(arbor.recipe):
192
+ def __init__(self, simulation, simdata):
193
+ super().__init__()
194
+ self._simulation = simulation
195
+ self._simdata = simdata
196
+ self._global_properties = arbor.neuron_cable_properties()
197
+ self._global_properties.set_property(Vm=-65, tempK=300, rL=35.4, cm=0.01)
198
+ self._global_properties.set_ion(ion="na", int_con=10, ext_con=140, rev_pot=50)
199
+ self._global_properties.set_ion(ion="k", int_con=54.4, ext_con=2.5, rev_pot=-77)
200
+ self._global_properties.set_ion(
201
+ ion="ca", int_con=0.0001, ext_con=2, rev_pot=132.5
202
+ )
203
+ self._global_properties.set_ion(
204
+ ion="h", valence=1, int_con=1.0, ext_con=1.0, rev_pot=-34
205
+ )
206
+ self._global_properties.catalogue = self._get_catalogue()
207
+
208
+ def _get_catalogue(self):
209
+ catalogue = arbor.default_catalogue()
210
+ prefixes = set()
211
+ for model in self._simulation.cell_models.values():
212
+ prefix, model_catalogue = model.get_prefixed_catalogue()
213
+ if model_catalogue is not None and prefix not in prefixes:
214
+ prefixes.add(prefix)
215
+ catalogue.extend(model_catalogue, "")
216
+
217
+ return catalogue
218
+
219
+ def global_properties(self, kind):
220
+ return self._global_properties
221
+
222
+ def num_cells(self):
223
+ return sum(
224
+ len(model.get_placement_set())
225
+ for model in self._simulation.cell_models.values()
226
+ )
227
+
228
+ def cell_kind(self, gid):
229
+ return self._simdata.gid_manager.lookup_kind(gid)
230
+
231
+ def cell_description(self, gid):
232
+ model = self._simdata.gid_manager.lookup_model(gid)
233
+ return model.get_description(gid)
234
+
235
+ def connections_on(self, gid):
236
+ return [
237
+ arbor.connection(rcv.from_(), rcv.on(), rcv.weight, rcv.delay)
238
+ for rcv in self._simdata.connections_on[gid]
239
+ ]
240
+
241
+ def gap_junctions_on(self, gid):
242
+ return [
243
+ c.model.gap_junction(c) for c in self._simdata.gap_junctions_on.get(gid, [])
244
+ ]
245
+
246
+ def probes(self, gid):
247
+ devices = self._simdata.devices_on[gid]
248
+ _ntag = 0
249
+ probes = []
250
+ for device in devices:
251
+ device_probes = device.implement_probes(self._simdata, gid)
252
+ for tag in range(_ntag, _ntag + len(device_probes)):
253
+ device.register_probe_id(gid, tag)
254
+ probes.extend(device_probes)
255
+ return probes
256
+
257
+ def event_generators(self, gid):
258
+ devices = self._simdata.devices_on[gid]
259
+ generators = []
260
+ for device in devices:
261
+ device_generators = device.implement_generators(self._simdata, gid)
262
+ generators.extend(device_generators)
263
+ return generators
264
+
265
+ def _name_of(self, gid):
266
+ return self._simdata.gid_manager.lookup_model(gid).cell_type.name
267
+
268
+
269
+ class ArborAdapter(SimulatorAdapter):
270
+ def __init__(self):
271
+ super().__init__()
272
+ self.simdata: typing.Dict["ArborSimulation", "SimulationData"] = {}
273
+
274
+ def get_rank(self):
275
+ return MPI.get_rank()
276
+
277
+ def get_size(self):
278
+ return MPI.get_size()
279
+
280
+ def broadcast(self, data, root=0):
281
+ return MPI.bcast(data, root)
282
+
283
+ def barrier(self):
284
+ return MPI.barrier()
285
+
286
+ def prepare(self, simulation: "ArborSimulation", comm=None):
287
+ simdata = self._create_simdata(simulation)
288
+ try:
289
+ context = arbor.context(arbor.proc_allocation(threads=simulation.threads))
290
+ if MPI.get_size() > 1:
291
+ if not arbor.config()["mpi4py"]:
292
+ warn(
293
+ f"Arbor does not seem to be built with MPI support, running"
294
+ "duplicate simulations on {MPI.get_size()} nodes."
295
+ )
296
+ else:
297
+ context = arbor.context(
298
+ arbor.proc_allocation(threads=simulation.threads),
299
+ mpi=comm or MPI.get_communicator(),
300
+ )
301
+ if simulation.profiling:
302
+ if arbor.config()["profiling"]:
303
+ report("enabling profiler", level=2)
304
+ arbor.profiler_initialize(context)
305
+ else:
306
+ raise RuntimeError(
307
+ "Arbor must be built with profiling support to use the `profiling` flag."
308
+ )
309
+ simdata.gid_manager = self.get_gid_manager(simulation, simdata)
310
+ simdata.populations = simdata.gid_manager.get_populations()
311
+ report("preparing simulation", level=1)
312
+ report("MPI processes:", context.ranks, level=2)
313
+ report("Threads per process:", context.threads, level=2)
314
+ recipe = self.get_recipe(simulation, simdata)
315
+ # Gap junctions are required for domain decomposition
316
+ self.domain = arbor.partition_load_balance(recipe, context)
317
+ self.gids = set(it.chain.from_iterable(g.gids for g in self.domain.groups))
318
+ simdata.arbor_sim = arbor.simulation(recipe, context, self.domain)
319
+ self.prepare_samples(simulation, simdata)
320
+ report("prepared simulation", level=1)
321
+ return simdata
322
+ except Exception:
323
+ del self.simdata[simulation]
324
+ raise
325
+
326
+ def get_gid_manager(self, simulation, simdata):
327
+ return GIDManager(simulation, simdata)
328
+
329
+ def prepare_samples(self, simulation, simdata):
330
+ for device in simulation.devices.values():
331
+ device.prepare_samples(simdata)
332
+
333
+ def run(self, *simulations):
334
+ if len(simulations) != 1:
335
+ raise RuntimeError(
336
+ "Can not run multiple simultaneous simulations. Composition not implemented."
337
+ )
338
+ simulation = simulations[0]
339
+ try:
340
+ simdata = self.simdata[simulation]
341
+ arbor_sim = simdata.arbor_sim
342
+ except KeyError:
343
+ raise AdapterError(
344
+ f"Can't run unprepared simulation '{simulation.name}'"
345
+ ) from None
346
+ try:
347
+ if not MPI.get_rank():
348
+ arbor_sim.record(arbor.spike_recording.all)
349
+
350
+ start = time.time()
351
+ report("running simulation", level=1)
352
+ arbor_sim.run(simulation.duration, dt=simulation.resolution)
353
+ report(f"completed simulation. {time.time() - start:.2f}s", level=1)
354
+ if simulation.profiling and arbor.config()["profiling"]:
355
+ report("printing profiler summary", level=2)
356
+ report(arbor.profiler_summary(), level=1)
357
+ return [simdata.result]
358
+ finally:
359
+ del self.simdata[simulation]
360
+
361
+ def get_recipe(self, simulation, simdata=None):
362
+ if simdata is None:
363
+ simdata = self._create_simdata(simulation)
364
+ self._cache_gap_junctions(simulation, simdata)
365
+ self._cache_connections(simulation, simdata)
366
+ self._cache_devices(simulation, simdata)
367
+ return ArborRecipe(simulation, simdata)
368
+
369
+ def _create_simdata(self, simulation):
370
+ self.simdata[simulation] = simdata = SimulationData(simulation)
371
+ self._assign_chunks(simulation, simdata)
372
+ return simdata
373
+
374
+ def _cache_gap_junctions(self, simulation, simdata):
375
+ simdata.gap_junctions_on = {}
376
+ for conn_model in simulation.connection_models.values():
377
+ if conn_model.gap:
378
+ conn_set = conn_model.get_connectivity_set()
379
+ conns = conn_set.load_connections().to(simdata.chunks).as_globals()
380
+ conn_model.create_gap_junctions_on(simdata.gap_junctions_on, conns)
381
+
382
+ def _cache_connections(self, simulation, simdata):
383
+ simdata.connections_on = {
384
+ gid: simdata.gid_manager.lookup_model(gid).make_receiver_collection()
385
+ for gid in simdata.gid_manager.all()
386
+ }
387
+ simdata.connections_from = {gid: [] for gid in simdata.gid_manager.all()}
388
+ for conn_model in simulation.connection_models.values():
389
+ if conn_model.gap:
390
+ continue
391
+ conn_set = conn_model.get_connectivity_set()
392
+ pop_pre, pop_post = None, None
393
+ for model in simulation.cell_models.values():
394
+ if model.cell_type is conn_set.pre_type:
395
+ pop_pre = simdata.populations[model]
396
+ if model.cell_type is conn_set.post_type:
397
+ pop_post = simdata.populations[model]
398
+ # Get the arriving connection iterator
399
+ conns_on = conn_set.load_connections().to(simdata.chunks).as_globals()
400
+ # Create the arriving connections
401
+ conn_model.create_connections_on(
402
+ simdata.connections_on, conns_on, pop_pre, pop_post
403
+ )
404
+ # Get the outgoing connection iterator
405
+ conns_from = conn_set.load_connections().from_(simdata.chunks).as_globals()
406
+ # Create the outgoing connections
407
+ conn_model.create_connections_from(
408
+ simdata.connections_from, conns_from, pop_pre, pop_post
409
+ )
410
+
411
+ def _cache_devices(self, simulation, simdata):
412
+ simdata.devices_on = {gid: [] for gid in simdata.gid_manager.all()}
413
+ for device in simulation.devices.values():
414
+ targets = device.targetting.get_targets(self, simulation, simdata)
415
+ for target in itertools.chain.from_iterable(targets.values()):
416
+ simdata.devices_on[target].append(device)
417
+
418
+ def _assign_chunks(self, simulation, simdata):
419
+ chunk_stats = simulation.scaffold.storage.get_chunk_stats()
420
+ size = MPI.get_size()
421
+ all_chunks = [Chunk.from_id(int(chunk), None) for chunk in chunk_stats.keys()]
422
+ simdata.node_chunk_alloc = [all_chunks[rank::size] for rank in range(0, size)]
423
+ simdata.chunk_node_map = {}
424
+ for node, chunks in enumerate(simdata.node_chunk_alloc):
425
+ for chunk in chunks:
426
+ simdata.chunk_node_map[chunk] = node
427
+ simdata.chunks = simdata.node_chunk_alloc[MPI.get_rank()]
428
+
429
+
430
+ def _all_bools(arr):
431
+ try:
432
+ return all(isinstance(b, bool) for b in arr)
433
+ except TypeError:
434
+ # Not iterable
435
+ return False
436
+
437
+
438
+ def _all_ints(arr):
439
+ try:
440
+ return all(isinstance(b, int) for b in arr)
441
+ except TypeError:
442
+ # Not iterable
443
+ return False