brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_dropout.py CHANGED
@@ -1,618 +1,618 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from functools import partial
18
- from typing import Optional, Sequence
19
-
20
- import brainunit as u
21
- import jax.numpy as jnp
22
-
23
- from brainstate import random, environ
24
- from brainstate._state import ShortTermState
25
- from brainstate.typing import Size
26
- from . import init as init
27
- from ._module import ElementWiseBlock
28
-
29
- __all__ = [
30
- 'Dropout',
31
- 'Dropout1d',
32
- 'Dropout2d',
33
- 'Dropout3d',
34
- 'AlphaDropout',
35
- 'FeatureAlphaDropout',
36
- 'DropoutFixed',
37
- ]
38
-
39
-
40
- class Dropout(ElementWiseBlock):
41
- """A layer that stochastically ignores a subset of inputs each training step.
42
-
43
- In training, to compensate for the fraction of input values dropped (`rate`),
44
- all surviving values are multiplied by `1 / (1 - rate)`.
45
-
46
- This layer is active only during training (``mode=brainstate.mixin.Training``). In other
47
- circumstances it is a no-op.
48
-
49
- Parameters
50
- ----------
51
- prob : float
52
- Probability to keep element of the tensor. Default is 0.5.
53
- broadcast_dims : Sequence[int]
54
- Dimensions that will share the same dropout mask. Default is ().
55
- name : str, optional
56
- The name of the dynamic system.
57
-
58
- References
59
- ----------
60
- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
61
- neural networks from overfitting." The journal of machine learning
62
- research 15.1 (2014): 1929-1958.
63
-
64
- Examples
65
- --------
66
- .. code-block:: python
67
-
68
- >>> import brainstate
69
- >>> layer = brainstate.nn.Dropout(prob=0.8)
70
- >>> x = brainstate.random.randn(10, 20)
71
- >>> with brainstate.environ.context(fit=True):
72
- ... output = layer(x)
73
- >>> output.shape
74
- (10, 20)
75
-
76
- """
77
- __module__ = 'brainstate.nn'
78
-
79
- def __init__(
80
- self,
81
- prob: float = 0.5,
82
- broadcast_dims: Sequence[int] = (),
83
- name: Optional[str] = None
84
- ) -> None:
85
- super().__init__(name=name)
86
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
87
- self.prob = prob
88
- self.broadcast_dims = broadcast_dims
89
-
90
- def __call__(self, x):
91
- dtype = u.math.get_dtype(x)
92
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
93
- if fit_phase and self.prob < 1.:
94
- broadcast_shape = list(x.shape)
95
- for dim in self.broadcast_dims:
96
- broadcast_shape[dim] = 1
97
- keep_mask = random.bernoulli(self.prob, broadcast_shape)
98
- keep_mask = u.math.broadcast_to(keep_mask, x.shape)
99
- return u.math.where(
100
- keep_mask,
101
- u.math.asarray(x / self.prob, dtype=dtype),
102
- u.math.asarray(0., dtype=dtype)
103
- )
104
- else:
105
- return x
106
-
107
-
108
- class _DropoutNd(ElementWiseBlock):
109
- __module__ = 'brainstate.nn'
110
- prob: float
111
- channel_axis: int
112
- minimal_dim: int
113
-
114
- def __init__(
115
- self,
116
- prob: float = 0.5,
117
- channel_axis: int = -1,
118
- name: Optional[str] = None
119
- ) -> None:
120
- super().__init__(name=name)
121
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
122
- self.prob = prob
123
- self.channel_axis = channel_axis
124
-
125
- def __call__(self, x):
126
- # check input shape
127
- inp_dim = u.math.ndim(x)
128
- if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
129
- raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
130
- f"but received a {inp_dim}D input. {self._get_msg(x)}")
131
- is_not_batched = self.minimal_dim
132
- if is_not_batched:
133
- channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
134
- mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
135
- else:
136
- channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
137
- assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
138
- mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
139
-
140
- # get fit phase
141
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
142
-
143
- # generate mask
144
- if fit_phase and self.prob < 1.:
145
- dtype = u.math.get_dtype(x)
146
- keep_mask = random.bernoulli(self.prob, mask_shape)
147
- keep_mask = jnp.broadcast_to(keep_mask, x.shape)
148
- return jnp.where(
149
- keep_mask,
150
- jnp.asarray(x / self.prob, dtype=dtype),
151
- jnp.asarray(0., dtype=dtype)
152
- )
153
- else:
154
- return x
155
-
156
- def _get_msg(self, x):
157
- return ''
158
-
159
-
160
- class Dropout1d(_DropoutNd):
161
- r"""Randomly zero out entire channels (a channel is a 1D feature map).
162
-
163
- Each channel will be zeroed out independently on every forward call with
164
- probability using samples from a Bernoulli distribution. The channel is
165
- a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
166
- in the batched input is a 1D tensor :math:`\text{input}[i, j]`.
167
-
168
- Usually the input comes from :class:`Conv1d` modules.
169
-
170
- As described in the paper [1]_, if adjacent pixels within feature maps are
171
- strongly correlated (as is normally the case in early convolution layers)
172
- then i.i.d. dropout will not regularize the activations and will otherwise
173
- just result in an effective learning rate decrease.
174
-
175
- In this case, :class:`Dropout1d` will help promote independence between
176
- feature maps and should be used instead.
177
-
178
- Parameters
179
- ----------
180
- prob : float
181
- Probability of an element to be kept. Default is 0.5.
182
- channel_axis : int
183
- The axis representing the channel dimension. Default is -1.
184
- name : str, optional
185
- The name of the dynamic system.
186
-
187
- Notes
188
- -----
189
- Input shape: :math:`(N, C, L)` or :math:`(C, L)`.
190
-
191
- Output shape: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
192
-
193
- References
194
- ----------
195
- .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
196
- https://arxiv.org/abs/1411.4280
197
-
198
- Examples
199
- --------
200
- .. code-block:: python
201
-
202
- >>> import brainstate
203
- >>> m = brainstate.nn.Dropout1d(prob=0.8)
204
- >>> x = brainstate.random.randn(20, 32, 16)
205
- >>> with brainstate.environ.context(fit=True):
206
- ... output = m(x)
207
- >>> output.shape
208
- (20, 32, 16)
209
-
210
- """
211
- __module__ = 'brainstate.nn'
212
- minimal_dim: int = 2
213
-
214
- def _get_msg(self, x):
215
- return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
216
- "spatial dimension, a channel dimension, and an optional batch dimension "
217
- "(i.e. 2D or 3D inputs).")
218
-
219
-
220
- class Dropout2d(_DropoutNd):
221
- r"""Randomly zero out entire channels (a channel is a 2D feature map).
222
-
223
- Each channel will be zeroed out independently on every forward call with
224
- probability using samples from a Bernoulli distribution. The channel is
225
- a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
226
- in the batched input is a 2D tensor :math:`\text{input}[i, j]`.
227
-
228
- Usually the input comes from :class:`Conv2d` modules.
229
-
230
- As described in the paper [1]_, if adjacent pixels within feature maps are
231
- strongly correlated (as is normally the case in early convolution layers)
232
- then i.i.d. dropout will not regularize the activations and will otherwise
233
- just result in an effective learning rate decrease.
234
-
235
- In this case, :class:`Dropout2d` will help promote independence between
236
- feature maps and should be used instead.
237
-
238
- Parameters
239
- ----------
240
- prob : float
241
- Probability of an element to be kept. Default is 0.5.
242
- channel_axis : int
243
- The axis representing the channel dimension. Default is -1.
244
- name : str, optional
245
- The name of the dynamic system.
246
-
247
- Notes
248
- -----
249
- Input shape: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
250
-
251
- Output shape: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
252
-
253
- References
254
- ----------
255
- .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
256
- https://arxiv.org/abs/1411.4280
257
-
258
- Examples
259
- --------
260
- .. code-block:: python
261
-
262
- >>> import brainstate
263
- >>> m = brainstate.nn.Dropout2d(prob=0.8)
264
- >>> x = brainstate.random.randn(20, 32, 32, 16)
265
- >>> with brainstate.environ.context(fit=True):
266
- ... output = m(x)
267
- >>> output.shape
268
- (20, 32, 32, 16)
269
-
270
- """
271
- __module__ = 'brainstate.nn'
272
- minimal_dim: int = 3
273
-
274
- def _get_msg(self, x):
275
- return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
276
- "spatial dimensions, a channel dimension, and an optional batch dimension "
277
- "(i.e. 3D or 4D inputs).")
278
-
279
-
280
- class Dropout3d(_DropoutNd):
281
- r"""Randomly zero out entire channels (a channel is a 3D feature map).
282
-
283
- Each channel will be zeroed out independently on every forward call with
284
- probability using samples from a Bernoulli distribution. The channel is
285
- a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
286
- in the batched input is a 3D tensor :math:`\text{input}[i, j]`.
287
-
288
- Usually the input comes from :class:`Conv3d` modules.
289
-
290
- As described in the paper [1]_, if adjacent pixels within feature maps are
291
- strongly correlated (as is normally the case in early convolution layers)
292
- then i.i.d. dropout will not regularize the activations and will otherwise
293
- just result in an effective learning rate decrease.
294
-
295
- In this case, :class:`Dropout3d` will help promote independence between
296
- feature maps and should be used instead.
297
-
298
- Parameters
299
- ----------
300
- prob : float
301
- Probability of an element to be kept. Default is 0.5.
302
- channel_axis : int
303
- The axis representing the channel dimension. Default is -1.
304
- name : str, optional
305
- The name of the dynamic system.
306
-
307
- Notes
308
- -----
309
- Input shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
310
-
311
- Output shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
312
-
313
- References
314
- ----------
315
- .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
316
- https://arxiv.org/abs/1411.4280
317
-
318
- Examples
319
- --------
320
- .. code-block:: python
321
-
322
- >>> import brainstate
323
- >>> m = brainstate.nn.Dropout3d(prob=0.8)
324
- >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
325
- >>> with brainstate.environ.context(fit=True):
326
- ... output = m(x)
327
- >>> output.shape
328
- (20, 16, 4, 32, 32)
329
-
330
- """
331
- __module__ = 'brainstate.nn'
332
- minimal_dim: int = 4
333
-
334
- def _get_msg(self, x):
335
- return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
336
- "spatial dimensions, a channel dimension, and an optional batch dimension "
337
- "(i.e. 4D or 5D inputs).")
338
-
339
-
340
- class AlphaDropout(_DropoutNd):
341
- r"""Applies Alpha Dropout over the input.
342
-
343
- Alpha Dropout is a type of Dropout that maintains the self-normalizing
344
- property. For an input with zero mean and unit standard deviation, the output of
345
- Alpha Dropout maintains the original mean and standard deviation of the
346
- input.
347
-
348
- Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
349
- that the outputs have zero mean and unit standard deviation.
350
-
351
- During training, it randomly masks some of the elements of the input
352
- tensor with probability using samples from a Bernoulli distribution.
353
- The elements to be masked are randomized on every forward call, and scaled
354
- and shifted to maintain zero mean and unit standard deviation.
355
-
356
- During evaluation the module simply computes an identity function.
357
-
358
- Parameters
359
- ----------
360
- prob : float
361
- Probability of an element to be kept. Default is 0.5.
362
- name : str, optional
363
- The name of the dynamic system.
364
-
365
- Notes
366
- -----
367
- Input shape: :math:`(*)`. Input can be of any shape.
368
-
369
- Output shape: :math:`(*)`. Output is of the same shape as input.
370
-
371
- References
372
- ----------
373
- .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
374
- https://arxiv.org/abs/1706.02515
375
-
376
- Examples
377
- --------
378
- .. code-block:: python
379
-
380
- >>> import brainstate
381
- >>> m = brainstate.nn.AlphaDropout(prob=0.8)
382
- >>> x = brainstate.random.randn(20, 16)
383
- >>> with brainstate.environ.context(fit=True):
384
- ... output = m(x)
385
- >>> output.shape
386
- (20, 16)
387
-
388
- """
389
- __module__ = 'brainstate.nn'
390
-
391
- def __init__(
392
- self,
393
- prob: float = 0.5,
394
- name: Optional[str] = None
395
- ) -> None:
396
- super().__init__(name=name)
397
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
398
- self.prob = prob
399
-
400
- # SELU parameters
401
- alpha = -1.7580993408473766
402
- self.alpha = alpha
403
-
404
- # Affine transformation parameters to maintain mean and variance
405
- self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
406
- self.b = -self.a * alpha * prob
407
-
408
- def __call__(self, x):
409
- dtype = u.math.get_dtype(x)
410
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
411
- if fit_phase and self.prob < 1.:
412
- keep_mask = random.bernoulli(self.prob, x.shape)
413
- return u.math.where(
414
- keep_mask,
415
- u.math.asarray(x, dtype=dtype),
416
- u.math.asarray(self.alpha, dtype=dtype)
417
- ) * self.a + self.b
418
- else:
419
- return x
420
-
421
-
422
- class FeatureAlphaDropout(ElementWiseBlock):
423
- r"""Randomly masks out entire channels with Alpha Dropout properties.
424
-
425
- Instead of setting activations to zero as in regular Dropout, the activations
426
- are set to the negative saturation value of the SELU activation function to
427
- maintain self-normalizing properties.
428
-
429
- Each channel (e.g., the :math:`j`-th channel of the :math:`i`-th sample in
430
- the batch input is a tensor :math:`\text{input}[i, j]`) will be masked
431
- independently for each sample on every forward call with probability using
432
- samples from a Bernoulli distribution. The elements to be masked are randomized
433
- on every forward call, and scaled and shifted to maintain zero mean and unit
434
- variance.
435
-
436
- Usually the input comes from convolutional layers with SELU activation.
437
-
438
- As described in the paper [2]_, if adjacent pixels within feature maps are
439
- strongly correlated (as is normally the case in early convolution layers)
440
- then i.i.d. dropout will not regularize the activations and will otherwise
441
- just result in an effective learning rate decrease.
442
-
443
- In this case, :class:`FeatureAlphaDropout` will help promote independence between
444
- feature maps and should be used instead.
445
-
446
- Parameters
447
- ----------
448
- prob : float
449
- Probability of an element to be kept. Default is 0.5.
450
- channel_axis : int
451
- The axis representing the channel dimension. Default is -1.
452
- name : str, optional
453
- The name of the dynamic system.
454
-
455
- Notes
456
- -----
457
- Input shape: :math:`(N, C, *)` where C is the channel dimension.
458
-
459
- Output shape: Same shape as input.
460
-
461
- References
462
- ----------
463
- .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
464
- https://arxiv.org/abs/1706.02515
465
- .. [2] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
466
- https://arxiv.org/abs/1411.4280
467
-
468
- Examples
469
- --------
470
- .. code-block:: python
471
-
472
- >>> import brainstate
473
- >>> m = brainstate.nn.FeatureAlphaDropout(prob=0.8)
474
- >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
475
- >>> with brainstate.environ.context(fit=True):
476
- ... output = m(x)
477
- >>> output.shape
478
- (20, 16, 4, 32, 32)
479
-
480
- """
481
- __module__ = 'brainstate.nn'
482
-
483
- def __init__(
484
- self,
485
- prob: float = 0.5,
486
- channel_axis: int = -1,
487
- name: Optional[str] = None
488
- ) -> None:
489
- super().__init__(name=name)
490
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
491
- self.prob = prob
492
- self.channel_axis = channel_axis
493
-
494
- # SELU parameters
495
- alpha = -1.7580993408473766
496
- self.alpha = alpha
497
-
498
- # Affine transformation parameters to maintain mean and variance
499
- self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
500
- self.b = -self.a * alpha * prob
501
-
502
- def __call__(self, x):
503
- dtype = u.math.get_dtype(x)
504
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
505
- if fit_phase and self.prob < 1.:
506
- # Create mask shape with 1s except for batch and channel dimensions
507
- channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
508
- mask_shape = [1] * x.ndim
509
- mask_shape[0] = x.shape[0] # batch dimension
510
- mask_shape[channel_axis] = x.shape[channel_axis] # channel dimension
511
-
512
- keep_mask = random.bernoulli(self.prob, mask_shape)
513
- keep_mask = u.math.broadcast_to(keep_mask, x.shape)
514
- return u.math.where(
515
- keep_mask,
516
- u.math.asarray(x, dtype=dtype),
517
- u.math.asarray(self.alpha, dtype=dtype)
518
- ) * self.a + self.b
519
- else:
520
- return x
521
-
522
-
523
- class DropoutFixed(ElementWiseBlock):
524
- """A dropout layer with a fixed dropout mask along the time axis.
525
-
526
- In training, to compensate for the fraction of input values dropped,
527
- all surviving values are multiplied by `1 / (1 - prob)`.
528
-
529
- This layer is active only during training (``mode=brainstate.mixin.Training``). In other
530
- circumstances it is a no-op.
531
-
532
- This kind of Dropout is particularly useful for spiking neural networks (SNNs) where
533
- the same dropout mask needs to be applied across multiple time steps within a single
534
- mini-batch iteration.
535
-
536
- Parameters
537
- ----------
538
- in_size : tuple or int
539
- The size of the input tensor.
540
- prob : float
541
- Probability to keep element of the tensor. Default is 0.5.
542
- name : str, optional
543
- The name of the dynamic system.
544
-
545
- Notes
546
- -----
547
- As described in [2]_, there is a subtle difference in the way dropout is applied in
548
- SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of
549
- mini-batches. In each iteration, randomly selected units (with dropout ratio of
550
- :math:`p`) are disconnected from the network while weighting by its posterior
551
- probability (:math:`1-p`).
552
-
553
- However, in SNNs, each iteration has more than one forward propagation depending on
554
- the time length of the spike train. We back-propagate the output error and modify
555
- the network parameters only at the last time step. For dropout to be effective in
556
- our training method, it has to be ensured that the set of connected units within an
557
- iteration of mini-batch data is not changed, such that the neural network is
558
- constituted by the same random subset of units during each forward propagation within
559
- a single iteration.
560
-
561
- On the other hand, if the units are randomly connected at each time-step, the effect
562
- of dropout will be averaged out over the entire forward propagation time within an
563
- iteration. Then, the dropout effect would fade-out once the output error is propagated
564
- backward and the parameters are updated at the last time step. Therefore, we need to
565
- keep the set of randomly connected units for the entire time window within an iteration.
566
-
567
- References
568
- ----------
569
- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
570
- neural networks from overfitting." The journal of machine learning
571
- research 15.1 (2014): 1929-1958.
572
- .. [2] Lee et al., "Enabling Spike-based Backpropagation for Training Deep Neural
573
- Network Architectures" https://arxiv.org/abs/1903.06379
574
-
575
- Examples
576
- --------
577
- .. code-block:: python
578
-
579
- >>> import brainstate
580
- >>> layer = brainstate.nn.DropoutFixed(in_size=(20,), prob=0.8)
581
- >>> layer.init_state(batch_size=10)
582
- >>> x = brainstate.random.randn(10, 20)
583
- >>> with brainstate.environ.context(fit=True):
584
- ... output = layer.update(x)
585
- >>> output.shape
586
- (10, 20)
587
-
588
- """
589
- __module__ = 'brainstate.nn'
590
-
591
- def __init__(
592
- self,
593
- in_size: Size,
594
- prob: float = 0.5,
595
- name: Optional[str] = None
596
- ) -> None:
597
- super().__init__(name=name)
598
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
599
- self.prob = prob
600
- self.in_size = in_size
601
- self.out_size = in_size
602
-
603
- def init_state(self, batch_size=None, **kwargs):
604
- if self.prob < 1.:
605
- self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
606
-
607
- def update(self, x):
608
- dtype = u.math.get_dtype(x)
609
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
610
- if fit_phase and self.prob < 1.:
611
- if self.mask.value.shape != x.shape:
612
- raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
613
- f"Please call `init_state()` method first.")
614
- return u.math.where(self.mask.value,
615
- u.math.asarray(x / self.prob, dtype=dtype),
616
- u.math.asarray(0., dtype=dtype) * u.get_unit(x))
617
- else:
618
- return x
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from functools import partial
18
+ from typing import Optional, Sequence
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate import random, environ
24
+ from brainstate._state import ShortTermState
25
+ from brainstate.typing import Size
26
+ from . import init as init
27
+ from ._module import ElementWiseBlock
28
+
29
+ __all__ = [
30
+ 'Dropout',
31
+ 'Dropout1d',
32
+ 'Dropout2d',
33
+ 'Dropout3d',
34
+ 'AlphaDropout',
35
+ 'FeatureAlphaDropout',
36
+ 'DropoutFixed',
37
+ ]
38
+
39
+
40
+ class Dropout(ElementWiseBlock):
41
+ """A layer that stochastically ignores a subset of inputs each training step.
42
+
43
+ In training, to compensate for the fraction of input values dropped (`rate`),
44
+ all surviving values are multiplied by `1 / (1 - rate)`.
45
+
46
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
47
+ circumstances it is a no-op.
48
+
49
+ Parameters
50
+ ----------
51
+ prob : float
52
+ Probability to keep element of the tensor. Default is 0.5.
53
+ broadcast_dims : Sequence[int]
54
+ Dimensions that will share the same dropout mask. Default is ().
55
+ name : str, optional
56
+ The name of the dynamic system.
57
+
58
+ References
59
+ ----------
60
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
61
+ neural networks from overfitting." The journal of machine learning
62
+ research 15.1 (2014): 1929-1958.
63
+
64
+ Examples
65
+ --------
66
+ .. code-block:: python
67
+
68
+ >>> import brainstate
69
+ >>> layer = brainstate.nn.Dropout(prob=0.8)
70
+ >>> x = brainstate.random.randn(10, 20)
71
+ >>> with brainstate.environ.context(fit=True):
72
+ ... output = layer(x)
73
+ >>> output.shape
74
+ (10, 20)
75
+
76
+ """
77
+ __module__ = 'brainstate.nn'
78
+
79
+ def __init__(
80
+ self,
81
+ prob: float = 0.5,
82
+ broadcast_dims: Sequence[int] = (),
83
+ name: Optional[str] = None
84
+ ) -> None:
85
+ super().__init__(name=name)
86
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
87
+ self.prob = prob
88
+ self.broadcast_dims = broadcast_dims
89
+
90
+ def __call__(self, x):
91
+ dtype = u.math.get_dtype(x)
92
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
93
+ if fit_phase and self.prob < 1.:
94
+ broadcast_shape = list(x.shape)
95
+ for dim in self.broadcast_dims:
96
+ broadcast_shape[dim] = 1
97
+ keep_mask = random.bernoulli(self.prob, broadcast_shape)
98
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
99
+ return u.math.where(
100
+ keep_mask,
101
+ u.math.asarray(x / self.prob, dtype=dtype),
102
+ u.math.asarray(0., dtype=dtype)
103
+ )
104
+ else:
105
+ return x
106
+
107
+
108
+ class _DropoutNd(ElementWiseBlock):
109
+ __module__ = 'brainstate.nn'
110
+ prob: float
111
+ channel_axis: int
112
+ minimal_dim: int
113
+
114
+ def __init__(
115
+ self,
116
+ prob: float = 0.5,
117
+ channel_axis: int = -1,
118
+ name: Optional[str] = None
119
+ ) -> None:
120
+ super().__init__(name=name)
121
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
122
+ self.prob = prob
123
+ self.channel_axis = channel_axis
124
+
125
+ def __call__(self, x):
126
+ # check input shape
127
+ inp_dim = u.math.ndim(x)
128
+ if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
129
+ raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
130
+ f"but received a {inp_dim}D input. {self._get_msg(x)}")
131
+ is_not_batched = self.minimal_dim
132
+ if is_not_batched:
133
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
134
+ mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
135
+ else:
136
+ channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
137
+ assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
138
+ mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
139
+
140
+ # get fit phase
141
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
142
+
143
+ # generate mask
144
+ if fit_phase and self.prob < 1.:
145
+ dtype = u.math.get_dtype(x)
146
+ keep_mask = random.bernoulli(self.prob, mask_shape)
147
+ keep_mask = jnp.broadcast_to(keep_mask, x.shape)
148
+ return jnp.where(
149
+ keep_mask,
150
+ jnp.asarray(x / self.prob, dtype=dtype),
151
+ jnp.asarray(0., dtype=dtype)
152
+ )
153
+ else:
154
+ return x
155
+
156
+ def _get_msg(self, x):
157
+ return ''
158
+
159
+
160
+ class Dropout1d(_DropoutNd):
161
+ r"""Randomly zero out entire channels (a channel is a 1D feature map).
162
+
163
+ Each channel will be zeroed out independently on every forward call with
164
+ probability using samples from a Bernoulli distribution. The channel is
165
+ a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
166
+ in the batched input is a 1D tensor :math:`\text{input}[i, j]`.
167
+
168
+ Usually the input comes from :class:`Conv1d` modules.
169
+
170
+ As described in the paper [1]_, if adjacent pixels within feature maps are
171
+ strongly correlated (as is normally the case in early convolution layers)
172
+ then i.i.d. dropout will not regularize the activations and will otherwise
173
+ just result in an effective learning rate decrease.
174
+
175
+ In this case, :class:`Dropout1d` will help promote independence between
176
+ feature maps and should be used instead.
177
+
178
+ Parameters
179
+ ----------
180
+ prob : float
181
+ Probability of an element to be kept. Default is 0.5.
182
+ channel_axis : int
183
+ The axis representing the channel dimension. Default is -1.
184
+ name : str, optional
185
+ The name of the dynamic system.
186
+
187
+ Notes
188
+ -----
189
+ Input shape: :math:`(N, C, L)` or :math:`(C, L)`.
190
+
191
+ Output shape: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
192
+
193
+ References
194
+ ----------
195
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
196
+ https://arxiv.org/abs/1411.4280
197
+
198
+ Examples
199
+ --------
200
+ .. code-block:: python
201
+
202
+ >>> import brainstate
203
+ >>> m = brainstate.nn.Dropout1d(prob=0.8)
204
+ >>> x = brainstate.random.randn(20, 32, 16)
205
+ >>> with brainstate.environ.context(fit=True):
206
+ ... output = m(x)
207
+ >>> output.shape
208
+ (20, 32, 16)
209
+
210
+ """
211
+ __module__ = 'brainstate.nn'
212
+ minimal_dim: int = 2
213
+
214
+ def _get_msg(self, x):
215
+ return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
216
+ "spatial dimension, a channel dimension, and an optional batch dimension "
217
+ "(i.e. 2D or 3D inputs).")
218
+
219
+
220
+ class Dropout2d(_DropoutNd):
221
+ r"""Randomly zero out entire channels (a channel is a 2D feature map).
222
+
223
+ Each channel will be zeroed out independently on every forward call with
224
+ probability using samples from a Bernoulli distribution. The channel is
225
+ a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
226
+ in the batched input is a 2D tensor :math:`\text{input}[i, j]`.
227
+
228
+ Usually the input comes from :class:`Conv2d` modules.
229
+
230
+ As described in the paper [1]_, if adjacent pixels within feature maps are
231
+ strongly correlated (as is normally the case in early convolution layers)
232
+ then i.i.d. dropout will not regularize the activations and will otherwise
233
+ just result in an effective learning rate decrease.
234
+
235
+ In this case, :class:`Dropout2d` will help promote independence between
236
+ feature maps and should be used instead.
237
+
238
+ Parameters
239
+ ----------
240
+ prob : float
241
+ Probability of an element to be kept. Default is 0.5.
242
+ channel_axis : int
243
+ The axis representing the channel dimension. Default is -1.
244
+ name : str, optional
245
+ The name of the dynamic system.
246
+
247
+ Notes
248
+ -----
249
+ Input shape: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
250
+
251
+ Output shape: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
252
+
253
+ References
254
+ ----------
255
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
256
+ https://arxiv.org/abs/1411.4280
257
+
258
+ Examples
259
+ --------
260
+ .. code-block:: python
261
+
262
+ >>> import brainstate
263
+ >>> m = brainstate.nn.Dropout2d(prob=0.8)
264
+ >>> x = brainstate.random.randn(20, 32, 32, 16)
265
+ >>> with brainstate.environ.context(fit=True):
266
+ ... output = m(x)
267
+ >>> output.shape
268
+ (20, 32, 32, 16)
269
+
270
+ """
271
+ __module__ = 'brainstate.nn'
272
+ minimal_dim: int = 3
273
+
274
+ def _get_msg(self, x):
275
+ return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
276
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
277
+ "(i.e. 3D or 4D inputs).")
278
+
279
+
280
+ class Dropout3d(_DropoutNd):
281
+ r"""Randomly zero out entire channels (a channel is a 3D feature map).
282
+
283
+ Each channel will be zeroed out independently on every forward call with
284
+ probability using samples from a Bernoulli distribution. The channel is
285
+ a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
286
+ in the batched input is a 3D tensor :math:`\text{input}[i, j]`.
287
+
288
+ Usually the input comes from :class:`Conv3d` modules.
289
+
290
+ As described in the paper [1]_, if adjacent pixels within feature maps are
291
+ strongly correlated (as is normally the case in early convolution layers)
292
+ then i.i.d. dropout will not regularize the activations and will otherwise
293
+ just result in an effective learning rate decrease.
294
+
295
+ In this case, :class:`Dropout3d` will help promote independence between
296
+ feature maps and should be used instead.
297
+
298
+ Parameters
299
+ ----------
300
+ prob : float
301
+ Probability of an element to be kept. Default is 0.5.
302
+ channel_axis : int
303
+ The axis representing the channel dimension. Default is -1.
304
+ name : str, optional
305
+ The name of the dynamic system.
306
+
307
+ Notes
308
+ -----
309
+ Input shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
310
+
311
+ Output shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
312
+
313
+ References
314
+ ----------
315
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
316
+ https://arxiv.org/abs/1411.4280
317
+
318
+ Examples
319
+ --------
320
+ .. code-block:: python
321
+
322
+ >>> import brainstate
323
+ >>> m = brainstate.nn.Dropout3d(prob=0.8)
324
+ >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
325
+ >>> with brainstate.environ.context(fit=True):
326
+ ... output = m(x)
327
+ >>> output.shape
328
+ (20, 16, 4, 32, 32)
329
+
330
+ """
331
+ __module__ = 'brainstate.nn'
332
+ minimal_dim: int = 4
333
+
334
+ def _get_msg(self, x):
335
+ return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
336
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
337
+ "(i.e. 4D or 5D inputs).")
338
+
339
+
340
+ class AlphaDropout(_DropoutNd):
341
+ r"""Applies Alpha Dropout over the input.
342
+
343
+ Alpha Dropout is a type of Dropout that maintains the self-normalizing
344
+ property. For an input with zero mean and unit standard deviation, the output of
345
+ Alpha Dropout maintains the original mean and standard deviation of the
346
+ input.
347
+
348
+ Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
349
+ that the outputs have zero mean and unit standard deviation.
350
+
351
+ During training, it randomly masks some of the elements of the input
352
+ tensor with probability using samples from a Bernoulli distribution.
353
+ The elements to be masked are randomized on every forward call, and scaled
354
+ and shifted to maintain zero mean and unit standard deviation.
355
+
356
+ During evaluation the module simply computes an identity function.
357
+
358
+ Parameters
359
+ ----------
360
+ prob : float
361
+ Probability of an element to be kept. Default is 0.5.
362
+ name : str, optional
363
+ The name of the dynamic system.
364
+
365
+ Notes
366
+ -----
367
+ Input shape: :math:`(*)`. Input can be of any shape.
368
+
369
+ Output shape: :math:`(*)`. Output is of the same shape as input.
370
+
371
+ References
372
+ ----------
373
+ .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
374
+ https://arxiv.org/abs/1706.02515
375
+
376
+ Examples
377
+ --------
378
+ .. code-block:: python
379
+
380
+ >>> import brainstate
381
+ >>> m = brainstate.nn.AlphaDropout(prob=0.8)
382
+ >>> x = brainstate.random.randn(20, 16)
383
+ >>> with brainstate.environ.context(fit=True):
384
+ ... output = m(x)
385
+ >>> output.shape
386
+ (20, 16)
387
+
388
+ """
389
+ __module__ = 'brainstate.nn'
390
+
391
+ def __init__(
392
+ self,
393
+ prob: float = 0.5,
394
+ name: Optional[str] = None
395
+ ) -> None:
396
+ super().__init__(name=name)
397
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
398
+ self.prob = prob
399
+
400
+ # SELU parameters
401
+ alpha = -1.7580993408473766
402
+ self.alpha = alpha
403
+
404
+ # Affine transformation parameters to maintain mean and variance
405
+ self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
406
+ self.b = -self.a * alpha * prob
407
+
408
+ def __call__(self, x):
409
+ dtype = u.math.get_dtype(x)
410
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
411
+ if fit_phase and self.prob < 1.:
412
+ keep_mask = random.bernoulli(self.prob, x.shape)
413
+ return u.math.where(
414
+ keep_mask,
415
+ u.math.asarray(x, dtype=dtype),
416
+ u.math.asarray(self.alpha, dtype=dtype)
417
+ ) * self.a + self.b
418
+ else:
419
+ return x
420
+
421
+
422
+ class FeatureAlphaDropout(ElementWiseBlock):
423
+ r"""Randomly masks out entire channels with Alpha Dropout properties.
424
+
425
+ Instead of setting activations to zero as in regular Dropout, the activations
426
+ are set to the negative saturation value of the SELU activation function to
427
+ maintain self-normalizing properties.
428
+
429
+ Each channel (e.g., the :math:`j`-th channel of the :math:`i`-th sample in
430
+ the batch input is a tensor :math:`\text{input}[i, j]`) will be masked
431
+ independently for each sample on every forward call with probability using
432
+ samples from a Bernoulli distribution. The elements to be masked are randomized
433
+ on every forward call, and scaled and shifted to maintain zero mean and unit
434
+ variance.
435
+
436
+ Usually the input comes from convolutional layers with SELU activation.
437
+
438
+ As described in the paper [2]_, if adjacent pixels within feature maps are
439
+ strongly correlated (as is normally the case in early convolution layers)
440
+ then i.i.d. dropout will not regularize the activations and will otherwise
441
+ just result in an effective learning rate decrease.
442
+
443
+ In this case, :class:`FeatureAlphaDropout` will help promote independence between
444
+ feature maps and should be used instead.
445
+
446
+ Parameters
447
+ ----------
448
+ prob : float
449
+ Probability of an element to be kept. Default is 0.5.
450
+ channel_axis : int
451
+ The axis representing the channel dimension. Default is -1.
452
+ name : str, optional
453
+ The name of the dynamic system.
454
+
455
+ Notes
456
+ -----
457
+ Input shape: :math:`(N, C, *)` where C is the channel dimension.
458
+
459
+ Output shape: Same shape as input.
460
+
461
+ References
462
+ ----------
463
+ .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
464
+ https://arxiv.org/abs/1706.02515
465
+ .. [2] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
466
+ https://arxiv.org/abs/1411.4280
467
+
468
+ Examples
469
+ --------
470
+ .. code-block:: python
471
+
472
+ >>> import brainstate
473
+ >>> m = brainstate.nn.FeatureAlphaDropout(prob=0.8)
474
+ >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
475
+ >>> with brainstate.environ.context(fit=True):
476
+ ... output = m(x)
477
+ >>> output.shape
478
+ (20, 16, 4, 32, 32)
479
+
480
+ """
481
+ __module__ = 'brainstate.nn'
482
+
483
+ def __init__(
484
+ self,
485
+ prob: float = 0.5,
486
+ channel_axis: int = -1,
487
+ name: Optional[str] = None
488
+ ) -> None:
489
+ super().__init__(name=name)
490
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
491
+ self.prob = prob
492
+ self.channel_axis = channel_axis
493
+
494
+ # SELU parameters
495
+ alpha = -1.7580993408473766
496
+ self.alpha = alpha
497
+
498
+ # Affine transformation parameters to maintain mean and variance
499
+ self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
500
+ self.b = -self.a * alpha * prob
501
+
502
+ def __call__(self, x):
503
+ dtype = u.math.get_dtype(x)
504
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
505
+ if fit_phase and self.prob < 1.:
506
+ # Create mask shape with 1s except for batch and channel dimensions
507
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
508
+ mask_shape = [1] * x.ndim
509
+ mask_shape[0] = x.shape[0] # batch dimension
510
+ mask_shape[channel_axis] = x.shape[channel_axis] # channel dimension
511
+
512
+ keep_mask = random.bernoulli(self.prob, mask_shape)
513
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
514
+ return u.math.where(
515
+ keep_mask,
516
+ u.math.asarray(x, dtype=dtype),
517
+ u.math.asarray(self.alpha, dtype=dtype)
518
+ ) * self.a + self.b
519
+ else:
520
+ return x
521
+
522
+
523
+ class DropoutFixed(ElementWiseBlock):
524
+ """A dropout layer with a fixed dropout mask along the time axis.
525
+
526
+ In training, to compensate for the fraction of input values dropped,
527
+ all surviving values are multiplied by `1 / (1 - prob)`.
528
+
529
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
530
+ circumstances it is a no-op.
531
+
532
+ This kind of Dropout is particularly useful for spiking neural networks (SNNs) where
533
+ the same dropout mask needs to be applied across multiple time steps within a single
534
+ mini-batch iteration.
535
+
536
+ Parameters
537
+ ----------
538
+ in_size : tuple or int
539
+ The size of the input tensor.
540
+ prob : float
541
+ Probability to keep element of the tensor. Default is 0.5.
542
+ name : str, optional
543
+ The name of the dynamic system.
544
+
545
+ Notes
546
+ -----
547
+ As described in [2]_, there is a subtle difference in the way dropout is applied in
548
+ SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of
549
+ mini-batches. In each iteration, randomly selected units (with dropout ratio of
550
+ :math:`p`) are disconnected from the network while weighting by its posterior
551
+ probability (:math:`1-p`).
552
+
553
+ However, in SNNs, each iteration has more than one forward propagation depending on
554
+ the time length of the spike train. We back-propagate the output error and modify
555
+ the network parameters only at the last time step. For dropout to be effective in
556
+ our training method, it has to be ensured that the set of connected units within an
557
+ iteration of mini-batch data is not changed, such that the neural network is
558
+ constituted by the same random subset of units during each forward propagation within
559
+ a single iteration.
560
+
561
+ On the other hand, if the units are randomly connected at each time-step, the effect
562
+ of dropout will be averaged out over the entire forward propagation time within an
563
+ iteration. Then, the dropout effect would fade-out once the output error is propagated
564
+ backward and the parameters are updated at the last time step. Therefore, we need to
565
+ keep the set of randomly connected units for the entire time window within an iteration.
566
+
567
+ References
568
+ ----------
569
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
570
+ neural networks from overfitting." The journal of machine learning
571
+ research 15.1 (2014): 1929-1958.
572
+ .. [2] Lee et al., "Enabling Spike-based Backpropagation for Training Deep Neural
573
+ Network Architectures" https://arxiv.org/abs/1903.06379
574
+
575
+ Examples
576
+ --------
577
+ .. code-block:: python
578
+
579
+ >>> import brainstate
580
+ >>> layer = brainstate.nn.DropoutFixed(in_size=(20,), prob=0.8)
581
+ >>> layer.init_state(batch_size=10)
582
+ >>> x = brainstate.random.randn(10, 20)
583
+ >>> with brainstate.environ.context(fit=True):
584
+ ... output = layer.update(x)
585
+ >>> output.shape
586
+ (10, 20)
587
+
588
+ """
589
+ __module__ = 'brainstate.nn'
590
+
591
+ def __init__(
592
+ self,
593
+ in_size: Size,
594
+ prob: float = 0.5,
595
+ name: Optional[str] = None
596
+ ) -> None:
597
+ super().__init__(name=name)
598
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
599
+ self.prob = prob
600
+ self.in_size = in_size
601
+ self.out_size = in_size
602
+
603
+ def init_state(self, batch_size=None, **kwargs):
604
+ if self.prob < 1.:
605
+ self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
606
+
607
+ def update(self, x):
608
+ dtype = u.math.get_dtype(x)
609
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
610
+ if fit_phase and self.prob < 1.:
611
+ if self.mask.value.shape != x.shape:
612
+ raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
613
+ f"Please call `init_state()` method first.")
614
+ return u.math.where(self.mask.value,
615
+ u.math.asarray(x / self.prob, dtype=dtype),
616
+ u.math.asarray(0., dtype=dtype) * u.get_unit(x))
617
+ else:
618
+ return x