brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -146
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -470
  58. brainstate/nn/_delay_test.py +238 -0
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1361
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1120
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -208
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.7.dist-info/RECORD +0 -131
  133. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/_dropout.py CHANGED
@@ -1,426 +1,426 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from functools import partial
18
- from typing import Optional, Sequence
19
-
20
- import brainunit as u
21
- import jax.numpy as jnp
22
-
23
- from brainstate import random, environ, init
24
- from brainstate._state import ShortTermState
25
- from brainstate.typing import Size
26
- from ._module import ElementWiseBlock
27
-
28
- __all__ = [
29
- 'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
30
- ]
31
-
32
-
33
- class Dropout(ElementWiseBlock):
34
- """A layer that stochastically ignores a subset of inputs each training step.
35
-
36
- In training, to compensate for the fraction of input values dropped (`rate`),
37
- all surviving values are multiplied by `1 / (1 - rate)`.
38
-
39
- This layer is active only during training (``mode=brainstate.mixin.Training``). In other
40
- circumstances it is a no-op.
41
-
42
- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
43
- neural networks from overfitting." The journal of machine learning
44
- research 15.1 (2014): 1929-1958.
45
-
46
- Args:
47
- prob: Probability to keep element of the tensor.
48
- broadcast_dims: dimensions that will share the same dropout mask.
49
- name: str. The name of the dynamic system.
50
-
51
- """
52
- __module__ = 'brainstate.nn'
53
-
54
- def __init__(
55
- self,
56
- prob: float = 0.5,
57
- broadcast_dims: Sequence[int] = (),
58
- name: Optional[str] = None
59
- ) -> None:
60
- super().__init__(name=name)
61
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
62
- self.prob = prob
63
- self.broadcast_dims = broadcast_dims
64
-
65
- def __call__(self, x):
66
- dtype = u.math.get_dtype(x)
67
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
68
- if fit_phase and self.prob < 1.:
69
- broadcast_shape = list(x.shape)
70
- for dim in self.broadcast_dims:
71
- broadcast_shape[dim] = 1
72
- keep_mask = random.bernoulli(self.prob, broadcast_shape)
73
- keep_mask = u.math.broadcast_to(keep_mask, x.shape)
74
- return u.math.where(
75
- keep_mask,
76
- u.math.asarray(x / self.prob, dtype=dtype),
77
- u.math.asarray(0., dtype=dtype)
78
- )
79
- else:
80
- return x
81
-
82
-
83
- class _DropoutNd(ElementWiseBlock):
84
- __module__ = 'brainstate.nn'
85
- prob: float
86
- channel_axis: int
87
- minimal_dim: int
88
-
89
- def __init__(
90
- self,
91
- prob: float = 0.5,
92
- channel_axis: int = -1,
93
- name: Optional[str] = None
94
- ) -> None:
95
- super().__init__(name=name)
96
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
97
- self.prob = prob
98
- self.channel_axis = channel_axis
99
-
100
- def __call__(self, x):
101
- # check input shape
102
- inp_dim = u.math.ndim(x)
103
- if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
104
- raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
105
- f"but received a {inp_dim}D input. {self._get_msg(x)}")
106
- is_not_batched = self.minimal_dim
107
- if is_not_batched:
108
- channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
109
- mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
110
- else:
111
- channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
112
- assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
113
- mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
114
-
115
- # get fit phase
116
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
117
-
118
- # generate mask
119
- if fit_phase and self.prob < 1.:
120
- dtype = u.math.get_dtype(x)
121
- keep_mask = random.bernoulli(self.prob, mask_shape)
122
- keep_mask = jnp.broadcast_to(keep_mask, x.shape)
123
- return jnp.where(
124
- keep_mask,
125
- jnp.asarray(x / self.prob, dtype=dtype),
126
- jnp.asarray(0., dtype=dtype)
127
- )
128
- else:
129
- return x
130
-
131
- def _get_msg(self, x):
132
- return ''
133
-
134
-
135
- class Dropout1d(_DropoutNd):
136
- r"""Randomly zero out entire channels (a channel is a 1D feature map,
137
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
138
- batched input is a 1D tensor :math:`\text{input}[i, j]`).
139
- Each channel will be zeroed out independently on every forward call with
140
- probability :attr:`p` using samples from a Bernoulli distribution.
141
-
142
- Usually the input comes from :class:`nn.Conv1d` modules.
143
-
144
- As described in the paper
145
- `Efficient Object Localization Using Convolutional Networks`_ ,
146
- if adjacent pixels within feature maps are strongly correlated
147
- (as is normally the case in early convolution layers) then i.i.d. dropout
148
- will not regularize the activations and will otherwise just result
149
- in an effective learning rate decrease.
150
-
151
- In this case, :func:`nn.Dropout1d` will help promote independence between
152
- feature maps and should be used instead.
153
-
154
- Args:
155
- prob: float. probability of an element to be zero-ed.
156
-
157
- Shape:
158
- - Input: :math:`(N, C, L)` or :math:`(C, L)`.
159
- - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
160
-
161
- Examples::
162
-
163
- >>> m = Dropout1d(p=0.2)
164
- >>> x = random.randn(20, 32, 16)
165
- >>> output = m(x)
166
- >>> output.shape
167
- (20, 32, 16)
168
-
169
- .. _Efficient Object Localization Using Convolutional Networks:
170
- https://arxiv.org/abs/1411.4280
171
- """
172
- __module__ = 'brainstate.nn'
173
- minimal_dim: int = 2
174
-
175
- def _get_msg(self, x):
176
- return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
177
- "spatial dimension, a channel dimension, and an optional batch dimension "
178
- "(i.e. 2D or 3D inputs).")
179
-
180
-
181
- class Dropout2d(_DropoutNd):
182
- r"""Randomly zero out entire channels (a channel is a 2D feature map,
183
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
184
- batched input is a 2D tensor :math:`\text{input}[i, j]`).
185
- Each channel will be zeroed out independently on every forward call with
186
- probability :attr:`p` using samples from a Bernoulli distribution.
187
-
188
- Usually the input comes from :class:`nn.Conv2d` modules.
189
-
190
- As described in the paper
191
- `Efficient Object Localization Using Convolutional Networks`_ ,
192
- if adjacent pixels within feature maps are strongly correlated
193
- (as is normally the case in early convolution layers) then i.i.d. dropout
194
- will not regularize the activations and will otherwise just result
195
- in an effective learning rate decrease.
196
-
197
- In this case, :func:`nn.Dropout2d` will help promote independence between
198
- feature maps and should be used instead.
199
-
200
- Args:
201
- prob: float. probability of an element to be kept.
202
-
203
- Shape:
204
- - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
205
- - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
206
-
207
- Examples::
208
-
209
- >>> m = Dropout2d(p=0.2)
210
- >>> x = random.randn(20, 32, 32, 16)
211
- >>> output = m(x)
212
-
213
- .. _Efficient Object Localization Using Convolutional Networks:
214
- https://arxiv.org/abs/1411.4280
215
- """
216
- __module__ = 'brainstate.nn'
217
- minimal_dim: int = 3
218
-
219
- def _get_msg(self, x):
220
- return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
221
- "spatial dimensions, a channel dimension, and an optional batch dimension "
222
- "(i.e. 3D or 4D inputs).")
223
-
224
-
225
- class Dropout3d(_DropoutNd):
226
- r"""Randomly zero out entire channels (a channel is a 3D feature map,
227
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
228
- batched input is a 3D tensor :math:`\text{input}[i, j]`).
229
- Each channel will be zeroed out independently on every forward call with
230
- probability :attr:`p` using samples from a Bernoulli distribution.
231
-
232
- Usually the input comes from :class:`nn.Conv3d` modules.
233
-
234
- As described in the paper
235
- `Efficient Object Localization Using Convolutional Networks`_ ,
236
- if adjacent pixels within feature maps are strongly correlated
237
- (as is normally the case in early convolution layers) then i.i.d. dropout
238
- will not regularize the activations and will otherwise just result
239
- in an effective learning rate decrease.
240
-
241
- In this case, :func:`nn.Dropout3d` will help promote independence between
242
- feature maps and should be used instead.
243
-
244
- Args:
245
- prob: float. probability of an element to be kept.
246
-
247
- Shape:
248
- - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
249
- - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
250
-
251
- Examples::
252
-
253
- >>> m = Dropout3d(p=0.2)
254
- >>> x = random.randn(20, 16, 4, 32, 32)
255
- >>> output = m(x)
256
-
257
- .. _Efficient Object Localization Using Convolutional Networks:
258
- https://arxiv.org/abs/1411.4280
259
- """
260
- __module__ = 'brainstate.nn'
261
- minimal_dim: int = 4
262
-
263
- def _get_msg(self, x):
264
- return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
265
- "spatial dimensions, a channel dimension, and an optional batch dimension "
266
- "(i.e. 4D or 5D inputs).")
267
-
268
-
269
- class AlphaDropout(_DropoutNd):
270
- r"""Applies Alpha Dropout over the input.
271
-
272
- Alpha Dropout is a type of Dropout that maintains the self-normalizing
273
- property.
274
- For an input with zero mean and unit standard deviation, the output of
275
- Alpha Dropout maintains the original mean and standard deviation of the
276
- input.
277
- Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
278
- that the outputs have zero mean and unit standard deviation.
279
-
280
- During training, it randomly masks some of the elements of the input
281
- tensor with probability *p* using samples from a bernoulli distribution.
282
- The elements to masked are randomized on every forward call, and scaled
283
- and shifted to maintain zero mean and unit standard deviation.
284
-
285
- During evaluation the module simply computes an identity function.
286
-
287
- More details can be found in the paper `Self-Normalizing Neural Networks`_ .
288
-
289
- Args:
290
- prob: float. probability of an element to be kept.
291
-
292
- Shape:
293
- - Input: :math:`(*)`. Input can be of any shape
294
- - Output: :math:`(*)`. Output is of the same shape as input
295
-
296
- Examples::
297
-
298
- >>> m = AlphaDropout(p=0.2)
299
- >>> x = random.randn(20, 16)
300
- >>> output = m(x)
301
-
302
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
303
- """
304
- __module__ = 'brainstate.nn'
305
-
306
- def update(self, *args, **kwargs):
307
- raise NotImplementedError("AlphaDropout is not supported in the current version.")
308
-
309
-
310
- class FeatureAlphaDropout(_DropoutNd):
311
- r"""Randomly masks out entire channels (a channel is a feature map,
312
- e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
313
- is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
314
- setting activations to zero, as in regular Dropout, the activations are set
315
- to the negative saturation value of the SELU activation function. More details
316
- can be found in the paper `Self-Normalizing Neural Networks`_ .
317
-
318
- Each element will be masked independently for each sample on every forward
319
- call with probability :attr:`p` using samples from a Bernoulli distribution.
320
- The elements to be masked are randomized on every forward call, and scaled
321
- and shifted to maintain zero mean and unit variance.
322
-
323
- Usually the input comes from :class:`nn.AlphaDropout` modules.
324
-
325
- As described in the paper
326
- `Efficient Object Localization Using Convolutional Networks`_ ,
327
- if adjacent pixels within feature maps are strongly correlated
328
- (as is normally the case in early convolution layers) then i.i.d. dropout
329
- will not regularize the activations and will otherwise just result
330
- in an effective learning rate decrease.
331
-
332
- In this case, :func:`nn.AlphaDropout` will help promote independence between
333
- feature maps and should be used instead.
334
-
335
- Args:
336
- prob: float. probability of an element to be kept.
337
-
338
- Shape:
339
- - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
340
- - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
341
-
342
- Examples::
343
-
344
- >>> m = FeatureAlphaDropout(p=0.2)
345
- >>> x = random.randn(20, 16, 4, 32, 32)
346
- >>> output = m(x)
347
-
348
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
349
- .. _Efficient Object Localization Using Convolutional Networks:
350
- https://arxiv.org/abs/1411.4280
351
- """
352
- __module__ = 'brainstate.nn'
353
-
354
- def update(self, *args, **kwargs):
355
- raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
356
-
357
-
358
- class DropoutFixed(ElementWiseBlock):
359
- """
360
- A dropout layer with the fixed dropout mask along the time axis once after initialized.
361
-
362
- In training, to compensate for the fraction of input values dropped (`rate`),
363
- all surviving values are multiplied by `1 / (1 - rate)`.
364
-
365
- This layer is active only during training (``mode=brainstate.mixin.Training``). In other
366
- circumstances it is a no-op.
367
-
368
- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
369
- neural networks from overfitting." The journal of machine learning
370
- research 15.1 (2014): 1929-1958.
371
-
372
- .. admonition:: Tip
373
- :class: tip
374
-
375
- This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
376
- Network Architectures <https://arxiv.org/abs/1903.06379>`_:
377
-
378
- There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
379
- training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
380
- are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
381
- iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
382
- the output error and modify the network parameters only at the last time step. For dropout to be effective in
383
- our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
384
- data is not changed, such that the neural network is constituted by the same random subset of units during
385
- each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
386
- each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
387
- iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
388
- are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
389
- time window within an iteration.
390
-
391
- Args:
392
- in_size: The size of the input tensor.
393
- prob: Probability to keep element of the tensor.
394
- mode: Mode. The computation mode of the object.
395
- name: str. The name of the dynamic system.
396
- """
397
- __module__ = 'brainstate.nn'
398
-
399
- def __init__(
400
- self,
401
- in_size: Size,
402
- prob: float = 0.5,
403
- name: Optional[str] = None
404
- ) -> None:
405
- super().__init__(name=name)
406
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
407
- self.prob = prob
408
- self.in_size = in_size
409
- self.out_size = in_size
410
-
411
- def init_state(self, batch_size=None, **kwargs):
412
- if self.prob < 1.:
413
- self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
414
-
415
- def update(self, x):
416
- dtype = u.math.get_dtype(x)
417
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
418
- if fit_phase and self.prob < 1.:
419
- if self.mask.value.shape != x.shape:
420
- raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
421
- f"Please call `init_state()` method first.")
422
- return u.math.where(self.mask.value,
423
- u.math.asarray(x / self.prob, dtype=dtype),
424
- u.math.asarray(0., dtype=dtype) * u.get_unit(x))
425
- else:
426
- return x
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from functools import partial
18
+ from typing import Optional, Sequence
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate import random, environ, init
24
+ from brainstate._state import ShortTermState
25
+ from brainstate.typing import Size
26
+ from ._module import ElementWiseBlock
27
+
28
+ __all__ = [
29
+ 'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
30
+ ]
31
+
32
+
33
+ class Dropout(ElementWiseBlock):
34
+ """A layer that stochastically ignores a subset of inputs each training step.
35
+
36
+ In training, to compensate for the fraction of input values dropped (`rate`),
37
+ all surviving values are multiplied by `1 / (1 - rate)`.
38
+
39
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
40
+ circumstances it is a no-op.
41
+
42
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
43
+ neural networks from overfitting." The journal of machine learning
44
+ research 15.1 (2014): 1929-1958.
45
+
46
+ Args:
47
+ prob: Probability to keep element of the tensor.
48
+ broadcast_dims: dimensions that will share the same dropout mask.
49
+ name: str. The name of the dynamic system.
50
+
51
+ """
52
+ __module__ = 'brainstate.nn'
53
+
54
+ def __init__(
55
+ self,
56
+ prob: float = 0.5,
57
+ broadcast_dims: Sequence[int] = (),
58
+ name: Optional[str] = None
59
+ ) -> None:
60
+ super().__init__(name=name)
61
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
62
+ self.prob = prob
63
+ self.broadcast_dims = broadcast_dims
64
+
65
+ def __call__(self, x):
66
+ dtype = u.math.get_dtype(x)
67
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
68
+ if fit_phase and self.prob < 1.:
69
+ broadcast_shape = list(x.shape)
70
+ for dim in self.broadcast_dims:
71
+ broadcast_shape[dim] = 1
72
+ keep_mask = random.bernoulli(self.prob, broadcast_shape)
73
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
74
+ return u.math.where(
75
+ keep_mask,
76
+ u.math.asarray(x / self.prob, dtype=dtype),
77
+ u.math.asarray(0., dtype=dtype)
78
+ )
79
+ else:
80
+ return x
81
+
82
+
83
+ class _DropoutNd(ElementWiseBlock):
84
+ __module__ = 'brainstate.nn'
85
+ prob: float
86
+ channel_axis: int
87
+ minimal_dim: int
88
+
89
+ def __init__(
90
+ self,
91
+ prob: float = 0.5,
92
+ channel_axis: int = -1,
93
+ name: Optional[str] = None
94
+ ) -> None:
95
+ super().__init__(name=name)
96
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
97
+ self.prob = prob
98
+ self.channel_axis = channel_axis
99
+
100
+ def __call__(self, x):
101
+ # check input shape
102
+ inp_dim = u.math.ndim(x)
103
+ if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
104
+ raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
105
+ f"but received a {inp_dim}D input. {self._get_msg(x)}")
106
+ is_not_batched = self.minimal_dim
107
+ if is_not_batched:
108
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
109
+ mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
110
+ else:
111
+ channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
112
+ assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
113
+ mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
114
+
115
+ # get fit phase
116
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
117
+
118
+ # generate mask
119
+ if fit_phase and self.prob < 1.:
120
+ dtype = u.math.get_dtype(x)
121
+ keep_mask = random.bernoulli(self.prob, mask_shape)
122
+ keep_mask = jnp.broadcast_to(keep_mask, x.shape)
123
+ return jnp.where(
124
+ keep_mask,
125
+ jnp.asarray(x / self.prob, dtype=dtype),
126
+ jnp.asarray(0., dtype=dtype)
127
+ )
128
+ else:
129
+ return x
130
+
131
+ def _get_msg(self, x):
132
+ return ''
133
+
134
+
135
+ class Dropout1d(_DropoutNd):
136
+ r"""Randomly zero out entire channels (a channel is a 1D feature map,
137
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
138
+ batched input is a 1D tensor :math:`\text{input}[i, j]`).
139
+ Each channel will be zeroed out independently on every forward call with
140
+ probability :attr:`p` using samples from a Bernoulli distribution.
141
+
142
+ Usually the input comes from :class:`nn.Conv1d` modules.
143
+
144
+ As described in the paper
145
+ `Efficient Object Localization Using Convolutional Networks`_ ,
146
+ if adjacent pixels within feature maps are strongly correlated
147
+ (as is normally the case in early convolution layers) then i.i.d. dropout
148
+ will not regularize the activations and will otherwise just result
149
+ in an effective learning rate decrease.
150
+
151
+ In this case, :func:`nn.Dropout1d` will help promote independence between
152
+ feature maps and should be used instead.
153
+
154
+ Args:
155
+ prob: float. probability of an element to be zero-ed.
156
+
157
+ Shape:
158
+ - Input: :math:`(N, C, L)` or :math:`(C, L)`.
159
+ - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
160
+
161
+ Examples::
162
+
163
+ >>> m = Dropout1d(p=0.2)
164
+ >>> x = random.randn(20, 32, 16)
165
+ >>> output = m(x)
166
+ >>> output.shape
167
+ (20, 32, 16)
168
+
169
+ .. _Efficient Object Localization Using Convolutional Networks:
170
+ https://arxiv.org/abs/1411.4280
171
+ """
172
+ __module__ = 'brainstate.nn'
173
+ minimal_dim: int = 2
174
+
175
+ def _get_msg(self, x):
176
+ return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
177
+ "spatial dimension, a channel dimension, and an optional batch dimension "
178
+ "(i.e. 2D or 3D inputs).")
179
+
180
+
181
+ class Dropout2d(_DropoutNd):
182
+ r"""Randomly zero out entire channels (a channel is a 2D feature map,
183
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
184
+ batched input is a 2D tensor :math:`\text{input}[i, j]`).
185
+ Each channel will be zeroed out independently on every forward call with
186
+ probability :attr:`p` using samples from a Bernoulli distribution.
187
+
188
+ Usually the input comes from :class:`nn.Conv2d` modules.
189
+
190
+ As described in the paper
191
+ `Efficient Object Localization Using Convolutional Networks`_ ,
192
+ if adjacent pixels within feature maps are strongly correlated
193
+ (as is normally the case in early convolution layers) then i.i.d. dropout
194
+ will not regularize the activations and will otherwise just result
195
+ in an effective learning rate decrease.
196
+
197
+ In this case, :func:`nn.Dropout2d` will help promote independence between
198
+ feature maps and should be used instead.
199
+
200
+ Args:
201
+ prob: float. probability of an element to be kept.
202
+
203
+ Shape:
204
+ - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
205
+ - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
206
+
207
+ Examples::
208
+
209
+ >>> m = Dropout2d(p=0.2)
210
+ >>> x = random.randn(20, 32, 32, 16)
211
+ >>> output = m(x)
212
+
213
+ .. _Efficient Object Localization Using Convolutional Networks:
214
+ https://arxiv.org/abs/1411.4280
215
+ """
216
+ __module__ = 'brainstate.nn'
217
+ minimal_dim: int = 3
218
+
219
+ def _get_msg(self, x):
220
+ return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
221
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
222
+ "(i.e. 3D or 4D inputs).")
223
+
224
+
225
+ class Dropout3d(_DropoutNd):
226
+ r"""Randomly zero out entire channels (a channel is a 3D feature map,
227
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
228
+ batched input is a 3D tensor :math:`\text{input}[i, j]`).
229
+ Each channel will be zeroed out independently on every forward call with
230
+ probability :attr:`p` using samples from a Bernoulli distribution.
231
+
232
+ Usually the input comes from :class:`nn.Conv3d` modules.
233
+
234
+ As described in the paper
235
+ `Efficient Object Localization Using Convolutional Networks`_ ,
236
+ if adjacent pixels within feature maps are strongly correlated
237
+ (as is normally the case in early convolution layers) then i.i.d. dropout
238
+ will not regularize the activations and will otherwise just result
239
+ in an effective learning rate decrease.
240
+
241
+ In this case, :func:`nn.Dropout3d` will help promote independence between
242
+ feature maps and should be used instead.
243
+
244
+ Args:
245
+ prob: float. probability of an element to be kept.
246
+
247
+ Shape:
248
+ - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
249
+ - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
250
+
251
+ Examples::
252
+
253
+ >>> m = Dropout3d(p=0.2)
254
+ >>> x = random.randn(20, 16, 4, 32, 32)
255
+ >>> output = m(x)
256
+
257
+ .. _Efficient Object Localization Using Convolutional Networks:
258
+ https://arxiv.org/abs/1411.4280
259
+ """
260
+ __module__ = 'brainstate.nn'
261
+ minimal_dim: int = 4
262
+
263
+ def _get_msg(self, x):
264
+ return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
265
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
266
+ "(i.e. 4D or 5D inputs).")
267
+
268
+
269
+ class AlphaDropout(_DropoutNd):
270
+ r"""Applies Alpha Dropout over the input.
271
+
272
+ Alpha Dropout is a type of Dropout that maintains the self-normalizing
273
+ property.
274
+ For an input with zero mean and unit standard deviation, the output of
275
+ Alpha Dropout maintains the original mean and standard deviation of the
276
+ input.
277
+ Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
278
+ that the outputs have zero mean and unit standard deviation.
279
+
280
+ During training, it randomly masks some of the elements of the input
281
+ tensor with probability *p* using samples from a bernoulli distribution.
282
+ The elements to masked are randomized on every forward call, and scaled
283
+ and shifted to maintain zero mean and unit standard deviation.
284
+
285
+ During evaluation the module simply computes an identity function.
286
+
287
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
288
+
289
+ Args:
290
+ prob: float. probability of an element to be kept.
291
+
292
+ Shape:
293
+ - Input: :math:`(*)`. Input can be of any shape
294
+ - Output: :math:`(*)`. Output is of the same shape as input
295
+
296
+ Examples::
297
+
298
+ >>> m = AlphaDropout(p=0.2)
299
+ >>> x = random.randn(20, 16)
300
+ >>> output = m(x)
301
+
302
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
303
+ """
304
+ __module__ = 'brainstate.nn'
305
+
306
+ def update(self, *args, **kwargs):
307
+ raise NotImplementedError("AlphaDropout is not supported in the current version.")
308
+
309
+
310
+ class FeatureAlphaDropout(_DropoutNd):
311
+ r"""Randomly masks out entire channels (a channel is a feature map,
312
+ e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
313
+ is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
314
+ setting activations to zero, as in regular Dropout, the activations are set
315
+ to the negative saturation value of the SELU activation function. More details
316
+ can be found in the paper `Self-Normalizing Neural Networks`_ .
317
+
318
+ Each element will be masked independently for each sample on every forward
319
+ call with probability :attr:`p` using samples from a Bernoulli distribution.
320
+ The elements to be masked are randomized on every forward call, and scaled
321
+ and shifted to maintain zero mean and unit variance.
322
+
323
+ Usually the input comes from :class:`nn.AlphaDropout` modules.
324
+
325
+ As described in the paper
326
+ `Efficient Object Localization Using Convolutional Networks`_ ,
327
+ if adjacent pixels within feature maps are strongly correlated
328
+ (as is normally the case in early convolution layers) then i.i.d. dropout
329
+ will not regularize the activations and will otherwise just result
330
+ in an effective learning rate decrease.
331
+
332
+ In this case, :func:`nn.AlphaDropout` will help promote independence between
333
+ feature maps and should be used instead.
334
+
335
+ Args:
336
+ prob: float. probability of an element to be kept.
337
+
338
+ Shape:
339
+ - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
340
+ - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
341
+
342
+ Examples::
343
+
344
+ >>> m = FeatureAlphaDropout(p=0.2)
345
+ >>> x = random.randn(20, 16, 4, 32, 32)
346
+ >>> output = m(x)
347
+
348
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
349
+ .. _Efficient Object Localization Using Convolutional Networks:
350
+ https://arxiv.org/abs/1411.4280
351
+ """
352
+ __module__ = 'brainstate.nn'
353
+
354
+ def update(self, *args, **kwargs):
355
+ raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
356
+
357
+
358
+ class DropoutFixed(ElementWiseBlock):
359
+ """
360
+ A dropout layer with the fixed dropout mask along the time axis once after initialized.
361
+
362
+ In training, to compensate for the fraction of input values dropped (`rate`),
363
+ all surviving values are multiplied by `1 / (1 - rate)`.
364
+
365
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
366
+ circumstances it is a no-op.
367
+
368
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
369
+ neural networks from overfitting." The journal of machine learning
370
+ research 15.1 (2014): 1929-1958.
371
+
372
+ .. admonition:: Tip
373
+ :class: tip
374
+
375
+ This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
376
+ Network Architectures <https://arxiv.org/abs/1903.06379>`_:
377
+
378
+ There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
379
+ training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
380
+ are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
381
+ iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
382
+ the output error and modify the network parameters only at the last time step. For dropout to be effective in
383
+ our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
384
+ data is not changed, such that the neural network is constituted by the same random subset of units during
385
+ each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
386
+ each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
387
+ iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
388
+ are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
389
+ time window within an iteration.
390
+
391
+ Args:
392
+ in_size: The size of the input tensor.
393
+ prob: Probability to keep element of the tensor.
394
+ mode: Mode. The computation mode of the object.
395
+ name: str. The name of the dynamic system.
396
+ """
397
+ __module__ = 'brainstate.nn'
398
+
399
+ def __init__(
400
+ self,
401
+ in_size: Size,
402
+ prob: float = 0.5,
403
+ name: Optional[str] = None
404
+ ) -> None:
405
+ super().__init__(name=name)
406
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
407
+ self.prob = prob
408
+ self.in_size = in_size
409
+ self.out_size = in_size
410
+
411
+ def init_state(self, batch_size=None, **kwargs):
412
+ if self.prob < 1.:
413
+ self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
414
+
415
+ def update(self, x):
416
+ dtype = u.math.get_dtype(x)
417
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
418
+ if fit_phase and self.prob < 1.:
419
+ if self.mask.value.shape != x.shape:
420
+ raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
421
+ f"Please call `init_state()` method first.")
422
+ return u.math.where(self.mask.value,
423
+ u.math.asarray(x / self.prob, dtype=dtype),
424
+ u.math.asarray(0., dtype=dtype) * u.get_unit(x))
425
+ else:
426
+ return x