brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240623__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +147 -42
  3. brainstate/_module_test.py +95 -21
  4. brainstate/environ.py +0 -1
  5. brainstate/functional/__init__.py +2 -2
  6. brainstate/functional/_activations.py +7 -26
  7. brainstate/functional/_spikes.py +0 -1
  8. brainstate/mixin.py +2 -2
  9. brainstate/nn/_elementwise.py +5 -4
  10. brainstate/nn/_misc.py +4 -3
  11. brainstate/nn/_others.py +3 -2
  12. brainstate/nn/_poolings.py +21 -20
  13. brainstate/nn/_poolings_test.py +4 -4
  14. brainstate/optim/__init__.py +0 -1
  15. brainstate/optim/_sgd_optimizer.py +18 -17
  16. brainstate/transform/__init__.py +2 -3
  17. brainstate/transform/_autograd.py +1 -1
  18. brainstate/transform/_autograd_test.py +0 -2
  19. brainstate/transform/_jit_test.py +0 -3
  20. brainstate/transform/_make_jaxpr.py +0 -1
  21. brainstate/transform/_make_jaxpr_test.py +0 -2
  22. brainstate/transform/_progress_bar.py +1 -3
  23. brainstate/util.py +0 -1
  24. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/METADATA +2 -12
  25. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/RECORD +28 -35
  26. brainstate/math/__init__.py +0 -21
  27. brainstate/math/_einops.py +0 -787
  28. brainstate/math/_einops_parsing.py +0 -169
  29. brainstate/math/_einops_parsing_test.py +0 -126
  30. brainstate/math/_einops_test.py +0 -346
  31. brainstate/math/_misc.py +0 -298
  32. brainstate/math/_misc_test.py +0 -58
  33. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/LICENSE +0 -0
  34. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/WHEEL +0 -0
  35. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -20,17 +20,16 @@ A ``State``-based Transformation System for Brain Dynamics Programming
20
20
  __version__ = "0.0.1"
21
21
 
22
22
  from . import environ
23
- from . import math
23
+ from . import functional
24
+ from . import init
24
25
  from . import mixin
25
26
  from . import nn
26
27
  from . import optim
27
28
  from . import random
29
+ from . import surrogate
28
30
  from . import transform
29
31
  from . import typing
30
32
  from . import util
31
- from . import surrogate
32
- from . import functional
33
- from . import init
34
33
  from ._module import *
35
34
  from ._module import __all__ as _module_all
36
35
  from ._state import *
@@ -39,7 +38,7 @@ from ._state import __all__ as _state_all
39
38
  __all__ = (
40
39
  ['environ', 'share', 'nn', 'optim', 'random',
41
40
  'surrogate', 'functional', 'init',
42
- 'mixin', 'math', 'transform', 'util', 'typing'] +
41
+ 'mixin', 'transform', 'util', 'typing'] +
43
42
  _module_all + _state_all
44
43
  )
45
44
  del _module_all, _state_all
brainstate/_module.py CHANGED
@@ -59,9 +59,8 @@ import numpy as np
59
59
  from . import environ
60
60
  from ._state import State, StateDictManager, visible_state_dict
61
61
  from ._utils import set_module_as
62
- from .math import get_dtype
63
62
  from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
64
- from .transform._jit_error import jit_error
63
+ from .transform import jit_error
65
64
  from .util import unique_name, DictManager, get_unique_name
66
65
 
67
66
  Shape = Union[int, Sequence[int]]
@@ -69,8 +68,10 @@ PyTree = Any
69
68
  ArrayLike = jax.typing.ArrayLike
70
69
 
71
70
  delay_identifier = '_*_delay_of_'
72
- ROTATE_UPDATE = 'rotation'
73
- CONCAT_UPDATE = 'concat'
71
+ _DELAY_ROTATE = 'rotation'
72
+ _DELAY_CONCAT = 'concat'
73
+ _INTERP_LINEAR = 'linear_interp'
74
+ _INTERP_ROUND = 'round'
74
75
 
75
76
  StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
76
77
 
@@ -257,9 +258,8 @@ class Module(object):
257
258
  --------
258
259
 
259
260
  >>> import brainstate as bst
