brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_dropout.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,13 +20,20 @@ from typing import Optional, Sequence
20
20
  import brainunit as u
21
21
  import jax.numpy as jnp
22
22
 
23
- from brainstate import random, environ, init
23
+ from brainstate import random, environ
24
24
  from brainstate._state import ShortTermState
25
25
  from brainstate.typing import Size
26
+ from . import init as init
26
27
  from ._module import ElementWiseBlock
27
28
 
28
29
  __all__ = [
29
- 'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
30
+ 'Dropout',
31
+ 'Dropout1d',
32
+ 'Dropout2d',
33
+ 'Dropout3d',
34
+ 'AlphaDropout',
35
+ 'FeatureAlphaDropout',
36
+ 'DropoutFixed',
30
37
  ]
31
38
 
32
39
 
@@ -39,14 +46,32 @@ class Dropout(ElementWiseBlock):
39
46
  This layer is active only during training (``mode=brainstate.mixin.Training``). In other
40
47
  circumstances it is a no-op.
41
48
 
49
+ Parameters
50
+ ----------
51
+ prob : float
52
+ Probability to keep element of the tensor. Default is 0.5.
53
+ broadcast_dims : Sequence[int]
54
+ Dimensions that will share the same dropout mask. Default is ().
55
+ name : str, optional
56
+ The name of the dynamic system.
57
+
58
+ References
59
+ ----------
42
60
  .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
43
61
  neural networks from overfitting." The journal of machine learning
44
62
  research 15.1 (2014): 1929-1958.
45
63
 
46
- Args:
47
- prob: Probability to keep element of the tensor.
48
- broadcast_dims: dimensions that will share the same dropout mask.
49
- name: str. The name of the dynamic system.
64
+ Examples
65
+ --------
66
+ .. code-block:: python
67
+
68
+ >>> import brainstate
69
+ >>> layer = brainstate.nn.Dropout(prob=0.8)
70
+ >>> x = brainstate.random.randn(10, 20)
71
+ >>> with brainstate.environ.context(fit=True):
72
+ ... output = layer(x)
73
+ >>> output.shape
74
+ (10, 20)
50
75
 
51
76
  """
52
77
  __module__ = 'brainstate.nn'
@@ -133,41 +158,55 @@ class _DropoutNd(ElementWiseBlock):
133
158
 
134
159
 
135
160
  class Dropout1d(_DropoutNd):
136
- r"""Randomly zero out entire channels (a channel is a 1D feature map,
137
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
138
- batched input is a 1D tensor :math:`\text{input}[i, j]`).
161
+ r"""Randomly zero out entire channels (a channel is a 1D feature map).
162
+
139
163
  Each channel will be zeroed out independently on every forward call with
140
- probability :attr:`p` using samples from a Bernoulli distribution.
164
+ probability using samples from a Bernoulli distribution. The channel is
165
+ a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
166
+ in the batched input is a 1D tensor :math:`\text{input}[i, j]`.
141
167
 
142
- Usually the input comes from :class:`nn.Conv1d` modules.
168
+ Usually the input comes from :class:`Conv1d` modules.
143
169
 
144
- As described in the paper
145
- `Efficient Object Localization Using Convolutional Networks`_ ,
146
- if adjacent pixels within feature maps are strongly correlated
147
- (as is normally the case in early convolution layers) then i.i.d. dropout
148
- will not regularize the activations and will otherwise just result
149
- in an effective learning rate decrease.
170
+ As described in the paper [1]_, if adjacent pixels within feature maps are
171
+ strongly correlated (as is normally the case in early convolution layers)
172
+ then i.i.d. dropout will not regularize the activations and will otherwise
173
+ just result in an effective learning rate decrease.
150
174
 
151
- In this case, :func:`nn.Dropout1d` will help promote independence between
175
+ In this case, :class:`Dropout1d` will help promote independence between
152
176
  feature maps and should be used instead.
153
177
 
154
- Args:
155
- prob: float. probability of an element to be zero-ed.
156
-
157
- Shape:
158
- - Input: :math:`(N, C, L)` or :math:`(C, L)`.
159
- - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
160
-
161
- Examples::
162
-
163
- >>> m = Dropout1d(p=0.2)
164
- >>> x = random.randn(20, 32, 16)
165
- >>> output = m(x)
178
+ Parameters
179
+ ----------
180
+ prob : float
181
+ Probability of an element to be kept. Default is 0.5.
182
+ channel_axis : int
183
+ The axis representing the channel dimension. Default is -1.
184
+ name : str, optional
185
+ The name of the dynamic system.
186
+
187
+ Notes
188
+ -----
189
+ Input shape: :math:`(N, C, L)` or :math:`(C, L)`.
190
+
191
+ Output shape: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
192
+
193
+ References
194
+ ----------
195
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
196
+ https://arxiv.org/abs/1411.4280
197
+
198
+ Examples
199
+ --------
200
+ .. code-block:: python
201
+
202
+ >>> import brainstate
203
+ >>> m = brainstate.nn.Dropout1d(prob=0.8)
204
+ >>> x = brainstate.random.randn(20, 32, 16)
205
+ >>> with brainstate.environ.context(fit=True):
206
+ ... output = m(x)
166
207
  >>> output.shape
167
208
  (20, 32, 16)
168
209
 
169
- .. _Efficient Object Localization Using Convolutional Networks:
170
- https://arxiv.org/abs/1411.4280
171
210
  """
