sinabs 3.0.4.dev25__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sinabs/activation/reset_mechanism.py +3 -3
  2. sinabs/activation/surrogate_gradient_fn.py +4 -4
  3. sinabs/backend/dynapcnn/__init__.py +5 -4
  4. sinabs/backend/dynapcnn/chip_factory.py +33 -61
  5. sinabs/backend/dynapcnn/chips/dynapcnn.py +182 -86
  6. sinabs/backend/dynapcnn/chips/speck2e.py +6 -5
  7. sinabs/backend/dynapcnn/chips/speck2f.py +6 -5
  8. sinabs/backend/dynapcnn/config_builder.py +39 -59
  9. sinabs/backend/dynapcnn/connectivity_specs.py +48 -0
  10. sinabs/backend/dynapcnn/discretize.py +91 -155
  11. sinabs/backend/dynapcnn/dvs_layer.py +59 -101
  12. sinabs/backend/dynapcnn/dynapcnn_layer.py +185 -119
  13. sinabs/backend/dynapcnn/dynapcnn_layer_utils.py +335 -0
  14. sinabs/backend/dynapcnn/dynapcnn_network.py +602 -325
  15. sinabs/backend/dynapcnn/dynapcnnnetwork_module.py +370 -0
  16. sinabs/backend/dynapcnn/exceptions.py +122 -3
  17. sinabs/backend/dynapcnn/io.py +55 -92
  18. sinabs/backend/dynapcnn/mapping.py +111 -75
  19. sinabs/backend/dynapcnn/nir_graph_extractor.py +877 -0
  20. sinabs/backend/dynapcnn/sinabs_edges_handler.py +1024 -0
  21. sinabs/backend/dynapcnn/utils.py +214 -459
  22. sinabs/backend/dynapcnn/weight_rescaling_methods.py +53 -0
  23. sinabs/conversion.py +2 -2
  24. sinabs/from_torch.py +23 -1
  25. sinabs/hooks.py +38 -41
  26. sinabs/layers/alif.py +16 -16
  27. sinabs/layers/crop2d.py +2 -2
  28. sinabs/layers/exp_leak.py +1 -1
  29. sinabs/layers/iaf.py +11 -11
  30. sinabs/layers/lif.py +9 -9
  31. sinabs/layers/neuromorphic_relu.py +9 -8
  32. sinabs/layers/pool2d.py +5 -5
  33. sinabs/layers/quantize.py +1 -1
  34. sinabs/layers/stateful_layer.py +10 -7
  35. sinabs/layers/to_spike.py +9 -9
  36. sinabs/network.py +14 -12
  37. sinabs/nir.py +4 -3
  38. sinabs/synopcounter.py +10 -7
  39. sinabs/utils.py +155 -7
  40. sinabs/validate_memory_speck.py +0 -5
  41. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/METADATA +3 -2
  42. sinabs-3.1.1.dist-info/RECORD +65 -0
  43. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/licenses/AUTHORS +1 -0
  44. sinabs-3.1.1.dist-info/pbr.json +1 -0
  45. sinabs-3.0.4.dev25.dist-info/RECORD +0 -59
  46. sinabs-3.0.4.dev25.dist-info/pbr.json +0 -1
  47. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/WHEEL +0 -0
  48. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/licenses/LICENSE +0 -0
  49. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/top_level.txt +0 -0
@@ -1,450 +1,28 @@
1
+ from collections import defaultdict, deque
1
2
  from copy import deepcopy
2
- from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
3
+ from typing import TYPE_CHECKING, List, Optional, Set, Tuple, TypeVar, Union
3
4
 
4
5
  import torch
5
6
  import torch.nn as nn
7
+ import warnings
6
8
 
7
- import sinabs
8
9
  import sinabs.layers as sl
9
10
 
10
11
  from .crop2d import Crop2d
11
- from .dvs_layer import DVSLayer, expand_to_pair
12
- from .dynapcnn_layer import DynapcnnLayer
13
- from .exceptions import InputConfigurationError, MissingLayer, UnexpectedLayer
14
- from .flipdims import FlipDims
12
+ from .dvs_layer import DVSLayer
13
+ from .exceptions import InputConfigurationError
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from sinabs.backend.dynapcnn.dynapcnn_network import DynapcnnNetwork
18
17
 
