brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +191 -48
  3. brainstate/_module_test.py +95 -21
  4. brainstate/_state.py +17 -0
  5. brainstate/environ.py +2 -2
  6. brainstate/functional/__init__.py +3 -2
  7. brainstate/functional/_activations.py +7 -26
  8. brainstate/functional/_normalization.py +3 -0
  9. brainstate/functional/_others.py +49 -0
  10. brainstate/functional/_spikes.py +0 -1
  11. brainstate/mixin.py +2 -2
  12. brainstate/nn/__init__.py +4 -0
  13. brainstate/nn/_base.py +10 -7
  14. brainstate/nn/_dynamics.py +20 -0
  15. brainstate/nn/_elementwise.py +5 -4
  16. brainstate/nn/_embedding.py +66 -0
  17. brainstate/nn/_misc.py +4 -3
  18. brainstate/nn/_others.py +3 -2
  19. brainstate/nn/_poolings.py +21 -20
  20. brainstate/nn/_poolings_test.py +4 -4
  21. brainstate/nn/_rate_rnns.py +17 -0
  22. brainstate/nn/_readout.py +6 -0
  23. brainstate/optim/__init__.py +0 -1
  24. brainstate/optim/_lr_scheduler_test.py +13 -0
  25. brainstate/optim/_sgd_optimizer.py +18 -17
  26. brainstate/transform/__init__.py +2 -3
  27. brainstate/transform/_autograd.py +1 -1
  28. brainstate/transform/_autograd_test.py +0 -2
  29. brainstate/transform/_jit.py +47 -21
  30. brainstate/transform/_jit_test.py +0 -3
  31. brainstate/transform/_make_jaxpr.py +164 -3
  32. brainstate/transform/_make_jaxpr_test.py +0 -2
  33. brainstate/transform/_progress_bar.py +1 -3
  34. brainstate/util.py +0 -1
  35. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
  36. brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
  37. brainstate/math/__init__.py +0 -21
  38. brainstate/math/_einops.py +0 -787
  39. brainstate/math/_einops_parsing.py +0 -169
  40. brainstate/math/_einops_parsing_test.py +0 -126
  41. brainstate/math/_einops_test.py +0 -346
  42. brainstate/math/_misc.py +0 -298
  43. brainstate/math/_misc_test.py +0 -58
  44. brainstate/nn/functional/__init__.py +0 -25
  45. brainstate/nn/functional/_activations.py +0 -754
  46. brainstate/nn/functional/_normalization.py +0 -69
  47. brainstate/nn/functional/_spikes.py +0 -90
  48. brainstate/nn/init/__init__.py +0 -26
  49. brainstate/nn/init/_base.py +0 -36
  50. brainstate/nn/init/_generic.py +0 -175
  51. brainstate/nn/init/_random_inits.py +0 -489
  52. brainstate/nn/init/_regular_inits.py +0 -109
  53. brainstate/nn/surrogate.py +0 -1740
  54. brainstate-0.0.1.dist-info/RECORD +0 -79
  55. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
  56. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
  57. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -20,17 +20,16 @@ A ``State``-based Transformation System for Brain Dynamics Programming
20
20
  __version__ = "0.0.1"
21
21
 
22
22
  from . import environ
23
- from . import math
23
+ from . import functional
24
+ from . import init
24
25
  from . import mixin
25
26
  from . import nn
26
27
  from . import optim
27
28
  from . import random
29
+ from . import surrogate
28
30
  from . import transform
29
31
  from . import typing
30
32
  from . import util
31
- from . import surrogate
32
- from . import functional
33
- from . import init
34
33
  from ._module import *
35
34
  from ._module import __all__ as _module_all
36
35
  from ._state import *
@@ -39,7 +38,7 @@ from ._state import __all__ as _state_all
39
38
  __all__ = (
40
39
  ['environ', 'share', 'nn', 'optim', 'random',
41
40
  'surrogate', 'functional', 'init',
42
- 'mixin', 'math', 'transform', 'util', 'typing'] +
41
+ 'mixin', 'transform', 'util', 'typing'] +
43
42
  _module_all + _state_all
44
43
  )
45
44
  del _module_all, _state_all
brainstate/_module.py CHANGED
@@ -46,7 +46,6 @@ For handling the delays:
46
46
 
