brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +148 -43
  3. brainstate/_module_test.py +95 -21
  4. brainstate/environ.py +0 -1
  5. brainstate/functional/__init__.py +2 -2
  6. brainstate/functional/_activations.py +7 -26
  7. brainstate/functional/_spikes.py +0 -1
  8. brainstate/mixin.py +2 -2
  9. brainstate/nn/_elementwise.py +5 -4
  10. brainstate/nn/_misc.py +4 -3
  11. brainstate/nn/_others.py +3 -2
  12. brainstate/nn/_poolings.py +21 -20
  13. brainstate/nn/_poolings_test.py +4 -4
  14. brainstate/optim/__init__.py +0 -1
  15. brainstate/optim/_sgd_optimizer.py +18 -17
  16. brainstate/transform/__init__.py +2 -3
  17. brainstate/transform/_autograd.py +1 -1
  18. brainstate/transform/_autograd_test.py +0 -2
  19. brainstate/transform/_jit_test.py +0 -3
  20. brainstate/transform/_make_jaxpr.py +0 -1
  21. brainstate/transform/_make_jaxpr_test.py +0 -2
  22. brainstate/transform/_progress_bar.py +1 -3
  23. brainstate/util.py +0 -1
  24. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +2 -12
  25. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/RECORD +28 -35
  26. brainstate/math/__init__.py +0 -21
  27. brainstate/math/_einops.py +0 -787
  28. brainstate/math/_einops_parsing.py +0 -169
  29. brainstate/math/_einops_parsing_test.py +0 -126
  30. brainstate/math/_einops_test.py +0 -346
  31. brainstate/math/_misc.py +0 -298
  32. brainstate/math/_misc_test.py +0 -58
  33. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
  34. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
  35. {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
@@ -87,4 +87,3 @@ def spike_bitwise(x, y, op: str):
87
87
  return spike_bitwise_ixor(x, y)
88
88
  else:
89
89
  raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
90
-
brainstate/mixin.py CHANGED
@@ -68,7 +68,7 @@ class DelayedInit(Mixin):
68
68
  Note this Mixin can be applied in any Python object.
69
69
  """
70
70
 
71
- non_hash_params: Optional[Sequence[str]] = None
71
+ non_hashable_params: Optional[Sequence[str]] = None
72
72
 
73
73
  @classmethod
74
74
  def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
@@ -94,7 +94,7 @@ class DelayedInitializer(metaclass=NoSubclassMeta):
94
94
  """
95
95
 
96
96
  def __init__(self, cls: T, *desc_tuple, **desc_dict):
97
- self.cls = cls
97
+ self.cls: type = cls
98
98
 
99
99
  # arguments
100
100
  self.args = desc_tuple
@@ -19,11 +19,12 @@ from __future__ import annotations
19
19
 
20
20
  from typing import Optional
21
21
 
22
+ import brainunit as bu
22
23
  import jax.numpy as jnp
23
24
  import jax.typing
24
25
 
25
26
  from ._base import ElementWiseBlock
26
- from .. import math, environ, random, functional as F
27
+ from .. import environ, random, functional as F
27
28
  from .._module import Module
28
29
  from .._state import ParamState
29
30
  from ..mixin import Mode
@@ -82,7 +83,7 @@ class Threshold(Module, ElementWiseBlock):
82
83
  self.value = value
83
84
 
84
85
  def __call__(self, x: ArrayLike) -> ArrayLike:
85
- dtype = math.get_dtype(x)
86
+ dtype = bu.math.get_dtype(x)
86
87
  return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
87
88
  x,
88
89
  jnp.asarray(self.value, dtype=dtype))
@@ -1142,7 +1143,7 @@ class Dropout(Module, ElementWiseBlock):
1142
1143
  self.prob = prob
1143
1144
 
1144
1145
  def __call__(self, x):
1145
- dtype = math.get_dtype(x)
1146
+ dtype = bu.math.get_dtype(x)
1146
1147
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1147
1148
  if fit_phase:
1148
1149
  keep_mask = random.bernoulli(self.prob, x.shape)
@@ -1172,7 +1173,7 @@ class _DropoutNd(Module, ElementWiseBlock):
1172
1173
  self.channel_axis = channel_axis
1173
1174
 
1174
1175
  def __call__(self, x):
1175
- dtype = math.get_dtype(x)
1176
+ dtype = bu.math.get_dtype(x)
1176
1177
  # get fit phase
1177
1178
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1178
1179
 