19
- DEFAULT_IGNORED_LAYER_TYPES = (nn.Identity, nn.Dropout, nn.Dropout2d, nn.Flatten)
18
+ # Other than `COMPLETELY_IGNORED_LAYER_TYPES`, `IGNORED_LAYER_TYPES` are
19
+ # part of the graph initially and are needed to ensure proper handling of
20
+ # graph structure (e.g. Merge nodes) or meta-information (e.g.
21
+ # `nn.Flatten` for io-shapes)
22
+ COMPLETELY_IGNORED_LAYER_TYPES = (nn.Identity, nn.Dropout, nn.Dropout2d)
23
+ IGNORED_LAYER_TYPES = (nn.Flatten, sl.Merge)
20
24
 
21
-
22
- def infer_input_shape(
23
- layers: List[nn.Module], input_shape: Optional[Tuple[int, int, int]] = None
24
- ) -> Tuple[int, int, int]:
25
- """Checks if the input_shape is specified. If either of them are specified, then it checks if
26
- the information is consistent and returns the input shape.
27
-
28
- Parameters
29
- ----------
30
- layers:
31
- List of modules
32
- input_shape :
33
- (channels, height, width)
34
-
35
- Returns
36
- -------
37
- Output shape:
38
- (channels, height, width)
39
- """
40
- if input_shape is not None and len(input_shape) != 3:
41
- raise InputConfigurationError(
42
- f"input_shape expected to have length 3 or None but input_shape={input_shape} given."
43
- )
44
-
45
- input_shape_from_layer = None
46
- if layers and isinstance(layers[0], DVSLayer):
47
- input_shape_from_layer = layers[0].input_shape
48
- if len(input_shape_from_layer) != 3:
49
- raise InputConfigurationError(
50
- f"input_shape of layer {layers[0]} expected to have length 3 or None but input_shape={input_shape_from_layer} found."
51
- )
52
- if (input_shape is not None) and (input_shape_from_layer is not None):
53
- if input_shape == input_shape_from_layer:
54
- return input_shape
55
- else:
56
- raise InputConfigurationError(
57
- f"Input shape from the layer {input_shape_from_layer} does not match the specified input_shape {input_shape}"
58
- )
59
- elif input_shape_from_layer is not None:
60
- return input_shape_from_layer
61
- elif input_shape is not None:
62
- return input_shape
63
- else:
64
- raise InputConfigurationError("No input shape could be inferred")
65
-
66
-
67
- def convert_cropping2dlayer_to_crop2d(
68
- layer: sl.Cropping2dLayer, input_shape: Tuple[int, int]
69
- ) -> Crop2d:
70
- """Convert a sinabs layer of type Cropping2dLayer to Crop2d layer.
71
-
72
- Parameters
73
- ----------
74
- layer:
75
- Cropping2dLayer
76
- input_shape:
77
- (height, width) input dimensions
78
-
79
- Returns
80
- -------
81
- Equivalent Crop2d layer
82
- """
83
- h, w = input_shape
84
- top = layer.top_crop
85
- left = layer.left_crop
86
- bottom = h - layer.bottom_crop
87
- right = w - layer.right_crop
88
- print(h, w, left, right, top, bottom, layer.right_crop, layer.bottom_crop)
89
- return Crop2d(((top, bottom), (left, right)))
90
-
91
-
92
- def construct_dvs_layer(
93
- layers: List[nn.Module],
94
- input_shape: Tuple[int, int, int],
95
- idx_start: int = 0,
96
- dvs_input: bool = False,
97
- ) -> Tuple[Optional[DVSLayer], int, float]:
98
- """
99
- Generate a DVSLayer given a list of layers. If `layers` does not start
100
- with a pooling, cropping or flipping layer and `dvs_input` is False,
101
- will return `None` instead of a DVSLayer.
102
- NOTE: The number of channels is implicitly assumed to be 2 because of DVS
103
-
104
- Parameters
105
- ----------
106
- layers:
107
- List of layers
108
- input_shape:
109
- Shape of input (channels, height, width)
110
- idx_start:
111
- Starting index to scan the list. Default 0
112
-
113
- Returns
114
- -------
115
- dvs_layer:
116
- None or DVSLayer
117
- idx_next: int or None
118
- Index of first layer after this layer is constructed
119
- rescale_factor: float
120
- Rescaling factor needed when turning AvgPool to SumPool. May
121
- differ from the pooling kernel in certain cases.
122
- dvs_input: bool
123
- Whether DVSLayer should have pixel array activated.
124
- """
125
- # Start with defaults
126
- layer_idx_next = idx_start
127
- crop_lyr = None
128
- flip_lyr = None
129
-
130
- if len(input_shape) != 3:
131
- raise ValueError(
132
- f"Input shape should be 3 dimensional but input_shape={input_shape} was given."
133
- )
134
-
135
- # Return existing DVS layer as is
136
- if len(layers) and isinstance(layers[0], DVSLayer):
137
- return deepcopy(layers[0]), 1, 1
138
-
139
- # Construct pooling layer
140
- pool_lyr, layer_idx_next, rescale_factor = construct_next_pooling_layer(
141
- layers, layer_idx_next
142
- )
143
-
144
- # Find next layer (check twice for two layers)
145
- for __ in range(2):
146
- # Go to the next layer
147
- if layer_idx_next < len(layers):
148
- layer = layers[layer_idx_next]
149
- else:
150
- break
151
- # Check layer type
152
- if isinstance(layer, sl.Cropping2dLayer):
153
- # The shape after pooling is
154
- pool = expand_to_pair(pool_lyr.kernel_size)
155
- h = input_shape[1] // pool[0]
156
- w = input_shape[2] // pool[1]
157
- print(f"Input shape to the cropping layer is {h}, {w}")
158
- crop_lyr = convert_cropping2dlayer_to_crop2d(layer, (h, w))
159
- elif isinstance(layer, Crop2d):
160
- crop_lyr = layer
161
- elif isinstance(layer, FlipDims):
162
- flip_lyr = layer
163
- else:
164
- break
165
-
166
- layer_idx_next += 1
167
-
168
- # If any parameters have been found or dvs_input is True
169
- if (layer_idx_next > 0) or dvs_input:
170
- dvs_layer = DVSLayer.from_layers(
171
- pool_layer=pool_lyr,
172
- crop_layer=crop_lyr,
173
- flip_layer=flip_lyr,
174
- input_shape=input_shape,
175
- disable_pixel_array=not dvs_input,
176
- )
177
- return dvs_layer, layer_idx_next, rescale_factor
178
- else:
179
- # No parameters/layers pertaining to DVS preprocessing found
180
- return None, 0, 1
181
-
182
-
183
- def merge_conv_bn(conv, bn):
184
- """Merge a convolutional layer with subsequent batch normalization.
185
-
186
- Parameters
187
- ----------
188
- conv: torch.nn.Conv2d
189
- Convolutional layer
190
- bn: torch.nn.Batchnorm2d
191
- Batch normalization
192
-
193
- Returns
194
- -------
195
- torch.nn.Conv2d: Convolutional layer including batch normalization
196
- """
197
- mu = bn.running_mean
198
- sigmasq = bn.running_var
199
-
200
- if bn.affine:
201
- gamma, beta = bn.weight, bn.bias
202
- else:
203
- gamma, beta = 1.0, 0.0
204
-
205
- factor = gamma / sigmasq.sqrt()
206
-
207
- c_weight = conv.weight.data.clone().detach()
208
- c_bias = 0.0 if conv.bias is None else conv.bias.data.clone().detach()
209
-
210
- conv = deepcopy(conv) # TODO: this will cause copying twice
211
-
212
- conv.weight.data = c_weight * factor[:, None, None, None]
213
- conv.bias.data = beta + (c_bias - mu) * factor
214
-
215
- return conv
216
-
217
-
218
- def construct_next_pooling_layer(
219
- layers: List[nn.Module], idx_start: int
220
- ) -> Tuple[Optional[sl.SumPool2d], int, float]:
221
- """Consolidate the first `AvgPool2d` objects in `layers` until the first object of different
222
- type.
223
-
224
- Parameters
225
- ----------
226
- layers: Sequence of layer objects
227
- Contains `AvgPool2d` and other objects.
228
- idx_start: int
229
- Layer index to start construction from
230
- Returns
231
- -------
232
- lyr_pool: int or tuple of ints
233
- Consolidated pooling size.
234
- idx_next: int
235
- Index of first object in `layers` that is not a `AvgPool2d`,
236
- rescale_factor: float
237
- Rescaling factor needed when turning AvgPool to SumPool. May
238
- differ from the pooling kernel in certain cases.
239
- """
240
-
241
- rescale_factor = 1
242
- cumulative_pooling = expand_to_pair(1)
243
-
244
- idx_next = idx_start
245
- # Figure out pooling dims
246
- while idx_next < len(layers):
247
- lyr = layers[idx_next]
248
- if isinstance(lyr, nn.AvgPool2d):
249
- if lyr.padding != 0:
250
- raise ValueError("Padding is not supported for the pooling layers")
251
- elif isinstance(lyr, sl.SumPool2d):
252
- ...
253
- else:
254
- # Reached a non pooling layer
255
- break
256
- # Increment if it is a pooling layer
257
- idx_next += 1
258
-
259
- pooling = expand_to_pair(lyr.kernel_size)
260
- if lyr.stride is not None:
261
- stride = expand_to_pair(lyr.stride)
262
- if pooling != stride:
263
- raise ValueError(
264
- f"Stride length {lyr.stride} should be the same as pooling kernel size {lyr.kernel_size}"
265
- )
266
- # Compute cumulative pooling
267
- cumulative_pooling = (
268
- cumulative_pooling[0] * pooling[0],
269
- cumulative_pooling[1] * pooling[1],
270
- )
271
- # Update rescaling factor
272
- if isinstance(lyr, nn.AvgPool2d):
273
- rescale_factor *= pooling[0] * pooling[1]
274
-
275
- # If there are no layers
276
- if cumulative_pooling == (1, 1):
277
- return None, idx_next, 1
278
- else:
279
- lyr_pool = sl.SumPool2d(cumulative_pooling)
280
- return lyr_pool, idx_next, rescale_factor
281
-
282
-
283
- def construct_next_dynapcnn_layer(
284
- layers: List[nn.Module],
285
- idx_start: int,
286
- in_shape: Tuple[int, int, int],
287
- discretize: bool,
288
- rescale_factor: float = 1,
289
- ) -> Tuple[DynapcnnLayer, int, float]:
290
- """Generate a DynapcnnLayer from a Conv2d layer and its subsequent spiking and pooling layers.
291
-
292
- Parameters
293
- ----------
294
-
295
- layers: sequence of layer objects
296
- First object must be Conv2d, next must be an IAF layer. All pooling
297
- layers that follow immediately are consolidated. Layers after this
298
- will be ignored.
299
- idx_start:
300
- Layer index to start construction from
301
- in_shape: tuple of integers
302
- Shape of the input to the first layer in `layers`. Convention:
303
- (input features, height, width)
304
- discretize: bool
305
- Discretize weights and thresholds if True
306
- rescale_factor: float
307
- Weights of Conv2d layer are scaled down by this factor. Can be
308
- used to account for preceding average pooling that gets converted
309
- to sum pooling.
310
-
311
- Returns
312
- -------
313
- dynapcnn_layer: DynapcnnLayer
314
- DynapcnnLayer
315
- layer_idx_next: int
316
- Index of the next layer after this layer is constructed
317
- rescale_factor: float
318
- rescaling factor to account for average pooling
319
- """
320
- layer_idx_next = idx_start # Keep track of layer indices
321
-
322
- # Check that the first layer is Conv2d, or Linear
323
- if not isinstance(layers[layer_idx_next], (nn.Conv2d, nn.Linear)):
324
- raise UnexpectedLayer(nn.Conv2d, layers[layer_idx_next])
325
-
326
- # Identify and consolidate conv layer
327
- lyr_conv = layers[layer_idx_next]
328
- layer_idx_next += 1
329
- if layer_idx_next >= len(layers):
330
- raise MissingLayer(layer_idx_next)
331
- # Check and consolidate batch norm
332
- if isinstance(layers[layer_idx_next], nn.BatchNorm2d):
333
- lyr_conv = merge_conv_bn(lyr_conv, layers[layer_idx_next])
334
- layer_idx_next += 1
335
-
336
- # Check next layer exists
337
- try:
338
- lyr_spk = layers[layer_idx_next]
339
- layer_idx_next += 1
340
- except IndexError:
341
- raise MissingLayer(layer_idx_next)
342
-
343
- # Check that the next layer is spiking
344
- # TODO: Check that the next layer is an IAF layer
345
- if not isinstance(lyr_spk, sl.IAF):
346
- raise TypeError(
347
- f"Convolution must be followed by IAF spiking layer, found {type(lyr_spk)}"
348
- )
349
-
350
- # Check for next pooling layer
351
- lyr_pool, i_next, rescale_factor_after_pooling = construct_next_pooling_layer(
352
- layers, layer_idx_next
353
- )
354
- # Increment layer index to after the pooling layers
355
- layer_idx_next = i_next
356
-
357
- # Compose DynapcnnLayer
358
- dynapcnn_layer = DynapcnnLayer(
359
- conv=lyr_conv,
360
- spk=lyr_spk,
361
- pool=lyr_pool,
362
- in_shape=in_shape,
363
- discretize=discretize,
364
- rescale_weights=rescale_factor,
365
- )
366
-
367
- return dynapcnn_layer, layer_idx_next, rescale_factor_after_pooling
368
-
369
-
370
- def build_from_list(
371
- layers: List[nn.Module],
372
- in_shape,
373
- discretize=True,
374
- dvs_input=False,
375
- ) -> nn.Sequential:
376
- """Build a sequential model of DVSLayer and DynapcnnLayer(s) given a list of layers comprising
377
- a spiking CNN.
378
-
379
- Parameters
380
- ----------
381
-
382
- layers: sequence of layer objects
383
- in_shape: tuple of integers
384
- Shape of the input to the first layer in `layers`. Convention:
385
- (channels, height, width)
386
- discretize: bool
387
- Discretize weights and thresholds if True
388
- dvs_input: bool
389
- Whether model should receive DVS input. If `True`, the returned model
390
- will begin with a DVSLayer with `disable_pixel_array` set to False.
391
- Otherwise, the model starts with a DVSLayer only if the first element
392
- in `layers` is a pooling, cropping or flipping layer.
393
-
394
- Returns
395
- -------
396
- nn.Sequential
397
- """
398
- compatible_layers = []
399
- lyr_indx_next = 0
400
- # Find and populate dvs layer (NOTE: We are ignoring the channel information here and could lead to problems)
401
- dvs_layer, lyr_indx_next, rescale_factor = construct_dvs_layer(
402
- layers, input_shape=in_shape, idx_start=lyr_indx_next, dvs_input=dvs_input
403
- )
404
- if dvs_layer is not None:
405
- compatible_layers.append(dvs_layer)
406
- in_shape = dvs_layer.get_output_shape()
407
- # Find and populate dynapcnn layers
408
- while lyr_indx_next < len(layers):
409
- if isinstance(layers[lyr_indx_next], DEFAULT_IGNORED_LAYER_TYPES):
410
- # - Ignore identity, dropout and flatten layers
411
- lyr_indx_next += 1
412
- continue
413
- dynapcnn_layer, lyr_indx_next, rescale_factor = construct_next_dynapcnn_layer(
414
- layers,
415
- lyr_indx_next,
416
- in_shape=in_shape,
417
- discretize=discretize,
418
- rescale_factor=rescale_factor,
419
- )
420
- in_shape = dynapcnn_layer.get_output_shape()
421
- compatible_layers.append(dynapcnn_layer)
422
-
423
- return nn.Sequential(*compatible_layers)
424
-
425
-
426
- def convert_model_to_layer_list(
427
- model: Union[nn.Sequential, sinabs.Network],
428
- ignore: Union[Type, Tuple[Type, ...]] = (),
429
- ) -> List[nn.Module]:
430
- """Convert a model to a list of layers.
431
-
432
- Parameters
433
- ----------
434
- model: nn.Sequential or sinabs.Network
435
- ignore: type or tuple of types of modules to be ignored
436
-
437
- Returns
438
- -------
439
- List[nn.Module]
440
- """
441
- if isinstance(model, sinabs.Network):
442
- return convert_model_to_layer_list(model.spiking_model)
443
- elif isinstance(model, nn.Sequential):
444
- layers = [layer for layer in model if not isinstance(layer, ignore)]
445
- else:
446
- raise TypeError("Expected torch.nn.Sequential or sinabs.Network")
447
- return layers
25
+ Edge = Tuple[int, int] # Define edge-type alias
448
26
 