260
- >>> import brainscale as nn # noqa
261
261
  >>> x = bst.random.rand((10, 10))
262
- >>> l = nn.Activation(jax.numpy.tanh)
262
+ >>> l = bst.nn.Activation(jax.numpy.tanh)
263
263
  >>> y = x >> l
264
264
  """
265
265
  return self.__call__(other)
@@ -401,8 +401,8 @@ class visible_module_list(list):
401
401
  retieved when using :py:func:`~.nodes()` function.
402
402
 
403
403
  >>> import brainstate as bst
404
- >>> l = bst.visible_module_list([bp.dnn.Dense(1, 2),
405
- >>> bp.dnn.LSTMCell(2, 3)])
404
+ >>> l = bst.visible_module_list([bst.nn.Linear(1, 2),
405
+ >>> bst.nn.LSTMCell(2, 3)])
406
406
  """
407
407
 
408
408
  __module__ = 'brainstate'
@@ -1038,14 +1038,14 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1038
1038
  delay = length data ]
1039
1039
  entries: optional, dict. The delay access entries.
1040
1040
  name: str. The delay name.
1041
- method: str. The method used for updating delay. Default None.
1041
+ delay_method: str. The method used for updating delay. Default None.
1042
1042
  mode: Mode. The computing mode. Default None.
1043
1043
  """
1044
1044
 
1045
1045
  __module__ = 'brainstate'
1046
1046
 
1047
- non_hash_params = ('time', 'entries', 'name')
1048
- max_time: float
1047
+ non_hashable_params = ('time', 'entries', 'name')
1048
+ max_time: float #
1049
1049
  max_length: int
1050
1050
  history: Optional[State]
1051
1051
 
@@ -1053,20 +1053,27 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1053
1053
  self,
1054
1054
  target_info: PyTree,
1055
1055
  time: Optional[Union[int, float]] = None, # delay time
1056
- init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
1056
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
1057
1057
  entries: Optional[Dict] = None, # delay access entry
1058
- method: Optional[str] = ROTATE_UPDATE, # delay method
1058
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
1059
+ interp_method: str = _INTERP_LINEAR, # interpolation method
1059
1060
  # others
1060
1061
  name: Optional[str] = None,
1061
1062
  mode: Optional[Mode] = None,
1062
1063
  ):
1063
1064
 
1064
1065
  # target information
1065
- self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, get_dtype(a)), target_info)
1066
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
1066
1067
 
1067
1068
  # delay method
1068
- assert method in [ROTATE_UPDATE, CONCAT_UPDATE]
1069
- self.method = method
1069
+ assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
1070
+ f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
1071
+ self.delay_method = delay_method
1072
+
1073
+ # interp method
1074
+ assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
1075
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
1076
+ self.interp_method = interp_method
1070
1077
 
1071
1078
  # delay length and time
1072
1079
  self.max_time, delay_length = _get_delay(time, None)
@@ -1076,7 +1083,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1076
1083
 
1077
1084
  # delay data
1078
1085
  if init is not None:
1079
- assert isinstance(init, (numbers.Number, jax.Array, Callable))
1086
+ if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
1087
+ raise TypeError(f'init should be Array, Callable, or None. But got {init}')
1080
1088
  self._init = init
1081
1089
  self._history = None
1082
1090
 
@@ -1090,7 +1098,11 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1090
1098
 
1091
1099
  def __repr__(self):
1092
1100
  name = self.__class__.__name__
1093
- return f'{name}(delay_length={self.max_length}, target_info={self.target_info}, method="{self.method}")'
1101
+ return (f'{name}('
1102
+ f'delay_length={self.max_length}, '
1103
+ f'target_info={self.target_info}, '
1104
+ f'delay_method="{self.delay_method}", '
1105
+ f'interp_method="{self.interp_method}")')
1094
1106
 
1095
1107
  @property
1096
1108
  def history(self):
@@ -1105,7 +1117,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1105
1117
  if batch_size is not None:
1106
1118
  shape.insert(self.mode.batch_axis, batch_size)
1107
1119
  shape.insert(0, length)
1108
- if isinstance(self._init, (jax.Array, numbers.Number)):
1120
+ if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
1109
1121
  data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
1110
1122
  elif callable(self._init):
1111
1123
  data = self._init(shape, dtype=a.dtype)
@@ -1132,7 +1144,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1132
1144
  delay_time: Optional[Union[int, float]] = None,
1133
1145
  delay_step: Optional[int] = None,
1134
1146
  ) -> 'Delay':
1135
- """Register an entry to access the data.
1147
+ """
1148
+ Register an entry to access the delay data.
1136
1149
 
1137
1150
  Args:
1138
1151
  entry: str. The entry to access the delay data.
@@ -1162,7 +1175,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1162
1175
  return self
1163
1176
 
1164
1177
  def at(self, entry: str, *indices) -> ArrayLike:
1165
- """Get the data at the given entry.
1178
+ """
1179
+ Get the data at the given entry.
1166
1180
 
1167
1181
  Args:
1168
1182
  entry: str. The entry to access the data.
@@ -1178,15 +1192,23 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1178
1192
  delay_step = self._registered_entries[entry]
1179
1193
  if delay_step is None:
1180
1194
  delay_step = 0
1181
- return self.retrieve(delay_step, *indices)
1195
+ return self.retrieve_at_step(delay_step, *indices)
1182
1196
 
1183
- def retrieve(self, delay_step, *indices):
1184
- """Retrieve the delay data according to the delay length.
1197
+ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
1198
+ """
1199
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
1185
1200
 
1186
1201
  Parameters
1187
1202
  ----------
1188
- delay_step: int
1189
- The delay length used to retrieve the data.
1203
+ delay_step: int_like
1204
+ Retrieve the data at the given time step.
1205
+ indices: tuple
1206
+ The indices to slice the data.
1207
+
1208
+ Returns
1209
+ -------
1210
+ delay_data: The delay data at the given delay step.
1211
+
1190
1212
  """