brainstate/nn/_misc.py CHANGED
@@ -20,9 +20,10 @@ from enum import Enum
20
20
  from functools import wraps
21
21
  from typing import Sequence, Callable
22
22
 
23
+ import brainunit as bu
23
24
  import jax.numpy as jnp
24
25
 
25
- from .. import environ, math
26
+ from .. import environ
26
27
  from .._state import State
27
28
  from ..transform import vector_grad
28
29
 
@@ -96,7 +97,7 @@ def exp_euler(fun):
96
97
  )
97
98
  dt = environ.get('dt')
98
99
  linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
99
- phi = math.exprel(dt * linear)
100
+ phi = bu.math.exprel(dt * linear)
100
101
  return args[0] + dt * phi * derivative
101
102
 
102
103
  return integral
@@ -128,5 +129,5 @@ def exp_euler_step(fun: Callable, *args, **kwargs):
128
129
  )
129
130
  dt = environ.get('dt')
130
131
  linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
131
- phi = math.exprel(dt * linear)
132
+ phi = bu.math.exprel(dt * linear)
132
133
  return args[0] + dt * phi * derivative
brainstate/nn/_others.py CHANGED
@@ -19,10 +19,11 @@ from __future__ import annotations
19
19
  from functools import partial
20
20
  from typing import Optional
21
21
 
22
+ import brainunit as bu
22
23
  import jax.numpy as jnp
23
24
 
24
25
  from ._base import DnnLayer
25
- from .. import random, math, environ, typing, init
26
+ from .. import random, environ, typing, init
26
27
  from ..mixin import Mode
27
28
 
28
29
  __all__ = [
@@ -88,7 +89,7 @@ class DropoutFixed(DnnLayer):
88
89
  self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)
89
90
 
90
91
  def update(self, x):
91
- dtype = math.get_dtype(x)
92
+ dtype = bu.math.get_dtype(x)
92
93
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
93
94
  if fit_phase:
94
95
  assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
@@ -21,12 +21,13 @@ import functools
21
21
  from typing import Sequence, Optional
22
22
  from typing import Union, Tuple, Callable, List
23
23
 
24
+ import brainunit as bu
24
25
  import jax
25
26
  import jax.numpy as jnp
26
27
  import numpy as np
27
28
 
28
29
  from ._base import DnnLayer, ExplicitInOutSize
29
- from .. import environ, math
30
+ from .. import environ
30
31
  from ..mixin import Mode
31
32
  from ..typing import Size
32
33
 
@@ -53,8 +54,8 @@ class Flatten(DnnLayer, ExplicitInOutSize):
53
54
 
54
55
  Args:
55
56
  in_size: Sequence of int. The shape of the input tensor.
56
- start_dim: first dim to flatten (default = 1).
57
- end_dim: last dim to flatten (default = -1).
57
+ start_axis: first dim to flatten (default = 1).
58
+ end_axis: last dim to flatten (default = -1).
58
59
 
59
60
  Examples::
60
61
  >>> import brainstate as bst
@@ -74,36 +75,36 @@ class Flatten(DnnLayer, ExplicitInOutSize):
74
75
 
75
76
  def __init__(
76
77
  self,
77
- start_dim: int = 0,
78
- end_dim: int = -1,
78
+ start_axis: int = 0,
79
+ end_axis: int = -1,
79
80
  in_size: Optional[Size] = None
80
81
  ) -> None:
81
82
  super().__init__()
82
- self.start_dim = start_dim
83
- self.end_dim = end_dim
83
+ self.start_axis = start_axis
84
+ self.end_axis = end_axis
84
85
 
85
86
  if in_size is not None:
86
87
  self.in_size = tuple(in_size)
87
- y = jax.eval_shape(functools.partial(math.flatten, start_dim=start_dim, end_dim=end_dim),
88
+ y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis),
88
89
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
89
90
  self.out_size = y.shape
90
91
 
91
92
  def update(self, x):
92
93
  if self._in_size is None:
93
- start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim
94
+ start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
94
95
  else:
95
96
  assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
96
97
  dim_diff = x.ndim - len(self.in_size)
97
98
  if self.in_size != x.shape[dim_diff:]:
98
99
  raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
99
- if self.start_dim >= 0:
100
- start_dim = self.start_dim + dim_diff
100
+ if self.start_axis >= 0:
101
+ start_axis = self.start_axis + dim_diff
101
102
  else:
102
- start_dim = x.ndim + self.start_dim
103
- return math.flatten(x, start_dim, self.end_dim)
103
+ start_axis = x.ndim + self.start_axis
104
+ return bu.math.flatten(x, start_axis, self.end_axis)
104
105
 
105
106
  def __repr__(self) -> str:
106
- return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})'
107
+ return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
107
108
 
108
109
 
109
110
  class Unflatten(DnnLayer, ExplicitInOutSize):
@@ -124,7 +125,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
124
125
  :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
125
126
 
126
127
  Args:
127
- dim: int, Dimension to be unflattened.
128
+ axis: int, Dimension to be unflattened.
128
129
  sizes: Sequence of int. New shape of the unflattened dimension.
129
130
  in_size: Sequence of int. The shape of the input tensor.
130
131
  """
@@ -132,7 +133,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
132
133
 
133
134
  def __init__(
134
135
  self,
135
- dim: int,
136
+ axis: int,
136
137
  sizes: Size,
137
138
  mode: Mode = None,
138
139
  name: str = None,
@@ -140,7 +141,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
140
141
  ) -> None:
141
142
  super().__init__(mode=mode, name=name)
142
143
 
143
- self.dim = dim
144
+ self.axis = axis
144
145
  self.sizes = sizes
145
146
  if isinstance(sizes, (tuple, list)):
146
147
  for idx, elem in enumerate(sizes):
@@ -152,15 +153,15 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
152
153
 
153
154
  if in_size is not None:
154
155
  self.in_size = tuple(in_size)
155
- y = jax.eval_shape(functools.partial(math.unflatten, dim=dim, sizes=sizes),
156
+ y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes),
156
157
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
157
158
  self.out_size = y.shape
158
159
 
159
160
  def update(self, x):
160
- return math.unflatten(x, self.dim, self.sizes)
161
+ return bu.math.unflatten(x, self.axis, self.sizes)
161
162
 
162
163
  def __repr__(self):
163
- return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})'
164
+ return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
164
165
 
165
166
 
166
167
  class _MaxPool(DnnLayer, ExplicitInOutSize):
@@ -18,7 +18,7 @@ class TestFlatten(parameterized.TestCase):
18
18
  (10, 20, 30),
19
19
  ]:
20
20
  arr = bst.random.rand(*size)
21
- f = nn.Flatten(start_dim=0)
21
+ f = nn.Flatten(start_axis=0)
22
22
  out = f(arr)
23
23
  self.assertTrue(out.shape == (np.prod(size),))
24
24
 
@@ -29,21 +29,21 @@ class TestFlatten(parameterized.TestCase):
29
29
  (10, 20, 30),
30
30
  ]:
31
31
  arr = bst.random.rand(*size)
32
- f = nn.Flatten(start_dim=1)
32
+ f = nn.Flatten(start_axis=1)
33
33
  out = f(arr)
34
34
  self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
35
35
 
36
36
  def test_flatten3(self):
37
37
  size = (16, 32, 32, 8)
38
38
  arr = bst.random.rand(*size)
39
- f = nn.Flatten(start_dim=0, in_size=(32, 8))
39
+ f = nn.Flatten(start_axis=0, in_size=(32, 8))
40
40
  out = f(arr)
41
41
  self.assertTrue(out.shape == (16, 32, 32 * 8))
42
42
 
43
43
  def test_flatten4(self):
44
44
  size = (16, 32, 32, 8)
45
45
  arr = bst.random.rand(*size)
46
- f = nn.Flatten(start_dim=1, in_size=(32, 32, 8))
46
+ f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
47
47
  out = f(arr)
48
48
  self.assertTrue(out.shape == (16, 32, 32 * 8))
49
49
 
@@ -20,4 +20,3 @@ from ._sgd_optimizer import *
20
20
  from ._sgd_optimizer import __all__ as optimizer_all
21
21
 
22
22
  __all__ = scheduler_all + optimizer_all
23
-
@@ -18,11 +18,12 @@
18
18
  import functools
19
19
  from typing import Union, Dict, Optional, Tuple, Any, TypeVar
20
20
 
21
+ import brainunit as bu
21
22
  import jax
22
23
  import jax.numpy as jnp
23
24
 
24
25
  from ._lr_scheduler import make_schedule, LearningRateScheduler
25
- from .. import environ, math
26
+ from .. import environ
26
27
  from .._module import Module
27
28
  from .._state import State, LongTermState, StateDictManager, visible_state_dict
28
29
 
@@ -282,7 +283,7 @@ class Momentum(_WeightDecayOptimizer):
282
283
  for k, v in train_states.items():
283
284
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
284
285
  self.weight_states.add_unique_elem(k, v)
285
- self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
286
+ self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
286
287
 
287
288
  def update(self, grads: dict):
288
289
  lr = self.lr()
@@ -349,7 +350,7 @@ class MomentumNesterov(_WeightDecayOptimizer):
349
350
  for k, v in train_states.items():
350
351
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
351
352
  self.weight_states.add_unique_elem(k, v)
352
- self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
353
+ self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
353
354
 
354
355
  def update(self, grads: dict):
355
356
  lr = self.lr()
@@ -417,7 +418,7 @@ class Adagrad(_WeightDecayOptimizer):
417
418
  for k, v in train_states.items():
418
419
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
419
420
  self.weight_states.add_unique_elem(k, v)
420
- self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
421
+ self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
421
422
 
422
423
  def update(self, grads: dict):
423
424
  lr = self.lr()
@@ -500,8 +501,8 @@ class Adadelta(_WeightDecayOptimizer):
500
501
  for k, v in train_states.items():
501
502
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
502
503
  self.weight_states.add_unique_elem(k, v)
503
- self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
504
- self.delta_states[k] = OptimState(math.tree_zeros_like(v.value))
504
+ self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
505
+ self.delta_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
505
506
 
506
507
  def update(self, grads: dict):
507
508
  weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
@@ -574,7 +575,7 @@ class RMSProp(_WeightDecayOptimizer):
574
575
  for k, v in train_states.items():
575
576
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
576
577
  self.weight_states.add_unique_elem(k, v)
577
- self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
578
+ self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
578
579
 
579
580
  def update(self, grads: dict):
580
581
  lr = self.lr()
@@ -647,8 +648,8 @@ class Adam(_WeightDecayOptimizer):
647
648
  for k, v in train_states.items():
648
649
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
649
650
  self.weight_states.add_unique_elem(k, v)
650
- self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
651
- self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
651
+ self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
652
+ self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
652
653
 
653
654
  def update(self, grads: dict):
654
655
  lr = self.lr()
@@ -730,7 +731,7 @@ class LARS(_WeightDecayOptimizer):
730
731
  for k, v in train_states.items():
731
732
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
732
733
  self.weight_states.add_unique_elem(k, v)
733
- self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
734
+ self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
734
735
 
735
736
  def update(self, grads: dict):
736
737
  lr = self.lr()
@@ -835,10 +836,10 @@ class Adan(_WeightDecayOptimizer):
835
836
  for k, v in train_states.items():
836
837
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
837
838
  self.weight_states.add_unique_elem(k, v)
838
- self.exp_avg_states[k] = OptimState(math.tree_zeros_like(v.value))
839
- self.exp_avg_sq_states[k] = OptimState(math.tree_zeros_like(v.value))
840
- self.exp_avg_diff_states[k] = OptimState(math.tree_zeros_like(v.value))
841
- self.pre_grad_states[k] = OptimState(math.tree_zeros_like(v.value))
839
+ self.exp_avg_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
840
+ self.exp_avg_sq_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
841
+ self.exp_avg_diff_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
842
+ self.pre_grad_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
842
843
 
843
844
  def update(self, grads: dict):
844
845
  lr = self.lr()
@@ -989,10 +990,10 @@ class AdamW(_WeightDecayOptimizer):
989
990
  for k, v in train_states.items():
990
991
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
991
992
  self.weight_states.add_unique_elem(k, v)
992
- self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
993
- self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
993
+ self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
994
+ self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
994
995
  if self.amsgrad:
995
- self.vmax_states[k] = OptimState(math.tree_zeros_like(v.value))
996
+ self.vmax_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
996
997
 
997
998
  def update(self, grads: dict):
998
999
  lr_old = self.lr()
@@ -17,10 +17,10 @@
17
17
  This module contains the functions for the transformation of the brain data.
18
18
  """