449
27
 
450
28
  def parse_device_id(device_id: str) -> Tuple[str, int]:
@@ -497,6 +75,138 @@ def standardize_device_id(device_id: str) -> str:
497
75
  return get_device_id(device_type=device_type, index=index)
498
76
 
499
77
 
78
+ def topological_sorting(edges: Set[Tuple[int, int]]) -> List[int]:
79
+ """Performs a topological sorting (using Kahn's algorithm) of a graph
80
+ described by a list of edges. An entry node `X` of the graph have to be
81
+ flagged inside `edges` by a tuple `('input', X)`.
82
+
83
+ Args:
84
+ edges (set): the edges describing the *acyclic* graph.
85
+
86
+ Returns:
87
+ The nodes sorted by the graph's topology.
88
+ """
89
+
90
+ graph = defaultdict(list)
91
+ in_degree = defaultdict(int)
92
+
93
+ # initialize the graph and in-degrees.
94
+ for u, v in edges:
95
+ if u != "input":
96
+ graph[u].append(v)
97
+ in_degree[v] += 1
98
+ else:
99
+ if v not in in_degree:
100
+ in_degree[v] = 0
101
+ if v not in in_degree:
102
+ in_degree[v] = 0
103
+
104
+ # find all nodes with zero in-degrees.
105
+ zero_in_degree_nodes = deque(
106
+ [node for node, degree in in_degree.items() if degree == 0]
107
+ )
108
+
109
+ # process nodes and create the topological order.
110
+ topological_order = []
111
+
112
+ while zero_in_degree_nodes:
113
+ node = zero_in_degree_nodes.popleft()
114
+ topological_order.append(node)
115
+
116
+ for neighbor in graph[node]:
117
+ in_degree[neighbor] -= 1
118
+ if in_degree[neighbor] == 0:
119
+ zero_in_degree_nodes.append(neighbor)
120
+
121
+ # check if all nodes are processed (to handle cycles).
122
+ if len(topological_order) == len(in_degree):
123
+ return topological_order
124
+
125
+ raise ValueError("The graph has a cycle and cannot be topologically sorted.")
126
+
127
+
128
+ def convert_cropping2dlayer_to_crop2d(
129
+ layer: sl.Cropping2dLayer, input_shape: Tuple[int, int]
130
+ ) -> Crop2d:
131
+ """Convert a sinabs layer of type Cropping2dLayer to Crop2d layer.
132
+
133
+ Args:
134
+ layer: Cropping2dLayer.
135
+ input_shape: (height, width) input dimensions.
136
+
137
+ Returns:
138
+ Equivalent Crop2d layer.
139
+ """
140
+ h, w = input_shape
141
+ top = layer.top_crop
142
+ left = layer.left_crop
143
+ bottom = h - layer.bottom_crop
144
+ right = w - layer.right_crop
145
+ print(h, w, left, right, top, bottom, layer.right_crop, layer.bottom_crop)
146
+ return Crop2d(((top, bottom), (left, right)))
147
+
148
+
149
+ WeightLayer = TypeVar("WeightLayer", nn.Linear, nn.Conv2d)
150
+
151
+
152
+ def merge_bn(
153
+ weight_layer: WeightLayer, bn: Union[nn.BatchNorm1d, nn.BatchNorm2d]
154
+ ) -> WeightLayer:
155
+ """Merge a convolutional or linear layer with subsequent batch
156
+ normalization.
157
+
158
+ Args:
159
+ weight_layer: torch.nn.Conv2d or nn.Linear. Convolutional or linear
160
+ layer
161
+ bn: torch.nn.Batchnorm2d or nn.Batchnorm1d. Batch normalization.
162
+
163
+ Returns:
164
+ Weight layer including batch normalization.
165
+ """
166
+ mu = bn.running_mean
167
+ sigmasq = bn.running_var
168
+
169
+ if bn.affine:
170
+ gamma, beta = bn.weight, bn.bias
171
+ else:
172
+ gamma, beta = 1.0, 0.0
173
+
174
+ factor = gamma / sigmasq.sqrt()
175
+
176
+ weight = weight_layer.weight.data.clone().detach()
177
+ bias = 0.0 if weight_layer.bias is None else weight_layer.bias.data.clone().detach()
178
+
179
+ weight_layer = deepcopy(weight_layer)
180
+
181
+ new_bias = beta + (bias - mu) * factor
182
+ if weight_layer.bias is None:
183
+ weight_layer.bias = nn.Parameter(new_bias)
184
+ else:
185
+ weight_layer.bias.data = new_bias
186
+
187
+ for __ in range(weight_layer.weight.ndim - factor.ndim):
188
+ factor.unsqueeze_(-1)
189
+ weight_layer.weight.data = weight * factor
190
+
191
+ return weight_layer
192
+
193
+
194
+ def merge_conv_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
195
+ """Merge a convolutional layer with subsequent batch normalization.
196
+
197
+ Args:
198
+ conv: torch.nn.Conv2d. Convolutional layer.
199
+ bn: torch.nn.Batchnorm2d. Batch normalization.
200
+
201
+ Returns:
202
+ Convolutional layer including batch normalization.
203
+ """
204
+ warnings.warn(
205
+ "`merge_conv_bn` is deprecated. Use `merge_bn` instead.", DeprecationWarning
206
+ )
207
+ return merge_bn(conv, bn)
208
+
209
+
500
210
  def extend_readout_layer(model: "DynapcnnNetwork") -> "DynapcnnNetwork":
