brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +191 -48
- brainstate/_module_test.py +95 -21
- brainstate/_state.py +17 -0
- brainstate/environ.py +2 -2
- brainstate/functional/__init__.py +3 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_normalization.py +3 -0
- brainstate/functional/_others.py +49 -0
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_base.py +10 -7
- brainstate/nn/_dynamics.py +20 -0
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_embedding.py +66 -0
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/nn/_rate_rnns.py +17 -0
- brainstate/nn/_readout.py +6 -0
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +13 -0
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit.py +47 -21
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +164 -3
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
- brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- brainstate/nn/functional/__init__.py +0 -25
- brainstate/nn/functional/_activations.py +0 -754
- brainstate/nn/functional/_normalization.py +0 -69
- brainstate/nn/functional/_spikes.py +0 -90
- brainstate/nn/init/__init__.py +0 -26
- brainstate/nn/init/_base.py +0 -36
- brainstate/nn/init/_generic.py +0 -175
- brainstate/nn/init/_random_inits.py +0 -489
- brainstate/nn/init/_regular_inits.py +0 -109
- brainstate/nn/surrogate.py +0 -1740
- brainstate-0.0.1.dist-info/RECORD +0 -79
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -20,17 +20,16 @@ A ``State``-based Transformation System for Brain Dynamics Programming
|
|
20
20
|
__version__ = "0.0.1"
|
21
21
|
|
22
22
|
from . import environ
|
23
|
-
from . import
|
23
|
+
from . import functional
|
24
|
+
from . import init
|
24
25
|
from . import mixin
|
25
26
|
from . import nn
|
26
27
|
from . import optim
|
27
28
|
from . import random
|
29
|
+
from . import surrogate
|
28
30
|
from . import transform
|
29
31
|
from . import typing
|
30
32
|
from . import util
|
31
|
-
from . import surrogate
|
32
|
-
from . import functional
|
33
|
-
from . import init
|
34
33
|
from ._module import *
|
35
34
|
from ._module import __all__ as _module_all
|
36
35
|
from ._state import *
|
@@ -39,7 +38,7 @@ from ._state import __all__ as _state_all
|
|
39
38
|
__all__ = (
|
40
39
|
['environ', 'share', 'nn', 'optim', 'random',
|
41
40
|
'surrogate', 'functional', 'init',
|
42
|
-
'mixin', '
|
41
|
+
'mixin', 'transform', 'util', 'typing'] +
|
43
42
|
_module_all + _state_all
|
44
43
|
)
|
45
44
|
del _module_all, _state_all
|
brainstate/_module.py
CHANGED
@@ -46,7 +46,6 @@ For handling the delays:
|
|
46
46
|
|
47
47
|
"""
|
48
48
|
|
49
|
-
import inspect
|
50
49
|
import math
|
51
50
|
import numbers
|
52
51
|
from collections import namedtuple
|
@@ -58,20 +57,21 @@ import jax.numpy as jnp
|
|
58
57
|
import numpy as np
|
59
58
|
|
60
59
|
from . import environ
|
61
|
-
from ._utils import set_module_as
|
62
60
|
from ._state import State, StateDictManager, visible_state_dict
|
63
|
-
from .
|
64
|
-
from .math import get_dtype
|
61
|
+
from ._utils import set_module_as
|
65
62
|
from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
|
66
|
-
from .transform
|
63
|
+
from .transform import jit_error
|
64
|
+
from .util import unique_name, DictManager, get_unique_name
|
67
65
|
|
68
66
|
Shape = Union[int, Sequence[int]]
|
69
67
|
PyTree = Any
|
70
68
|
ArrayLike = jax.typing.ArrayLike
|
71
69
|
|
72
70
|
delay_identifier = '_*_delay_of_'
|
73
|
-
|
74
|
-
|
71
|
+
_DELAY_ROTATE = 'rotation'
|
72
|
+
_DELAY_CONCAT = 'concat'
|
73
|
+
_INTERP_LINEAR = 'linear_interp'
|
74
|
+
_INTERP_ROUND = 'round'
|
75
75
|
|
76
76
|
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
|
77
77
|
|
@@ -92,7 +92,7 @@ __all__ = [
|
|
92
92
|
'call_order',
|
93
93
|
|
94
94
|
# state processing
|
95
|
-
'init_states', 'load_states', 'save_states', 'assign_state_values',
|
95
|
+
'init_states', 'reset_states', 'load_states', 'save_states', 'assign_state_values',
|
96
96
|
]
|
97
97
|
|
98
98
|
|
@@ -258,9 +258,8 @@ class Module(object):
|
|
258
258
|
--------
|
259
259
|
|
260
260
|
>>> import brainstate as bst
|
261
|
-
>>> import brainscale as nn # noqa
|
262
261
|
>>> x = bst.random.rand((10, 10))
|
263
|
-
>>> l = nn.Activation(jax.numpy.tanh)
|
262
|
+
>>> l = bst.nn.Activation(jax.numpy.tanh)
|
264
263
|
>>> y = x >> l
|
265
264
|
"""
|
266
265
|
return self.__call__(other)
|
@@ -271,6 +270,12 @@ class Module(object):
|
|
271
270
|
"""
|
272
271
|
pass
|
273
272
|
|
273
|
+
def reset_state(self, *args, **kwargs):
|
274
|
+
"""
|
275
|
+
State resetting function.
|
276
|
+
"""
|
277
|
+
pass
|
278
|
+
|
274
279
|
def save_state(self, **kwargs) -> Dict:
|
275
280
|
"""Save states as a dictionary. """
|
276
281
|
return self.states(include_self=True, level=0, method='absolute')
|
@@ -396,8 +401,8 @@ class visible_module_list(list):
|
|
396
401
|
retieved when using :py:func:`~.nodes()` function.
|
397
402
|
|
398
403
|
>>> import brainstate as bst
|
399
|
-
>>> l = bst.visible_module_list([
|
400
|
-
>>>
|
404
|
+
>>> l = bst.visible_module_list([bst.nn.Linear(1, 2),
|
405
|
+
>>> bst.nn.LSTMCell(2, 3)])
|
401
406
|
"""
|
402
407
|
|
403
408
|
__module__ = 'brainstate'
|
@@ -1033,14 +1038,14 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1033
1038
|
delay = length data ]
|
1034
1039
|
entries: optional, dict. The delay access entries.
|
1035
1040
|
name: str. The delay name.
|
1036
|
-
|
1041
|
+
delay_method: str. The method used for updating delay. Default None.
|
1037
1042
|
mode: Mode. The computing mode. Default None.
|
1038
1043
|
"""
|
1039
1044
|
|
1040
1045
|
__module__ = 'brainstate'
|
1041
1046
|
|
1042
|
-
|
1043
|
-
max_time: float
|
1047
|
+
non_hashable_params = ('time', 'entries', 'name')
|
1048
|
+
max_time: float #
|
1044
1049
|
max_length: int
|
1045
1050
|
history: Optional[State]
|
1046
1051
|
|
@@ -1048,20 +1053,27 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1048
1053
|
self,
|
1049
1054
|
target_info: PyTree,
|
1050
1055
|
time: Optional[Union[int, float]] = None, # delay time
|
1051
|
-
init: Optional[Union[ArrayLike, Callable]] = None, # delay data
|
1056
|
+
init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
|
1052
1057
|
entries: Optional[Dict] = None, # delay access entry
|
1053
|
-
|
1058
|
+
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
1059
|
+
interp_method: str = _INTERP_LINEAR, # interpolation method
|
1054
1060
|
# others
|
1055
1061
|
name: Optional[str] = None,
|
1056
1062
|
mode: Optional[Mode] = None,
|
1057
1063
|
):
|
1058
1064
|
|
1059
1065
|
# target information
|
1060
|
-
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape,
|
1066
|
+
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
1061
1067
|
|
1062
1068
|
# delay method
|
1063
|
-
assert
|
1064
|
-
|
1069
|
+
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
|
1070
|
+
f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
|
1071
|
+
self.delay_method = delay_method
|
1072
|
+
|
1073
|
+
# interp method
|
1074
|
+
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
|
1075
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
1076
|
+
self.interp_method = interp_method
|
1065
1077
|
|
1066
1078
|
# delay length and time
|
1067
1079
|
self.max_time, delay_length = _get_delay(time, None)
|
@@ -1071,7 +1083,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1071
1083
|
|
1072
1084
|
# delay data
|
1073
1085
|
if init is not None:
|
1074
|
-
|
1086
|
+
if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
|
1087
|
+
raise TypeError(f'init should be Array, Callable, or None. But got {init}')
|
1075
1088
|
self._init = init
|
1076
1089
|
self._history = None
|
1077
1090
|
|
@@ -1085,7 +1098,11 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1085
1098
|
|
1086
1099
|
def __repr__(self):
|
1087
1100
|
name = self.__class__.__name__
|
1088
|
-
return f'{name}(
|
1101
|
+
return (f'{name}('
|
1102
|
+
f'delay_length={self.max_length}, '
|
1103
|
+
f'target_info={self.target_info}, '
|
1104
|
+
f'delay_method="{self.delay_method}", '
|
1105
|
+
f'interp_method="{self.interp_method}")')
|
1089
1106
|
|
1090
1107
|
@property
|
1091
1108
|
def history(self):
|
@@ -1100,7 +1117,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1100
1117
|
if batch_size is not None:
|
1101
1118
|
shape.insert(self.mode.batch_axis, batch_size)
|
1102
1119
|
shape.insert(0, length)
|
1103
|
-
if isinstance(self._init, (jax.Array, numbers.Number)):
|
1120
|
+
if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
|
1104
1121
|
data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
|
1105
1122
|
elif callable(self._init):
|
1106
1123
|
data = self._init(shape, dtype=a.dtype)
|
@@ -1115,13 +1132,20 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1115
1132
|
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
1116
1133
|
self.history = State(jax.tree.map(fun, self.target_info))
|
1117
1134
|
|
1135
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
1136
|
+
if batch_size is not None:
|
1137
|
+
assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
|
1138
|
+
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
1139
|
+
self.history.value = jax.tree.map(fun, self.target_info)
|
1140
|
+
|
1118
1141
|
def register_entry(
|
1119
1142
|
self,
|
1120
1143
|
entry: str,
|
1121
1144
|
delay_time: Optional[Union[int, float]] = None,
|
1122
1145
|
delay_step: Optional[int] = None,
|
1123
1146
|
) -> 'Delay':
|
1124
|
-
"""
|
1147
|
+
"""
|
1148
|
+
Register an entry to access the delay data.
|
1125
1149
|
|
1126
1150
|
Args:
|
1127
1151
|
entry: str. The entry to access the delay data.
|
@@ -1151,7 +1175,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1151
1175
|
return self
|
1152
1176
|
|
1153
1177
|
def at(self, entry: str, *indices) -> ArrayLike:
|
1154
|
-
"""
|
1178
|
+
"""
|
1179
|
+
Get the data at the given entry.
|
1155
1180
|
|
1156
1181
|
Args:
|
1157
1182
|
entry: str. The entry to access the data.
|
@@ -1167,20 +1192,28 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1167
1192
|
delay_step = self._registered_entries[entry]
|
1168
1193
|
if delay_step is None:
|
1169
1194
|
delay_step = 0
|
1170
|
-
return self.
|
1195
|
+
return self.retrieve_at_step(delay_step, *indices)
|
1171
1196
|
|
1172
|
-
def
|
1173
|
-
"""
|
1197
|
+
def retrieve_at_step(self, delay_step, *indices) -> PyTree:
|
1198
|
+
"""
|
1199
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
1174
1200
|
|
1175
1201
|
Parameters
|
1176
1202
|
----------
|
1177
|
-
delay_step:
|
1178
|
-
|
1203
|
+
delay_step: int_like
|
1204
|
+
Retrieve the data at the given time step.
|
1205
|
+
indices: tuple
|
1206
|
+
The indices to slice the data.
|
1207
|
+
|
1208
|
+
Returns
|
1209
|
+
-------
|
1210
|
+
delay_data: The delay data at the given delay step.
|
1211
|
+
|
1179
1212
|
"""
|
1180
1213
|
assert self.history is not None, 'The delay history is not initialized.'
|
1181
1214
|
assert delay_step is not None, 'The delay step should be given.'
|
1182
1215
|
|
1183
|
-
if environ.get(environ.JIT_ERROR_CHECK,
|
1216
|
+
if environ.get(environ.JIT_ERROR_CHECK, True):
|
1184
1217
|
def _check_delay(delay_len):
|
1185
1218
|
raise ValueError(f'The request delay length should be less than the '
|
1186
1219
|
f'maximum delay {self.max_length}. But we got {delay_len}')
|
@@ -1188,17 +1221,17 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1188
1221
|
jit_error(delay_step >= self.max_length, _check_delay, delay_step)
|
1189
1222
|
|
1190
1223
|
# rotation method
|
1191
|
-
if self.
|
1192
|
-
i = environ.get(environ.I)
|
1224
|
+
if self.delay_method == _DELAY_ROTATE:
|
1225
|
+
i = environ.get(environ.I, desc='The time step index.')
|
1193
1226
|
di = i - delay_step
|
1194
1227
|
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
1195
1228
|
delay_idx = jax.lax.stop_gradient(delay_idx)
|
1196
1229
|
|
1197
|
-
elif self.
|
1230
|
+
elif self.delay_method == _DELAY_CONCAT:
|
1198
1231
|
delay_idx = delay_step
|
1199
1232
|
|
1200
1233
|
else:
|
1201
|
-
raise ValueError(f'Unknown updating method "{self.
|
1234
|
+
raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
1202
1235
|
|
1203
1236
|
# the delay index
|
1204
1237
|
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
|
@@ -1208,6 +1241,81 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1208
1241
|
# the delay data
|
1209
1242
|
return jax.tree.map(lambda a: a[indices], self.history.value)
|
1210
1243
|
|
1244
|
+
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
1245
|
+
"""
|
1246
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
1247
|
+
|
1248
|
+
Parameters
|
1249
|
+
----------
|
1250
|
+
delay_time: float
|
1251
|
+
Retrieve the data at the given time.
|
1252
|
+
indices: tuple
|
1253
|
+
The indices to slice the data.
|
1254
|
+
|
1255
|
+
Returns
|
1256
|
+
-------
|
1257
|
+
delay_data: The delay data at the given delay step.
|
1258
|
+
|
1259
|
+
"""
|
1260
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
1261
|
+
assert delay_time is not None, 'The delay time should be given.'
|
1262
|
+
|
1263
|
+
current_time = environ.get(environ.T, desc='The current time.')
|
1264
|
+
dt = environ.get_dt()
|
1265
|
+
|
1266
|
+
if environ.get(environ.JIT_ERROR_CHECK, True):
|
1267
|
+
def _check_delay(args):
|
1268
|
+
t_now, t_delay = args
|
1269
|
+
raise ValueError(f'The request delay time should be within '
|
1270
|
+
f'[{t_now - self.max_time - dt}, {t_now}], '
|
1271
|
+
f'but we got {t_delay}')
|
1272
|
+
|
1273
|
+
jit_error(jnp.logical_or(delay_time > current_time,
|
1274
|
+
delay_time < current_time - self.max_time - dt),
|
1275
|
+
_check_delay,
|
1276
|
+
(current_time, delay_time))
|
1277
|
+
|
1278
|
+
diff = current_time - delay_time
|
1279
|
+
float_time_step = diff / dt
|
1280
|
+
|
1281
|
+
if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
|
1282
|
+
# def _interp(target):
|
1283
|
+
# if len(indices) > 0:
|
1284
|
+
# raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
|
1285
|
+
# if self.delay_method == _DELAY_ROTATE:
|
1286
|
+
# i = environ.get(environ.I, desc='The time step index.')
|
1287
|
+
# _interp_fun = partial(jnp.interp, period=self.max_length)
|
1288
|
+
# for dim in range(1, target.ndim, 1):
|
1289
|
+
# _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
|
1290
|
+
# di = i - jnp.arange(self.max_length)
|
1291
|
+
# delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
1292
|
+
# return _interp_fun(float_time_step, delay_idx, target)
|
1293
|
+
#
|
1294
|
+
# elif self.delay_method == _DELAY_CONCAT:
|
1295
|
+
# _interp_fun = partial(jnp.interp, period=self.max_length)
|
1296
|
+
# for dim in range(1, target.ndim, 1):
|
1297
|
+
# _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
|
1298
|
+
# return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
|
1299
|
+
#
|
1300
|
+
# else:
|
1301
|
+
# raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
1302
|
+
# return jax.tree.map(_interp, self.history.value)
|
1303
|
+
|
1304
|
+
data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
|
1305
|
+
data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
|
1306
|
+
t_diff = float_time_step - jnp.floor(float_time_step)
|
1307
|
+
return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
|
1308
|
+
|
1309
|
+
elif self.interp_method == _INTERP_ROUND: # "round" interpolation
|
1310
|
+
return self.retrieve_at_step(
|
1311
|
+
jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
|
1312
|
+
*indices
|
1313
|
+
)
|
1314
|
+
|
1315
|
+
else: # raise error
|
1316
|
+
raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
|
1317
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
1318
|
+
|
1211
1319
|
def update(self, current: PyTree) -> None:
|
1212
1320
|
"""
|
1213
1321
|
Update delay variable with the new data.
|
@@ -1215,25 +1323,29 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1215
1323
|
assert self.history is not None, 'The delay history is not initialized.'
|
1216
1324
|
|
1217
1325
|
# update the delay data at the rotation index
|
1218
|
-
if self.
|
1326
|
+
if self.delay_method == _DELAY_ROTATE:
|
1219
1327
|
i = environ.get(environ.I)
|
1220
1328
|
idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
|
1221
1329
|
idx = jax.lax.stop_gradient(idx)
|
1222
|
-
self.history.value = jax.tree.map(
|
1223
|
-
|
1224
|
-
|
1330
|
+
self.history.value = jax.tree.map(
|
1331
|
+
lambda hist, cur: hist.at[idx].set(cur),
|
1332
|
+
self.history.value,
|
1333
|
+
current
|
1334
|
+
)
|
1225
1335
|
# update the delay data at the first position
|
1226
|
-
elif self.
|
1336
|
+
elif self.delay_method == _DELAY_CONCAT:
|
1227
1337
|
current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
|
1228
1338
|
if self.max_length > 1:
|
1229
|
-
self.history.value = jax.tree.map(
|
1230
|
-
|
1231
|
-
|
1339
|
+
self.history.value = jax.tree.map(
|
1340
|
+
lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
|
1341
|
+
self.history.value,
|
1342
|
+
current
|
1343
|
+
)
|
1232
1344
|
else:
|
1233
1345
|
self.history.value = current
|
1234
1346
|
|
1235
1347
|
else:
|
1236
|
-
raise ValueError(f'Unknown updating method "{self.
|
1348
|
+
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
1237
1349
|
|
1238
1350
|
|
1239
1351
|
class _StateDelay(Delay):
|
@@ -1254,14 +1366,18 @@ class _StateDelay(Delay):
|
|
1254
1366
|
time: Optional[Union[int, float]] = None, # delay time
|
1255
1367
|
init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
|
1256
1368
|
entries: Optional[Dict] = None, # delay access entry
|
1257
|
-
|
1369
|
+
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
1258
1370
|
# others
|
1259
1371
|
name: Optional[str] = None,
|
1260
1372
|
mode: Optional[Mode] = None,
|
1261
1373
|
):
|
1262
1374
|
super().__init__(target_info=target.value,
|
1263
|
-
time=time,
|
1264
|
-
|
1375
|
+
time=time,
|
1376
|
+
init=init,
|
1377
|
+
entries=entries,
|
1378
|
+
delay_method=delay_method,
|
1379
|
+
name=name,
|
1380
|
+
mode=mode)
|
1265
1381
|
self.target = target
|
1266
1382
|
|
1267
1383
|
def update(self, *args, **kwargs):
|
@@ -1344,7 +1460,7 @@ def call_order(level: int = 0):
|
|
1344
1460
|
@set_module_as('brainstate')
|
1345
1461
|
def init_states(target: Module, *args, **kwargs) -> Module:
|
1346
1462
|
"""
|
1347
|
-
|
1463
|
+
Initialize states of all children nodes in the given target.
|
1348
1464
|
|
1349
1465
|
Args:
|
1350
1466
|
target: The target Module.
|
@@ -1368,6 +1484,33 @@ def init_states(target: Module, *args, **kwargs) -> Module:
|
|
1368
1484
|
return target
|
1369
1485
|
|
1370
1486
|
|
1487
|
+
@set_module_as('brainstate')
|
1488
|
+
def reset_states(target: Module, *args, **kwargs) -> Module:
|
1489
|
+
"""
|
1490
|
+
Reset states of all children nodes in the given target.
|
1491
|
+
|
1492
|
+
Args:
|
1493
|
+
target: The target Module.
|
1494
|
+
|
1495
|
+
Returns:
|
1496
|
+
The target Module.
|
1497
|
+
"""
|
1498
|
+
nodes_with_order = []
|
1499
|
+
|
1500
|
+
# reset node whose `init_state` has no `call_order`
|
1501
|
+
for node in list(target.nodes().values()):
|
1502
|
+
if not hasattr(node.reset_state, 'call_order'):
|
1503
|
+
node.reset_state(*args, **kwargs)
|
1504
|
+
else:
|
1505
|
+
nodes_with_order.append(node)
|
1506
|
+
|
1507
|
+
# reset the node's states
|
1508
|
+
for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
|
1509
|
+
node.reset_state(*args, **kwargs)
|
1510
|
+
|
1511
|
+
return target
|
1512
|
+
|
1513
|
+
|
1371
1514
|
@set_module_as('brainstate')
|
1372
1515
|
def load_states(target: Module, state_dict: Dict, **kwargs):
|
1373
1516
|
"""Copy parameters and buffers from :attr:`state_dict` into
|
brainstate/_module_test.py
CHANGED
@@ -16,14 +16,15 @@
|
|
16
16
|
import unittest
|
17
17
|
|
18
18
|
import jax.numpy as jnp
|
19
|
+
import jaxlib.xla_extension
|
19
20
|
|
20
|
-
import brainstate as
|
21
|
+
import brainstate as bst
|
21
22
|
|
22
23
|
|
23
|
-
class
|
24
|
+
class TestDelay(unittest.TestCase):
|
24
25
|
def test_delay1(self):
|
25
|
-
a =
|
26
|
-
delay =
|
26
|
+
a = bst.State(bst.random.random(10, 20))
|
27
|
+
delay = bst.Delay(a.value)
|
27
28
|
delay.register_entry('a', 1.)
|
28
29
|
delay.register_entry('b', 2.)
|
29
30
|
delay.register_entry('c', None)
|
@@ -31,10 +32,10 @@ class TestVarDelay(unittest.TestCase):
|
|
31
32
|
delay.init_state()
|
32
33
|
with self.assertRaises(KeyError):
|
33
34
|
delay.register_entry('c', 10.)
|
34
|
-
|
35
|
+
bst.util.clear_buffer_memory()
|
35
36
|
|
36
37
|
def test_rotation_delay(self):
|
37
|
-
rotation_delay =
|
38
|
+
rotation_delay = bst.Delay(jnp.ones((1,)))
|
38
39
|
t0 = 0.
|
39
40
|
t1, n1 = 1., 10
|
40
41
|
t2, n2 = 2., 20
|
@@ -51,16 +52,16 @@ class TestVarDelay(unittest.TestCase):
|
|
51
52
|
# print(rotation_delay.max_length)
|
52
53
|
|
53
54
|
for i in range(100):
|
54
|
-
|
55
|
+
bst.environ.set(i=i)
|
55
56
|
rotation_delay(jnp.ones((1,)) * i)
|
56
57
|
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
57
58
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
58
59
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
59
60
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
60
|
-
|
61
|
+
bst.util.clear_buffer_memory()
|
61
62
|
|
62
63
|
def test_concat_delay(self):
|
63
|
-
rotation_delay =
|
64
|
+
rotation_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
|
64
65
|
t0 = 0.
|
65
66
|
t1, n1 = 1., 10
|
66
67
|
t2, n2 = 2., 20
|
@@ -73,17 +74,91 @@ class TestVarDelay(unittest.TestCase):
|
|
73
74
|
|
74
75
|
print()
|
75
76
|
for i in range(100):
|
76
|
-
|
77
|
+
bst.environ.set(i=i)
|
77
78
|
rotation_delay(jnp.ones((1,)) * i)
|
78
79
|
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
79
80
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
80
81
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
81
82
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
82
|
-
|
83
|
+
bst.util.clear_buffer_memory()
|
84
|
+
|
85
|
+
def test_jit_erro(self):
|
86
|
+
rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
|
+
rotation_delay.init_state()
|
88
|
+
|
89
|
+
with bst.environ.context(i=0, t=0):
|
90
|
+
rotation_delay.retrieve_at_time(-2.0)
|
91
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
92
|
+
rotation_delay.retrieve_at_time(-2.1)
|
93
|
+
rotation_delay.retrieve_at_time(-2.01)
|
94
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
95
|
+
rotation_delay.retrieve_at_time(-2.09)
|
96
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
97
|
+
rotation_delay.retrieve_at_time(0.1)
|
98
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
99
|
+
rotation_delay.retrieve_at_time(0.01)
|
100
|
+
|
101
|
+
def test_round_interp(self):
|
102
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
|
+
for delay_method in ['rotation', 'concat']:
|
104
|
+
rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='round')
|
105
|
+
t0, n1 = 0.01, 0
|
106
|
+
t1, n1 = 1.04, 10
|
107
|
+
t2, n2 = 1.06, 11
|
108
|
+
rotation_delay.init_state()
|
109
|
+
|
110
|
+
@bst.transform.jit
|
111
|
+
def retrieve(td, i):
|
112
|
+
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
113
|
+
return rotation_delay.retrieve_at_time(td)
|
114
|
+
|
115
|
+
print()
|
116
|
+
for i in range(100):
|
117
|
+
t = i * bst.environ.get_dt()
|
118
|
+
with bst.environ.context(i=i, t=t):
|
119
|
+
rotation_delay(jnp.ones(shape) * i)
|
120
|
+
print(i,
|
121
|
+
retrieve(t - t0, i),
|
122
|
+
retrieve(t - t1, i),
|
123
|
+
retrieve(t - t2, i))
|
124
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
|
125
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
126
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
127
|
+
bst.util.clear_buffer_memory()
|
128
|
+
|
129
|
+
def test_linear_interp(self):
|
130
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
131
|
+
for delay_method in ['rotation', 'concat']:
|
132
|
+
print(shape, delay_method)
|
133
|
+
|
134
|
+
rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='linear_interp')
|
135
|
+
t0, n0 = 0.01, 0.1
|
136
|
+
t1, n1 = 1.04, 10.4
|
137
|
+
t2, n2 = 1.06, 10.6
|
138
|
+
rotation_delay.init_state()
|
139
|
+
|
140
|
+
@bst.transform.jit
|
141
|
+
def retrieve(td, i):
|
142
|
+
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
143
|
+
return rotation_delay.retrieve_at_time(td)
|
144
|
+
|
145
|
+
print()
|
146
|
+
for i in range(100):
|
147
|
+
t = i * bst.environ.get_dt()
|
148
|
+
with bst.environ.context(i=i, t=t):
|
149
|
+
rotation_delay(jnp.ones(shape) * i)
|
150
|
+
print(i,
|
151
|
+
retrieve(t - t0, i),
|
152
|
+
retrieve(t - t1, i),
|
153
|
+
retrieve(t - t2, i))
|
154
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
|
155
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
156
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
157
|
+
bst.util.clear_buffer_memory()
|
83
158
|
|
84
159
|
def test_rotation_and_concat_delay(self):
|
85
|
-
rotation_delay =
|
86
|
-
concat_delay =
|
160
|
+
rotation_delay = bst.Delay(jnp.ones((1,)))
|
161
|
+
concat_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
|
87
162
|
t0 = 0.
|
88
163
|
t1, n1 = 1., 10
|
89
164
|
t2, n2 = 2., 20
|
@@ -100,29 +175,29 @@ class TestVarDelay(unittest.TestCase):
|
|
100
175
|
|
101
176
|
print()
|
102
177
|
for i in range(100):
|
103
|
-
|
178
|
+
bst.environ.set(i=i)
|
104
179
|
new = jnp.ones((1,)) * i
|
105
180
|
rotation_delay(new)
|
106
181
|
concat_delay(new)
|
107
182
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
108
183
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
109
184
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
110
|
-
|
185
|
+
bst.util.clear_buffer_memory()
|
111
186
|
|
112
187
|
|
113
188
|
class TestModule(unittest.TestCase):
|
114
189
|
def test_states(self):
|
115
|
-
class A(
|
190
|
+
class A(bst.Module):
|
116
191
|
def __init__(self):
|
117
192
|
super().__init__()
|
118
|
-
self.a =
|
119
|
-
self.b =
|
193
|
+
self.a = bst.State(bst.random.random(10, 20))
|
194
|
+
self.b = bst.State(bst.random.random(10, 20))
|
120
195
|
|
121
|
-
class B(
|
196
|
+
class B(bst.Module):
|
122
197
|
def __init__(self):
|
123
198
|
super().__init__()
|
124
199
|
self.a = A()
|
125
|
-
self.b =
|
200
|
+
self.b = bst.State(bst.random.random(10, 20))
|
126
201
|
|
127
202
|
b = B()
|
128
203
|
print()
|
@@ -130,4 +205,3 @@ class TestModule(unittest.TestCase):
|
|
130
205
|
print(b.states())
|
131
206
|
print(b.states(level=0))
|
132
207
|
print(b.states(level=0))
|
133
|
-
|
brainstate/_state.py
CHANGED
@@ -59,6 +59,23 @@ _global_context_to_check_state_tree = [False]
|
|
59
59
|
def check_state_value_tree() -> None:
|
60
60
|
"""
|
61
61
|
The contex manager to check weather the tree structure of the state value keeps consistently.
|
62
|
+
|
63
|
+
Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
|
64
|
+
the tree structure of the value is not checked to avoid off the repeated evaluation.
|
65
|
+
If you want to check the tree structure of the value once the new value is assigned,
|
66
|
+
you can use this context manager.
|
67
|
+
|
68
|
+
Example::
|
69
|
+
|
70
|
+
```python
|
71
|
+
state = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
72
|
+
with check_state_value_tree():
|
73
|
+
state.value = jnp.zeros((2, 3))
|
74
|
+
|
75
|
+
# The following code will raise an error.
|
76
|
+
state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
|
77
|
+
```
|
78
|
+
|
62
79
|
"""
|
63
80
|
try:
|
64
81
|
_global_context_to_check_state_tree.append(True)
|