brainstate 0.0.1.1.post20240802__py2.py3-none-any.whl → 0.0.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Brain Dynamics Programming
18
18
  """
19
19
 
20
- __version__ = "0.0.1.1"
20
+ __version__ = "0.0.2"
21
21
 
22
22
  from . import environ
23
23
  from . import functional
brainstate/_module.py CHANGED
@@ -61,11 +61,9 @@ from ._state import State, StateDictManager, visible_state_dict
61
61
  from ._utils import set_module_as
62
62
  from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
63
63
  from .transform import jit_error
64
+ from .typing import Size, ArrayLike, PyTree
64
65
  from .util import unique_name, DictManager, get_unique_name
65
66
 
66
- Shape = Union[int, Sequence[int]]
67
- PyTree = Any
68
- ArrayLike = jax.typing.ArrayLike
69
67
 
70
68
  delay_identifier = '_*_delay_of_'
71
69
  _DELAY_ROTATE = 'rotation'
@@ -805,7 +803,7 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
805
803
 
806
804
  def __init__(
807
805
  self,
808
- size: Shape,
806
+ size: Size,
809
807
  keep_size: bool = False,
810
808
  name: Optional[str] = None,
811
809
  mode: Optional[Mode] = None,
@@ -1275,25 +1273,25 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1275
1273
 
1276
1274
  if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
1277
1275
  # def _interp(target):
1278
- # if len(indices) > 0:
1279
- # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
1280
- # if self.delay_method == _DELAY_ROTATE:
1281
- # i = environ.get(environ.I, desc='The time step index.')
1282
- # _interp_fun = partial(jnp.interp, period=self.max_length)
1283
- # for dim in range(1, target.ndim, 1):
1284
- # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1285
- # di = i - jnp.arange(self.max_length)
1286
- # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1287
- # return _interp_fun(float_time_step, delay_idx, target)
1288
- #
1289
- # elif self.delay_method == _DELAY_CONCAT:
1290
- # _interp_fun = partial(jnp.interp, period=self.max_length)
1291
- # for dim in range(1, target.ndim, 1):
1292
- # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1293
- # return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
1294
- #
1295
- # else:
1296
- # raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1276
+ # if len(indices) > 0:
1277
+ # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
1278
+ # if self.delay_method == _DELAY_ROTATE:
1279
+ # i = environ.get(environ.I, desc='The time step index.')
1280
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1281
+ # for dim in range(1, target.ndim, 1):
1282
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1283
+ # di = i - jnp.arange(self.max_length)
1284
+ # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1285
+ # return _interp_fun(float_time_step, delay_idx, target)
1286
+ #
1287
+ # elif self.delay_method == _DELAY_CONCAT:
1288
+ # _interp_fun = partial(jnp.interp, period=self.max_length)
1289
+ # for dim in range(1, target.ndim, 1):
1290
+ # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1291
+ # return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
1292
+ #
1293
+ # else:
1294
+ # raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1297
1295
  # return jax.tree.map(_interp, self.history.value)
1298
1296
 
1299
1297
  data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
@@ -30,13 +30,14 @@ def uniform_for_unit(
30
30
  maxval: ArrayLike = 1.
31
31
  ) -> jax.Array | bu.Quantity:
32
32
  if isinstance(minval, bu.Quantity) and isinstance(maxval, bu.Quantity):
33
- return bu.Quantity(jr.uniform(key, shape, dtype, minval.value, maxval.value), dim=minval.dim)
33
+ maxval = maxval.in_unit(minval.unit)
34
+ return bu.Quantity(jr.uniform(key, shape, dtype, minval.mantissa, maxval.mantissa), unit=minval.unit)
34
35
  elif isinstance(minval, bu.Quantity):
35
36
  assert minval.is_unitless, f'minval must be unitless when maxval is not a Quantity, got {minval}'
36
- minval = minval.value
37
+ minval = minval.mantissa
37
38
  elif isinstance(maxval, bu.Quantity):
38
39
  assert maxval.is_unitless, f'maxval must be unitless when minval is not a Quantity, got {maxval}'
39
- maxval = maxval.value
40
+ maxval = maxval.mantissa
40
41
  return jr.uniform(key, shape, dtype, minval, maxval)
41
42
 
42
43
 
@@ -47,5 +48,5 @@ def permutation_for_unit(
47
48
  independent: bool = False
48
49
  ) -> jax.Array | bu.Quantity:
49
50
  if isinstance(x, bu.Quantity):
50
- return bu.Quantity(jr.permutation(key, x.value, axis, independent=independent), dim=x.dim)
51
+ return bu.Quantity(jr.permutation(key, x.mantissa, axis, independent=independent), unit=x.unit)
51
52
  return jr.permutation(key, x, axis, independent=independent)
brainstate/_state.py CHANGED
@@ -22,11 +22,9 @@ import numpy as np
22
22
  from jax.api_util import shaped_abstractify
23
23
  from jax.extend import source_info_util
24
24
 
25
+ from .typing import ArrayLike, PyTree
25
26
  from .util import DictManager
26
27
 
27
- PyTree = Any
28
- max_int = np.iinfo(np.int32)
29
-
30
28
  __all__ = [
31
29
  'State', 'ShortTermState', 'LongTermState', 'ParamState',
32
30
  'StateDictManager',
@@ -36,6 +34,7 @@ __all__ = [
36
34
  ]
37
35
 
38
36
  _pytree_registered_objects = set()
37
+ max_int = np.iinfo(np.int32)
39
38
 
40
39
 
41
40
  def _register_pytree_cls(cls):
@@ -110,7 +109,7 @@ class State(object):
110
109
  __module__ = 'brainstate'
111
110
  __slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree')
112
111
 
113
- def __init__(self, value: PyTree, name: Optional[str] = None):
112
+ def __init__(self, value: PyTree[ArrayLike], name: Optional[str] = None):
114
113
  if isinstance(value, State):
115
114
  value = value.value
116
115
  self._value = value
@@ -135,7 +134,7 @@ class State(object):
135
134
  self._name = name
136
135
 
137
136
  @property
138
- def value(self) -> PyTree:
137
+ def value(self) -> PyTree[ArrayLike]:
139
138
  """