1191
1213
  assert self.history is not None, 'The delay history is not initialized.'
1192
1214
  assert delay_step is not None, 'The delay step should be given.'
@@ -1199,17 +1221,17 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1199
1221
  jit_error(delay_step >= self.max_length, _check_delay, delay_step)
1200
1222
 
1201
1223
  # rotation method
1202
- if self.method == ROTATE_UPDATE:
1203
- i = environ.get(environ.I)
1224
+ if self.delay_method == _DELAY_ROTATE:
1225
+ i = environ.get(environ.I, desc='The time step index.')
1204
1226
  di = i - delay_step
1205
1227
  delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1206
1228
  delay_idx = jax.lax.stop_gradient(delay_idx)
1207
1229
 
1208
- elif self.method == CONCAT_UPDATE:
1230
+ elif self.delay_method == _DELAY_CONCAT:
1209
1231
  delay_idx = delay_step
1210
1232
 
1211
1233
  else:
1212
- raise ValueError(f'Unknown updating method "{self.method}"')
1234
+ raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1213
1235
 
1214
1236
  # the delay index
1215
1237
  if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
@@ -1219,6 +1241,81 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1219
1241
  # the delay data
1220
1242
  return jax.tree.map(lambda a: a[indices], self.history.value)
1221
1243
 
