brainstate 0.0.1.1.post20240708__py2.py3-none-any.whl → 0.0.1.1.post20240804__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_module.py CHANGED
@@ -61,11 +61,9 @@ from ._state import State, StateDictManager, visible_state_dict
61
61
  from ._utils import set_module_as
62
62
  from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
63
63
  from .transform import jit_error
64
+ from .typing import Size, ArrayLike, PyTree
64
65
  from .util import unique_name, DictManager, get_unique_name
65
66
 
66
- Shape = Union[int, Sequence[int]]
67
- PyTree = Any
68
- ArrayLike = jax.typing.ArrayLike
69
67
 
70
68
  delay_identifier = '_*_delay_of_'
71
69
  _DELAY_ROTATE = 'rotation'
@@ -805,7 +803,7 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
805
803
 
806
804
  def __init__(
807
805
  self,
808
- size: Shape,
806
+ size: Size,
809
807
  keep_size: bool = False,
810
808
  name: Optional[str] = None,
811
809
  mode: Optional[Mode] = None,
@@ -1275,25 +1273,25 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1275
1273
 
1276
1274
  if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
1277
1275
  # def _interp(target):
1278
- # if len(indices) > 0:
1279
- # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
1280
- # if self.delay_method == _DELAY_ROTATE:
1281
- # i = environ.get(environ.I, desc='The time step index.')
1282
- # _interp_fun = partial(jnp.interp, period=self.max_length)
1283
- # for dim in range(1, target.ndim, 1):
1284
- # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1285
- # di = i - jnp.arange(self.max_length)
1286
- # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1287
- # return _interp_fun(float_time_step, delay_idx, target)
1288
- #
1289
- # elif self.delay_method == _DELAY_CONCAT:
1290
- # _interp_fun = partial(jnp.interp, period=self.max_length)
1291
- # for dim in range(1, target.ndim, 1):
1292
- # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1293
- # return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
1294
- #
1295
- # else:
1296
- # raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1276
+ # if len(indices) > 0:
1277
+ # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
1278
+ # if self.delay_method == _DELAY_ROTATE:
1279
+ # i = environ.get(environ.I, desc='The time step index.')
1280
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1281
+ # for dim in range(1, target.ndim, 1):
1282
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1283
+ # di = i - jnp.arange(self.max_length)
1284
+ # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1285
+ # return _interp_fun(float_time_step, delay_idx, target)
1286
+ #
1287
+ # elif self.delay_method == _DELAY_CONCAT:
1288
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1289
+ # for dim in range(1, target.ndim, 1):
1290
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1291
+ # return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
1292
+ #
1293
+ # else:
1294
+ # raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1297
1295
  # return jax.tree.map(_interp, self.history.value)
1298
1296
 
1299
1297
  data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
brainstate/_state.py CHANGED
@@ -15,18 +15,16 @@
15
15
 
16
16
  import contextlib
17
17
  import threading
18
- from typing import Any, Tuple, Dict, List, Callable
18
+ from typing import Any, Tuple, Dict, List, Callable, Optional
19
19
 
20
20
  import jax
21
21
  import numpy as np
22
22
  from jax.api_util import shaped_abstractify
23
23
  from jax.extend import source_info_util
24
24
 
25
+ from .typing import ArrayLike, PyTree
25
26
  from .util import DictManager
26
27
 
27
- PyTree = Any
28
- max_int = np.iinfo(np.int32)
29
-
30
28
  __all__ = [
31
29
  'State', 'ShortTermState', 'LongTermState', 'ParamState',
32
30
  'StateDictManager',
@@ -36,6 +34,7 @@ __all__ = [
36
34
  ]
37
35
 
38
36
  _pytree_registered_objects = set()
37
+ max_int = np.iinfo(np.int32)
39
38
 
40
39
 
41
40
  def _register_pytree_cls(cls):
@@ -108,9 +107,9 @@ class State(object):
108
107
  value: PyTree. It can be anything as a pyTree.
109
108
  """
110
109
  __module__ = 'brainstate'
111
- __slots__ = ('_value', '_tree', '_level', '_source_info', '_check_tree')
110
+ __slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree')
112
111
 
113
- def __init__(self, value: PyTree):
112
+ def __init__(self, value: PyTree[ArrayLike], name: Optional[str] = None):
114
113
  if isinstance(value, State):
115
114
  value = value.value
116
115
  self._value = value
@@ -118,9 +117,24 @@ class State(object):
118
117
  self._check_tree = False
119
118
  self._level = len(thread_local_stack.stack)
120
119
  self._source_info = source_info_util.current()
120
+ self._name = name
121
+
122
+ @property
123
+ def name(self) -> Optional[str]:
124
+ """
125
+ The name of the state.
126
+ """
127
+ return self._name
128
+
129
+ @name.setter
130
+ def name(self, name: str) -> None:
131
+ """
132
+ Set the name of the state.
133
+ """
134
+ self._name = name
121
135
 
122
136
  @property
123
- def value(self) -> PyTree:
137
+ def value(self) -> PyTree[ArrayLike]:
124
138
  """
125
139
  The data and its value.
126
140
  """
@@ -210,7 +224,10 @@ class State(object):
210
224
  leaves, tree = jax.tree.flatten(self._value)
211
225
  leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
212
226
  tree_info = jax.tree.unflatten(tree, leaves_info)
213
- return f'{self.__class__.__name__}({tree_info})'
227
+ if self.name is None:
228
+ return f'{self.__class__.__name__}({tree_info})'
229
+ else:
230
+ return f'{self.__class__.__name__}({self.name}: {tree_info})'
214
231
 
215
232
 
216
233
  class ShapeDtype:
@@ -25,8 +25,8 @@ from typing import Any, Union, Sequence
25
25
  import jax
26
26
  import jax.numpy as jnp
27
27
  from jax.scipy.special import logsumexp
28
- from jax.typing import ArrayLike
29
28
 
29
+ from brainstate.typing import ArrayLike
30
30
  from .. import random
31
31
 
32
32
  __all__ = [
@@ -16,12 +16,11 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  from functools import partial
19
- from typing import Any
20
19
 
21
20
  import jax
22
21
  import jax.numpy as jnp
23
22
 
24
- PyTree = Any
23
+ from brainstate.typing import PyTree
25
24
 
26
25
  __all__ = [
27
26
  'clip_grad_norm',
@@ -22,8 +22,8 @@ import jax
22
22
  import numpy as np
23
23
 
24
24
  from brainstate._state import State
25
+ from brainstate.typing import ArrayLike
25
26
  from ._base import to_size
26
- from ..typing import ArrayLike
27
27
 
28
28
  __all__ = [
29
29
  'param',
@@ -83,7 +83,7 @@ def _expand_params_to_match_sizes(params, sizes):
83
83
 
84
84
 
85
85
  def param(
86
- parameter: Union[Callable, ArrayLike],
86
+ parameter: Union[Callable, ArrayLike, State],
87
87
  sizes: Union[int, Sequence[int]],
88
88
  batch_size: Optional[int] = None,
89
89
  allow_none: bool = True,
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
  import numpy as np
23
23
 
24
24
  from brainstate import environ, random
25
+ from brainstate.typing import ArrayLike
25
26
  from ._base import Initializer, to_size
26
- from ..typing import ArrayLike
27
27
 
28
28
  __all__ = [
29
29
  'Normal',
brainstate/mixin.py CHANGED
@@ -15,14 +15,15 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from typing import Sequence, Optional, TypeVar, Any
18
+ from typing import Sequence, Optional, TypeVar
19
19
  from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
20
20
 
21
- T = TypeVar('T')
22
- PyTree = Any
21
+ from .typing import PyTree
23
22
 
23
+ T = TypeVar('T')
24
24
  State = None
25
25
 
26
+
26
27
  __all__ = [
27
28
  'Mixin',
28
29
  'DelayedInit',
@@ -207,7 +208,7 @@ class _JointGenericAlias(_UnionGenericAlias, _root=True):
207
208
 
208
209
  @_SpecialForm
209
210
  def JointTypes(self, parameters):
210
- """All of types; AllOfTypes[X, Y] means both X and Y.
211
+ """Joint types; JointTypes[X, Y] means both X and Y.
211
212
 
212
213
  To define a union, use e.g. Union[int, str].
213
214
 
@@ -216,28 +217,28 @@ def JointTypes(self, parameters):
216
217
  - None as an argument is a special case and is replaced by `type(None)`.
217
218
  - Unions of unions are flattened, e.g.::
218
219
 
219
- AllOfTypes[AllOfTypes[int, str], float] == AllOfTypes[int, str, float]
220
+ JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
220
221
 
221
222
  - Unions of a single argument vanish, e.g.::
222
223
 
223
- AllOfTypes[int] == int # The constructor actually returns int
224
+ JointTypes[int] == int # The constructor actually returns int
224
225
 
225
226
  - Redundant arguments are skipped, e.g.::
226
227
 
227
- AllOfTypes[int, str, int] == AllOfTypes[int, str]
228
+ JointTypes[int, str, int] == JointTypes[int, str]
228
229
 
229
230
  - When comparing unions, the argument order is ignored, e.g.::
230
231
 
231
- AllOfTypes[int, str] == AllOfTypes[str, int]
232
+ JointTypes[int, str] == JointTypes[str, int]
232
233
 
233
- - You cannot subclass or instantiate a AllOfTypes.
234
- - You can use Optional[X] as a shorthand for AllOfTypes[X, None].
234
+ - You cannot subclass or instantiate a JointTypes.
235
+ - You can use Optional[X] as a shorthand for JointTypes[X, None].
235
236
  """
236
237
  if parameters == ():
237
238
  raise TypeError("Cannot take a Joint of no types.")
238
239
  if not isinstance(parameters, tuple):
239
240
  parameters = (parameters,)
240
- msg = "AllOfTypes[arg, ...]: each arg must be a type."
241
+ msg = "JointTypes[arg, ...]: each arg must be a type."
241
242
  parameters = tuple(_type_check(p, msg) for p in parameters)
242
243
  parameters = _remove_dups_flatten(parameters)
243
244
  if len(parameters) == 1:
brainstate/mixin_test.py CHANGED
@@ -30,6 +30,8 @@ class TestMixin(unittest.TestCase):
30
30
  self.assertTrue(bc.mixin.Training)
31
31
 
32
32
 
33
+
34
+
33
35
  class TestMode(unittest.TestCase):
34
36
  def test_JointMode(self):
35
37
  a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
@@ -1139,13 +1139,13 @@ class Dropout(Module, ElementWiseBlock):
1139
1139
  name: Optional[str] = None
1140
1140
  ) -> None:
1141
1141
  super().__init__(mode=mode, name=name)
1142
- assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
1142
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
1143
1143
  self.prob = prob
1144
1144
 
1145
1145
  def __call__(self, x):
1146
1146
  dtype = bu.math.get_dtype(x)
1147
1147
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1148
- if fit_phase:
1148
+ if fit_phase and self.prob < 1.:
1149
1149
  keep_mask = random.bernoulli(self.prob, x.shape)
1150
1150
  return jnp.where(keep_mask,
1151
1151
  jnp.asarray(x / self.prob, dtype=dtype),
brainstate/random.py CHANGED
@@ -1167,23 +1167,32 @@ def default_rng(seed_or_key=None, clone: bool = True) -> RandomState:
1167
1167
  return RandomState(seed_or_key)
1168
1168
 
1169
1169
 
1170
- def seed(seed: int = None):
1170
+ def seed(seed_or_key: int = None):
1171
1171
  """Sets a new random seed.
1172
1172
 
1173
1173
  Parameters
1174
1174
  ----------
1175
- seed: int, optional
1176
- The random seed.
1175
+ seed_or_key: int, optional
1176
+ The random seed (an integer) or jax random key.
1177
1177
  """
1178
1178
  with jax.ensure_compile_time_eval():
1179
- if seed is None:
1180
- seed = np.random.randint(0, 100000)
1181
- np.random.seed(seed)
1182
- DEFAULT.seed(seed)
1179
+ if seed_or_key is None:
1180
+ seed_or_key = np.random.randint(0, 100000)
1181
+
1182
+ # numpy random seed
1183
+ if np.size(seed_or_key) == 1: # seed
1184
+ np.random.seed(seed_or_key)
1185
+ elif np.size(seed_or_key) == 2: # jax random key
1186
+ np.random.seed(seed_or_key[0])
1187
+ else:
1188
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
1189
+
1190
+ # jax random seed
1191
+ DEFAULT.seed(seed_or_key)
1183
1192
 
1184
1193
 
1185
1194
  @contextmanager
1186
- def seed_context(seed: int):
1195
+ def seed_context(seed_or_key: SeedOrKey):
1187
1196
  """
1188
1197
  A context manager that sets the random seed for the duration of the block.
1189
1198
 
@@ -1206,16 +1215,19 @@ def seed_context(seed: int):
1206
1215
  The context manager does not only set the seed for the AX random state, but also for the numpy random state.
1207
1216
 
1208
1217
  Args:
1209
- seed: The seed (an integer).
1218
+ seed_or_key: The seed (an integer) or jax random key.
1210
1219
 
1211
- Returns:
1212
- The random state.
1213
1220
  """
1214
1221
  old_jrand_key = DEFAULT.value
1215
1222
  old_np_state = np.random.get_state()
1216
1223
  try:
1217
- np.random.seed(seed)
1218
- DEFAULT.seed(seed)
1224
+ if np.size(seed_or_key) == 1: # seed
1225
+ np.random.seed(seed_or_key)
1226
+ elif np.size(seed_or_key) == 2: # jax random key
1227
+ np.random.seed(seed_or_key[0])
1228
+ else:
1229
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
1230
+ DEFAULT.seed(seed_or_key)
1219
1231
  yield
1220
1232
  finally:
1221
1233
  np.random.set_state(old_np_state)
@@ -1223,7 +1235,8 @@ def seed_context(seed: int):
1223
1235
 
1224
1236
 
1225
1237
  def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1226
- r"""Random values in a given shape.
1238
+ r"""
1239
+ Random values in a given shape.
1227
1240
 
1228
1241
  .. note::
1229
1242
  This is a convenience function for users porting code from Matlab,
@@ -4796,10 +4809,9 @@ def _size2shape(size):
4796
4809
 
4797
4810
 
4798
4811
  def _check_shape(name, shape, *param_shapes):
4799
- shape = core.as_named_shape(shape)
4800
4812
  if param_shapes:
4801
- shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
4802
- if shape.positional != shape_:
4813
+ shape_ = lax.broadcast_shapes(shape, *param_shapes)
4814
+ if shape != shape_:
4803
4815
  msg = ("{} parameter shapes must be broadcast-compatible with shape "
4804
4816
  "argument, and the result of broadcasting the shapes must equal "
4805
4817
  "the shape argument, but got result {} for shape argument {}.")
@@ -25,7 +25,7 @@ import jax.numpy as jnp
25
25
  import numpy as np
26
26
 
27
27
  from brainstate._utils import set_module_as
28
- from ._jit_error import jit_error
28
+ from ._jit_error import jit_error, remove_vmap
29
29
  from ._make_jaxpr import StatefulFunction, _assign_state_values
30
30
  from ._progress_bar import ProgressBar
31
31
 
@@ -347,7 +347,7 @@ def _wrap_fun_with_pbar(fun, pbar_runner):
347
347
  def new_fun(new_carry, inputs):
348
348
  i, old_carry = new_carry
349
349
  old_carry, old_outputs = fun(old_carry, inputs)
350
- pbar_runner(i)
350
+ pbar_runner(remove_vmap(i, op='none'))
351
351
  return (i + 1, old_carry), old_outputs
352
352
 
353
353
  return new_fun
@@ -16,8 +16,8 @@
16
16
  import unittest
17
17
 
18
18
  import jax
19
- import jaxlib.xla_extension
20
19
  import jax.numpy as jnp
20
+ import jaxlib.xla_extension
21
21
 
22
22
  import brainstate as bst
23
23
 
@@ -71,8 +71,8 @@ from jax.util import wraps
71
71
 
72
72
  from brainstate._state import State, StateTrace
73
73
  from brainstate._utils import set_module_as
74
+ from brainstate.typing import PyTree
74
75
 
75
- PyTree = Any
76
76
  AxisName = Hashable
77
77
 
78
78
  __all__ = [
brainstate/typing.py CHANGED
@@ -14,13 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from typing import Any, Sequence, Protocol, Union
17
+ import functools as ft
18
+ import typing
19
+ from typing import Sequence, Protocol, Union, Any, Generic, TypeVar
18
20
 
19
21
  import brainunit as bu
20
22
  import jax
21
23
  import numpy as np
22
24
 
23
25
  __all__ = [
26
+ 'PyTree',
24
27
  'Size',
25
28
  'Axes',
26
29
  'SeedOrKey',
@@ -29,6 +32,151 @@ __all__ = [
29
32
  'DTypeLike',
30
33
  ]
31
34
 
35
+ _T = TypeVar("_T")
36
+
37
+
38
+ class _FakePyTree(Generic[_T]):
39
+ pass
40
+
41
+
42
+ _FakePyTree.__name__ = "PyTree"
43
+ _FakePyTree.__qualname__ = "PyTree"
44
+ _FakePyTree.__module__ = "builtins"
45
+
46
+
47
+ class _MetaPyTree(type):
48
+ def __call__(self, *args, **kwargs):
49
+ raise RuntimeError("PyTree cannot be instantiated")
50
+
51
+ # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
52
+ # the custom __instancecheck__ that we want.
53
+ # We can't add that __instancecheck__ via subclassing, e.g.
54
+ # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
55
+ # isn't allowed.
56
+ # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
57
+ # has __module__ "types", e.g. we get types.PyTree[int].
58
+ @ft.lru_cache(maxsize=None)
59
+ def __getitem__(cls, item):
60
+ if isinstance(item, tuple):
61
+ if len(item) == 2:
62
+
63
+ class X(PyTree):
64
+ leaftype = item[0]
65
+ structure = item[1].strip()
66
+
67
+ if not isinstance(X.structure, str):
68
+ raise ValueError(
69
+ "The structure annotation `struct` in "
70
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
71
+ f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
72
+ )
73
+ pieces = X.structure.split()
74
+ if len(pieces) == 0:
75
+ raise ValueError(
76
+ "The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
77
+ "cannot be the empty string."
78
+ )
79
+ for piece_index, piece in enumerate(pieces):
80
+ if (piece_index == 0) or (piece_index == len(pieces) - 1):
81
+ if piece == "...":
82
+ continue
83
+ if not piece.isidentifier():
84
+ raise ValueError(
85
+ "The string `struct` in "
86
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a "
87
+ "whitespace-separated sequence of identifiers, e.g. "
88
+ "`brainstate.typing.PyTree[leaftype, 'T']` or "
89
+ "`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
90
+ "(Here, 'identifier' is used in the same sense as in "
91
+ "regular Python, i.e. a valid variable name.)\n"
92
+ f"Got piece '{piece}' in overall structure '{X.structure}'."
93
+ )
94
+ name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
95
+ else:
96
+ raise ValueError(
97
+ "The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
98
+ "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
99
+ "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
100
+ f"{len(item)}."
101
+ )
102
+ else:
103
+ name = str(_FakePyTree[item])
104
+
105
+ class X(PyTree):
106
+ leaftype = item
107
+ structure = None
108
+
109
+ X.__name__ = name
110
+ X.__qualname__ = name
111
+ if getattr(typing, "GENERATING_DOCUMENTATION", False):
112
+ X.__module__ = "builtins"
113
+ else:
114
+ X.__module__ = "brainstate.typing"
115
+ return X
116
+
117
+
118
+ # Can't do `class PyTree(Generic[_T]): ...` because we need to override the
119
+ # instancecheck for PyTree[foo], but subclassing
120
+ # `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
121
+ PyTree = _MetaPyTree("PyTree", (), {})
122
+ if getattr(typing, "GENERATING_DOCUMENTATION", False):
123
+ PyTree.__module__ = "builtins"
124
+ else:
125
+ PyTree.__module__ = "brainstate.typing"
126
+ PyTree.__doc__ = """Represents a PyTree.
127
+
128
+ Annotations of the following sorts are supported:
129
+ ```python
130
+ a: PyTree
131
+ b: PyTree[LeafType]
132
+ c: PyTree[LeafType, "T"]
133
+ d: PyTree[LeafType, "S T"]
134
+ e: PyTree[LeafType, "... T"]
135
+ f: PyTree[LeafType, "T ..."]
136
+ ```
137
+
138
+ These correspond to:
139
+
140
+ a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
141
+ suggestively-named alternative to `Any`.
142
+ ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
143
+
144
+ b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
145
+ example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
146
+
147
+ c. A structure name can also be passed. In this case
148
+ `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
149
+ This can be used to mark that multiple PyTrees all have the same structure:
150
+ ```python
151
+ def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
152
+ ...
153
+ ```
154
+
155
+ d. A composite structure can be declared. In this case the variable must have a PyTree
156
+ structure each to the composition of multiple previously-bound PyTree structures.
157
+ For example:
158
+ ```python
159
+ def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
160
+ ...
161
+
162
+ x = (1, 2)
163
+ y = {"key": 3}
164
+ z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
165
+ f(x, y, z)
166
+ ```
167
+ When performing runtime type-checking, all the individual pieces must have already
168
+ been bound to structures, otherwise the composite structure check will throw an error.
169
+
170
+ e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
171
+ must match the declared structure, but the upper levels can be arbitrary. As in the
172
+ previous case, all named pieces must already have been seen and their structures
173
+ bound.
174
+
175
+ f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
176
+ declared structure, but the lower levels can be arbitrary. As in the previous two
177
+ cases, all named pieces must already have been seen and their structures bound.
178
+ """ # noqa: E501
179
+
32
180
  Size = Union[int, Sequence[int]]
33
181
  Axes = Union[int, Sequence[int]]
34
182
  SeedOrKey = Union[int, jax.Array, np.ndarray]
@@ -44,7 +192,7 @@ ArrayLike = Union[
44
192
  np.ndarray, # NumPy array type
45
193
  np.bool_, np.number, # NumPy scalar types
46
194
  bool, int, float, complex, # Python scalar types
47
- bu.Quantity, # quantity
195
+ bu.Quantity, # Quantity
48
196
  ]
49
197
 
50
198
  # --- Dtype --- #
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1.1.post20240708
3
+ Version: 0.0.1.1.post20240804
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
- Author: BrainPy Team
6
+ Author: BDP
7
7
  Author-email: BrainPy Team <chao.brain@qq.com>
8
8
  License: Apache-2.0 license
9
9
  Project-URL: homepage, http://github.com/brainpy
@@ -1,33 +1,33 @@
1
1
  brainstate/__init__.py,sha256=oxslZrm6wxtBQDqwJFb2BaAKZFmnp4d_esDkaeuGMWE,1410
2
- brainstate/_module.py,sha256=UjhfmY26VHQ-kF6U4l68AslGU6vhfD-dR7gF-4Io5ic,52520
2
+ brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
3
3
  brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
4
4
  brainstate/_random_for_unit.py,sha256=eW4NJkX27VCCNWUwAlyt2otkeEthGKOpUoX6XJ6i95Y,1946
5
- brainstate/_state.py,sha256=ykQluBkdKIcQsd7-pU_UlpnByWFof_8TYQl1hGn7HS8,11629
5
+ brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
6
6
  brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
7
7
  brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
8
8
  brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
9
- brainstate/mixin.py,sha256=f0XovWlyTOQKC63WTtzSbztRSld3pKEHy9-I1gVZWLg,10748
10
- brainstate/mixin_test.py,sha256=qFLw9Bq8TkOMg8M8CcI92BYvFekXkjyCC9lXSEVa8Ck,2919
11
- brainstate/random.py,sha256=OMa4739GbrQpKspEx0TbeBqjA4yvwwJzdtpr-kJZN6s,187841
9
+ brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
10
+ brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
11
+ brainstate/random.py,sha256=rqwSsiUoeZwxhk9ot8NnOJA8iWMdZB0HaHOVuweJdZQ,188387
12
12
  brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
13
13
  brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
14
- brainstate/typing.py,sha256=tZyXpopNtp1UvqbntnmlUP6OBTM_ywyV9QBQ6Gu3IdU,2232
14
+ brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
15
15
  brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
16
16
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
17
- brainstate/functional/_activations.py,sha256=xlwvYG8qvpkfMEZTFxD_4amW63ZfEa8x3vzVH2hDgeY,17791
17
+ brainstate/functional/_activations.py,sha256=gvZ9E1-TsEUlyO7Om0eYzlM9DF-14_A32-gta1mjGo4,17798
18
18
  brainstate/functional/_normalization.py,sha256=IxE580waloZylZVXcpUUK4bWQdlE6oSPfafaKYfDkbg,2169
19
- brainstate/functional/_others.py,sha256=ifB-l82y7ZB632yLUJOEcpkRY-yOoiJ0mtDOxNilp4M,1711
19
+ brainstate/functional/_others.py,sha256=1Epp75RkGYobMc2kHISZuS-_xnTFk3zHb1UHadwugCo,1711
20
20
  brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il_Jw,2574
21
21
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
22
22
  brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
23
- brainstate/init/_generic.py,sha256=yLkmYRWpGHUK_nYx_fWPe74s_Cd-JDdQ6mUqpi4yqcc,7308
24
- brainstate/init/_random_inits.py,sha256=ycxH9WyKvPhRbk9PkamCLlc9aX5YokREBfSCbrbFQQ4,16021
23
+ brainstate/init/_generic.py,sha256=LB7IQfswOG6X-q0QX5N8T5vZmxdygetsSBQ6iXlZ0oU,7324
24
+ brainstate/init/_random_inits.py,sha256=LsfvKSX4wsR7Kh5jgKgdyXTCEEa5Nn_iYcp_2GgLQKY,16030
25
25
  brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
26
26
  brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
27
27
  brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
28
28
  brainstate/nn/_connections.py,sha256=GSOW2IbpJRHdPyF4nFJ2RPgO8y6SVHT1Gn-pbri9pMk,22970
29
29
  brainstate/nn/_dynamics.py,sha256=OeYYXv1dqjUDcCsRhZo1XS7SP2li1vlH9uhME_PE9v0,13205
30
- brainstate/nn/_elementwise.py,sha256=6BTqSvSnaHhldwB5ol5OV0hPJ5yJ-Jpm4WSrtFKMNoQ,43579
30
+ brainstate/nn/_elementwise.py,sha256=Br2yd1kdr06iWGSvpoebWWO6suXFDiF8PQv_hOX9kZQ,43599
31
31
  brainstate/nn/_embedding.py,sha256=WbgrIaM_14abN8zBDr0xipBOsFc8dXP2m7Z_aRLAfmU,2249
32
32
  brainstate/nn/_misc.py,sha256=Xc4U4NLmvfnKdBNDayFrRBPAy3p0beS6T9C59rIDP00,3790
33
33
  brainstate/nn/_normalizations.py,sha256=9yVDORAEpqEkL9MYSPU4m7C4q8Qj5UNsPh9sKmIt5gQ,14329
@@ -50,17 +50,17 @@ brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aa
50
50
  brainstate/transform/__init__.py,sha256=my2X4ZW0uKZRfN82zyGEPizWNJ0fsSP2akvmkjn43ck,1458
51
51
  brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
52
52
  brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
53
- brainstate/transform/_control.py,sha256=NWceTIuLlj2uGTdNcqBAXgnaLuChOGgAtIXtFn5vdLU,26837
53
+ brainstate/transform/_control.py,sha256=0NFUGLIenqKuBhBiTmY0YgCrl2GI1ZbuWMW0DSOolpE,26874
54
54
  brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
55
55
  brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
56
56
  brainstate/transform/_jit_error.py,sha256=8rGRx8dtvmPWmHVOsfz30EUMXSix-m2PKM3Ni_9-_7I,4829
57
- brainstate/transform/_jit_error_test.py,sha256=MGFSCih6GVXTmLh0K-m8RX8N4x2_CAgtcGOIWfzol88,1862
57
+ brainstate/transform/_jit_error_test.py,sha256=GAVGL0eNJ5Fu0lHABCGc-nLfa_0x0tw_VPfURB-nhLc,1862
58
58
  brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
59
- brainstate/transform/_make_jaxpr.py,sha256=vMPruKfp5Ugv8RL-9wGfQdSumLZdLtThZvv3sU9MDjE,30426
59
+ brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
60
60
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
61
61
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.1.1.post20240708.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.1.1.post20240708.dist-info/METADATA,sha256=HN33Hoom47puNsauzxdmXvsmcP5RqrUpeR_2tVf7Y5U,3816
64
- brainstate-0.0.1.1.post20240708.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
- brainstate-0.0.1.1.post20240708.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.1.1.post20240708.dist-info/RECORD,,
62
+ brainstate-0.0.1.1.post20240804.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
+ brainstate-0.0.1.1.post20240804.dist-info/METADATA,sha256=RTuqQrR0-syn5SyxoyEbfbdAUpXBRxNMzpaqnVM2cqQ,3807
64
+ brainstate-0.0.1.1.post20240804.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
+ brainstate-0.0.1.1.post20240804.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
+ brainstate-0.0.1.1.post20240804.dist-info/RECORD,,