172
211
  __module__ = 'brainstate.nn'
173
212
  minimal_dim: int = 2
@@ -179,39 +218,55 @@ class Dropout1d(_DropoutNd):
179
218
 
180
219
 
181
220
  class Dropout2d(_DropoutNd):
182
- r"""Randomly zero out entire channels (a channel is a 2D feature map,
183
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
184
- batched input is a 2D tensor :math:`\text{input}[i, j]`).
221
+ r"""Randomly zero out entire channels (a channel is a 2D feature map).
222
+
185
223
  Each channel will be zeroed out independently on every forward call with
186
- probability :attr:`p` using samples from a Bernoulli distribution.
224
+ probability using samples from a Bernoulli distribution. The channel is
225
+ a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
226
+ in the batched input is a 2D tensor :math:`\text{input}[i, j]`.
187
227
 
188
- Usually the input comes from :class:`nn.Conv2d` modules.
228
+ Usually the input comes from :class:`Conv2d` modules.
189
229
 
190
- As described in the paper
191
- `Efficient Object Localization Using Convolutional Networks`_ ,
192
- if adjacent pixels within feature maps are strongly correlated
193
- (as is normally the case in early convolution layers) then i.i.d. dropout
194
- will not regularize the activations and will otherwise just result
195
- in an effective learning rate decrease.
230
+ As described in the paper [1]_, if adjacent pixels within feature maps are
231
+ strongly correlated (as is normally the case in early convolution layers)
232
+ then i.i.d. dropout will not regularize the activations and will otherwise
233
+ just result in an effective learning rate decrease.
196
234
 
197
- In this case, :func:`nn.Dropout2d` will help promote independence between
235
+ In this case, :class:`Dropout2d` will help promote independence between
198
236
  feature maps and should be used instead.
199
237
 
200
- Args:
201
- prob: float. probability of an element to be kept.
202
-
203
- Shape:
204
- - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
205
- - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
206
-
207
- Examples::
208
-
209
- >>> m = Dropout2d(p=0.2)
210
- >>> x = random.randn(20, 32, 32, 16)
211
- >>> output = m(x)
238
+ Parameters
239
+ ----------
240
+ prob : float
241
+ Probability of an element to be kept. Default is 0.5.
242
+ channel_axis : int
243
+ The axis representing the channel dimension. Default is -1.
244
+ name : str, optional
245
+ The name of the dynamic system.
246
+
247
+ Notes
248
+ -----
249
+ Input shape: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
250
+
251
+ Output shape: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
252
+
253
+ References
254
+ ----------
255
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
256
+ https://arxiv.org/abs/1411.4280
257
+
258
+ Examples
259
+ --------
260
+ .. code-block:: python
261
+
262
+ >>> import brainstate
263
+ >>> m = brainstate.nn.Dropout2d(prob=0.8)
264
+ >>> x = brainstate.random.randn(20, 32, 32, 16)
265
+ >>> with brainstate.environ.context(fit=True):
266
+ ... output = m(x)
267
+ >>> output.shape
268
+ (20, 32, 32, 16)
212
269
 
