brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,418 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from __future__ import annotations
18
+
19
+ from functools import partial
20
+ from typing import Optional
21
+
22
+ import brainunit as u
23
+ import jax.numpy as jnp
24
+
25
+ from brainstate import random, environ, init
26
+ from brainstate._state import ShortTermState
27
+ from brainstate.nn._module import ElementWiseBlock
28
+ from brainstate.typing import Size
29
+
30
+ __all__ = [
31
+ 'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
32
+ 'AlphaDropout', 'FeatureAlphaDropout',
33
+ ]
34
+
35
+
36
+ class Dropout(ElementWiseBlock):
37
+ """A layer that stochastically ignores a subset of inputs each training step.
38
+
39
+ In training, to compensate for the fraction of input values dropped (`rate`),
40
+ all surviving values are multiplied by `1 / (1 - rate)`.
41
+
42
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
43
+ circumstances it is a no-op.
44
+
45
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
46
+ neural networks from overfitting." The journal of machine learning
47
+ research 15.1 (2014): 1929-1958.
48
+
49
+ Args:
50
+ prob: Probability to keep element of the tensor.
51
+ mode: Mode. The computation mode of the object.
52
+ name: str. The name of the dynamic system.
53
+
54
+ """
55
+ __module__ = 'brainstate.nn'
56
+
57
+ def __init__(
58
+ self,
59
+ prob: float = 0.5,
60
+ name: Optional[str] = None
61
+ ) -> None:
62
+ super().__init__(name=name)
63
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
64
+ self.prob = prob
65
+
66
+ def __call__(self, x):
67
+ dtype = u.math.get_dtype(x)
68
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
69
+ if fit_phase and self.prob < 1.:
70
+ keep_mask = random.bernoulli(self.prob, x.shape)
71
+ return jnp.where(keep_mask,
72
+ jnp.asarray(x / self.prob, dtype=dtype),
73
+ jnp.asarray(0., dtype=dtype))
74
+ else:
75
+ return x
76
+
77
+
78
+ class _DropoutNd(ElementWiseBlock):
79
+ __module__ = 'brainstate.nn'
80
+ prob: float
81
+ channel_axis: int
82
+ minimal_dim: int
83
+
84
+ def __init__(
85
+ self,
86
+ prob: float = 0.5,
87
+ channel_axis: int = -1,
88
+ name: Optional[str] = None
89
+ ) -> None:
90
+ super().__init__(name=name)
91
+ assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
92
+ self.prob = prob
93
+ self.channel_axis = channel_axis
94
+
95
+ def __call__(self, x):
96
+
97
+ # check input shape
98
+ inp_dim = u.math.ndim(x)
99
+ if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
100
+ raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
101
+ f"but received a {inp_dim}D input. {self._get_msg(x)}")
102
+ is_not_batched = self.minimal_dim
103
+ if is_not_batched:
104
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
105
+ mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
106
+ else:
107
+ channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
108
+ assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
109
+ mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
110
+
111
+ # get fit phase
112
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
113
+
114
+ # generate mask
115
+ if fit_phase:
116
+ dtype = u.math.get_dtype(x)
117
+ keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
118
+ return jnp.where(keep_mask,
119
+ jnp.asarray(x / self.prob, dtype=dtype),
120
+ jnp.asarray(0., dtype=dtype))
121
+ else:
122
+ return x
123
+
124
+ def _get_msg(self, x):
125
+ return ''
126
+
127
+
128
+ class Dropout1d(_DropoutNd):
129
+ r"""Randomly zero out entire channels (a channel is a 1D feature map,
130
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
131
+ batched input is a 1D tensor :math:`\text{input}[i, j]`).
132
+ Each channel will be zeroed out independently on every forward call with
133
+ probability :attr:`p` using samples from a Bernoulli distribution.
134
+
135
+ Usually the input comes from :class:`nn.Conv1d` modules.
136
+
137
+ As described in the paper
138
+ `Efficient Object Localization Using Convolutional Networks`_ ,
139
+ if adjacent pixels within feature maps are strongly correlated
140
+ (as is normally the case in early convolution layers) then i.i.d. dropout
141
+ will not regularize the activations and will otherwise just result
142
+ in an effective learning rate decrease.
143
+
144
+ In this case, :func:`nn.Dropout1d` will help promote independence between
145
+ feature maps and should be used instead.
146
+
147
+ Args:
148
+ prob: float. probability of an element to be zero-ed.
149
+
150
+ Shape:
151
+ - Input: :math:`(N, C, L)` or :math:`(C, L)`.
152
+ - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
153
+
154
+ Examples::
155
+
156
+ >>> m = Dropout1d(p=0.2)
157
+ >>> x = random.randn(20, 32, 16)
158
+ >>> output = m(x)
159
+ >>> output.shape
160
+ (20, 32, 16)
161
+
162
+ .. _Efficient Object Localization Using Convolutional Networks:
163
+ https://arxiv.org/abs/1411.4280
164
+ """
165
+ __module__ = 'brainstate.nn'
166
+ minimal_dim: int = 2
167
+
168
+ def _get_msg(self, x):
169
+ return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
170
+ "spatial dimension, a channel dimension, and an optional batch dimension "
171
+ "(i.e. 2D or 3D inputs).")
172
+
173
+
174
+ class Dropout2d(_DropoutNd):
175
+ r"""Randomly zero out entire channels (a channel is a 2D feature map,
176
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
177
+ batched input is a 2D tensor :math:`\text{input}[i, j]`).
178
+ Each channel will be zeroed out independently on every forward call with
179
+ probability :attr:`p` using samples from a Bernoulli distribution.
180
+
181
+ Usually the input comes from :class:`nn.Conv2d` modules.
182
+
183
+ As described in the paper
184
+ `Efficient Object Localization Using Convolutional Networks`_ ,
185
+ if adjacent pixels within feature maps are strongly correlated
186
+ (as is normally the case in early convolution layers) then i.i.d. dropout
187
+ will not regularize the activations and will otherwise just result
188
+ in an effective learning rate decrease.
189
+
190
+ In this case, :func:`nn.Dropout2d` will help promote independence between
191
+ feature maps and should be used instead.
192
+
193
+ Args:
194
+ prob: float. probability of an element to be kept.
195
+
196
+ Shape:
197
+ - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
198
+ - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
199
+
200
+ Examples::
201
+
202
+ >>> m = Dropout2d(p=0.2)
203
+ >>> x = random.randn(20, 32, 32, 16)
204
+ >>> output = m(x)
205
+
206
+ .. _Efficient Object Localization Using Convolutional Networks:
207
+ https://arxiv.org/abs/1411.4280
208
+ """
209
+ __module__ = 'brainstate.nn'
210
+ minimal_dim: int = 3
211
+
212
+ def _get_msg(self, x):
213
+ return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
214
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
215
+ "(i.e. 3D or 4D inputs).")
216
+
217
+
218
+ class Dropout3d(_DropoutNd):
219
+ r"""Randomly zero out entire channels (a channel is a 3D feature map,
220
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
221
+ batched input is a 3D tensor :math:`\text{input}[i, j]`).
222
+ Each channel will be zeroed out independently on every forward call with
223
+ probability :attr:`p` using samples from a Bernoulli distribution.
224
+
225
+ Usually the input comes from :class:`nn.Conv3d` modules.
226
+
227
+ As described in the paper
228
+ `Efficient Object Localization Using Convolutional Networks`_ ,
229
+ if adjacent pixels within feature maps are strongly correlated
230
+ (as is normally the case in early convolution layers) then i.i.d. dropout
231
+ will not regularize the activations and will otherwise just result
232
+ in an effective learning rate decrease.
233
+
234
+ In this case, :func:`nn.Dropout3d` will help promote independence between
235
+ feature maps and should be used instead.
236
+
237
+ Args:
238
+ prob: float. probability of an element to be kept.
239
+
240
+ Shape:
241
+ - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
242
+ - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
243
+
244
+ Examples::
245
+
246
+ >>> m = Dropout3d(p=0.2)
247
+ >>> x = random.randn(20, 16, 4, 32, 32)
248
+ >>> output = m(x)
249
+
250
+ .. _Efficient Object Localization Using Convolutional Networks:
251
+ https://arxiv.org/abs/1411.4280
252
+ """
253
+ __module__ = 'brainstate.nn'
254
+ minimal_dim: int = 4
255
+
256
+ def _get_msg(self, x):
257
+ return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
258
+ "spatial dimensions, a channel dimension, and an optional batch dimension "
259
+ "(i.e. 4D or 5D inputs).")
260
+
261
+
262
+ class AlphaDropout(_DropoutNd):
263
+ r"""Applies Alpha Dropout over the input.
264
+
265
+ Alpha Dropout is a type of Dropout that maintains the self-normalizing
266
+ property.
267
+ For an input with zero mean and unit standard deviation, the output of
268
+ Alpha Dropout maintains the original mean and standard deviation of the
269
+ input.
270
+ Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
271
+ that the outputs have zero mean and unit standard deviation.
272
+
273
+ During training, it randomly masks some of the elements of the input
274
+ tensor with probability *p* using samples from a bernoulli distribution.
275
+ The elements to masked are randomized on every forward call, and scaled
276
+ and shifted to maintain zero mean and unit standard deviation.
277
+
278
+ During evaluation the module simply computes an identity function.
279
+
280
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
281
+
282
+ Args:
283
+ prob: float. probability of an element to be kept.
284
+
285
+ Shape:
286
+ - Input: :math:`(*)`. Input can be of any shape
287
+ - Output: :math:`(*)`. Output is of the same shape as input
288
+
289
+ Examples::
290
+
291
+ >>> m = AlphaDropout(p=0.2)
292
+ >>> x = random.randn(20, 16)
293
+ >>> output = m(x)
294
+
295
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
296
+ """
297
+ __module__ = 'brainstate.nn'
298
+
299
+ def forward(self, x):
300
+ return F.alpha_dropout(x, self.p, self.training)
301
+
302
+
303
+ class FeatureAlphaDropout(_DropoutNd):
304
+ r"""Randomly masks out entire channels (a channel is a feature map,
305
+ e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
306
+ is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
307
+ setting activations to zero, as in regular Dropout, the activations are set
308
+ to the negative saturation value of the SELU activation function. More details
309
+ can be found in the paper `Self-Normalizing Neural Networks`_ .
310
+
311
+ Each element will be masked independently for each sample on every forward
312
+ call with probability :attr:`p` using samples from a Bernoulli distribution.
313
+ The elements to be masked are randomized on every forward call, and scaled
314
+ and shifted to maintain zero mean and unit variance.
315
+
316
+ Usually the input comes from :class:`nn.AlphaDropout` modules.
317
+
318
+ As described in the paper
319
+ `Efficient Object Localization Using Convolutional Networks`_ ,
320
+ if adjacent pixels within feature maps are strongly correlated
321
+ (as is normally the case in early convolution layers) then i.i.d. dropout
322
+ will not regularize the activations and will otherwise just result
323
+ in an effective learning rate decrease.
324
+
325
+ In this case, :func:`nn.AlphaDropout` will help promote independence between
326
+ feature maps and should be used instead.
327
+
328
+ Args:
329
+ prob: float. probability of an element to be kept.
330
+
331
+ Shape:
332
+ - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
333
+ - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
334
+
335
+ Examples::
336
+
337
+ >>> m = FeatureAlphaDropout(p=0.2)
338
+ >>> x = random.randn(20, 16, 4, 32, 32)
339
+ >>> output = m(x)
340
+
341
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
342
+ .. _Efficient Object Localization Using Convolutional Networks:
343
+ https://arxiv.org/abs/1411.4280
344
+ """
345
+ __module__ = 'brainstate.nn'
346
+
347
+ def forward(self, x):
348
+ return F.feature_alpha_dropout(x, self.p, self.training)
349
+
350
+
351
+ class DropoutFixed(ElementWiseBlock):
352
+ """
353
+ A dropout layer with the fixed dropout mask along the time axis once after initialized.
354
+
355
+ In training, to compensate for the fraction of input values dropped (`rate`),
356
+ all surviving values are multiplied by `1 / (1 - rate)`.
357
+
358
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
359
+ circumstances it is a no-op.
360
+
361
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
362
+ neural networks from overfitting." The journal of machine learning
363
+ research 15.1 (2014): 1929-1958.
364
+
365
+ .. admonition:: Tip
366
+ :class: tip
367
+
368
+ This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
369
+ Network Architectures <https://arxiv.org/abs/1903.06379>`_:
370
+
371
+ There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
372
+ training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
373
+ are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
374
+ iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
375
+ the output error and modify the network parameters only at the last time step. For dropout to be effective in
376
+ our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
377
+ data is not changed, such that the neural network is constituted by the same random subset of units during
378
+ each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
379
+ each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
380
+ iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
381
+ are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
382
+ time window within an iteration.
383
+
384
+ Args:
385
+ in_size: The size of the input tensor.
386
+ prob: Probability to keep element of the tensor.
387
+ mode: Mode. The computation mode of the object.
388
+ name: str. The name of the dynamic system.
389
+ """
390
+ __module__ = 'brainstate.nn'
391
+
392
+ def __init__(
393
+ self,
394
+ in_size: Size,
395
+ prob: float = 0.5,
396
+ name: Optional[str] = None
397
+ ) -> None:
398
+ super().__init__(name=name)
399
+ assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
400
+ self.prob = prob
401
+ self.in_size = in_size
402
+ self.out_size = in_size
403
+
404
+ def init_state(self, batch_size=None, **kwargs):
405
+ self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
406
+
407
+ def update(self, x):
408
+ dtype = u.math.get_dtype(x)
409
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
410
+ if fit_phase:
411
+ if self.mask.value.shape != x.shape:
412
+ raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
413
+ f"Please call `init_state()` method first.")
414
+ return jnp.where(self.mask.value,
415
+ jnp.asarray(x / self.prob, dtype=dtype),
416
+ jnp.asarray(0., dtype=dtype))
417
+ else:
418
+ return x
@@ -0,0 +1,100 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import numpy as np
20
+
21
+ import brainstate as bst
22
+
23
+
24
+ class TestDropout(unittest.TestCase):
25
+
26
+ def test_dropout(self):
27
+ # Create a Dropout layer with a dropout rate of 0.5
28
+ dropout_layer = bst.nn.Dropout(0.5)
29
+
30
+ # Input data
31
+ input_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
32
+
33
+ with bst.environ.context(fit=True):
34
+ # Apply dropout
35
+ output_data = dropout_layer(input_data)
36
+
37
+ # Check that the output has the same shape as the input
38
+ self.assertEqual(input_data.shape, output_data.shape)
39
+
40
+ # Check that some elements are zeroed out
41
+ self.assertTrue(np.any(output_data == 0))
42
+
43
+ # Check that the non-zero elements are scaled by 1/(1-rate)
44
+ scale_factor = 1 / (1 - 0.5)
45
+ non_zero_elements = output_data[output_data != 0]
46
+ expected_non_zero_elements = input_data[output_data != 0] * scale_factor
47
+ np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
48
+
49
+ def test_DropoutFixed(self):
50
+ dropout_layer = bst.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
51
+ dropout_layer.init_state(batch_size=2)
52
+ input_data = np.random.randn(2, 2, 3)
53
+ with bst.environ.context(fit=True):
54
+ output_data = dropout_layer.update(input_data)
55
+ self.assertEqual(input_data.shape, output_data.shape)
56
+ self.assertTrue(np.any(output_data == 0))
57
+ scale_factor = 1 / (1 - 0.5)
58
+ non_zero_elements = output_data[output_data != 0]
59
+ expected_non_zero_elements = input_data[output_data != 0] * scale_factor
60
+ np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
61
+
62
+ def test_Dropout1d(self):
63
+ dropout_layer = bst.nn.Dropout1d(prob=0.5)
64
+ input_data = np.random.randn(2, 3, 4)
65
+ with bst.environ.context(fit=True):
66
+ output_data = dropout_layer(input_data)
67
+ self.assertEqual(input_data.shape, output_data.shape)
68
+ self.assertTrue(np.any(output_data == 0))
69
+ scale_factor = 1 / (1 - 0.5)
70
+ non_zero_elements = output_data[output_data != 0]
71
+ expected_non_zero_elements = input_data[output_data != 0] * scale_factor
72
+ np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
73
+
74
+ def test_Dropout2d(self):
75
+ dropout_layer = bst.nn.Dropout2d(prob=0.5)
76
+ input_data = np.random.randn(2, 3, 4, 5)
77
+ with bst.environ.context(fit=True):
78
+ output_data = dropout_layer(input_data)
79
+ self.assertEqual(input_data.shape, output_data.shape)
80
+ self.assertTrue(np.any(output_data == 0))
81
+ scale_factor = 1 / (1 - 0.5)
82
+ non_zero_elements = output_data[output_data != 0]
83
+ expected_non_zero_elements = input_data[output_data != 0] * scale_factor
84
+ np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
85
+
86
+ def test_Dropout3d(self):
87
+ dropout_layer = bst.nn.Dropout3d(prob=0.5)
88
+ input_data = np.random.randn(2, 3, 4, 5, 6)
89
+ with bst.environ.context(fit=True):
90
+ output_data = dropout_layer(input_data)
91
+ self.assertEqual(input_data.shape, output_data.shape)
92
+ self.assertTrue(np.any(output_data == 0))
93
+ scale_factor = 1 / (1 - 0.5)
94
+ non_zero_elements = output_data[output_data != 0]
95
+ expected_non_zero_elements = input_data[output_data != 0] * scale_factor
96
+ np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
97
+
98
+
99
+ if __name__ == '__main__':
100
+ unittest.main()