sinabs 3.0.4.dev25__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sinabs/activation/reset_mechanism.py +3 -3
  2. sinabs/activation/surrogate_gradient_fn.py +4 -4
  3. sinabs/backend/dynapcnn/__init__.py +5 -4
  4. sinabs/backend/dynapcnn/chip_factory.py +33 -61
  5. sinabs/backend/dynapcnn/chips/dynapcnn.py +182 -86
  6. sinabs/backend/dynapcnn/chips/speck2e.py +6 -5
  7. sinabs/backend/dynapcnn/chips/speck2f.py +6 -5
  8. sinabs/backend/dynapcnn/config_builder.py +39 -59
  9. sinabs/backend/dynapcnn/connectivity_specs.py +48 -0
  10. sinabs/backend/dynapcnn/discretize.py +91 -155
  11. sinabs/backend/dynapcnn/dvs_layer.py +59 -101
  12. sinabs/backend/dynapcnn/dynapcnn_layer.py +185 -119
  13. sinabs/backend/dynapcnn/dynapcnn_layer_utils.py +335 -0
  14. sinabs/backend/dynapcnn/dynapcnn_network.py +602 -325
  15. sinabs/backend/dynapcnn/dynapcnnnetwork_module.py +370 -0
  16. sinabs/backend/dynapcnn/exceptions.py +122 -3
  17. sinabs/backend/dynapcnn/io.py +55 -92
  18. sinabs/backend/dynapcnn/mapping.py +111 -75
  19. sinabs/backend/dynapcnn/nir_graph_extractor.py +877 -0
  20. sinabs/backend/dynapcnn/sinabs_edges_handler.py +1024 -0
  21. sinabs/backend/dynapcnn/utils.py +214 -459
  22. sinabs/backend/dynapcnn/weight_rescaling_methods.py +53 -0
  23. sinabs/conversion.py +2 -2
  24. sinabs/from_torch.py +23 -1
  25. sinabs/hooks.py +38 -41
  26. sinabs/layers/alif.py +16 -16
  27. sinabs/layers/crop2d.py +2 -2
  28. sinabs/layers/exp_leak.py +1 -1
  29. sinabs/layers/iaf.py +11 -11
  30. sinabs/layers/lif.py +9 -9
  31. sinabs/layers/neuromorphic_relu.py +9 -8
  32. sinabs/layers/pool2d.py +5 -5
  33. sinabs/layers/quantize.py +1 -1
  34. sinabs/layers/stateful_layer.py +10 -7
  35. sinabs/layers/to_spike.py +9 -9
  36. sinabs/network.py +14 -12
  37. sinabs/nir.py +4 -3
  38. sinabs/synopcounter.py +10 -7
  39. sinabs/utils.py +155 -7
  40. sinabs/validate_memory_speck.py +0 -5
  41. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/METADATA +3 -2
  42. sinabs-3.1.1.dist-info/RECORD +65 -0
  43. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/licenses/AUTHORS +1 -0
  44. sinabs-3.1.1.dist-info/pbr.json +1 -0
  45. sinabs-3.0.4.dev25.dist-info/RECORD +0 -59
  46. sinabs-3.0.4.dev25.dist-info/pbr.json +0 -1
  47. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/WHEEL +0 -0
  48. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/licenses/LICENSE +0 -0
  49. {sinabs-3.0.4.dev25.dist-info → sinabs-3.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,53 @@
1
+ import statistics
2
+ from typing import Iterable
3
+
4
+ import numpy as np
5
+
6
+
7
+ def rescale_method_1(scaling_factors: Iterable[int], lambda_: float = 0.5) -> float:
8
+ """
9
+ This method will use the average (scaled by `lambda_`) of the computed re-scaling factor
10
+ for the pooling layer(s) feeding into a convolutional layer.
11
+
12
+ Arguments
13
+ ---------
14
+ - scaling_factors (list): the list of re-scaling factors computed by each `SumPool2d` layer targeting a
15
+ single `Conv2d` layer within a `DynapcnnLayer` instance.
16
+ - lambda_ (float): a scaling variable that multiplies the computed average re-scaling factor of the pooling layers.
17
+
18
+ Returns
19
+ ---------
20
+ - the averaged re-scaling factor multiplied by `lambda_` if `len(scaling_factors) > 0`, else `1` is returned.
21
+ """
22
+
23
+ if len(scaling_factors) > 0:
24
+ return np.round(np.mean(list(scaling_factors)) * lambda_, 2)
25
+ else:
26
+ return 1.0
27
+
28
+
29
+ def rescale_method_2(scaling_factors: Iterable[int], lambda_: float = 0.5) -> float:
30
+ """
31
+ This method will use the harmonic mean (scaled by `lambda_`) of the computed re-scaling factor
32
+ for the pooling layer(s) feeding into a convolutional layer.
33
+
34
+ Arguments
35
+ ---------
36
+ - scaling_factors (list): the list of re-scaling factors computed by each `SumPool2d` layer targeting a
37
+ single `Conv2d` layer within a `DynapcnnLayer` instance.
38
+ - lambda_ (float): a scaling variable that multiplies the computed average re-scaling factor of the pooling layers.
39
+
40
+ Returns
41
+ ---------
42
+ - the averaged re-scaling factor multiplied by `lambda_` if `len(scaling_factors) > 0`, else `1` is returned.
43
+
44
+ Note
45
+ ---------
46
+ - since the harmonic mean is less sensitive to outliers it **could be** that this is a better method
47
+ for weight re-scaling when multiple poolings with big differentces in kernel sizes are being considered.
48
+ """
49
+
50
+ if len(scaling_factors) > 0:
51
+ return np.round(statistics.harmonic_mean(list(scaling_factors)) * lambda_, 2)
52
+ else:
53
+ return 1.0
sinabs/conversion.py CHANGED
@@ -9,7 +9,7 @@ def replace_module(model: nn.Module, source_class: type, mapper_fn: Callable):
9
9
  """A utility function that returns a copy of the model, where specific layers are replaced with
10
10
  another type depending on the mapper function.
11
11
 
12
- Parameters:
12
+ Args:
13
13
  model: A PyTorch model.
14
14
  source_class: the layer class to replace. Each find will be passed to mapper_fn
15
15
  mapper_fn: A callable that takes as argument the layer to replace and returns the new object.
@@ -31,7 +31,7 @@ def replace_module_(model: nn.Sequential, source_class: type, mapper_fn: Callabl
31
31
  """In-place version of replace_module that will step through modules that have children and
32
32
  apply the mapper_fn.
33
33
 
34
- Parameters:
34
+ Args:
35
35
  model: A PyTorch model.
36
36
  source_class: the layer class to replace. Each find will be passed to mapper_fn
37
37
  mapper_fn: A callable that takes as argument the layer to replace and returns the new object.
sinabs/from_torch.py CHANGED
@@ -31,7 +31,7 @@ def from_model(
31
31
  analyzed, and a copy is returned, with all ReLUs and NeuromorphicReLUs turned into
32
32
  SpikingLayers.
33
33
 
34
- Parameters:
34
+ Args:
35
35
  model: Torch model
36
36
  input_shape: If provided, the layer dimensions are computed. Otherwise they will be computed at the first forward pass.
37
37
  spike_threshold: The membrane potential threshold for spiking (same for all layers).
@@ -101,6 +101,28 @@ def from_model(
101
101
  **kwargs_backend,
102
102
  ).to(device),
103
103
  )
104
+
105
+ elif isinstance(model, nn.Module):
106
+ layers = [layer for _, layer in model.named_children()]
107
+
108
+ if not isinstance(layers[-1], (nn.ReLU, sl.NeuromorphicReLU)):
109
+ snn.add_module(
110
+ "spike_output",
111
+ spike_layer_class(
112
+ spike_threshold=spike_threshold,
113
+ spike_fn=spike_fn,
114
+ reset_fn=reset_fn,
115
+ surrogate_grad_fn=surrogate_grad_fn,
116
+ min_v_mem=min_v_mem,
117
+ **kwargs_backend,
118
+ ).to(device),
119
+ )
120
+
121
+ else:
122
+ warn(
123
+ "Spiking output can only be added to sequential models that do not end in a ReLU. No layer has been added."
124
+ )
125
+
104
126
  else:
105
127
  warn(
106
128
  "Spiking output can only be added to sequential models that do not end in a ReLU. No layer has been added."
sinabs/hooks.py CHANGED
@@ -12,7 +12,7 @@ from sinabs.layers import SqueezeMixin, StatefulLayer
12
12
  def _extract_single_input(input_data: List[Any]) -> Any:
13
13
  """Extract single element of a list.
14
14
 
15
- Parameters:
15
+ Args:
16
16
  input_data: List that should have only one element
17
17
 
18
18
  Returns:
@@ -36,14 +36,13 @@ def conv_connection_map(
36
36
  """Generate connectivity map for a convolutional layer The map indicates for each element in
37
37
  the layer input to how many postsynaptic neurons it connects (i.e. the fanout)
38
38
 
39
- Parameters:
40
- layer: Convolutional layer for which connectivity map is to be
41
- generated
39
+ Args:
40
+ layer: Convolutional layer for which connectivity map is to be generated
42
41
  input_shape: Shape of the input data (N, C, Y, X)
43
42
  output_shape: Shape of layer output given `input_shape`
44
43
  device: Device on which the connectivity map should reside.
45
- Should be the same as that of the input to `layer`.
46
- If None, will select device of the weight of `layer`.
44
+ Should be the same as that of the input to `layer`.
45
+ If None, will select device of the weight of `layer`.
47
46
 
48
47
  Returns:
49
48
  torch.Tensor: Connectivity map indicating the fanout for each
@@ -78,10 +77,9 @@ def get_hook_data_dict(module: nn.Module) -> Dict:
78
77
  """Convenience function to get `hook_data` attribute of a module if it has one and create it
79
78
  otherwise.
80
79
 
81
- Parameters:
80
+ Args:
82
81
  module: The module whose `hook_data` dict is to be fetched.
83
- If it does not have an attribute of that name, it
84
- will add an empty dict.
82
+ If it does not have an attribute of that name, it will add an empty dict.
85
83
  Returns:
86
84
  The `hook_data` attribute of `module`. Should be a Dict.
87
85
  """
@@ -107,15 +105,15 @@ def input_diff_hook(
107
105
  at each forward pass. Afterwards the data can be accessed with
108
106
  `module.hook_data['diff_output']`
109
107
 
110
- Parameters:
108
+ If `module` does not already have a `hook_data` attribute, it
109
+ will be added and the difference value described above will be
110
+ stored under the key 'diff_output'. It is a tensor of the same
111
+ shape as `output`.
112
+
113
+ Args:
111
114
  module: Either a torch.nn.Conv2d or Linear layer
112
115
  input_: List of inputs to the layer. Should hold a single tensor.
113
116
  output: The layer's output.
114
- Effect:
115
- If `module` does not already have a `hook_data` attribute, it
116
- will be added and the difference value described above will be
117
- stored under the key 'diff_output'. It is a tensor of the same
118
- shape as `output`.
119
117
  """
120
118
  data = get_hook_data_dict(module)
121
119
  input_ = _extract_single_input(input_)
@@ -147,14 +145,14 @@ def firing_rate_hook(module: StatefulLayer, input_: Any, output: torch.Tensor):
147
145
  at each forward pass. Afterwards the data can be accessed with
148
146
  `module.hook_data['firing_rate']`
149
147
 
150
- Parameters:
148
+ If `module` does not already have a `hook_data` attribute, it
149
+ will be added and the mean firing rate will be stored under the
150
+ key 'firing_rate'. It is a scalar value.
151
+
152
+ Args:
151
153
  module: A spiking sinabs layer, such as `IAF` or `LIF`.
152
154
  input_: List of inputs to the layer. Ignored here.
153
155
  output: The layer's output.
154
- Effect:
155
- If `module` does not already have a `hook_data` attribute, it
156
- will be added and the mean firing rate will be stored under the
157
- key 'firing_rate'. It is a scalar value.
158
156
  """
159
157
  data = get_hook_data_dict(module)
160
158
  data["firing_rate"] = output.mean()
@@ -171,16 +169,15 @@ def firing_rate_per_neuron_hook(
171
169
  `torch.register_forward_hook`. It will be called automatically
172
170
  at each forward pass. Afterwards the data can be accessed with
173
171
  `module.hook_data['firing_rate_per_neuron']`
172
+ If `module` does not already have a `hook_data` attribute, it
173
+ will be added and the mean firing rate will be stored under the
174
+ key 'firing_rate_per_neuron'. It is a tensor of the same
175
+ shape as neurons of the spiking layer.
174
176
 
175
- Parameters:
177
+ Args:
176
178
  module: A spiking sinabs layer, such as `IAF` or `LIF`.
177
179
  input_: List of inputs to the layer. Ignored here.
178
180
  output: The layer's output.
179
- Effect:
180
- If `module` does not already have a `hook_data` attribute, it
181
- will be added and the mean firing rate will be stored under the
182
- key 'firing_rate_per_neuron'. It is a tensor of the same
183
- shape as neurons of the spiking layer.
184
181
  """
185
182
  data = get_hook_data_dict(module)
186
183
  if isinstance(module, SqueezeMixin):
@@ -208,16 +205,16 @@ def conv_layer_synops_hook(
208
205
  at each forward pass. Afterwards the data can be accessed with
209
206
  `module.hook_data['layer_synops_per_timestep']`
210
207
 
211
- Parameters:
208
+ If `module` does not already have a `hook_data` attribute, it
209
+ will be added and the mean firing rate will be stored under the
210
+ key 'layer_synops_per_timestep'. It is a scalar value.
211
+ It will also store a connectivity map under the key 'connection_map',
212
+ which holds the fanout for each input neuron.
213
+
214
+ Args:
212
215
  module: A torch.nn.Conv2d layer
213
216
  input_: List of inputs to the layer. Must contain exactly one tensor
214
217
  output: The layer's output.
215
- Effect:
216
- If `module` does not already have a `hook_data` attribute, it
217
- will be added and the mean firing rate will be stored under the
218
- key 'layer_synops_per_timestep'. It is a scalar value.
219
- It will also store a connectivity map under the key 'connection_map',
220
- which holds the fanout for each input neuron.
221
218
  """
222
219
  data = get_hook_data_dict(module)
223
220
  input_ = _extract_single_input(input_)
@@ -252,14 +249,14 @@ def linear_layer_synops_hook(
252
249
  at each forward pass. Afterwards the data can be accessed with
253
250
  `module.hook_data['layer_synops_per_timestep']`
254
251
 
255
- Parameters:
252
+ If `module` does not already have a `hook_data` attribute, it
253
+ will be added and the mean firing rate will be stored under the
254
+ key 'layer_synops_per_timestep'.
255
+
256
+ Args:
256
257
  module: A torch.nn.Linear layer.
257
258
  input_: List of inputs to the layer. Must contain exactly one tensor
258
259
  output: The layer's output.
259
- Effect:
260
- If `module` does not already have a `hook_data` attribute, it
261
- will be added and the mean firing rate will be stored under the
262
- key 'layer_synops_per_timestep'.
263
260
  """
264
261
  data = get_hook_data_dict(module)
265
262
  input_ = _extract_single_input(input_)
@@ -308,7 +305,7 @@ class ModelSynopsHook:
308
305
  the synaptic operations per second, under the assumption that `dt`
309
306
  is the time step in seconds.
310
307
 
311
- Parameters:
308
+ Attributes:
312
309
  dt: If not None, should be a float that indicates the simulation
313
310
  time step in seconds. The synaptic operations will be also
314
311
  provided in terms of synops per second.
@@ -320,7 +317,7 @@ class ModelSynopsHook:
320
317
  """Forward call of the synops model hook. Should not be called manually but only by PyTorch
321
318
  during a forward pass.
322
319
 
323
- Parameters:
320
+ Args:
324
321
  module: A torch.nn.Sequential
325
322
  input_: List of inputs to the module.
326
323
  output: The module output.
@@ -380,7 +377,7 @@ def register_synops_hooks(module: nn.Sequential, dt: Optional[float] = None):
380
377
  This can be used instead of calling the torch function
381
378
  `register_forward_hook` on all layers.
382
379
 
383
- Parameters:
380
+ Args:
384
381
  module: Sequential model for which the hooks should be registered.
385
382
  dt: If not None, should be a float indicating the simulation
386
383
  time step in seconds. Will also calculate synaptic operations per second.
sinabs/layers/alif.py CHANGED
@@ -38,19 +38,19 @@ class ALIF(StatefulLayer):
38
38
  i(t+1) = \\alpha_{syn} i(t) (1-\\alpha_{syn}) + input
39
39
 
40
40
 
41
- Parameters:
41
+ Args:
42
42
  tau_mem: Membrane potential time constant.
43
43
  tau_adapt: Spike threshold time constant.
44
44
  tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.
45
45
  adapt_scale: The amount that the spike threshold is bumped up for every spike, after which it decays back to the initial threshold.
46
46
  spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0.
47
47
  spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states,
48
- a spike threshold and a surrogate gradient function and returns spikes. Be aware
49
- that the class itself is passed here (because torch.autograd methods are static)
50
- rather than an object instance.
48
+ a spike threshold and a surrogate gradient function and returns spikes. Be aware
49
+ that the class itself is passed here (because torch.autograd methods are static)
50
+ rather than an object instance.
51
51
  reset_fn: A function that defines how the membrane potential is reset after a spike.
52
52
  surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the
53
- backward pass. This is a function of membrane potential.
53
+ backward pass. This is a function of membrane potential.
54
54
  min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step.
55
55
  shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size.
56
56
  train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.
@@ -58,8 +58,8 @@ class ALIF(StatefulLayer):
58
58
  record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False.
59
59
 
60
60
  Shape:
61
- - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
62
- - Output: Same as input.
61
+ Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
62
+ Output: Same as input.
63
63
 
64
64
  Attributes:
65
65
  v_mem: The membrane potential resets according to reset_fn for every spike.
@@ -137,7 +137,7 @@ class ALIF(StatefulLayer):
137
137
 
138
138
  def forward(self, input_data: torch.Tensor):
139
139
  """
140
- Parameters:
140
+ Args:
141
141
  input_data: Data to be processed. Expected shape: (batch, time, ...)
142
142
 
143
143
  Returns:
@@ -244,7 +244,7 @@ class ALIFRecurrent(ALIF):
244
244
  .. math ::
245
245
  i(t+1) = \\alpha_{syn} i(t) (1-\\alpha_{syn}) + input
246
246
 
247
- Parameters:
247
+ Args:
248
248
  tau_mem: Membrane potential time constant.
249
249
  tau_adapt: Spike threshold time constant.
250
250
  rec_connect: An nn.Module which defines the recurrent connectivity, e.g. nn.Linear
@@ -252,12 +252,12 @@ class ALIFRecurrent(ALIF):
252
252
  adapt_scale: The amount that the spike threshold is bumped up for every spike, after which it decays back to the initial threshold.
253
253
  spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0.
254
254
  spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states,
255
- a spike threshold and a surrogate gradient function and returns spikes. Be aware
256
- that the class itself is passed here (because torch.autograd methods are static)
257
- rather than an object instance.
255
+ a spike threshold and a surrogate gradient function and returns spikes. Be aware
256
+ that the class itself is passed here (because torch.autograd methods are static)
257
+ rather than an object instance.
258
258
  reset_fn: A function that defines how the membrane potential is reset after a spike.
259
259
  surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the
260
- backward pass. This is a function of membrane potential.
260
+ backward pass. This is a function of membrane potential.
261
261
  min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step.
262
262
  shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size.
263
263
  train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.
@@ -265,8 +265,8 @@ class ALIFRecurrent(ALIF):
265
265
  record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False.
266
266
 
267
267
  Shape:
268
- - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
269
- - Output: Same as input.
268
+ Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
269
+ Output: Same as input.
270
270
 
271
271
  Attributes:
272
272
  v_mem: The membrane potential resets according to reset_fn for every spike.
@@ -311,7 +311,7 @@ class ALIFRecurrent(ALIF):
311
311
 
312
312
  def forward(self, input_data: torch.Tensor):
313
313
  """
314
- Parameters:
314
+ Args:
315
315
  input_data: Data to be processed. Expected shape: (batch, time, ...)
316
316
 
317
317
  Returns:
sinabs/layers/crop2d.py CHANGED
@@ -9,7 +9,7 @@ ArrayLike = Union[np.ndarray, List, Tuple]
9
9
  class Cropping2dLayer(nn.Module):
10
10
  """Crop input image by.
11
11
 
12
- Parameters:
12
+ Args:
13
13
  cropping: ((top, bottom), (left, right))
14
14
  """
15
15
 
@@ -38,7 +38,7 @@ class Cropping2dLayer(nn.Module):
38
38
  def get_output_shape(self, input_shape: Tuple) -> Tuple:
39
39
  """Retuns the output dimensions.
40
40
 
41
- Parameters:
41
+ Args:
42
42
  input_shape: (channels, height, width)
43
43
 
44
44
  Returns:
sinabs/layers/exp_leak.py CHANGED
@@ -17,7 +17,7 @@ class ExpLeak(LIF):
17
17
 
18
18
  where :math:`\\alpha = e^{-1/tau_{mem}}` and :math:`\\sum z(t)` represents the sum of all input currents at time :math:`t`.
19
19
 
20
- Parameters:
20
+ Args:
21
21
  tau_mem: Membrane potential time constant.
22
22
  min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step.
23
23
  train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.
sinabs/layers/iaf.py CHANGED
@@ -22,24 +22,24 @@ class IAF(LIF):
22
22
 
23
23
  where :math:`\\sum z(t)` represents the sum of all input currents at time :math:`t`.
24
24
 
25
- Parameters:
25
+ Args:
26
26
  spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0.
27
27
  spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states,
28
- a spike threshold and a surrogate gradient function and returns spikes. Be aware
29
- that the class itself is passed here (because torch.autograd methods are static)
30
- rather than an object instance.
28
+ a spike threshold and a surrogate gradient function and returns spikes. Be aware
29
+ that the class itself is passed here (because torch.autograd methods are static)
30
+ rather than an object instance.
31
31
  reset_fn: A function that defines how the membrane potential is reset after a spike.
32
32
  surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the
33
- backward pass. This is a function of membrane potential.
33
+ backward pass. This is a function of membrane potential.
34
34
  tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.
35
35
  min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step.
36
36
  shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size.
37
37
  record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute
38
- `recordings`. Default is False.
38
+ `recordings`. Default is False.
39
39
 
40
40
  Shape:
41
- - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
42
- - Output: Same as input.
41
+ Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
42
+ Output: Same as input.
43
43
 
44
44
  Attributes:
45
45
  v_mem: The membrane potential resets according to reset_fn for every spike.
@@ -99,7 +99,7 @@ class IAFRecurrent(LIFRecurrent):
99
99
 
100
100
  where :math:`\\sum z(t)` represents the sum of all input currents at time :math:`t`.
101
101
 
102
- Parameters:
102
+ Args:
103
103
  rec_connect: An nn.Module which defines the recurrent connectivity, e.g. nn.Linear
104
104
  spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0.
105
105
  spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states,
@@ -115,8 +115,8 @@ class IAFRecurrent(LIFRecurrent):
115
115
  record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False.
116
116
 
117
117
  Shape:
118
- - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
119
- - Output: Same as input.
118
+ Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
119
+ Output: Same as input.
120
120
 
121
121
  Attributes:
122
122
  v_mem: The membrane potential resets according to reset_fn for every spike.
sinabs/layers/lif.py CHANGED
@@ -31,7 +31,7 @@ class LIF(StatefulLayer):
31
31
  .. math ::
32
32
  \\text{if } V_{mem}(t) >= V_{th} \\text{, then } V_{mem} \\rightarrow V_{reset}
33
33
 
34
- Parameters:
34
+ Args:
35
35
  tau_mem: Membrane potential time constant.
36
36
  tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.
37
37
  spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0.
@@ -50,8 +50,8 @@ class LIF(StatefulLayer):
50
50
  attribute `recordings`. Default is False.
51
51
 
52
52
  Shape:
53
- - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
54
- - Output: Same as input.
53
+ Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
54
+ Output: Same as input.
55
55
 
56
56
  Attributes:
57
57
  v_mem: The membrane potential resets according to reset_fn for every spike.
@@ -155,7 +155,7 @@ class LIF(StatefulLayer):
155
155
 
156
156
  def forward(self, input_data: torch.Tensor) -> torch.Tensor:
157
157
  """
158
- Parameters:
158
+ Args:
159
159
  input_data: Data to be processed. Expected shape: (batch, time, ...)
160
160
 
161
161
  Returns:
@@ -254,7 +254,7 @@ class LIFRecurrent(LIF):
254
254
  .. math ::
255
255
  \\text{if } V_{mem}(t) >= V_{th} \\text{, then } V_{mem} \\rightarrow V_{reset}
256
256
 
257
- Parameters:
257
+ Args:
258
258
  tau_mem: Membrane potential time constant.
259
259
  rec_connect: An nn.Module which defines the recurrent connectivity, e.g. nn.Linear
260
260
  tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.
@@ -269,8 +269,8 @@ class LIFRecurrent(LIF):
269
269
  record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False.
270
270
 
271
271
  Shape:
272
- - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
273
- - Output: Same as input.
272
+ Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)`
273
+ Output: Same as input.
274
274
 
275
275
  Attributes:
276
276
  v_mem: The membrane potential resets according to reset_fn for every spike.
@@ -309,8 +309,8 @@ class LIFRecurrent(LIF):
309
309
 
310
310
  def forward(self, input_data: torch.Tensor):
311
311
  """
312
- Parameters:
313
- input_data: Data to be processed. Expected shape: (batch, time, ...)
312
+ Args:
313
+ input_data: Data to be processed. Expected shape: (batch, time, ...).
314
314
 
315
315
  Returns:
316
316
  Output data with same shape as `input_data`.
@@ -7,14 +7,15 @@ class NeuromorphicReLU(torch.nn.Module):
7
7
  """NeuromorphicReLU layer. This layer is NOT used for Sinabs networks; it's useful while
8
8
  training analogue pyTorch networks for future use with Sinabs.
9
9
 
10
- Parameters:
11
- quantize: Whether or not to quantize the output (i.e. floor it to \
12
- the integer below), in order to mimic spiking behavior.
13
- fanout: Useful when computing the number of SynOps of a quantized \
14
- NeuromorphicReLU. The activity can be accessed through \
15
- NeuromorphicReLU.activity, and is multiplied by the value of fanout.
16
- stochastic_rounding: Upon quantization, should the value be rounded stochastically or floored
17
- Only done during training. During evaluation mode, the value is simply floored
10
+ Args:
11
+ quantize: Whether or not to quantize the output (i.e. floor it to the
12
+ integer below), in order to mimic spiking behavior.
13
+ fanout: Useful when computing the number of SynOps of a quantized
14
+ NeuromorphicReLU. The activity can be accessed through
15
+ NeuromorphicReLU.activity, and is multiplied by the value of fanout.
16
+ stochastic_rounding: Upon quantization, should the value be rounded
17
+ stochastically or floored. Only done during training. During
18
+ evaluation mode, the value is simply floored
18
19
  """
19
20
 
20
21
  def __init__(self, quantize=True, fanout=1, stochastic_rounding=False):
sinabs/layers/pool2d.py CHANGED
@@ -74,7 +74,7 @@ class SpikingMaxPooling2dLayer(nn.Module):
74
74
  def get_output_shape(self, input_shape: Tuple) -> Tuple:
75
75
  """Returns the shape of output, given an input to this layer.
76
76
 
77
- Parameters:
77
+ Args:
78
78
  input_shape: (channels, height, width)
79
79
 
80
80
  Returns:
@@ -95,10 +95,10 @@ class SumPool2d(torch.nn.LPPool2d):
95
95
  """Non-spiking sumpooling layer to be used in analogue Torch models. It is identical to
96
96
  torch.nn.LPPool2d with p=1.
97
97
 
98
- Parameters:
99
- kernel_size: the size of the window
100
- stride: the stride of the window. Default value is kernel_size
101
- ceil_mode: when True, will use ceil instead of floor to compute the output shape
98
+ Args:
99
+ kernel_size: the size of the window.
100
+ stride: the stride of the window. Default value is kernel_size.
101
+ ceil_mode: when True, will use ceil instead of floor to compute the output shape.
102
102
  """
103
103
 
104
104
  def __init__(self, kernel_size, stride=None, ceil_mode=False):
sinabs/layers/quantize.py CHANGED
@@ -6,7 +6,7 @@ from sinabs.activation import Quantize
6
6
  class QuantizeLayer(nn.Module):
7
7
  """Layer that quantizes the input, i.e. returns floor(input).
8
8
 
9
- Parameters:
9
+ Args:
10
10
  quantize: If False, this layer will pass on the input without modifying it.
11
11
  """
12
12
 
@@ -8,7 +8,7 @@ class StatefulLayer(torch.nn.Module):
8
8
  """A base class that instantiates buffers/states which update at every time step and provides
9
9
  helper methods that manage those states.
10
10
 
11
- Parameters:
11
+ Args:
12
12
  state_names: the PyTorch buffers to initialise. These are not parameters.
13
13
  """
14
14
 
@@ -97,12 +97,15 @@ class StatefulLayer(torch.nn.Module):
97
97
  ):
98
98
  """Reset the state/buffers in a layer.
99
99
 
100
- Parameters:
101
- randomize: If true, reset the states between a range provided. Else, the states are reset to zero.
102
- value_ranges: A dictionary of key value pairs: buffer_name -> (min, max) for each state that needs to be reset.
103
- The states are reset with a uniform distribution between the min and max values specified.
104
- Any state with an undefined key in this dictionary will be reset between 0 and 1
105
- This parameter is only used if randomize is set to true.
100
+ Args:
101
+ randomize: If true, reset the states between a range provided.
102
+ Else, the states are reset to zero.
103
+ value_ranges: A dictionary of key value pairs: buffer_name -> (min,
104
+ max) for each state that needs to be reset. The states are
105
+ reset with a uniform distribution between the min and max
106
+ values specified. Any state with an undefined key in this
107
+ dictionary will be reset between 0 and 1. This parameter is
108
+ only used if randomize is set to true.
106
109
 
107
110
  .. note:: If you would like to reset the state with a custom distribution, you can do this individually for each parameter as follows::
108
111
 
sinabs/layers/to_spike.py CHANGED
@@ -7,15 +7,15 @@ from torch import nn
7
7
  class Img2SpikeLayer(nn.Module):
8
8
  """Layer to convert images to spikes.
9
9
 
10
- Parameters:
11
- image_shape: tuple image shape
12
- tw: int Time window length
13
- max_rate: maximum firing rate of neurons
14
- layer_name: string layer name
15
- norm: the supposed maximum value of the input (default 255.0)
16
- squeeze: whether to remove singleton dimensions from the input
10
+ Args:
11
+ image_shape: tuple image shape.
12
+ tw: int Time window length.
13
+ max_rate: maximum firing rate of neurons.
14
+ layer_name: string layer name.
15
+ norm: the supposed maximum value of the input (default 255.0).
16
+ squeeze: whether to remove singleton dimensions from the input.
17
17
  negative_spikes: whether to allow negative spikes in response
18
- to negative input
18
+ to negative input.
19
19
  """
20
20
 
21
21
  def __init__(
@@ -60,7 +60,7 @@ class Img2SpikeLayer(nn.Module):
60
60
  class Sig2SpikeLayer(torch.nn.Module):
61
61
  """Layer to convert analog Signals to spikes.
62
62
 
63
- Parameters:
63
+ Args:
64
64
  channels_in: number of channels in the analog signal
65
65
  tw: int number of time steps for each sample of the signal (up sampling)
66
66
  layer_name: string layer name