19
19
 
20
- from ._control import *
21
- from ._control import __all__ as _controls_all
22
20
  from ._autograd import *
23
21
  from ._autograd import __all__ as _gradients_all
22
+ from ._control import *
23
+ from ._control import __all__ as _controls_all
24
24
  from ._jit import *
25
25
  from ._jit import __all__ as _jit_all
26
26
  from ._jit_error import *
@@ -33,4 +33,3 @@ from ._progress_bar import __all__ as _progress_bar_all
33
33
  __all__ = _gradients_all + _jit_error_all + _controls_all + _make_jaxpr_all + _jit_all + _progress_bar_all
34
34
 
35
35
  del _gradients_all, _jit_error_all, _controls_all, _make_jaxpr_all, _jit_all, _progress_bar_all
36
-
@@ -25,8 +25,8 @@ from jax._src.api import _vjp
25
25
  from jax.api_util import argnums_partial
26
26
  from jax.extend import linear_util
27
27
 
28
- from brainstate._utils import set_module_as
29
28
  from brainstate._state import State, StateTrace, StateDictManager
29
+ from brainstate._utils import set_module_as
30
30
 
31
31
  __all__ = [
32
32
  'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
@@ -537,7 +537,6 @@ class TestPureFuncJacobian(unittest.TestCase):
537
537
 
538
538
  def test_jacrev_return_aux1(self):
539
539
  with bc.environ.context(precision=64):
540
-
541
540
  def f1(x, y):
542
541
  a = 4 * x[1] ** 2 - 2 * x[2]
543
542
  r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
@@ -564,7 +563,6 @@ class TestPureFuncJacobian(unittest.TestCase):
564
563
  assert (vec == _r).all()
565
564
 
566
565
 
567
-
568
566
  class TestClassFuncJacobian(unittest.TestCase):
569
567
  def test_jacrev1(self):
570
568
  def f1(x, y):
@@ -16,7 +16,6 @@
16
16
  import unittest
17
17
 
18
18
  import jax.numpy as jnp
19
- import jax.stages
20
19
 
21
20
  import brainstate as bc
22
21
 
@@ -90,7 +89,6 @@ class TestJIT(unittest.TestCase):
90
89
  self.assertTrue(len(compiling) == 2)
91
90
 
92
91
  def test_jit_attribute_origin_fun(self):
93
-
94
92
  def fun1(x):
95
93
  return x
96
94
 
@@ -99,4 +97,3 @@ class TestJIT(unittest.TestCase):
99
97
  self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction))