1244
+ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
1245
+ """
1246
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
1247
+
1248
+ Parameters
1249
+ ----------
1250
+ delay_time: float
1251
+ Retrieve the data at the given time.
1252
+ indices: tuple
1253
+ The indices to slice the data.
1254
+
1255
+ Returns
1256
+ -------
1257
+ delay_data: The delay data at the given delay step.
1258
+
1259
+ """
1260
+ assert self.history is not None, 'The delay history is not initialized.'
1261
+ assert delay_time is not None, 'The delay time should be given.'
1262
+
1263
+ current_time = environ.get(environ.T, desc='The current time.')
1264
+ dt = environ.get_dt()
1265
+
1266
+ if environ.get(environ.JIT_ERROR_CHECK, False):
1267
+ def _check_delay(args):
1268
+ t_now, t_delay = args
1269
+ raise ValueError(f'The request delay time should be within '
1270
+ f'[{t_now - self.max_time - dt}, {t_now}], '
1271
+ f'but we got {t_delay}')
1272
+
1273
+ jit_error(jnp.logical_or(delay_time > current_time,
1274
+ delay_time < current_time - self.max_time - dt),
1275
+ _check_delay,
1276
+ (current_time, delay_time))
1277
+
1278
+ diff = current_time - delay_time
1279
+ float_time_step = diff / dt
1280
+
1281
+ if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
1282
+ # def _interp(target):
1283
+ # if len(indices) > 0:
1284
+ # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
1285
+ # if self.delay_method == _DELAY_ROTATE:
1286
+ # i = environ.get(environ.I, desc='The time step index.')
1287
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1288
+ # for dim in range(1, target.ndim, 1):
1289
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1290
+ # di = i - jnp.arange(self.max_length)
1291
+ # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1292
+ # return _interp_fun(float_time_step, delay_idx, target)
1293
+ #
1294
+ # elif self.delay_method == _DELAY_CONCAT:
1295
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1296
+ # for dim in range(1, target.ndim, 1):
1297
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1298
+ # return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
1299
+ #
1300
+ # else:
1301
+ # raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1302
+ # return jax.tree.map(_interp, self.history.value)
1303
+
1304
+ data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
1305
+ data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
1306
+ t_diff = float_time_step - jnp.floor(float_time_step)
1307
+ return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
1308
+
1309
+ elif self.interp_method == _INTERP_ROUND: # "round" interpolation
1310
+ return self.retrieve_at_step(
1311
+ jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
1312
+ *indices
1313
+ )
1314
+
1315
+ else: # raise error
1316
+ raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
1317
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
1318
+
1222
1319
  def update(self, current: PyTree) -> None:
1223
1320
  """
1224
1321
  Update delay variable with the new data.
@@ -1226,25 +1323,29 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1226
1323
  assert self.history is not None, 'The delay history is not initialized.'
1227
1324
 
1228
1325
  # update the delay data at the rotation index
1229
- if self.method == ROTATE_UPDATE:
1326
+ if self.delay_method == _DELAY_ROTATE:
1230
1327
  i = environ.get(environ.I)
1231
1328
  idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
1232
1329
  idx = jax.lax.stop_gradient(idx)
1233
- self.history.value = jax.tree.map(lambda hist, cur: hist.at[idx].set(cur),
1234
- self.history.value,
1235
- current)
1330
+ self.history.value = jax.tree.map(
1331
+ lambda hist, cur: hist.at[idx].set(cur),
1332
+ self.history.value,
1333
+ current
1334
+ )
1236
1335
  # update the delay data at the first position
1237
- elif self.method == CONCAT_UPDATE:
1336
+ elif self.delay_method == _DELAY_CONCAT:
1238
1337
  current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
1239
1338
  if self.max_length > 1:
1240
- self.history.value = jax.tree.map(lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
1241
- self.history.value,
1242
- current)
1339
+ self.history.value = jax.tree.map(
1340
+ lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
1341
+ self.history.value,
1342
+ current
1343
+ )
1243
1344
  else:
1244
1345
  self.history.value = current
1245
1346
 
1246
1347
  else:
1247
- raise ValueError(f'Unknown updating method "{self.method}"')
1348
+ raise ValueError(f'Unknown updating method "{self.delay_method}"')
1248
1349
 
1249
1350
 
1250
1351
  class _StateDelay(Delay):
@@ -1265,14 +1366,18 @@ class _StateDelay(Delay):
1265
1366
  time: Optional[Union[int, float]] = None, # delay time
1266
1367
  init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
1267
1368
  entries: Optional[Dict] = None, # delay access entry
1268
- method: Optional[str] = ROTATE_UPDATE, # delay method
1369
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
1269
1370
  # others
1270
1371
  name: Optional[str] = None,
1271
1372
  mode: Optional[Mode] = None,
1272
1373
  ):
1273
1374
  super().__init__(target_info=target.value,
1274
- time=time, init=init, entries=entries,
1275
- method=method, name=name, mode=mode)
1375
+ time=time,
1376
+ init=init,
1377
+ entries=entries,
1378
+ delay_method=delay_method,
1379
+ name=name,
1380
+ mode=mode)
1276
1381
  self.target = target
1277
1382
 
1278
1383
  def update(self, *args, **kwargs):
@@ -16,14 +16,15 @@
16
16
  import unittest
17
17
 
18
18
  import jax.numpy as jnp
19
+ import jaxlib.xla_extension
19
20
 
20
- import brainstate as bc
21
+ import brainstate as bst
21
22
 
22
23
 
23
- class TestVarDelay(unittest.TestCase):
24
+ class TestDelay(unittest.TestCase):
24
25
  def test_delay1(self):
25
- a = bc.State(bc.random.random(10, 20))
26
- delay = bc.Delay(a.value)
26
+ a = bst.State(bst.random.random(10, 20))
27
+ delay = bst.Delay(a.value)
27
28
  delay.register_entry('a', 1.)
28
29
  delay.register_entry('b', 2.)
29
30
  delay.register_entry('c', None)
@@ -31,10 +32,10 @@ class TestVarDelay(unittest.TestCase):
31
32
  delay.init_state()
32
33
  with self.assertRaises(KeyError):
33
34
  delay.register_entry('c', 10.)
34
- bc.util.clear_buffer_memory()
35
+ bst.util.clear_buffer_memory()
35
36
 
36
37
  def test_rotation_delay(self):
37
- rotation_delay = bc.Delay(jnp.ones((1,)))
38
+ rotation_delay = bst.Delay(jnp.ones((1,)))
38
39
  t0 = 0.
39
40
  t1, n1 = 1., 10
40
41
  t2, n2 = 2., 20
@@ -51,16 +52,16 @@ class TestVarDelay(unittest.TestCase):
51
52
  # print(rotation_delay.max_length)
52
53
 
53
54
  for i in range(100):
54
- bc.environ.set(i=i)
55
+ bst.environ.set(i=i)
55
56
  rotation_delay(jnp.ones((1,)) * i)
56
57
  # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
57
58
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
58
59
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
59
60
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
60
- bc.util.clear_buffer_memory()
61
+ bst.util.clear_buffer_memory()
61
62
 
62
63
  def test_concat_delay(self):
63
- rotation_delay = bc.Delay(jnp.ones([1]), method='concat')
64
+ rotation_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
64
65
  t0 = 0.
65
66
  t1, n1 = 1., 10
66
67
  t2, n2 = 2., 20
@@ -73,17 +74,91 @@ class TestVarDelay(unittest.TestCase):
73
74
 
74
75
  print()
75
76
  for i in range(100):
76
- bc.environ.set(i=i)
77
+ bst.environ.set(i=i)
77
78
  rotation_delay(jnp.ones((1,)) * i)
78
79
  print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
79
80
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
80
81
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
81
82
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
82
- bc.util.clear_buffer_memory()
83
+ bst.util.clear_buffer_memory()
84
+
85
+ def test_jit_erro(self):
86
+ rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
87
+ rotation_delay.init_state()
88
+
89
+ with bst.environ.context(i=0, t=0):
90
+ rotation_delay.retrieve_at_time(-2.0)
91
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
92
+ rotation_delay.retrieve_at_time(-2.1)
93
+ rotation_delay.retrieve_at_time(-2.01)
94
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
95
+ rotation_delay.retrieve_at_time(-2.09)
96
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
97
+ rotation_delay.retrieve_at_time(0.1)
98
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
99
+ rotation_delay.retrieve_at_time(0.01)
100
+
101
+ def test_round_interp(self):
102
+ for shape in [(1,), (1, 1), (1, 1, 1)]:
103
+ for delay_method in ['rotation', 'concat']:
104
+ rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='round')
105
+ t0, n1 = 0.01, 0
106
+ t1, n1 = 1.04, 10
107
+ t2, n2 = 1.06, 11
108
+ rotation_delay.init_state()
109
+
110
+ @bst.transform.jit
111
+ def retrieve(td, i):
112
+ with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
113
+ return rotation_delay.retrieve_at_time(td)
114
+
115
+ print()
116
+ for i in range(100):
117
+ t = i * bst.environ.get_dt()
118
+ with bst.environ.context(i=i, t=t):
119
+ rotation_delay(jnp.ones(shape) * i)
120
+ print(i,
121
+ retrieve(t - t0, i),
122
+ retrieve(t - t1, i),
123
+ retrieve(t - t2, i))
124
+ self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
125
+ self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
126
+ self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
127
+ bst.util.clear_buffer_memory()
128
+
129
+ def test_linear_interp(self):
130
+ for shape in [(1,), (1, 1), (1, 1, 1)]:
131
+ for delay_method in ['rotation', 'concat']:
132
+ print(shape, delay_method)
133
+
134
+ rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='linear_interp')
135
+ t0, n0 = 0.01, 0.1
136
+ t1, n1 = 1.04, 10.4
137
+ t2, n2 = 1.06, 10.6
138
+ rotation_delay.init_state()
139
+
140
+ @bst.transform.jit
141
+ def retrieve(td, i):
142
+ with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
143
+ return rotation_delay.retrieve_at_time(td)
144
+
145
+ print()
146
+ for i in range(100):
147
+ t = i * bst.environ.get_dt()
148
+ with bst.environ.context(i=i, t=t):
149
+ rotation_delay(jnp.ones(shape) * i)
150
+ print(i,
151
+ retrieve(t - t0, i),
152
+ retrieve(t - t1, i),
153
+ retrieve(t - t2, i))
154
+ self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
155
+ self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
156
+ self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
157
+ bst.util.clear_buffer_memory()
83
158
 
84
159
  def test_rotation_and_concat_delay(self):
85
- rotation_delay = bc.Delay(jnp.ones((1,)))
86
- concat_delay = bc.Delay(jnp.ones([1]), method='concat')
160
+ rotation_delay = bst.Delay(jnp.ones((1,)))
161
+ concat_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
87
162
  t0 = 0.
88
163
  t1, n1 = 1., 10
89
164
  t2, n2 = 2., 20
@@ -100,29 +175,29 @@ class TestVarDelay(unittest.TestCase):
100
175
 
101
176
  print()
102
177
  for i in range(100):
103
- bc.environ.set(i=i)
178
+ bst.environ.set(i=i)
104
179
  new = jnp.ones((1,)) * i
105
180
  rotation_delay(new)
106
181
  concat_delay(new)
107
182
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
108
183
  self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
109
184
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
110
- bc.util.clear_buffer_memory()
185
+ bst.util.clear_buffer_memory()
111
186
 
112
187
 
113
188
  class TestModule(unittest.TestCase):
114
189
  def test_states(self):
115
- class A(bc.Module):
190
+ class A(bst.Module):
116
191
  def __init__(self):
117
192
  super().__init__()
118
- self.a = bc.State(bc.random.random(10, 20))
119
- self.b = bc.State(bc.random.random(10, 20))
193
+ self.a = bst.State(bst.random.random(10, 20))
194
+ self.b = bst.State(bst.random.random(10, 20))
120
195
 
121
- class B(bc.Module):
196
+ class B(bst.Module):
122
197
  def __init__(self):
123
198
  super().__init__()
124
199
  self.a = A()
125
- self.b = bc.State(bc.random.random(10, 20))
200
+ self.b = bst.State(bst.random.random(10, 20))
126
201
 
127
202
  b = B()
128
203
  print()
@@ -130,4 +205,3 @@ class TestModule(unittest.TestCase):
130
205
  print(b.states())
131
206
  print(b.states(level=0))
132
207
  print(b.states(level=0))
133
-
brainstate/environ.py CHANGED
@@ -24,7 +24,6 @@ __all__ = [
24
24
  'dftype', 'ditype', 'dutype', 'dctype',
25
25
  ]
26
26
 
27
-
28
27
  # Default, there are several shared arguments in the global context.
29
28
  I = 'i' # the index of the current computation.
30
29
  T = 't' # the current time of the current computation.
@@ -18,9 +18,9 @@ from ._activations import *
18
18
  from ._activations import __all__ as __activations_all__
19
19
  from ._normalization import *
20
20
  from ._normalization import __all__ as __others_all__
21
- from ._spikes import *
22
- from ._spikes import __all__ as __spikes_all__
23
21
  from ._others import *
24
22
  from ._others import __all__ as __others_all__
23
+ from ._spikes import *
24
+ from ._spikes import __all__ as __spikes_all__
25
25
 
26
26
  __all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
@@ -27,7 +27,7 @@ import jax.numpy as jnp
27
27
  from jax.scipy.special import logsumexp
28
28
  from jax.typing import ArrayLike
29
29
 
30
- from .. import math, random
30
+ from .. import random
31
31
 
32
32
  __all__ = [
33
33
  "tanh",
@@ -136,10 +136,7 @@ def prelu(x, a=0.25):
136
136
  parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
137
137
  a separate :math:`a` is used for each input channel.
138
138
  """