213
- .. _Efficient Object Localization Using Convolutional Networks:
214
- https://arxiv.org/abs/1411.4280
215
270
  """
216
271
  __module__ = 'brainstate.nn'
217
272
  minimal_dim: int = 3
@@ -223,39 +278,55 @@ class Dropout2d(_DropoutNd):
223
278
 
224
279
 
225
280
  class Dropout3d(_DropoutNd):
226
- r"""Randomly zero out entire channels (a channel is a 3D feature map,
227
- e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
228
- batched input is a 3D tensor :math:`\text{input}[i, j]`).
281
+ r"""Randomly zero out entire channels (a channel is a 3D feature map).
282
+
229
283
  Each channel will be zeroed out independently on every forward call with
230
- probability :attr:`p` using samples from a Bernoulli distribution.
284
+ probability using samples from a Bernoulli distribution. The channel is
285
+ a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
286
+ in the batched input is a 3D tensor :math:`\text{input}[i, j]`.
231
287
 
232
- Usually the input comes from :class:`nn.Conv3d` modules.
288
+ Usually the input comes from :class:`Conv3d` modules.
233
289
 
234
- As described in the paper
235
- `Efficient Object Localization Using Convolutional Networks`_ ,
236
- if adjacent pixels within feature maps are strongly correlated
237
- (as is normally the case in early convolution layers) then i.i.d. dropout
238
- will not regularize the activations and will otherwise just result
239
- in an effective learning rate decrease.
290
+ As described in the paper [1]_, if adjacent pixels within feature maps are
291
+ strongly correlated (as is normally the case in early convolution layers)
292
+ then i.i.d. dropout will not regularize the activations and will otherwise
293
+ just result in an effective learning rate decrease.
240
294
 
241
- In this case, :func:`nn.Dropout3d` will help promote independence between
295
+ In this case, :class:`Dropout3d` will help promote independence between
242
296
  feature maps and should be used instead.
243
297
 
244
- Args:
245
- prob: float. probability of an element to be kept.
246
-
247
- Shape:
248
- - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
249
- - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
250
-
251
- Examples::
252
-
253
- >>> m = Dropout3d(p=0.2)
254
- >>> x = random.randn(20, 16, 4, 32, 32)
255
- >>> output = m(x)
298
+ Parameters
299
+ ----------
300
+ prob : float
301
+ Probability of an element to be kept. Default is 0.5.
302
+ channel_axis : int
303
+ The axis representing the channel dimension. Default is -1.
304
+ name : str, optional
305
+ The name of the dynamic system.
306
+
307
+ Notes
308
+ -----
309
+ Input shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
310
+
311
+ Output shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
312
+
313
+ References
314
+ ----------
315
+ .. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
316
+ https://arxiv.org/abs/1411.4280
317
+
318
+ Examples
319
+ --------
320
+ .. code-block:: python
321
+
322
+ >>> import brainstate
323
+ >>> m = brainstate.nn.Dropout3d(prob=0.8)
324
+ >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
325
+ >>> with brainstate.environ.context(fit=True):
326
+ ... output = m(x)
327
+ >>> output.shape
328
+ (20, 16, 4, 32, 32)
256
329
 
257
- .. _Efficient Object Localization Using Convolutional Networks:
258
- https://arxiv.org/abs/1411.4280
259
330
  """
260
331
  __module__ = 'brainstate.nn'
261
332
  minimal_dim: int = 4