140
139
  The data and its value.
141
140
  """
@@ -25,8 +25,8 @@ from typing import Any, Union, Sequence
25
25
  import jax
26
26
  import jax.numpy as jnp
27
27
  from jax.scipy.special import logsumexp
28
- from jax.typing import ArrayLike
29
28
 
29
+ from brainstate.typing import ArrayLike
30
30
  from .. import random
31
31
 
32
32
  __all__ = [
@@ -16,12 +16,11 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  from functools import partial
19
- from typing import Any
20
19
 
21
20
  import jax
22
21
  import jax.numpy as jnp
23
22
 
24
- PyTree = Any
23
+ from brainstate.typing import PyTree
25
24
 
26
25
  __all__ = [
27
26
  'clip_grad_norm',
@@ -22,8 +22,8 @@ import jax
22
22
  import numpy as np
23
23
 
24
24
  from brainstate._state import State
25
+ from brainstate.typing import ArrayLike
25
26
  from ._base import to_size
26
- from ..typing import ArrayLike
27
27
 
28
28
  __all__ = [
29
29
  'param',
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
  import numpy as np
23
23
 
24
24
  from brainstate import environ, random
25
+ from brainstate.typing import ArrayLike
25
26
  from ._base import Initializer, to_size
26
- from ..typing import ArrayLike
27
27
 
28
28
  __all__ = [
29
29
  'Normal',
brainstate/mixin.py CHANGED
@@ -15,14 +15,15 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from typing import Sequence, Optional, TypeVar, Any
18
+ from typing import Sequence, Optional, TypeVar
19
19
  from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
20
20
 
21
- T = TypeVar('T')
22
- PyTree = Any
21
+ from .typing import PyTree
23
22
 
23
+ T = TypeVar('T')
24
24
  State = None
25
25
 
26
+
26
27
  __all__ = [
27
28
  'Mixin',
28
29
  'DelayedInit',
brainstate/random.py CHANGED
@@ -490,14 +490,13 @@ class RandomState(State):
490
490
  upper = bu.math.asarray(upper, dtype=dtype)
491
491
  loc = bu.math.asarray(loc, dtype=dtype)
492
492
  scale = bu.math.asarray(scale, dtype=dtype)
493
- bu.fail_for_dimension_mismatch(lower, upper)
494
- bu.fail_for_dimension_mismatch(lower, loc)
495
- bu.fail_for_dimension_mismatch(lower, scale)
496
- dim = lower.dim if isinstance(lower, bu.Quantity) else bu.DIMENSIONLESS
497
- lower = lower.value if isinstance(lower, bu.Quantity) else lower
498
- upper = upper.value if isinstance(upper, bu.Quantity) else upper
499
- loc = loc.value if isinstance(loc, bu.Quantity) else loc
500
- scale = scale.value if isinstance(scale, bu.Quantity) else scale
493
+ unit = bu.get_unit(lower)
494
+ lower, upper, loc, scale = (
495
+ lower.mantissa if isinstance(lower, bu.Quantity) else lower,
496
+ bu.Quantity(upper).in_unit(unit).mantissa,
497
+ bu.Quantity(loc).in_unit(unit).mantissa,
498
+ bu.Quantity(scale).in_unit(unit).mantissa
499
+ )
501
500
 
502
501
  jit_error(
503
502
  bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
@@ -535,10 +534,12 @@ class RandomState(State):
535
534
  out = out * scale * sqrt2 + loc
536
535
 
537
536
  # Clamp to ensure it's in the proper range
538
- out = jnp.clip(out,
539
- lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
540
- lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
541
- return out if dim == bu.DIMENSIONLESS else bu.Quantity(out, dim=dim)
537
+ out = jnp.clip(
538
+ out,
539
+ lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
540
+ lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
541
+ )
542
+ return out if unit.is_unitless else bu.Quantity(out, unit=unit)
542
543
 
543
544
  def _check_p(self, p):
544
545
  raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
@@ -555,30 +556,33 @@ class RandomState(State):
555
556
  r = jr.bernoulli(key, p=p, shape=_size2shape(size))
556
557
  return r
557
558
 
558
- def lognormal(self,
559
- mean=None,
560
- sigma=None,
561
- size: Optional[Size] = None,
562
- key: Optional[SeedOrKey] = None,
563
- dtype: DTypeLike = None):
559
+ def lognormal(
560
+ self,
561
+ mean=None,
562
+ sigma=None,
563
+ size: Optional[Size] = None,
564
+ key: Optional[SeedOrKey] = None,
565
+ dtype: DTypeLike = None
566
+ ):
564
567
  mean = _check_py_seq(mean)
565
568
  sigma = _check_py_seq(sigma)
566
569
  mean = bu.math.asarray(mean, dtype=dtype)
567
570
  sigma = bu.math.asarray(sigma, dtype=dtype)
568
- bu.fail_for_dimension_mismatch(mean, sigma)
569
- dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
570
- mean = mean.value if isinstance(mean, bu.Quantity) else mean
571
- sigma = sigma.value if isinstance(sigma, bu.Quantity) else sigma
571
+ unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
572
+ mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
573
+ sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, bu.Quantity) else sigma
572
574
 
573
575
  if size is None:
574
- size = jnp.broadcast_shapes(jnp.shape(mean),
575
- jnp.shape(sigma))
576
+ size = jnp.broadcast_shapes(
577
+ jnp.shape(mean),
578
+ jnp.shape(sigma)
579
+ )
576
580
  key = self.split_key() if key is None else _formalize_key(key)
577
581
  dtype = dtype or environ.dftype()
578
582
  samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
579
583
  samples = _loc_scale(mean, sigma, samples)
580
584
  samples = jnp.exp(samples)
581
- return samples if dim == bu.DIMENSIONLESS else bu.Quantity(samples, dim=dim)
585
+ return samples if unit.is_unitless else bu.Quantity(samples, unit=unit)
582
586
 
583
587
  def binomial(self,
584
588
  n,
@@ -678,10 +682,10 @@ class RandomState(State):
678
682
  cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
679
683
  if isinstance(mean, bu.Quantity):
680
684
  assert isinstance(cov, bu.Quantity)
681
- assert mean.dim ** 2 == cov.dim
682
- mean = mean.value if isinstance(mean, bu.Quantity) else mean
683
- cov = cov.value if isinstance(cov, bu.Quantity) else cov
684
- dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
685
+ assert mean.unit ** 2 == cov.unit
686
+ mean = mean.mantissa if isinstance(mean, bu.Quantity) else mean
687
+ cov = cov.mantissa if isinstance(cov, bu.Quantity) else cov
688
+ unit = mean.unit if isinstance(mean, bu.Quantity) else bu.Unit()
685
689
 
686
690
  key = self.split_key() if key is None else _formalize_key(key)
687
691
  if not jnp.ndim(mean) >= 1:
@@ -708,7 +712,7 @@ class RandomState(State):
708
712
  factor = jnp.linalg.cholesky(cov)
709
713
  normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
710
714
  r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
711
- return r if dim == bu.DIMENSIONLESS else bu.Quantity(r, dim=dim)
715
+ return r if unit.is_unitless else bu.Quantity(r, unit=unit)
712
716
 
713
717
  def rayleigh(self,
714
718
  scale=1.0,
@@ -4809,10 +4813,9 @@ def _size2shape(size):
4809
4813
 
4810
4814
 
4811
4815
  def _check_shape(name, shape, *param_shapes):
4812
- shape = core.as_named_shape(shape)
4813
4816
  if param_shapes:
4814
- shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
4815
- if shape.positional != shape_:
4817
+ shape_ = lax.broadcast_shapes(shape, *param_shapes)
4818
+ if shape != shape_:
4816
4819
  msg = ("{} parameter shapes must be broadcast-compatible with shape "
4817
4820
  "argument, and the result of broadcasting the shapes must equal "
4818
4821
  "the shape argument, but got result {} for shape argument {}.")
@@ -16,8 +16,8 @@
16
16
  import unittest
17
17
 
18
18
  import jax
19
- import jaxlib.xla_extension
20
19
  import jax.numpy as jnp
20
+ import jaxlib.xla_extension
21
21
 
22
22
  import brainstate as bst
23
23
 
@@ -71,8 +71,8 @@ from jax.util import wraps
71
71
 
72
72
  from brainstate._state import State, StateTrace
73
73
  from brainstate._utils import set_module_as
74
+ from brainstate.typing import PyTree
74
75
 
75
- PyTree = Any
76
76
  AxisName = Hashable
77
77
 
78
78
  __all__ = [
brainstate/typing.py CHANGED
@@ -14,13 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from typing import Any, Sequence, Protocol, Union
17
+ import functools as ft
18
+ import typing
19
+ from typing import Sequence, Protocol, Union, Any, Generic, TypeVar
18
20
 
19
21
  import brainunit as bu
20
22
  import jax
21
23
  import numpy as np
22
24
 
23
25
  __all__ = [
26
+ 'PyTree',
24
27
  'Size',
25
28
  'Axes',
26
29
  'SeedOrKey',
@@ -29,6 +32,151 @@ __all__ = [
29
32
  'DTypeLike',
30
33
  ]
31
34
 
35
+ _T = TypeVar("_T")
36
+
37
+
38
+ class _FakePyTree(Generic[_T]):
39
+ pass
40
+
41
+
42
+ _FakePyTree.__name__ = "PyTree"
43
+ _FakePyTree.__qualname__ = "PyTree"
44
+ _FakePyTree.__module__ = "builtins"
45
+
46
+
47
+ class _MetaPyTree(type):
48
+ def __call__(self, *args, **kwargs):
49
+ raise RuntimeError("PyTree cannot be instantiated")
50
+
51
+ # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
52
+ # the custom __instancecheck__ that we want.
53
+ # We can't add that __instancecheck__ via subclassing, e.g.
54
+ # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
55
+ # isn't allowed.
56
+ # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
57
+ # has __module__ "types", e.g. we get types.PyTree[int].
58
+ @ft.lru_cache(maxsize=None)
59
+ def __getitem__(cls, item):
60
+ if isinstance(item, tuple):
61
+ if len(item) == 2:
62
+
63
+ class X(PyTree):
64
+ leaftype = item[0]
65
+ structure = item[1].strip()
66
+
67
+ if not isinstance(X.structure, str):
68
+ raise ValueError(
69
+ "The structure annotation `struct` in "
70
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
71
+ f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
72
+ )
73
+ pieces = X.structure.split()
74
+ if len(pieces) == 0:
75
+ raise ValueError(
76
+ "The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
77
+ "cannot be the empty string."
78
+ )
79
+ for piece_index, piece in enumerate(pieces):
80
+ if (piece_index == 0) or (piece_index == len(pieces) - 1):
81
+ if piece == "...":
82
+ continue
83
+ if not piece.isidentifier():
84
+ raise ValueError(
85
+ "The string `struct` in "
86
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a "
87
+ "whitespace-separated sequence of identifiers, e.g. "
88
+ "`brainstate.typing.PyTree[leaftype, 'T']` or "
89
+ "`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
90
+ "(Here, 'identifier' is used in the same sense as in "
91
+ "regular Python, i.e. a valid variable name.)\n"
92
+ f"Got piece '{piece}' in overall structure '{X.structure}'."
93
+ )
94
+ name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
95
+ else:
96
+ raise ValueError(
97
+ "The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
98
+ "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
99
+ "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
100
+ f"{len(item)}."
101
+ )
102
+ else:
103
+ name = str(_FakePyTree[item])
104
+
105
+ class X(PyTree):
106
+ leaftype = item
107
+ structure = None
108
+
109
+ X.__name__ = name
110
+ X.__qualname__ = name
111
+ if getattr(typing, "GENERATING_DOCUMENTATION", False):
112
+ X.__module__ = "builtins"
113
+ else:
114
+ X.__module__ = "brainstate.typing"
115
+ return X
116
+
117
+
118
+ # Can't do `class PyTree(Generic[_T]): ...` because we need to override the
119
+ # instancecheck for PyTree[foo], but subclassing
120
+ # `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
121
+ PyTree = _MetaPyTree("PyTree", (), {})
122
+ if getattr(typing, "GENERATING_DOCUMENTATION", False):
123
+ PyTree.__module__ = "builtins"
124
+ else:
125
+ PyTree.__module__ = "brainstate.typing"
126
+ PyTree.__doc__ = """Represents a PyTree.
127
+
128
+ Annotations of the following sorts are supported:
129
+ ```python
130
+ a: PyTree
131
+ b: PyTree[LeafType]
132
+ c: PyTree[LeafType, "T"]
133
+ d: PyTree[LeafType, "S T"]
134
+ e: PyTree[LeafType, "... T"]
135
+ f: PyTree[LeafType, "T ..."]
136
+ ```
137
+
138
+ These correspond to:
139
+
140
+ a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
141
+ suggestively-named alternative to `Any`.
142
+ ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
143
+
144
+ b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
145
+ example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
146
+
147
+ c. A structure name can also be passed. In this case
148
+ `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
149
+ This can be used to mark that multiple PyTrees all have the same structure:
150
+ ```python
151
+ def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
152
+ ...
153
+ ```
154
+
155
+ d. A composite structure can be declared. In this case the variable must have a PyTree
156
+ structure each to the composition of multiple previously-bound PyTree structures.
157
+ For example:
158
+ ```python
159
+ def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
160
+ ...
161
+
162
+ x = (1, 2)
163
+ y = {"key": 3}
164
+ z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
165
+ f(x, y, z)
166
+ ```
167
+ When performing runtime type-checking, all the individual pieces must have already
168
+ been bound to structures, otherwise the composite structure check will throw an error.
169
+
170
+ e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
171
+ must match the declared structure, but the upper levels can be arbitrary. As in the
172
+ previous case, all named pieces must already have been seen and their structures
173
+ bound.
174
+
175
+ f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
176
+ declared structure, but the lower levels can be arbitrary. As in the previous two
177
+ cases, all named pieces must already have been seen and their structures bound.
178
+ """ # noqa: E501
179
+
32
180
  Size = Union[int, Sequence[int]]
33
181
  Axes = Union[int, Sequence[int]]
34
182
  SeedOrKey = Union[int, jax.Array, np.ndarray]
@@ -44,7 +192,7 @@ ArrayLike = Union[
44
192
  np.ndarray, # NumPy array type
45
193
  np.bool_, np.number, # NumPy scalar types
46
194
  bool, int, float, complex, # Python scalar types
47
- bu.Quantity, # quantity
195
+ bu.Quantity, # Quantity
48
196
  ]
49
197
 
50
198
  # --- Dtype --- #
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1.1.post20240802
3
+ Version: 0.0.2
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BDP
@@ -1,27 +1,27 @@
1
- brainstate/__init__.py,sha256=oxslZrm6wxtBQDqwJFb2BaAKZFmnp4d_esDkaeuGMWE,1410
2
- brainstate/_module.py,sha256=UjhfmY26VHQ-kF6U4l68AslGU6vhfD-dR7gF-4Io5ic,52520
1
+ brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
2
+ brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
3
3
  brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
4
- brainstate/_random_for_unit.py,sha256=eW4NJkX27VCCNWUwAlyt2otkeEthGKOpUoX6XJ6i95Y,1946
5
- brainstate/_state.py,sha256=t4lEikvxTfeL2TW0chLUvsQuuRoJSO-iXylUydl1i7k,12057
4
+ brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
5
+ brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
6
6
  brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
7
7
  brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
8
8
  brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
9
- brainstate/mixin.py,sha256=2f2toMUmgJIiovX1wi8OIBM8sbWH6s9Usa1ixL9J4tg,10747
9
+ brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
10
10
  brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
11
- brainstate/random.py,sha256=UbXfC0nrxk5FOsld0rCxBp2bIaeQHH5bj-NWQTR8bbQ,188447
11
+ brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
12
12
  brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
13
13
  brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
14
- brainstate/typing.py,sha256=tZyXpopNtp1UvqbntnmlUP6OBTM_ywyV9QBQ6Gu3IdU,2232
14
+ brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
15
15
  brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
16
16
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
17
- brainstate/functional/_activations.py,sha256=xlwvYG8qvpkfMEZTFxD_4amW63ZfEa8x3vzVH2hDgeY,17791
17
+ brainstate/functional/_activations.py,sha256=gvZ9E1-TsEUlyO7Om0eYzlM9DF-14_A32-gta1mjGo4,17798
18
18
  brainstate/functional/_normalization.py,sha256=IxE580waloZylZVXcpUUK4bWQdlE6oSPfafaKYfDkbg,2169
19
- brainstate/functional/_others.py,sha256=ifB-l82y7ZB632yLUJOEcpkRY-yOoiJ0mtDOxNilp4M,1711
19
+ brainstate/functional/_others.py,sha256=1Epp75RkGYobMc2kHISZuS-_xnTFk3zHb1UHadwugCo,1711
20
20
  brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il_Jw,2574
21
21
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
22
22
  brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
23
- brainstate/init/_generic.py,sha256=G5WHSQs3M9htREIZ_OXqj1ffSF_BdYdPTaMqe9AKj-k,7315
24
- brainstate/init/_random_inits.py,sha256=ycxH9WyKvPhRbk9PkamCLlc9aX5YokREBfSCbrbFQQ4,16021
23
+ brainstate/init/_generic.py,sha256=LB7IQfswOG6X-q0QX5N8T5vZmxdygetsSBQ6iXlZ0oU,7324
24
+ brainstate/init/_random_inits.py,sha256=LsfvKSX4wsR7Kh5jgKgdyXTCEEa5Nn_iYcp_2GgLQKY,16030
25
25
  brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
26
26
  brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
27
27
  brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
@@ -54,13 +54,13 @@ brainstate/transform/_control.py,sha256=0NFUGLIenqKuBhBiTmY0YgCrl2GI1ZbuWMW0DSOo
54
54
  brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
55
55
  brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
56
56
  brainstate/transform/_jit_error.py,sha256=8rGRx8dtvmPWmHVOsfz30EUMXSix-m2PKM3Ni_9-_7I,4829
57
- brainstate/transform/_jit_error_test.py,sha256=MGFSCih6GVXTmLh0K-m8RX8N4x2_CAgtcGOIWfzol88,1862
57
+ brainstate/transform/_jit_error_test.py,sha256=GAVGL0eNJ5Fu0lHABCGc-nLfa_0x0tw_VPfURB-nhLc,1862
58
58
  brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
59
- brainstate/transform/_make_jaxpr.py,sha256=vMPruKfp5Ugv8RL-9wGfQdSumLZdLtThZvv3sU9MDjE,30426
59
+ brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
60
60
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
61
61
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.1.1.post20240802.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.1.1.post20240802.dist-info/METADATA,sha256=LYHl7AJ94js5MhJVghX5mJ1KB_VA4kWVc1bCmZ0O3GY,3807
64
- brainstate-0.0.1.1.post20240802.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
- brainstate-0.0.1.1.post20240802.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.1.1.post20240802.dist-info/RECORD,,
62
+ brainstate-0.0.2.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
+ brainstate-0.0.2.dist-info/METADATA,sha256=K6yiVOqGj3Qs_vKGgQmFXZtlu8cS4r7EZXl_iyCjwh0,3792
64
+ brainstate-0.0.2.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
+ brainstate-0.0.2.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
+ brainstate-0.0.2.dist-info/RECORD,,