139
- dtype = math.get_dtype(x)
140
- return jnp.where(x >= jnp.asarray(0., dtype),
141
- x,
142
- jnp.asarray(a, dtype) * x)
139
+ return jnp.where(x >= 0., x, a * x)
143
140
 
144
141
 
145
142
  def soft_shrink(x, lambd=0.5):
@@ -161,11 +158,7 @@ def soft_shrink(x, lambd=0.5):
161
158
  - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
162
159
  - Output: :math:`(*)`, same shape as the input.
163
160
  """
164
- dtype = math.get_dtype(x)
165
- lambd = jnp.asarray(lambd, dtype)
166
- return jnp.where(x > lambd,
167
- x - lambd,
168
- jnp.where(x < -lambd, x + lambd, jnp.asarray(0., dtype)))
161
+ return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.))
169
162
 
170
163
 
171
164
  def mish(x):
@@ -217,9 +210,8 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
217
210
  .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
218
211
  https://arxiv.org/abs/1505.00853
219
212
  """
220
- dtype = math.get_dtype(x)
221
- a = random.uniform(lower, upper, size=jnp.shape(x), dtype=dtype)
222
- return jnp.where(x >= jnp.asarray(0., dtype), x, jnp.asarray(a, dtype) * x)
213
+ a = random.uniform(lower, upper, size=jnp.shape(x), dtype=x.dtype)
214
+ return jnp.where(x >= 0., x, a * x)
223
215
 