@@ -270,129 +341,250 @@ class AlphaDropout(_DropoutNd):
270
341
  r"""Applies Alpha Dropout over the input.
271
342
 
272
343
  Alpha Dropout is a type of Dropout that maintains the self-normalizing
273
- property.
274
- For an input with zero mean and unit standard deviation, the output of
344
+ property. For an input with zero mean and unit standard deviation, the output of
275
345
  Alpha Dropout maintains the original mean and standard deviation of the
276
346
  input.
347
+
277
348
  Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
278
349
  that the outputs have zero mean and unit standard deviation.
279
350
 
280
351
  During training, it randomly masks some of the elements of the input
281
- tensor with probability *p* using samples from a bernoulli distribution.
282
- The elements to masked are randomized on every forward call, and scaled
352
+ tensor with probability using samples from a Bernoulli distribution.
353
+ The elements to be masked are randomized on every forward call, and scaled
283
354
  and shifted to maintain zero mean and unit standard deviation.
284
355
 
285
356
  During evaluation the module simply computes an identity function.
286
357
 
287
- More details can be found in the paper `Self-Normalizing Neural Networks`_ .
358
+ Parameters
359
+ ----------
360
+ prob : float
361
+ Probability of an element to be kept. Default is 0.5.
362
+ name : str, optional
363
+ The name of the dynamic system.
364
+
365
+ Notes
366
+ -----
367
+ Input shape: :math:`(*)`. Input can be of any shape.
368
+
369
+ Output shape: :math:`(*)`. Output is of the same shape as input.
370
+
371
+ References
372
+ ----------
373
+ .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
374
+ https://arxiv.org/abs/1706.02515
375
+
376
+ Examples
377
+ --------
378
+ .. code-block:: python
379
+
380
+ >>> import brainstate
381
+ >>> m = brainstate.nn.AlphaDropout(prob=0.8)
382
+ >>> x = brainstate.random.randn(20, 16)
383
+ >>> with brainstate.environ.context(fit=True):
384
+ ... output = m(x)
385
+ >>> output.shape
386
+ (20, 16)
288
387
 
289
- Args:
290
- prob: float. probability of an element to be kept.
388
+ """
389
+ __module__ = 'brainstate.nn'
291
390
 
292
- Shape:
293
- - Input: :math:`(*)`. Input can be of any shape
294
- - Output: :math:`(*)`. Output is of the same shape as input
391
+ def __init__(
392
+ self,
393
+ prob: float = 0.5,
394
+ name: Optional[str] = None
395
+ ) -> None:
396
+ super().__init__(name=name)
397
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
398
+ self.prob = prob
295
399
 
296
- Examples::
400
+ # SELU parameters
401
+ alpha = -1.7580993408473766
402
+ self.alpha = alpha
297
403
 
298
- >>> m = AlphaDropout(p=0.2)
299
- >>> x = random.randn(20, 16)
300
- >>> output = m(x)
404
+ # Affine transformation parameters to maintain mean and variance
405
+ self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
406
+ self.b = -self.a * alpha * prob
301
407
 
302
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
303
- """
304
- __module__ = 'brainstate.nn'
408
+ def __call__(self, x):
409
+ dtype = u.math.get_dtype(x)
410
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
411
+ if fit_phase and self.prob < 1.:
412
+ keep_mask = random.bernoulli(self.prob, x.shape)
413
+ return u.math.where(
414
+ keep_mask,
415
+ u.math.asarray(x, dtype=dtype),
416
+ u.math.asarray(self.alpha, dtype=dtype)
417
+ ) * self.a + self.b
418
+ else:
419
+ return x
305
420
 
306
- def update(self, *args, **kwargs):
307
- raise NotImplementedError("AlphaDropout is not supported in the current version.")
308
421
 
422
+ class FeatureAlphaDropout(ElementWiseBlock):
423
+ r"""Randomly masks out entire channels with Alpha Dropout properties.
309
424
 
