brainstate 0.0.1.1.post20240708__py2.py3-none-any.whl → 0.0.1.1.post20240804__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +21 -23
- brainstate/_state.py +25 -8
- brainstate/functional/_activations.py +1 -1
- brainstate/functional/_others.py +1 -2
- brainstate/init/_generic.py +2 -2
- brainstate/init/_random_inits.py +1 -1
- brainstate/mixin.py +12 -11
- brainstate/mixin_test.py +2 -0
- brainstate/nn/_elementwise.py +2 -2
- brainstate/random.py +29 -17
- brainstate/transform/_control.py +2 -2
- brainstate/transform/_jit_error_test.py +1 -1
- brainstate/transform/_make_jaxpr.py +1 -1
- brainstate/typing.py +150 -2
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/METADATA +2 -2
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/RECORD +19 -19
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/top_level.txt +0 -0
brainstate/_module.py
CHANGED
@@ -61,11 +61,9 @@ from ._state import State, StateDictManager, visible_state_dict
|
|
61
61
|
from ._utils import set_module_as
|
62
62
|
from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
|
63
63
|
from .transform import jit_error
|
64
|
+
from .typing import Size, ArrayLike, PyTree
|
64
65
|
from .util import unique_name, DictManager, get_unique_name
|
65
66
|
|
66
|
-
Shape = Union[int, Sequence[int]]
|
67
|
-
PyTree = Any
|
68
|
-
ArrayLike = jax.typing.ArrayLike
|
69
67
|
|
70
68
|
delay_identifier = '_*_delay_of_'
|
71
69
|
_DELAY_ROTATE = 'rotation'
|
@@ -805,7 +803,7 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
|
|
805
803
|
|
806
804
|
def __init__(
|
807
805
|
self,
|
808
|
-
size:
|
806
|
+
size: Size,
|
809
807
|
keep_size: bool = False,
|
810
808
|
name: Optional[str] = None,
|
811
809
|
mode: Optional[Mode] = None,
|
@@ -1275,25 +1273,25 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1275
1273
|
|
1276
1274
|
if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
|
1277
1275
|
# def _interp(target):
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1276
|
+
# if len(indices) > 0:
|
1277
|
+
# raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
|
1278
|
+
# if self.delay_method == _DELAY_ROTATE:
|
1279
|
+
# i = environ.get(environ.I, desc='The time step index.')
|
1280
|
+
# _interp_fun = partial(jnp.interp, period=self.max_length)
|
1281
|
+
# for dim in range(1, target.ndim, 1):
|
1282
|
+
# _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
|
1283
|
+
# di = i - jnp.arange(self.max_length)
|
1284
|
+
# delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
1285
|
+
# return _interp_fun(float_time_step, delay_idx, target)
|
1286
|
+
#
|
1287
|
+
# elif self.delay_method == _DELAY_CONCAT:
|
1288
|
+
# _interp_fun = partial(jnp.interp, period=self.max_length)
|
1289
|
+
# for dim in range(1, target.ndim, 1):
|
1290
|
+
# _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
|
1291
|
+
# return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
|
1292
|
+
#
|
1293
|
+
# else:
|
1294
|
+
# raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
1297
1295
|
# return jax.tree.map(_interp, self.history.value)
|
1298
1296
|
|
1299
1297
|
data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
|
brainstate/_state.py
CHANGED
@@ -15,18 +15,16 @@
|
|
15
15
|
|
16
16
|
import contextlib
|
17
17
|
import threading
|
18
|
-
from typing import Any, Tuple, Dict, List, Callable
|
18
|
+
from typing import Any, Tuple, Dict, List, Callable, Optional
|
19
19
|
|
20
20
|
import jax
|
21
21
|
import numpy as np
|
22
22
|
from jax.api_util import shaped_abstractify
|
23
23
|
from jax.extend import source_info_util
|
24
24
|
|
25
|
+
from .typing import ArrayLike, PyTree
|
25
26
|
from .util import DictManager
|
26
27
|
|
27
|
-
PyTree = Any
|
28
|
-
max_int = np.iinfo(np.int32)
|
29
|
-
|
30
28
|
__all__ = [
|
31
29
|
'State', 'ShortTermState', 'LongTermState', 'ParamState',
|
32
30
|
'StateDictManager',
|
@@ -36,6 +34,7 @@ __all__ = [
|
|
36
34
|
]
|
37
35
|
|
38
36
|
_pytree_registered_objects = set()
|
37
|
+
max_int = np.iinfo(np.int32)
|
39
38
|
|
40
39
|
|
41
40
|
def _register_pytree_cls(cls):
|
@@ -108,9 +107,9 @@ class State(object):
|
|
108
107
|
value: PyTree. It can be anything as a pyTree.
|
109
108
|
"""
|
110
109
|
__module__ = 'brainstate'
|
111
|
-
__slots__ = ('_value', '_tree', '_level', '_source_info', '_check_tree')
|
110
|
+
__slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree')
|
112
111
|
|
113
|
-
def __init__(self, value: PyTree):
|
112
|
+
def __init__(self, value: PyTree[ArrayLike], name: Optional[str] = None):
|
114
113
|
if isinstance(value, State):
|
115
114
|
value = value.value
|
116
115
|
self._value = value
|
@@ -118,9 +117,24 @@ class State(object):
|
|
118
117
|
self._check_tree = False
|
119
118
|
self._level = len(thread_local_stack.stack)
|
120
119
|
self._source_info = source_info_util.current()
|
120
|
+
self._name = name
|
121
|
+
|
122
|
+
@property
|
123
|
+
def name(self) -> Optional[str]:
|
124
|
+
"""
|
125
|
+
The name of the state.
|
126
|
+
"""
|
127
|
+
return self._name
|
128
|
+
|
129
|
+
@name.setter
|
130
|
+
def name(self, name: str) -> None:
|
131
|
+
"""
|
132
|
+
Set the name of the state.
|
133
|
+
"""
|
134
|
+
self._name = name
|
121
135
|
|
122
136
|
@property
|
123
|
-
def value(self) -> PyTree:
|
137
|
+
def value(self) -> PyTree[ArrayLike]:
|
124
138
|
"""
|
125
139
|
The data and its value.
|
126
140
|
"""
|
@@ -210,7 +224,10 @@ class State(object):
|
|
210
224
|
leaves, tree = jax.tree.flatten(self._value)
|
211
225
|
leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
|
212
226
|
tree_info = jax.tree.unflatten(tree, leaves_info)
|
213
|
-
|
227
|
+
if self.name is None:
|
228
|
+
return f'{self.__class__.__name__}({tree_info})'
|
229
|
+
else:
|
230
|
+
return f'{self.__class__.__name__}({self.name}: {tree_info})'
|
214
231
|
|
215
232
|
|
216
233
|
class ShapeDtype:
|
brainstate/functional/_others.py
CHANGED
brainstate/init/_generic.py
CHANGED
@@ -22,8 +22,8 @@ import jax
|
|
22
22
|
import numpy as np
|
23
23
|
|
24
24
|
from brainstate._state import State
|
25
|
+
from brainstate.typing import ArrayLike
|
25
26
|
from ._base import to_size
|
26
|
-
from ..typing import ArrayLike
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'param',
|
@@ -83,7 +83,7 @@ def _expand_params_to_match_sizes(params, sizes):
|
|
83
83
|
|
84
84
|
|
85
85
|
def param(
|
86
|
-
parameter: Union[Callable, ArrayLike],
|
86
|
+
parameter: Union[Callable, ArrayLike, State],
|
87
87
|
sizes: Union[int, Sequence[int]],
|
88
88
|
batch_size: Optional[int] = None,
|
89
89
|
allow_none: bool = True,
|
brainstate/init/_random_inits.py
CHANGED
brainstate/mixin.py
CHANGED
@@ -15,14 +15,15 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
from typing import Sequence, Optional, TypeVar
|
18
|
+
from typing import Sequence, Optional, TypeVar
|
19
19
|
from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
|
20
20
|
|
21
|
-
|
22
|
-
PyTree = Any
|
21
|
+
from .typing import PyTree
|
23
22
|
|
23
|
+
T = TypeVar('T')
|
24
24
|
State = None
|
25
25
|
|
26
|
+
|
26
27
|
__all__ = [
|
27
28
|
'Mixin',
|
28
29
|
'DelayedInit',
|
@@ -207,7 +208,7 @@ class _JointGenericAlias(_UnionGenericAlias, _root=True):
|
|
207
208
|
|
208
209
|
@_SpecialForm
|
209
210
|
def JointTypes(self, parameters):
|
210
|
-
"""
|
211
|
+
"""Joint types; JointTypes[X, Y] means both X and Y.
|
211
212
|
|
212
213
|
To define a union, use e.g. Union[int, str].
|
213
214
|
|
@@ -216,28 +217,28 @@ def JointTypes(self, parameters):
|
|
216
217
|
- None as an argument is a special case and is replaced by `type(None)`.
|
217
218
|
- Unions of unions are flattened, e.g.::
|
218
219
|
|
219
|
-
|
220
|
+
JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
|
220
221
|
|
221
222
|
- Unions of a single argument vanish, e.g.::
|
222
223
|
|
223
|
-
|
224
|
+
JointTypes[int] == int # The constructor actually returns int
|
224
225
|
|
225
226
|
- Redundant arguments are skipped, e.g.::
|
226
227
|
|
227
|
-
|
228
|
+
JointTypes[int, str, int] == JointTypes[int, str]
|
228
229
|
|
229
230
|
- When comparing unions, the argument order is ignored, e.g.::
|
230
231
|
|
231
|
-
|
232
|
+
JointTypes[int, str] == JointTypes[str, int]
|
232
233
|
|
233
|
-
- You cannot subclass or instantiate a
|
234
|
-
- You can use Optional[X] as a shorthand for
|
234
|
+
- You cannot subclass or instantiate a JointTypes.
|
235
|
+
- You can use Optional[X] as a shorthand for JointTypes[X, None].
|
235
236
|
"""
|
236
237
|
if parameters == ():
|
237
238
|
raise TypeError("Cannot take a Joint of no types.")
|
238
239
|
if not isinstance(parameters, tuple):
|
239
240
|
parameters = (parameters,)
|
240
|
-
msg = "
|
241
|
+
msg = "JointTypes[arg, ...]: each arg must be a type."
|
241
242
|
parameters = tuple(_type_check(p, msg) for p in parameters)
|
242
243
|
parameters = _remove_dups_flatten(parameters)
|
243
244
|
if len(parameters) == 1:
|
brainstate/mixin_test.py
CHANGED
brainstate/nn/_elementwise.py
CHANGED
@@ -1139,13 +1139,13 @@ class Dropout(Module, ElementWiseBlock):
|
|
1139
1139
|
name: Optional[str] = None
|
1140
1140
|
) -> None:
|
1141
1141
|
super().__init__(mode=mode, name=name)
|
1142
|
-
assert 0. <= prob
|
1142
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
1143
1143
|
self.prob = prob
|
1144
1144
|
|
1145
1145
|
def __call__(self, x):
|
1146
1146
|
dtype = bu.math.get_dtype(x)
|
1147
1147
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1148
|
-
if fit_phase
|
1148
|
+
if fit_phase and self.prob < 1.:
|
1149
1149
|
keep_mask = random.bernoulli(self.prob, x.shape)
|
1150
1150
|
return jnp.where(keep_mask,
|
1151
1151
|
jnp.asarray(x / self.prob, dtype=dtype),
|
brainstate/random.py
CHANGED
@@ -1167,23 +1167,32 @@ def default_rng(seed_or_key=None, clone: bool = True) -> RandomState:
|
|
1167
1167
|
return RandomState(seed_or_key)
|
1168
1168
|
|
1169
1169
|
|
1170
|
-
def seed(
|
1170
|
+
def seed(seed_or_key: int = None):
|
1171
1171
|
"""Sets a new random seed.
|
1172
1172
|
|
1173
1173
|
Parameters
|
1174
1174
|
----------
|
1175
|
-
|
1176
|
-
The random seed.
|
1175
|
+
seed_or_key: int, optional
|
1176
|
+
The random seed (an integer) or jax random key.
|
1177
1177
|
"""
|
1178
1178
|
with jax.ensure_compile_time_eval():
|
1179
|
-
if
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1179
|
+
if seed_or_key is None:
|
1180
|
+
seed_or_key = np.random.randint(0, 100000)
|
1181
|
+
|
1182
|
+
# numpy random seed
|
1183
|
+
if np.size(seed_or_key) == 1: # seed
|
1184
|
+
np.random.seed(seed_or_key)
|
1185
|
+
elif np.size(seed_or_key) == 2: # jax random key
|
1186
|
+
np.random.seed(seed_or_key[0])
|
1187
|
+
else:
|
1188
|
+
raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
|
1189
|
+
|
1190
|
+
# jax random seed
|
1191
|
+
DEFAULT.seed(seed_or_key)
|
1183
1192
|
|
1184
1193
|
|
1185
1194
|
@contextmanager
|
1186
|
-
def seed_context(
|
1195
|
+
def seed_context(seed_or_key: SeedOrKey):
|
1187
1196
|
"""
|
1188
1197
|
A context manager that sets the random seed for the duration of the block.
|
1189
1198
|
|
@@ -1206,16 +1215,19 @@ def seed_context(seed: int):
|
|
1206
1215
|
The context manager does not only set the seed for the AX random state, but also for the numpy random state.
|
1207
1216
|
|
1208
1217
|
Args:
|
1209
|
-
|
1218
|
+
seed_or_key: The seed (an integer) or jax random key.
|
1210
1219
|
|
1211
|
-
Returns:
|
1212
|
-
The random state.
|
1213
1220
|
"""
|
1214
1221
|
old_jrand_key = DEFAULT.value
|
1215
1222
|
old_np_state = np.random.get_state()
|
1216
1223
|
try:
|
1217
|
-
np.
|
1218
|
-
|
1224
|
+
if np.size(seed_or_key) == 1: # seed
|
1225
|
+
np.random.seed(seed_or_key)
|
1226
|
+
elif np.size(seed_or_key) == 2: # jax random key
|
1227
|
+
np.random.seed(seed_or_key[0])
|
1228
|
+
else:
|
1229
|
+
raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
|
1230
|
+
DEFAULT.seed(seed_or_key)
|
1219
1231
|
yield
|
1220
1232
|
finally:
|
1221
1233
|
np.random.set_state(old_np_state)
|
@@ -1223,7 +1235,8 @@ def seed_context(seed: int):
|
|
1223
1235
|
|
1224
1236
|
|
1225
1237
|
def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
1226
|
-
r"""
|
1238
|
+
r"""
|
1239
|
+
Random values in a given shape.
|
1227
1240
|
|
1228
1241
|
.. note::
|
1229
1242
|
This is a convenience function for users porting code from Matlab,
|
@@ -4796,10 +4809,9 @@ def _size2shape(size):
|
|
4796
4809
|
|
4797
4810
|
|
4798
4811
|
def _check_shape(name, shape, *param_shapes):
|
4799
|
-
shape = core.as_named_shape(shape)
|
4800
4812
|
if param_shapes:
|
4801
|
-
shape_ = lax.broadcast_shapes(shape
|
4802
|
-
if shape
|
4813
|
+
shape_ = lax.broadcast_shapes(shape, *param_shapes)
|
4814
|
+
if shape != shape_:
|
4803
4815
|
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
4804
4816
|
"argument, and the result of broadcasting the shapes must equal "
|
4805
4817
|
"the shape argument, but got result {} for shape argument {}.")
|
brainstate/transform/_control.py
CHANGED
@@ -25,7 +25,7 @@ import jax.numpy as jnp
|
|
25
25
|
import numpy as np
|
26
26
|
|
27
27
|
from brainstate._utils import set_module_as
|
28
|
-
from ._jit_error import jit_error
|
28
|
+
from ._jit_error import jit_error, remove_vmap
|
29
29
|
from ._make_jaxpr import StatefulFunction, _assign_state_values
|
30
30
|
from ._progress_bar import ProgressBar
|
31
31
|
|
@@ -347,7 +347,7 @@ def _wrap_fun_with_pbar(fun, pbar_runner):
|
|
347
347
|
def new_fun(new_carry, inputs):
|
348
348
|
i, old_carry = new_carry
|
349
349
|
old_carry, old_outputs = fun(old_carry, inputs)
|
350
|
-
pbar_runner(i)
|
350
|
+
pbar_runner(remove_vmap(i, op='none'))
|
351
351
|
return (i + 1, old_carry), old_outputs
|
352
352
|
|
353
353
|
return new_fun
|
brainstate/typing.py
CHANGED
@@ -14,13 +14,16 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
|
17
|
+
import functools as ft
|
18
|
+
import typing
|
19
|
+
from typing import Sequence, Protocol, Union, Any, Generic, TypeVar
|
18
20
|
|
19
21
|
import brainunit as bu
|
20
22
|
import jax
|
21
23
|
import numpy as np
|
22
24
|
|
23
25
|
__all__ = [
|
26
|
+
'PyTree',
|
24
27
|
'Size',
|
25
28
|
'Axes',
|
26
29
|
'SeedOrKey',
|
@@ -29,6 +32,151 @@ __all__ = [
|
|
29
32
|
'DTypeLike',
|
30
33
|
]
|
31
34
|
|
35
|
+
_T = TypeVar("_T")
|
36
|
+
|
37
|
+
|
38
|
+
class _FakePyTree(Generic[_T]):
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
_FakePyTree.__name__ = "PyTree"
|
43
|
+
_FakePyTree.__qualname__ = "PyTree"
|
44
|
+
_FakePyTree.__module__ = "builtins"
|
45
|
+
|
46
|
+
|
47
|
+
class _MetaPyTree(type):
|
48
|
+
def __call__(self, *args, **kwargs):
|
49
|
+
raise RuntimeError("PyTree cannot be instantiated")
|
50
|
+
|
51
|
+
# Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
|
52
|
+
# the custom __instancecheck__ that we want.
|
53
|
+
# We can't add that __instancecheck__ via subclassing, e.g.
|
54
|
+
# type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
|
55
|
+
# isn't allowed.
|
56
|
+
# Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
|
57
|
+
# has __module__ "types", e.g. we get types.PyTree[int].
|
58
|
+
@ft.lru_cache(maxsize=None)
|
59
|
+
def __getitem__(cls, item):
|
60
|
+
if isinstance(item, tuple):
|
61
|
+
if len(item) == 2:
|
62
|
+
|
63
|
+
class X(PyTree):
|
64
|
+
leaftype = item[0]
|
65
|
+
structure = item[1].strip()
|
66
|
+
|
67
|
+
if not isinstance(X.structure, str):
|
68
|
+
raise ValueError(
|
69
|
+
"The structure annotation `struct` in "
|
70
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
|
71
|
+
f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
|
72
|
+
)
|
73
|
+
pieces = X.structure.split()
|
74
|
+
if len(pieces) == 0:
|
75
|
+
raise ValueError(
|
76
|
+
"The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
|
77
|
+
"cannot be the empty string."
|
78
|
+
)
|
79
|
+
for piece_index, piece in enumerate(pieces):
|
80
|
+
if (piece_index == 0) or (piece_index == len(pieces) - 1):
|
81
|
+
if piece == "...":
|
82
|
+
continue
|
83
|
+
if not piece.isidentifier():
|
84
|
+
raise ValueError(
|
85
|
+
"The string `struct` in "
|
86
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a "
|
87
|
+
"whitespace-separated sequence of identifiers, e.g. "
|
88
|
+
"`brainstate.typing.PyTree[leaftype, 'T']` or "
|
89
|
+
"`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
|
90
|
+
"(Here, 'identifier' is used in the same sense as in "
|
91
|
+
"regular Python, i.e. a valid variable name.)\n"
|
92
|
+
f"Got piece '{piece}' in overall structure '{X.structure}'."
|
93
|
+
)
|
94
|
+
name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
|
95
|
+
else:
|
96
|
+
raise ValueError(
|
97
|
+
"The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
|
98
|
+
"leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
|
99
|
+
"structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
|
100
|
+
f"{len(item)}."
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
name = str(_FakePyTree[item])
|
104
|
+
|
105
|
+
class X(PyTree):
|
106
|
+
leaftype = item
|
107
|
+
structure = None
|
108
|
+
|
109
|
+
X.__name__ = name
|
110
|
+
X.__qualname__ = name
|
111
|
+
if getattr(typing, "GENERATING_DOCUMENTATION", False):
|
112
|
+
X.__module__ = "builtins"
|
113
|
+
else:
|
114
|
+
X.__module__ = "brainstate.typing"
|
115
|
+
return X
|
116
|
+
|
117
|
+
|
118
|
+
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
|
119
|
+
# instancecheck for PyTree[foo], but subclassing
|
120
|
+
# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
|
121
|
+
PyTree = _MetaPyTree("PyTree", (), {})
|
122
|
+
if getattr(typing, "GENERATING_DOCUMENTATION", False):
|
123
|
+
PyTree.__module__ = "builtins"
|
124
|
+
else:
|
125
|
+
PyTree.__module__ = "brainstate.typing"
|
126
|
+
PyTree.__doc__ = """Represents a PyTree.
|
127
|
+
|
128
|
+
Annotations of the following sorts are supported:
|
129
|
+
```python
|
130
|
+
a: PyTree
|
131
|
+
b: PyTree[LeafType]
|
132
|
+
c: PyTree[LeafType, "T"]
|
133
|
+
d: PyTree[LeafType, "S T"]
|
134
|
+
e: PyTree[LeafType, "... T"]
|
135
|
+
f: PyTree[LeafType, "T ..."]
|
136
|
+
```
|
137
|
+
|
138
|
+
These correspond to:
|
139
|
+
|
140
|
+
a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
|
141
|
+
suggestively-named alternative to `Any`.
|
142
|
+
([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
|
143
|
+
|
144
|
+
b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
|
145
|
+
example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
|
146
|
+
|
147
|
+
c. A structure name can also be passed. In this case
|
148
|
+
`jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
|
149
|
+
This can be used to mark that multiple PyTrees all have the same structure:
|
150
|
+
```python
|
151
|
+
def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
|
152
|
+
...
|
153
|
+
```
|
154
|
+
|
155
|
+
d. A composite structure can be declared. In this case the variable must have a PyTree
|
156
|
+
structure each to the composition of multiple previously-bound PyTree structures.
|
157
|
+
For example:
|
158
|
+
```python
|
159
|
+
def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
|
160
|
+
...
|
161
|
+
|
162
|
+
x = (1, 2)
|
163
|
+
y = {"key": 3}
|
164
|
+
z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
|
165
|
+
f(x, y, z)
|
166
|
+
```
|
167
|
+
When performing runtime type-checking, all the individual pieces must have already
|
168
|
+
been bound to structures, otherwise the composite structure check will throw an error.
|
169
|
+
|
170
|
+
e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
|
171
|
+
must match the declared structure, but the upper levels can be arbitrary. As in the
|
172
|
+
previous case, all named pieces must already have been seen and their structures
|
173
|
+
bound.
|
174
|
+
|
175
|
+
f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
|
176
|
+
declared structure, but the lower levels can be arbitrary. As in the previous two
|
177
|
+
cases, all named pieces must already have been seen and their structures bound.
|
178
|
+
""" # noqa: E501
|
179
|
+
|
32
180
|
Size = Union[int, Sequence[int]]
|
33
181
|
Axes = Union[int, Sequence[int]]
|
34
182
|
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
@@ -44,7 +192,7 @@ ArrayLike = Union[
|
|
44
192
|
np.ndarray, # NumPy array type
|
45
193
|
np.bool_, np.number, # NumPy scalar types
|
46
194
|
bool, int, float, complex, # Python scalar types
|
47
|
-
bu.Quantity, #
|
195
|
+
bu.Quantity, # Quantity
|
48
196
|
]
|
49
197
|
|
50
198
|
# --- Dtype --- #
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/METADATA
RENAMED
@@ -1,9 +1,9 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.1.1.
|
3
|
+
Version: 0.0.1.1.post20240804
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
5
|
Home-page: https://github.com/brainpy/brainstate
|
6
|
-
Author:
|
6
|
+
Author: BDP
|
7
7
|
Author-email: BrainPy Team <chao.brain@qq.com>
|
8
8
|
License: Apache-2.0 license
|
9
9
|
Project-URL: homepage, http://github.com/brainpy
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/RECORD
RENAMED
@@ -1,33 +1,33 @@
|
|
1
1
|
brainstate/__init__.py,sha256=oxslZrm6wxtBQDqwJFb2BaAKZFmnp4d_esDkaeuGMWE,1410
|
2
|
-
brainstate/_module.py,sha256=
|
2
|
+
brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
|
3
3
|
brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
|
4
4
|
brainstate/_random_for_unit.py,sha256=eW4NJkX27VCCNWUwAlyt2otkeEthGKOpUoX6XJ6i95Y,1946
|
5
|
-
brainstate/_state.py,sha256=
|
5
|
+
brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
|
6
6
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
7
7
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
8
8
|
brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
|
9
|
-
brainstate/mixin.py,sha256=
|
10
|
-
brainstate/mixin_test.py,sha256
|
11
|
-
brainstate/random.py,sha256=
|
9
|
+
brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
|
10
|
+
brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
|
11
|
+
brainstate/random.py,sha256=rqwSsiUoeZwxhk9ot8NnOJA8iWMdZB0HaHOVuweJdZQ,188387
|
12
12
|
brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
|
13
13
|
brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
14
|
-
brainstate/typing.py,sha256=
|
14
|
+
brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
|
15
15
|
brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
|
16
16
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
17
|
-
brainstate/functional/_activations.py,sha256=
|
17
|
+
brainstate/functional/_activations.py,sha256=gvZ9E1-TsEUlyO7Om0eYzlM9DF-14_A32-gta1mjGo4,17798
|
18
18
|
brainstate/functional/_normalization.py,sha256=IxE580waloZylZVXcpUUK4bWQdlE6oSPfafaKYfDkbg,2169
|
19
|
-
brainstate/functional/_others.py,sha256=
|
19
|
+
brainstate/functional/_others.py,sha256=1Epp75RkGYobMc2kHISZuS-_xnTFk3zHb1UHadwugCo,1711
|
20
20
|
brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il_Jw,2574
|
21
21
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
22
22
|
brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
|
23
|
-
brainstate/init/_generic.py,sha256=
|
24
|
-
brainstate/init/_random_inits.py,sha256=
|
23
|
+
brainstate/init/_generic.py,sha256=LB7IQfswOG6X-q0QX5N8T5vZmxdygetsSBQ6iXlZ0oU,7324
|
24
|
+
brainstate/init/_random_inits.py,sha256=LsfvKSX4wsR7Kh5jgKgdyXTCEEa5Nn_iYcp_2GgLQKY,16030
|
25
25
|
brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
|
26
26
|
brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
|
27
27
|
brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
|
28
28
|
brainstate/nn/_connections.py,sha256=GSOW2IbpJRHdPyF4nFJ2RPgO8y6SVHT1Gn-pbri9pMk,22970
|
29
29
|
brainstate/nn/_dynamics.py,sha256=OeYYXv1dqjUDcCsRhZo1XS7SP2li1vlH9uhME_PE9v0,13205
|
30
|
-
brainstate/nn/_elementwise.py,sha256=
|
30
|
+
brainstate/nn/_elementwise.py,sha256=Br2yd1kdr06iWGSvpoebWWO6suXFDiF8PQv_hOX9kZQ,43599
|
31
31
|
brainstate/nn/_embedding.py,sha256=WbgrIaM_14abN8zBDr0xipBOsFc8dXP2m7Z_aRLAfmU,2249
|
32
32
|
brainstate/nn/_misc.py,sha256=Xc4U4NLmvfnKdBNDayFrRBPAy3p0beS6T9C59rIDP00,3790
|
33
33
|
brainstate/nn/_normalizations.py,sha256=9yVDORAEpqEkL9MYSPU4m7C4q8Qj5UNsPh9sKmIt5gQ,14329
|
@@ -50,17 +50,17 @@ brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aa
|
|
50
50
|
brainstate/transform/__init__.py,sha256=my2X4ZW0uKZRfN82zyGEPizWNJ0fsSP2akvmkjn43ck,1458
|
51
51
|
brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
|
52
52
|
brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
|
53
|
-
brainstate/transform/_control.py,sha256=
|
53
|
+
brainstate/transform/_control.py,sha256=0NFUGLIenqKuBhBiTmY0YgCrl2GI1ZbuWMW0DSOolpE,26874
|
54
54
|
brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
|
55
55
|
brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
|
56
56
|
brainstate/transform/_jit_error.py,sha256=8rGRx8dtvmPWmHVOsfz30EUMXSix-m2PKM3Ni_9-_7I,4829
|
57
|
-
brainstate/transform/_jit_error_test.py,sha256=
|
57
|
+
brainstate/transform/_jit_error_test.py,sha256=GAVGL0eNJ5Fu0lHABCGc-nLfa_0x0tw_VPfURB-nhLc,1862
|
58
58
|
brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
|
59
|
-
brainstate/transform/_make_jaxpr.py,sha256=
|
59
|
+
brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
|
60
60
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
61
61
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
62
|
-
brainstate-0.0.1.1.
|
63
|
-
brainstate-0.0.1.1.
|
64
|
-
brainstate-0.0.1.1.
|
65
|
-
brainstate-0.0.1.1.
|
66
|
-
brainstate-0.0.1.1.
|
62
|
+
brainstate-0.0.1.1.post20240804.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
63
|
+
brainstate-0.0.1.1.post20240804.dist-info/METADATA,sha256=RTuqQrR0-syn5SyxoyEbfbdAUpXBRxNMzpaqnVM2cqQ,3807
|
64
|
+
brainstate-0.0.1.1.post20240804.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
65
|
+
brainstate-0.0.1.1.post20240804.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
66
|
+
brainstate-0.0.1.1.post20240804.dist-info/RECORD,,
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/LICENSE
RENAMED
File without changes
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240804.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|