501
211
  """Return a copied and extended model with the readout layer extended to 4 times the number of
502
212
  output channels. For Speck 2E and 2F, to get readout with correct output index, we need to
@@ -510,30 +220,75 @@ def extend_readout_layer(model: "DynapcnnNetwork") -> "DynapcnnNetwork":
510
220
  """
511
221
  model = deepcopy(model)
512
222
  input_shape = model.input_shape
513
- og_readout_conv_layer = model.sequence[
514
- -1
515
- ].conv_layer # extract the conv layer from dynapcnn network
516
- og_weight_data = og_readout_conv_layer.weight.data
517
- og_bias_data = og_readout_conv_layer.bias
518
- og_bias = og_bias_data is not None
519
- # modify the out channels
520
- og_out_channels = og_readout_conv_layer.out_channels
521
- new_out_channels = (og_out_channels - 1) * 4 + 1
522
- og_readout_conv_layer.out_channels = new_out_channels
523
- # build extended weight and replace the old one
524
- ext_weight_shape = (new_out_channels, *og_weight_data.shape[1:])
525
- ext_weight_data = torch.zeros(ext_weight_shape, dtype=og_weight_data.dtype)
526
- for i in range(og_out_channels):
527
- ext_weight_data[i * 4] = og_weight_data[i]
528
- og_readout_conv_layer.weight.data = ext_weight_data
529
- # build extended bias and replace if necessary
530
- if og_bias:
531
- ext_bias_shape = (new_out_channels,)
532
- ext_bias_data = torch.zeros(ext_bias_shape, dtype=og_bias_data.dtype)
223
+ for exit_layer in model.exit_layers:
224
+ # extract the conv layer from dynapcnn network
225
+ og_readout_conv_layer = exit_layer.conv_layer
226
+ og_weight_data = og_readout_conv_layer.weight.data
227
+ og_bias_data = og_readout_conv_layer.bias
228
+ og_bias = og_bias_data is not None
229
+ # modify the out channels
230
+ og_out_channels = og_readout_conv_layer.out_channels
231
+ new_out_channels = (og_out_channels - 1) * 4 + 1
232
+ og_readout_conv_layer.out_channels = new_out_channels
233
+ # build extended weight and replace the old one
234
+ ext_weight_shape = (new_out_channels, *og_weight_data.shape[1:])
235
+ ext_weight_data = torch.zeros(ext_weight_shape, dtype=og_weight_data.dtype)
533
236
  for i in range(og_out_channels):