47
47
  """
48
48
 
49
- import inspect
50
49
  import math
51
50
  import numbers
52
51
  from collections import namedtuple
@@ -58,20 +57,21 @@ import jax.numpy as jnp
58
57
  import numpy as np
59
58
 
60
59
  from . import environ
61
- from ._utils import set_module_as
62
60
  from ._state import State, StateDictManager, visible_state_dict
63
- from .util import unique_name, DictManager, get_unique_name, DotDict
64
- from .math import get_dtype
61
+ from ._utils import set_module_as
65
62
  from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
66
- from .transform._jit_error import jit_error
63
+ from .transform import jit_error
64
+ from .util import unique_name, DictManager, get_unique_name
67
65
 
68
66
  Shape = Union[int, Sequence[int]]
69
67
  PyTree = Any
70
68
  ArrayLike = jax.typing.ArrayLike
71
69
 
72
70
  delay_identifier = '_*_delay_of_'
73
- ROTATE_UPDATE = 'rotation'
74
- CONCAT_UPDATE = 'concat'
71
+ _DELAY_ROTATE = 'rotation'
72
+ _DELAY_CONCAT = 'concat'
73
+ _INTERP_LINEAR = 'linear_interp'
74
+ _INTERP_ROUND = 'round'
75
75
 
76
76
  StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
77
77
 
@@ -92,7 +92,7 @@ __all__ = [
92
92
  'call_order',
93
93
 
94
94
  # state processing
95
- 'init_states', 'load_states', 'save_states', 'assign_state_values',
95
+ 'init_states', 'reset_states', 'load_states', 'save_states', 'assign_state_values',
96
96
  ]
97
97
 
98
98
 
@@ -258,9 +258,8 @@ class Module(object):
258
258
  --------
259
259
 
260
260
  >>> import brainstate as bst
261
- >>> import brainscale as nn # noqa
262
261
  >>> x = bst.random.rand((10, 10))
263
- >>> l = nn.Activation(jax.numpy.tanh)
262
+ >>> l = bst.nn.Activation(jax.numpy.tanh)
264
263
  >>> y = x >> l
265
264
  """
266
265
  return self.__call__(other)
@@ -271,6 +270,12 @@ class Module(object):
271
270
  """
272
271
  pass
273
272
 
273
+ def reset_state(self, *args, **kwargs):
274
+ """
275
+ State resetting function.
276
+ """
277
+ pass
278
+
274
279
  def save_state(self, **kwargs) -> Dict:
275
280
  """Save states as a dictionary. """
276
281
  return self.states(include_self=True, level=0, method='absolute')
@@ -396,8 +401,8 @@ class visible_module_list(list):
396
401
  retieved when using :py:func:`~.nodes()` function.
397
402
 
398
403
  >>> import brainstate as bst
399
- >>> l = bst.visible_module_list([bp.dnn.Dense(1, 2),
400
- >>> bp.dnn.LSTMCell(2, 3)])
404
+ >>> l = bst.visible_module_list([bst.nn.Linear(1, 2),
405
+ >>> bst.nn.LSTMCell(2, 3)])
401
406
  """
402
407
 
403
408
  __module__ = 'brainstate'
@@ -1033,14 +1038,14 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1033
1038
  delay = length data ]
1034
1039
  entries: optional, dict. The delay access entries.
1035
1040
  name: str. The delay name.
1036
- method: str. The method used for updating delay. Default None.
1041
+ delay_method: str. The method used for updating delay. Default None.
1037
1042
  mode: Mode. The computing mode. Default None.
