brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,418 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from functools import partial
|
20
|
+
from typing import Optional
|
21
|
+
|
22
|
+
import brainunit as u
|
23
|
+
import jax.numpy as jnp
|
24
|
+
|
25
|
+
from brainstate import random, environ, init
|
26
|
+
from brainstate._state import ShortTermState
|
27
|
+
from brainstate.nn._module import ElementWiseBlock
|
28
|
+
from brainstate.typing import Size
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
|
32
|
+
'AlphaDropout', 'FeatureAlphaDropout',
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
class Dropout(ElementWiseBlock):
|
37
|
+
"""A layer that stochastically ignores a subset of inputs each training step.
|
38
|
+
|
39
|
+
In training, to compensate for the fraction of input values dropped (`rate`),
|
40
|
+
all surviving values are multiplied by `1 / (1 - rate)`.
|
41
|
+
|
42
|
+
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
43
|
+
circumstances it is a no-op.
|
44
|
+
|
45
|
+
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
46
|
+
neural networks from overfitting." The journal of machine learning
|
47
|
+
research 15.1 (2014): 1929-1958.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
prob: Probability to keep element of the tensor.
|
51
|
+
mode: Mode. The computation mode of the object.
|
52
|
+
name: str. The name of the dynamic system.
|
53
|
+
|
54
|
+
"""
|
55
|
+
__module__ = 'brainstate.nn'
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
prob: float = 0.5,
|
60
|
+
name: Optional[str] = None
|
61
|
+
) -> None:
|
62
|
+
super().__init__(name=name)
|
63
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
64
|
+
self.prob = prob
|
65
|
+
|
66
|
+
def __call__(self, x):
|
67
|
+
dtype = u.math.get_dtype(x)
|
68
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
69
|
+
if fit_phase and self.prob < 1.:
|
70
|
+
keep_mask = random.bernoulli(self.prob, x.shape)
|
71
|
+
return jnp.where(keep_mask,
|
72
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
73
|
+
jnp.asarray(0., dtype=dtype))
|
74
|
+
else:
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class _DropoutNd(ElementWiseBlock):
|
79
|
+
__module__ = 'brainstate.nn'
|
80
|
+
prob: float
|
81
|
+
channel_axis: int
|
82
|
+
minimal_dim: int
|
83
|
+
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
prob: float = 0.5,
|
87
|
+
channel_axis: int = -1,
|
88
|
+
name: Optional[str] = None
|
89
|
+
) -> None:
|
90
|
+
super().__init__(name=name)
|
91
|
+
assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
|
92
|
+
self.prob = prob
|
93
|
+
self.channel_axis = channel_axis
|
94
|
+
|
95
|
+
def __call__(self, x):
|
96
|
+
|
97
|
+
# check input shape
|
98
|
+
inp_dim = u.math.ndim(x)
|
99
|
+
if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
|
100
|
+
raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
|
101
|
+
f"but received a {inp_dim}D input. {self._get_msg(x)}")
|
102
|
+
is_not_batched = self.minimal_dim
|
103
|
+
if is_not_batched:
|
104
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
105
|
+
mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
|
106
|
+
else:
|
107
|
+
channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
108
|
+
assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
|
109
|
+
mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
|
110
|
+
|
111
|
+
# get fit phase
|
112
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
113
|
+
|
114
|
+
# generate mask
|
115
|
+
if fit_phase:
|
116
|
+
dtype = u.math.get_dtype(x)
|
117
|
+
keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
|
118
|
+
return jnp.where(keep_mask,
|
119
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
120
|
+
jnp.asarray(0., dtype=dtype))
|
121
|
+
else:
|
122
|
+
return x
|
123
|
+
|
124
|
+
def _get_msg(self, x):
|
125
|
+
return ''
|
126
|
+
|
127
|
+
|
128
|
+
class Dropout1d(_DropoutNd):
|
129
|
+
r"""Randomly zero out entire channels (a channel is a 1D feature map,
|
130
|
+
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
131
|
+
batched input is a 1D tensor :math:`\text{input}[i, j]`).
|
132
|
+
Each channel will be zeroed out independently on every forward call with
|
133
|
+
probability :attr:`p` using samples from a Bernoulli distribution.
|
134
|
+
|
135
|
+
Usually the input comes from :class:`nn.Conv1d` modules.
|
136
|
+
|
137
|
+
As described in the paper
|
138
|
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
139
|
+
if adjacent pixels within feature maps are strongly correlated
|
140
|
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
141
|
+
will not regularize the activations and will otherwise just result
|
142
|
+
in an effective learning rate decrease.
|
143
|
+
|
144
|
+
In this case, :func:`nn.Dropout1d` will help promote independence between
|
145
|
+
feature maps and should be used instead.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
prob: float. probability of an element to be zero-ed.
|
149
|
+
|
150
|
+
Shape:
|
151
|
+
- Input: :math:`(N, C, L)` or :math:`(C, L)`.
|
152
|
+
- Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
|
153
|
+
|
154
|
+
Examples::
|
155
|
+
|
156
|
+
>>> m = Dropout1d(p=0.2)
|
157
|
+
>>> x = random.randn(20, 32, 16)
|
158
|
+
>>> output = m(x)
|
159
|
+
>>> output.shape
|
160
|
+
(20, 32, 16)
|
161
|
+
|
162
|
+
.. _Efficient Object Localization Using Convolutional Networks:
|
163
|
+
https://arxiv.org/abs/1411.4280
|
164
|
+
"""
|
165
|
+
__module__ = 'brainstate.nn'
|
166
|
+
minimal_dim: int = 2
|
167
|
+
|
168
|
+
def _get_msg(self, x):
|
169
|
+
return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
|
170
|
+
"spatial dimension, a channel dimension, and an optional batch dimension "
|
171
|
+
"(i.e. 2D or 3D inputs).")
|
172
|
+
|
173
|
+
|
174
|
+
class Dropout2d(_DropoutNd):
|
175
|
+
r"""Randomly zero out entire channels (a channel is a 2D feature map,
|
176
|
+
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
177
|
+
batched input is a 2D tensor :math:`\text{input}[i, j]`).
|
178
|
+
Each channel will be zeroed out independently on every forward call with
|
179
|
+
probability :attr:`p` using samples from a Bernoulli distribution.
|
180
|
+
|
181
|
+
Usually the input comes from :class:`nn.Conv2d` modules.
|
182
|
+
|
183
|
+
As described in the paper
|
184
|
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
185
|
+
if adjacent pixels within feature maps are strongly correlated
|
186
|
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
187
|
+
will not regularize the activations and will otherwise just result
|
188
|
+
in an effective learning rate decrease.
|
189
|
+
|
190
|
+
In this case, :func:`nn.Dropout2d` will help promote independence between
|
191
|
+
feature maps and should be used instead.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
prob: float. probability of an element to be kept.
|
195
|
+
|
196
|
+
Shape:
|
197
|
+
- Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
|
198
|
+
- Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
|
199
|
+
|
200
|
+
Examples::
|
201
|
+
|
202
|
+
>>> m = Dropout2d(p=0.2)
|
203
|
+
>>> x = random.randn(20, 32, 32, 16)
|
204
|
+
>>> output = m(x)
|
205
|
+
|
206
|
+
.. _Efficient Object Localization Using Convolutional Networks:
|
207
|
+
https://arxiv.org/abs/1411.4280
|
208
|
+
"""
|
209
|
+
__module__ = 'brainstate.nn'
|
210
|
+
minimal_dim: int = 3
|
211
|
+
|
212
|
+
def _get_msg(self, x):
|
213
|
+
return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
|
214
|
+
"spatial dimensions, a channel dimension, and an optional batch dimension "
|
215
|
+
"(i.e. 3D or 4D inputs).")
|
216
|
+
|
217
|
+
|
218
|
+
class Dropout3d(_DropoutNd):
|
219
|
+
r"""Randomly zero out entire channels (a channel is a 3D feature map,
|
220
|
+
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
|
221
|
+
batched input is a 3D tensor :math:`\text{input}[i, j]`).
|
222
|
+
Each channel will be zeroed out independently on every forward call with
|
223
|
+
probability :attr:`p` using samples from a Bernoulli distribution.
|
224
|
+
|
225
|
+
Usually the input comes from :class:`nn.Conv3d` modules.
|
226
|
+
|
227
|
+
As described in the paper
|
228
|
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
229
|
+
if adjacent pixels within feature maps are strongly correlated
|
230
|
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
231
|
+
will not regularize the activations and will otherwise just result
|
232
|
+
in an effective learning rate decrease.
|
233
|
+
|
234
|
+
In this case, :func:`nn.Dropout3d` will help promote independence between
|
235
|
+
feature maps and should be used instead.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
prob: float. probability of an element to be kept.
|
239
|
+
|
240
|
+
Shape:
|
241
|
+
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
242
|
+
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
243
|
+
|
244
|
+
Examples::
|
245
|
+
|
246
|
+
>>> m = Dropout3d(p=0.2)
|
247
|
+
>>> x = random.randn(20, 16, 4, 32, 32)
|
248
|
+
>>> output = m(x)
|
249
|
+
|
250
|
+
.. _Efficient Object Localization Using Convolutional Networks:
|
251
|
+
https://arxiv.org/abs/1411.4280
|
252
|
+
"""
|
253
|
+
__module__ = 'brainstate.nn'
|
254
|
+
minimal_dim: int = 4
|
255
|
+
|
256
|
+
def _get_msg(self, x):
|
257
|
+
return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
|
258
|
+
"spatial dimensions, a channel dimension, and an optional batch dimension "
|
259
|
+
"(i.e. 4D or 5D inputs).")
|
260
|
+
|
261
|
+
|
262
|
+
class AlphaDropout(_DropoutNd):
|
263
|
+
r"""Applies Alpha Dropout over the input.
|
264
|
+
|
265
|
+
Alpha Dropout is a type of Dropout that maintains the self-normalizing
|
266
|
+
property.
|
267
|
+
For an input with zero mean and unit standard deviation, the output of
|
268
|
+
Alpha Dropout maintains the original mean and standard deviation of the
|
269
|
+
input.
|
270
|
+
Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
|
271
|
+
that the outputs have zero mean and unit standard deviation.
|
272
|
+
|
273
|
+
During training, it randomly masks some of the elements of the input
|
274
|
+
tensor with probability *p* using samples from a bernoulli distribution.
|
275
|
+
The elements to masked are randomized on every forward call, and scaled
|
276
|
+
and shifted to maintain zero mean and unit standard deviation.
|
277
|
+
|
278
|
+
During evaluation the module simply computes an identity function.
|
279
|
+
|
280
|
+
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
281
|
+
|
282
|
+
Args:
|
283
|
+
prob: float. probability of an element to be kept.
|
284
|
+
|
285
|
+
Shape:
|
286
|
+
- Input: :math:`(*)`. Input can be of any shape
|
287
|
+
- Output: :math:`(*)`. Output is of the same shape as input
|
288
|
+
|
289
|
+
Examples::
|
290
|
+
|
291
|
+
>>> m = AlphaDropout(p=0.2)
|
292
|
+
>>> x = random.randn(20, 16)
|
293
|
+
>>> output = m(x)
|
294
|
+
|
295
|
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
296
|
+
"""
|
297
|
+
__module__ = 'brainstate.nn'
|
298
|
+
|
299
|
+
def forward(self, x):
|
300
|
+
return F.alpha_dropout(x, self.p, self.training)
|
301
|
+
|
302
|
+
|
303
|
+
class FeatureAlphaDropout(_DropoutNd):
|
304
|
+
r"""Randomly masks out entire channels (a channel is a feature map,
|
305
|
+
e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
|
306
|
+
is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
|
307
|
+
setting activations to zero, as in regular Dropout, the activations are set
|
308
|
+
to the negative saturation value of the SELU activation function. More details
|
309
|
+
can be found in the paper `Self-Normalizing Neural Networks`_ .
|
310
|
+
|
311
|
+
Each element will be masked independently for each sample on every forward
|
312
|
+
call with probability :attr:`p` using samples from a Bernoulli distribution.
|
313
|
+
The elements to be masked are randomized on every forward call, and scaled
|
314
|
+
and shifted to maintain zero mean and unit variance.
|
315
|
+
|
316
|
+
Usually the input comes from :class:`nn.AlphaDropout` modules.
|
317
|
+
|
318
|
+
As described in the paper
|
319
|
+
`Efficient Object Localization Using Convolutional Networks`_ ,
|
320
|
+
if adjacent pixels within feature maps are strongly correlated
|
321
|
+
(as is normally the case in early convolution layers) then i.i.d. dropout
|
322
|
+
will not regularize the activations and will otherwise just result
|
323
|
+
in an effective learning rate decrease.
|
324
|
+
|
325
|
+
In this case, :func:`nn.AlphaDropout` will help promote independence between
|
326
|
+
feature maps and should be used instead.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
prob: float. probability of an element to be kept.
|
330
|
+
|
331
|
+
Shape:
|
332
|
+
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
333
|
+
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
334
|
+
|
335
|
+
Examples::
|
336
|
+
|
337
|
+
>>> m = FeatureAlphaDropout(p=0.2)
|
338
|
+
>>> x = random.randn(20, 16, 4, 32, 32)
|
339
|
+
>>> output = m(x)
|
340
|
+
|
341
|
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
342
|
+
.. _Efficient Object Localization Using Convolutional Networks:
|
343
|
+
https://arxiv.org/abs/1411.4280
|
344
|
+
"""
|
345
|
+
__module__ = 'brainstate.nn'
|
346
|
+
|
347
|
+
def forward(self, x):
|
348
|
+
return F.feature_alpha_dropout(x, self.p, self.training)
|
349
|
+
|
350
|
+
|
351
|
+
class DropoutFixed(ElementWiseBlock):
|
352
|
+
"""
|
353
|
+
A dropout layer with the fixed dropout mask along the time axis once after initialized.
|
354
|
+
|
355
|
+
In training, to compensate for the fraction of input values dropped (`rate`),
|
356
|
+
all surviving values are multiplied by `1 / (1 - rate)`.
|
357
|
+
|
358
|
+
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
359
|
+
circumstances it is a no-op.
|
360
|
+
|
361
|
+
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
362
|
+
neural networks from overfitting." The journal of machine learning
|
363
|
+
research 15.1 (2014): 1929-1958.
|
364
|
+
|
365
|
+
.. admonition:: Tip
|
366
|
+
:class: tip
|
367
|
+
|
368
|
+
This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
|
369
|
+
Network Architectures <https://arxiv.org/abs/1903.06379>`_:
|
370
|
+
|
371
|
+
There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
|
372
|
+
training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
|
373
|
+
are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
|
374
|
+
iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
|
375
|
+
the output error and modify the network parameters only at the last time step. For dropout to be effective in
|
376
|
+
our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
|
377
|
+
data is not changed, such that the neural network is constituted by the same random subset of units during
|
378
|
+
each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
|
379
|
+
each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
|
380
|
+
iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
|
381
|
+
are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
|
382
|
+
time window within an iteration.
|
383
|
+
|
384
|
+
Args:
|
385
|
+
in_size: The size of the input tensor.
|
386
|
+
prob: Probability to keep element of the tensor.
|
387
|
+
mode: Mode. The computation mode of the object.
|
388
|
+
name: str. The name of the dynamic system.
|
389
|
+
"""
|
390
|
+
__module__ = 'brainstate.nn'
|
391
|
+
|
392
|
+
def __init__(
|
393
|
+
self,
|
394
|
+
in_size: Size,
|
395
|
+
prob: float = 0.5,
|
396
|
+
name: Optional[str] = None
|
397
|
+
) -> None:
|
398
|
+
super().__init__(name=name)
|
399
|
+
assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
|
400
|
+
self.prob = prob
|
401
|
+
self.in_size = in_size
|
402
|
+
self.out_size = in_size
|
403
|
+
|
404
|
+
def init_state(self, batch_size=None, **kwargs):
|
405
|
+
self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
|
406
|
+
|
407
|
+
def update(self, x):
|
408
|
+
dtype = u.math.get_dtype(x)
|
409
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
410
|
+
if fit_phase:
|
411
|
+
if self.mask.value.shape != x.shape:
|
412
|
+
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
|
413
|
+
f"Please call `init_state()` method first.")
|
414
|
+
return jnp.where(self.mask.value,
|
415
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
416
|
+
jnp.asarray(0., dtype=dtype))
|
417
|
+
else:
|
418
|
+
return x
|
@@ -0,0 +1,100 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
|
21
|
+
import brainstate as bst
|
22
|
+
|
23
|
+
|
24
|
+
class TestDropout(unittest.TestCase):
|
25
|
+
|
26
|
+
def test_dropout(self):
|
27
|
+
# Create a Dropout layer with a dropout rate of 0.5
|
28
|
+
dropout_layer = bst.nn.Dropout(0.5)
|
29
|
+
|
30
|
+
# Input data
|
31
|
+
input_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
32
|
+
|
33
|
+
with bst.environ.context(fit=True):
|
34
|
+
# Apply dropout
|
35
|
+
output_data = dropout_layer(input_data)
|
36
|
+
|
37
|
+
# Check that the output has the same shape as the input
|
38
|
+
self.assertEqual(input_data.shape, output_data.shape)
|
39
|
+
|
40
|
+
# Check that some elements are zeroed out
|
41
|
+
self.assertTrue(np.any(output_data == 0))
|
42
|
+
|
43
|
+
# Check that the non-zero elements are scaled by 1/(1-rate)
|
44
|
+
scale_factor = 1 / (1 - 0.5)
|
45
|
+
non_zero_elements = output_data[output_data != 0]
|
46
|
+
expected_non_zero_elements = input_data[output_data != 0] * scale_factor
|
47
|
+
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
48
|
+
|
49
|
+
def test_DropoutFixed(self):
|
50
|
+
dropout_layer = bst.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
|
51
|
+
dropout_layer.init_state(batch_size=2)
|
52
|
+
input_data = np.random.randn(2, 2, 3)
|
53
|
+
with bst.environ.context(fit=True):
|
54
|
+
output_data = dropout_layer.update(input_data)
|
55
|
+
self.assertEqual(input_data.shape, output_data.shape)
|
56
|
+
self.assertTrue(np.any(output_data == 0))
|
57
|
+
scale_factor = 1 / (1 - 0.5)
|
58
|
+
non_zero_elements = output_data[output_data != 0]
|
59
|
+
expected_non_zero_elements = input_data[output_data != 0] * scale_factor
|
60
|
+
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
61
|
+
|
62
|
+
def test_Dropout1d(self):
|
63
|
+
dropout_layer = bst.nn.Dropout1d(prob=0.5)
|
64
|
+
input_data = np.random.randn(2, 3, 4)
|
65
|
+
with bst.environ.context(fit=True):
|
66
|
+
output_data = dropout_layer(input_data)
|
67
|
+
self.assertEqual(input_data.shape, output_data.shape)
|
68
|
+
self.assertTrue(np.any(output_data == 0))
|
69
|
+
scale_factor = 1 / (1 - 0.5)
|
70
|
+
non_zero_elements = output_data[output_data != 0]
|
71
|
+
expected_non_zero_elements = input_data[output_data != 0] * scale_factor
|
72
|
+
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
73
|
+
|
74
|
+
def test_Dropout2d(self):
|
75
|
+
dropout_layer = bst.nn.Dropout2d(prob=0.5)
|
76
|
+
input_data = np.random.randn(2, 3, 4, 5)
|
77
|
+
with bst.environ.context(fit=True):
|
78
|
+
output_data = dropout_layer(input_data)
|
79
|
+
self.assertEqual(input_data.shape, output_data.shape)
|
80
|
+
self.assertTrue(np.any(output_data == 0))
|
81
|
+
scale_factor = 1 / (1 - 0.5)
|
82
|
+
non_zero_elements = output_data[output_data != 0]
|
83
|
+
expected_non_zero_elements = input_data[output_data != 0] * scale_factor
|
84
|
+
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
85
|
+
|
86
|
+
def test_Dropout3d(self):
|
87
|
+
dropout_layer = bst.nn.Dropout3d(prob=0.5)
|
88
|
+
input_data = np.random.randn(2, 3, 4, 5, 6)
|
89
|
+
with bst.environ.context(fit=True):
|
90
|
+
output_data = dropout_layer(input_data)
|
91
|
+
self.assertEqual(input_data.shape, output_data.shape)
|
92
|
+
self.assertTrue(np.any(output_data == 0))
|
93
|
+
scale_factor = 1 / (1 - 0.5)
|
94
|
+
non_zero_elements = output_data[output_data != 0]
|
95
|
+
expected_non_zero_elements = input_data[output_data != 0] * scale_factor
|
96
|
+
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
97
|
+
|
98
|
+
|
99
|
+
if __name__ == '__main__':
|
100
|
+
unittest.main()
|