brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_dropout.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,13 +20,20 @@ from typing import Optional, Sequence
|
|
20
20
|
import brainunit as u
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
|
-
from brainstate import random, environ
|
23
|
+
from brainstate import random, environ
|
24
24
|
from brainstate._state import ShortTermState
|
25
25
|
from brainstate.typing import Size
|
26
|
+
from . import init as init
|
26
27
|
from ._module import ElementWiseBlock
|
27
28
|
|
28
29
|
__all__ = [
|
29
|
-
'
|
30
|
+
'Dropout',
|
31
|
+
'Dropout1d',
|
32
|
+
'Dropout2d',
|
33
|
+
'Dropout3d',
|
34
|
+
'AlphaDropout',
|
35
|
+
'FeatureAlphaDropout',
|
36
|
+
'DropoutFixed',
|
30
37
|
]
|
31
38
|
|
32
39
|
|
@@ -39,14 +46,32 @@ class Dropout(ElementWiseBlock):
|
|
39
46
|
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
40
47
|
circumstances it is a no-op.
|
41
48
|
|
49
|
+
Parameters
|
50
|
+
----------
|
51
|
+
prob : float
|
52
|
+
Probability to keep element of the tensor. Default is 0.5.
|
53
|
+
broadcast_dims : Sequence[int]
|
54
|
+
Dimensions that will share the same dropout mask. Default is ().
|
55
|
+
name : str, optional
|
56
|
+
The name of the dynamic system.
|
57
|
+
|
58
|
+
References
|
59
|
+
----------
|
42
60
|
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
43
61
|
neural networks from overfitting." The journal of machine learning
|
44
62
|
research 15.1 (2014): 1929-1958.
|
45
63
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
64
|
+
Examples
|
65
|
+
--------
|
66
|
+
.. code-block:: python
|
67
|
+
|
68
|
+
>>> import brainstate
|
69
|
+
>>> layer = brainstate.nn.Dropout(prob=0.8)
|
70
|
+
>>> x = brainstate.random.randn(10, 20)
|
71
|
+
>>> with brainstate.environ.context(fit=True):
|
72
|
+
... output = layer(x)
|
73
|
+
>>> output.shape
|
74
|
+
(10, 20)
|
50
75
|
|
51
76
|
"""
|
52
77
|
__module__ = 'brainstate.nn'
|
@@ -133,41 +158,55 @@ class _DropoutNd(ElementWiseBlock):
|
|
133
158
|
|
134
159
|
|
135
160
|
class Dropout1d(_DropoutNd):
|
136
|
-
r"""Randomly zero out entire channels (a channel is a 1D feature map
|
137
|
-
|
138
|
-
batched input is a 1D tensor :math:`\text{input}[i, j]`).
|
161
|
+
r"""Randomly zero out entire channels (a channel is a 1D feature map).
|
162
|
+
|
139
163
|
Each channel will be zeroed out independently on every forward call with
|
140
|
-
probability
|
164
|
+
probability using samples from a Bernoulli distribution. The channel is
|
165
|
+
a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
|
166
|
+
in the batched input is a 1D tensor :math:`\text{input}[i, j]`.
|
141
167
|
|
142
|
-
Usually the input comes from :class:`
|
168
|
+
Usually the input comes from :class:`Conv1d` modules.
|
143
169
|
|
144
|
-
As described in the paper
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
will not regularize the activations and will otherwise just result
|
149
|
-
in an effective learning rate decrease.
|
170
|
+
As described in the paper [1]_, if adjacent pixels within feature maps are
|
171
|
+
strongly correlated (as is normally the case in early convolution layers)
|
172
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
173
|
+
just result in an effective learning rate decrease.
|
150
174
|
|
151
|
-
In this case, :
|
175
|
+
In this case, :class:`Dropout1d` will help promote independence between
|
152
176
|
feature maps and should be used instead.
|
153
177
|
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
178
|
+
Parameters
|
179
|
+
----------
|
180
|
+
prob : float
|
181
|
+
Probability of an element to be kept. Default is 0.5.
|
182
|
+
channel_axis : int
|
183
|
+
The axis representing the channel dimension. Default is -1.
|
184
|
+
name : str, optional
|
185
|
+
The name of the dynamic system.
|
186
|
+
|
187
|
+
Notes
|
188
|
+
-----
|
189
|
+
Input shape: :math:`(N, C, L)` or :math:`(C, L)`.
|
190
|
+
|
191
|
+
Output shape: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
|
192
|
+
|
193
|
+
References
|
194
|
+
----------
|
195
|
+
.. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
196
|
+
https://arxiv.org/abs/1411.4280
|
197
|
+
|
198
|
+
Examples
|
199
|
+
--------
|
200
|
+
.. code-block:: python
|
201
|
+
|
202
|
+
>>> import brainstate
|
203
|
+
>>> m = brainstate.nn.Dropout1d(prob=0.8)
|
204
|
+
>>> x = brainstate.random.randn(20, 32, 16)
|
205
|
+
>>> with brainstate.environ.context(fit=True):
|
206
|
+
... output = m(x)
|
166
207
|
>>> output.shape
|
167
208
|
(20, 32, 16)
|
168
209
|
|
169
|
-
.. _Efficient Object Localization Using Convolutional Networks:
|
170
|
-
https://arxiv.org/abs/1411.4280
|
171
210
|
"""
|
172
211
|
__module__ = 'brainstate.nn'
|
173
212
|
minimal_dim: int = 2
|
@@ -179,39 +218,55 @@ class Dropout1d(_DropoutNd):
|
|
179
218
|
|
180
219
|
|
181
220
|
class Dropout2d(_DropoutNd):
|
182
|
-
r"""Randomly zero out entire channels (a channel is a 2D feature map
|
183
|
-
|
184
|
-
batched input is a 2D tensor :math:`\text{input}[i, j]`).
|
221
|
+
r"""Randomly zero out entire channels (a channel is a 2D feature map).
|
222
|
+
|
185
223
|
Each channel will be zeroed out independently on every forward call with
|
186
|
-
probability
|
224
|
+
probability using samples from a Bernoulli distribution. The channel is
|
225
|
+
a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
|
226
|
+
in the batched input is a 2D tensor :math:`\text{input}[i, j]`.
|
187
227
|
|
188
|
-
Usually the input comes from :class:`
|
228
|
+
Usually the input comes from :class:`Conv2d` modules.
|
189
229
|
|
190
|
-
As described in the paper
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
will not regularize the activations and will otherwise just result
|
195
|
-
in an effective learning rate decrease.
|
230
|
+
As described in the paper [1]_, if adjacent pixels within feature maps are
|
231
|
+
strongly correlated (as is normally the case in early convolution layers)
|
232
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
233
|
+
just result in an effective learning rate decrease.
|
196
234
|
|
197
|
-
In this case, :
|
235
|
+
In this case, :class:`Dropout2d` will help promote independence between
|
198
236
|
feature maps and should be used instead.
|
199
237
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
238
|
+
Parameters
|
239
|
+
----------
|
240
|
+
prob : float
|
241
|
+
Probability of an element to be kept. Default is 0.5.
|
242
|
+
channel_axis : int
|
243
|
+
The axis representing the channel dimension. Default is -1.
|
244
|
+
name : str, optional
|
245
|
+
The name of the dynamic system.
|
246
|
+
|
247
|
+
Notes
|
248
|
+
-----
|
249
|
+
Input shape: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
250
|
+
|
251
|
+
Output shape: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
|
252
|
+
|
253
|
+
References
|
254
|
+
----------
|
255
|
+
.. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
256
|
+
https://arxiv.org/abs/1411.4280
|
257
|
+
|
258
|
+
Examples
|
259
|
+
--------
|
260
|
+
.. code-block:: python
|
261
|
+
|
262
|
+
>>> import brainstate
|
263
|
+
>>> m = brainstate.nn.Dropout2d(prob=0.8)
|
264
|
+
>>> x = brainstate.random.randn(20, 32, 32, 16)
|
265
|
+
>>> with brainstate.environ.context(fit=True):
|
266
|
+
... output = m(x)
|
267
|
+
>>> output.shape
|
268
|
+
(20, 32, 32, 16)
|
212
269
|
|
213
|
-
.. _Efficient Object Localization Using Convolutional Networks:
|
214
|
-
https://arxiv.org/abs/1411.4280
|
215
270
|
"""
|
216
271
|
__module__ = 'brainstate.nn'
|
217
272
|
minimal_dim: int = 3
|
@@ -223,39 +278,55 @@ class Dropout2d(_DropoutNd):
|
|
223
278
|
|
224
279
|
|
225
280
|
class Dropout3d(_DropoutNd):
|
226
|
-
r"""Randomly zero out entire channels (a channel is a 3D feature map
|
227
|
-
|
228
|
-
batched input is a 3D tensor :math:`\text{input}[i, j]`).
|
281
|
+
r"""Randomly zero out entire channels (a channel is a 3D feature map).
|
282
|
+
|
229
283
|
Each channel will be zeroed out independently on every forward call with
|
230
|
-
probability
|
284
|
+
probability using samples from a Bernoulli distribution. The channel is
|
285
|
+
a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
|
286
|
+
in the batched input is a 3D tensor :math:`\text{input}[i, j]`.
|
231
287
|
|
232
|
-
Usually the input comes from :class:`
|
288
|
+
Usually the input comes from :class:`Conv3d` modules.
|
233
289
|
|
234
|
-
As described in the paper
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
will not regularize the activations and will otherwise just result
|
239
|
-
in an effective learning rate decrease.
|
290
|
+
As described in the paper [1]_, if adjacent pixels within feature maps are
|
291
|
+
strongly correlated (as is normally the case in early convolution layers)
|
292
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
293
|
+
just result in an effective learning rate decrease.
|
240
294
|
|
241
|
-
In this case, :
|
295
|
+
In this case, :class:`Dropout3d` will help promote independence between
|
242
296
|
feature maps and should be used instead.
|
243
297
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
298
|
+
Parameters
|
299
|
+
----------
|
300
|
+
prob : float
|
301
|
+
Probability of an element to be kept. Default is 0.5.
|
302
|
+
channel_axis : int
|
303
|
+
The axis representing the channel dimension. Default is -1.
|
304
|
+
name : str, optional
|
305
|
+
The name of the dynamic system.
|
306
|
+
|
307
|
+
Notes
|
308
|
+
-----
|
309
|
+
Input shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
310
|
+
|
311
|
+
Output shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
312
|
+
|
313
|
+
References
|
314
|
+
----------
|
315
|
+
.. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
316
|
+
https://arxiv.org/abs/1411.4280
|
317
|
+
|
318
|
+
Examples
|
319
|
+
--------
|
320
|
+
.. code-block:: python
|
321
|
+
|
322
|
+
>>> import brainstate
|
323
|
+
>>> m = brainstate.nn.Dropout3d(prob=0.8)
|
324
|
+
>>> x = brainstate.random.randn(20, 16, 4, 32, 32)
|
325
|
+
>>> with brainstate.environ.context(fit=True):
|
326
|
+
... output = m(x)
|
327
|
+
>>> output.shape
|
328
|
+
(20, 16, 4, 32, 32)
|
256
329
|
|
257
|
-
.. _Efficient Object Localization Using Convolutional Networks:
|
258
|
-
https://arxiv.org/abs/1411.4280
|
259
330
|
"""
|
260
331
|
__module__ = 'brainstate.nn'
|
261
332
|
minimal_dim: int = 4
|
@@ -270,129 +341,250 @@ class AlphaDropout(_DropoutNd):
|
|
270
341
|
r"""Applies Alpha Dropout over the input.
|
271
342
|
|
272
343
|
Alpha Dropout is a type of Dropout that maintains the self-normalizing
|
273
|
-
property.
|
274
|
-
For an input with zero mean and unit standard deviation, the output of
|
344
|
+
property. For an input with zero mean and unit standard deviation, the output of
|
275
345
|
Alpha Dropout maintains the original mean and standard deviation of the
|
276
346
|
input.
|
347
|
+
|
277
348
|
Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
|
278
349
|
that the outputs have zero mean and unit standard deviation.
|
279
350
|
|
280
351
|
During training, it randomly masks some of the elements of the input
|
281
|
-
tensor with probability
|
282
|
-
The elements to masked are randomized on every forward call, and scaled
|
352
|
+
tensor with probability using samples from a Bernoulli distribution.
|
353
|
+
The elements to be masked are randomized on every forward call, and scaled
|
283
354
|
and shifted to maintain zero mean and unit standard deviation.
|
284
355
|
|
285
356
|
During evaluation the module simply computes an identity function.
|
286
357
|
|
287
|
-
|
358
|
+
Parameters
|
359
|
+
----------
|
360
|
+
prob : float
|
361
|
+
Probability of an element to be kept. Default is 0.5.
|
362
|
+
name : str, optional
|
363
|
+
The name of the dynamic system.
|
364
|
+
|
365
|
+
Notes
|
366
|
+
-----
|
367
|
+
Input shape: :math:`(*)`. Input can be of any shape.
|
368
|
+
|
369
|
+
Output shape: :math:`(*)`. Output is of the same shape as input.
|
370
|
+
|
371
|
+
References
|
372
|
+
----------
|
373
|
+
.. [1] Klambauer et al., "Self-Normalizing Neural Networks"
|
374
|
+
https://arxiv.org/abs/1706.02515
|
375
|
+
|
376
|
+
Examples
|
377
|
+
--------
|
378
|
+
.. code-block:: python
|
379
|
+
|
380
|
+
>>> import brainstate
|
381
|
+
>>> m = brainstate.nn.AlphaDropout(prob=0.8)
|
382
|
+
>>> x = brainstate.random.randn(20, 16)
|
383
|
+
>>> with brainstate.environ.context(fit=True):
|
384
|
+
... output = m(x)
|
385
|
+
>>> output.shape
|
386
|
+
(20, 16)
|
288
387
|
|
289
|
-
|
290
|
-
|
388
|
+
"""
|
389
|
+
__module__ = 'brainstate.nn'
|
291
390
|
|
292
|
-
|
293
|
-
|
294
|
-
|
391
|
+
def __init__(
|
392
|
+
self,
|
393
|
+
prob: float = 0.5,
|
394
|
+
name: Optional[str] = None
|
395
|
+
) -> None:
|
396
|
+
super().__init__(name=name)
|
397
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
398
|
+
self.prob = prob
|
295
399
|
|
296
|
-
|
400
|
+
# SELU parameters
|
401
|
+
alpha = -1.7580993408473766
|
402
|
+
self.alpha = alpha
|
297
403
|
|
298
|
-
|
299
|
-
|
300
|
-
|
404
|
+
# Affine transformation parameters to maintain mean and variance
|
405
|
+
self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
|
406
|
+
self.b = -self.a * alpha * prob
|
301
407
|
|
302
|
-
|
303
|
-
|
304
|
-
|
408
|
+
def __call__(self, x):
|
409
|
+
dtype = u.math.get_dtype(x)
|
410
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
411
|
+
if fit_phase and self.prob < 1.:
|
412
|
+
keep_mask = random.bernoulli(self.prob, x.shape)
|
413
|
+
return u.math.where(
|
414
|
+
keep_mask,
|
415
|
+
u.math.asarray(x, dtype=dtype),
|
416
|
+
u.math.asarray(self.alpha, dtype=dtype)
|
417
|
+
) * self.a + self.b
|
418
|
+
else:
|
419
|
+
return x
|
305
420
|
|
306
|
-
def update(self, *args, **kwargs):
|
307
|
-
raise NotImplementedError("AlphaDropout is not supported in the current version.")
|
308
421
|
|
422
|
+
class FeatureAlphaDropout(ElementWiseBlock):
|
423
|
+
r"""Randomly masks out entire channels with Alpha Dropout properties.
|
309
424
|
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
|
314
|
-
setting activations to zero, as in regular Dropout, the activations are set
|
315
|
-
to the negative saturation value of the SELU activation function. More details
|
316
|
-
can be found in the paper `Self-Normalizing Neural Networks`_ .
|
425
|
+
Instead of setting activations to zero as in regular Dropout, the activations
|
426
|
+
are set to the negative saturation value of the SELU activation function to
|
427
|
+
maintain self-normalizing properties.
|
317
428
|
|
318
|
-
Each
|
319
|
-
|
320
|
-
|
321
|
-
|
429
|
+
Each channel (e.g., the :math:`j`-th channel of the :math:`i`-th sample in
|
430
|
+
the batch input is a tensor :math:`\text{input}[i, j]`) will be masked
|
431
|
+
independently for each sample on every forward call with probability using
|
432
|
+
samples from a Bernoulli distribution. The elements to be masked are randomized
|
433
|
+
on every forward call, and scaled and shifted to maintain zero mean and unit
|
434
|
+
variance.
|
322
435
|
|
323
|
-
Usually the input comes from
|
436
|
+
Usually the input comes from convolutional layers with SELU activation.
|
324
437
|
|
325
|
-
As described in the paper
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
will not regularize the activations and will otherwise just result
|
330
|
-
in an effective learning rate decrease.
|
438
|
+
As described in the paper [2]_, if adjacent pixels within feature maps are
|
439
|
+
strongly correlated (as is normally the case in early convolution layers)
|
440
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
441
|
+
just result in an effective learning rate decrease.
|
331
442
|
|
332
|
-
In this case, :
|
443
|
+
In this case, :class:`FeatureAlphaDropout` will help promote independence between
|
333
444
|
feature maps and should be used instead.
|
334
445
|
|
335
|
-
|
336
|
-
|
446
|
+
Parameters
|
447
|
+
----------
|
448
|
+
prob : float
|
449
|
+
Probability of an element to be kept. Default is 0.5.
|
450
|
+
channel_axis : int
|
451
|
+
The axis representing the channel dimension. Default is -1.
|
452
|
+
name : str, optional
|
453
|
+
The name of the dynamic system.
|
454
|
+
|
455
|
+
Notes
|
456
|
+
-----
|
457
|
+
Input shape: :math:`(N, C, *)` where C is the channel dimension.
|
458
|
+
|
459
|
+
Output shape: Same shape as input.
|
460
|
+
|
461
|
+
References
|
462
|
+
----------
|
463
|
+
.. [1] Klambauer et al., "Self-Normalizing Neural Networks"
|
464
|
+
https://arxiv.org/abs/1706.02515
|
465
|
+
.. [2] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
466
|
+
https://arxiv.org/abs/1411.4280
|
467
|
+
|
468
|
+
Examples
|
469
|
+
--------
|
470
|
+
.. code-block:: python
|
471
|
+
|
472
|
+
>>> import brainstate
|
473
|
+
>>> m = brainstate.nn.FeatureAlphaDropout(prob=0.8)
|
474
|
+
>>> x = brainstate.random.randn(20, 16, 4, 32, 32)
|
475
|
+
>>> with brainstate.environ.context(fit=True):
|
476
|
+
... output = m(x)
|
477
|
+
>>> output.shape
|
478
|
+
(20, 16, 4, 32, 32)
|
337
479
|
|
338
|
-
|
339
|
-
|
340
|
-
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
480
|
+
"""
|
481
|
+
__module__ = 'brainstate.nn'
|
341
482
|
|
342
|
-
|
483
|
+
def __init__(
|
484
|
+
self,
|
485
|
+
prob: float = 0.5,
|
486
|
+
channel_axis: int = -1,
|
487
|
+
name: Optional[str] = None
|
488
|
+
) -> None:
|
489
|
+
super().__init__(name=name)
|
490
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
491
|
+
self.prob = prob
|
492
|
+
self.channel_axis = channel_axis
|
343
493
|
|
344
|
-
|
345
|
-
|
346
|
-
|
494
|
+
# SELU parameters
|
495
|
+
alpha = -1.7580993408473766
|
496
|
+
self.alpha = alpha
|
347
497
|
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
"""
|
352
|
-
__module__ = 'brainstate.nn'
|
498
|
+
# Affine transformation parameters to maintain mean and variance
|
499
|
+
self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
|
500
|
+
self.b = -self.a * alpha * prob
|
353
501
|
|
354
|
-
def
|
355
|
-
|
502
|
+
def __call__(self, x):
|
503
|
+
dtype = u.math.get_dtype(x)
|
504
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
505
|
+
if fit_phase and self.prob < 1.:
|
506
|
+
# Create mask shape with 1s except for batch and channel dimensions
|
507
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
508
|
+
mask_shape = [1] * x.ndim
|
509
|
+
mask_shape[0] = x.shape[0] # batch dimension
|
510
|
+
mask_shape[channel_axis] = x.shape[channel_axis] # channel dimension
|
511
|
+
|
512
|
+
keep_mask = random.bernoulli(self.prob, mask_shape)
|
513
|
+
keep_mask = u.math.broadcast_to(keep_mask, x.shape)
|
514
|
+
return u.math.where(
|
515
|
+
keep_mask,
|
516
|
+
u.math.asarray(x, dtype=dtype),
|
517
|
+
u.math.asarray(self.alpha, dtype=dtype)
|
518
|
+
) * self.a + self.b
|
519
|
+
else:
|
520
|
+
return x
|
356
521
|
|
357
522
|
|
358
523
|
class DropoutFixed(ElementWiseBlock):
|
359
|
-
"""
|
360
|
-
A dropout layer with the fixed dropout mask along the time axis once after initialized.
|
524
|
+
"""A dropout layer with a fixed dropout mask along the time axis.
|
361
525
|
|
362
|
-
In training, to compensate for the fraction of input values dropped
|
363
|
-
all surviving values are multiplied by `1 / (1 -
|
526
|
+
In training, to compensate for the fraction of input values dropped,
|
527
|
+
all surviving values are multiplied by `1 / (1 - prob)`.
|
364
528
|
|
365
529
|
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
366
530
|
circumstances it is a no-op.
|
367
531
|
|
532
|
+
This kind of Dropout is particularly useful for spiking neural networks (SNNs) where
|
533
|
+
the same dropout mask needs to be applied across multiple time steps within a single
|
534
|
+
mini-batch iteration.
|
535
|
+
|
536
|
+
Parameters
|
537
|
+
----------
|
538
|
+
in_size : tuple or int
|
539
|
+
The size of the input tensor.
|
540
|
+
prob : float
|
541
|
+
Probability to keep element of the tensor. Default is 0.5.
|
542
|
+
name : str, optional
|
543
|
+
The name of the dynamic system.
|
544
|
+
|
545
|
+
Notes
|
546
|
+
-----
|
547
|
+
As described in [2]_, there is a subtle difference in the way dropout is applied in
|
548
|
+
SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of
|
549
|
+
mini-batches. In each iteration, randomly selected units (with dropout ratio of
|
550
|
+
:math:`p`) are disconnected from the network while weighting by its posterior
|
551
|
+
probability (:math:`1-p`).
|
552
|
+
|
553
|
+
However, in SNNs, each iteration has more than one forward propagation depending on
|
554
|
+
the time length of the spike train. We back-propagate the output error and modify
|
555
|
+
the network parameters only at the last time step. For dropout to be effective in
|
556
|
+
our training method, it has to be ensured that the set of connected units within an
|
557
|
+
iteration of mini-batch data is not changed, such that the neural network is
|
558
|
+
constituted by the same random subset of units during each forward propagation within
|
559
|
+
a single iteration.
|
560
|
+
|
561
|
+
On the other hand, if the units are randomly connected at each time-step, the effect
|
562
|
+
of dropout will be averaged out over the entire forward propagation time within an
|
563
|
+
iteration. Then, the dropout effect would fade-out once the output error is propagated
|
564
|
+
backward and the parameters are updated at the last time step. Therefore, we need to
|
565
|
+
keep the set of randomly connected units for the entire time window within an iteration.
|
566
|
+
|
567
|
+
References
|
568
|
+
----------
|
368
569
|
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
369
570
|
neural networks from overfitting." The journal of machine learning
|
370
571
|
research 15.1 (2014): 1929-1958.
|
572
|
+
.. [2] Lee et al., "Enabling Spike-based Backpropagation for Training Deep Neural
|
573
|
+
Network Architectures" https://arxiv.org/abs/1903.06379
|
574
|
+
|
575
|
+
Examples
|
576
|
+
--------
|
577
|
+
.. code-block:: python
|
578
|
+
|
579
|
+
>>> import brainstate
|
580
|
+
>>> layer = brainstate.nn.DropoutFixed(in_size=(20,), prob=0.8)
|
581
|
+
>>> layer.init_state(batch_size=10)
|
582
|
+
>>> x = brainstate.random.randn(10, 20)
|
583
|
+
>>> with brainstate.environ.context(fit=True):
|
584
|
+
... output = layer.update(x)
|
585
|
+
>>> output.shape
|
586
|
+
(10, 20)
|
371
587
|
|
372
|
-
.. admonition:: Tip
|
373
|
-
:class: tip
|
374
|
-
|
375
|
-
This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
|
376
|
-
Network Architectures <https://arxiv.org/abs/1903.06379>`_:
|
377
|
-
|
378
|
-
There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
|
379
|
-
training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
|
380
|
-
are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
|
381
|
-
iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
|
382
|
-
the output error and modify the network parameters only at the last time step. For dropout to be effective in
|
383
|
-
our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
|
384
|
-
data is not changed, such that the neural network is constituted by the same random subset of units during
|
385
|
-
each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
|
386
|
-
each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
|
387
|
-
iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
|
388
|
-
are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
|
389
|
-
time window within an iteration.
|
390
|
-
|
391
|
-
Args:
|
392
|
-
in_size: The size of the input tensor.
|
393
|
-
prob: Probability to keep element of the tensor.
|
394
|
-
mode: Mode. The computation mode of the object.
|
395
|
-
name: str. The name of the dynamic system.
|
396
588
|
"""
|
397
589
|
__module__ = 'brainstate.nn'
|
398
590
|
|