sinabs 3.0.4.dev25__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. sinabs/activation/reset_mechanism.py +3 -3
  2. sinabs/activation/surrogate_gradient_fn.py +4 -4
  3. sinabs/backend/dynapcnn/__init__.py +5 -4
  4. sinabs/backend/dynapcnn/chip_factory.py +33 -61
  5. sinabs/backend/dynapcnn/chips/dynapcnn.py +182 -86
  6. sinabs/backend/dynapcnn/chips/speck2e.py +6 -5
  7. sinabs/backend/dynapcnn/chips/speck2f.py +6 -5
  8. sinabs/backend/dynapcnn/config_builder.py +39 -59
  9. sinabs/backend/dynapcnn/connectivity_specs.py +48 -0
  10. sinabs/backend/dynapcnn/discretize.py +91 -155
  11. sinabs/backend/dynapcnn/dvs_layer.py +59 -101
  12. sinabs/backend/dynapcnn/dynapcnn_layer.py +185 -119
  13. sinabs/backend/dynapcnn/dynapcnn_layer_utils.py +335 -0
  14. sinabs/backend/dynapcnn/dynapcnn_network.py +602 -325
  15. sinabs/backend/dynapcnn/dynapcnnnetwork_module.py +370 -0
  16. sinabs/backend/dynapcnn/exceptions.py +122 -3
  17. sinabs/backend/dynapcnn/io.py +51 -91
  18. sinabs/backend/dynapcnn/mapping.py +111 -75
  19. sinabs/backend/dynapcnn/nir_graph_extractor.py +877 -0
  20. sinabs/backend/dynapcnn/sinabs_edges_handler.py +1024 -0
  21. sinabs/backend/dynapcnn/utils.py +214 -459
  22. sinabs/backend/dynapcnn/weight_rescaling_methods.py +53 -0
  23. sinabs/conversion.py +2 -2
  24. sinabs/from_torch.py +23 -1
  25. sinabs/hooks.py +38 -41
  26. sinabs/layers/alif.py +16 -16
  27. sinabs/layers/crop2d.py +2 -2
  28. sinabs/layers/exp_leak.py +1 -1
  29. sinabs/layers/iaf.py +11 -11
  30. sinabs/layers/lif.py +9 -9
  31. sinabs/layers/neuromorphic_relu.py +9 -8
  32. sinabs/layers/pool2d.py +5 -5
  33. sinabs/layers/quantize.py +1 -1
  34. sinabs/layers/stateful_layer.py +10 -7
  35. sinabs/layers/to_spike.py +9 -9
  36. sinabs/network.py +14 -12
  37. sinabs/synopcounter.py +10 -7
  38. sinabs/utils.py +155 -7
  39. sinabs/validate_memory_speck.py +0 -5
  40. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/METADATA +2 -1
  41. sinabs-3.1.0.dist-info/RECORD +65 -0
  42. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/licenses/AUTHORS +1 -0
  43. sinabs-3.1.0.dist-info/pbr.json +1 -0
  44. sinabs-3.0.4.dev25.dist-info/RECORD +0 -59
  45. sinabs-3.0.4.dev25.dist-info/pbr.json +0 -1
  46. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/WHEEL +0 -0
  47. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/licenses/LICENSE +0 -0
  48. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.0.dist-info}/top_level.txt +0 -0
@@ -3,49 +3,26 @@ from typing import Optional, Tuple
3
3
  import torch.nn as nn
4
4
 
5
5
  from sinabs.layers import SumPool2d
6
+ from sinabs.utils import expand_to_pair
6
7
 
7
8
  from .crop2d import Crop2d
8
9
  from .flipdims import FlipDims
9
10
 
10
11
 
11
- def expand_to_pair(value) -> (int, int):
12
- """Expand a given value to a pair (tuple) if an int is passed.
13
-
14
- Parameters
15
- ----------
16
- value:
17
- int
18
-
19
- Returns
20
- -------
21
- pair:
22
- (int, int)
23
- """
24
- return (value, value) if isinstance(value, int) else value
25
-
26
-
27
12
  class DVSLayer(nn.Module):