310
- class FeatureAlphaDropout(_DropoutNd):
311
- r"""Randomly masks out entire channels (a channel is a feature map,
312
- e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
313
- is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
314
- setting activations to zero, as in regular Dropout, the activations are set
315
- to the negative saturation value of the SELU activation function. More details
316
- can be found in the paper `Self-Normalizing Neural Networks`_ .
425
+ Instead of setting activations to zero as in regular Dropout, the activations
426
+ are set to the negative saturation value of the SELU activation function to
427
+ maintain self-normalizing properties.
317
428
 
318
- Each element will be masked independently for each sample on every forward
319
- call with probability :attr:`p` using samples from a Bernoulli distribution.
320
- The elements to be masked are randomized on every forward call, and scaled
321
- and shifted to maintain zero mean and unit variance.
429
+ Each channel (e.g., the :math:`j`-th channel of the :math:`i`-th sample in
430
+ the batch input is a tensor :math:`\text{input}[i, j]`) will be masked
431
+ independently for each sample on every forward call with probability using
432
+ samples from a Bernoulli distribution. The elements to be masked are randomized
433
+ on every forward call, and scaled and shifted to maintain zero mean and unit
434
+ variance.
322
435
 
323
- Usually the input comes from :class:`nn.AlphaDropout` modules.
436
+ Usually the input comes from convolutional layers with SELU activation.
324
437
 
325
- As described in the paper
326
- `Efficient Object Localization Using Convolutional Networks`_ ,
327
- if adjacent pixels within feature maps are strongly correlated
328
- (as is normally the case in early convolution layers) then i.i.d. dropout
329
- will not regularize the activations and will otherwise just result
330
- in an effective learning rate decrease.
438
+ As described in the paper [2]_, if adjacent pixels within feature maps are
439
+ strongly correlated (as is normally the case in early convolution layers)
440
+ then i.i.d. dropout will not regularize the activations and will otherwise
441
+ just result in an effective learning rate decrease.
331
442
 
332
- In this case, :func:`nn.AlphaDropout` will help promote independence between
443
+ In this case, :class:`FeatureAlphaDropout` will help promote independence between
333
444
  feature maps and should be used instead.
334
445
 
335
- Args:
336
- prob: float. probability of an element to be kept.
446
+ Parameters
447
+ ----------
448
+ prob : float
449
+ Probability of an element to be kept. Default is 0.5.
450
+ channel_axis : int
451
+ The axis representing the channel dimension. Default is -1.
452
+ name : str, optional
453
+ The name of the dynamic system.
454
+
455
+ Notes
456
+ -----
457
+ Input shape: :math:`(N, C, *)` where C is the channel dimension.
458
+
459
+ Output shape: Same shape as input.
460
+
461
+ References
462
+ ----------
463
+ .. [1] Klambauer et al., "Self-Normalizing Neural Networks"
464
+ https://arxiv.org/abs/1706.02515
465
+ .. [2] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
466
+ https://arxiv.org/abs/1411.4280
467
+
468
+ Examples
469
+ --------
470
+ .. code-block:: python
471
+
472
+ >>> import brainstate
473
+ >>> m = brainstate.nn.FeatureAlphaDropout(prob=0.8)
474
+ >>> x = brainstate.random.randn(20, 16, 4, 32, 32)
475
+ >>> with brainstate.environ.context(fit=True):
476
+ ... output = m(x)
477
+ >>> output.shape
478
+ (20, 16, 4, 32, 32)
337
479
 
338
- Shape:
339
- - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
340
- - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
480
+ """
481
+ __module__ = 'brainstate.nn'
341
482
 
342
- Examples::
483
+ def __init__(
484
+ self,
485
+ prob: float = 0.5,
486
+ channel_axis: int = -1,
487
+ name: Optional[str] = None
488
+ ) -> None:
489
+ super().__init__(name=name)
490
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
491
+ self.prob = prob
492
+ self.channel_axis = channel_axis
343
493
 
344
- >>> m = FeatureAlphaDropout(p=0.2)
345
- >>> x = random.randn(20, 16, 4, 32, 32)
346
- >>> output = m(x)
494
+ # SELU parameters
495
+ alpha = -1.7580993408473766
496
+ self.alpha = alpha
347
497
 
