sinabs 3.0.4.dev25__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. sinabs/activation/reset_mechanism.py +3 -3
  2. sinabs/activation/surrogate_gradient_fn.py +4 -4
  3. sinabs/backend/dynapcnn/__init__.py +5 -4
  4. sinabs/backend/dynapcnn/chip_factory.py +33 -61
  5. sinabs/backend/dynapcnn/chips/dynapcnn.py +182 -86
  6. sinabs/backend/dynapcnn/chips/speck2e.py +6 -5
  7. sinabs/backend/dynapcnn/chips/speck2f.py +6 -5
  8. sinabs/backend/dynapcnn/config_builder.py +39 -59
  9. sinabs/backend/dynapcnn/connectivity_specs.py +48 -0
  10. sinabs/backend/dynapcnn/discretize.py +91 -155
  11. sinabs/backend/dynapcnn/dvs_layer.py +59 -101
  12. sinabs/backend/dynapcnn/dynapcnn_layer.py +185 -119
  13. sinabs/backend/dynapcnn/dynapcnn_layer_utils.py +335 -0
  14. sinabs/backend/dynapcnn/dynapcnn_network.py +602 -325
  15. sinabs/backend/dynapcnn/dynapcnnnetwork_module.py +370 -0
  16. sinabs/backend/dynapcnn/exceptions.py +122 -3
  17. sinabs/backend/dynapcnn/io.py +51 -91
  18. sinabs/backend/dynapcnn/mapping.py +111 -75
  19. sinabs/backend/dynapcnn/nir_graph_extractor.py +877 -0
  20. sinabs/backend/dynapcnn/sinabs_edges_handler.py +1024 -0
  21. sinabs/backend/dynapcnn/utils.py +214 -459
  22. sinabs/backend/dynapcnn/weight_rescaling_methods.py +53 -0
  23. sinabs/conversion.py +2 -2
  24. sinabs/from_torch.py +23 -1
  25. sinabs/hooks.py +38 -41
  26. sinabs/layers/alif.py +16 -16
  27. sinabs/layers/crop2d.py +2 -2
  28. sinabs/layers/exp_leak.py +1 -1
  29. sinabs/layers/iaf.py +11 -11
  30. sinabs/layers/lif.py +9 -9
  31. sinabs/layers/neuromorphic_relu.py +9 -8
  32. sinabs/layers/pool2d.py +5 -5
  33. sinabs/layers/quantize.py +1 -1
  34. sinabs/layers/stateful_layer.py +10 -7
  35. sinabs/layers/to_spike.py +9 -9
  36. sinabs/network.py +14 -12
  37. sinabs/synopcounter.py +10 -7
  38. sinabs/utils.py +155 -7
  39. sinabs/validate_memory_speck.py +0 -5
  40. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/METADATA +2 -1
  41. sinabs-3.1.0.dist-info/RECORD +65 -0
  42. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/licenses/AUTHORS +1 -0
  43. sinabs-3.1.0.dist-info/pbr.json +1 -0
  44. sinabs-3.0.4.dev25.dist-info/RECORD +0 -59
  45. sinabs-3.0.4.dev25.dist-info/pbr.json +0 -1
  46. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/WHEEL +0 -0
  47. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/licenses/LICENSE +0 -0
  48. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,877 @@