1038
1043
  """
1039
1044
 
1040
1045
  __module__ = 'brainstate'
1041
1046
 
1042
- non_hash_params = ('time', 'entries', 'name')
1043
- max_time: float
1047
+ non_hashable_params = ('time', 'entries', 'name')
1048
+ max_time: float #
1044
1049
  max_length: int
1045
1050
  history: Optional[State]
1046
1051
 
@@ -1048,20 +1053,27 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1048
1053
  self,
1049
1054
  target_info: PyTree,
1050
1055
  time: Optional[Union[int, float]] = None, # delay time
1051
- init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
1056
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
1052
1057
  entries: Optional[Dict] = None, # delay access entry
1053
- method: Optional[str] = ROTATE_UPDATE, # delay method
1058
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
1059
+ interp_method: str = _INTERP_LINEAR, # interpolation method
1054
1060
  # others
1055
1061
  name: Optional[str] = None,
1056
1062
  mode: Optional[Mode] = None,
1057
1063
  ):
1058
1064
 
1059
1065
  # target information
1060
- self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, get_dtype(a)), target_info)
1066
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
1061
1067
 
1062
1068
  # delay method
1063
- assert method in [ROTATE_UPDATE, CONCAT_UPDATE]
1064
- self.method = method
1069
+ assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
1070
+ f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
1071
+ self.delay_method = delay_method
1072
+
1073
+ # interp method
1074
+ assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
1075
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
1076
+ self.interp_method = interp_method
1065
1077
 
1066
1078
  # delay length and time
1067
1079
  self.max_time, delay_length = _get_delay(time, None)
@@ -1071,7 +1083,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1071
1083
 
1072
1084
  # delay data
1073
1085
  if init is not None:
1074
- assert isinstance(init, (numbers.Number, jax.Array, Callable))
1086
+ if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
1087
+ raise TypeError(f'init should be Array, Callable, or None. But got {init}')
1075
1088
  self._init = init
1076
1089
  self._history = None
1077
1090
 
@@ -1085,7 +1098,11 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1085
1098
 
1086
1099
  def __repr__(self):
1087
1100
  name = self.__class__.__name__
1088
- return f'{name}(delay_length={self.max_length}, target_info={self.target_info}, method="{self.method}")'
1101
+ return (f'{name}('
1102
+ f'delay_length={self.max_length}, '
1103
+ f'target_info={self.target_info}, '
1104
+ f'delay_method="{self.delay_method}", '
1105
+ f'interp_method="{self.interp_method}")')
1089
1106
 
1090
1107
  @property
1091
1108
  def history(self):
@@ -1100,7 +1117,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1100
1117
  if batch_size is not None:
1101
1118
  shape.insert(self.mode.batch_axis, batch_size)
1102
1119
  shape.insert(0, length)
1103
- if isinstance(self._init, (jax.Array, numbers.Number)):
1120
+ if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
1104
1121
  data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
1105
1122
  elif callable(self._init):
1106
1123
  data = self._init(shape, dtype=a.dtype)
@@ -1115,13 +1132,20 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1115
1132
  fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
1116
1133
  self.history = State(jax.tree.map(fun, self.target_info))
1117
1134
 
1135
+ def reset_state(self, batch_size: int = None, **kwargs):
1136
+ if batch_size is not None:
1137
+ assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
1138
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
1139
+ self.history.value = jax.tree.map(fun, self.target_info)
1140
+
1118
1141
  def register_entry(
1119
1142
  self,
1120
1143
  entry: str,
1121
1144
  delay_time: Optional[Union[int, float]] = None,
1122
1145
  delay_step: Optional[int] = None,
1123
1146
  ) -> 'Delay':
1124
- """Register an entry to access the data.
1147
+ """
1148
+ Register an entry to access the delay data.
1125
1149
 
1126
1150
  Args:
1127
1151
  entry: str. The entry to access the delay data.
@@ -1151,7 +1175,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1151
1175
  return self
1152
1176
 
1153
1177
  def at(self, entry: str, *indices) -> ArrayLike:
1154
- """Get the data at the given entry.
1178
+ """
1179
+ Get the data at the given entry.
1155
1180
 
1156
1181
  Args:
1157
1182
  entry: str. The entry to access the data.
@@ -1167,20 +1192,28 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1167
1192
  delay_step = self._registered_entries[entry]
1168
1193
  if delay_step is None:
1169
1194
  delay_step = 0
1170
- return self.retrieve(delay_step, *indices)
1195
+ return self.retrieve_at_step(delay_step, *indices)
1171
1196
 
1172
- def retrieve(self, delay_step, *indices):
1173
- """Retrieve the delay data according to the delay length.
1197
+ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
1198
+ """
1199
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
1174
1200
 
1175
1201
  Parameters
1176
1202
  ----------
1177
- delay_step: int
1178
- The delay length used to retrieve the data.
1203
+ delay_step: int_like
1204
+ Retrieve the data at the given time step.
1205
+ indices: tuple
1206
+ The indices to slice the data.
1207
+
1208
+ Returns
1209
+ -------
1210
+ delay_data: The delay data at the given delay step.
1211
+
1179
1212
  """