100
98
  self.assertTrue(callable(jitted_fun.jitted_fun))
101
99
  self.assertTrue(callable(jitted_fun.clear_cache))
102
-
@@ -75,7 +75,6 @@ from brainstate._utils import set_module_as
75
75
  PyTree = Any
76
76
  AxisName = Hashable
77
77
 
78
-
79
78
  __all__ = [
80
79
  "StatefulFunction",
81
80
  "make_jaxpr",
@@ -129,5 +129,3 @@ def test_return_states():
129
129
 
130
130
  with pytest.raises(ValueError):
131
131
  f()
132
-
133
-
@@ -14,13 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from __future__ import annotations
17
+
17
18
  import copy
18
19
  from typing import Optional
19
20
 
20
21
  import jax
21
22
 
22
- from brainstate import environ
23
-
24
23
  try:
25
24
  from tqdm.auto import tqdm
26
25
  except (ImportError, ModuleNotFoundError):
@@ -95,7 +94,6 @@ class ProgressBarRunner(object):
95
94
  self.tqdm_bars[0].close()
96
95
 
97
96
  def __call__(self, iter_num, *args, **kwargs):
98
-
99
97
  _ = jax.lax.cond(
100
98
  iter_num == 0,
101
99
  lambda: jax.debug.callback(self._define_tqdm),
brainstate/util.py CHANGED
@@ -26,7 +26,6 @@ from jax.lib import xla_bridge
26
26
 
27
27
  from ._utils import set_module_as
28
28
 
29
-
30
29
  __all__ = [
31
30
  'unique_name',
32
31
  'clear_buffer_memory',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1.post20240612
3
+ Version: 0.0.1.post20240622
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BrainPy Team
@@ -31,21 +31,11 @@ License-File: LICENSE
31
31
  Requires-Dist: jax
32
32
  Requires-Dist: jaxlib
33
33
  Requires-Dist: numpy
34
+ Requires-Dist: brainunit
34
35
  Provides-Extra: cpu
35
36
  Requires-Dist: jaxlib ; extra == 'cpu'
36
- Requires-Dist: brainpylib ; extra == 'cpu'
37
- Provides-Extra: cpu_mini
38
- Requires-Dist: jaxlib ; extra == 'cpu_mini'
39
- Provides-Extra: cuda11
40
- Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11'
41
- Requires-Dist: brainpylib ; extra == 'cuda11'
42
- Provides-Extra: cuda11_mini
43
- Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11_mini'
44
37
  Provides-Extra: cuda12
45
38
  Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12'
46
- Requires-Dist: brainpylib ; extra == 'cuda12'
47
- Provides-Extra: cuda12_mini
48
- Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12_mini'
49
39
  Provides-Extra: testing
50
40
  Requires-Dist: pytest ; extra == 'testing'
51
41
  Provides-Extra: tpu
@@ -1,45 +1,38 @@
1
- brainstate/__init__.py,sha256=3R0I9oLpIS6Z9iwmx5ODzSlZGX3MbYWvCWsFHaCiaG4,1436
2
- brainstate/_module.py,sha256=R3pBeNvqR_mEfquZU60uWj7JmopOCUciF2BgcJyA0aw,48151
3
- brainstate/_module_test.py,sha256=4tqtp2-j5mSoUmCITY0mVZEcXzxXCWJ_02Jdt1fxYJg,4502
1
+ brainstate/__init__.py,sha256=DwgnJOghZ_qeFh0a_roiaMCDH-V_F6Ve7by3xjSVrwk,1408
2
+ brainstate/_module.py,sha256=lIfRRev49QQfq55PBm9YohtvPsMWdLHNhrfBjMjD6c8,52623
3
+ brainstate/_module_test.py,sha256=TJlxR4R5bf621y68hTgzTaf0PBN9YmVhwoGKNcpXbpE,7821
4
4
  brainstate/_state.py,sha256=RWnLjMeaidxWXNAA0X-8mxj4i61j3T8w5KhugACUYhI,11422
5
5
  brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
6
6
  brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
7
- brainstate/environ.py,sha256=RMDUACuixwk2ZTHf0UGLhcd5DCraW-l1j9T3wc2wcFc,10242
8
- brainstate/mixin.py,sha256=V75vjMTzYcCMlPo5wekgRZZ9o6-xN8kJocQgEliu5yI,10738
7
+ brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
8
+ brainstate/mixin.py,sha256=x4WIYMTCFZgtTp-uiZeNI5J4Qd2BYaV0Ccm_EMdzl9c,10748
9
9
  brainstate/mixin_test.py,sha256=qDYqhHbHw3aBFW8aHQdPhov29013Eo9TJDF7RW2dapE,2919
10
10
  brainstate/random.py,sha256=Mi5i0kAsR8C-VoI8LMuIbPPr6YFzq6NBxhJ5K0w2qW4,186392
11
11
  brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
12
12
  brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
13
13
  brainstate/typing.py,sha256=Ooweu7c17nYP686fyIeKNomChodSxx_OEpu8QRoB9cY,2180
14
- brainstate/util.py,sha256=FrBN_OZAPlWxfNK8c9Z1d-bbIa8qwMrcOsSJZJS8xOE,19878
15
- brainstate/functional/__init__.py,sha256=Z-43coOHFAsQK0u5amlr4l0fNNPc7dVcuKXfNY4Gj_s,1107
16
- brainstate/functional/_activations.py,sha256=IfZ6Zy8SAwyxo166E3NmCZMUHnG_rBFAUaLTyxG5FgA,18490
14
+ brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
15
+ brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
16
+ brainstate/functional/_activations.py,sha256=xlwvYG8qvpkfMEZTFxD_4amW63ZfEa8x3vzVH2hDgeY,17791
17
17
  brainstate/functional/_normalization.py,sha256=IxE580waloZylZVXcpUUK4bWQdlE6oSPfafaKYfDkbg,2169
18
18
  brainstate/functional/_others.py,sha256=ifB-l82y7ZB632yLUJOEcpkRY-yOoiJ0mtDOxNilp4M,1711
19
- brainstate/functional/_spikes.py,sha256=uAln_Q87pr1codLxeDck3PUA9jpk7S5LifNps1kdyrU,2576
19
+ brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il_Jw,2574
20
20
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
21
21
  brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
22
22
  brainstate/init/_generic.py,sha256=OJFS7DHYmZV0JogdsgjnUseUfvTUrAUYiXZynCQqmG4,5163
23
23
  brainstate/init/_random_inits.py,sha256=STbX-mrHwNuICXkw7EldtJLdUUsWOAcGkEzx2ycV-Yc,15321
24
24
  brainstate/init/_regular_inits.py,sha256=n-vF-51FM1UcUh-8h5lUk5Jhjrn04KPcGXgGhUGFAAk,3065
25
- brainstate/math/__init__.py,sha256=meQnO6k1EzMRMhO3x_22oj4-LVo_KevHK4L04bmHZPo,873
26
- brainstate/math/_einops.py,sha256=Lwi8AGKNPb-x1To0dDQYHbKwUOrO6pPL23qdg28-nB0,31726
27
- brainstate/math/_einops_parsing.py,sha256=zjTJdJlEBRS0y02PgKoZ8Y6bv54B4Axzk4AtPQOo934,6805
28
- brainstate/math/_einops_parsing_test.py,sha256=JPn73yld300481J6E9cL7jHWn63Vr21VV8k1jJxAK4A,4888
29
- brainstate/math/_einops_test.py,sha256=xj-DDTL0EsW1Obm64KCnT7eqELWjjj04Ozdwk0839Tw,13289
30
- brainstate/math/_misc.py,sha256=jDtREP4ojxHyj6lXcLcYLGVsLA0HFZcrs8cdlnA7aK8,7863
31
- brainstate/math/_misc_test.py,sha256=V41YV-RiEbukKQlzq54174cpSalOhMjaHOoVH8o82eI,2443
32
25
  brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
33
26
  brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
34
27
  brainstate/nn/_connections.py,sha256=GSOW2IbpJRHdPyF4nFJ2RPgO8y6SVHT1Gn-pbri9pMk,22970
35
28
  brainstate/nn/_dynamics.py,sha256=OeYYXv1dqjUDcCsRhZo1XS7SP2li1vlH9uhME_PE9v0,13205
36
- brainstate/nn/_elementwise.py,sha256=T1oCu47t11Ki7LPaL-hHk4W8bKP_Q3HLJcGngcmGK0Y,43552
29
+ brainstate/nn/_elementwise.py,sha256=6BTqSvSnaHhldwB5ol5OV0hPJ5yJ-Jpm4WSrtFKMNoQ,43579
37
30
  brainstate/nn/_embedding.py,sha256=WbgrIaM_14abN8zBDr0xipBOsFc8dXP2m7Z_aRLAfmU,2249
38
- brainstate/nn/_misc.py,sha256=Z7gdJraJ18gVNHyNOk_KmE67M3OM4z3QT4RN6al5JMc,3766
31
+ brainstate/nn/_misc.py,sha256=Xc4U4NLmvfnKdBNDayFrRBPAy3p0beS6T9C59rIDP00,3790
39
32
  brainstate/nn/_normalizations.py,sha256=9yVDORAEpqEkL9MYSPU4m7C4q8Qj5UNsPh9sKmIt5gQ,14329
40
- brainstate/nn/_others.py,sha256=8PYmCiUNzru4kmm58HY0RzCs-32dnwNFDZdTTPixaqo,4492
41
- brainstate/nn/_poolings.py,sha256=cNZ1PyMIaViP-_AUkEbpy3ZfHo--ib1hAhL0bEAmXIQ,45688
42
- brainstate/nn/_poolings_test.py,sha256=iE0NgvOIWVgwmcvP4wazhGG4RJQdU2eeagdJ1sDXIBQ,7260
33
+ brainstate/nn/_others.py,sha256=AYyrbbdKZj16kT0cVITnoZHck4xcccM1W3LX5XM5Z3Q,4513
34
+ brainstate/nn/_poolings.py,sha256=wO1Q4s8blsLLv4CMlkrvZm0ravdL3dFGyOcg2QDendI,45754
35
+ brainstate/nn/_poolings_test.py,sha256=Mj4gO86Xl4JS5hHNR_CgeUdZQIqAxUoeBldS-eoZoBg,7264
43
36
  brainstate/nn/_rate_rnns.py,sha256=Cebhy57UWzfwrCfq0v2qLDegmb__mXL5ht750y4aTro,14457
44
37
  brainstate/nn/_readout.py,sha256=jsQwhVnrJICKw4wFq-Du2AORPb_XXz_tZ4cURcckU-E,4240
45
38
  brainstate/nn/_synouts.py,sha256=gi3EyKlzt4UoyghwvNIr03r7YabZyl1idbq9aYG8zYM,4379
@@ -49,23 +42,23 @@ brainstate/nn/_projection/_align_pre.py,sha256=R2U6_RQ_o8y6PWXpozeWE2cx_oQ7WMhhr
49
42
  brainstate/nn/_projection/_delta.py,sha256=KT8ySo3n_Q_7swzOH-ISDf0x9rjMkiv99H-vqeQZDR8,7122
50
43
  brainstate/nn/_projection/_utils.py,sha256=UcmELOqsINgqJr7eC5BSNNteyZ--1lyGjhUTJfxyMmA,813
51
44
  brainstate/nn/_projection/_vanilla.py,sha256=_bh_DLtF0o33SBtj6IGL8CTanFEtJwfjBrgxBEAmIlg,3397
52
- brainstate/optim/__init__.py,sha256=1xH5_peSWKuZ4tOU295r9EKAv0a-cBMABx6XV3faDJI,919
45
+ brainstate/optim/__init__.py,sha256=1L6x_qZprw3PJYddB1nX-uTFGUl6_Qt3PM0OdY6g968,917
53
46
  brainstate/optim/_lr_scheduler.py,sha256=emKnA52UVqOfUcX7LJqwP-FVDVlGGzTQi2djYmbCWUo,15627
54
47
  brainstate/optim/_lr_scheduler_test.py,sha256=OwF8Iz-PorEbO0gO--A7IIgQEytqEfYWbPucAgzqL90,1598
55
- brainstate/optim/_sgd_optimizer.py,sha256=7-jMfP_Hol0XGEA6_4wVqygpLTqI1646F6eeLtwtNFY,45760
56
- brainstate/transform/__init__.py,sha256=9S9TLp1sF6nWRmW6jFtu6_dLmOc43V88Ruh073Z8I50,1460
57
- brainstate/transform/_autograd.py,sha256=sFGJ6oAhlSr54Hb1c1aNc5Q2St7eIr_X77lupc31YXg,23964
58
- brainstate/transform/_autograd_test.py,sha256=epQ2z97fAp_dQ_CwWGZD7sgw-p9o9fGfSeOUAJiiDY0,38658
48
+ brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aan7b70,45826
49
+ brainstate/transform/__init__.py,sha256=my2X4ZW0uKZRfN82zyGEPizWNJ0fsSP2akvmkjn43ck,1458
50
+ brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
51
+ brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
59
52
  brainstate/transform/_control.py,sha256=NWceTIuLlj2uGTdNcqBAXgnaLuChOGgAtIXtFn5vdLU,26837
60
53
  brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
61
54
  brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
62
55
  brainstate/transform/_jit_error.py,sha256=lO_e5AdhkjozHjM10q0b57OaXbeZ9gQkVmZMN6VQVCw,4450
63
- brainstate/transform/_jit_test.py,sha256=lVXvScfXExhXwFi8jnvEY6stNVulZHCzriamajFqzrY,2891
64
- brainstate/transform/_make_jaxpr.py,sha256=MTeBpPO1thu5yDytWoJijySHV7-nWmUoBMC0RCbdzcY,29972
65
- brainstate/transform/_make_jaxpr_test.py,sha256=4nEwZv_ebgUZgV86vOJFO_qC69mw2F3rogViF2SC1Qs,3823
66
- brainstate/transform/_progress_bar.py,sha256=myrAkBcUfuVGFLVwFzeSe5vdg1z49ARKqTlccG92maA,3536
67
- brainstate-0.0.1.post20240612.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
68
- brainstate-0.0.1.post20240612.dist-info/METADATA,sha256=VRHXnO0TBRcoo_M4iFHsywjCJbhonpStzYSmRXkR_wM,4254
69
- brainstate-0.0.1.post20240612.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
70
- brainstate-0.0.1.post20240612.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
71
- brainstate-0.0.1.post20240612.dist-info/RECORD,,
56
+ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
57
+ brainstate/transform/_make_jaxpr.py,sha256=q3OPy-1Gg0mVaB9pgSTWzzP8FSCAgquSjP-pDEw3Tpg,29970
58
+ brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
59
+ brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
60
+ brainstate-0.0.1.post20240622.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
61
+ brainstate-0.0.1.post20240622.dist-info/METADATA,sha256=N-I84Xg5_9MT9p4LjmQjpk4GcPV5iDR3J8LDa9ppKnM,3814
62
+ brainstate-0.0.1.post20240622.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
63
+ brainstate-0.0.1.post20240622.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
64
+ brainstate-0.0.1.post20240622.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from ._einops import *
17
- from ._einops import __all__ as _einops_all
18
- from ._misc import *
19
- from ._misc import __all__ as _misc_all
20
-
21
- __all__ = _misc_all + _einops_all