brainstate 0.1.0.post20241125__py2.py3-none-any.whl → 0.1.0.post20241209__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +15 -1
- brainstate/augment/_mapping_test.py +84 -0
- brainstate/compile/_loop_collect_return.py +5 -1
- brainstate/compile/_make_jaxpr.py +30 -25
- brainstate/compile/_progress_bar.py +30 -12
- brainstate/functional/_activations.py +4 -12
- brainstate/graph/_graph_operation.py +4 -1
- brainstate/nn/_collective_ops.py +18 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -1
- brainstate/nn/_elementwise/_dropout.py +31 -22
- brainstate/nn/_interaction/_normalizations.py +598 -66
- brainstate/util/_tracers.py +0 -7
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/RECORD +17 -19
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/top_level.txt +0 -1
- benchmark/COBA_2005.py +0 -125
- benchmark/CUBA_2005.py +0 -149
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/WHEEL +0 -0
brainstate/_state.py
CHANGED
@@ -99,6 +99,9 @@ def catch_new_states(tag: str = None) -> List:
|
|
99
99
|
|
100
100
|
|
101
101
|
class Catcher:
|
102
|
+
"""
|
103
|
+
The catcher to catch the new states.
|
104
|
+
"""
|
102
105
|
def __init__(self, tag: str):
|
103
106
|
self.tag = tag
|
104
107
|
self.state_ids = set()
|
@@ -231,6 +234,7 @@ class State(Generic[A], PrettyRepr):
|
|
231
234
|
# avoid using self._setattr to avoid the check
|
232
235
|
vars(self).update(metadata)
|
233
236
|
|
237
|
+
# record the state initialization
|
234
238
|
record_state_init(self)
|
235
239
|
|
236
240
|
if not TYPE_CHECKING:
|
@@ -290,7 +294,6 @@ class State(Generic[A], PrettyRepr):
|
|
290
294
|
v: The value.
|
291
295
|
"""
|
292
296
|
self.write_value(v)
|
293
|
-
self._been_writen = True
|
294
297
|
|
295
298
|
def write_value(self, v) -> None:
|
296
299
|
# value checking
|
@@ -301,6 +304,8 @@ class State(Generic[A], PrettyRepr):
|
|
301
304
|
record_state_value_write(self)
|
302
305
|
# set the value
|
303
306
|
self._value = v
|
307
|
+
# set flag
|
308
|
+
self._been_writen = True
|
304
309
|
|
305
310
|
def restore_value(self, v) -> None:
|
306
311
|
"""
|
@@ -511,6 +516,15 @@ class LongTermState(State):
|
|
511
516
|
__module__ = 'brainstate'
|
512
517
|
|
513
518
|
|
519
|
+
class BatchState(LongTermState):
|
520
|
+
"""
|
521
|
+
The batch state, which is used to store the batch data in the program.
|
522
|
+
"""
|
523
|
+
|
524
|
+
__module__ = 'brainstate'
|
525
|
+
|
526
|
+
|
527
|
+
|
514
528
|
class HiddenState(ShortTermState):
|
515
529
|
"""
|
516
530
|
The hidden state, which is used to store the hidden data in a dynamic model.
|
@@ -204,7 +204,91 @@ class TestVmap(unittest.TestCase):
|
|
204
204
|
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
|
205
205
|
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
206
206
|
|
207
|
+
write_state_ids = [id(st) for st in trace.get_write_states()]
|
208
|
+
read_state_ids = [id(st) for st in trace.get_read_states()]
|
209
|
+
|
210
|
+
assert id(foo.a) in read_state_ids
|
211
|
+
assert id(foo.b) in write_state_ids
|
212
|
+
|
213
|
+
print(trace.get_write_states())
|
214
|
+
print(trace.get_read_states())
|
215
|
+
|
216
|
+
|
217
|
+
|
218
|
+
def test_vmap_jit(self):
|
219
|
+
class Foo(bst.nn.Module):
|
220
|
+
def __init__(self):
|
221
|
+
super().__init__()
|
222
|
+
self.a = bst.ParamState(jnp.arange(4))
|
223
|
+
self.b = bst.ShortTermState(jnp.arange(4))
|
224
|
+
|
225
|
+
def __call__(self):
|
226
|
+
self.b.value = self.a.value * self.b.value
|
227
|
+
|
228
|
+
@bst.augment.vmap
|
229
|
+
def mul(foo):
|
230
|
+
foo()
|
231
|
+
|
232
|
+
@bst.compile.jit
|
233
|
+
def mul_jit(inp):
|
234
|
+
mul(foo)
|
235
|
+
foo.a.value += inp
|
236
|
+
|
237
|
+
foo = Foo()
|
238
|
+
with bst.StateTraceStack() as trace:
|
239
|
+
mul_jit(1.)
|
240
|
+
|
241
|
+
print(foo.a.value)
|
242
|
+
print(foo.b.value)
|
243
|
+
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
|
244
|
+
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
245
|
+
|
246
|
+
write_state_ids = [id(st) for st in trace.get_write_states()]
|
247
|
+
read_state_ids = [id(st) for st in trace.get_read_states()]
|
248
|
+
|
249
|
+
assert id(foo.a) in write_state_ids
|
250
|
+
assert id(foo.b) in write_state_ids
|
251
|
+
|
207
252
|
print(trace.get_write_states())
|
208
253
|
print(trace.get_read_states())
|
209
254
|
|
210
255
|
|
256
|
+
def test_vmap_grad(self):
|
257
|
+
class Foo(bst.nn.Module):
|
258
|
+
def __init__(self):
|
259
|
+
super().__init__()
|
260
|
+
self.a = bst.ParamState(jnp.arange(4.))
|
261
|
+
self.b = bst.ShortTermState(jnp.arange(4.))
|
262
|
+
|
263
|
+
def __call__(self):
|
264
|
+
self.b.value = self.a.value * self.b.value
|
265
|
+
|
266
|
+
@bst.augment.vmap
|
267
|
+
def mul(foo):
|
268
|
+
foo()
|
269
|
+
|
270
|
+
def loss():
|
271
|
+
mul(foo)
|
272
|
+
return jnp.sum(foo.b.value)
|
273
|
+
|
274
|
+
foo = Foo()
|
275
|
+
with bst.StateTraceStack() as trace:
|
276
|
+
grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
|
277
|
+
print(grads)
|
278
|
+
print(loss)
|
279
|
+
|
280
|
+
# print(foo.a.value)
|
281
|
+
# print(foo.b.value)
|
282
|
+
# self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
|
283
|
+
# self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
284
|
+
#
|
285
|
+
# write_state_ids = [id(st) for st in trace.get_write_states()]
|
286
|
+
# read_state_ids = [id(st) for st in trace.get_read_states()]
|
287
|
+
#
|
288
|
+
# assert id(foo.a) in write_state_ids
|
289
|
+
# assert id(foo.b) in write_state_ids
|
290
|
+
#
|
291
|
+
# print(trace.get_write_states())
|
292
|
+
# print(trace.get_read_states())
|
293
|
+
|
294
|
+
|
@@ -211,7 +211,11 @@ def scan(
|
|
211
211
|
|
212
212
|
# scan
|
213
213
|
init = (all_writen_state_vals, init)
|
214
|
-
(all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f,
|
214
|
+
(all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f,
|
215
|
+
init,
|
216
|
+
xs,
|
217
|
+
length=length,
|
218
|
+
reverse=reverse,
|
215
219
|
unroll=unroll)
|
216
220
|
# assign the written state values and restore the read state values
|
217
221
|
write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
|
@@ -72,7 +72,6 @@ from jax.util import wraps
|
|
72
72
|
from brainstate._state import State, StateTraceStack
|
73
73
|
from brainstate._utils import set_module_as
|
74
74
|
from brainstate.typing import PyTree
|
75
|
-
from brainstate.util._tracers import new_jax_trace
|
76
75
|
|
77
76
|
AxisName = Hashable
|
78
77
|
|
@@ -112,28 +111,27 @@ def _new_arg_fn(frame, trace, aval):
|
|
112
111
|
return tracer
|
113
112
|
|
114
113
|
|
115
|
-
def
|
116
|
-
|
117
|
-
frame
|
114
|
+
def _new_jax_trace():
|
115
|
+
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
116
|
+
frame = main.jaxpr_stack[-1]
|
117
|
+
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
118
|
+
return frame, trace
|
119
|
+
|
120
|
+
|
121
|
+
def _init_state_trace_stack() -> StateTraceStack:
|
118
122
|
state_trace: StateTraceStack = StateTraceStack()
|
119
|
-
# Set the function to transform the new argument to a tracer
|
120
|
-
state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
|
121
|
-
return state_trace
|
122
123
|
|
124
|
+
if jax.__version_info__ < (0, 4, 36):
|
125
|
+
# Should be within the calling of ``jax.make_jaxpr()``
|
126
|
+
frame, trace = _new_jax_trace()
|
127
|
+
# Set the function to transform the new argument to a tracer
|
128
|
+
state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
|
129
|
+
return state_trace
|
123
130
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
# Args:
|
129
|
-
# x: The input.
|
130
|
-
#
|
131
|
-
# Returns:
|
132
|
-
# The abstractified input.
|
133
|
-
# """
|
134
|
-
# if isinstance(x, pe.DynamicJaxprTracer):
|
135
|
-
# return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
|
136
|
-
# return shaped_abstractify(x)
|
131
|
+
else:
|
132
|
+
trace = jax.core.trace_ctx.trace
|
133
|
+
state_trace.set_new_arg(trace.new_arg)
|
134
|
+
return state_trace
|
137
135
|
|
138
136
|
|
139
137
|
class StatefulFunction(object):
|
@@ -383,12 +381,15 @@ class StatefulFunction(object):
|
|
383
381
|
A tuple of the states that are read and written by the function and the output of the function.
|
384
382
|
"""
|
385
383
|
# state trace
|
386
|
-
state_trace =
|
384
|
+
state_trace = _init_state_trace_stack()
|
387
385
|
self._cached_state_trace[cache_key] = state_trace
|
388
386
|
with state_trace:
|
389
387
|
out = self.fun(*args, **kwargs)
|
390
|
-
state_values =
|
391
|
-
|
388
|
+
state_values = (
|
389
|
+
state_trace.get_write_state_values(True)
|
390
|
+
if return_only_write else
|
391
|
+
state_trace.get_state_values()
|
392
|
+
)
|
392
393
|
state_trace.recovery_original_values()
|
393
394
|
|
394
395
|
# State instance as functional returns is not allowed.
|
@@ -419,17 +420,21 @@ class StatefulFunction(object):
|
|
419
420
|
try:
|
420
421
|
# jaxpr
|
421
422
|
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
422
|
-
functools.partial(
|
423
|
+
functools.partial(
|
424
|
+
self._wrapped_fun_to_eval,
|
425
|
+
cache_key,
|
426
|
+
return_only_write=return_only_write
|
427
|
+
),
|
423
428
|
static_argnums=self.static_argnums,
|
424
429
|
axis_env=self.axis_env,
|
425
430
|
return_shape=True,
|
426
431
|
abstracted_axes=self.abstracted_axes
|
427
432
|
)(*args, **kwargs)
|
428
|
-
|
429
433
|
# returns
|
430
434
|
self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
|
431
435
|
self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
|
432
436
|
self._cached_jaxpr[cache_key] = jaxpr
|
437
|
+
|
433
438
|
except Exception as e:
|
434
439
|
try:
|
435
440
|
self._cached_state_trace.pop(cache_key)
|
@@ -93,19 +93,37 @@ class ProgressBarRunner(object):
|
|
93
93
|
self.tqdm_bars[0].update(self.remainder)
|
94
94
|
self.tqdm_bars[0].close()
|
95
95
|
|
96
|
+
def _tqdm(self, is_init, is_print, is_final):
|
97
|
+
if is_init:
|
98
|
+
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
99
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
100
|
+
if is_print:
|
101
|
+
self.tqdm_bars[0].update(self.print_freq)
|
102
|
+
if is_final:
|
103
|
+
if self.remainder > 0:
|
104
|
+
self.tqdm_bars[0].update(self.remainder)
|
105
|
+
self.tqdm_bars[0].close()
|
106
|
+
|
96
107
|
def __call__(self, iter_num, *args, **kwargs):
|
97
|
-
|
108
|
+
jax.debug.callback(
|
109
|
+
self._tqdm,
|
98
110
|
iter_num == 0,
|
99
|
-
lambda: jax.debug.callback(self._define_tqdm),
|
100
|
-
lambda: None,
|
101
|
-
)
|
102
|
-
_ = jax.lax.cond(
|
103
111
|
(iter_num + 1) % self.print_freq == 0,
|
104
|
-
|
105
|
-
lambda: None,
|
106
|
-
)
|
107
|
-
_ = jax.lax.cond(
|
108
|
-
iter_num == self.n - 1,
|
109
|
-
lambda: jax.debug.callback(self._close_tqdm),
|
110
|
-
lambda: None,
|
112
|
+
iter_num == self.n - 1
|
111
113
|
)
|
114
|
+
|
115
|
+
# _ = jax.lax.cond(
|
116
|
+
# iter_num == 0,
|
117
|
+
# lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
118
|
+
# lambda: None,
|
119
|
+
# )
|
120
|
+
# _ = jax.lax.cond(
|
121
|
+
# (iter_num + 1) % self.print_freq == 0,
|
122
|
+
# lambda: jax.debug.callback(self._update_tqdm, ordered=True),
|
123
|
+
# lambda: None,
|
124
|
+
# )
|
125
|
+
# _ = jax.lax.cond(
|
126
|
+
# iter_num == self.n - 1,
|
127
|
+
# lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
128
|
+
# lambda: None,
|
129
|
+
# )
|
@@ -588,8 +588,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
|
588
588
|
|
589
589
|
def log_softmax(x: ArrayLike,
|
590
590
|
axis: int | tuple[int, ...] | None = -1,
|
591
|
-
where: ArrayLike | None = None,
|
592
|
-
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
591
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
593
592
|
r"""Log-Softmax function.
|
594
593
|
|
595
594
|
Computes the logarithm of the :code:`softmax` function, which rescales
|
@@ -604,8 +603,6 @@ def log_softmax(x: ArrayLike,
|
|
604
603
|
axis: the axis or axes along which the :code:`log_softmax` should be
|
605
604
|
computed. Either an integer or a tuple of integers.
|
606
605
|
where: Elements to include in the :code:`log_softmax`.
|
607
|
-
initial: The minimum value used to shift the input array. Must be present
|
608
|
-
when :code:`where` is not None.
|
609
606
|
|
610
607
|
Returns:
|
611
608
|
An array.
|
@@ -613,15 +610,12 @@ def log_softmax(x: ArrayLike,
|
|
613
610
|
See also:
|
614
611
|
:func:`softmax`
|
615
612
|
"""
|
616
|
-
|
617
|
-
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
618
|
-
return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
|
613
|
+
return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where)
|
619
614
|
|
620
615
|
|
621
616
|
def softmax(x: ArrayLike,
|
622
617
|
axis: int | tuple[int, ...] | None = -1,
|
623
|
-
where: ArrayLike | None = None,
|
624
|
-
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
618
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
625
619
|
r"""Softmax function.
|
626
620
|
|
627
621
|
Computes the function which rescales elements to the range :math:`[0, 1]`
|
@@ -645,9 +639,7 @@ def softmax(x: ArrayLike,
|
|
645
639
|
See also:
|
646
640
|
:func:`log_softmax`
|
647
641
|
"""
|
648
|
-
|
649
|
-
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
650
|
-
return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
|
642
|
+
return _keep_unit(jax.nn.softmax, x, axis=axis, where=where)
|
651
643
|
|
652
644
|
|
653
645
|
def standardize(x: ArrayLike,
|
@@ -608,7 +608,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
608
608
|
if isinstance(value, TreefyState):
|
609
609
|
variable.update_from_ref(value)
|
610
610
|
elif isinstance(value, State):
|
611
|
-
|
611
|
+
if value._been_writen:
|
612
|
+
variable.write_value(value.value)
|
613
|
+
else:
|
614
|
+
variable.restore_value(value.value)
|
612
615
|
else:
|
613
616
|
raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
|
614
617
|
else: # if it doesn't, create a new variable
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -20,8 +20,10 @@ from typing import Dict, Callable, TypeVar
|
|
20
20
|
|
21
21
|
import jax
|
22
22
|
|
23
|
+
from brainstate._state import catch_new_states
|
23
24
|
from brainstate._utils import set_module_as
|
24
25
|
from brainstate.graph import nodes
|
26
|
+
from brainstate.util._filter import Filter
|
25
27
|
from ._module import Module
|
26
28
|
|
27
29
|
# the maximum order
|
@@ -74,16 +76,29 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
|
|
74
76
|
|
75
77
|
|
76
78
|
@set_module_as('brainstate.nn')
|
77
|
-
def init_all_states(
|
79
|
+
def init_all_states(
|
80
|
+
target: T,
|
81
|
+
*args,
|
82
|
+
exclude: Filter = None,
|
83
|
+
**kwargs
|
84
|
+
) -> T:
|
78
85
|
"""
|
79
86
|
Collectively initialize states of all children nodes in the given target.
|
80
87
|
|
81
88
|
Args:
|
82
89
|
target: The target Module.
|
90
|
+
exclude: The filter to exclude some nodes.
|
91
|
+
tag: The tag for the new states.
|
92
|
+
args: The positional arguments for the initialization, which will be passed to the `init_state` method
|
93
|
+
of each node.
|
94
|
+
kwargs: The keyword arguments for the initialization, which will be passed to the `init_state` method
|
95
|
+
of each node.
|
83
96
|
|
84
97
|
Returns:
|
85
98
|
The target Module.
|
86
99
|
"""
|
100
|
+
|
101
|
+
# node that has `call_order` decorated
|
87
102
|
nodes_with_order = []
|
88
103
|
|
89
104
|
nodes_ = nodes(target).filter(Module)
|
@@ -97,7 +112,7 @@ def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
|
|
97
112
|
else:
|
98
113
|
node.init_state(*args, **kwargs)
|
99
114
|
|
100
|
-
# reset the node's states
|
115
|
+
# reset the node's states with `call_order`
|
101
116
|
for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
|
102
117
|
node.init_state(*args, **kwargs)
|
103
118
|
|
@@ -115,6 +130,7 @@ def reset_all_states(target: Module, *args, **kwargs) -> Module:
|
|
115
130
|
Returns:
|
116
131
|
The target Module.
|
117
132
|
"""
|
133
|
+
|
118
134
|
nodes_with_order = []
|
119
135
|
|
120
136
|
# reset node whose `init_state` has no `call_order`
|
@@ -112,7 +112,7 @@ class STP(Synapse):
|
|
112
112
|
self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
|
113
113
|
|
114
114
|
def update(self, pre_spike):
|
115
|
-
u = exp_euler_step(lambda u:
|
115
|
+
u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
|
116
116
|
x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
|
117
117
|
|
118
118
|
# --- original code:
|
@@ -17,7 +17,7 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
from functools import partial
|
20
|
-
from typing import Optional
|
20
|
+
from typing import Optional, Sequence
|
21
21
|
|
22
22
|
import brainunit as u
|
23
23
|
import jax.numpy as jnp
|
@@ -29,7 +29,6 @@ from brainstate.typing import Size
|
|
29
29
|
|
30
30
|
__all__ = [
|
31
31
|
'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
|
32
|
-
'AlphaDropout', 'FeatureAlphaDropout',
|
33
32
|
]
|
34
33
|
|
35
34
|
|
@@ -47,9 +46,9 @@ class Dropout(ElementWiseBlock):
|
|
47
46
|
research 15.1 (2014): 1929-1958.
|
48
47
|
|
49
48
|
Args:
|
50
|
-
|
51
|
-
|
52
|
-
|
49
|
+
prob: Probability to keep element of the tensor.
|
50
|
+
broadcast_dims: dimensions that will share the same dropout mask.
|
51
|
+
name: str. The name of the dynamic system.
|
53
52
|
|
54
53
|
"""
|
55
54
|
__module__ = 'brainstate.nn'
|
@@ -57,20 +56,28 @@ class Dropout(ElementWiseBlock):
|
|
57
56
|
def __init__(
|
58
57
|
self,
|
59
58
|
prob: float = 0.5,
|
59
|
+
broadcast_dims: Sequence[int] = (),
|
60
60
|
name: Optional[str] = None
|
61
61
|
) -> None:
|
62
62
|
super().__init__(name=name)
|
63
63
|
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
64
64
|
self.prob = prob
|
65
|
+
self.broadcast_dims = broadcast_dims
|
65
66
|
|
66
67
|
def __call__(self, x):
|
67
68
|
dtype = u.math.get_dtype(x)
|
68
69
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
69
70
|
if fit_phase and self.prob < 1.:
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
71
|
+
broadcast_shape = list(x.shape)
|
72
|
+
for dim in self.broadcast_dims:
|
73
|
+
broadcast_shape[dim] = 1
|
74
|
+
keep_mask = random.bernoulli(self.prob, broadcast_shape)
|
75
|
+
keep_mask = jnp.broadcast_to(keep_mask, x.shape)
|
76
|
+
return jnp.where(
|
77
|
+
keep_mask,
|
78
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
79
|
+
jnp.asarray(0., dtype=dtype)
|
80
|
+
)
|
74
81
|
else:
|
75
82
|
return x
|
76
83
|
|
@@ -88,12 +95,11 @@ class _DropoutNd(ElementWiseBlock):
|
|
88
95
|
name: Optional[str] = None
|
89
96
|
) -> None:
|
90
97
|
super().__init__(name=name)
|
91
|
-
assert 0. <= prob
|
98
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
92
99
|
self.prob = prob
|
93
100
|
self.channel_axis = channel_axis
|
94
101
|
|
95
102
|
def __call__(self, x):
|
96
|
-
|
97
103
|
# check input shape
|
98
104
|
inp_dim = u.math.ndim(x)
|
99
105
|
if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
|
@@ -112,12 +118,15 @@ class _DropoutNd(ElementWiseBlock):
|
|
112
118
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
113
119
|
|
114
120
|
# generate mask
|
115
|
-
if fit_phase
|
121
|
+
if fit_phase and self.prob < 1.:
|
116
122
|
dtype = u.math.get_dtype(x)
|
117
|
-
keep_mask =
|
118
|
-
|
119
|
-
|
120
|
-
|
123
|
+
keep_mask = random.bernoulli(self.prob, mask_shape)
|
124
|
+
keep_mask = jnp.broadcast_to(keep_mask, x.shape)
|
125
|
+
return jnp.where(
|
126
|
+
keep_mask,
|
127
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
128
|
+
jnp.asarray(0., dtype=dtype)
|
129
|
+
)
|
121
130
|
else:
|
122
131
|
return x
|
123
132
|
|
@@ -296,8 +305,8 @@ class AlphaDropout(_DropoutNd):
|
|
296
305
|
"""
|
297
306
|
__module__ = 'brainstate.nn'
|
298
307
|
|
299
|
-
def
|
300
|
-
|
308
|
+
def update(self, *args, **kwargs):
|
309
|
+
raise NotImplementedError("AlphaDropout is not supported in the current version.")
|
301
310
|
|
302
311
|
|
303
312
|
class FeatureAlphaDropout(_DropoutNd):
|
@@ -344,8 +353,8 @@ class FeatureAlphaDropout(_DropoutNd):
|
|
344
353
|
"""
|
345
354
|
__module__ = 'brainstate.nn'
|
346
355
|
|
347
|
-
def
|
348
|
-
|
356
|
+
def update(self, *args, **kwargs):
|
357
|
+
raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
|
349
358
|
|
350
359
|
|
351
360
|
class DropoutFixed(ElementWiseBlock):
|
@@ -396,7 +405,7 @@ class DropoutFixed(ElementWiseBlock):
|
|
396
405
|
name: Optional[str] = None
|
397
406
|
) -> None:
|
398
407
|
super().__init__(name=name)
|
399
|
-
assert 0. <= prob
|
408
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
400
409
|
self.prob = prob
|
401
410
|
self.in_size = in_size
|
402
411
|
self.out_size = in_size
|
@@ -407,7 +416,7 @@ class DropoutFixed(ElementWiseBlock):
|
|
407
416
|
def update(self, x):
|
408
417
|
dtype = u.math.get_dtype(x)
|
409
418
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
410
|
-
if fit_phase
|
419
|
+
if fit_phase and self.prob < 1.:
|
411
420
|
if self.mask.value.shape != x.shape:
|
412
421
|
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
|
413
422
|
f"Please call `init_state()` method first.")
|