348
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
349
- .. _Efficient Object Localization Using Convolutional Networks:
350
- https://arxiv.org/abs/1411.4280
351
- """
352
- __module__ = 'brainstate.nn'
498
+ # Affine transformation parameters to maintain mean and variance
499
+ self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
500
+ self.b = -self.a * alpha * prob
353
501
 
354
- def update(self, *args, **kwargs):
355
- raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
502
+ def __call__(self, x):
503
+ dtype = u.math.get_dtype(x)
504
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
505
+ if fit_phase and self.prob < 1.:
506
+ # Create mask shape with 1s except for batch and channel dimensions
507
+ channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
508
+ mask_shape = [1] * x.ndim
509
+ mask_shape[0] = x.shape[0] # batch dimension
510
+ mask_shape[channel_axis] = x.shape[channel_axis] # channel dimension
511
+
512
+ keep_mask = random.bernoulli(self.prob, mask_shape)
513
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
514
+ return u.math.where(
515
+ keep_mask,
516
+ u.math.asarray(x, dtype=dtype),
517
+ u.math.asarray(self.alpha, dtype=dtype)
518
+ ) * self.a + self.b
519
+ else:
520
+ return x
356
521
 
357
522
 
358
523
  class DropoutFixed(ElementWiseBlock):
359
- """
360
- A dropout layer with the fixed dropout mask along the time axis once after initialized.
524
+ """A dropout layer with a fixed dropout mask along the time axis.
361
525
 
362
- In training, to compensate for the fraction of input values dropped (`rate`),
363
- all surviving values are multiplied by `1 / (1 - rate)`.
526
+ In training, to compensate for the fraction of input values dropped,
527
+ all surviving values are multiplied by `1 / (1 - prob)`.
364
528
 
365
529
  This layer is active only during training (``mode=brainstate.mixin.Training``). In other
366
530
  circumstances it is a no-op.
367
531
 
532
+ This kind of Dropout is particularly useful for spiking neural networks (SNNs) where
533
+ the same dropout mask needs to be applied across multiple time steps within a single
534
+ mini-batch iteration.
535
+
536
+ Parameters
537
+ ----------
538
+ in_size : tuple or int
539
+ The size of the input tensor.
540
+ prob : float
541
+ Probability to keep element of the tensor. Default is 0.5.
542
+ name : str, optional
543
+ The name of the dynamic system.
544
+
545
+ Notes
546
+ -----
547
+ As described in [2]_, there is a subtle difference in the way dropout is applied in
548
+ SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of
549
+ mini-batches. In each iteration, randomly selected units (with dropout ratio of
550
+ :math:`p`) are disconnected from the network while weighting by its posterior
551
+ probability (:math:`1-p`).
552
+
553
+ However, in SNNs, each iteration has more than one forward propagation depending on
554
+ the time length of the spike train. We back-propagate the output error and modify
555
+ the network parameters only at the last time step. For dropout to be effective in
556
+ our training method, it has to be ensured that the set of connected units within an
557
+ iteration of mini-batch data is not changed, such that the neural network is
558
+ constituted by the same random subset of units during each forward propagation within
559
+ a single iteration.
560
+
561
+ On the other hand, if the units are randomly connected at each time-step, the effect
562
+ of dropout will be averaged out over the entire forward propagation time within an
563
+ iteration. Then, the dropout effect would fade-out once the output error is propagated
564
+ backward and the parameters are updated at the last time step. Therefore, we need to
565
+ keep the set of randomly connected units for the entire time window within an iteration.
566
+
567
+ References
568
+ ----------
368
569
  .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
369
570
  neural networks from overfitting." The journal of machine learning
370
571
  research 15.1 (2014): 1929-1958.
572
+ .. [2] Lee et al., "Enabling Spike-based Backpropagation for Training Deep Neural
573
+ Network Architectures" https://arxiv.org/abs/1903.06379
574
+
575
+ Examples
576
+ --------
577
+ .. code-block:: python
578
+
579
+ >>> import brainstate
580
+ >>> layer = brainstate.nn.DropoutFixed(in_size=(20,), prob=0.8)
581
+ >>> layer.init_state(batch_size=10)
582
+ >>> x = brainstate.random.randn(10, 20)
583
+ >>> with brainstate.environ.context(fit=True):
584
+ ... output = layer.update(x)
585
+ >>> output.shape
586
+ (10, 20)
371
587
 
372
- .. admonition:: Tip
373
- :class: tip
374
-
375
- This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
376
- Network Architectures <https://arxiv.org/abs/1903.06379>`_:
377
-
378
- There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
379
- training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
380
- are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
381
- iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
382
- the output error and modify the network parameters only at the last time step. For dropout to be effective in
383
- our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
384
- data is not changed, such that the neural network is constituted by the same random subset of units during
385
- each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
386
- each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
387
- iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
388
- are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
389
- time window within an iteration.
390
-
391
- Args:
392
- in_size: The size of the input tensor.
393
- prob: Probability to keep element of the tensor.
394
- mode: Mode. The computation mode of the object.
395
- name: str. The name of the dynamic system.
396
588
  """
397
589
  __module__ = 'brainstate.nn'
398
590