28
13
  """DVSLayer representing the DVS pixel array on chip and/or the pre-processing. The order of
29
14
  processing is as follows MergePolarity -> Pool -> Cut -> Flip.
30
15
 
31
- Parameters
32
- ----------
33
- input_shape;
34
- Shape of input (height, width)
35
- pool:
36
- Sum pooling kernel size (height, width)
37
- crop:
38
- Crop the input to the given ROI ((top, bottom), (left, right))
39
- merge_polarities:
40
- If true, events from both polarities will be merged.
41
- flip_x:
42
- Flip the X axis
43
- flip_y:
44
- Flip the Y axis
45
- swap_xy:
46
- Swap X and Y dimensions
47
- disable_pixel_array:
48
- Disable the pixel array. This is useful if you want to use the DVS layer for input preprocessing.
16
+ Args:
17
+ input_shape: Shape of input (height, width).
18
+ pool: Sum pooling kernel size (height, width).
19
+ crop: Crop the input to the given ROI ((top, bottom), (left, right)).
20
+ merge_polarities: If true, events from both polarities will be merged.
21
+ flip_x: Flip the X axis.
22
+ flip_y: Flip the Y axis.
23
+ swap_xy: Swap X and Y dimensions.
24
+ disable_pixel_array: Disable the pixel array. This is useful if you want to use the
25
+ DVS layer for input preprocessing.
49
26
  """
50
27
 
51
28
  def __init__(
@@ -97,22 +74,15 @@ class DVSLayer(nn.Module):
97
74
  ) -> "DVSLayer":
98
75
  """Alternative factory method. Generate a DVSLayer from a set of torch layers.
99
76
 
100
- Parameters
101
- ----------
102
- input_shape:
103
- (channels, height, width)
104
- pool_layer:
105
- SumPool2d layer
106
- crop_layer:
107
- Crop2d layer
108
- flip_layer:
109
- FlipDims layer
110
- disable_pixel_array:
111
- Whether pixel array of new DVSLayer should be disabled.
112
-
113
- Returns
114
- -------
115
- DVSLayer
77
+ Args:
78
+ input_shape: (channels, height, width).
79
+ pool_layer: SumPool2d layer.
80
+ crop_layer: Crop2d layer.
81
+ flip_layer: FlipDims layer.
82
+ disable_pixel_array: Whether pixel array of new DVSLayer should be disabled.
83
+
84
+ Returns:
85
+ DVSLayer
116
86
  """
117
87
  pool = (1, 1)
118
88
  crop = None
@@ -156,9 +126,8 @@ class DVSLayer(nn.Module):
156
126
  def input_shape_dict(self) -> dict:
157
127
  """The configuration dictionary for the input shape.
158
128
 
159
- Returns
160
- -------
161
- dict
129
+ Returns:
130
+ dict
162
131
  """
163
132
  channel_count, input_size_y, input_size_x = self.input_shape
164
133
 
@@ -173,9 +142,8 @@ class DVSLayer(nn.Module):
173
142
  def get_output_shape_after_pooling(self) -> Tuple[int, int, int]:
174
143
  """Get the shape of data just after the pooling layer.
175
144
 
176
- Returns
177
- -------
178
- (channel, height, width)
145
+ Returns:
146
+ (channel, height, width)
179
147
  """
180
148
  channel_count, input_size_y, input_size_x = self.input_shape
181
149
 
@@ -191,9 +159,8 @@ class DVSLayer(nn.Module):
191
159
  def get_output_shape_dict(self) -> dict:
192
160
  """Configuration dictionary for output shape.
193
161
 
194
- Returns
195
- -------
196
- dict
162
+ Returns:
163
+ dict
197
164
  """
198
165
  (
199
166
  channel_count,
@@ -202,14 +169,13 @@ class DVSLayer(nn.Module):
202
169
  ) = self.get_output_shape_after_pooling()
