brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_dropout.py
CHANGED
@@ -1,426 +1,618 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from functools import partial
|
18
|
-
from typing import Optional, Sequence
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
from brainstate import random, environ
|
24
|
-
from brainstate._state import ShortTermState
|
25
|
-
from brainstate.typing import Size
|
26
|
-
from .
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
`
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from functools import partial
|
18
|
+
from typing import Optional, Sequence
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate import random, environ
|
24
|
+
from brainstate._state import ShortTermState
|
25
|
+
from brainstate.typing import Size
|
26
|
+
from . import init as init
|
27
|
+
from ._module import ElementWiseBlock
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'Dropout',
|
31
|
+
'Dropout1d',
|
32
|
+
'Dropout2d',
|
33
|
+
'Dropout3d',
|
34
|
+
'AlphaDropout',
|
35
|
+
'FeatureAlphaDropout',
|
36
|
+
'DropoutFixed',
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
class Dropout(ElementWiseBlock):
|
41
|
+
"""A layer that stochastically ignores a subset of inputs each training step.
|
42
|
+
|
43
|
+
In training, to compensate for the fraction of input values dropped (`rate`),
|
44
|
+
all surviving values are multiplied by `1 / (1 - rate)`.
|
45
|
+
|
46
|
+
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
47
|
+
circumstances it is a no-op.
|
48
|
+
|
49
|
+
Parameters
|
50
|
+
----------
|
51
|
+
prob : float
|
52
|
+
Probability to keep element of the tensor. Default is 0.5.
|
53
|
+
broadcast_dims : Sequence[int]
|
54
|
+
Dimensions that will share the same dropout mask. Default is ().
|
55
|
+
name : str, optional
|
56
|
+
The name of the dynamic system.
|
57
|
+
|
58
|
+
References
|
59
|
+
----------
|
60
|
+
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
61
|
+
neural networks from overfitting." The journal of machine learning
|
62
|
+
research 15.1 (2014): 1929-1958.
|
63
|
+
|
64
|
+
Examples
|
65
|
+
--------
|
66
|
+
.. code-block:: python
|
67
|
+
|
68
|
+
>>> import brainstate
|
69
|
+
>>> layer = brainstate.nn.Dropout(prob=0.8)
|
70
|
+
>>> x = brainstate.random.randn(10, 20)
|
71
|
+
>>> with brainstate.environ.context(fit=True):
|
72
|
+
... output = layer(x)
|
73
|
+
>>> output.shape
|
74
|
+
(10, 20)
|
75
|
+
|
76
|
+
"""
|
77
|
+
__module__ = 'brainstate.nn'
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
prob: float = 0.5,
|
82
|
+
broadcast_dims: Sequence[int] = (),
|
83
|
+
name: Optional[str] = None
|
84
|
+
) -> None:
|
85
|
+
super().__init__(name=name)
|
86
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
87
|
+
self.prob = prob
|
88
|
+
self.broadcast_dims = broadcast_dims
|
89
|
+
|
90
|
+
def __call__(self, x):
|
91
|
+
dtype = u.math.get_dtype(x)
|
92
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
93
|
+
if fit_phase and self.prob < 1.:
|
94
|
+
broadcast_shape = list(x.shape)
|
95
|
+
for dim in self.broadcast_dims:
|
96
|
+
broadcast_shape[dim] = 1
|
97
|
+
keep_mask = random.bernoulli(self.prob, broadcast_shape)
|
98
|
+
keep_mask = u.math.broadcast_to(keep_mask, x.shape)
|
99
|
+
return u.math.where(
|
100
|
+
keep_mask,
|
101
|
+
u.math.asarray(x / self.prob, dtype=dtype),
|
102
|
+
u.math.asarray(0., dtype=dtype)
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
return x
|
106
|
+
|
107
|
+
|
108
|
+
class _DropoutNd(ElementWiseBlock):
|
109
|
+
__module__ = 'brainstate.nn'
|
110
|
+
prob: float
|
111
|
+
channel_axis: int
|
112
|
+
minimal_dim: int
|
113
|
+
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
prob: float = 0.5,
|
117
|
+
channel_axis: int = -1,
|
118
|
+
name: Optional[str] = None
|
119
|
+
) -> None:
|
120
|
+
super().__init__(name=name)
|
121
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
122
|
+
self.prob = prob
|
123
|
+
self.channel_axis = channel_axis
|
124
|
+
|
125
|
+
def __call__(self, x):
|
126
|
+
# check input shape
|
127
|
+
inp_dim = u.math.ndim(x)
|
128
|
+
if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
|
129
|
+
raise RuntimeError(f"dropout1d: Expected {self.minimal_dim}D or {self.minimal_dim + 1}D input, "
|
130
|
+
f"but received a {inp_dim}D input. {self._get_msg(x)}")
|
131
|
+
is_not_batched = self.minimal_dim
|
132
|
+
if is_not_batched:
|
133
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
134
|
+
mask_shape = [(dim if i == channel_axis else 1) for i, dim in enumerate(x.shape)]
|
135
|
+
else:
|
136
|
+
channel_axis = (self.channel_axis + 1) if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
137
|
+
assert channel_axis != 0, f"Channel axis must not be 0. But got {self.channel_axis}."
|
138
|
+
mask_shape = [(dim if i in (channel_axis, 0) else 1) for i, dim in enumerate(x.shape)]
|
139
|
+
|
140
|
+
# get fit phase
|
141
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
142
|
+
|
143
|
+
# generate mask
|
144
|
+
if fit_phase and self.prob < 1.:
|
145
|
+
dtype = u.math.get_dtype(x)
|
146
|
+
keep_mask = random.bernoulli(self.prob, mask_shape)
|
147
|
+
keep_mask = jnp.broadcast_to(keep_mask, x.shape)
|
148
|
+
return jnp.where(
|
149
|
+
keep_mask,
|
150
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
151
|
+
jnp.asarray(0., dtype=dtype)
|
152
|
+
)
|
153
|
+
else:
|
154
|
+
return x
|
155
|
+
|
156
|
+
def _get_msg(self, x):
|
157
|
+
return ''
|
158
|
+
|
159
|
+
|
160
|
+
class Dropout1d(_DropoutNd):
|
161
|
+
r"""Randomly zero out entire channels (a channel is a 1D feature map).
|
162
|
+
|
163
|
+
Each channel will be zeroed out independently on every forward call with
|
164
|
+
probability using samples from a Bernoulli distribution. The channel is
|
165
|
+
a 1D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
|
166
|
+
in the batched input is a 1D tensor :math:`\text{input}[i, j]`.
|
167
|
+
|
168
|
+
Usually the input comes from :class:`Conv1d` modules.
|
169
|
+
|
170
|
+
As described in the paper [1]_, if adjacent pixels within feature maps are
|
171
|
+
strongly correlated (as is normally the case in early convolution layers)
|
172
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
173
|
+
just result in an effective learning rate decrease.
|
174
|
+
|
175
|
+
In this case, :class:`Dropout1d` will help promote independence between
|
176
|
+
feature maps and should be used instead.
|
177
|
+
|
178
|
+
Parameters
|
179
|
+
----------
|
180
|
+
prob : float
|
181
|
+
Probability of an element to be kept. Default is 0.5.
|
182
|
+
channel_axis : int
|
183
|
+
The axis representing the channel dimension. Default is -1.
|
184
|
+
name : str, optional
|
185
|
+
The name of the dynamic system.
|
186
|
+
|
187
|
+
Notes
|
188
|
+
-----
|
189
|
+
Input shape: :math:`(N, C, L)` or :math:`(C, L)`.
|
190
|
+
|
191
|
+
Output shape: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
|
192
|
+
|
193
|
+
References
|
194
|
+
----------
|
195
|
+
.. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
196
|
+
https://arxiv.org/abs/1411.4280
|
197
|
+
|
198
|
+
Examples
|
199
|
+
--------
|
200
|
+
.. code-block:: python
|
201
|
+
|
202
|
+
>>> import brainstate
|
203
|
+
>>> m = brainstate.nn.Dropout1d(prob=0.8)
|
204
|
+
>>> x = brainstate.random.randn(20, 32, 16)
|
205
|
+
>>> with brainstate.environ.context(fit=True):
|
206
|
+
... output = m(x)
|
207
|
+
>>> output.shape
|
208
|
+
(20, 32, 16)
|
209
|
+
|
210
|
+
"""
|
211
|
+
__module__ = 'brainstate.nn'
|
212
|
+
minimal_dim: int = 2
|
213
|
+
|
214
|
+
def _get_msg(self, x):
|
215
|
+
return ("Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
|
216
|
+
"spatial dimension, a channel dimension, and an optional batch dimension "
|
217
|
+
"(i.e. 2D or 3D inputs).")
|
218
|
+
|
219
|
+
|
220
|
+
class Dropout2d(_DropoutNd):
|
221
|
+
r"""Randomly zero out entire channels (a channel is a 2D feature map).
|
222
|
+
|
223
|
+
Each channel will be zeroed out independently on every forward call with
|
224
|
+
probability using samples from a Bernoulli distribution. The channel is
|
225
|
+
a 2D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
|
226
|
+
in the batched input is a 2D tensor :math:`\text{input}[i, j]`.
|
227
|
+
|
228
|
+
Usually the input comes from :class:`Conv2d` modules.
|
229
|
+
|
230
|
+
As described in the paper [1]_, if adjacent pixels within feature maps are
|
231
|
+
strongly correlated (as is normally the case in early convolution layers)
|
232
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
233
|
+
just result in an effective learning rate decrease.
|
234
|
+
|
235
|
+
In this case, :class:`Dropout2d` will help promote independence between
|
236
|
+
feature maps and should be used instead.
|
237
|
+
|
238
|
+
Parameters
|
239
|
+
----------
|
240
|
+
prob : float
|
241
|
+
Probability of an element to be kept. Default is 0.5.
|
242
|
+
channel_axis : int
|
243
|
+
The axis representing the channel dimension. Default is -1.
|
244
|
+
name : str, optional
|
245
|
+
The name of the dynamic system.
|
246
|
+
|
247
|
+
Notes
|
248
|
+
-----
|
249
|
+
Input shape: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
250
|
+
|
251
|
+
Output shape: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
|
252
|
+
|
253
|
+
References
|
254
|
+
----------
|
255
|
+
.. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
256
|
+
https://arxiv.org/abs/1411.4280
|
257
|
+
|
258
|
+
Examples
|
259
|
+
--------
|
260
|
+
.. code-block:: python
|
261
|
+
|
262
|
+
>>> import brainstate
|
263
|
+
>>> m = brainstate.nn.Dropout2d(prob=0.8)
|
264
|
+
>>> x = brainstate.random.randn(20, 32, 32, 16)
|
265
|
+
>>> with brainstate.environ.context(fit=True):
|
266
|
+
... output = m(x)
|
267
|
+
>>> output.shape
|
268
|
+
(20, 32, 32, 16)
|
269
|
+
|
270
|
+
"""
|
271
|
+
__module__ = 'brainstate.nn'
|
272
|
+
minimal_dim: int = 3
|
273
|
+
|
274
|
+
def _get_msg(self, x):
|
275
|
+
return ("Note that dropout2d exists to provide channel-wise dropout on inputs with 2 "
|
276
|
+
"spatial dimensions, a channel dimension, and an optional batch dimension "
|
277
|
+
"(i.e. 3D or 4D inputs).")
|
278
|
+
|
279
|
+
|
280
|
+
class Dropout3d(_DropoutNd):
|
281
|
+
r"""Randomly zero out entire channels (a channel is a 3D feature map).
|
282
|
+
|
283
|
+
Each channel will be zeroed out independently on every forward call with
|
284
|
+
probability using samples from a Bernoulli distribution. The channel is
|
285
|
+
a 3D feature map, e.g., the :math:`j`-th channel of the :math:`i`-th sample
|
286
|
+
in the batched input is a 3D tensor :math:`\text{input}[i, j]`.
|
287
|
+
|
288
|
+
Usually the input comes from :class:`Conv3d` modules.
|
289
|
+
|
290
|
+
As described in the paper [1]_, if adjacent pixels within feature maps are
|
291
|
+
strongly correlated (as is normally the case in early convolution layers)
|
292
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
293
|
+
just result in an effective learning rate decrease.
|
294
|
+
|
295
|
+
In this case, :class:`Dropout3d` will help promote independence between
|
296
|
+
feature maps and should be used instead.
|
297
|
+
|
298
|
+
Parameters
|
299
|
+
----------
|
300
|
+
prob : float
|
301
|
+
Probability of an element to be kept. Default is 0.5.
|
302
|
+
channel_axis : int
|
303
|
+
The axis representing the channel dimension. Default is -1.
|
304
|
+
name : str, optional
|
305
|
+
The name of the dynamic system.
|
306
|
+
|
307
|
+
Notes
|
308
|
+
-----
|
309
|
+
Input shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
|
310
|
+
|
311
|
+
Output shape: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
|
312
|
+
|
313
|
+
References
|
314
|
+
----------
|
315
|
+
.. [1] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
316
|
+
https://arxiv.org/abs/1411.4280
|
317
|
+
|
318
|
+
Examples
|
319
|
+
--------
|
320
|
+
.. code-block:: python
|
321
|
+
|
322
|
+
>>> import brainstate
|
323
|
+
>>> m = brainstate.nn.Dropout3d(prob=0.8)
|
324
|
+
>>> x = brainstate.random.randn(20, 16, 4, 32, 32)
|
325
|
+
>>> with brainstate.environ.context(fit=True):
|
326
|
+
... output = m(x)
|
327
|
+
>>> output.shape
|
328
|
+
(20, 16, 4, 32, 32)
|
329
|
+
|
330
|
+
"""
|
331
|
+
__module__ = 'brainstate.nn'
|
332
|
+
minimal_dim: int = 4
|
333
|
+
|
334
|
+
def _get_msg(self, x):
|
335
|
+
return ("Note that dropout3d exists to provide channel-wise dropout on inputs with 3 "
|
336
|
+
"spatial dimensions, a channel dimension, and an optional batch dimension "
|
337
|
+
"(i.e. 4D or 5D inputs).")
|
338
|
+
|
339
|
+
|
340
|
+
class AlphaDropout(_DropoutNd):
|
341
|
+
r"""Applies Alpha Dropout over the input.
|
342
|
+
|
343
|
+
Alpha Dropout is a type of Dropout that maintains the self-normalizing
|
344
|
+
property. For an input with zero mean and unit standard deviation, the output of
|
345
|
+
Alpha Dropout maintains the original mean and standard deviation of the
|
346
|
+
input.
|
347
|
+
|
348
|
+
Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
|
349
|
+
that the outputs have zero mean and unit standard deviation.
|
350
|
+
|
351
|
+
During training, it randomly masks some of the elements of the input
|
352
|
+
tensor with probability using samples from a Bernoulli distribution.
|
353
|
+
The elements to be masked are randomized on every forward call, and scaled
|
354
|
+
and shifted to maintain zero mean and unit standard deviation.
|
355
|
+
|
356
|
+
During evaluation the module simply computes an identity function.
|
357
|
+
|
358
|
+
Parameters
|
359
|
+
----------
|
360
|
+
prob : float
|
361
|
+
Probability of an element to be kept. Default is 0.5.
|
362
|
+
name : str, optional
|
363
|
+
The name of the dynamic system.
|
364
|
+
|
365
|
+
Notes
|
366
|
+
-----
|
367
|
+
Input shape: :math:`(*)`. Input can be of any shape.
|
368
|
+
|
369
|
+
Output shape: :math:`(*)`. Output is of the same shape as input.
|
370
|
+
|
371
|
+
References
|
372
|
+
----------
|
373
|
+
.. [1] Klambauer et al., "Self-Normalizing Neural Networks"
|
374
|
+
https://arxiv.org/abs/1706.02515
|
375
|
+
|
376
|
+
Examples
|
377
|
+
--------
|
378
|
+
.. code-block:: python
|
379
|
+
|
380
|
+
>>> import brainstate
|
381
|
+
>>> m = brainstate.nn.AlphaDropout(prob=0.8)
|
382
|
+
>>> x = brainstate.random.randn(20, 16)
|
383
|
+
>>> with brainstate.environ.context(fit=True):
|
384
|
+
... output = m(x)
|
385
|
+
>>> output.shape
|
386
|
+
(20, 16)
|
387
|
+
|
388
|
+
"""
|
389
|
+
__module__ = 'brainstate.nn'
|
390
|
+
|
391
|
+
def __init__(
|
392
|
+
self,
|
393
|
+
prob: float = 0.5,
|
394
|
+
name: Optional[str] = None
|
395
|
+
) -> None:
|
396
|
+
super().__init__(name=name)
|
397
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
398
|
+
self.prob = prob
|
399
|
+
|
400
|
+
# SELU parameters
|
401
|
+
alpha = -1.7580993408473766
|
402
|
+
self.alpha = alpha
|
403
|
+
|
404
|
+
# Affine transformation parameters to maintain mean and variance
|
405
|
+
self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
|
406
|
+
self.b = -self.a * alpha * prob
|
407
|
+
|
408
|
+
def __call__(self, x):
|
409
|
+
dtype = u.math.get_dtype(x)
|
410
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
411
|
+
if fit_phase and self.prob < 1.:
|
412
|
+
keep_mask = random.bernoulli(self.prob, x.shape)
|
413
|
+
return u.math.where(
|
414
|
+
keep_mask,
|
415
|
+
u.math.asarray(x, dtype=dtype),
|
416
|
+
u.math.asarray(self.alpha, dtype=dtype)
|
417
|
+
) * self.a + self.b
|
418
|
+
else:
|
419
|
+
return x
|
420
|
+
|
421
|
+
|
422
|
+
class FeatureAlphaDropout(ElementWiseBlock):
|
423
|
+
r"""Randomly masks out entire channels with Alpha Dropout properties.
|
424
|
+
|
425
|
+
Instead of setting activations to zero as in regular Dropout, the activations
|
426
|
+
are set to the negative saturation value of the SELU activation function to
|
427
|
+
maintain self-normalizing properties.
|
428
|
+
|
429
|
+
Each channel (e.g., the :math:`j`-th channel of the :math:`i`-th sample in
|
430
|
+
the batch input is a tensor :math:`\text{input}[i, j]`) will be masked
|
431
|
+
independently for each sample on every forward call with probability using
|
432
|
+
samples from a Bernoulli distribution. The elements to be masked are randomized
|
433
|
+
on every forward call, and scaled and shifted to maintain zero mean and unit
|
434
|
+
variance.
|
435
|
+
|
436
|
+
Usually the input comes from convolutional layers with SELU activation.
|
437
|
+
|
438
|
+
As described in the paper [2]_, if adjacent pixels within feature maps are
|
439
|
+
strongly correlated (as is normally the case in early convolution layers)
|
440
|
+
then i.i.d. dropout will not regularize the activations and will otherwise
|
441
|
+
just result in an effective learning rate decrease.
|
442
|
+
|
443
|
+
In this case, :class:`FeatureAlphaDropout` will help promote independence between
|
444
|
+
feature maps and should be used instead.
|
445
|
+
|
446
|
+
Parameters
|
447
|
+
----------
|
448
|
+
prob : float
|
449
|
+
Probability of an element to be kept. Default is 0.5.
|
450
|
+
channel_axis : int
|
451
|
+
The axis representing the channel dimension. Default is -1.
|
452
|
+
name : str, optional
|
453
|
+
The name of the dynamic system.
|
454
|
+
|
455
|
+
Notes
|
456
|
+
-----
|
457
|
+
Input shape: :math:`(N, C, *)` where C is the channel dimension.
|
458
|
+
|
459
|
+
Output shape: Same shape as input.
|
460
|
+
|
461
|
+
References
|
462
|
+
----------
|
463
|
+
.. [1] Klambauer et al., "Self-Normalizing Neural Networks"
|
464
|
+
https://arxiv.org/abs/1706.02515
|
465
|
+
.. [2] Springenberg et al., "Striving for Simplicity: The All Convolutional Net"
|
466
|
+
https://arxiv.org/abs/1411.4280
|
467
|
+
|
468
|
+
Examples
|
469
|
+
--------
|
470
|
+
.. code-block:: python
|
471
|
+
|
472
|
+
>>> import brainstate
|
473
|
+
>>> m = brainstate.nn.FeatureAlphaDropout(prob=0.8)
|
474
|
+
>>> x = brainstate.random.randn(20, 16, 4, 32, 32)
|
475
|
+
>>> with brainstate.environ.context(fit=True):
|
476
|
+
... output = m(x)
|
477
|
+
>>> output.shape
|
478
|
+
(20, 16, 4, 32, 32)
|
479
|
+
|
480
|
+
"""
|
481
|
+
__module__ = 'brainstate.nn'
|
482
|
+
|
483
|
+
def __init__(
|
484
|
+
self,
|
485
|
+
prob: float = 0.5,
|
486
|
+
channel_axis: int = -1,
|
487
|
+
name: Optional[str] = None
|
488
|
+
) -> None:
|
489
|
+
super().__init__(name=name)
|
490
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
491
|
+
self.prob = prob
|
492
|
+
self.channel_axis = channel_axis
|
493
|
+
|
494
|
+
# SELU parameters
|
495
|
+
alpha = -1.7580993408473766
|
496
|
+
self.alpha = alpha
|
497
|
+
|
498
|
+
# Affine transformation parameters to maintain mean and variance
|
499
|
+
self.a = ((1 - prob) * (1 + prob * alpha ** 2)) ** -0.5
|
500
|
+
self.b = -self.a * alpha * prob
|
501
|
+
|
502
|
+
def __call__(self, x):
|
503
|
+
dtype = u.math.get_dtype(x)
|
504
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
505
|
+
if fit_phase and self.prob < 1.:
|
506
|
+
# Create mask shape with 1s except for batch and channel dimensions
|
507
|
+
channel_axis = self.channel_axis if self.channel_axis >= 0 else (x.ndim + self.channel_axis)
|
508
|
+
mask_shape = [1] * x.ndim
|
509
|
+
mask_shape[0] = x.shape[0] # batch dimension
|
510
|
+
mask_shape[channel_axis] = x.shape[channel_axis] # channel dimension
|
511
|
+
|
512
|
+
keep_mask = random.bernoulli(self.prob, mask_shape)
|
513
|
+
keep_mask = u.math.broadcast_to(keep_mask, x.shape)
|
514
|
+
return u.math.where(
|
515
|
+
keep_mask,
|
516
|
+
u.math.asarray(x, dtype=dtype),
|
517
|
+
u.math.asarray(self.alpha, dtype=dtype)
|
518
|
+
) * self.a + self.b
|
519
|
+
else:
|
520
|
+
return x
|
521
|
+
|
522
|
+
|
523
|
+
class DropoutFixed(ElementWiseBlock):
|
524
|
+
"""A dropout layer with a fixed dropout mask along the time axis.
|
525
|
+
|
526
|
+
In training, to compensate for the fraction of input values dropped,
|
527
|
+
all surviving values are multiplied by `1 / (1 - prob)`.
|
528
|
+
|
529
|
+
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
530
|
+
circumstances it is a no-op.
|
531
|
+
|
532
|
+
This kind of Dropout is particularly useful for spiking neural networks (SNNs) where
|
533
|
+
the same dropout mask needs to be applied across multiple time steps within a single
|
534
|
+
mini-batch iteration.
|
535
|
+
|
536
|
+
Parameters
|
537
|
+
----------
|
538
|
+
in_size : tuple or int
|
539
|
+
The size of the input tensor.
|
540
|
+
prob : float
|
541
|
+
Probability to keep element of the tensor. Default is 0.5.
|
542
|
+
name : str, optional
|
543
|
+
The name of the dynamic system.
|
544
|
+
|
545
|
+
Notes
|
546
|
+
-----
|
547
|
+
As described in [2]_, there is a subtle difference in the way dropout is applied in
|
548
|
+
SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of
|
549
|
+
mini-batches. In each iteration, randomly selected units (with dropout ratio of
|
550
|
+
:math:`p`) are disconnected from the network while weighting by its posterior
|
551
|
+
probability (:math:`1-p`).
|
552
|
+
|
553
|
+
However, in SNNs, each iteration has more than one forward propagation depending on
|
554
|
+
the time length of the spike train. We back-propagate the output error and modify
|
555
|
+
the network parameters only at the last time step. For dropout to be effective in
|
556
|
+
our training method, it has to be ensured that the set of connected units within an
|
557
|
+
iteration of mini-batch data is not changed, such that the neural network is
|
558
|
+
constituted by the same random subset of units during each forward propagation within
|
559
|
+
a single iteration.
|
560
|
+
|
561
|
+
On the other hand, if the units are randomly connected at each time-step, the effect
|
562
|
+
of dropout will be averaged out over the entire forward propagation time within an
|
563
|
+
iteration. Then, the dropout effect would fade-out once the output error is propagated
|
564
|
+
backward and the parameters are updated at the last time step. Therefore, we need to
|
565
|
+
keep the set of randomly connected units for the entire time window within an iteration.
|
566
|
+
|
567
|
+
References
|
568
|
+
----------
|
569
|
+
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
570
|
+
neural networks from overfitting." The journal of machine learning
|
571
|
+
research 15.1 (2014): 1929-1958.
|
572
|
+
.. [2] Lee et al., "Enabling Spike-based Backpropagation for Training Deep Neural
|
573
|
+
Network Architectures" https://arxiv.org/abs/1903.06379
|
574
|
+
|
575
|
+
Examples
|
576
|
+
--------
|
577
|
+
.. code-block:: python
|
578
|
+
|
579
|
+
>>> import brainstate
|
580
|
+
>>> layer = brainstate.nn.DropoutFixed(in_size=(20,), prob=0.8)
|
581
|
+
>>> layer.init_state(batch_size=10)
|
582
|
+
>>> x = brainstate.random.randn(10, 20)
|
583
|
+
>>> with brainstate.environ.context(fit=True):
|
584
|
+
... output = layer.update(x)
|
585
|
+
>>> output.shape
|
586
|
+
(10, 20)
|
587
|
+
|
588
|
+
"""
|
589
|
+
__module__ = 'brainstate.nn'
|
590
|
+
|
591
|
+
def __init__(
|
592
|
+
self,
|
593
|
+
in_size: Size,
|
594
|
+
prob: float = 0.5,
|
595
|
+
name: Optional[str] = None
|
596
|
+
) -> None:
|
597
|
+
super().__init__(name=name)
|
598
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
599
|
+
self.prob = prob
|
600
|
+
self.in_size = in_size
|
601
|
+
self.out_size = in_size
|
602
|
+
|
603
|
+
def init_state(self, batch_size=None, **kwargs):
|
604
|
+
if self.prob < 1.:
|
605
|
+
self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
|
606
|
+
|
607
|
+
def update(self, x):
|
608
|
+
dtype = u.math.get_dtype(x)
|
609
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
610
|
+
if fit_phase and self.prob < 1.:
|
611
|
+
if self.mask.value.shape != x.shape:
|
612
|
+
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
|
613
|
+
f"Please call `init_state()` method first.")
|
614
|
+
return u.math.where(self.mask.value,
|
615
|
+
u.math.asarray(x / self.prob, dtype=dtype),
|
616
|
+
u.math.asarray(0., dtype=dtype) * u.get_unit(x))
|
617
|
+
else:
|
618
|
+
return x
|