534
- ext_bias_data[i * 4] = og_bias_data[i]
535
- og_readout_conv_layer.bias.data = ext_bias_data
536
- _ = model(
537
- torch.zeros(size=(1, *input_shape))
538
- ) # run a forward pass to initialize the new weights and last IAF
237
+ ext_weight_data[i * 4] = og_weight_data[i]
238
+ og_readout_conv_layer.weight.data = ext_weight_data
239
+ # build extended bias and replace if necessary
240
+ if og_bias:
241
+ ext_bias_shape = (new_out_channels,)
242
+ ext_bias_data = torch.zeros(ext_bias_shape, dtype=og_bias_data.dtype)
243
+ for i in range(og_out_channels):
244
+ ext_bias_data[i * 4] = og_bias_data[i]
245
+ og_readout_conv_layer.bias.data = ext_bias_data
246
+ # run a forward pass to initialize the new weights and last IAF
247
+ model(torch.zeros(size=(1, *input_shape)))
539
248
  return model
249
+
250
+
251
+ def infer_input_shape(
252
+ snn: nn.Module, input_shape: Optional[Tuple[int, int, int]] = None
253
+ ) -> Tuple[int, int, int]:
254
+ """Infer expected shape of input for `snn` either from `input_shape`
255
+ or from `DVSLayer` instance within `snn` which provides it.
256
+
257
+ If neither are available, raise an InputConfigurationError.
258
+ If both are the case, verify that the information is consistent.
259
+
260
+ Args:
261
+ snn (nn.Module): The SNN whose input shape is to be inferred.
262
+ input_shape (tuple or None): Explicitly provide input shape.
263
+ If not None, must be of the format `(channels, height, width)`.
264
+
265
+ Returns:
266
+ The input shape to `snn`, in the format `(channels, height, width)`
267
+ """
268
+ if input_shape is not None and len(input_shape) != 3:
269
+ raise InputConfigurationError(
270
+ f"input_shape expected to have length 3 or None but input_shape={input_shape} given."
271
+ )
272
+
273
+ # Find `DVSLayer` instance and infer input shape from it
274
+ input_shape_from_layer = None
275
+ for module in snn.modules():
276
+ if isinstance(module, DVSLayer):
277
+ input_shape_from_layer = module.input_shape
278
+ # Make sure `input_shape_from_layer` is identical to provided `input_shape`
279
+ if input_shape is not None and input_shape != input_shape_from_layer:
280
+ raise InputConfigurationError(
281
+ f"Input shape from `DVSLayer` {input_shape_from_layer} does "
282
+ f"not match the specified input_shape {input_shape}"
283
+ )
284
+ return input_shape_from_layer
285
+
286
+ # If no `DVSLayer` is found, `input_shape` must not be provided
287
+ if input_shape is None:
288
+ raise InputConfigurationError(
289
+ "No input shape could be inferred. Either provide it explicitly "
290
+ "with the `input_shape` argument, or provide a model with "
291
+ "`DVSLayer` instance."
292
+ )
293
+ else:
294
+ return input_shape