brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240623__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +147 -42
- brainstate/_module_test.py +95 -21
- brainstate/environ.py +0 -1
- brainstate/functional/__init__.py +2 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +0 -1
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/METADATA +2 -12
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/RECORD +28 -35
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -20,17 +20,16 @@ A ``State``-based Transformation System for Brain Dynamics Programming
|
|
20
20
|
__version__ = "0.0.1"
|
21
21
|
|
22
22
|
from . import environ
|
23
|
-
from . import
|
23
|
+
from . import functional
|
24
|
+
from . import init
|
24
25
|
from . import mixin
|
25
26
|
from . import nn
|
26
27
|
from . import optim
|
27
28
|
from . import random
|
29
|
+
from . import surrogate
|
28
30
|
from . import transform
|
29
31
|
from . import typing
|
30
32
|
from . import util
|
31
|
-
from . import surrogate
|
32
|
-
from . import functional
|
33
|
-
from . import init
|
34
33
|
from ._module import *
|
35
34
|
from ._module import __all__ as _module_all
|
36
35
|
from ._state import *
|
@@ -39,7 +38,7 @@ from ._state import __all__ as _state_all
|
|
39
38
|
__all__ = (
|
40
39
|
['environ', 'share', 'nn', 'optim', 'random',
|
41
40
|
'surrogate', 'functional', 'init',
|
42
|
-
'mixin', '
|
41
|
+
'mixin', 'transform', 'util', 'typing'] +
|
43
42
|
_module_all + _state_all
|
44
43
|
)
|
45
44
|
del _module_all, _state_all
|
brainstate/_module.py
CHANGED
@@ -59,9 +59,8 @@ import numpy as np
|
|
59
59
|
from . import environ
|
60
60
|
from ._state import State, StateDictManager, visible_state_dict
|
61
61
|
from ._utils import set_module_as
|
62
|
-
from .math import get_dtype
|
63
62
|
from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
|
64
|
-
from .transform
|
63
|
+
from .transform import jit_error
|
65
64
|
from .util import unique_name, DictManager, get_unique_name
|
66
65
|
|
67
66
|
Shape = Union[int, Sequence[int]]
|
@@ -69,8 +68,10 @@ PyTree = Any
|
|
69
68
|
ArrayLike = jax.typing.ArrayLike
|
70
69
|
|
71
70
|
delay_identifier = '_*_delay_of_'
|
72
|
-
|
73
|
-
|
71
|
+
_DELAY_ROTATE = 'rotation'
|
72
|
+
_DELAY_CONCAT = 'concat'
|
73
|
+
_INTERP_LINEAR = 'linear_interp'
|
74
|
+
_INTERP_ROUND = 'round'
|
74
75
|
|
75
76
|
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
|
76
77
|
|
@@ -257,9 +258,8 @@ class Module(object):
|
|
257
258
|
--------
|
258
259
|
|
259
260
|
>>> import brainstate as bst
|
260
|
-
>>> import brainscale as nn # noqa
|
261
261
|
>>> x = bst.random.rand((10, 10))
|
262
|
-
>>> l = nn.Activation(jax.numpy.tanh)
|
262
|
+
>>> l = bst.nn.Activation(jax.numpy.tanh)
|
263
263
|
>>> y = x >> l
|
264
264
|
"""
|
265
265
|
return self.__call__(other)
|
@@ -401,8 +401,8 @@ class visible_module_list(list):
|
|
401
401
|
retieved when using :py:func:`~.nodes()` function.
|
402
402
|
|
403
403
|
>>> import brainstate as bst
|
404
|
-
>>> l = bst.visible_module_list([
|
405
|
-
>>>
|
404
|
+
>>> l = bst.visible_module_list([bst.nn.Linear(1, 2),
|
405
|
+
>>> bst.nn.LSTMCell(2, 3)])
|
406
406
|
"""
|
407
407
|
|
408
408
|
__module__ = 'brainstate'
|
@@ -1038,14 +1038,14 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1038
1038
|
delay = length data ]
|
1039
1039
|
entries: optional, dict. The delay access entries.
|
1040
1040
|
name: str. The delay name.
|
1041
|
-
|
1041
|
+
delay_method: str. The method used for updating delay. Default None.
|
1042
1042
|
mode: Mode. The computing mode. Default None.
|
1043
1043
|
"""
|
1044
1044
|
|
1045
1045
|
__module__ = 'brainstate'
|
1046
1046
|
|
1047
|
-
|
1048
|
-
max_time: float
|
1047
|
+
non_hashable_params = ('time', 'entries', 'name')
|
1048
|
+
max_time: float #
|
1049
1049
|
max_length: int
|
1050
1050
|
history: Optional[State]
|
1051
1051
|
|
@@ -1053,20 +1053,27 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1053
1053
|
self,
|
1054
1054
|
target_info: PyTree,
|
1055
1055
|
time: Optional[Union[int, float]] = None, # delay time
|
1056
|
-
init: Optional[Union[ArrayLike, Callable]] = None, # delay data
|
1056
|
+
init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
|
1057
1057
|
entries: Optional[Dict] = None, # delay access entry
|
1058
|
-
|
1058
|
+
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
1059
|
+
interp_method: str = _INTERP_LINEAR, # interpolation method
|
1059
1060
|
# others
|
1060
1061
|
name: Optional[str] = None,
|
1061
1062
|
mode: Optional[Mode] = None,
|
1062
1063
|
):
|
1063
1064
|
|
1064
1065
|
# target information
|
1065
|
-
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape,
|
1066
|
+
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
1066
1067
|
|
1067
1068
|
# delay method
|
1068
|
-
assert
|
1069
|
-
|
1069
|
+
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
|
1070
|
+
f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
|
1071
|
+
self.delay_method = delay_method
|
1072
|
+
|
1073
|
+
# interp method
|
1074
|
+
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
|
1075
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
1076
|
+
self.interp_method = interp_method
|
1070
1077
|
|
1071
1078
|
# delay length and time
|
1072
1079
|
self.max_time, delay_length = _get_delay(time, None)
|
@@ -1076,7 +1083,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1076
1083
|
|
1077
1084
|
# delay data
|
1078
1085
|
if init is not None:
|
1079
|
-
|
1086
|
+
if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
|
1087
|
+
raise TypeError(f'init should be Array, Callable, or None. But got {init}')
|
1080
1088
|
self._init = init
|
1081
1089
|
self._history = None
|
1082
1090
|
|
@@ -1090,7 +1098,11 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1090
1098
|
|
1091
1099
|
def __repr__(self):
|
1092
1100
|
name = self.__class__.__name__
|
1093
|
-
return f'{name}(
|
1101
|
+
return (f'{name}('
|
1102
|
+
f'delay_length={self.max_length}, '
|
1103
|
+
f'target_info={self.target_info}, '
|
1104
|
+
f'delay_method="{self.delay_method}", '
|
1105
|
+
f'interp_method="{self.interp_method}")')
|
1094
1106
|
|
1095
1107
|
@property
|
1096
1108
|
def history(self):
|
@@ -1105,7 +1117,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1105
1117
|
if batch_size is not None:
|
1106
1118
|
shape.insert(self.mode.batch_axis, batch_size)
|
1107
1119
|
shape.insert(0, length)
|
1108
|
-
if isinstance(self._init, (jax.Array, numbers.Number)):
|
1120
|
+
if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
|
1109
1121
|
data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
|
1110
1122
|
elif callable(self._init):
|
1111
1123
|
data = self._init(shape, dtype=a.dtype)
|
@@ -1132,7 +1144,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1132
1144
|
delay_time: Optional[Union[int, float]] = None,
|
1133
1145
|
delay_step: Optional[int] = None,
|
1134
1146
|
) -> 'Delay':
|
1135
|
-
"""
|
1147
|
+
"""
|
1148
|
+
Register an entry to access the delay data.
|
1136
1149
|
|
1137
1150
|
Args:
|
1138
1151
|
entry: str. The entry to access the delay data.
|
@@ -1162,7 +1175,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1162
1175
|
return self
|
1163
1176
|
|
1164
1177
|
def at(self, entry: str, *indices) -> ArrayLike:
|
1165
|
-
"""
|
1178
|
+
"""
|
1179
|
+
Get the data at the given entry.
|
1166
1180
|
|
1167
1181
|
Args:
|
1168
1182
|
entry: str. The entry to access the data.
|
@@ -1178,15 +1192,23 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1178
1192
|
delay_step = self._registered_entries[entry]
|
1179
1193
|
if delay_step is None:
|
1180
1194
|
delay_step = 0
|
1181
|
-
return self.
|
1195
|
+
return self.retrieve_at_step(delay_step, *indices)
|
1182
1196
|
|
1183
|
-
def
|
1184
|
-
"""
|
1197
|
+
def retrieve_at_step(self, delay_step, *indices) -> PyTree:
|
1198
|
+
"""
|
1199
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
1185
1200
|
|
1186
1201
|
Parameters
|
1187
1202
|
----------
|
1188
|
-
delay_step:
|
1189
|
-
|
1203
|
+
delay_step: int_like
|
1204
|
+
Retrieve the data at the given time step.
|
1205
|
+
indices: tuple
|
1206
|
+
The indices to slice the data.
|
1207
|
+
|
1208
|
+
Returns
|
1209
|
+
-------
|
1210
|
+
delay_data: The delay data at the given delay step.
|
1211
|
+
|
1190
1212
|
"""
|
1191
1213
|
assert self.history is not None, 'The delay history is not initialized.'
|
1192
1214
|
assert delay_step is not None, 'The delay step should be given.'
|
@@ -1199,17 +1221,17 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1199
1221
|
jit_error(delay_step >= self.max_length, _check_delay, delay_step)
|
1200
1222
|
|
1201
1223
|
# rotation method
|
1202
|
-
if self.
|
1203
|
-
i = environ.get(environ.I)
|
1224
|
+
if self.delay_method == _DELAY_ROTATE:
|
1225
|
+
i = environ.get(environ.I, desc='The time step index.')
|
1204
1226
|
di = i - delay_step
|
1205
1227
|
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
1206
1228
|
delay_idx = jax.lax.stop_gradient(delay_idx)
|
1207
1229
|
|
1208
|
-
elif self.
|
1230
|
+
elif self.delay_method == _DELAY_CONCAT:
|
1209
1231
|
delay_idx = delay_step
|
1210
1232
|
|
1211
1233
|
else:
|
1212
|
-
raise ValueError(f'Unknown updating method "{self.
|
1234
|
+
raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
1213
1235
|
|
1214
1236
|
# the delay index
|
1215
1237
|
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
|
@@ -1219,6 +1241,81 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1219
1241
|
# the delay data
|
1220
1242
|
return jax.tree.map(lambda a: a[indices], self.history.value)
|
1221
1243
|
|
1244
|
+
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
1245
|
+
"""
|
1246
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
1247
|
+
|
1248
|
+
Parameters
|
1249
|
+
----------
|
1250
|
+
delay_time: float
|
1251
|
+
Retrieve the data at the given time.
|
1252
|
+
indices: tuple
|
1253
|
+
The indices to slice the data.
|
1254
|
+
|
1255
|
+
Returns
|
1256
|
+
-------
|
1257
|
+
delay_data: The delay data at the given delay step.
|
1258
|
+
|
1259
|
+
"""
|
1260
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
1261
|
+
assert delay_time is not None, 'The delay time should be given.'
|
1262
|
+
|
1263
|
+
current_time = environ.get(environ.T, desc='The current time.')
|
1264
|
+
dt = environ.get_dt()
|
1265
|
+
|
1266
|
+
if environ.get(environ.JIT_ERROR_CHECK, False):
|
1267
|
+
def _check_delay(args):
|
1268
|
+
t_now, t_delay = args
|
1269
|
+
raise ValueError(f'The request delay time should be within '
|
1270
|
+
f'[{t_now - self.max_time - dt}, {t_now}], '
|
1271
|
+
f'but we got {t_delay}')
|
1272
|
+
|
1273
|
+
jit_error(jnp.logical_or(delay_time > current_time,
|
1274
|
+
delay_time < current_time - self.max_time - dt),
|
1275
|
+
_check_delay,
|
1276
|
+
(current_time, delay_time))
|
1277
|
+
|
1278
|
+
diff = current_time - delay_time
|
1279
|
+
float_time_step = diff / dt
|
1280
|
+
|
1281
|
+
if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
|
1282
|
+
# def _interp(target):
|
1283
|
+
# if len(indices) > 0:
|
1284
|
+
# raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
|
1285
|
+
# if self.delay_method == _DELAY_ROTATE:
|
1286
|
+
# i = environ.get(environ.I, desc='The time step index.')
|
1287
|
+
# _interp_fun = partial(jnp.interp, period=self.max_length)
|
1288
|
+
# for dim in range(1, target.ndim, 1):
|
1289
|
+
# _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
|
1290
|
+
# di = i - jnp.arange(self.max_length)
|
1291
|
+
# delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
1292
|
+
# return _interp_fun(float_time_step, delay_idx, target)
|
1293
|
+
#
|
1294
|
+
# elif self.delay_method == _DELAY_CONCAT:
|
1295
|
+
# _interp_fun = partial(jnp.interp, period=self.max_length)
|
1296
|
+
# for dim in range(1, target.ndim, 1):
|
1297
|
+
# _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
|
1298
|
+
# return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
|
1299
|
+
#
|
1300
|
+
# else:
|
1301
|
+
# raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
1302
|
+
# return jax.tree.map(_interp, self.history.value)
|
1303
|
+
|
1304
|
+
data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
|
1305
|
+
data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
|
1306
|
+
t_diff = float_time_step - jnp.floor(float_time_step)
|
1307
|
+
return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
|
1308
|
+
|
1309
|
+
elif self.interp_method == _INTERP_ROUND: # "round" interpolation
|
1310
|
+
return self.retrieve_at_step(
|
1311
|
+
jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
|
1312
|
+
*indices
|
1313
|
+
)
|
1314
|
+
|
1315
|
+
else: # raise error
|
1316
|
+
raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
|
1317
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
1318
|
+
|
1222
1319
|
def update(self, current: PyTree) -> None:
|
1223
1320
|
"""
|
1224
1321
|
Update delay variable with the new data.
|
@@ -1226,25 +1323,29 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1226
1323
|
assert self.history is not None, 'The delay history is not initialized.'
|
1227
1324
|
|
1228
1325
|
# update the delay data at the rotation index
|
1229
|
-
if self.
|
1326
|
+
if self.delay_method == _DELAY_ROTATE:
|
1230
1327
|
i = environ.get(environ.I)
|
1231
1328
|
idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
|
1232
1329
|
idx = jax.lax.stop_gradient(idx)
|
1233
|
-
self.history.value = jax.tree.map(
|
1234
|
-
|
1235
|
-
|
1330
|
+
self.history.value = jax.tree.map(
|
1331
|
+
lambda hist, cur: hist.at[idx].set(cur),
|
1332
|
+
self.history.value,
|
1333
|
+
current
|
1334
|
+
)
|
1236
1335
|
# update the delay data at the first position
|
1237
|
-
elif self.
|
1336
|
+
elif self.delay_method == _DELAY_CONCAT:
|
1238
1337
|
current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
|
1239
1338
|
if self.max_length > 1:
|
1240
|
-
self.history.value = jax.tree.map(
|
1241
|
-
|
1242
|
-
|
1339
|
+
self.history.value = jax.tree.map(
|
1340
|
+
lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
|
1341
|
+
self.history.value,
|
1342
|
+
current
|
1343
|
+
)
|
1243
1344
|
else:
|
1244
1345
|
self.history.value = current
|
1245
1346
|
|
1246
1347
|
else:
|
1247
|
-
raise ValueError(f'Unknown updating method "{self.
|
1348
|
+
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
1248
1349
|
|
1249
1350
|
|
1250
1351
|
class _StateDelay(Delay):
|
@@ -1265,14 +1366,18 @@ class _StateDelay(Delay):
|
|
1265
1366
|
time: Optional[Union[int, float]] = None, # delay time
|
1266
1367
|
init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
|
1267
1368
|
entries: Optional[Dict] = None, # delay access entry
|
1268
|
-
|
1369
|
+
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
1269
1370
|
# others
|
1270
1371
|
name: Optional[str] = None,
|
1271
1372
|
mode: Optional[Mode] = None,
|
1272
1373
|
):
|
1273
1374
|
super().__init__(target_info=target.value,
|
1274
|
-
time=time,
|
1275
|
-
|
1375
|
+
time=time,
|
1376
|
+
init=init,
|
1377
|
+
entries=entries,
|
1378
|
+
delay_method=delay_method,
|
1379
|
+
name=name,
|
1380
|
+
mode=mode)
|
1276
1381
|
self.target = target
|
1277
1382
|
|
1278
1383
|
def update(self, *args, **kwargs):
|
brainstate/_module_test.py
CHANGED
@@ -16,14 +16,15 @@
|
|
16
16
|
import unittest
|
17
17
|
|
18
18
|
import jax.numpy as jnp
|
19
|
+
import jaxlib.xla_extension
|
19
20
|
|
20
|
-
import brainstate as
|
21
|
+
import brainstate as bst
|
21
22
|
|
22
23
|
|
23
|
-
class
|
24
|
+
class TestDelay(unittest.TestCase):
|
24
25
|
def test_delay1(self):
|
25
|
-
a =
|
26
|
-
delay =
|
26
|
+
a = bst.State(bst.random.random(10, 20))
|
27
|
+
delay = bst.Delay(a.value)
|
27
28
|
delay.register_entry('a', 1.)
|
28
29
|
delay.register_entry('b', 2.)
|
29
30
|
delay.register_entry('c', None)
|
@@ -31,10 +32,10 @@ class TestVarDelay(unittest.TestCase):
|
|
31
32
|
delay.init_state()
|
32
33
|
with self.assertRaises(KeyError):
|
33
34
|
delay.register_entry('c', 10.)
|
34
|
-
|
35
|
+
bst.util.clear_buffer_memory()
|
35
36
|
|
36
37
|
def test_rotation_delay(self):
|
37
|
-
rotation_delay =
|
38
|
+
rotation_delay = bst.Delay(jnp.ones((1,)))
|
38
39
|
t0 = 0.
|
39
40
|
t1, n1 = 1., 10
|
40
41
|
t2, n2 = 2., 20
|
@@ -51,16 +52,16 @@ class TestVarDelay(unittest.TestCase):
|
|
51
52
|
# print(rotation_delay.max_length)
|
52
53
|
|
53
54
|
for i in range(100):
|
54
|
-
|
55
|
+
bst.environ.set(i=i)
|
55
56
|
rotation_delay(jnp.ones((1,)) * i)
|
56
57
|
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
57
58
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
58
59
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
59
60
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
60
|
-
|
61
|
+
bst.util.clear_buffer_memory()
|
61
62
|
|
62
63
|
def test_concat_delay(self):
|
63
|
-
rotation_delay =
|
64
|
+
rotation_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
|
64
65
|
t0 = 0.
|
65
66
|
t1, n1 = 1., 10
|
66
67
|
t2, n2 = 2., 20
|
@@ -73,17 +74,91 @@ class TestVarDelay(unittest.TestCase):
|
|
73
74
|
|
74
75
|
print()
|
75
76
|
for i in range(100):
|
76
|
-
|
77
|
+
bst.environ.set(i=i)
|
77
78
|
rotation_delay(jnp.ones((1,)) * i)
|
78
79
|
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
79
80
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
80
81
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
81
82
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
82
|
-
|
83
|
+
bst.util.clear_buffer_memory()
|
84
|
+
|
85
|
+
def test_jit_erro(self):
|
86
|
+
rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
|
+
rotation_delay.init_state()
|
88
|
+
|
89
|
+
with bst.environ.context(i=0, t=0):
|
90
|
+
rotation_delay.retrieve_at_time(-2.0)
|
91
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
92
|
+
rotation_delay.retrieve_at_time(-2.1)
|
93
|
+
rotation_delay.retrieve_at_time(-2.01)
|
94
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
95
|
+
rotation_delay.retrieve_at_time(-2.09)
|
96
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
97
|
+
rotation_delay.retrieve_at_time(0.1)
|
98
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
99
|
+
rotation_delay.retrieve_at_time(0.01)
|
100
|
+
|
101
|
+
def test_round_interp(self):
|
102
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
|
+
for delay_method in ['rotation', 'concat']:
|
104
|
+
rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='round')
|
105
|
+
t0, n1 = 0.01, 0
|
106
|
+
t1, n1 = 1.04, 10
|
107
|
+
t2, n2 = 1.06, 11
|
108
|
+
rotation_delay.init_state()
|
109
|
+
|
110
|
+
@bst.transform.jit
|
111
|
+
def retrieve(td, i):
|
112
|
+
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
113
|
+
return rotation_delay.retrieve_at_time(td)
|
114
|
+
|
115
|
+
print()
|
116
|
+
for i in range(100):
|
117
|
+
t = i * bst.environ.get_dt()
|
118
|
+
with bst.environ.context(i=i, t=t):
|
119
|
+
rotation_delay(jnp.ones(shape) * i)
|
120
|
+
print(i,
|
121
|
+
retrieve(t - t0, i),
|
122
|
+
retrieve(t - t1, i),
|
123
|
+
retrieve(t - t2, i))
|
124
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
|
125
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
126
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
127
|
+
bst.util.clear_buffer_memory()
|
128
|
+
|
129
|
+
def test_linear_interp(self):
|
130
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
131
|
+
for delay_method in ['rotation', 'concat']:
|
132
|
+
print(shape, delay_method)
|
133
|
+
|
134
|
+
rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='linear_interp')
|
135
|
+
t0, n0 = 0.01, 0.1
|
136
|
+
t1, n1 = 1.04, 10.4
|
137
|
+
t2, n2 = 1.06, 10.6
|
138
|
+
rotation_delay.init_state()
|
139
|
+
|
140
|
+
@bst.transform.jit
|
141
|
+
def retrieve(td, i):
|
142
|
+
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
143
|
+
return rotation_delay.retrieve_at_time(td)
|
144
|
+
|
145
|
+
print()
|
146
|
+
for i in range(100):
|
147
|
+
t = i * bst.environ.get_dt()
|
148
|
+
with bst.environ.context(i=i, t=t):
|
149
|
+
rotation_delay(jnp.ones(shape) * i)
|
150
|
+
print(i,
|
151
|
+
retrieve(t - t0, i),
|
152
|
+
retrieve(t - t1, i),
|
153
|
+
retrieve(t - t2, i))
|
154
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
|
155
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
156
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
157
|
+
bst.util.clear_buffer_memory()
|
83
158
|
|
84
159
|
def test_rotation_and_concat_delay(self):
|
85
|
-
rotation_delay =
|
86
|
-
concat_delay =
|
160
|
+
rotation_delay = bst.Delay(jnp.ones((1,)))
|
161
|
+
concat_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
|
87
162
|
t0 = 0.
|
88
163
|
t1, n1 = 1., 10
|
89
164
|
t2, n2 = 2., 20
|
@@ -100,29 +175,29 @@ class TestVarDelay(unittest.TestCase):
|
|
100
175
|
|
101
176
|
print()
|
102
177
|
for i in range(100):
|
103
|
-
|
178
|
+
bst.environ.set(i=i)
|
104
179
|
new = jnp.ones((1,)) * i
|
105
180
|
rotation_delay(new)
|
106
181
|
concat_delay(new)
|
107
182
|
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
108
183
|
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
109
184
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
110
|
-
|
185
|
+
bst.util.clear_buffer_memory()
|
111
186
|
|
112
187
|
|
113
188
|
class TestModule(unittest.TestCase):
|
114
189
|
def test_states(self):
|
115
|
-
class A(
|
190
|
+
class A(bst.Module):
|
116
191
|
def __init__(self):
|
117
192
|
super().__init__()
|
118
|
-
self.a =
|
119
|
-
self.b =
|
193
|
+
self.a = bst.State(bst.random.random(10, 20))
|
194
|
+
self.b = bst.State(bst.random.random(10, 20))
|
120
195
|
|
121
|
-
class B(
|
196
|
+
class B(bst.Module):
|
122
197
|
def __init__(self):
|
123
198
|
super().__init__()
|
124
199
|
self.a = A()
|
125
|
-
self.b =
|
200
|
+
self.b = bst.State(bst.random.random(10, 20))
|
126
201
|
|
127
202
|
b = B()
|
128
203
|
print()
|
@@ -130,4 +205,3 @@ class TestModule(unittest.TestCase):
|
|
130
205
|
print(b.states())
|
131
206
|
print(b.states(level=0))
|
132
207
|
print(b.states(level=0))
|
133
|
-
|
brainstate/environ.py
CHANGED
@@ -18,9 +18,9 @@ from ._activations import *
|
|
18
18
|
from ._activations import __all__ as __activations_all__
|
19
19
|
from ._normalization import *
|
20
20
|
from ._normalization import __all__ as __others_all__
|
21
|
-
from ._spikes import *
|
22
|
-
from ._spikes import __all__ as __spikes_all__
|
23
21
|
from ._others import *
|
24
22
|
from ._others import __all__ as __others_all__
|
23
|
+
from ._spikes import *
|
24
|
+
from ._spikes import __all__ as __spikes_all__
|
25
25
|
|
26
26
|
__all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
|
@@ -27,7 +27,7 @@ import jax.numpy as jnp
|
|
27
27
|
from jax.scipy.special import logsumexp
|
28
28
|
from jax.typing import ArrayLike
|
29
29
|
|
30
|
-
from .. import
|
30
|
+
from .. import random
|
31
31
|
|
32
32
|
__all__ = [
|
33
33
|
"tanh",
|
@@ -136,10 +136,7 @@ def prelu(x, a=0.25):
|
|
136
136
|
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
137
137
|
a separate :math:`a` is used for each input channel.
|
138
138
|
"""
|
139
|
-
|
140
|
-
return jnp.where(x >= jnp.asarray(0., dtype),
|
141
|
-
x,
|
142
|
-
jnp.asarray(a, dtype) * x)
|
139
|
+
return jnp.where(x >= 0., x, a * x)
|
143
140
|
|
144
141
|
|
145
142
|
def soft_shrink(x, lambd=0.5):
|
@@ -161,11 +158,7 @@ def soft_shrink(x, lambd=0.5):
|
|
161
158
|
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
162
159
|
- Output: :math:`(*)`, same shape as the input.
|
163
160
|
"""
|
164
|
-
|
165
|
-
lambd = jnp.asarray(lambd, dtype)
|
166
|
-
return jnp.where(x > lambd,
|
167
|
-
x - lambd,
|
168
|
-
jnp.where(x < -lambd, x + lambd, jnp.asarray(0., dtype)))
|
161
|
+
return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.))
|
169
162
|
|
170
163
|
|
171
164
|
def mish(x):
|
@@ -217,9 +210,8 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
|
217
210
|
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
218
211
|
https://arxiv.org/abs/1505.00853
|
219
212
|
"""
|
220
|
-
|
221
|
-
|
222
|
-
return jnp.where(x >= jnp.asarray(0., dtype), x, jnp.asarray(a, dtype) * x)
|
213
|
+
a = random.uniform(lower, upper, size=jnp.shape(x), dtype=x.dtype)
|
214
|
+
return jnp.where(x >= 0., x, a * x)
|
223
215
|
|
224
216
|
|
225
217
|
def hard_shrink(x, lambd=0.5):
|
@@ -243,11 +235,7 @@ def hard_shrink(x, lambd=0.5):
|
|
243
235
|
- Output: :math:`(*)`, same shape as the input.
|
244
236
|
|
245
237
|
"""
|
246
|
-
|
247
|
-
lambd = jnp.asarray(lambd, dtype)
|
248
|
-
return jnp.where(x > lambd,
|
249
|
-
x,
|
250
|
-
jnp.where(x < -lambd, x, jnp.asarray(0., dtype)))
|
238
|
+
return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.))
|
251
239
|
|
252
240
|
|
253
241
|
def relu(x: ArrayLike) -> jax.Array:
|
@@ -298,8 +286,7 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
|
|
298
286
|
x : input array
|
299
287
|
b : smoothness parameter
|
300
288
|
"""
|
301
|
-
|
302
|
-
return jax.nn.squareplus(x, jnp.asarray(b, dtype))
|
289
|
+
return jax.nn.squareplus(x, b)
|
303
290
|
|
304
291
|
|
305
292
|
def softplus(x: ArrayLike) -> jax.Array:
|
@@ -417,8 +404,6 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
|
417
404
|
See also:
|
418
405
|
:func:`selu`
|
419
406
|
"""
|
420
|
-
dtype = math.get_dtype(x)
|
421
|
-
alpha = jnp.asarray(alpha, dtype)
|
422
407
|
return jax.nn.elu(x, alpha)
|
423
408
|
|
424
409
|
|
@@ -445,8 +430,6 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
|
|
445
430
|
See also:
|
446
431
|
:func:`relu`
|
447
432
|
"""
|
448
|
-
dtype = math.get_dtype(x)
|
449
|
-
negative_slope = jnp.asarray(negative_slope, dtype)
|
450
433
|
return jax.nn.leaky_relu(x, negative_slope=negative_slope)
|
451
434
|
|
452
435
|
|
@@ -493,8 +476,6 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
|
493
476
|
Returns:
|
494
477
|
An array.
|
495
478
|
"""
|
496
|
-
dtype = math.get_dtype(x)
|
497
|
-
alpha = jnp.asarray(alpha, dtype)
|
498
479
|
return jax.nn.celu(x, alpha)
|
499
480
|
|
500
481
|
|
brainstate/functional/_spikes.py
CHANGED