brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_dropout.py CHANGED
@@ -1,426 +1,618 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from functools import partial
18
- from typing import Optional, Sequence
19
-
20
- import brainunit as u
21
- import jax.numpy as jnp
22
-
23
- from brainstate import random, environ, init
24
- from brainstate._state import ShortTermState
25
- from brainstate.typing import Size
26
- from ._module import ElementWiseBlock
27
-
28
- __all__ = [
29
- 'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
30
- ]
31
-
32
-
33
- class Dropout(ElementWiseBlock):
34
- """A layer that stochastically ignores a subset of inputs each training step.
35
-
36
- In training, to compensate for the fraction of input values dropped (`rate`),
37
- all surviving values are multiplied by `1 / (1 - rate)`.
38
-
39
- This layer is active only during training (``mode=brainstate.mixin.Training``). In other
40
- circumstances it is a no-op.
41
-
42
- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
43
- neural networks from overfitting." The journal of machine learning
44
- research 15.1 (2014): 1929-1958.
45
-
46
- Args:
47
- prob: Probability to keep element of the tensor.
48
- broadcast_dims: dimensions that will share the same dropout mask.
49
- name: str. The name of the dynamic system.
50
-
51
- """
52
- __module__ = 'brainstate.nn'
53
-
54
- def __init__(
55
- self,
56
- prob: float = 0.5,
57
- broadcast_dims: Sequence[int] = (),
58
- name: Optional[str] = None
59
- ) -> None:
60
- super().__init__(name=name)
61
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
62
- self.prob = prob
63
- self.broadcast_dims = broadcast_dims
64
-
65
- def __call__(self, x):
66
- dtype = u.math.get_dtype(x)
67
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
68
- if fit_phase and self.prob < 1.:
69
- broadcast_shape = list(x.shape)
70
- for dim in self.broadcast_dims:
71
- broadcast_shape[dim] = 1
72
- keep_mask = random.bernoulli(self.prob, broadcast_shape)
73
- keep_mask = u.math.broadcast_to(keep_mask, x.shape)
74
- return u.math.where(
75
- keep_mask,
76
- u.math.asarray(x / self.prob, dtype=dtype),
77
- u.math.asarray(0., dtype=dtype)
78
- )
79
- else:
80
- return x
81
-
82
-
83
- class _DropoutNd(ElementWiseBlock):
84
- __module__ = 'brainstate.nn'
85
- prob: float
86
- channel_axis: int
87
- minimal_dim: int
88
-
89
- def __init__(
90
- self,
91
- prob: float = 0.5,
92
- channel_axis: int = -1,
93
- name: Optional[str] = None
94
- ) -> None:
95
- super().__init__(name=name)
96
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
97
- self.prob = prob
98
- self.channel_axis = channel_axis
99
-
100
- def __call__(self, x):
101
- # check input shape
102
- inp_dim = u.math.ndim(x)
103
- if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
104
- raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
105
- f"but received a {inp_dim}D input. {self._get_msg(x)}")
106
- is_not_batched = self.minimal_dim
107
- if is_not_batched:
108
- channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
109
- mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
110
- else:
111
- channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
112
- assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
113
- mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
114
-
115
- # get fit phase
116
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
117
-
118
- # generate mask
119
- if fit_phase and self.prob < 1.:
120
- dtype = u.math.get_dtype(x)
121
- keep_mask = random.bernoulli(self.prob, mask_shape)
122
- keep_mask = jnp.broadcast_to(keep_mask, x.shape)
123
- return jnp.where(
124
- keep_mask,
125
- jnp.asarray(x / self.prob, dtype=dtype),
126
- jnp.asarray(0., dtype=dtype)
127
- )
128
- else:
129
- return x
130
-
131
- def _get_msg(self, x):
132
- return ''
133
-
134
-
135
- class Dropout1d(_DropoutNd):
136
- r"""Randomly zero out entire channels (a channel is a 1D feature map,
137
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
138
- batched input is a 1D tensor :math:`\text{input}[i, j]`).
139
- Each channel will be zeroed out independently on every forward call with
140
- probability :attr:`p` using samples from a Bernoulli distribution.
141
-
142
- Usually the input comes from :class:`nn.Conv1d` modules.
143
-
144
- As described in the paper
145
- `Efficient Object Localization Using Convolutional Networks`_ ,
146
- if adjacent pixels within feature maps are strongly correlated
147
- (as is normally the case in early convolution layers) then i.i.d. dropout
148
- will not regularize the activations and will otherwise just result
149
- in an effective learning rate decrease.
150
-
151
- In this case, :func:`nn.Dropout1d` will help promote independence between
152
- feature maps and should be used instead.
153
-
154
- Args:
155
- prob: float. probability of an element to be zero-ed.
156
-
157
- Shape:
158
- - Input: :math:`(N, C, L)` or :math:`(C, L)`.
159
- - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
160
-
161
- Examples::
162
-
163
- >>> m = Dropout1d(p=0.2)
164
- >>> x = random.randn(20, 32, 16)
165
- >>> output = m(x)
166
- >>> output.shape
167
- (20, 32, 16)
168
-
169
- .. _Efficient Object Localization Using Convolutional Networks:
170
- https://arxiv.org/abs/1411.4280
171
- """
172
- __module__ = 'brainstate.nn'
173
- minimal_dim: int = 2
174
-
175
- def _get_msg(self, x):
176
- return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
177
- "spatial dimension, a channel dimension, and an optional batch dimension "
178
- "(i.e. 2D or 3D inputs).")
179
-
180
-
181
- class Dropout2d(_DropoutNd):
182
- r"""Randomly zero out entire channels (a channel is a 2D feature map,
183
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
184
- batched input is a 2D tensor :math:`\text{input}[i, j]`).
185
- Each channel will be zeroed out independently on every forward call with
186
- probability :attr:`p` using samples from a Bernoulli distribution.
187
-
188
- Usually the input comes from :class:`nn.Conv2d` modules.
189
-
190
- As described in the paper
191
- `Efficient Object Localization Using Convolutional Networks`_ ,
192
- if adjacent pixels within feature maps are strongly correlated
193
- (as is normally the case in early convolution layers) then i.i.d. dropout
194
- will not regularize the activations and will otherwise just result
195
- in an effective learning rate decrease.
196
-
197
- In this case, :func:`nn.Dropout2d` will help promote independence between
198
- feature maps and should be used instead.
199
-
200
- Args:
201
- prob: float. probability of an element to be kept.
202
-
203
- Shape:
204
- - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
205
- - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
206
-
207
- Examples::
208
-
209
- >>> m = Dropout2d(p=0.2)
210
- >>> x = random.randn(20, 32, 32, 16)
211
- >>> output = m(x)
212
-
213
- .. _Efficient Object Localization Using Convolutional Networks:
214
- https://arxiv.org/abs/1411.4280
215
- """
216
- __module__ = 'brainstate.nn'
217
- minimal_dim: int = 3
218
-
219
- def _get_msg(self, x):
220
- return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
221
- "spatial dimensions, a channel dimension, and an optional batch dimension "
222
- "(i.e. 3D or 4D inputs).")
223
-
224
-
225
- class Dropout3d(_DropoutNd):
226
- r"""Randomly zero out entire channels (a channel is a 3D feature map,
227
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
228
- batched input is a 3D tensor :math:`\text{input}[i, j]`).
229
- Each channel will be zeroed out independently on every forward call with
230
- probability :attr:`p` using samples from a Bernoulli distribution.
231
-
232
- Usually the input comes from :class:`nn.Conv3d` modules.
233
-
234
- As described in the paper
235
- `Efficient Object Localization Using Convolutional Networks`_ ,
236
- if adjacent pixels within feature maps are strongly correlated
237
- (as is normally the case in early convolution layers) then i.i.d. dropout
238
- will not regularize the activations and will otherwise just result
239
- in an effective learning rate decrease.
240
-
241
- In this case, :func:`nn.Dropout3d` will help promote independence between
242
- feature maps and should be used instead.
243
-
244
- Args:
245
- prob: float. probability of an element to be kept.
246
-
247
- Shape:
248
- - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
249
- - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
250
-
251
- Examples::
252
-
253
- >>> m = Dropout3d(p=0.2)
254
- >>> x = random.randn(20, 16, 4, 32, 32)
255
- >>> output = m(x)
256
-
257
- .. _Efficient Object Localization Using Convolutional Networks:
258
- https://arxiv.org/abs/1411.4280
259
- """
260
- __module__ = 'brainstate.nn'
261
- minimal_dim: int = 4
262
-
263
- def _get_msg(self, x):
264
- return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
265
- "spatial dimensions, a channel dimension, and an optional batch dimension "
266
- "(i.e. 4D or 5D inputs).")
267
-
268
-
269
- class AlphaDropout(_DropoutNd):
270
- r"""Applies Alpha Dropout over the input.
271
-
272
- Alpha Dropout is a type of Dropout that maintains the self-normalizing
273
- property.
274
- For an input with zero mean and unit standard deviation, the output of
275
- Alpha Dropout maintains the original mean and standard deviation of the
276
- input.
277
- Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
278
- that the outputs have zero mean and unit standard deviation.
279
-
280
- During training, it randomly masks some of the elements of the input
281
- tensor with probability *p* using samples from a bernoulli distribution.
282
- The elements to masked are randomized on every forward call, and scaled
283
- and shifted to maintain zero mean and unit standard deviation.
284
-
285
- During evaluation the module simply computes an identity function.
286
-
287
- More details can be found in the paper `Self-Normalizing Neural Networks`_ .
288
-
289
- Args:
290
- prob: float. probability of an element to be kept.
291
-
292
- Shape:
293
- - Input: :math:`(*)`. Input can be of any shape
294
- - Output: :math:`(*)`. Output is of the same shape as input
295
-
296
- Examples::
297
-
298
- >>> m = AlphaDropout(p=0.2)
299
- >>> x = random.randn(20, 16)
300
- >>> output = m(x)
301
-
302
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
303
- """
304
- __module__ = 'brainstate.nn'
305
-
306
- def update(self, *args, **kwargs):
307
- raise NotImplementedError("AlphaDropout is not supported in the current version.")
308
-
309
-
310
- class FeatureAlphaDropout(_DropoutNd):
311
- r"""Randomly masks out entire channels (a channel is a feature map,
312
- e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
313
- is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
314
- setting activations to zero, as in regular Dropout, the activations are set
315
- to the negative saturation value of the SELU activation function. More details
316
- can be found in the paper `Self-Normalizing Neural Networks`_ .
317
-
318
- Each element will be masked independently for each sample on every forward
319
- call with probability :attr:`p` using samples from a Bernoulli distribution.
320
- The elements to be masked are randomized on every forward call, and scaled
321
- and shifted to maintain zero mean and unit variance.
322
-
323
- Usually the input comes from :class:`nn.AlphaDropout` modules.
324
-
325
- As described in the paper
326
- `Efficient Object Localization Using Convolutional Networks`_ ,
327
- if adjacent pixels within feature maps are strongly correlated
328
- (as is normally the case in early convolution layers) then i.i.d. dropout
329
- will not regularize the activations and will otherwise just result
330
- in an effective learning rate decrease.
331
-
332
- In this case, :func:`nn.AlphaDropout` will help promote independence between
333
- feature maps and should be used instead.
334
-
335
- Args:
336
- prob: float. probability of an element to be kept.
337
-
338
- Shape:
339
- - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
340
- - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
341
-
342
- Examples::
343
-
344
- >>> m = FeatureAlphaDropout(p=0.2)
345
- >>> x = random.randn(20, 16, 4, 32, 32)
346
- >>> output = m(x)
347
-
348
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
349
- .. _Efficient Object Localization Using Convolutional Networks:
350
- https://arxiv.org/abs/1411.4280
351
- """
352
- __module__ = 'brainstate.nn'
353
-
354
- def update(self, *args, **kwargs):
355
- raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
356
-
357
-
358
- class DropoutFixed(ElementWiseBlock):
359
- """
360
- A dropout layer with the fixed dropout mask along the time axis once after initialized.
361
-
362
- In training, to compensate for the fraction of input values dropped (`rate`),
363
- all surviving values are multiplied by `1 / (1 - rate)`.
364
-
365
- This layer is active only during training (``mode=brainstate.mixin.Training``). In other
366
- circumstances it is a no-op.
367
-
368
- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
369
- neural networks from overfitting." The journal of machine learning
370
- research 15.1 (2014): 1929-1958.
371
-
372
- .. admonition:: Tip
373
- :class: tip
374
-
375
- This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
376
- Network Architectures <https://arxiv.org/abs/1903.06379>`_:
377
-
378
- There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
379
- training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
380
- are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
381
- iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
382
- the output error and modify the network parameters only at the last time step. For dropout to be effective in
383
- our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
384
- data is not changed, such that the neural network is constituted by the same random subset of units during
385
- each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
386
- each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
387
- iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
388
- are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
389
- time window within an iteration.
390
-
391
- Args:
392
- in_size: The size of the input tensor.
393
- prob: Probability to keep element of the tensor.
394
- mode: Mode. The computation mode of the object.
395
- name: str. The name of the dynamic system.
396
- """
397
- __module__ = 'brainstate.nn'
398
-
399
- def __init__(
400
- self,
401
- in_size: Size,
402
- prob: float = 0.5,
403
- name: Optional[str] = None
404
- ) -> None:
405
- super().__init__(name=name)
406
- assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
407
- self.prob = prob
408
- self.in_size = in_size
409
- self.out_size = in_size
410
-
411
- def init_state(self, batch_size=None, **kwargs):
412
- if self.prob < 1.:
413
- self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
414
-
415
- def update(self, x):
416
- dtype = u.math.get_dtype(x)
417
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
418
- if fit_phase and self.prob < 1.:
419
- if self.mask.value.shape != x.shape:
420
- raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
421
- f"Please call `init_state()` method first.")
422
- return u.math.where(self.mask.value,
423
- u.math.asarray(x / self.prob, dtype=dtype),
424
- u.math.asarray(0., dtype=dtype) * u.get_unit(x))
425
- else:
426
- return x
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from functools import partial
18
+ from typing import Optional, Sequence
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate import random, environ
24
+ from brainstate._state import ShortTermState
25
+ from brainstate.typing import Size
26
+ from . import init as init
27
+ from ._module import ElementWiseBlock
28
+
29
+ __all__ = [
30
+ 'Dropout',
31
+ 'Dropout1d',
32
+ 'Dropout2d',
33
+ 'Dropout3d',
34
+ 'AlphaDropout',
35
+ 'FeatureAlphaDropout',
36
+ 'DropoutFixed',
37
+ ]
38
+
39
+
40
+ class Dropout(ElementWiseBlock):
41
+ """A layer that stochastically ignores a subset of inputs each training step.
42
+
43
+ In training, to compensate for the fraction of input values dropped (`rate`),
44
+ all surviving values are multiplied by `1 / (1 - rate)`.
45
+
46
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
47
+ circumstances it is a no-op.
48
+
49
+ Parameters
50
+ ----------
51
+ prob : float
52
+ Probability to keep element of the tensor. Default is 0.5.
53
+ broadcast_dims : Sequence[int]
54
+ Dimensions that will share the same dropout mask. Default is ().
55
+ name : str, optional
56
+ The name of the dynamic system.
57
+
58
+ References
59
+ ----------
60
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
61
+ neural networks from overfitting." The journal of machine learning
62
+ research 15.1 (2014): 1929-1958.
63
+
64
+ Examples
65
+ --------
66
+ .. code-block:: python
67
+
68
+ >>> import brainstate
69
+ >>> layer = brainstate.nn.Dropout(prob=0.8)
70
+ >>> x = brainstate.random.randn(10, 20)
71
+ >>> with brainstate.environ.context(fit=True):
72
+ ... output = layer(x)
73
+ >>> output.shape
74
+ (10, 20)
75
+
76
+ """
77
+ __module__ = 'brainstate.nn'
78
+
79
+ def __init__(
80
+ self,
81
+ prob: float = 0.5,
82
+ broadcast_dims: Sequence[int] = (),
83
+ name: Optional[str] = None
84
+ ) -> None:
85
+ super().__init__(name=name)
86
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
87
+ self.prob = prob
88
+ self.broadcast_dims = broadcast_dims
89
+
90
+ def __call__(self, x):
91
+ dtype = u.math.get_dtype(x)
92
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
93
+ if fit_phase and self.prob < 1.:
94
+ broadcast_shape = list(x.shape)
95
+ for dim in self.broadcast_dims:
96
+ broadcast_shape[dim] = 1
97
+ keep_mask = random.bernoulli(self.prob, broadcast_shape)
98
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
99
+ return u.math.where(
100
+ keep_mask,
101
+ u.math.asarray(x / self.prob, dtype=dtype),
102
+ u.math.asarray(0., dtype=dtype)
103
+ )
104
+ else:
105
+ return x
106
+
107
+
108
+ class _DropoutNd(ElementWiseBlock):
109
+ __module__ = 'brainstate.nn'
110
+ prob: float
111
+ channel_axis: int
112
+ minimal_dim: int
113
+
114
+ def __init__(
115
+ self,
116
+ prob: float = 0.5,
117
+ channel_axis: int = -1,
118
+ name: Optional[str] = None
119
+ ) -> None:
120
+ super().__init__(name=name)
121
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
122
+ self.prob = prob
123
+ self.channel_axis = channel_axis
124
+
125
+ def __call__(self, x):
126
+ # check input shape
127
+ inp_dim = u.math.ndim(x)
128
+ if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
129
+ raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
130
+ f"but received a {inp_dim}D input. {self._get_msg(x)}")
131
+ is_not_batched = self.minimal_dim
132
+ if is_not_batched:
133
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
134
+ mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
135
+ else:
136
+ channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
137
+ assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
138
+ mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
139
+
140
+ # get fit phase
141
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
142
+
143
+ # generate mask
144
+ if fit_phase and self.prob < 1.:
145
+ dtype = u.math.get_dtype(x)
146
+ keep_mask = random.bernoulli(self.prob, mask_shape)
147
+ keep_mask = jnp.broadcast_to(keep_mask, x.shape)
148
+ return jnp.where(
149
+ keep_mask,
150
+ jnp.asarray(x / self.prob, dtype=dtype),
151
+ jnp.asarray(0., dtype=dtype)
152
+ )
153
+ else:
154
+ return x
155
+
156
+ def _get_msg(self, x):
157
+ return ''
158
+
159
+
160
+ class Dropout1d(_DropoutNd):
161
+ r"""Randomly zero out entire channels (a channel is a 1D feature map).
162
+
163
+ Each channel will be zeroed out independently on every forward call with
164
+ probability using samples from a Bernoulli distribution. The channel is
165
+ a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
166
+ in the batched input is a 1D tensor :math:`\text{input}[i, j]`.
167
+
168
+ Usually the input comes from :class:`Conv1d` modules.
169
+
170
+ As described in the paper [1]_, if adjacent pixels within feature maps are
171
+ strongly correlated (as is normally the case in early convolution layers)
172
+ then i.i.d. dropout will not regularize the activations and will otherwise
173
+ just result in an effective learning rate decrease.
174
+
175
+ In this case, :class:`Dropout1d` will help promote independence between
176
+ feature maps and should be used instead.
177
+
178
+ Parameters
179
+ ----------
180
+ prob : float
181
+ Probability of an element to be kept. Default is 0.5.
182
+ channel_axis : int
183
+ The axis representing the channel dimension. Default is -1.
184
+ name : str, optional
185
+ The name of the dynamic system.
186
+
187
+ Notes
188
+ -----
189
+ Input shape: :math:`(N, C, L)` or :math:`(C, L)`.
190
+
191
+ Output shape: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
192
+
193
+ References
194
+ ----------
195
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
196
+ https://arxiv.org/abs/1411.4280
197
+
198
+ Examples
199
+ --------
200
+ .. code-block:: python
201
+
202
+ >>> import brainstate
203
+ >>> m = brainstate.nn.Dropout1d(prob=0.8)
204
+ >>> x = brainstate.random.randn(20, 32, 16)
205
+ >>> with brainstate.environ.context(fit=True):
206
+ ... output = m(x)
207
+ >>> output.shape
208
+ (20, 32, 16)
209
+
210
+ """
211
+ __module__ = 'brainstate.nn'
212
+ minimal_dim: int = 2
213
+
214
+ def _get_msg(self, x):
215
+ return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
216
+ "spatial dimension, a channel dimension, and an optional batch dimension "
217
+ "(i.e. 2D or 3D inputs).")
218
+
219
+
220
+ class Dropout2d(_DropoutNd):
221
+ r"""Randomly zero out entire channels (a channel is a 2D feature map).
222
+
223
+ Each channel will be zeroed out independently on every forward call with
224
+ probability using samples from a Bernoulli distribution. The channel is
225
+ a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
226
+ in the batched input is a 2D tensor :math:`\text{input}[i, j]`.
227
+
228
+ Usually the input comes from :class:`Conv2d` modules.
229
+
230
+ As described in the paper [1]_, if adjacent pixels within feature maps are
231
+ strongly correlated (as is normally the case in early convolution layers)
232
+ then i.i.d. dropout will not regularize the activations and will otherwise
233
+ just result in an effective learning rate decrease.
234
+
235
+ In this case, :class:`Dropout2d` will help promote independence between
236
+ feature maps and should be used instead.
237
+
238
+ Parameters
239
+ ----------
240
+ prob : float
241
+ Probability of an element to be kept. Default is 0.5.
242
+ channel_axis : int
243
+ The axis representing the channel dimension. Default is -1.
244
+ name : str, optional
245
+ The name of the dynamic system.
246
+
247
+ Notes
248
+ -----
249
+ Input shape: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
250
+
251
+ Output shape: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
252
+
253
+ References
254
+ ----------
255
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
256
+ https://arxiv.org/abs/1411.4280
257
+
258
+ Examples
259
+ --------
260
+ .. code-block:: python
261
+
262
+ >>> import brainstate
263
+ >>> m = brainstate.nn.Dropout2d(prob=0.8)
264
+ >>> x = brainstate.random.randn(20, 32, 32, 16)
265
+ >>> with brainstate.environ.context(fit=True):
266
+ ... output = m(x)
267
+ >>> output.shape
268
+ (20, 32, 32, 16)
269
+
270
+ """
271
+ __module__ = 'brainstate.nn'
272
+ minimal_dim: int = 3
273
+
274
+ def _get_msg(self, x):
275
+ return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
276
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
277
+ "(i.e. 3D or 4D inputs).")
278
+
279
+
280
+ class Dropout3d(_DropoutNd):
281
+ r"""Randomly zero out entire channels (a channel is a 3D feature map).
282
+
283
+ Each channel will be zeroed out independently on every forward call with
284
+ probability using samples from a Bernoulli distribution. The channel is
285
+ a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
286
+ in the batched input is a 3D tensor :math:`\text{input}[i, j]`.
287
+
288
+ Usually the input comes from :class:`Conv3d` modules.
289
+
290
+ As described in the paper [1]_, if adjacent pixels within feature maps are
291
+ strongly correlated (as is normally the case in early convolution layers)
292
+ then i.i.d. dropout will not regularize the activations and will otherwise
293
+ just result in an effective learning rate decrease.
294
+
295
+ In this case, :class:`Dropout3d` will help promote independence between
296
+ feature maps and should be used instead.
297
+
298
+ Parameters
299
+ ----------
300
+ prob : float
301
+ Probability of an element to be kept. Default is 0.5.
302
+ channel_axis : int
303
+ The axis representing the channel dimension. Default is -1.
304
+ name : str, optional
305
+ The name of the dynamic system.
306
+
307
+ Notes
308
+ -----
309
+ Input shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
310
+
311
+ Output shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
312
+
313
+ References
314
+ ----------
315
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
316
+ https://arxiv.org/abs/1411.4280
317
+
318
+ Examples
319
+ --------
320
+ .. code-block:: python
321
+
322
+ >>> import brainstate
323
+ >>> m = brainstate.nn.Dropout3d(prob=0.8)
324
+ >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
325
+ >>> with brainstate.environ.context(fit=True):
326
+ ... output = m(x)
327
+ >>> output.shape
328
+ (20, 16, 4, 32, 32)
329
+
330
+ """
331
+ __module__ = 'brainstate.nn'
332
+ minimal_dim: int = 4
333
+
334
+ def _get_msg(self, x):
335
+ return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
336
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
337
+ "(i.e. 4D or 5D inputs).")
338
+
339
+
340
+ class AlphaDropout(_DropoutNd):
341
+ r"""Applies Alpha Dropout over the input.
342
+
343
+ Alpha Dropout is a type of Dropout that maintains the self-normalizing
344
+ property. For an input with zero mean and unit standard deviation, the output of
345
+ Alpha Dropout maintains the original mean and standard deviation of the
346
+ input.
347
+
348
+ Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
349
+ that the outputs have zero mean and unit standard deviation.
350
+
351
+ During training, it randomly masks some of the elements of the input
352
+ tensor with probability using samples from a Bernoulli distribution.
353
+ The elements to be masked are randomized on every forward call, and scaled
354
+ and shifted to maintain zero mean and unit standard deviation.
355
+
356
+ During evaluation the module simply computes an identity function.
357
+
358
+ Parameters
359
+ ----------
360
+ prob : float
361
+ Probability of an element to be kept. Default is 0.5.
362
+ name : str, optional
363
+ The name of the dynamic system.
364
+
365
+ Notes
366
+ -----
367
+ Input shape: :math:`(*)`. Input can be of any shape.
368
+
369
+ Output shape: :math:`(*)`. Output is of the same shape as input.
370
+
371
+ References
372
+ ----------
373
+ .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
374
+ https://arxiv.org/abs/1706.02515
375
+
376
+ Examples
377
+ --------
378
+ .. code-block:: python
379
+
380
+ >>> import brainstate
381
+ >>> m = brainstate.nn.AlphaDropout(prob=0.8)
382
+ >>> x = brainstate.random.randn(20, 16)
383
+ >>> with brainstate.environ.context(fit=True):
384
+ ... output = m(x)
385
+ >>> output.shape
386
+ (20, 16)
387
+
388
+ """
389
+ __module__ = 'brainstate.nn'
390
+
391
+ def __init__(
392
+ self,
393
+ prob: float = 0.5,
394
+ name: Optional[str] = None
395
+ ) -> None:
396
+ super().__init__(name=name)
397
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
398
+ self.prob = prob
399
+
400
+ # SELU parameters
401
+ alpha = -1.7580993408473766
402
+ self.alpha = alpha
403
+
404
+ # Affine transformation parameters to maintain mean and variance
405
+ self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
406
+ self.b = -self.a * alpha * prob
407
+
408
+ def __call__(self, x):
409
+ dtype = u.math.get_dtype(x)
410
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
411
+ if fit_phase and self.prob < 1.:
412
+ keep_mask = random.bernoulli(self.prob, x.shape)
413
+ return u.math.where(
414
+ keep_mask,
415
+ u.math.asarray(x, dtype=dtype),
416
+ u.math.asarray(self.alpha, dtype=dtype)
417
+ ) * self.a + self.b
418
+ else:
419
+ return x
420
+
421
+
422
+ class FeatureAlphaDropout(ElementWiseBlock):
423
+ r"""Randomly masks out entire channels with Alpha Dropout properties.
424
+
425
+ Instead of setting activations to zero as in regular Dropout, the activations
426
+ are set to the negative saturation value of the SELU activation function to
427
+ maintain self-normalizing properties.
428
+
429
+ Each channel (e.g., the :math:`j`-th channel of the :math:`i`-th sample in
430
+ the batch input is a tensor :math:`\text{input}[i, j]`) will be masked
431
+ independently for each sample on every forward call with probability using
432
+ samples from a Bernoulli distribution. The elements to be masked are randomized
433
+ on every forward call, and scaled and shifted to maintain zero mean and unit
434
+ variance.
435
+
436
+ Usually the input comes from convolutional layers with SELU activation.
437
+
438
+ As described in the paper [2]_, if adjacent pixels within feature maps are
439
+ strongly correlated (as is normally the case in early convolution layers)
440
+ then i.i.d. dropout will not regularize the activations and will otherwise
441
+ just result in an effective learning rate decrease.
442
+
443
+ In this case, :class:`FeatureAlphaDropout` will help promote independence between
444
+ feature maps and should be used instead.
445
+
446
+ Parameters
447
+ ----------
448
+ prob : float
449
+ Probability of an element to be kept. Default is 0.5.
450
+ channel_axis : int
451
+ The axis representing the channel dimension. Default is -1.
452
+ name : str, optional
453
+ The name of the dynamic system.
454
+
455
+ Notes
456
+ -----
457
+ Input shape: :math:`(N, C, *)` where C is the channel dimension.
458
+
459
+ Output shape: Same shape as input.
460
+
461
+ References
462
+ ----------
463
+ .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
464
+ https://arxiv.org/abs/1706.02515
465
+ .. [2] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
466
+ https://arxiv.org/abs/1411.4280
467
+
468
+ Examples
469
+ --------
470
+ .. code-block:: python
471
+
472
+ >>> import brainstate
473
+ >>> m = brainstate.nn.FeatureAlphaDropout(prob=0.8)
474
+ >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
475
+ >>> with brainstate.environ.context(fit=True):
476
+ ... output = m(x)
477
+ >>> output.shape
478
+ (20, 16, 4, 32, 32)
479
+
480
+ """
481
+ __module__ = 'brainstate.nn'
482
+
483
+ def __init__(
484
+ self,
485
+ prob: float = 0.5,
486
+ channel_axis: int = -1,
487
+ name: Optional[str] = None
488
+ ) -> None:
489
+ super().__init__(name=name)
490
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
491
+ self.prob = prob
492
+ self.channel_axis = channel_axis
493
+
494
+ # SELU parameters
495
+ alpha = -1.7580993408473766
496
+ self.alpha = alpha
497
+
498
+ # Affine transformation parameters to maintain mean and variance
499
+ self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
500
+ self.b = -self.a * alpha * prob
501
+
502
+ def __call__(self, x):
503
+ dtype = u.math.get_dtype(x)
504
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
505
+ if fit_phase and self.prob < 1.:
506
+ # Create mask shape with 1s except for batch and channel dimensions
507
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
508
+ mask_shape = [1] * x.ndim
509
+ mask_shape[0] = x.shape[0] # batch dimension
510
+ mask_shape[channel_axis] = x.shape[channel_axis] # channel dimension
511
+
512
+ keep_mask = random.bernoulli(self.prob, mask_shape)
513
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
514
+ return u.math.where(
515
+ keep_mask,
516
+ u.math.asarray(x, dtype=dtype),
517
+ u.math.asarray(self.alpha, dtype=dtype)
518
+ ) * self.a + self.b
519
+ else:
520
+ return x
521
+
522
+
523
+ class DropoutFixed(ElementWiseBlock):
524
+ """A dropout layer with a fixed dropout mask along the time axis.
525
+
526
+ In training, to compensate for the fraction of input values dropped,
527
+ all surviving values are multiplied by `1 / (1 - prob)`.
528
+
529
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
530
+ circumstances it is a no-op.
531
+
532
+ This kind of Dropout is particularly useful for spiking neural networks (SNNs) where
533
+ the same dropout mask needs to be applied across multiple time steps within a single
534
+ mini-batch iteration.
535
+
536
+ Parameters
537
+ ----------
538
+ in_size : tuple or int
539
+ The size of the input tensor.
540
+ prob : float
541
+ Probability to keep element of the tensor. Default is 0.5.
542
+ name : str, optional
543
+ The name of the dynamic system.
544
+
545
+ Notes
546
+ -----
547
+ As described in [2]_, there is a subtle difference in the way dropout is applied in
548
+ SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of
549
+ mini-batches. In each iteration, randomly selected units (with dropout ratio of
550
+ :math:`p`) are disconnected from the network while weighting by its posterior
551
+ probability (:math:`1-p`).
552
+
553
+ However, in SNNs, each iteration has more than one forward propagation depending on
554
+ the time length of the spike train. We back-propagate the output error and modify
555
+ the network parameters only at the last time step. For dropout to be effective in
556
+ our training method, it has to be ensured that the set of connected units within an
557
+ iteration of mini-batch data is not changed, such that the neural network is
558
+ constituted by the same random subset of units during each forward propagation within
559
+ a single iteration.
560
+
561
+ On the other hand, if the units are randomly connected at each time-step, the effect
562
+ of dropout will be averaged out over the entire forward propagation time within an
563
+ iteration. Then, the dropout effect would fade-out once the output error is propagated
564
+ backward and the parameters are updated at the last time step. Therefore, we need to
565
+ keep the set of randomly connected units for the entire time window within an iteration.
566
+
567
+ References
568
+ ----------
569
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
570
+ neural networks from overfitting." The journal of machine learning
571
+ research 15.1 (2014): 1929-1958.
572
+ .. [2] Lee et al., "Enabling Spike-based Backpropagation for Training Deep Neural
573
+ Network Architectures" https://arxiv.org/abs/1903.06379
574
+
575
+ Examples
576
+ --------
577
+ .. code-block:: python
578
+
579
+ >>> import brainstate
580
+ >>> layer = brainstate.nn.DropoutFixed(in_size=(20,), prob=0.8)
581
+ >>> layer.init_state(batch_size=10)
582
+ >>> x = brainstate.random.randn(10, 20)
583
+ >>> with brainstate.environ.context(fit=True):
584
+ ... output = layer.update(x)
585
+ >>> output.shape
586
+ (10, 20)
587
+
588
+ """
589
+ __module__ = 'brainstate.nn'
590
+
591
+ def __init__(
592
+ self,
593
+ in_size: Size,
594
+ prob: float = 0.5,
595
+ name: Optional[str] = None
596
+ ) -> None:
597
+ super().__init__(name=name)
598
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
599
+ self.prob = prob
600
+ self.in_size = in_size
601
+ self.out_size = in_size
602
+
603
+ def init_state(self, batch_size=None, **kwargs):
604
+ if self.prob < 1.:
605
+ self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
606
+
607
+ def update(self, x):
608
+ dtype = u.math.get_dtype(x)
609
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
610
+ if fit_phase and self.prob < 1.:
611
+ if self.mask.value.shape != x.shape:
612
+ raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
613
+ f"Please call `init_state()` method first.")
614
+ return u.math.where(self.mask.value,
615
+ u.math.asarray(x / self.prob, dtype=dtype),
616
+ u.math.asarray(0., dtype=dtype) * u.get_unit(x))
617
+ else:
618
+ return x