224
216
 
225
217
  def hard_shrink(x, lambd=0.5):
@@ -243,11 +235,7 @@ def hard_shrink(x, lambd=0.5):
243
235
  - Output: :math:`(*)`, same shape as the input.
244
236
 
245
237
  """
246
- dtype = math.get_dtype(x)
247
- lambd = jnp.asarray(lambd, dtype)
248
- return jnp.where(x > lambd,
249
- x,
250
- jnp.where(x < -lambd, x, jnp.asarray(0., dtype)))
238
+ return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.))
251
239
 
252
240
 
253
241
  def relu(x: ArrayLike) -> jax.Array:
@@ -298,8 +286,7 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
298
286
  x : input array
299
287
  b : smoothness parameter
300
288
  """
301
- dtype = math.get_dtype(x)
302
- return jax.nn.squareplus(x, jnp.asarray(b, dtype))
289
+ return jax.nn.squareplus(x, b)
303
290
 
304
291
 
305
292
  def softplus(x: ArrayLike) -> jax.Array:
@@ -417,8 +404,6 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
417
404
  See also:
418
405
  :func:`selu`
419
406
  """
420
- dtype = math.get_dtype(x)
421
- alpha = jnp.asarray(alpha, dtype)
422
407
  return jax.nn.elu(x, alpha)
423
408
 
424
409
 
@@ -445,8 +430,6 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
445
430
  See also:
446
431
  :func:`relu`
447
432
  """
448
- dtype = math.get_dtype(x)
449
- negative_slope = jnp.asarray(negative_slope, dtype)
450
433
  return jax.nn.leaky_relu(x, negative_slope=negative_slope)
451
434
 
452
435
 
@@ -493,8 +476,6 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
493
476
  Returns:
494
477
  An array.
495
478
  """
496
- dtype = math.get_dtype(x)
497
- alpha = jnp.asarray(alpha, dtype)
498
479
  return jax.nn.celu(x, alpha)
499
480
 
500
481
 
@@ -87,4 +87,3 @@ def spike_bitwise(x, y, op: str):
87
87
  return spike_bitwise_ixor(x, y)
88
88
  else:
89
89
  raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
90
-