1180
1213
  assert self.history is not None, 'The delay history is not initialized.'
1181
1214
  assert delay_step is not None, 'The delay step should be given.'
1182
1215
 
1183
- if environ.get(environ.JIT_ERROR_CHECK, False):
1216
+ if environ.get(environ.JIT_ERROR_CHECK, True):
1184
1217
  def _check_delay(delay_len):
1185
1218
  raise ValueError(f'The request delay length should be less than the '
1186
1219
  f'maximum delay {self.max_length}. But we got {delay_len}')
@@ -1188,17 +1221,17 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1188
1221
  jit_error(delay_step >= self.max_length, _check_delay, delay_step)
1189
1222
 
1190
1223
  # rotation method
1191
- if self.method == ROTATE_UPDATE:
1192
- i = environ.get(environ.I)
1224
+ if self.delay_method == _DELAY_ROTATE:
1225
+ i = environ.get(environ.I, desc='The time step index.')
1193
1226
  di = i - delay_step
1194
1227
  delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1195
1228
  delay_idx = jax.lax.stop_gradient(delay_idx)
1196
1229
 
1197
- elif self.method == CONCAT_UPDATE:
1230
+ elif self.delay_method == _DELAY_CONCAT:
1198
1231
  delay_idx = delay_step
1199
1232
 
1200
1233
  else:
1201
- raise ValueError(f'Unknown updating method "{self.method}"')
1234
+ raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1202
1235
 
1203
1236
  # the delay index
1204
1237
  if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
@@ -1208,6 +1241,81 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1208
1241
  # the delay data
1209
1242
  return jax.tree.map(lambda a: a[indices], self.history.value)
1210
1243
 
1244
+ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
1245
+ """
1246
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
1247
+
1248
+ Parameters
1249
+ ----------
1250
+ delay_time: float
1251
+ Retrieve the data at the given time.
1252
+ indices: tuple
1253
+ The indices to slice the data.
1254
+
1255
+ Returns
1256
+ -------
1257
+ delay_data: The delay data at the given delay step.
1258
+
1259
+ """
1260
+ assert self.history is not None, 'The delay history is not initialized.'
1261
+ assert delay_time is not None, 'The delay time should be given.'
1262
+
1263
+ current_time = environ.get(environ.T, desc='The current time.')
1264
+ dt = environ.get_dt()
1265
+
1266
+ if environ.get(environ.JIT_ERROR_CHECK, True):
1267
+ def _check_delay(args):
1268
+ t_now, t_delay = args
1269
+ raise ValueError(f'The request delay time should be within '
1270
+ f'[{t_now - self.max_time - dt}, {t_now}], '
1271
+ f'but we got {t_delay}')
1272
+
1273
+ jit_error(jnp.logical_or(delay_time > current_time,
1274
+ delay_time < current_time - self.max_time - dt),
1275
+ _check_delay,
1276
+ (current_time, delay_time))
1277
+
1278
+ diff = current_time - delay_time
1279
+ float_time_step = diff / dt
1280
+
1281
+ if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
1282
+ # def _interp(target):
1283
+ # if len(indices) > 0:
1284
+ # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
1285
+ # if self.delay_method == _DELAY_ROTATE:
1286
+ # i = environ.get(environ.I, desc='The time step index.')
1287
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1288
+ # for dim in range(1, target.ndim, 1):
1289
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1290
+ # di = i - jnp.arange(self.max_length)
1291
+ # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1292
+ # return _interp_fun(float_time_step, delay_idx, target)
1293
+ #
1294
+ # elif self.delay_method == _DELAY_CONCAT:
1295
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1296
+ # for dim in range(1, target.ndim, 1):
1297
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1298
+ # return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
1299
+ #
1300
+ # else:
1301
+ # raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1302
+ # return jax.tree.map(_interp, self.history.value)
1303
+
1304
+ data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
1305
+ data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
1306
+ t_diff = float_time_step - jnp.floor(float_time_step)
1307
+ return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
1308
+
1309
+ elif self.interp_method == _INTERP_ROUND: # "round" interpolation
1310
+ return self.retrieve_at_step(
1311
+ jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
1312
+ *indices
1313
+ )
1314
+
1315
+ else: # raise error
1316
+ raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
1317
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
1318
+
1211
1319
  def update(self, current: PyTree) -> None:
