brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240612__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +43 -5
- brainstate/_state.py +17 -0
- brainstate/environ.py +2 -1
- brainstate/functional/__init__.py +3 -2
- brainstate/functional/_activations.py +1 -1
- brainstate/functional/_normalization.py +3 -0
- brainstate/functional/_others.py +49 -0
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_base.py +10 -7
- brainstate/nn/_dynamics.py +20 -0
- brainstate/nn/_embedding.py +66 -0
- brainstate/nn/_rate_rnns.py +17 -0
- brainstate/nn/_readout.py +6 -0
- brainstate/optim/_lr_scheduler_test.py +13 -0
- brainstate/transform/_jit.py +47 -21
- brainstate/transform/_make_jaxpr.py +165 -3
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/METADATA +8 -6
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/RECORD +21 -29
- brainstate/nn/functional/__init__.py +0 -25
- brainstate/nn/functional/_activations.py +0 -754
- brainstate/nn/functional/_normalization.py +0 -69
- brainstate/nn/functional/_spikes.py +0 -90
- brainstate/nn/init/__init__.py +0 -26
- brainstate/nn/init/_base.py +0 -36
- brainstate/nn/init/_generic.py +0 -175
- brainstate/nn/init/_random_inits.py +0 -489
- brainstate/nn/init/_regular_inits.py +0 -109
- brainstate/nn/surrogate.py +0 -1740
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240612.dist-info}/top_level.txt +0 -0
brainstate/_module.py
CHANGED
@@ -46,7 +46,6 @@ For handling the delays:
|
|
46
46
|
|
47
47
|
"""
|
48
48
|
|
49
|
-
import inspect
|
50
49
|
import math
|
51
50
|
import numbers
|
52
51
|
from collections import namedtuple
|
@@ -58,12 +57,12 @@ import jax.numpy as jnp
|
|
58
57
|
import numpy as np
|
59
58
|
|
60
59
|
from . import environ
|
61
|
-
from ._utils import set_module_as
|
62
60
|
from ._state import State, StateDictManager, visible_state_dict
|
63
|
-
from .
|
61
|
+
from ._utils import set_module_as
|
64
62
|
from .math import get_dtype
|
65
63
|
from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
|
66
64
|
from .transform._jit_error import jit_error
|
65
|
+
from .util import unique_name, DictManager, get_unique_name
|
67
66
|
|
68
67
|
Shape = Union[int, Sequence[int]]
|
69
68
|
PyTree = Any
|
@@ -92,7 +91,7 @@ __all__ = [
|
|
92
91
|
'call_order',
|
93
92
|
|
94
93
|
# state processing
|
95
|
-
'init_states', 'load_states', 'save_states', 'assign_state_values',
|
94
|
+
'init_states', 'reset_states', 'load_states', 'save_states', 'assign_state_values',
|
96
95
|
]
|
97
96
|
|
98
97
|
|
@@ -271,6 +270,12 @@ class Module(object):
|
|
271
270
|
"""
|
272
271
|
pass
|
273
272
|
|
273
|
+
def reset_state(self, *args, **kwargs):
|
274
|
+
"""
|
275
|
+
State resetting function.
|
276
|
+
"""
|
277
|
+
pass
|
278
|
+
|
274
279
|
def save_state(self, **kwargs) -> Dict:
|
275
280
|
"""Save states as a dictionary. """
|
276
281
|
return self.states(include_self=True, level=0, method='absolute')
|
@@ -1115,6 +1120,12 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1115
1120
|
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
1116
1121
|
self.history = State(jax.tree.map(fun, self.target_info))
|
1117
1122
|
|
1123
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
1124
|
+
if batch_size is not None:
|
1125
|
+
assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
|
1126
|
+
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
1127
|
+
self.history.value = jax.tree.map(fun, self.target_info)
|
1128
|
+
|
1118
1129
|
def register_entry(
|
1119
1130
|
self,
|
1120
1131
|
entry: str,
|
@@ -1344,7 +1355,7 @@ def call_order(level: int = 0):
|
|
1344
1355
|
@set_module_as('brainstate')
|
1345
1356
|
def init_states(target: Module, *args, **kwargs) -> Module:
|
1346
1357
|
"""
|
1347
|
-
|
1358
|
+
Initialize states of all children nodes in the given target.
|
1348
1359
|
|
1349
1360
|
Args:
|
1350
1361
|
target: The target Module.
|
@@ -1368,6 +1379,33 @@ def init_states(target: Module, *args, **kwargs) -> Module:
|
|
1368
1379
|
return target
|
1369
1380
|
|
1370
1381
|
|
1382
|
+
@set_module_as('brainstate')
|
1383
|
+
def reset_states(target: Module, *args, **kwargs) -> Module:
|
1384
|
+
"""
|
1385
|
+
Reset states of all children nodes in the given target.
|
1386
|
+
|
1387
|
+
Args:
|
1388
|
+
target: The target Module.
|
1389
|
+
|
1390
|
+
Returns:
|
1391
|
+
The target Module.
|
1392
|
+
"""
|
1393
|
+
nodes_with_order = []
|
1394
|
+
|
1395
|
+
# reset node whose `init_state` has no `call_order`
|
1396
|
+
for node in list(target.nodes().values()):
|
1397
|
+
if not hasattr(node.reset_state, 'call_order'):
|
1398
|
+
node.reset_state(*args, **kwargs)
|
1399
|
+
else:
|
1400
|
+
nodes_with_order.append(node)
|
1401
|
+
|
1402
|
+
# reset the node's states
|
1403
|
+
for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
|
1404
|
+
node.reset_state(*args, **kwargs)
|
1405
|
+
|
1406
|
+
return target
|
1407
|
+
|
1408
|
+
|
1371
1409
|
@set_module_as('brainstate')
|
1372
1410
|
def load_states(target: Module, state_dict: Dict, **kwargs):
|
1373
1411
|
"""Copy parameters and buffers from :attr:`state_dict` into
|
brainstate/_state.py
CHANGED
@@ -59,6 +59,23 @@ _global_context_to_check_state_tree = [False]
|
|
59
59
|
def check_state_value_tree() -> None:
|
60
60
|
"""
|
61
61
|
The contex manager to check weather the tree structure of the state value keeps consistently.
|
62
|
+
|
63
|
+
Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
|
64
|
+
the tree structure of the value is not checked to avoid off the repeated evaluation.
|
65
|
+
If you want to check the tree structure of the value once the new value is assigned,
|
66
|
+
you can use this context manager.
|
67
|
+
|
68
|
+
Example::
|
69
|
+
|
70
|
+
```python
|
71
|
+
state = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
72
|
+
with check_state_value_tree():
|
73
|
+
state.value = jnp.zeros((2, 3))
|
74
|
+
|
75
|
+
# The following code will raise an error.
|
76
|
+
state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
|
77
|
+
```
|
78
|
+
|
62
79
|
"""
|
63
80
|
try:
|
64
81
|
_global_context_to_check_state_tree.append(True)
|
brainstate/environ.py
CHANGED
@@ -18,7 +18,8 @@ from .util import MemScaling, IdMemScaling
|
|
18
18
|
__all__ = [
|
19
19
|
'set', 'context', 'get', 'all',
|
20
20
|
'set_host_device_count', 'set_platform',
|
21
|
-
'get_host_device_count', 'get_platform',
|
21
|
+
'get_host_device_count', 'get_platform',
|
22
|
+
'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
|
22
23
|
'tolerance',
|
23
24
|
'dftype', 'ditype', 'dutype', 'dctype',
|
24
25
|
]
|
@@ -20,6 +20,7 @@ from ._normalization import *
|
|
20
20
|
from ._normalization import __all__ as __others_all__
|
21
21
|
from ._spikes import *
|
22
22
|
from ._spikes import __all__ as __spikes_all__
|
23
|
+
from ._others import *
|
24
|
+
from ._others import __all__ as __others_all__
|
23
25
|
|
24
|
-
__all__ = __spikes_all__ + __others_all__ + __activations_all__
|
25
|
-
|
26
|
+
__all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
|
@@ -20,11 +20,14 @@ from typing import Optional
|
|
20
20
|
import jax
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
|
+
from .._utils import set_module_as
|
24
|
+
|
23
25
|
__all__ = [
|
24
26
|
'weight_standardization',
|
25
27
|
]
|
26
28
|
|
27
29
|
|
30
|
+
@set_module_as('brainstate.functional')
|
28
31
|
def weight_standardization(
|
29
32
|
w: jax.typing.ArrayLike,
|
30
33
|
eps: float = 1e-4,
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from functools import partial
|
19
|
+
from typing import Any
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
|
24
|
+
PyTree = Any
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'clip_grad_norm',
|
28
|
+
]
|
29
|
+
|
30
|
+
|
31
|
+
def clip_grad_norm(
|
32
|
+
grad: PyTree,
|
33
|
+
max_norm: float | jax.Array,
|
34
|
+
norm_type: int | str | None = None
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Clips gradient norm of an iterable of parameters.
|
38
|
+
|
39
|
+
The norm is computed over all gradients together, as if they were
|
40
|
+
concatenated into a single vector. Gradients are modified in-place.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
|
44
|
+
max_norm (float): max norm of the gradients.
|
45
|
+
norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
46
|
+
"""
|
47
|
+
norm_fn = partial(jnp.linalg.norm, ord=norm_type)
|
48
|
+
norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
|
49
|
+
return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
|
brainstate/nn/__init__.py
CHANGED
@@ -21,6 +21,8 @@ from ._dynamics import *
|
|
21
21
|
from ._dynamics import __all__ as dynamics_all
|
22
22
|
from ._elementwise import *
|
23
23
|
from ._elementwise import __all__ as elementwise_all
|
24
|
+
from ._embedding import *
|
25
|
+
from ._embedding import __all__ as embed_all
|
24
26
|
from ._misc import *
|
25
27
|
from ._misc import __all__ as _misc_all
|
26
28
|
from ._normalizations import *
|
@@ -43,6 +45,7 @@ __all__ = (
|
|
43
45
|
connections_all +
|
44
46
|
dynamics_all +
|
45
47
|
elementwise_all +
|
48
|
+
embed_all +
|
46
49
|
normalizations_all +
|
47
50
|
others_all +
|
48
51
|
poolings_all +
|
@@ -58,6 +61,7 @@ del (
|
|
58
61
|
connections_all,
|
59
62
|
dynamics_all,
|
60
63
|
elementwise_all,
|
64
|
+
embed_all,
|
61
65
|
normalizations_all,
|
62
66
|
others_all,
|
63
67
|
poolings_all,
|
brainstate/nn/_base.py
CHANGED
@@ -55,22 +55,24 @@ class ExplicitInOutSize(Mixin):
|
|
55
55
|
|
56
56
|
@property
|
57
57
|
def in_size(self) -> Tuple[int, ...]:
|
58
|
-
if self._in_size is None:
|
59
|
-
raise ValueError(f"The input shape is not set in this node: {self} ")
|
60
58
|
return self._in_size
|
61
59
|
|
62
60
|
@in_size.setter
|
63
|
-
def in_size(self, in_size: Sequence[int]):
|
61
|
+
def in_size(self, in_size: Sequence[int] | int):
|
62
|
+
if isinstance(in_size, int):
|
63
|
+
in_size = (in_size,)
|
64
|
+
assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {type(in_size)}"
|
64
65
|
self._in_size = tuple(in_size)
|
65
66
|
|
66
67
|
@property
|
67
68
|
def out_size(self) -> Tuple[int, ...]:
|
68
|
-
if self._out_size is None:
|
69
|
-
raise ValueError(f"The output shape is not set in this node: {self}")
|
70
69
|
return self._out_size
|
71
70
|
|
72
71
|
@out_size.setter
|
73
|
-
def out_size(self, out_size: Sequence[int]):
|
72
|
+
def out_size(self, out_size: Sequence[int] | int):
|
73
|
+
if isinstance(out_size, int):
|
74
|
+
out_size = (out_size,)
|
75
|
+
assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}"
|
74
76
|
self._out_size = tuple(out_size)
|
75
77
|
|
76
78
|
|
@@ -152,7 +154,8 @@ class Sequential(Module, UpdateReturn, Container, ExplicitInOutSize):
|
|
152
154
|
self.children = visible_module_dict(self.format_elements(object, first, *tuple_modules, **dict_modules))
|
153
155
|
|
154
156
|
# the input and output shape
|
155
|
-
|
157
|
+
if first.in_size is not None:
|
158
|
+
self.in_size = first.in_size
|
156
159
|
self.out_size = tuple(in_size)
|
157
160
|
|
158
161
|
def _format_module(self, module, in_size):
|
brainstate/nn/_dynamics.py
CHANGED
@@ -103,6 +103,9 @@ class IF(Neuron):
|
|
103
103
|
def init_state(self, batch_size: int = None, **kwargs):
|
104
104
|
self.V = ShortTermState(init.param(jnp.zeros, self.varshape, batch_size))
|
105
105
|
|
106
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
107
|
+
self.V.value = init.param(jnp.zeros, self.varshape, batch_size)
|
108
|
+
|
106
109
|
def get_spike(self, V=None):
|
107
110
|
V = self.V.value if V is None else V
|
108
111
|
v_scaled = (V - self.V_th) / self.V_th
|
@@ -160,6 +163,9 @@ class LIF(Neuron):
|
|
160
163
|
def init_state(self, batch_size: int = None, **kwargs):
|
161
164
|
self.V = ShortTermState(init.param(init.Constant(self.V_reset), self.varshape, batch_size))
|
162
165
|
|
166
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
167
|
+
self.V.value = init.param(init.Constant(self.V_reset), self.varshape, batch_size)
|
168
|
+
|
163
169
|
def get_spike(self, V=None):
|
164
170
|
V = self.V.value if V is None else V
|
165
171
|
v_scaled = (V - self.V_th) / self.V_th
|
@@ -214,6 +220,10 @@ class ALIF(Neuron):
|
|
214
220
|
self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
215
221
|
self.a = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
216
222
|
|
223
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
224
|
+
self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
225
|
+
self.a.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
226
|
+
|
217
227
|
def get_spike(self, V=None, a=None):
|
218
228
|
V = self.V.value if V is None else V
|
219
229
|
a = self.a.value if a is None else a
|
@@ -275,6 +285,9 @@ class Expon(Synapse):
|
|
275
285
|
def init_state(self, batch_size: int = None, **kwargs):
|
276
286
|
self.g = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
277
287
|
|
288
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
289
|
+
self.g.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
290
|
+
|
278
291
|
def update(self, x=None):
|
279
292
|
self.g.value = exp_euler_step(self.dg, self.g.value, environ.get('t'))
|
280
293
|
if x is not None:
|
@@ -325,6 +338,10 @@ class STP(Synapse):
|
|
325
338
|
self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
|
326
339
|
self.u = ShortTermState(init.param(init.Constant(self.U), self.varshape, batch_size))
|
327
340
|
|
341
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
342
|
+
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
|
343
|
+
self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
|
344
|
+
|
328
345
|
def du(self, u, t):
|
329
346
|
return self.U - u / self.tau_f
|
330
347
|
|
@@ -390,6 +407,9 @@ class STD(Synapse):
|
|
390
407
|
def init_state(self, batch_size: int = None, **kwargs):
|
391
408
|
self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
|
392
409
|
|
410
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
411
|
+
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
|
412
|
+
|
393
413
|
def update(self, pre_spike):
|
394
414
|
t = environ.get('t')
|
395
415
|
x = exp_euler_step(self.dx, self.x.value, t)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Optional, Callable, Union
|
17
|
+
|
18
|
+
from ._base import DnnLayer
|
19
|
+
from .. import init
|
20
|
+
from .._state import ParamState
|
21
|
+
from ..mixin import Mode, Training
|
22
|
+
from ..typing import ArrayLike
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'Embedding',
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
class Embedding(DnnLayer):
|
30
|
+
r"""
|
31
|
+
A simple lookup table that stores embeddings of a fixed size.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
num_embeddings: Size of embedding dictionary. Must be non-negative.
|
35
|
+
embedding_size: Size of each embedding vector. Must be non-negative.
|
36
|
+
embed_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
|
37
|
+
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
num_embeddings: int,
|
43
|
+
embedding_size: int,
|
44
|
+
embed_init: Union[Callable, ArrayLike] = init.LecunUniform(),
|
45
|
+
name: Optional[str] = None,
|
46
|
+
mode: Optional[Mode] = None,
|
47
|
+
):
|
48
|
+
super().__init__(name=name, mode=mode)
|
49
|
+
if num_embeddings < 0:
|
50
|
+
raise ValueError("num_embeddings must not be negative.")
|
51
|
+
if embedding_size < 0:
|
52
|
+
raise ValueError("embedding_size must not be negative.")
|
53
|
+
self.num_embeddings = num_embeddings
|
54
|
+
self.embedding_size = embedding_size
|
55
|
+
self.out_size = (embedding_size,)
|
56
|
+
|
57
|
+
weight = init.param(embed_init, (self.num_embeddings, self.embedding_size))
|
58
|
+
if self.mode.has(Training):
|
59
|
+
self.weight = ParamState(weight)
|
60
|
+
else:
|
61
|
+
self.weight = weight
|
62
|
+
|
63
|
+
def update(self, indices: ArrayLike):
|
64
|
+
if self.mode.has(Training):
|
65
|
+
return self.weight.value[indices]
|
66
|
+
return self.weight[indices]
|
brainstate/nn/_rate_rnns.py
CHANGED
@@ -90,6 +90,9 @@ class ValinaRNNCell(RNNCell):
|
|
90
90
|
def init_state(self, batch_size: int = None, **kwargs):
|
91
91
|
self.h = ShortTermState(init.param(self._state_initializer, self.num_out, batch_size))
|
92
92
|
|
93
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
94
|
+
self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
|
95
|
+
|
93
96
|
def update(self, x):
|
94
97
|
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
95
98
|
h = self.W(xh)
|
@@ -147,6 +150,9 @@ class GRUCell(RNNCell):
|
|
147
150
|
def init_state(self, batch_size: int = None, **kwargs):
|
148
151
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
149
152
|
|
153
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
154
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
155
|
+
|
150
156
|
def update(self, x):
|
151
157
|
old_h = self.h.value
|
152
158
|
xh = jnp.concatenate([x, old_h], axis=-1)
|
@@ -224,6 +230,9 @@ class MGUCell(RNNCell):
|
|
224
230
|
def init_state(self, batch_size: int = None, **kwargs):
|
225
231
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
226
232
|
|
233
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
234
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
235
|
+
|
227
236
|
def update(self, x):
|
228
237
|
old_h = self.h.value
|
229
238
|
xh = jnp.concatenate([x, old_h], axis=-1)
|
@@ -327,6 +336,10 @@ class LSTMCell(RNNCell):
|
|
327
336
|
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
328
337
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
329
338
|
|
339
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
340
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
341
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
342
|
+
|
330
343
|
def update(self, x):
|
331
344
|
h, c = self.h.value, self.c.value
|
332
345
|
xh = jnp.concat([x, h], axis=-1)
|
@@ -379,6 +392,10 @@ class URLSTMCell(RNNCell):
|
|
379
392
|
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
380
393
|
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
381
394
|
|
395
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
396
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
397
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
398
|
+
|
382
399
|
def update(self, x: ArrayLike) -> ArrayLike:
|
383
400
|
h, c = self.h.value, self.c.value
|
384
401
|
xh = jnp.concat([x, h], axis=-1)
|
brainstate/nn/_readout.py
CHANGED
@@ -66,6 +66,9 @@ class LeakyRateReadout(DnnLayer):
|
|
66
66
|
def init_state(self, batch_size=None, **kwargs):
|
67
67
|
self.r = ShortTermState(init.param(init.Constant(0.), self.out_size, batch_size))
|
68
68
|
|
69
|
+
def reset_state(self, batch_size=None, **kwargs):
|
70
|
+
self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
|
71
|
+
|
69
72
|
def update(self, x):
|
70
73
|
r = self.decay * self.r.value + x @ self.weight.value
|
71
74
|
self.r.value = r
|
@@ -109,6 +112,9 @@ class LeakySpikeReadout(Neuron):
|
|
109
112
|
def init_state(self, batch_size, **kwargs):
|
110
113
|
self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
111
114
|
|
115
|
+
def reset_state(self, batch_size, **kwargs):
|
116
|
+
self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
117
|
+
|
112
118
|
@property
|
113
119
|
def spike(self):
|
114
120
|
return self.get_spike(self.V.value)
|
@@ -34,3 +34,16 @@ class TestMultiStepLR(unittest.TestCase):
|
|
34
34
|
self.assertTrue(jnp.allclose(r, 0.001))
|
35
35
|
else:
|
36
36
|
self.assertTrue(jnp.allclose(r, 0.0001))
|
37
|
+
|
38
|
+
def test2(self):
|
39
|
+
lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
|
40
|
+
for i in range(40):
|
41
|
+
r = lr(i)
|
42
|
+
if i < 10:
|
43
|
+
self.assertEqual(r, 0.1)
|
44
|
+
elif i < 20:
|
45
|
+
self.assertTrue(jnp.allclose(r, 0.01))
|
46
|
+
elif i < 30:
|
47
|
+
self.assertTrue(jnp.allclose(r, 0.001))
|
48
|
+
else:
|
49
|
+
self.assertTrue(jnp.allclose(r, 0.0001))
|
brainstate/transform/_jit.py
CHANGED
@@ -23,8 +23,8 @@ import jax
|
|
23
23
|
from jax._src import sharding_impls
|
24
24
|
from jax.lib import xla_client as xc
|
25
25
|
|
26
|
-
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
|
27
26
|
from brainstate._utils import set_module_as
|
27
|
+
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
|
28
28
|
|
29
29
|
__all__ = ['jit']
|
30
30
|
|
@@ -33,10 +33,13 @@ class JittedFunction(Callable):
|
|
33
33
|
"""
|
34
34
|
A wrapped version of ``fun``, set up for just-in-time compilation.
|
35
35
|
"""
|
36
|
-
origin_fun: Callable
|
36
|
+
origin_fun: Callable # the original function
|
37
37
|
stateful_fun: StatefulFunction # the stateful function for extracting states
|
38
38
|
jitted_fun: jax.stages.Wrapped # the jitted function
|
39
|
-
clear_cache: Callable
|
39
|
+
clear_cache: Callable # clear the cache of the jitted function
|
40
|
+
|
41
|
+
def __call__(self, *args, **kwargs):
|
42
|
+
pass
|
40
43
|
|
41
44
|
|
42
45
|
def _get_jitted_fun(
|
@@ -85,12 +88,16 @@ def _get_jitted_fun(
|
|
85
88
|
jit_fun.clear_cache()
|
86
89
|
|
87
90
|
jitted_fun: JittedFunction
|
91
|
+
|
88
92
|
# the original function
|
89
93
|
jitted_fun.origin_fun = fun.fun
|
94
|
+
|
90
95
|
# the stateful function for extracting states
|
91
96
|
jitted_fun.stateful_fun = fun
|
97
|
+
|
92
98
|
# the jitted function
|
93
99
|
jitted_fun.jitted_fun = jit_fun
|
100
|
+
|
94
101
|
# clear cache
|
95
102
|
jitted_fun.clear_cache = clear_cache
|
96
103
|
|
@@ -99,18 +106,18 @@ def _get_jitted_fun(
|
|
99
106
|
|
100
107
|
@set_module_as('brainstate.transform')
|
101
108
|
def jit(
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
109
|
+
fun: Callable = None,
|
110
|
+
in_shardings=sharding_impls.UNSPECIFIED,
|
111
|
+
out_shardings=sharding_impls.UNSPECIFIED,
|
112
|
+
static_argnums: int | Sequence[int] | None = None,
|
113
|
+
donate_argnums: int | Sequence[int] | None = None,
|
114
|
+
donate_argnames: str | Iterable[str] | None = None,
|
115
|
+
keep_unused: bool = False,
|
116
|
+
device: xc.Device | None = None,
|
117
|
+
backend: str | None = None,
|
118
|
+
inline: bool = False,
|
119
|
+
abstracted_axes: Any | None = None,
|
120
|
+
**kwargs
|
114
121
|
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
|
115
122
|
"""
|
116
123
|
Sets up ``fun`` for just-in-time compilation with XLA.
|
@@ -228,12 +235,31 @@ def jit(
|
|
228
235
|
|
229
236
|
if fun is None:
|
230
237
|
def wrapper(fun_again: Callable) -> JittedFunction:
|
231
|
-
return _get_jitted_fun(fun_again,
|
232
|
-
|
233
|
-
|
238
|
+
return _get_jitted_fun(fun_again,
|
239
|
+
in_shardings,
|
240
|
+
out_shardings,
|
241
|
+
static_argnums,
|
242
|
+
donate_argnums,
|
243
|
+
donate_argnames,
|
244
|
+
keep_unused,
|
245
|
+
device,
|
246
|
+
backend,
|
247
|
+
inline,
|
248
|
+
abstracted_axes,
|
249
|
+
**kwargs)
|
250
|
+
|
234
251
|
return wrapper
|
235
252
|
|
236
253
|
else:
|
237
|
-
return _get_jitted_fun(fun,
|
238
|
-
|
239
|
-
|
254
|
+
return _get_jitted_fun(fun,
|
255
|
+
in_shardings,
|
256
|
+
out_shardings,
|
257
|
+
static_argnums,
|
258
|
+
donate_argnums,
|
259
|
+
donate_argnames,
|
260
|
+
keep_unused,
|
261
|
+
device,
|
262
|
+
backend,
|
263
|
+
inline,
|
264
|
+
abstracted_axes,
|
265
|
+
**kwargs)
|