1
+ from copy import deepcopy
2
+ from pprint import pformat
3
+ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
4
+
5
+ import nirtorch
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from sinabs import layers as sl
10
+ from sinabs.utils import get_new_index
11
+
12
+ from .connectivity_specs import (
13
+ LAYER_TYPES_WITH_MULTIPLE_INPUTS,
14
+ LAYER_TYPES_WITH_MULTIPLE_OUTPUTS,
15
+ SupportedNodeTypes,
16
+ )
17
+ from .dvs_layer import DVSLayer
18
+ from .dynapcnn_layer_utils import construct_dynapcnnlayers_from_mapper
19
+ from .dynapcnnnetwork_module import DynapcnnNetworkModule
20
+ from .exceptions import InvalidGraphStructure, UnsupportedLayerType
21
+ from .sinabs_edges_handler import (
22
+ collect_dynapcnn_layer_info,
23
+ fix_dvs_module_edges,
24
+ handle_batchnorm_nodes,
25
+ )
26
+ from .utils import Edge, topological_sorting
27
+ from warnings import warn
28
+
29
+ try:
30
+ from nirtorch.graph import TorchGraph
31
+ except ImportError:
32
+ # In older nirtorch versions TorchGraph is called Graph
33
+ from nirtorch.graph import Graph as TorchGraph
34
+
35
+
36
+ class GraphExtractor:
37
+ def __init__(
38
+ self,
39
+ spiking_model: nn.Module,
40
+ dummy_input: torch.tensor,
41
+ dvs_input: Optional[bool] = None,
42
+ ignore_node_types: Optional[Iterable[Type]] = None,
43
+ ):
44
+ """Class implementing the extraction of the computational graph from `spiking_model`, where
45
+ each node represents a layer in the model and the list of edges represents how the data flow between
46
+ the layers.
47
+
48
+ Args:
49
+ spiking_model (nn.Module): a sinabs-compatible spiking network.
50
+ dummy_input (torch.tensor): an input sample to be fed through
51
+ the model to acquire both the computational graph (via
52
+ `nirtorch`) and the I/O shapes of each node. Its a 4-D shape
53
+ with `(batch, channels, heigh, width)`.
54
+ dvs_input (bool): optional (default as `None`). Whether or not the
55
+ model should start with a `DVSLayer`.
56
+ ignore_node_types (iterable of types): Node types that should be
57
+ ignored completely from the graph. This can include, for
58
+ instance, `nn.Dropout2d`, which otherwise can result in wrongly
59
+ inferred graph structures by NIRTorch. Types such as
60
+ `nn.Flatten`, or sinabs `Merge` should not be included here, as
61
+ they are needed to properly handle graph structure and
62
+ metadata. They can be removed after instantiation with
63
+ `remove_nodes_by_class`.
64
+
65
+ Attributes:
66
+ edges (set of 2-tuples of integers): Tuples describing the
67
+ connections between layers in `spiking_model`. Each layer
68
+ (node) is identified by a unique integer ID.
69
+ name_2_index_map (dict): Keys are original variable names of layers
70
+ in `spiking_model`. Values are unique integer IDs.
71
+ entry_nodes (set of ints): IDs of nodes acting as entry points for
72
+ the network, i.e. receiving external input.
73
+ indx_2_module_map (dict): Map from layer ID to the corresponding
74
+ nn.Module instance.
75
+ nodes_io_shapes (dict): Map from node ID to dict containing node's
76
+ in- and output shapes.
77
+ dvs_input (bool): optional (default as `None`). Whether or not the
78
+ model should start with a `DVSLayer`.
79
+ ignore_node_types (iterable of types): Node types that should be
80
+ ignored completely from the graph. This can include, for
81
+ instance, `nn.Dropout2d`, which otherwise can result in wrongly
82
+ inferred graph structures by NIRTorch. Types such as
83
+ `nn.Flatten`, or sinabs `Merge` should not be included here, as
84
+ they are needed to properly handle graph structure and
85
+ metadata. They can be removed after instantiation with
86
+ `remove_nodes_by_class`.
87
+ """
88
+
89
+ # Store state before it is changed due to NIRTorch and
90
+ # `self._get_nodes_io_shapes` passing dummy input
91
+ original_state = {
92
+ n: b.detach().clone() for n, b in spiking_model.named_buffers()
93
+ }
94
+
95
+ self._edges = set()
96
+ # Empty sequentials will cause nirtorch to fail. Treat this case separately
97
+ if isinstance(spiking_model, nn.Sequential) and len(spiking_model) == 0:
98
+ self._name_2_indx_map = dict()
99
+ self._edges = set()
100
+ original_state = {}
101
+ else:
102
+ nir_graph = nirtorch.graph.extract_torch_graph(
103
+ spiking_model, dummy_input, model_name=None
104
+ ).ignore_tensors()
105
+
106
+ if ignore_node_types is not None:
107
+ for node_type in ignore_node_types:
108
+ nir_graph = nir_graph.ignore_nodes(node_type)
109
+
110
+ # Map node names to indices
111
+ self._name_2_indx_map = self._get_name_2_indx_map(nir_graph)
112
+
113
+ # Extract edges list from graph
114
+ self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map)
115
+
116
+ # Store the associated `nn.Module` (layer) of each node.
117
+ self._indx_2_module_map = self._get_name2module_map(spiking_model)
118
+
119
+ if len(self._name_2_indx_map) > 0:
120
+ # Merges BatchNorm2d/BatchNorm1d nodes with Conv2d/Linear ones.
121
+ handle_batchnorm_nodes(
122
+ self._edges, self._indx_2_module_map, self._name_2_indx_map
123
+ )
124
+
125
+ # Determine entry points to graph
126
+ self._entry_nodes = self._get_entry_nodes(self._edges)
127
+
128
+ # Make sure DVS input is properly integrated into graph
129
+ self._handle_dvs_input(input_shape=dummy_input.shape[1:], dvs_input=dvs_input)
130
+
131
+ # retrieves what the I/O shape for each node's module is.
132
+ self._nodes_io_shapes = self._get_nodes_io_shapes(dummy_input)
133
+
134
+ # Restore original state - after forward passes from nirtorch and `_get_nodes_io_shapes`
135
+ for n, b in spiking_model.named_buffers():
136
+ b.set_(original_state[n].clone())
137
+
138
+ # Verify that graph is compatible
139
+ self.verify_graph_integrity()
140
+
141
+ @property
142
+ def dvs_layer(self) -> Union[DVSLayer, None]:
143
+ idx = self.dvs_node_id
144
+ if idx is None:
145
+ return None
146
+ else:
147
+ return self.indx_2_module_map[self.dvs_node_id]
148
+
149
+ @property
150
+ def dvs_node_id(self) -> Union[int, None]:
151
+ return self._get_dvs_node_id()
152
+
153
+ @property
154
+ def entry_nodes(self) -> Set[int]:
155
+ return {n for n in self._entry_nodes}
156
+
157
+ @property
158
+ def edges(self) -> Set[Edge]:
159
+ return {(src, tgt) for src, tgt in self._edges}
160
+
161
+ @property
162
+ def has_dvs_layer(self) -> bool:
163
+ return self.dvs_layer is not None
164
+
165
+ @property
166
+ def name_2_indx_map(self) -> Dict[str, int]:
167
+ return {name: idx for name, idx in self._name_2_indx_map.items()}
168
+
169
+ @property
170
+ def nodes_io_shapes(self) -> Dict[int, Tuple[torch.Size]]:
171
+ return {n: size for n, size in self._nodes_io_shapes.items()}
172
+
173
+ @property
174
+ def sorted_nodes(self) -> List[int]:
175
+ return [n for n in self._sort_graph_nodes()]
176
+
177
+ @property
178
+ def indx_2_module_map(self) -> Dict[int, nn.Module]:
179
+ return {n: module for n, module in self._indx_2_module_map.items()}
180
+
181
+ def get_dynapcnn_network_module(
182
+ self, discretize: bool = True, weight_rescaling_fn: Optional[Callable] = None
183
+ ) -> DynapcnnNetworkModule:
184
+ """Create DynapcnnNetworkModule based on stored graph representation
185
+
186
+ This includes construction of the DynapcnnLayer instances
187
+
188
+ Args:
189
+ discretize (bool): If `True`, discretize the parameters and
190
+ thresholds. This is needed for uploading weights to DynapCNN.
191
+ Set to `False` only for testing purposes.
192
+ weight_rescaling_fn (callable): a method that handles how the
193
+ re-scaling factor for one or more `SumPool2d` projecting to the
194
+ same convolutional layer are combined/re-scaled before applying them.
195
+
196
+ Returns:
197
+ The DynapcnnNetworkModule based on graph representation of this `GraphExtractor`
198
+
199
+ """
200
+ # Make sure all nodes are supported and there are no isolated nodes.
201
+ self.verify_node_types()
202
+ self.verify_no_isolated_nodes()
203
+
204
+ # create a dict holding the data necessary to instantiate a `DynapcnnLayer`.
205
+ self.dcnnl_info, self.dvs_layer_info = collect_dynapcnn_layer_info(
206
+ indx_2_module_map=self.indx_2_module_map,
207
+ edges=self.edges,
208
+ nodes_io_shapes=self.nodes_io_shapes,
209
+ entry_nodes=self.entry_nodes,
210
+ )
211
+
212
+ # Special case where there is a disconnected `DVSLayer`: There are no
213
+ # Edges for the edges handler to process. Instantiate layer info manually.
214
+ if self.dvs_layer_info is None and self.dvs_layer is not None:
215
+ self.dvs_layer_info = {
216
+ "node_id": self.dvs_node_id,
217
+ "input_shape": self.nodes_io_shapes[self.dvs_node_id]["input"],
218
+ "module": self.dvs_layer,
219
+ "pooling": None,
220
+ "destinations": None,
221
+ }
222
+
223
+ # build `DynapcnnLayer` instances from mapper.
224
+ (
225
+ dynapcnn_layers,
226
+ destination_map,
227
+ entry_points,
228
+ ) = construct_dynapcnnlayers_from_mapper(
229
+ dcnnl_map=self.dcnnl_info,
230
+ dvs_layer_info=self.dvs_layer_info,
231
+ discretize=discretize,
232
+ rescale_fn=weight_rescaling_fn,
233
+ )
234
+
235
+ # Instantiate the DynapcnnNetworkModule
236
+ return DynapcnnNetworkModule(
237
+ dynapcnn_layers, destination_map, entry_points, self.dvs_layer_info
238
+ )
239
+
240
+ def remove_nodes_by_class(self, node_classes: Tuple[Type]):
241
+ """
242
+ Remove all nodes of the specified classes from the graph, and update the graph structure in place.
243
+ When nodes are removed, the function rewires the graph to maintain data flow continuity.
244
+ Specifically, any incoming edges to the removed nodes are redirected to their valid successors,
245
+ effectively bypassing the removed nodes.
246
+
247
+ Example:
248
+ If the original structure is A → B → C, and B is removed, the resulting connection becomes A → C.
249
+ If an entry node (a node without predecessors) is removed, its valid successors become new entry nodes.
250
+
251
+ Special handling:
252
+ For `Flatten` layers: the function updates the input shape of successor nodes to match the
253
+ shape before flattening. Note that this may lead to incorrect results if multiple `Flatten`
254
+ layers appear consecutively.
255
+
256
+ After modification, the function updates:
257
+ - The edge set (`self._edges`)
258
+ - The entry nodes (`self._entry_nodes`)
259
+ - The node index mapping (`self._name_2_indx_map`)
260
+ - The stored input/output shapes (`self._nodes_io_shapes`)
261
+
262
+ Args:
263
+ node_classes (tuple of types): The types of nodes (layers) to remove from the graph.
264
+ """
265
+
266
+ # Build a mapping of remaining nodes to their valid successors
267
+ # (i.e., successors not belonging to the classes to be removed)
268
+ source2target: Dict[int, Set[int]] = {}
269
+
270
+ for node in self.sorted_nodes:
271
+ mod = self.indx_2_module_map[node]
272
+ if isinstance(mod, node_classes):
273
+ # If an entry node is removed, promote its valid successors as new entry nodes
274
+ if node in self.entry_nodes:
275
+ targets = self._find_valid_targets(node, node_classes)
276
+ self._entry_nodes.update(targets)
277
+
278
+ # If removing a Flatten layer, propagate the pre-flatten shape
279
+ # to the next nodes to maintain correct input shape metadata
280
+ if isinstance(mod, nn.Flatten):
281
+ shape_before_flatten = self.nodes_io_shapes[node]["input"]
282
+ for target_node in self._find_valid_targets(node, node_classes):
283
+ self._nodes_io_shapes[target_node][
284
+ "input"
285
+ ] = shape_before_flatten
286
+ else:
287
+ # Retain this node and connect it to its valid successors
288
+ source2target[node] = self._find_valid_targets(node, node_classes)
289
+
290
+ # Reassign new contiguous indices to the remaining nodes (starting from 0)
291
+ remapped_nodes = {
292
+ old_idx: new_idx
293
+ for new_idx, old_idx in enumerate(sorted(source2target.keys()))
294
+ }
295
+
296
+ # Rebuild the edge set using the remapped indices
297
+
298
+ self._edges = {
299
+ (remapped_nodes[src], remapped_nodes[tgt])
300
+ for src, targets in source2target.items()
301
+ for tgt in targets
302
+ }
303
+
304
+ # Synchronize all internal structures with the updated graph topology
305
+ self._update_internal_representation(remapped_nodes)
306
+
307
+ def get_node_io_shapes(self, node: int) -> Tuple[torch.Size, torch.Size]:
308
+ """Returns the I/O tensors' shapes of `node`.
309
+
310
+ Returns:
311
+ input shape (torch.Size): shape of the input tensor to `node`.
312
+ output shape (torch.Size): shape of the output tensor from `node`.
313
+ """
314
+ return (
315
+ self._nodes_io_shapes[node]["input"],
316
+ self._nodes_io_shapes[node]["output"],
317
+ )
318
+
319
+ def verify_graph_integrity(self):
320
+ """
321
+ Apply consistency checks to verify that the graph structure is valid and supported.
322
+ This function ensures:
323
+ - No isolated nodes exist (except for `DVSLayer` instances).
324
+ - Only certain layer types are allowed to have multiple inputs or outputs.
325
+ - No self-recurrent edges are present (edges of the form `(x, x)`).
326
+ Raises:
327
+ InvalidGraphStructure: If any of the integrity checks fail.
328
+ """
329
+ # --- Check for self-recurrent edges -------------------------------------
330
+ self_recurrent_edges = {(src, tgt) for (src, tgt) in self.edges if src == tgt}
331
+ if self_recurrent_edges:
332
+ raise InvalidGraphStructure(
333
+ f"The graph contains self-recurrent edges: {self_recurrent_edges}. "
334
+ "Recurrent connections (edges of the form (x, x)) are not supported."
335
+ )
336
+ # --- Check node connectivity and input/output structure -----------------
337
+ for node, module in self.indx_2_module_map.items():
338
+ # Ensure there are no isolated (unconnected) nodes, except for DVSLayer instances
339
+ edges_with_node = {e for e in self.edges if node in e}
340
+ if not edges_with_node and not isinstance(module, DVSLayer):
341
+ raise InvalidGraphStructure(
342
+ f"There is an isolated module of type {type(module)} (node {node}). "
343
+ "Only `DVSLayer` instances can be completely disconnected from "
344
+ "the rest of the network. Other than that, DynapCNN layers consist "
345
+ "of groups of weight layers (`Linear` or `Conv2d`), spiking layers "
346
+ "(`IAF` or `IAFSqueeze`), and optionally pooling layers "
347
+ "(`SumPool2d`, `AvgPool2d`)."
348
+ )
349
+ # Ensure only specific module types can have multiple inputs
350
+ if not isinstance(module, LAYER_TYPES_WITH_MULTIPLE_INPUTS):
351
+ sources = self._find_all_sources_of_input_to(node)
352
+ if len(sources) > 1:
353
+ raise InvalidGraphStructure(
354
+ f"Node {node} of type {type(module)} has {len(sources)} inputs, "
355
+ f"but only nodes of type {LAYER_TYPES_WITH_MULTIPLE_INPUTS} "
356
+ "are allowed to have multiple inputs."
357
+ )
358
+ # Ensure only specific module types can have multiple outputs
359
+ if not isinstance(module, LAYER_TYPES_WITH_MULTIPLE_OUTPUTS):
360
+ targets = self._find_valid_targets(node)
361
+ if len(targets) > 1:
362
+ raise InvalidGraphStructure(
363
+ f"Node {node} of type {type(module)} has {len(targets)} outputs, "
364
+ f"but only nodes of type {LAYER_TYPES_WITH_MULTIPLE_OUTPUTS} "
365
+ "are allowed to have multiple outputs."
366
+ )
367
+
368
+ def verify_node_types(self):
369
+ """Verify that all nodes are of a supported type.
370
+
371
+ Raises:
372
+ UnsupportedLayerType: If any verification fails.
373
+ """
374
+ unsupported_nodes = dict()
375
+ for index, module in self.indx_2_module_map.items():
376
+ if not isinstance(module, SupportedNodeTypes):
377
+ node_type = type(module)
378
+ if node_type in unsupported_nodes:
379
+ unsupported_nodes[node_type].add(index)
380
+ else:
381
+ unsupported_nodes[node_type] = {index}
382
+ # Specific error message for non-squeezing IAF layer
383
+ iaf_layers = []
384
+ for idx in unsupported_nodes.pop(sl.IAF, []):
385
+ iaf_layers.append(self.indx_2_module_map[idx])
386
+ if iaf_layers:
387
+ layer_str = ", ".join(str(lyr) for lyr in (iaf_layers))
388
+ raise UnsupportedLayerType(
389
+ f"The provided SNN contains IAF layers:\n{layer_str}.\n"
390
+ "For compatibility with torch's `nn.Conv2d` modules, please "
391
+ "use `IAFSqueeze` layers instead."
392
+ )
393
+ # Specific error message for leaky neuron types
394
+ lif_layers = []
395
+ for lif_type in (sl.LIF, sl.LIFSqueeze):
396
+ for idx in unsupported_nodes.pop(lif_type, []):
397
+ lif_layers.append(self.indx_2_module_map[idx])
398
+ if lif_layers:
399
+ layer_str = ", ".join(str(lyr) for lyr in (lif_layers))
400
+ raise UnsupportedLayerType(
401
+ f"The provided SNN contains LIF layers:\n{layer_str}.\n"
402
+ "Leaky Integrate-and-Fire dynamics are not supported by "
403
+ "DynapCNN. Use non-leaky `IAF` or `IAFSqueeze` layers "
404
+ "instead."
405
+ )
406
+ # Specific error message for most common non-spiking activation layers
407
+ activation_layers = []
408
+ for activation_type in (nn.ReLU, nn.Sigmoid, nn.Tanh, sl.NeuromorphicReLU):
409
+ for idx in unsupported_nodes.pop(activation_type, []):
410
+ activation_layers.append(self.indx_2_module_map[idx])
411
+ if activation_layers:
412
+ layer_str = ", ".join(str(lyr) for lyr in (activation_layers))
413
+ raise UnsupportedLayerType(
414
+ "The provided SNN contains non-spiking activation layers:\n"
415
+ f"{layer_str}.\nPlease convert them to `IAF` or `IAFSqueeze` "
416
+ "layers before instantiating a `DynapcnnNetwork`. You can "
417
+ "use the function `sinabs.from_model.from_torch` for this."
418
+ )
419
+ if unsupported_nodes:
420
+ # More generic error message for all remaining types
421
+ raise UnsupportedLayerType(
422
+ "One or more layers in the provided SNN are not supported: "
423
+ f"{pformat(unsupported_nodes)}. Supported layer types are: "
424
+ f"{pformat(SupportedNodeTypes)}."
425
+ )
426
+
427
+ def verify_no_isolated_nodes(self):
428
+ """Verify that there are no disconnected nodes except for `DVSLayer` instances.
429
+
430
+ Raises:
431
+ InvalidGraphStructure when disconnected nodes are detected.
432
+ """
433
+ for node, module in self.indx_2_module_map.items():
434
+ # Make sure there are no individual, unconnected nodes
435
+ edges_with_node = {e for e in self.edges if node in e}
436
+ if not edges_with_node and not isinstance(module, DVSLayer):
437
+ raise InvalidGraphStructure(
438
+ f"There is an isolated module of type {type(module)}. Only "
439
+ "`DVSLayer` instances can be completely disconnected from "
440
+ "any other module. Other than that, layers for DynapCNN "
441
+ "consist of groups of weight layers (`Linear` or `Conv2d`), "
442
+ "spiking layers (`IAF` or `IAFSqueeze`), and optioanlly "
443
+ "pooling layers (`SumPool2d`, `AvgPool2d`)."
444
+ )
445
+
446
+ def _handle_dvs_input(
447
+ self, input_shape: Tuple[int, int, int], dvs_input: Optional[bool] = None
448
+ ):
449
+ """Make sure DVS input is properly integrated into graph
450
+
451
+ Decide whether `DVSLayer` instance needs to be added to the graph. This
452
+ is the case when `dvs_input==True` and there is no `DVSLayer` yet.
453
+ Make sure edges between DVS related nodes are set properly.
454
+ Absorb pooling layers in DVS node if applicable.
455
+
456
+ Args:
457
+ input_shape (tuple of three integers): Input shape (features,
458
+ height, width).
459
+ dvs_input (bool or `None` (default)): If `False`, will raise
460
+ `InvalidModelWithDvsSetup` if a `DVSLayer` is part of the
461
+ graph. If `True`, a `DVSLayer` will be added to the graph if
462
+ there is none already. If `None`, the model is considered to be
463
+ using DVS input only if the graph contains a `DVSLayer`.
464
+ """
465
+ if self.has_dvs_layer:
466
+ # Make a copy of the layer so that the original version is not
467
+ # changed in place
468
+ new_dvs_layer = deepcopy(self.dvs_layer)
469
+ self._indx_2_module_map[self.dvs_node_id] = new_dvs_layer
470
+ elif dvs_input:
471
+ # Insert a DVSLayer node in the graph.
472
+ new_dvs_layer = self._add_dvs_node(dvs_input_shape=input_shape)
473
+ else:
474
+ dvs_input = None
475
+ if dvs_input is not None:
476
+ # Disable pixel array if `dvs_input` is False
477
+ new_dvs_layer.disable_pixel_array = not dvs_input
478
+
479
+ # Check for the need of fixing NIR edges extraction when DVS is a node in the graph. If DVS
480
+ # is used its node becomes the only entry node in the graph.
481
+ fix_dvs_module_edges(
482
+ self._edges,
483
+ self._indx_2_module_map,
484
+ self._name_2_indx_map,
485
+ self._entry_nodes,
486
+ )
487
+
488
+ # Check if graph structure and DVSLayer.merge_polarities are correctly set (if DVS node exists).
489
+ self._validate_dvs_setup(dvs_input_shape=input_shape)
490
+
491
+ def _add_dvs_node(self, dvs_input_shape: Tuple[int, int, int]) -> DVSLayer:
492
+ """In-place modification of `self._name_2_indx_map`,
493
+ `self._indx_2_module_map`, and `self._edges` to accomodate the creation
494
+ of an extra node in the graph representing the DVS camera of the chip.
495
+ The DVSLayer node will point to every other node that is up to this
496
+ point an entry node of the original graph, so `self._entry_nodes` is
497
+ modified in-place to have only one entry: the index of the DVS node.
498
+
499
+ Args:
500
+ dvs_input_shape (tuple): shape of the DVSLayer input in format
501
+ `(features, height, width)`
502
+
503
+ Returns:
504
+ A handler to the newly added `DVSLayer` instance.
505
+ """
506
+
507
+ (features, height, width) = dvs_input_shape
508
+ if features > 2:
509
+ raise ValueError(
510
+ f"A DVSLayer istance can have a max feature dimension of 2 but {features} was given."
511
+ )
512
+
513
+ # Find new index to be assigned to DVS node
514
+ self._name_2_indx_map["dvs"] = get_new_index(self._name_2_indx_map.values())
515
+ # add module entry for node 'dvs'.
516
+ dvs_layer = DVSLayer(
517
+ input_shape=(height, width),
518
+ merge_polarities=(features == 1),
519
+ )
520
+ self._indx_2_module_map[self._name_2_indx_map["dvs"]] = dvs_layer
521
+
522
+ # set DVS node as input to each entry node of the graph
523
+ self._edges.update(
524
+ {
525
+ (self._name_2_indx_map["dvs"], entry_node)
526
+ for entry_node in self._entry_nodes
527
+ }
528
+ )
529
+ # DVSLayer node becomes the only entrypoint of the graph
530
+ self._entry_nodes = {self._name_2_indx_map["dvs"]}
531
+
532
+ return dvs_layer
533
+
534
+ def _get_dvs_node_id(self) -> Union[int, None]:
535
+ """Return index of `DVSLayer`
536
+ instance if it exists.
537
+
538
+ Returns:
539
+ DVSLayer if exactly one is found, otherwise None.
540
+
541
+ Raises:
542
+ InvalidGraphStructure: if more than one DVSLayer is found.
543
+
544
+ """
545
+
546
+ dvs_layer_indices = {
547
+ index
548
+ for index, module in self._indx_2_module_map.items()
549
+ if isinstance(module, DVSLayer)
550
+ }
551
+
552
+ if (num_dvs := len(dvs_layer_indices)) == 0:
553
+ return
554
+ elif num_dvs == 1:
555
+ return dvs_layer_indices.pop()
556
+ else:
557
+ raise InvalidGraphStructure(
558
+ f"The provided model has {num_dvs} `DVSLayer`s. At most one is allowed."
559
+ )
560
+
561
+ def _validate_dvs_setup(self, dvs_input_shape: Tuple[int, int, int]) -> None:
562
+ """If a DVSLayer node exists, makes sure it is the only entry node of
563
+ the graph. Checks if its `merge_polarities` attribute matches
564
+ `dummy_input.shape[0]` (the number of features) and, if not, it will be
565
+ set based on the number of features of the input.
566
+
567
+ Args:
568
+ dvs_input_shape (tuple): shape of the DVSLayer input in format
569
+ `(features, height, width)`.
570
+ """
571
+
572
+ if self.dvs_layer is None:
573
+ # No DVSLayer found - nothing to do here.
574
+ return
575
+
576
+ if (nb_entries := len(self._entry_nodes)) > 1:
577
+ raise ValueError(
578
+ f"A DVSLayer node exists and there are {nb_entries} entry nodes in the graph: the DVSLayer should be the only entry node."
579
+ )
580
+
581
+ (features, _, _) = dvs_input_shape
582
+
583
+ if features > 2:
584
+ raise ValueError(
585
+ f"A DVSLayer istance can have the feature dimension of its inputs with values 1 or 2 but {features} was given."
586
+ )
587
+
588
+ if self.dvs_layer.merge_polarities and features != 1:
589
+ raise ValueError(
590
+ f"The 'DVSLayer.merge_polarities' is set to 'True' which means the number of input features should be 1 (current input shape is {dvs_input_shape})."
591
+ )
592
+
593
+ if features == 1:
594
+ self.dvs_layer.merge_polarities = True
595
+
596
+ def _get_name_2_indx_map(self, nir_graph: TorchGraph) -> Dict[str, int]:
597
+ """Assign unique index to each node and return mapper from name to
598
+ index.
599
+
600
+ Args:
601
+ nir_graph (TorchGraph): a NIR graph representation of
602
+ `spiking_model`.
603
+
604
+ Returns:
605
+ A dictionary where `key` is the original variable name for a layer
606
+ in `spiking_model` and `value` is an integer representing the layer
607
+ in a standard format.
608
+ """
609
+
610
+ return {
611
+ node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list)
612
+ }
613
+
614
+ def _get_edges_from_nir(
615
+ self, nir_graph: TorchGraph, name_2_indx_map: Dict[str, int]
616
+ ) -> Set[Edge]:
617
+ """Standardize the representation of TorchGraph` into a list of edges,
618
+ representing nodes by their indices.
619
+
620
+ Args:
621
+ nir_graph (TorchGraph): a NIR graph representation of
622
+ `spiking_model`.
623
+ name_2_indx_map (dict): Map from node names to unique indices.
624
+
625
+ Returns:
626
+ Tuple describing the connections between layers in `spiking_model`.
627
+ """
628
+ return {
629
+ (name_2_indx_map[src.name], name_2_indx_map[tgt.name])
630
+ for src in nir_graph.node_list
631
+ for tgt in src.outgoing_nodes
632
+ }
633
+
634
+ def _get_entry_nodes(self, edges: Set[Edge]) -> Set[Edge]:
635
+ """Find nodes that act as entry points to the graph
636
+
637
+ Args:
638
+ edges (set): tuples describing the connections between layers in
639
+ `spiking_model`.
640
+
641
+ Returns:
642
+ IDs of nodes acting as entry points for the network (i.e.,
643
+ receiving external input).
644
+ """
645
+ if not edges:
646
+ return set()
647
+
648
+ all_sources, all_targets = zip(*edges)
649
+ return set(all_sources) - set(all_targets)
650
+
651
+ def _get_name2module_map(self, model: nn.Module) -> Dict[int, nn.Module]:
652
+ """Find for each node in the graph what its associated layer in `model` is.
653
+
654
+ Args:
655
+ model (nn.Module): the `spiking_model` used as argument to the class
656
+ instance.
657
+ Returns:
658
+ The mapping between a node (`key` as an `int`) and its module
659
+ (`value` as a `nn.Module`).
660
+ """
661
+
662
+ indx_2_module_map = dict()
663
+
664
+ for name, module in model.named_modules():
665
+ # Make sure names match those provided by nirtorch nodes
666
+ if name in self._name_2_indx_map:
667
+ indx_2_module_map[self._name_2_indx_map[name]] = module
668
+ else:
669
+ # In older nirtorch versions, node names are "sanitized"
670
+ # Try with sanitized version of the name
671
+ name = nirtorch.utils.sanitize_name(name)
672
+ if name in self._name_2_indx_map:
673
+ indx_2_module_map[self._name_2_indx_map[name]] = module
674
+
675
+ return indx_2_module_map
676
+
677
+ def _update_internal_representation(self, remapped_nodes: Dict[int, int]):
678
+ """Update internal attributes after remapping of nodes
679
+
680
+ Args:
681
+ remapped_nodes (dict): Maps previous (key) to new (value) node
682
+ indices. Nodes that were removed are not included.
683
+ """
684
+
685
+ if len(self._name_2_indx_map) > 0:
686
+ # Update name-to-index map based on new node indices
687
+ self._name_2_indx_map = {
688
+ name: remapped_nodes[old_idx]
689
+ for name, old_idx in self._name_2_indx_map.items()
690
+ if old_idx in remapped_nodes
691
+ }
692
+
693
+ # Update entry nodes based on new node indices
694
+ self._entry_nodes = {
695
+ remapped_nodes[old_idx]
696
+ for old_idx in self._entry_nodes
697
+ if old_idx in remapped_nodes
698
+ }
699
+
700
+ # Update io-shapes based on new node indices
701
+ self._nodes_io_shapes = {
702
+ remapped_nodes[old_idx]: shape
703
+ for old_idx, shape in self._nodes_io_shapes.items()
704
+ if old_idx in remapped_nodes
705
+ }
706
+
707
+ # Update sinabs module map based on new node indices
708
+ self._indx_2_module_map = {
709
+ remapped_nodes[old_idx]: module
710
+ for old_idx, module in self._indx_2_module_map.items()
711
+ if old_idx in remapped_nodes
712
+ }
713
+
714
+ def _sort_graph_nodes(self) -> List[int]:
715
+ """Sort graph nodes topologically.
716
+
717
+ Returns:
718
+ sorted_nodes (list of integers): IDs of nodes, sorted.
719
+ """
720
+ # Make a temporary copy of edges and include inputs
721
+ temp_edges = self.edges
722
+ for node in self._entry_nodes:
723
+ temp_edges.add(("input", node))
724
+ return topological_sorting(temp_edges)
725
+
726
+ def _get_nodes_io_shapes(
727
+ self, input_dummy: torch.tensor
728
+ ) -> Dict[int, Dict[str, torch.Size]]:
729
+ """Iteratively calls the forward method of each `nn.Module` (i.e., a
730
+ layer/node in the graph) using the topologically sorted nodes extracted
731
+ from the computational graph of the model being parsed.
732
+
733
+ Args:
734
+ input_dummy (torch.tensor): a sample (random) tensor of the sort of
735
+ input being fed to the network.
736
+
737
+ Returns:
738
+ A dictionary mapping nodes to their I/O shapes.
739
+ """
740
+ nodes_io_map = {}
741
+
742
+ # propagate inputs through the nodes.
743
+ for node in self.sorted_nodes:
744
+ if isinstance(self.indx_2_module_map[node], sl.merge.Merge):
745
+ # find `Merge` arguments (at this point the inputs to Merge should have been calculated).
746
+ input_nodes = self._find_merge_arguments(node)
747
+
748
+ # retrieve arguments output tensors.
749
+ inputs = [nodes_io_map[n]["output"] for n in input_nodes]
750
+
751
+ # TODO - this is currently a limitation imposed by the validation checks done by Speck once a configuration: it wants
752
+ # different input sources to a core to have the same output shapes.
753
+ if any(inp.shape != inputs[0].shape for inp in inputs):
754
+ raise ValueError(
755
+ f"Layer `sinabs.layers.merge.Merge` (node {node}) requires input tensors with the same shape"
756
+ )
757
+
758
+ # forward input through the node.
759
+ _output = self.indx_2_module_map[node](*inputs)
760
+
761
+ # save node's I/O tensors.
762
+ nodes_io_map[node] = {"input": inputs[0], "output": _output}
763
+
764
+ else:
765
+ if node in self._entry_nodes:
766
+ # forward input dummy through node.
767
+ _output = self.indx_2_module_map[node](input_dummy)
768
+
769
+ # save node's I/O tensors.
770
+ nodes_io_map[node] = {"input": input_dummy, "output": _output}
771
+
772
+ else:
773
+ # find node generating the input to be used.
774
+ input_node = self._find_source_of_input_to(node)
775
+ _input = nodes_io_map[input_node]["output"]
776
+
777
+ # forward input through the node.
778
+ _output = self.indx_2_module_map[node](_input)
779
+
780
+ # save node's I/O tensors.
781
+ nodes_io_map[node] = {"input": _input, "output": _output}
782
+
783
+ # replace the I/O tensor information by its shape information, ignoring the batch/time axis
784
+ for node, io in nodes_io_map.items():
785
+ input_shape = io["input"].shape[1:]
786
+ output_shape = io["output"].shape[1:]
787
+ # Linear layers have fewer in/out dimensions. Extend by appending 1's
788
+ if (length := len(input_shape)) < 3:
789
+ input_shape = (*input_shape, *(1 for __ in range(3 - length)))
790
+ assert len(input_shape) == 3
791
+ if (length := len(output_shape)) < 3:
792
+ output_shape = (*output_shape, *(1 for __ in range(3 - length)))
793
+ assert len(output_shape) == 3
794
+ nodes_io_map[node]["input"] = input_shape
795
+ nodes_io_map[node]["output"] = output_shape
796
+
797
+ return nodes_io_map
798
+
799
+ def _find_all_sources_of_input_to(self, node: int) -> Set[int]:
800
+ """Finds all source nodes to `node`.
801
+
802
+ Args:
803
+ node (int): the node in the computational graph for which we whish
804
+ to find the input source (either another node in the graph or
805
+ the original input itself to the network).
806
+
807
+ Returns:
808
+ IDs of the nodes in the computational graph providing the input to `node`.
809
+ """
810
+ return set(src for (src, tgt) in self._edges if tgt == node)
811
+
812
+ def _find_source_of_input_to(self, node: int) -> int:
813
+ """Finds the first edge `(X, node)` returns `X`.
814
+
815
+ Args:
816
+ node (int): the node in the computational graph for which we whish
817
+ to find the input source (either another node in the graph or
818
+ the original input itself to the network).
819
+
820
+ Returns:
821
+ ID of the node in the computational graph providing the input to
822
+ `node`. If `node` is receiving outside input (i.e., it is a starting
823
+ node) the return will be -1. For example, this will be the case when
824
+ a network with two independent branches (each starts from a
825
+ different "input node") merge along the computational graph.
826
+ """
827
+ sources = self._find_all_sources_of_input_to(node)
828
+ if len(sources) == 0:
829
+ return -1
830
+ if len(sources) > 1:
831
+ raise RuntimeError(f"Node {node} has more than 1 input")
832
+ return sources.pop()
833
+
834
+ def _find_merge_arguments(self, node: int) -> Edge:
835
+ """A `Merge` layer receives two inputs. Return the two inputs to
836
+ `merge_node` representing a `Merge` layer.
837
+
838
+ Returns:
839
+ The IDs of the nodes that provice the input arguments to a `Merge` layer.
840
+ """
841
+ sources = self._find_all_sources_of_input_to(node)
842
+
843
+ if len(sources) != 2:
844
+ raise ValueError(
845
+ f"Number of arguments found for `Merge` node {node} is {len(sources)} (should be 2)."
846
+ )
847
+
848
+ return tuple(sources)
849
+
850
+ def _find_valid_targets(
851
+ self, node: int, ignored_node_classes: Tuple[Type] = ()
852
+ ) -> Set[int]:
853
+ """Find all targets of a node that are not ignored classes
854
+
855
+ Return a set of all target nodes that are not of an ignored class.
856
+ For target nodes of ignored classes, recursively return their valid
857
+ targets.
858
+
859
+ Args:
860
+ node (int): ID of node whose targets should be found.
861
+ ignored_node_classes (tuple of types): Classes of which nodes
862
+ should be skiped
863
+
864
+ Returns:
865
+ Set of all recursively found valid target IDs.
866
+ """
867
+ targets = set()
868
+ for src, tgt in self.edges:
869
+ # Search for all edges with node as source
870
+ if src == node:
871
+ if isinstance(self.indx_2_module_map[tgt], ignored_node_classes):
872
+ # Find valid targets of target
873
+ targets.update(self._find_valid_targets(tgt, ignored_node_classes))
874
+ else:
875
+ # Target is valid, add it to `targets`
876
+ targets.add(tgt)
877
+ return targets