1212
1320
  """
1213
1321
  Update delay variable with the new data.
@@ -1215,25 +1323,29 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1215
1323
  assert self.history is not None, 'The delay history is not initialized.'
1216
1324
 
1217
1325
  # update the delay data at the rotation index
1218
- if self.method == ROTATE_UPDATE:
1326
+ if self.delay_method == _DELAY_ROTATE:
1219
1327
  i = environ.get(environ.I)
1220
1328
  idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
1221
1329
  idx = jax.lax.stop_gradient(idx)
1222
- self.history.value = jax.tree.map(lambda hist, cur: hist.at[idx].set(cur),
1223
- self.history.value,
1224
- current)
1330
+ self.history.value = jax.tree.map(
1331
+ lambda hist, cur: hist.at[idx].set(cur),
1332
+ self.history.value,
1333
+ current
1334
+ )
1225
1335
  # update the delay data at the first position
1226
- elif self.method == CONCAT_UPDATE:
1336
+ elif self.delay_method == _DELAY_CONCAT:
1227
1337
  current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
1228
1338
  if self.max_length > 1:
1229
- self.history.value = jax.tree.map(lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
1230
- self.history.value,
1231
- current)
1339
+ self.history.value = jax.tree.map(
1340
+ lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
1341
+ self.history.value,
1342
+ current
1343
+ )
1232
1344
  else:
1233
1345
  self.history.value = current
1234
1346
 
1235
1347
  else:
1236
- raise ValueError(f'Unknown updating method "{self.method}"')
1348
+ raise ValueError(f'Unknown updating method "{self.delay_method}"')
1237
1349
 
1238
1350
 
1239
1351
  class _StateDelay(Delay):
@@ -1254,14 +1366,18 @@ class _StateDelay(Delay):
1254
1366
  time: Optional[Union[int, float]] = None, # delay time
1255
1367
  init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
1256
1368
  entries: Optional[Dict] = None, # delay access entry
1257
- method: Optional[str] = ROTATE_UPDATE, # delay method
1369
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
1258
1370
  # others
1259
1371
  name: Optional[str] = None,
1260
1372
  mode: Optional[Mode] = None,
1261
1373
  ):
1262
1374
  super().__init__(target_info=target.value,
1263
- time=time, init=init, entries=entries,
1264
- method=method, name=name, mode=mode)
1375
+ time=time,
1376
+ init=init,
1377
+ entries=entries,
1378
+ delay_method=delay_method,
1379
+ name=name,
1380
+ mode=mode)
1265
1381
  self.target = target
1266
1382
 
1267
1383
  def update(self, *args, **kwargs):
@@ -1344,7 +1460,7 @@ def call_order(level: int = 0):
1344
1460
  @set_module_as('brainstate')
1345
1461
  def init_states(target: Module, *args, **kwargs) -> Module:
1346
1462
  """
1347
- Reset states of all children nodes in the given target.
1463
+ Initialize states of all children nodes in the given target.
1348
1464
 
1349
1465
  Args:
1350
1466
  target: The target Module.
@@ -1368,6 +1484,33 @@ def init_states(target: Module, *args, **kwargs) -> Module:
1368
1484
  return target
1369
1485
 
1370
1486
 
1487
+ @set_module_as('brainstate')
1488
+ def reset_states(target: Module, *args, **kwargs) -> Module:
1489
+ """
1490
+ Reset states of all children nodes in the given target.
1491
+
1492
+ Args:
1493
+ target: The target Module.
1494
+
1495
+ Returns:
1496
+ The target Module.
1497
+ """
1498
+ nodes_with_order = []
1499
+
1500
+ # reset node whose `init_state` has no `call_order`
1501
+ for node in list(target.nodes().values()):
1502
+ if not hasattr(node.reset_state, 'call_order'):
1503
+ node.reset_state(*args, **kwargs)
1504
+ else:
1505
+ nodes_with_order.append(node)
1506
+
1507
+ # reset the node's states
1508
+ for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
1509
+ node.reset_state(*args, **kwargs)
1510
+
1511
+ return target
1512
+
1513
+
1371
1514
  @set_module_as('brainstate')
