sinabs 3.0.4.dev25__py3-none-any.whl → 3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sinabs/activation/reset_mechanism.py +3 -3
- sinabs/activation/surrogate_gradient_fn.py +4 -4
- sinabs/backend/dynapcnn/__init__.py +5 -4
- sinabs/backend/dynapcnn/chip_factory.py +33 -61
- sinabs/backend/dynapcnn/chips/dynapcnn.py +182 -86
- sinabs/backend/dynapcnn/chips/speck2e.py +6 -5
- sinabs/backend/dynapcnn/chips/speck2f.py +6 -5
- sinabs/backend/dynapcnn/config_builder.py +39 -59
- sinabs/backend/dynapcnn/connectivity_specs.py +48 -0
- sinabs/backend/dynapcnn/discretize.py +91 -155
- sinabs/backend/dynapcnn/dvs_layer.py +59 -101
- sinabs/backend/dynapcnn/dynapcnn_layer.py +185 -119
- sinabs/backend/dynapcnn/dynapcnn_layer_utils.py +335 -0
- sinabs/backend/dynapcnn/dynapcnn_network.py +602 -325
- sinabs/backend/dynapcnn/dynapcnnnetwork_module.py +370 -0
- sinabs/backend/dynapcnn/exceptions.py +122 -3
- sinabs/backend/dynapcnn/io.py +51 -91
- sinabs/backend/dynapcnn/mapping.py +111 -75
- sinabs/backend/dynapcnn/nir_graph_extractor.py +877 -0
- sinabs/backend/dynapcnn/sinabs_edges_handler.py +1024 -0
- sinabs/backend/dynapcnn/utils.py +214 -459
- sinabs/backend/dynapcnn/weight_rescaling_methods.py +53 -0
- sinabs/conversion.py +2 -2
- sinabs/from_torch.py +23 -1
- sinabs/hooks.py +38 -41
- sinabs/layers/alif.py +16 -16
- sinabs/layers/crop2d.py +2 -2
- sinabs/layers/exp_leak.py +1 -1
- sinabs/layers/iaf.py +11 -11
- sinabs/layers/lif.py +9 -9
- sinabs/layers/neuromorphic_relu.py +9 -8
- sinabs/layers/pool2d.py +5 -5
- sinabs/layers/quantize.py +1 -1
- sinabs/layers/stateful_layer.py +10 -7
- sinabs/layers/to_spike.py +9 -9
- sinabs/network.py +14 -12
- sinabs/synopcounter.py +10 -7
- sinabs/utils.py +155 -7
- sinabs/validate_memory_speck.py +0 -5
- {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/METADATA +2 -1
- sinabs-3.1.0.dist-info/RECORD +65 -0
- {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/licenses/AUTHORS +1 -0
- sinabs-3.1.0.dist-info/pbr.json +1 -0
- sinabs-3.0.4.dev25.dist-info/RECORD +0 -59
- sinabs-3.0.4.dev25.dist-info/pbr.json +0 -1
- {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/WHEEL +0 -0
- {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/licenses/LICENSE +0 -0
- {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,877 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from pprint import pformat
|
|
3
|
+
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
|
4
|
+
|
|
5
|
+
import nirtorch
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from sinabs import layers as sl
|
|
10
|
+
from sinabs.utils import get_new_index
|
|
11
|
+
|
|
12
|
+
from .connectivity_specs import (
|
|
13
|
+
LAYER_TYPES_WITH_MULTIPLE_INPUTS,
|
|
14
|
+
LAYER_TYPES_WITH_MULTIPLE_OUTPUTS,
|
|
15
|
+
SupportedNodeTypes,
|
|
16
|
+
)
|
|
17
|
+
from .dvs_layer import DVSLayer
|
|
18
|
+
from .dynapcnn_layer_utils import construct_dynapcnnlayers_from_mapper
|
|
19
|
+
from .dynapcnnnetwork_module import DynapcnnNetworkModule
|
|
20
|
+
from .exceptions import InvalidGraphStructure, UnsupportedLayerType
|
|
21
|
+
from .sinabs_edges_handler import (
|
|
22
|
+
collect_dynapcnn_layer_info,
|
|
23
|
+
fix_dvs_module_edges,
|
|
24
|
+
handle_batchnorm_nodes,
|
|
25
|
+
)
|
|
26
|
+
from .utils import Edge, topological_sorting
|
|
27
|
+
from warnings import warn
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from nirtorch.graph import TorchGraph
|
|
31
|
+
except ImportError:
|
|
32
|
+
# In older nirtorch versions TorchGraph is called Graph
|
|
33
|
+
from nirtorch.graph import Graph as TorchGraph
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class GraphExtractor:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
spiking_model: nn.Module,
|
|
40
|
+
dummy_input: torch.tensor,
|
|
41
|
+
dvs_input: Optional[bool] = None,
|
|
42
|
+
ignore_node_types: Optional[Iterable[Type]] = None,
|
|
43
|
+
):
|
|
44
|
+
"""Class implementing the extraction of the computational graph from `spiking_model`, where
|
|
45
|
+
each node represents a layer in the model and the list of edges represents how the data flow between
|
|
46
|
+
the layers.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
spiking_model (nn.Module): a sinabs-compatible spiking network.
|
|
50
|
+
dummy_input (torch.tensor): an input sample to be fed through
|
|
51
|
+
the model to acquire both the computational graph (via
|
|
52
|
+
`nirtorch`) and the I/O shapes of each node. Its a 4-D shape
|
|
53
|
+
with `(batch, channels, heigh, width)`.
|
|
54
|
+
dvs_input (bool): optional (default as `None`). Whether or not the
|
|
55
|
+
model should start with a `DVSLayer`.
|
|
56
|
+
ignore_node_types (iterable of types): Node types that should be
|
|
57
|
+
ignored completely from the graph. This can include, for
|
|
58
|
+
instance, `nn.Dropout2d`, which otherwise can result in wrongly
|
|
59
|
+
inferred graph structures by NIRTorch. Types such as
|
|
60
|
+
`nn.Flatten`, or sinabs `Merge` should not be included here, as
|
|
61
|
+
they are needed to properly handle graph structure and
|
|
62
|
+
metadata. They can be removed after instantiation with
|
|
63
|
+
`remove_nodes_by_class`.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
edges (set of 2-tuples of integers): Tuples describing the
|
|
67
|
+
connections between layers in `spiking_model`. Each layer
|
|
68
|
+
(node) is identified by a unique integer ID.
|
|
69
|
+
name_2_index_map (dict): Keys are original variable names of layers
|
|
70
|
+
in `spiking_model`. Values are unique integer IDs.
|
|
71
|
+
entry_nodes (set of ints): IDs of nodes acting as entry points for
|
|
72
|
+
the network, i.e. receiving external input.
|
|
73
|
+
indx_2_module_map (dict): Map from layer ID to the corresponding
|
|
74
|
+
nn.Module instance.
|
|
75
|
+
nodes_io_shapes (dict): Map from node ID to dict containing node's
|
|
76
|
+
in- and output shapes.
|
|
77
|
+
dvs_input (bool): optional (default as `None`). Whether or not the
|
|
78
|
+
model should start with a `DVSLayer`.
|
|
79
|
+
ignore_node_types (iterable of types): Node types that should be
|
|
80
|
+
ignored completely from the graph. This can include, for
|
|
81
|
+
instance, `nn.Dropout2d`, which otherwise can result in wrongly
|
|
82
|
+
inferred graph structures by NIRTorch. Types such as
|
|
83
|
+
`nn.Flatten`, or sinabs `Merge` should not be included here, as
|
|
84
|
+
they are needed to properly handle graph structure and
|
|
85
|
+
metadata. They can be removed after instantiation with
|
|
86
|
+
`remove_nodes_by_class`.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
# Store state before it is changed due to NIRTorch and
|
|
90
|
+
# `self._get_nodes_io_shapes` passing dummy input
|
|
91
|
+
original_state = {
|
|
92
|
+
n: b.detach().clone() for n, b in spiking_model.named_buffers()
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
self._edges = set()
|
|
96
|
+
# Empty sequentials will cause nirtorch to fail. Treat this case separately
|
|
97
|
+
if isinstance(spiking_model, nn.Sequential) and len(spiking_model) == 0:
|
|
98
|
+
self._name_2_indx_map = dict()
|
|
99
|
+
self._edges = set()
|
|
100
|
+
original_state = {}
|
|
101
|
+
else:
|
|
102
|
+
nir_graph = nirtorch.graph.extract_torch_graph(
|
|
103
|
+
spiking_model, dummy_input, model_name=None
|
|
104
|
+
).ignore_tensors()
|
|
105
|
+
|
|
106
|
+
if ignore_node_types is not None:
|
|
107
|
+
for node_type in ignore_node_types:
|
|
108
|
+
nir_graph = nir_graph.ignore_nodes(node_type)
|
|
109
|
+
|
|
110
|
+
# Map node names to indices
|
|
111
|
+
self._name_2_indx_map = self._get_name_2_indx_map(nir_graph)
|
|
112
|
+
|
|
113
|
+
# Extract edges list from graph
|
|
114
|
+
self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map)
|
|
115
|
+
|
|
116
|
+
# Store the associated `nn.Module` (layer) of each node.
|
|
117
|
+
self._indx_2_module_map = self._get_name2module_map(spiking_model)
|
|
118
|
+
|
|
119
|
+
if len(self._name_2_indx_map) > 0:
|
|
120
|
+
# Merges BatchNorm2d/BatchNorm1d nodes with Conv2d/Linear ones.
|
|
121
|
+
handle_batchnorm_nodes(
|
|
122
|
+
self._edges, self._indx_2_module_map, self._name_2_indx_map
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Determine entry points to graph
|
|
126
|
+
self._entry_nodes = self._get_entry_nodes(self._edges)
|
|
127
|
+
|
|
128
|
+
# Make sure DVS input is properly integrated into graph
|
|
129
|
+
self._handle_dvs_input(input_shape=dummy_input.shape[1:], dvs_input=dvs_input)
|
|
130
|
+
|
|
131
|
+
# retrieves what the I/O shape for each node's module is.
|
|
132
|
+
self._nodes_io_shapes = self._get_nodes_io_shapes(dummy_input)
|
|
133
|
+
|
|
134
|
+
# Restore original state - after forward passes from nirtorch and `_get_nodes_io_shapes`
|
|
135
|
+
for n, b in spiking_model.named_buffers():
|
|
136
|
+
b.set_(original_state[n].clone())
|
|
137
|
+
|
|
138
|
+
# Verify that graph is compatible
|
|
139
|
+
self.verify_graph_integrity()
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def dvs_layer(self) -> Union[DVSLayer, None]:
|
|
143
|
+
idx = self.dvs_node_id
|
|
144
|
+
if idx is None:
|
|
145
|
+
return None
|
|
146
|
+
else:
|
|
147
|
+
return self.indx_2_module_map[self.dvs_node_id]
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def dvs_node_id(self) -> Union[int, None]:
|
|
151
|
+
return self._get_dvs_node_id()
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def entry_nodes(self) -> Set[int]:
|
|
155
|
+
return {n for n in self._entry_nodes}
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def edges(self) -> Set[Edge]:
|
|
159
|
+
return {(src, tgt) for src, tgt in self._edges}
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def has_dvs_layer(self) -> bool:
|
|
163
|
+
return self.dvs_layer is not None
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def name_2_indx_map(self) -> Dict[str, int]:
|
|
167
|
+
return {name: idx for name, idx in self._name_2_indx_map.items()}
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def nodes_io_shapes(self) -> Dict[int, Tuple[torch.Size]]:
|
|
171
|
+
return {n: size for n, size in self._nodes_io_shapes.items()}
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def sorted_nodes(self) -> List[int]:
|
|
175
|
+
return [n for n in self._sort_graph_nodes()]
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def indx_2_module_map(self) -> Dict[int, nn.Module]:
|
|
179
|
+
return {n: module for n, module in self._indx_2_module_map.items()}
|
|
180
|
+
|
|
181
|
+
def get_dynapcnn_network_module(
|
|
182
|
+
self, discretize: bool = True, weight_rescaling_fn: Optional[Callable] = None
|
|
183
|
+
) -> DynapcnnNetworkModule:
|
|
184
|
+
"""Create DynapcnnNetworkModule based on stored graph representation
|
|
185
|
+
|
|
186
|
+
This includes construction of the DynapcnnLayer instances
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
discretize (bool): If `True`, discretize the parameters and
|
|
190
|
+
thresholds. This is needed for uploading weights to DynapCNN.
|
|
191
|
+
Set to `False` only for testing purposes.
|
|
192
|
+
weight_rescaling_fn (callable): a method that handles how the
|
|
193
|
+
re-scaling factor for one or more `SumPool2d` projecting to the
|
|
194
|
+
same convolutional layer are combined/re-scaled before applying them.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
The DynapcnnNetworkModule based on graph representation of this `GraphExtractor`
|
|
198
|
+
|
|
199
|
+
"""
|
|
200
|
+
# Make sure all nodes are supported and there are no isolated nodes.
|
|
201
|
+
self.verify_node_types()
|
|
202
|
+
self.verify_no_isolated_nodes()
|
|
203
|
+
|
|
204
|
+
# create a dict holding the data necessary to instantiate a `DynapcnnLayer`.
|
|
205
|
+
self.dcnnl_info, self.dvs_layer_info = collect_dynapcnn_layer_info(
|
|
206
|
+
indx_2_module_map=self.indx_2_module_map,
|
|
207
|
+
edges=self.edges,
|
|
208
|
+
nodes_io_shapes=self.nodes_io_shapes,
|
|
209
|
+
entry_nodes=self.entry_nodes,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Special case where there is a disconnected `DVSLayer`: There are no
|
|
213
|
+
# Edges for the edges handler to process. Instantiate layer info manually.
|
|
214
|
+
if self.dvs_layer_info is None and self.dvs_layer is not None:
|
|
215
|
+
self.dvs_layer_info = {
|
|
216
|
+
"node_id": self.dvs_node_id,
|
|
217
|
+
"input_shape": self.nodes_io_shapes[self.dvs_node_id]["input"],
|
|
218
|
+
"module": self.dvs_layer,
|
|
219
|
+
"pooling": None,
|
|
220
|
+
"destinations": None,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
# build `DynapcnnLayer` instances from mapper.
|
|
224
|
+
(
|
|
225
|
+
dynapcnn_layers,
|
|
226
|
+
destination_map,
|
|
227
|
+
entry_points,
|
|
228
|
+
) = construct_dynapcnnlayers_from_mapper(
|
|
229
|
+
dcnnl_map=self.dcnnl_info,
|
|
230
|
+
dvs_layer_info=self.dvs_layer_info,
|
|
231
|
+
discretize=discretize,
|
|
232
|
+
rescale_fn=weight_rescaling_fn,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Instantiate the DynapcnnNetworkModule
|
|
236
|
+
return DynapcnnNetworkModule(
|
|
237
|
+
dynapcnn_layers, destination_map, entry_points, self.dvs_layer_info
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def remove_nodes_by_class(self, node_classes: Tuple[Type]):
|
|
241
|
+
"""
|
|
242
|
+
Remove all nodes of the specified classes from the graph, and update the graph structure in place.
|
|
243
|
+
When nodes are removed, the function rewires the graph to maintain data flow continuity.
|
|
244
|
+
Specifically, any incoming edges to the removed nodes are redirected to their valid successors,
|
|
245
|
+
effectively bypassing the removed nodes.
|
|
246
|
+
|
|
247
|
+
Example:
|
|
248
|
+
If the original structure is A → B → C, and B is removed, the resulting connection becomes A → C.
|
|
249
|
+
If an entry node (a node without predecessors) is removed, its valid successors become new entry nodes.
|
|
250
|
+
|
|
251
|
+
Special handling:
|
|
252
|
+
For `Flatten` layers: the function updates the input shape of successor nodes to match the
|
|
253
|
+
shape before flattening. Note that this may lead to incorrect results if multiple `Flatten`
|
|
254
|
+
layers appear consecutively.
|
|
255
|
+
|
|
256
|
+
After modification, the function updates:
|
|
257
|
+
- The edge set (`self._edges`)
|
|
258
|
+
- The entry nodes (`self._entry_nodes`)
|
|
259
|
+
- The node index mapping (`self._name_2_indx_map`)
|
|
260
|
+
- The stored input/output shapes (`self._nodes_io_shapes`)
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
node_classes (tuple of types): The types of nodes (layers) to remove from the graph.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
# Build a mapping of remaining nodes to their valid successors
|
|
267
|
+
# (i.e., successors not belonging to the classes to be removed)
|
|
268
|
+
source2target: Dict[int, Set[int]] = {}
|
|
269
|
+
|
|
270
|
+
for node in self.sorted_nodes:
|
|
271
|
+
mod = self.indx_2_module_map[node]
|
|
272
|
+
if isinstance(mod, node_classes):
|
|
273
|
+
# If an entry node is removed, promote its valid successors as new entry nodes
|
|
274
|
+
if node in self.entry_nodes:
|
|
275
|
+
targets = self._find_valid_targets(node, node_classes)
|
|
276
|
+
self._entry_nodes.update(targets)
|
|
277
|
+
|
|
278
|
+
# If removing a Flatten layer, propagate the pre-flatten shape
|
|
279
|
+
# to the next nodes to maintain correct input shape metadata
|
|
280
|
+
if isinstance(mod, nn.Flatten):
|
|
281
|
+
shape_before_flatten = self.nodes_io_shapes[node]["input"]
|
|
282
|
+
for target_node in self._find_valid_targets(node, node_classes):
|
|
283
|
+
self._nodes_io_shapes[target_node][
|
|
284
|
+
"input"
|
|
285
|
+
] = shape_before_flatten
|
|
286
|
+
else:
|
|
287
|
+
# Retain this node and connect it to its valid successors
|
|
288
|
+
source2target[node] = self._find_valid_targets(node, node_classes)
|
|
289
|
+
|
|
290
|
+
# Reassign new contiguous indices to the remaining nodes (starting from 0)
|
|
291
|
+
remapped_nodes = {
|
|
292
|
+
old_idx: new_idx
|
|
293
|
+
for new_idx, old_idx in enumerate(sorted(source2target.keys()))
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
# Rebuild the edge set using the remapped indices
|
|
297
|
+
|
|
298
|
+
self._edges = {
|
|
299
|
+
(remapped_nodes[src], remapped_nodes[tgt])
|
|
300
|
+
for src, targets in source2target.items()
|
|
301
|
+
for tgt in targets
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
# Synchronize all internal structures with the updated graph topology
|
|
305
|
+
self._update_internal_representation(remapped_nodes)
|
|
306
|
+
|
|
307
|
+
def get_node_io_shapes(self, node: int) -> Tuple[torch.Size, torch.Size]:
|
|
308
|
+
"""Returns the I/O tensors' shapes of `node`.
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
input shape (torch.Size): shape of the input tensor to `node`.
|
|
312
|
+
output shape (torch.Size): shape of the output tensor from `node`.
|
|
313
|
+
"""
|
|
314
|
+
return (
|
|
315
|
+
self._nodes_io_shapes[node]["input"],
|
|
316
|
+
self._nodes_io_shapes[node]["output"],
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
def verify_graph_integrity(self):
|
|
320
|
+
"""
|
|
321
|
+
Apply consistency checks to verify that the graph structure is valid and supported.
|
|
322
|
+
This function ensures:
|
|
323
|
+
- No isolated nodes exist (except for `DVSLayer` instances).
|
|
324
|
+
- Only certain layer types are allowed to have multiple inputs or outputs.
|
|
325
|
+
- No self-recurrent edges are present (edges of the form `(x, x)`).
|
|
326
|
+
Raises:
|
|
327
|
+
InvalidGraphStructure: If any of the integrity checks fail.
|
|
328
|
+
"""
|
|
329
|
+
# --- Check for self-recurrent edges -------------------------------------
|
|
330
|
+
self_recurrent_edges = {(src, tgt) for (src, tgt) in self.edges if src == tgt}
|
|
331
|
+
if self_recurrent_edges:
|
|
332
|
+
raise InvalidGraphStructure(
|
|
333
|
+
f"The graph contains self-recurrent edges: {self_recurrent_edges}. "
|
|
334
|
+
"Recurrent connections (edges of the form (x, x)) are not supported."
|
|
335
|
+
)
|
|
336
|
+
# --- Check node connectivity and input/output structure -----------------
|
|
337
|
+
for node, module in self.indx_2_module_map.items():
|
|
338
|
+
# Ensure there are no isolated (unconnected) nodes, except for DVSLayer instances
|
|
339
|
+
edges_with_node = {e for e in self.edges if node in e}
|
|
340
|
+
if not edges_with_node and not isinstance(module, DVSLayer):
|
|
341
|
+
raise InvalidGraphStructure(
|
|
342
|
+
f"There is an isolated module of type {type(module)} (node {node}). "
|
|
343
|
+
"Only `DVSLayer` instances can be completely disconnected from "
|
|
344
|
+
"the rest of the network. Other than that, DynapCNN layers consist "
|
|
345
|
+
"of groups of weight layers (`Linear` or `Conv2d`), spiking layers "
|
|
346
|
+
"(`IAF` or `IAFSqueeze`), and optionally pooling layers "
|
|
347
|
+
"(`SumPool2d`, `AvgPool2d`)."
|
|
348
|
+
)
|
|
349
|
+
# Ensure only specific module types can have multiple inputs
|
|
350
|
+
if not isinstance(module, LAYER_TYPES_WITH_MULTIPLE_INPUTS):
|
|
351
|
+
sources = self._find_all_sources_of_input_to(node)
|
|
352
|
+
if len(sources) > 1:
|
|
353
|
+
raise InvalidGraphStructure(
|
|
354
|
+
f"Node {node} of type {type(module)} has {len(sources)} inputs, "
|
|
355
|
+
f"but only nodes of type {LAYER_TYPES_WITH_MULTIPLE_INPUTS} "
|
|
356
|
+
"are allowed to have multiple inputs."
|
|
357
|
+
)
|
|
358
|
+
# Ensure only specific module types can have multiple outputs
|
|
359
|
+
if not isinstance(module, LAYER_TYPES_WITH_MULTIPLE_OUTPUTS):
|
|
360
|
+
targets = self._find_valid_targets(node)
|
|
361
|
+
if len(targets) > 1:
|
|
362
|
+
raise InvalidGraphStructure(
|
|
363
|
+
f"Node {node} of type {type(module)} has {len(targets)} outputs, "
|
|
364
|
+
f"but only nodes of type {LAYER_TYPES_WITH_MULTIPLE_OUTPUTS} "
|
|
365
|
+
"are allowed to have multiple outputs."
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
def verify_node_types(self):
|
|
369
|
+
"""Verify that all nodes are of a supported type.
|
|
370
|
+
|
|
371
|
+
Raises:
|
|
372
|
+
UnsupportedLayerType: If any verification fails.
|
|
373
|
+
"""
|
|
374
|
+
unsupported_nodes = dict()
|
|
375
|
+
for index, module in self.indx_2_module_map.items():
|
|
376
|
+
if not isinstance(module, SupportedNodeTypes):
|
|
377
|
+
node_type = type(module)
|
|
378
|
+
if node_type in unsupported_nodes:
|
|
379
|
+
unsupported_nodes[node_type].add(index)
|
|
380
|
+
else:
|
|
381
|
+
unsupported_nodes[node_type] = {index}
|
|
382
|
+
# Specific error message for non-squeezing IAF layer
|
|
383
|
+
iaf_layers = []
|
|
384
|
+
for idx in unsupported_nodes.pop(sl.IAF, []):
|
|
385
|
+
iaf_layers.append(self.indx_2_module_map[idx])
|
|
386
|
+
if iaf_layers:
|
|
387
|
+
layer_str = ", ".join(str(lyr) for lyr in (iaf_layers))
|
|
388
|
+
raise UnsupportedLayerType(
|
|
389
|
+
f"The provided SNN contains IAF layers:\n{layer_str}.\n"
|
|
390
|
+
"For compatibility with torch's `nn.Conv2d` modules, please "
|
|
391
|
+
"use `IAFSqueeze` layers instead."
|
|
392
|
+
)
|
|
393
|
+
# Specific error message for leaky neuron types
|
|
394
|
+
lif_layers = []
|
|
395
|
+
for lif_type in (sl.LIF, sl.LIFSqueeze):
|
|
396
|
+
for idx in unsupported_nodes.pop(lif_type, []):
|
|
397
|
+
lif_layers.append(self.indx_2_module_map[idx])
|
|
398
|
+
if lif_layers:
|
|
399
|
+
layer_str = ", ".join(str(lyr) for lyr in (lif_layers))
|
|
400
|
+
raise UnsupportedLayerType(
|
|
401
|
+
f"The provided SNN contains LIF layers:\n{layer_str}.\n"
|
|
402
|
+
"Leaky Integrate-and-Fire dynamics are not supported by "
|
|
403
|
+
"DynapCNN. Use non-leaky `IAF` or `IAFSqueeze` layers "
|
|
404
|
+
"instead."
|
|
405
|
+
)
|
|
406
|
+
# Specific error message for most common non-spiking activation layers
|
|
407
|
+
activation_layers = []
|
|
408
|
+
for activation_type in (nn.ReLU, nn.Sigmoid, nn.Tanh, sl.NeuromorphicReLU):
|
|
409
|
+
for idx in unsupported_nodes.pop(activation_type, []):
|
|
410
|
+
activation_layers.append(self.indx_2_module_map[idx])
|
|
411
|
+
if activation_layers:
|
|
412
|
+
layer_str = ", ".join(str(lyr) for lyr in (activation_layers))
|
|
413
|
+
raise UnsupportedLayerType(
|
|
414
|
+
"The provided SNN contains non-spiking activation layers:\n"
|
|
415
|
+
f"{layer_str}.\nPlease convert them to `IAF` or `IAFSqueeze` "
|
|
416
|
+
"layers before instantiating a `DynapcnnNetwork`. You can "
|
|
417
|
+
"use the function `sinabs.from_model.from_torch` for this."
|
|
418
|
+
)
|
|
419
|
+
if unsupported_nodes:
|
|
420
|
+
# More generic error message for all remaining types
|
|
421
|
+
raise UnsupportedLayerType(
|
|
422
|
+
"One or more layers in the provided SNN are not supported: "
|
|
423
|
+
f"{pformat(unsupported_nodes)}. Supported layer types are: "
|
|
424
|
+
f"{pformat(SupportedNodeTypes)}."
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
def verify_no_isolated_nodes(self):
|
|
428
|
+
"""Verify that there are no disconnected nodes except for `DVSLayer` instances.
|
|
429
|
+
|
|
430
|
+
Raises:
|
|
431
|
+
InvalidGraphStructure when disconnected nodes are detected.
|
|
432
|
+
"""
|
|
433
|
+
for node, module in self.indx_2_module_map.items():
|
|
434
|
+
# Make sure there are no individual, unconnected nodes
|
|
435
|
+
edges_with_node = {e for e in self.edges if node in e}
|
|
436
|
+
if not edges_with_node and not isinstance(module, DVSLayer):
|
|
437
|
+
raise InvalidGraphStructure(
|
|
438
|
+
f"There is an isolated module of type {type(module)}. Only "
|
|
439
|
+
"`DVSLayer` instances can be completely disconnected from "
|
|
440
|
+
"any other module. Other than that, layers for DynapCNN "
|
|
441
|
+
"consist of groups of weight layers (`Linear` or `Conv2d`), "
|
|
442
|
+
"spiking layers (`IAF` or `IAFSqueeze`), and optioanlly "
|
|
443
|
+
"pooling layers (`SumPool2d`, `AvgPool2d`)."
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
def _handle_dvs_input(
|
|
447
|
+
self, input_shape: Tuple[int, int, int], dvs_input: Optional[bool] = None
|
|
448
|
+
):
|
|
449
|
+
"""Make sure DVS input is properly integrated into graph
|
|
450
|
+
|
|
451
|
+
Decide whether `DVSLayer` instance needs to be added to the graph. This
|
|
452
|
+
is the case when `dvs_input==True` and there is no `DVSLayer` yet.
|
|
453
|
+
Make sure edges between DVS related nodes are set properly.
|
|
454
|
+
Absorb pooling layers in DVS node if applicable.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
input_shape (tuple of three integers): Input shape (features,
|
|
458
|
+
height, width).
|
|
459
|
+
dvs_input (bool or `None` (default)): If `False`, will raise
|
|
460
|
+
`InvalidModelWithDvsSetup` if a `DVSLayer` is part of the
|
|
461
|
+
graph. If `True`, a `DVSLayer` will be added to the graph if
|
|
462
|
+
there is none already. If `None`, the model is considered to be
|
|
463
|
+
using DVS input only if the graph contains a `DVSLayer`.
|
|
464
|
+
"""
|
|
465
|
+
if self.has_dvs_layer:
|
|
466
|
+
# Make a copy of the layer so that the original version is not
|
|
467
|
+
# changed in place
|
|
468
|
+
new_dvs_layer = deepcopy(self.dvs_layer)
|
|
469
|
+
self._indx_2_module_map[self.dvs_node_id] = new_dvs_layer
|
|
470
|
+
elif dvs_input:
|
|
471
|
+
# Insert a DVSLayer node in the graph.
|
|
472
|
+
new_dvs_layer = self._add_dvs_node(dvs_input_shape=input_shape)
|
|
473
|
+
else:
|
|
474
|
+
dvs_input = None
|
|
475
|
+
if dvs_input is not None:
|
|
476
|
+
# Disable pixel array if `dvs_input` is False
|
|
477
|
+
new_dvs_layer.disable_pixel_array = not dvs_input
|
|
478
|
+
|
|
479
|
+
# Check for the need of fixing NIR edges extraction when DVS is a node in the graph. If DVS
|
|
480
|
+
# is used its node becomes the only entry node in the graph.
|
|
481
|
+
fix_dvs_module_edges(
|
|
482
|
+
self._edges,
|
|
483
|
+
self._indx_2_module_map,
|
|
484
|
+
self._name_2_indx_map,
|
|
485
|
+
self._entry_nodes,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# Check if graph structure and DVSLayer.merge_polarities are correctly set (if DVS node exists).
|
|
489
|
+
self._validate_dvs_setup(dvs_input_shape=input_shape)
|
|
490
|
+
|
|
491
|
+
def _add_dvs_node(self, dvs_input_shape: Tuple[int, int, int]) -> DVSLayer:
|
|
492
|
+
"""In-place modification of `self._name_2_indx_map`,
|
|
493
|
+
`self._indx_2_module_map`, and `self._edges` to accomodate the creation
|
|
494
|
+
of an extra node in the graph representing the DVS camera of the chip.
|
|
495
|
+
The DVSLayer node will point to every other node that is up to this
|
|
496
|
+
point an entry node of the original graph, so `self._entry_nodes` is
|
|
497
|
+
modified in-place to have only one entry: the index of the DVS node.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
dvs_input_shape (tuple): shape of the DVSLayer input in format
|
|
501
|
+
`(features, height, width)`
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
A handler to the newly added `DVSLayer` instance.
|
|
505
|
+
"""
|
|
506
|
+
|
|
507
|
+
(features, height, width) = dvs_input_shape
|
|
508
|
+
if features > 2:
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"A DVSLayer istance can have a max feature dimension of 2 but {features} was given."
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Find new index to be assigned to DVS node
|
|
514
|
+
self._name_2_indx_map["dvs"] = get_new_index(self._name_2_indx_map.values())
|
|
515
|
+
# add module entry for node 'dvs'.
|
|
516
|
+
dvs_layer = DVSLayer(
|
|
517
|
+
input_shape=(height, width),
|
|
518
|
+
merge_polarities=(features == 1),
|
|
519
|
+
)
|
|
520
|
+
self._indx_2_module_map[self._name_2_indx_map["dvs"]] = dvs_layer
|
|
521
|
+
|
|
522
|
+
# set DVS node as input to each entry node of the graph
|
|
523
|
+
self._edges.update(
|
|
524
|
+
{
|
|
525
|
+
(self._name_2_indx_map["dvs"], entry_node)
|
|
526
|
+
for entry_node in self._entry_nodes
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
# DVSLayer node becomes the only entrypoint of the graph
|
|
530
|
+
self._entry_nodes = {self._name_2_indx_map["dvs"]}
|
|
531
|
+
|
|
532
|
+
return dvs_layer
|
|
533
|
+
|
|
534
|
+
def _get_dvs_node_id(self) -> Union[int, None]:
|
|
535
|
+
"""Return index of `DVSLayer`
|
|
536
|
+
instance if it exists.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
DVSLayer if exactly one is found, otherwise None.
|
|
540
|
+
|
|
541
|
+
Raises:
|
|
542
|
+
InvalidGraphStructure: if more than one DVSLayer is found.
|
|
543
|
+
|
|
544
|
+
"""
|
|
545
|
+
|
|
546
|
+
dvs_layer_indices = {
|
|
547
|
+
index
|
|
548
|
+
for index, module in self._indx_2_module_map.items()
|
|
549
|
+
if isinstance(module, DVSLayer)
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
if (num_dvs := len(dvs_layer_indices)) == 0:
|
|
553
|
+
return
|
|
554
|
+
elif num_dvs == 1:
|
|
555
|
+
return dvs_layer_indices.pop()
|
|
556
|
+
else:
|
|
557
|
+
raise InvalidGraphStructure(
|
|
558
|
+
f"The provided model has {num_dvs} `DVSLayer`s. At most one is allowed."
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def _validate_dvs_setup(self, dvs_input_shape: Tuple[int, int, int]) -> None:
|
|
562
|
+
"""If a DVSLayer node exists, makes sure it is the only entry node of
|
|
563
|
+
the graph. Checks if its `merge_polarities` attribute matches
|
|
564
|
+
`dummy_input.shape[0]` (the number of features) and, if not, it will be
|
|
565
|
+
set based on the number of features of the input.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
dvs_input_shape (tuple): shape of the DVSLayer input in format
|
|
569
|
+
`(features, height, width)`.
|
|
570
|
+
"""
|
|
571
|
+
|
|
572
|
+
if self.dvs_layer is None:
|
|
573
|
+
# No DVSLayer found - nothing to do here.
|
|
574
|
+
return
|
|
575
|
+
|
|
576
|
+
if (nb_entries := len(self._entry_nodes)) > 1:
|
|
577
|
+
raise ValueError(
|
|
578
|
+
f"A DVSLayer node exists and there are {nb_entries} entry nodes in the graph: the DVSLayer should be the only entry node."
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
(features, _, _) = dvs_input_shape
|
|
582
|
+
|
|
583
|
+
if features > 2:
|
|
584
|
+
raise ValueError(
|
|
585
|
+
f"A DVSLayer istance can have the feature dimension of its inputs with values 1 or 2 but {features} was given."
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
if self.dvs_layer.merge_polarities and features != 1:
|
|
589
|
+
raise ValueError(
|
|
590
|
+
f"The 'DVSLayer.merge_polarities' is set to 'True' which means the number of input features should be 1 (current input shape is {dvs_input_shape})."
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if features == 1:
|
|
594
|
+
self.dvs_layer.merge_polarities = True
|
|
595
|
+
|
|
596
|
+
def _get_name_2_indx_map(self, nir_graph: TorchGraph) -> Dict[str, int]:
|
|
597
|
+
"""Assign unique index to each node and return mapper from name to
|
|
598
|
+
index.
|
|
599
|
+
|
|
600
|
+
Args:
|
|
601
|
+
nir_graph (TorchGraph): a NIR graph representation of
|
|
602
|
+
`spiking_model`.
|
|
603
|
+
|
|
604
|
+
Returns:
|
|
605
|
+
A dictionary where `key` is the original variable name for a layer
|
|
606
|
+
in `spiking_model` and `value` is an integer representing the layer
|
|
607
|
+
in a standard format.
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
return {
|
|
611
|
+
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list)
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
def _get_edges_from_nir(
|
|
615
|
+
self, nir_graph: TorchGraph, name_2_indx_map: Dict[str, int]
|
|
616
|
+
) -> Set[Edge]:
|
|
617
|
+
"""Standardize the representation of TorchGraph` into a list of edges,
|
|
618
|
+
representing nodes by their indices.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
nir_graph (TorchGraph): a NIR graph representation of
|
|
622
|
+
`spiking_model`.
|
|
623
|
+
name_2_indx_map (dict): Map from node names to unique indices.
|
|
624
|
+
|
|
625
|
+
Returns:
|
|
626
|
+
Tuple describing the connections between layers in `spiking_model`.
|
|
627
|
+
"""
|
|
628
|
+
return {
|
|
629
|
+
(name_2_indx_map[src.name], name_2_indx_map[tgt.name])
|
|
630
|
+
for src in nir_graph.node_list
|
|
631
|
+
for tgt in src.outgoing_nodes
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
def _get_entry_nodes(self, edges: Set[Edge]) -> Set[Edge]:
|
|
635
|
+
"""Find nodes that act as entry points to the graph
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
edges (set): tuples describing the connections between layers in
|
|
639
|
+
`spiking_model`.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
IDs of nodes acting as entry points for the network (i.e.,
|
|
643
|
+
receiving external input).
|
|
644
|
+
"""
|
|
645
|
+
if not edges:
|
|
646
|
+
return set()
|
|
647
|
+
|
|
648
|
+
all_sources, all_targets = zip(*edges)
|
|
649
|
+
return set(all_sources) - set(all_targets)
|
|
650
|
+
|
|
651
|
+
def _get_name2module_map(self, model: nn.Module) -> Dict[int, nn.Module]:
|
|
652
|
+
"""Find for each node in the graph what its associated layer in `model` is.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
model (nn.Module): the `spiking_model` used as argument to the class
|
|
656
|
+
instance.
|
|
657
|
+
Returns:
|
|
658
|
+
The mapping between a node (`key` as an `int`) and its module
|
|
659
|
+
(`value` as a `nn.Module`).
|
|
660
|
+
"""
|
|
661
|
+
|
|
662
|
+
indx_2_module_map = dict()
|
|
663
|
+
|
|
664
|
+
for name, module in model.named_modules():
|
|
665
|
+
# Make sure names match those provided by nirtorch nodes
|
|
666
|
+
if name in self._name_2_indx_map:
|
|
667
|
+
indx_2_module_map[self._name_2_indx_map[name]] = module
|
|
668
|
+
else:
|
|
669
|
+
# In older nirtorch versions, node names are "sanitized"
|
|
670
|
+
# Try with sanitized version of the name
|
|
671
|
+
name = nirtorch.utils.sanitize_name(name)
|
|
672
|
+
if name in self._name_2_indx_map:
|
|
673
|
+
indx_2_module_map[self._name_2_indx_map[name]] = module
|
|
674
|
+
|
|
675
|
+
return indx_2_module_map
|
|
676
|
+
|
|
677
|
+
def _update_internal_representation(self, remapped_nodes: Dict[int, int]):
|
|
678
|
+
"""Update internal attributes after remapping of nodes
|
|
679
|
+
|
|
680
|
+
Args:
|
|
681
|
+
remapped_nodes (dict): Maps previous (key) to new (value) node
|
|
682
|
+
indices. Nodes that were removed are not included.
|
|
683
|
+
"""
|
|
684
|
+
|
|
685
|
+
if len(self._name_2_indx_map) > 0:
|
|
686
|
+
# Update name-to-index map based on new node indices
|
|
687
|
+
self._name_2_indx_map = {
|
|
688
|
+
name: remapped_nodes[old_idx]
|
|
689
|
+
for name, old_idx in self._name_2_indx_map.items()
|
|
690
|
+
if old_idx in remapped_nodes
|
|
691
|
+
}
|
|
692
|
+
|
|
693
|
+
# Update entry nodes based on new node indices
|
|
694
|
+
self._entry_nodes = {
|
|
695
|
+
remapped_nodes[old_idx]
|
|
696
|
+
for old_idx in self._entry_nodes
|
|
697
|
+
if old_idx in remapped_nodes
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
# Update io-shapes based on new node indices
|
|
701
|
+
self._nodes_io_shapes = {
|
|
702
|
+
remapped_nodes[old_idx]: shape
|
|
703
|
+
for old_idx, shape in self._nodes_io_shapes.items()
|
|
704
|
+
if old_idx in remapped_nodes
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
# Update sinabs module map based on new node indices
|
|
708
|
+
self._indx_2_module_map = {
|
|
709
|
+
remapped_nodes[old_idx]: module
|
|
710
|
+
for old_idx, module in self._indx_2_module_map.items()
|
|
711
|
+
if old_idx in remapped_nodes
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
def _sort_graph_nodes(self) -> List[int]:
|
|
715
|
+
"""Sort graph nodes topologically.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
sorted_nodes (list of integers): IDs of nodes, sorted.
|
|
719
|
+
"""
|
|
720
|
+
# Make a temporary copy of edges and include inputs
|
|
721
|
+
temp_edges = self.edges
|
|
722
|
+
for node in self._entry_nodes:
|
|
723
|
+
temp_edges.add(("input", node))
|
|
724
|
+
return topological_sorting(temp_edges)
|
|
725
|
+
|
|
726
|
+
def _get_nodes_io_shapes(
|
|
727
|
+
self, input_dummy: torch.tensor
|
|
728
|
+
) -> Dict[int, Dict[str, torch.Size]]:
|
|
729
|
+
"""Iteratively calls the forward method of each `nn.Module` (i.e., a
|
|
730
|
+
layer/node in the graph) using the topologically sorted nodes extracted
|
|
731
|
+
from the computational graph of the model being parsed.
|
|
732
|
+
|
|
733
|
+
Args:
|
|
734
|
+
input_dummy (torch.tensor): a sample (random) tensor of the sort of
|
|
735
|
+
input being fed to the network.
|
|
736
|
+
|
|
737
|
+
Returns:
|
|
738
|
+
A dictionary mapping nodes to their I/O shapes.
|
|
739
|
+
"""
|
|
740
|
+
nodes_io_map = {}
|
|
741
|
+
|
|
742
|
+
# propagate inputs through the nodes.
|
|
743
|
+
for node in self.sorted_nodes:
|
|
744
|
+
if isinstance(self.indx_2_module_map[node], sl.merge.Merge):
|
|
745
|
+
# find `Merge` arguments (at this point the inputs to Merge should have been calculated).
|
|
746
|
+
input_nodes = self._find_merge_arguments(node)
|
|
747
|
+
|
|
748
|
+
# retrieve arguments output tensors.
|
|
749
|
+
inputs = [nodes_io_map[n]["output"] for n in input_nodes]
|
|
750
|
+
|
|
751
|
+
# TODO - this is currently a limitation imposed by the validation checks done by Speck once a configuration: it wants
|
|
752
|
+
# different input sources to a core to have the same output shapes.
|
|
753
|
+
if any(inp.shape != inputs[0].shape for inp in inputs):
|
|
754
|
+
raise ValueError(
|
|
755
|
+
f"Layer `sinabs.layers.merge.Merge` (node {node}) requires input tensors with the same shape"
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
# forward input through the node.
|
|
759
|
+
_output = self.indx_2_module_map[node](*inputs)
|
|
760
|
+
|
|
761
|
+
# save node's I/O tensors.
|
|
762
|
+
nodes_io_map[node] = {"input": inputs[0], "output": _output}
|
|
763
|
+
|
|
764
|
+
else:
|
|
765
|
+
if node in self._entry_nodes:
|
|
766
|
+
# forward input dummy through node.
|
|
767
|
+
_output = self.indx_2_module_map[node](input_dummy)
|
|
768
|
+
|
|
769
|
+
# save node's I/O tensors.
|
|
770
|
+
nodes_io_map[node] = {"input": input_dummy, "output": _output}
|
|
771
|
+
|
|
772
|
+
else:
|
|
773
|
+
# find node generating the input to be used.
|
|
774
|
+
input_node = self._find_source_of_input_to(node)
|
|
775
|
+
_input = nodes_io_map[input_node]["output"]
|
|
776
|
+
|
|
777
|
+
# forward input through the node.
|
|
778
|
+
_output = self.indx_2_module_map[node](_input)
|
|
779
|
+
|
|
780
|
+
# save node's I/O tensors.
|
|
781
|
+
nodes_io_map[node] = {"input": _input, "output": _output}
|
|
782
|
+
|
|
783
|
+
# replace the I/O tensor information by its shape information, ignoring the batch/time axis
|
|
784
|
+
for node, io in nodes_io_map.items():
|
|
785
|
+
input_shape = io["input"].shape[1:]
|
|
786
|
+
output_shape = io["output"].shape[1:]
|
|
787
|
+
# Linear layers have fewer in/out dimensions. Extend by appending 1's
|
|
788
|
+
if (length := len(input_shape)) < 3:
|
|
789
|
+
input_shape = (*input_shape, *(1 for __ in range(3 - length)))
|
|
790
|
+
assert len(input_shape) == 3
|
|
791
|
+
if (length := len(output_shape)) < 3:
|
|
792
|
+
output_shape = (*output_shape, *(1 for __ in range(3 - length)))
|
|
793
|
+
assert len(output_shape) == 3
|
|
794
|
+
nodes_io_map[node]["input"] = input_shape
|
|
795
|
+
nodes_io_map[node]["output"] = output_shape
|
|
796
|
+
|
|
797
|
+
return nodes_io_map
|
|
798
|
+
|
|
799
|
+
def _find_all_sources_of_input_to(self, node: int) -> Set[int]:
|
|
800
|
+
"""Finds all source nodes to `node`.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
node (int): the node in the computational graph for which we whish
|
|
804
|
+
to find the input source (either another node in the graph or
|
|
805
|
+
the original input itself to the network).
|
|
806
|
+
|
|
807
|
+
Returns:
|
|
808
|
+
IDs of the nodes in the computational graph providing the input to `node`.
|
|
809
|
+
"""
|
|
810
|
+
return set(src for (src, tgt) in self._edges if tgt == node)
|
|
811
|
+
|
|
812
|
+
def _find_source_of_input_to(self, node: int) -> int:
|
|
813
|
+
"""Finds the first edge `(X, node)` returns `X`.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
node (int): the node in the computational graph for which we whish
|
|
817
|
+
to find the input source (either another node in the graph or
|
|
818
|
+
the original input itself to the network).
|
|
819
|
+
|
|
820
|
+
Returns:
|
|
821
|
+
ID of the node in the computational graph providing the input to
|
|
822
|
+
`node`. If `node` is receiving outside input (i.e., it is a starting
|
|
823
|
+
node) the return will be -1. For example, this will be the case when
|
|
824
|
+
a network with two independent branches (each starts from a
|
|
825
|
+
different "input node") merge along the computational graph.
|
|
826
|
+
"""
|
|
827
|
+
sources = self._find_all_sources_of_input_to(node)
|
|
828
|
+
if len(sources) == 0:
|
|
829
|
+
return -1
|
|
830
|
+
if len(sources) > 1:
|
|
831
|
+
raise RuntimeError(f"Node {node} has more than 1 input")
|
|
832
|
+
return sources.pop()
|
|
833
|
+
|
|
834
|
+
def _find_merge_arguments(self, node: int) -> Edge:
|
|
835
|
+
"""A `Merge` layer receives two inputs. Return the two inputs to
|
|
836
|
+
`merge_node` representing a `Merge` layer.
|
|
837
|
+
|
|
838
|
+
Returns:
|
|
839
|
+
The IDs of the nodes that provice the input arguments to a `Merge` layer.
|
|
840
|
+
"""
|
|
841
|
+
sources = self._find_all_sources_of_input_to(node)
|
|
842
|
+
|
|
843
|
+
if len(sources) != 2:
|
|
844
|
+
raise ValueError(
|
|
845
|
+
f"Number of arguments found for `Merge` node {node} is {len(sources)} (should be 2)."
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
return tuple(sources)
|
|
849
|
+
|
|
850
|
+
def _find_valid_targets(
|
|
851
|
+
self, node: int, ignored_node_classes: Tuple[Type] = ()
|
|
852
|
+
) -> Set[int]:
|
|
853
|
+
"""Find all targets of a node that are not ignored classes
|
|
854
|
+
|
|
855
|
+
Return a set of all target nodes that are not of an ignored class.
|
|
856
|
+
For target nodes of ignored classes, recursively return their valid
|
|
857
|
+
targets.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
node (int): ID of node whose targets should be found.
|
|
861
|
+
ignored_node_classes (tuple of types): Classes of which nodes
|
|
862
|
+
should be skiped
|
|
863
|
+
|
|
864
|
+
Returns:
|
|
865
|
+
Set of all recursively found valid target IDs.
|
|
866
|
+
"""
|
|
867
|
+
targets = set()
|
|
868
|
+
for src, tgt in self.edges:
|
|
869
|
+
# Search for all edges with node as source
|
|
870
|
+
if src == node:
|
|
871
|
+
if isinstance(self.indx_2_module_map[tgt], ignored_node_classes):
|
|
872
|
+
# Find valid targets of target
|
|
873
|
+
targets.update(self._find_valid_targets(tgt, ignored_node_classes))
|
|
874
|
+
else:
|
|
875
|
+
# Target is valid, add it to `targets`
|
|
876
|
+
targets.add(tgt)
|
|
877
|
+
return targets
|