203
170
 
204
171
  # Compute dims after cropping
205
- if self.crop_layer is not None:
206
- (
207
- channel_count,
208
- output_size_y,
209
- output_size_x,
210
- ) = self.crop_layer.get_output_shape(
211
- (channel_count, output_size_y, output_size_x)
212
- )
172
+ (
173
+ channel_count,
174
+ output_size_y,
175
+ output_size_x,
176
+ ) = self.crop_layer.get_output_shape(
177
+ (channel_count, output_size_y, output_size_x)
178
+ )
213
179
 
214
180
  # Compute dims after pooling
215
181
  return {
@@ -237,11 +203,13 @@ class DVSLayer(nn.Module):
237
203
  # Merge polarities
238
204
  if self.merge_polarities:
239
205
  data = data.sum(1, keepdim=True)
206
+
240
207
  # Pool
241
208
  out = self.pool_layer(data)
209
+
242
210
  # Crop
243
- if self.crop_layer is not None:
244
- out = self.crop_layer(out)
211
+ out = self.crop_layer(out)
212
+
245
213
  # Flip stuff
246
214
  out = self.flip_layer(out)
247
215
 
@@ -250,9 +218,8 @@ class DVSLayer(nn.Module):
250
218
  def get_pooling(self) -> Tuple[int, int]:
251
219
  """Pooling kernel shape.
252
220
 
253
- Returns
254
- -------
255
- (ky, kx)
221
+ Returns:
222
+ (ky, kx)
256
223
  """
257
224
  return expand_to_pair(self.pool_layer.kernel_size)
258
225
 
@@ -260,26 +227,20 @@ class DVSLayer(nn.Module):
260
227
  """The coordinates for ROI. Note that this is not the same as crop parameter passed during
261
228
  the object construction.
262
229
 
263
- Returns
264
- -------
265
- ((top, bottom), (left, right))
230
+ Returns:
231
+ ((top, bottom), (left, right))
266
232
  """
267
- if self.crop_layer is not None:
268
- _, h, w = self.get_output_shape_after_pooling()
269
- return (
270
- (self.crop_layer.top_crop, self.crop_layer.bottom_crop),
271
- (self.crop_layer.left_crop, self.crop_layer.right_crop),
272
- )
273
- else:
274
- _, output_size_y, output_size_x = self.get_output_shape()
275
- return (0, output_size_y), (0, output_size_x)
233
+ _, h, w = self.get_output_shape_after_pooling()
234
+ return (
235
+ (self.crop_layer.top_crop, self.crop_layer.bottom_crop),
236
+ (self.crop_layer.left_crop, self.crop_layer.right_crop),
237
+ )
276
238
 
277
239
  def get_output_shape(self) -> Tuple[int, int, int]:
278
240
  """Output shape of the layer.
279
241
 
280
- Returns
281
- -------
282
- (channel, height, width)
242
+ Returns:
243
+ (channel, height, width)
283
244
  """
284
245
  channel_count, input_size_y, input_size_x = self.input_shape
285
246
 
@@ -292,23 +253,21 @@ class DVSLayer(nn.Module):
292
253
  output_size_y = input_size_y // pooling[0]
293
254
 
294
255
  # Compute dims after cropping
295
- if self.crop_layer is not None:
296
- (
297
- channel_count,
298
- output_size_y,
299
- output_size_x,
300
- ) = self.crop_layer.get_output_shape(
301
- (channel_count, output_size_y, output_size_x)
302
- )
256
+ (
257
+ channel_count,
258
+ output_size_y,
259
+ output_size_x,
260
+ ) = self.crop_layer.get_output_shape(
261
+ (channel_count, output_size_y, output_size_x)
262
+ )
303
263
 
304
264
  return channel_count, output_size_y, output_size_x
305
265
 
306
266
  def get_flip_dict(self) -> dict:
307
267
  """Configuration dictionary for x, y flip.
308
268
 
309
- Returns
310
- -------
311
- dict
269
+ Returns:
270
+ dict
312
271
  """
313
272
 
314
273
  return {"x": self.flip_layer.flip_x, "y": self.flip_layer.flip_y}
@@ -316,8 +275,7 @@ class DVSLayer(nn.Module):
316
275
  def get_swap_xy(self) -> bool:
317
276
  """True if XY has to be swapped.
318
277
 
319
- Returns
320
- -------
321
- bool
278
+ Returns:
279
+ bool
322
280
  """
323
281
  return self.flip_layer.swap_xy
@@ -1,40 +1,76 @@
1
1
  from copy import deepcopy
2
- from typing import Dict, Optional, Tuple, Union
3
- from warnings import warn
2
+ from functools import partial
3
+ from typing import List, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
7
  from torch import nn
8
8
 
9
- import sinabs.activation
10
9
  import sinabs.layers as sl
11
10
 
12
11
  from .discretize import discretize_conv_spike_
13
- from .dvs_layer import expand_to_pair
12
+
13
+ # Define sum pooling functional as power-average pooling with power 1
14
+ sum_pool2d = partial(nn.functional.lp_pool2d, norm_type=1)
15
+
16
+
17
+ def convert_linear_to_conv(
18
+ lin: nn.Linear, input_shape: Tuple[int, int, int]
19
+ ) -> nn.Conv2d:
20
+ """Convert Linear layer to Conv2d.
21
+
22
+ Args:
23
+ lin (nn.Linear): linear layer to be converted.
24
+ input_shape (tuple): the tensor shape the layer expects.
25
+
26
+ Returns:
27
+ convolutional layer equivalent to `lin`.
28
+ """
29
+ in_chan, in_h, in_w = input_shape
30
+ if lin.in_features != in_chan * in_h * in_w:
31
+ raise ValueError(
32
+ "Shape of linear layer weight does not match provided input shape"
33
+ )
34
+
35
+ layer = nn.Conv2d(
36
+ in_channels=in_chan,
37
+ kernel_size=(in_h, in_w),
38
+ out_channels=lin.out_features,
39
+ padding=0,
40
+ bias=lin.bias is not None,
41
+ )
42
+
43
+ if lin.bias is not None:
44
+ layer.bias.data = lin.bias.data.clone().detach()
45
+
46
+ layer.weight.data = (
47
+ lin.weight.data.clone()
48
+ .detach()
49
+ .reshape((lin.out_features, in_chan, in_h, in_w))
50
+ )
51
+
52
+ return layer
14
53
 
15
54
 
16
55
  class DynapcnnLayer(nn.Module):
17
- """Create a DynapcnnLayer object representing a dynapcnn layer.
18
-
19
- Requires a convolutional layer, a sinabs spiking layer and an optional
20
- pooling value. The layers are used in the order conv -> spike -> pool.
21
-
22
- Parameters
23
- ----------
24
- conv: torch.nn.Conv2d or torch.nn.Linear
25
- Convolutional or linear layer (linear will be converted to convolutional)
26
- spk: sinabs.layers.IAFSqueeze
27
- Sinabs IAF layer
28
- in_shape: tuple of int
29
- The input shape, needed to create dynapcnn configs if the network does not
30
- contain an input layer. Convention: (features, height, width)
31
- pool: int or None
32
- Integer representing the sum pooling kernel and stride. If `None`, no
33
- pooling will be applied.
34
- discretize: bool
35
- Whether to discretize parameters.
36
- rescale_weights: int
37
- Layer weights will be divided by this value.
56
+ """Create a DynapcnnLayer object representing a layer on DynapCNN or Speck.
57
+
58
+ Requires a convolutional layer, a sinabs spiking layer and a list of
59
+ pooling values. The layers are used in the order conv -> spike -> pool.
60
+
61
+ Attributes:
62
+ conv: torch.nn.Conv2d or torch.nn.Linear. Convolutional or linear layer.
63
+ Linear will be converted to convolutional.
64
+ spk (sinabs.layers.IAFSqueeze): Sinabs IAF layer.
65
+ in_shape (tuple of int): The input shape, needed to create dynapcnn configs
66
+ if the network does not contain an input layer.
67
+ Convention: (features, height, width).
68
+ pool (List of integers): Each integer entry represents an output (destination
69
+ on chip) and whether pooling should be applied (values > 1) or not
70
+ (values equal to 1). The number of entries determines the number of tensors
71
+ the layer's forward method returns.
72
+ discretize (bool): Whether to discretize parameters.
73
+ rescale_weights (int): Layer weights will be multiplied by this value.
38
74
  """
39
75
 
40
76
  def __init__(
@@ -42,126 +78,136 @@ class DynapcnnLayer(nn.Module):
42
78
  conv: nn.Conv2d,
43
79
  spk: sl.IAFSqueeze,
44
80
  in_shape: Tuple[int, int, int],
45
- pool: Optional[sl.SumPool2d] = None,
81
+ pool: List[int],
46
82
  discretize: bool = True,
47
83
  rescale_weights: int = 1,
48
84
  ):
49
85
  super().__init__()
50
86
 
51
- self.input_shape = in_shape
87
+ self.in_shape = in_shape
88
+ self.pool = pool
89
+ self._discretize = discretize
90
+ self._rescale_weights = rescale_weights
52
91
 
92
+ if not isinstance(spk, sl.IAFSqueeze):
93
+ raise TypeError(
94
+ f"Unsupported spiking layer type {type(spk)}. "
95
+ "Only `IAFSqueeze` layers are supported."
96
+ )
53
97
  spk = deepcopy(spk)
98
+
99
+ # Convert `nn.Linear` to `nn.Conv2d`.
54
100
  if isinstance(conv, nn.Linear):
55
- conv = self._convert_linear_to_conv(conv)
56
- if spk.is_state_initialised():
57
- # Expand dims
58
- spk.v_mem = spk.v_mem.data.unsqueeze(-1).unsqueeze(-1)
101
+ conv = convert_linear_to_conv(conv, in_shape)
102
+ if spk.is_state_initialised() and (ndim := spk.v_mem.ndim) < 4:
103
+ for __ in range(4 - ndim):
104
+ # Expand spatial dimensions
105
+ spk.v_mem = spk.v_mem.data.unsqueeze(-1)
59
106
  else:
60
107
  conv = deepcopy(conv)
61
108
 
62
- if rescale_weights != 1:
109
+ if self._rescale_weights != 1:
63
110
  # this has to be done after copying but before discretizing
64
- conv.weight.data = (conv.weight / rescale_weights).clone().detach()
111
+ conv.weight.data = (conv.weight * self._rescale_weights).clone().detach()
112
+
113
+ # check if convolution kernel is a square.
114
+ if conv.kernel_size[0] != conv.kernel_size[1]:
115
+ raise ValueError(
116
+ "The kernel of a `nn.Conv2d` must have the same height and width."
117
+ )
118
+ for pool_size in pool:
119
+ if pool_size[0] != pool_size[1]:
120
+ raise ValueError("Only square pooling kernels are supported")
65
121
 
66
- self.discretize = discretize
67
- if discretize:
68
- # int conversion is done while writing the config.
122
+ # int conversion is done while writing the config.
123
+ if self._discretize:
69
124
  conv, spk = discretize_conv_spike_(conv, spk, to_int=False)
70
125
 
71
- self.conv_layer = conv
72
- self.spk_layer = spk
73
- if pool is not None:
74
- if pool.kernel_size[0] != pool.kernel_size[1]:
75
- raise ValueError("Only square kernels are supported")
76
- self.pool_layer = deepcopy(pool)
77
- else:
78
- self.pool_layer = None
126
+ self.conv = conv
127
+ self.spk = spk
79
128
 
80
- def _convert_linear_to_conv(self, lin: nn.Linear) -> nn.Conv2d:
81
- """Convert Linear layer to Conv2d.
129
+ @property
130
+ def conv_layer(self):
131
+ return self.conv
82
132
 
83
- Parameters
84
- ----------
85
- lin: nn.Linear
86
- Linear layer to be converted
133
+ @property
134
+ def spk_layer(self):
135
+ return self.spk
87
136
 
88
- Returns
89
- -------
90
- nn.Conv2d
91
- Convolutional layer equivalent to `lin`.
92
- """
137
+ @property
138
+ def discretize(self):
139
+ return self._discretize
93
140
 
94
- in_chan, in_h, in_w = self.input_shape
141
+ @property
142
+ def rescale_weights(self):
143
+ return self._rescale_weights
95
144
 
96
- if lin.in_features != in_chan * in_h * in_w:
97
- raise ValueError("Shapes don't match.")
145
+ @property
146
+ def conv_out_shape(self):
147
+ return self._get_conv_output_shape()
98
148
 
99
- layer = nn.Conv2d(
100
- in_channels=in_chan,
101
- kernel_size=(in_h, in_w),
102
- out_channels=lin.out_features,
103
- padding=0,
104
- bias=lin.bias is not None,
105
- )
149
+ def forward(self, x) -> List[torch.Tensor]:
150
+ """Torch forward pass.
106
151
 
107
- if lin.bias is not None:
108
- layer.bias.data = lin.bias.data.clone().detach()
152
+ ...
153
+ """
109
154
 
110
- layer.weight.data = (
111
- lin.weight.data.clone()
112
- .detach()
113
- .reshape((lin.out_features, in_chan, in_h, in_w))
114
- )
155
+ returns = []
115
156
 
116
- return layer
157
+ x = self.conv_layer(x)
158
+ x = self.spk_layer(x)
159
+
160
+ for pool in self.pool:
161
+ if pool == 1:
162
+ # no pooling is applied.
163
+ returns.append(x)
164
+ else:
165
+ # sum pooling of `(pool, pool)` is applied.
166
+ pool_out = sum_pool2d(x, kernel_size=pool)
167
+ returns.append(pool_out)
168
+
169
+ if len(returns) == 1:
170
+ return returns[0]
171
+ else:
172
+ return tuple(returns)
173
+
174
+ def zero_grad(self, set_to_none: bool = False) -> None:
175
+ """Call `zero_grad` method of spiking layer"""
176
+ return self.spk.zero_grad(set_to_none)
117
177
 
118
178
  def get_neuron_shape(self) -> Tuple[int, int, int]:
119
179
  """Return the output shape of the neuron layer.
120
180
 
121
- Returns
122
- -------
123
- features, height, width
181
+ Returns:
182
+ conv_out_shape (tuple): formatted as (features, height, width).
124
183
  """
184
+ # same as the convolution's output.
185
+ return self._get_conv_output_shape()
125
186
 
126
- def get_shape_after_conv(layer: nn.Conv2d, input_shape):
127
- (ch_in, h_in, w_in) = input_shape
128
- (kh, kw) = expand_to_pair(layer.kernel_size)
129
- (pad_h, pad_w) = expand_to_pair(layer.padding)
130
- (stride_h, stride_w) = expand_to_pair(layer.stride)
131
-
132
- def out_len(in_len, k, s, p):
133
- return (in_len - k + 2 * p) // s + 1
134
-
135
- out_h = out_len(h_in, kh, stride_h, pad_h)
136
- out_w = out_len(w_in, kw, stride_w, pad_w)
137
- ch_out = layer.out_channels
138
- return ch_out, out_h, out_w
187
+ def get_output_shape(self) -> List[Tuple[int, int, int]]:
188
+ """Return the output shapes of the layer, including pooling.
139
189
 
140
- conv_out_shape = get_shape_after_conv(
141
- self.conv_layer, input_shape=self.input_shape
142
- )
143
- return conv_out_shape
144
-
145
- def get_output_shape(self) -> Tuple[int, int, int]:
190
+ Returns:
191
+ One entry per destination, each formatted as (features, height, width).
192
+ """
146
193
  neuron_shape = self.get_neuron_shape()
147
194
  # this is the actual output shape, including pooling
148
- if self.pool_layer is not None:
149
- pool = expand_to_pair(self.pool_layer.kernel_size)
150
- return (
195
+ output_shape = []
196
+ for pool in self.pool:
197
+ output_shape.append(
151
198
  neuron_shape[0],
152
- neuron_shape[1] // pool[0],
153
- neuron_shape[2] // pool[1],
199
+ neuron_shape[1] // pool,
200
+ neuron_shape[2] // pool,
154
201
  )
155
- else:
156
- return neuron_shape
202
+ return output_shape
157
203
 
158
204
  def summary(self) -> dict:
205
+ """Returns a summary of the convolution's/pooling's kernel sizes and the output shape of the spiking layer."""
206
+
159
207
  return {
160
- "pool": (
161
- None if self.pool_layer is None else list(self.pool_layer.kernel_size)
162
- ),
208
+ "pool": (self.pool),
163
209
  "kernel": list(self.conv_layer.weight.data.shape),
164
- "neuron": self.get_neuron_shape(),
210
+ "neuron": self._get_conv_output_shape(), # neuron layer output has the same shape as the convolution layer ouput.
165
211
  }
166
212
 
167
213
  def memory_summary(self):
@@ -177,13 +223,18 @@ class DynapcnnLayer(nn.Module):
177
223
 
178
224
  N_{MT} = f \\cdot 2^{ \\lceil \\log_2\\left(f_y\\right) \\rceil + \\lceil \\log_2\\left(f_x\\right) \\rceil }
179
225
 
180
- Returns
181
- -------
182
- A dictionary with keys kernel, neuron and bias and the corresponding memory sizes
226
+ Returns:
227
+ A dictionary with keys kernel, neuron and bias and the corresponding memory sizes
183
228
  """
184
229
  summary = self.summary()
185
230
  f, c, h, w = summary["kernel"]
186
- f, neuron_height, neuron_width = self.get_neuron_shape()
231
+ (
232
+ f,
233
+ neuron_height,
234
+ neuron_width,
235
+ ) = (
236
+ self._get_conv_output_shape()
237
+ ) # neuron layer output has the same shape as the convolution layer ouput.
187
238
 
188
239
  return {
189
240
  "kernel": c * pow(2, np.ceil(np.log2(h * w)) + np.ceil(np.log2(f))),
@@ -192,13 +243,28 @@ class DynapcnnLayer(nn.Module):
192
243
  "bias": 0 if self.conv_layer.bias is None else len(self.conv_layer.bias),
193
244
  }
194
245
 
195
- def forward(self, x):
196
- """Torch forward pass."""
197
- x = self.conv_layer(x)
198
- x = self.spk_layer(x)
199
- if self.pool_layer is not None:
200
- x = self.pool_layer(x)
201
- return x
246
+ def _get_conv_output_shape(self) -> Tuple[int, int, int]:
247
+ """Computes the output dimensions of `conv_layer`.
202
248
 
203
- def zero_grad(self, set_to_none: bool = False) -> None:
204
- return self.spk_layer.zero_grad(set_to_none)
249
+ Returns:
250
+ output dimensions (tuple): a tuple describing `(output channels, height, width)`.
251
+ """
252
+ # get the layer's parameters.
253
+
254
+ out_channels = self.conv_layer.out_channels
255
+ kernel_size = self.conv_layer.kernel_size
256
+ stride = self.conv_layer.stride
257
+ padding = self.conv_layer.padding
258
+ dilation = self.conv_layer.dilation
259
+
260
+ # compute the output height and width.
261
+ out_height = (
262
+ (self.in_shape[1] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1)
263
+ // stride[0]
264
+ ) + 1
265
+ out_width = (
266
+ (self.in_shape[2] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1)
267
+ // stride[1]
268
+ ) + 1
269
+
270
+ return (out_channels, out_height, out_width)