1372
1515
  def load_states(target: Module, state_dict: Dict, **kwargs):
1373
1516
  """Copy parameters and buffers from :attr:`state_dict` into
@@ -16,14 +16,15 @@
16
16
  import unittest
17
17
 
18
18
  import jax.numpy as jnp
19
+ import jaxlib.xla_extension
19
20
 
20
- import brainstate as bc
21
+ import brainstate as bst
21
22
 
22
23
 
23
- class TestVarDelay(unittest.TestCase):
24
+ class TestDelay(unittest.TestCase):
24
25
  def test_delay1(self):
25
- a = bc.State(bc.random.random(10, 20))
26
- delay = bc.Delay(a.value)
26
+ a = bst.State(bst.random.random(10, 20))
27
+ delay = bst.Delay(a.value)
27
28
  delay.register_entry('a', 1.)
28
29
  delay.register_entry('b', 2.)
29
30
  delay.register_entry('c', None)
@@ -31,10 +32,10 @@ class TestVarDelay(unittest.TestCase):
31
32
  delay.init_state()
32
33
  with self.assertRaises(KeyError):
33
34
  delay.register_entry('c', 10.)
34
- bc.util.clear_buffer_memory()
35
+ bst.util.clear_buffer_memory()
35
36
 
36
37
  def test_rotation_delay(self):
37
- rotation_delay = bc.Delay(jnp.ones((1,)))
38
+ rotation_delay = bst.Delay(jnp.ones((1,)))
38
39
  t0 = 0.
39
40
  t1, n1 = 1., 10
40
41
  t2, n2 = 2., 20
@@ -51,16 +52,16 @@ class TestVarDelay(unittest.TestCase):
51
52
  # print(rotation_delay.max_length)
52
53
 
53
54
  for i in range(100):
54
- bc.environ.set(i=i)
55
+ bst.environ.set(i=i)
55
56
  rotation_delay(jnp.ones((1,)) * i)
56
57
  # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
57
58
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
58
59
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
59
60
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
60
- bc.util.clear_buffer_memory()
61
+ bst.util.clear_buffer_memory()
61
62
 
62
63
  def test_concat_delay(self):
63
- rotation_delay = bc.Delay(jnp.ones([1]), method='concat')
64
+ rotation_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
64
65
  t0 = 0.
65
66
  t1, n1 = 1., 10
66
67
  t2, n2 = 2., 20
@@ -73,17 +74,91 @@ class TestVarDelay(unittest.TestCase):
73
74
 
74
75
  print()
75
76
  for i in range(100):
76
- bc.environ.set(i=i)
77
+ bst.environ.set(i=i)
77
78
  rotation_delay(jnp.ones((1,)) * i)
78
79
  print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
79
80
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
80
81
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
81
82
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
82
- bc.util.clear_buffer_memory()
83
+ bst.util.clear_buffer_memory()
84
+
85
+ def test_jit_erro(self):
86
+ rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
87
+ rotation_delay.init_state()
88
+
89
+ with bst.environ.context(i=0, t=0):
90
+ rotation_delay.retrieve_at_time(-2.0)
91
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
92
+ rotation_delay.retrieve_at_time(-2.1)
93
+ rotation_delay.retrieve_at_time(-2.01)
94
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
95
+ rotation_delay.retrieve_at_time(-2.09)
96
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
97
+ rotation_delay.retrieve_at_time(0.1)
98
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
99
+ rotation_delay.retrieve_at_time(0.01)
100
+
101
+ def test_round_interp(self):
102
+ for shape in [(1,), (1, 1), (1, 1, 1)]:
103
+ for delay_method in ['rotation', 'concat']:
104
+ rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='round')
105
+ t0, n1 = 0.01, 0
106
+ t1, n1 = 1.04, 10
107
+ t2, n2 = 1.06, 11
108
+ rotation_delay.init_state()
109
+
110
+ @bst.transform.jit
111
+ def retrieve(td, i):
112
+ with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
113
+ return rotation_delay.retrieve_at_time(td)
114
+
115
+ print()
116
+ for i in range(100):
117
+ t = i * bst.environ.get_dt()
118
+ with bst.environ.context(i=i, t=t):
119
+ rotation_delay(jnp.ones(shape) * i)
120
+ print(i,
121
+ retrieve(t - t0, i),
122
+ retrieve(t - t1, i),
123
+ retrieve(t - t2, i))
124
+ self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
125
+ self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
126
+ self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
127
+ bst.util.clear_buffer_memory()
128
+
129
+ def test_linear_interp(self):
130
+ for shape in [(1,), (1, 1), (1, 1, 1)]:
131
+ for delay_method in ['rotation', 'concat']:
132
+ print(shape, delay_method)
133
+
134
+ rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='linear_interp')
135
+ t0, n0 = 0.01, 0.1
136
+ t1, n1 = 1.04, 10.4
137
+ t2, n2 = 1.06, 10.6
138
+ rotation_delay.init_state()
139
+
140
+ @bst.transform.jit
141
+ def retrieve(td, i):
142
+ with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
143
+ return rotation_delay.retrieve_at_time(td)
144
+
145
+ print()
146
+ for i in range(100):
147
+ t = i * bst.environ.get_dt()
148
+ with bst.environ.context(i=i, t=t):
149
+ rotation_delay(jnp.ones(shape) * i)
150
+ print(i,
151
+ retrieve(t - t0, i),
152
+ retrieve(t - t1, i),
153
+ retrieve(t - t2, i))
154
+ self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
155
+ self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
156
+ self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
157
+ bst.util.clear_buffer_memory()
83
158
 
84
159
  def test_rotation_and_concat_delay(self):
85
- rotation_delay = bc.Delay(jnp.ones((1,)))
86
- concat_delay = bc.Delay(jnp.ones([1]), method='concat')
160
+ rotation_delay = bst.Delay(jnp.ones((1,)))
161
+ concat_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
87
162
  t0 = 0.
88
163
  t1, n1 = 1., 10
89
164
  t2, n2 = 2., 20
@@ -100,29 +175,29 @@ class TestVarDelay(unittest.TestCase):
100
175
 
101
176
  print()
102
177
  for i in range(100):
103
- bc.environ.set(i=i)
178
+ bst.environ.set(i=i)
104
179
  new = jnp.ones((1,)) * i
105
180
  rotation_delay(new)
106
181
  concat_delay(new)
107
182
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
108
183
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
109
184
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
110
- bc.util.clear_buffer_memory()
185
+ bst.util.clear_buffer_memory()
111
186
 
112
187
 
113
188
  class TestModule(unittest.TestCase):
114
189
  def test_states(self):
115
- class A(bc.Module):
190
+ class A(bst.Module):
116
191
  def __init__(self):
117
192
  super().__init__()
118
- self.a = bc.State(bc.random.random(10, 20))
119
- self.b = bc.State(bc.random.random(10, 20))
193
+ self.a = bst.State(bst.random.random(10, 20))
194
+ self.b = bst.State(bst.random.random(10, 20))
120
195
 
121
- class B(bc.Module):
196
+ class B(bst.Module):
122
197
  def __init__(self):
123
198
  super().__init__()
124
199
  self.a = A()
125
- self.b = bc.State(bc.random.random(10, 20))
200
+ self.b = bst.State(bst.random.random(10, 20))
126
201
 
127
202
  b = B()
128
203
  print()
@@ -130,4 +205,3 @@ class TestModule(unittest.TestCase):
130
205
  print(b.states())
131
206
  print(b.states(level=0))
132
207
  print(b.states(level=0))
133
-
brainstate/_state.py CHANGED
@@ -59,6 +59,23 @@ _global_context_to_check_state_tree = [False]
59
59
  def check_state_value_tree() -> None:
60
60
  """
61
61
  The contex manager to check weather the tree structure of the state value keeps consistently.
62
+
63
+ Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
64
+ the tree structure of the value is not checked to avoid off the repeated evaluation.
65
+ If you want to check the tree structure of the value once the new value is assigned,
66
+ you can use this context manager.
67
+
68
+ Example::
69
+
70
+ ```python
71
+ state = brainstate.ShortTermState(jnp.zeros((2, 3)))
72
+ with check_state_value_tree():
73
+ state.value = jnp.zeros((2, 3))
74
+
75
+ # The following code will raise an error.
76
+ state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
77
+ ```
78
+
62
79
  """
63
80
  try:
64
81
  _global_context_to_check_state_tree.append(True)