brainstate 0.0.1.post20240612__py2.py3-none-any.whl → 0.0.1.post20240623__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +147 -42
- brainstate/_module_test.py +95 -21
- brainstate/environ.py +0 -1
- brainstate/functional/__init__.py +2 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +0 -1
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/METADATA +2 -12
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/RECORD +28 -35
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/top_level.txt +0 -0
brainstate/mixin.py
CHANGED
@@ -68,7 +68,7 @@ class DelayedInit(Mixin):
|
|
68
68
|
Note this Mixin can be applied in any Python object.
|
69
69
|
"""
|
70
70
|
|
71
|
-
|
71
|
+
non_hashable_params: Optional[Sequence[str]] = None
|
72
72
|
|
73
73
|
@classmethod
|
74
74
|
def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
|
@@ -94,7 +94,7 @@ class DelayedInitializer(metaclass=NoSubclassMeta):
|
|
94
94
|
"""
|
95
95
|
|
96
96
|
def __init__(self, cls: T, *desc_tuple, **desc_dict):
|
97
|
-
self.cls = cls
|
97
|
+
self.cls: type = cls
|
98
98
|
|
99
99
|
# arguments
|
100
100
|
self.args = desc_tuple
|
brainstate/nn/_elementwise.py
CHANGED
@@ -19,11 +19,12 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
+
import brainunit as bu
|
22
23
|
import jax.numpy as jnp
|
23
24
|
import jax.typing
|
24
25
|
|
25
26
|
from ._base import ElementWiseBlock
|
26
|
-
from .. import
|
27
|
+
from .. import environ, random, functional as F
|
27
28
|
from .._module import Module
|
28
29
|
from .._state import ParamState
|
29
30
|
from ..mixin import Mode
|
@@ -82,7 +83,7 @@ class Threshold(Module, ElementWiseBlock):
|
|
82
83
|
self.value = value
|
83
84
|
|
84
85
|
def __call__(self, x: ArrayLike) -> ArrayLike:
|
85
|
-
dtype = math.get_dtype(x)
|
86
|
+
dtype = bu.math.get_dtype(x)
|
86
87
|
return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
|
87
88
|
x,
|
88
89
|
jnp.asarray(self.value, dtype=dtype))
|
@@ -1142,7 +1143,7 @@ class Dropout(Module, ElementWiseBlock):
|
|
1142
1143
|
self.prob = prob
|
1143
1144
|
|
1144
1145
|
def __call__(self, x):
|
1145
|
-
dtype = math.get_dtype(x)
|
1146
|
+
dtype = bu.math.get_dtype(x)
|
1146
1147
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1147
1148
|
if fit_phase:
|
1148
1149
|
keep_mask = random.bernoulli(self.prob, x.shape)
|
@@ -1172,7 +1173,7 @@ class _DropoutNd(Module, ElementWiseBlock):
|
|
1172
1173
|
self.channel_axis = channel_axis
|
1173
1174
|
|
1174
1175
|
def __call__(self, x):
|
1175
|
-
dtype = math.get_dtype(x)
|
1176
|
+
dtype = bu.math.get_dtype(x)
|
1176
1177
|
# get fit phase
|
1177
1178
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1178
1179
|
|
brainstate/nn/_misc.py
CHANGED
@@ -20,9 +20,10 @@ from enum import Enum
|
|
20
20
|
from functools import wraps
|
21
21
|
from typing import Sequence, Callable
|
22
22
|
|
23
|
+
import brainunit as bu
|
23
24
|
import jax.numpy as jnp
|
24
25
|
|
25
|
-
from .. import environ
|
26
|
+
from .. import environ
|
26
27
|
from .._state import State
|
27
28
|
from ..transform import vector_grad
|
28
29
|
|
@@ -96,7 +97,7 @@ def exp_euler(fun):
|
|
96
97
|
)
|
97
98
|
dt = environ.get('dt')
|
98
99
|
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
|
99
|
-
phi = math.exprel(dt * linear)
|
100
|
+
phi = bu.math.exprel(dt * linear)
|
100
101
|
return args[0] + dt * phi * derivative
|
101
102
|
|
102
103
|
return integral
|
@@ -128,5 +129,5 @@ def exp_euler_step(fun: Callable, *args, **kwargs):
|
|
128
129
|
)
|
129
130
|
dt = environ.get('dt')
|
130
131
|
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
|
131
|
-
phi = math.exprel(dt * linear)
|
132
|
+
phi = bu.math.exprel(dt * linear)
|
132
133
|
return args[0] + dt * phi * derivative
|
brainstate/nn/_others.py
CHANGED
@@ -19,10 +19,11 @@ from __future__ import annotations
|
|
19
19
|
from functools import partial
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
+
import brainunit as bu
|
22
23
|
import jax.numpy as jnp
|
23
24
|
|
24
25
|
from ._base import DnnLayer
|
25
|
-
from .. import random,
|
26
|
+
from .. import random, environ, typing, init
|
26
27
|
from ..mixin import Mode
|
27
28
|
|
28
29
|
__all__ = [
|
@@ -88,7 +89,7 @@ class DropoutFixed(DnnLayer):
|
|
88
89
|
self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)
|
89
90
|
|
90
91
|
def update(self, x):
|
91
|
-
dtype = math.get_dtype(x)
|
92
|
+
dtype = bu.math.get_dtype(x)
|
92
93
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
93
94
|
if fit_phase:
|
94
95
|
assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
|
brainstate/nn/_poolings.py
CHANGED
@@ -21,12 +21,13 @@ import functools
|
|
21
21
|
from typing import Sequence, Optional
|
22
22
|
from typing import Union, Tuple, Callable, List
|
23
23
|
|
24
|
+
import brainunit as bu
|
24
25
|
import jax
|
25
26
|
import jax.numpy as jnp
|
26
27
|
import numpy as np
|
27
28
|
|
28
29
|
from ._base import DnnLayer, ExplicitInOutSize
|
29
|
-
from .. import environ
|
30
|
+
from .. import environ
|
30
31
|
from ..mixin import Mode
|
31
32
|
from ..typing import Size
|
32
33
|
|
@@ -53,8 +54,8 @@ class Flatten(DnnLayer, ExplicitInOutSize):
|
|
53
54
|
|
54
55
|
Args:
|
55
56
|
in_size: Sequence of int. The shape of the input tensor.
|
56
|
-
|
57
|
-
|
57
|
+
start_axis: first dim to flatten (default = 1).
|
58
|
+
end_axis: last dim to flatten (default = -1).
|
58
59
|
|
59
60
|
Examples::
|
60
61
|
>>> import brainstate as bst
|
@@ -74,36 +75,36 @@ class Flatten(DnnLayer, ExplicitInOutSize):
|
|
74
75
|
|
75
76
|
def __init__(
|
76
77
|
self,
|
77
|
-
|
78
|
-
|
78
|
+
start_axis: int = 0,
|
79
|
+
end_axis: int = -1,
|
79
80
|
in_size: Optional[Size] = None
|
80
81
|
) -> None:
|
81
82
|
super().__init__()
|
82
|
-
self.
|
83
|
-
self.
|
83
|
+
self.start_axis = start_axis
|
84
|
+
self.end_axis = end_axis
|
84
85
|
|
85
86
|
if in_size is not None:
|
86
87
|
self.in_size = tuple(in_size)
|
87
|
-
y = jax.eval_shape(functools.partial(math.flatten,
|
88
|
+
y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis),
|
88
89
|
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
89
90
|
self.out_size = y.shape
|
90
91
|
|
91
92
|
def update(self, x):
|
92
93
|
if self._in_size is None:
|
93
|
-
|
94
|
+
start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
|
94
95
|
else:
|
95
96
|
assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
|
96
97
|
dim_diff = x.ndim - len(self.in_size)
|
97
98
|
if self.in_size != x.shape[dim_diff:]:
|
98
99
|
raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
|
99
|
-
if self.
|
100
|
-
|
100
|
+
if self.start_axis >= 0:
|
101
|
+
start_axis = self.start_axis + dim_diff
|
101
102
|
else:
|
102
|
-
|
103
|
-
return math.flatten(x,
|
103
|
+
start_axis = x.ndim + self.start_axis
|
104
|
+
return bu.math.flatten(x, start_axis, self.end_axis)
|
104
105
|
|
105
106
|
def __repr__(self) -> str:
|
106
|
-
return f'{self.__class__.__name__}(
|
107
|
+
return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
|
107
108
|
|
108
109
|
|
109
110
|
class Unflatten(DnnLayer, ExplicitInOutSize):
|
@@ -124,7 +125,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
124
125
|
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
125
126
|
|
126
127
|
Args:
|
127
|
-
|
128
|
+
axis: int, Dimension to be unflattened.
|
128
129
|
sizes: Sequence of int. New shape of the unflattened dimension.
|
129
130
|
in_size: Sequence of int. The shape of the input tensor.
|
130
131
|
"""
|
@@ -132,7 +133,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
132
133
|
|
133
134
|
def __init__(
|
134
135
|
self,
|
135
|
-
|
136
|
+
axis: int,
|
136
137
|
sizes: Size,
|
137
138
|
mode: Mode = None,
|
138
139
|
name: str = None,
|
@@ -140,7 +141,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
140
141
|
) -> None:
|
141
142
|
super().__init__(mode=mode, name=name)
|
142
143
|
|
143
|
-
self.
|
144
|
+
self.axis = axis
|
144
145
|
self.sizes = sizes
|
145
146
|
if isinstance(sizes, (tuple, list)):
|
146
147
|
for idx, elem in enumerate(sizes):
|
@@ -152,15 +153,15 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
152
153
|
|
153
154
|
if in_size is not None:
|
154
155
|
self.in_size = tuple(in_size)
|
155
|
-
y = jax.eval_shape(functools.partial(math.unflatten,
|
156
|
+
y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes),
|
156
157
|
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
157
158
|
self.out_size = y.shape
|
158
159
|
|
159
160
|
def update(self, x):
|
160
|
-
return math.unflatten(x, self.
|
161
|
+
return bu.math.unflatten(x, self.axis, self.sizes)
|
161
162
|
|
162
163
|
def __repr__(self):
|
163
|
-
return f'{self.__class__.__name__}(
|
164
|
+
return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
|
164
165
|
|
165
166
|
|
166
167
|
class _MaxPool(DnnLayer, ExplicitInOutSize):
|
brainstate/nn/_poolings_test.py
CHANGED
@@ -18,7 +18,7 @@ class TestFlatten(parameterized.TestCase):
|
|
18
18
|
(10, 20, 30),
|
19
19
|
]:
|
20
20
|
arr = bst.random.rand(*size)
|
21
|
-
f = nn.Flatten(
|
21
|
+
f = nn.Flatten(start_axis=0)
|
22
22
|
out = f(arr)
|
23
23
|
self.assertTrue(out.shape == (np.prod(size),))
|
24
24
|
|
@@ -29,21 +29,21 @@ class TestFlatten(parameterized.TestCase):
|
|
29
29
|
(10, 20, 30),
|
30
30
|
]:
|
31
31
|
arr = bst.random.rand(*size)
|
32
|
-
f = nn.Flatten(
|
32
|
+
f = nn.Flatten(start_axis=1)
|
33
33
|
out = f(arr)
|
34
34
|
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
35
35
|
|
36
36
|
def test_flatten3(self):
|
37
37
|
size = (16, 32, 32, 8)
|
38
38
|
arr = bst.random.rand(*size)
|
39
|
-
f = nn.Flatten(
|
39
|
+
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
40
40
|
out = f(arr)
|
41
41
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
42
42
|
|
43
43
|
def test_flatten4(self):
|
44
44
|
size = (16, 32, 32, 8)
|
45
45
|
arr = bst.random.rand(*size)
|
46
|
-
f = nn.Flatten(
|
46
|
+
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
47
47
|
out = f(arr)
|
48
48
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
49
49
|
|
brainstate/optim/__init__.py
CHANGED
@@ -18,11 +18,12 @@
|
|
18
18
|
import functools
|
19
19
|
from typing import Union, Dict, Optional, Tuple, Any, TypeVar
|
20
20
|
|
21
|
+
import brainunit as bu
|
21
22
|
import jax
|
22
23
|
import jax.numpy as jnp
|
23
24
|
|
24
25
|
from ._lr_scheduler import make_schedule, LearningRateScheduler
|
25
|
-
from .. import environ
|
26
|
+
from .. import environ
|
26
27
|
from .._module import Module
|
27
28
|
from .._state import State, LongTermState, StateDictManager, visible_state_dict
|
28
29
|
|
@@ -282,7 +283,7 @@ class Momentum(_WeightDecayOptimizer):
|
|
282
283
|
for k, v in train_states.items():
|
283
284
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
284
285
|
self.weight_states.add_unique_elem(k, v)
|
285
|
-
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
|
286
|
+
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
286
287
|
|
287
288
|
def update(self, grads: dict):
|
288
289
|
lr = self.lr()
|
@@ -349,7 +350,7 @@ class MomentumNesterov(_WeightDecayOptimizer):
|
|
349
350
|
for k, v in train_states.items():
|
350
351
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
351
352
|
self.weight_states.add_unique_elem(k, v)
|
352
|
-
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
|
353
|
+
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
353
354
|
|
354
355
|
def update(self, grads: dict):
|
355
356
|
lr = self.lr()
|
@@ -417,7 +418,7 @@ class Adagrad(_WeightDecayOptimizer):
|
|
417
418
|
for k, v in train_states.items():
|
418
419
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
419
420
|
self.weight_states.add_unique_elem(k, v)
|
420
|
-
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
|
421
|
+
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
421
422
|
|
422
423
|
def update(self, grads: dict):
|
423
424
|
lr = self.lr()
|
@@ -500,8 +501,8 @@ class Adadelta(_WeightDecayOptimizer):
|
|
500
501
|
for k, v in train_states.items():
|
501
502
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
502
503
|
self.weight_states.add_unique_elem(k, v)
|
503
|
-
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
|
504
|
-
self.delta_states[k] = OptimState(math.tree_zeros_like(v.value))
|
504
|
+
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
505
|
+
self.delta_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
505
506
|
|
506
507
|
def update(self, grads: dict):
|
507
508
|
weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
|
@@ -574,7 +575,7 @@ class RMSProp(_WeightDecayOptimizer):
|
|
574
575
|
for k, v in train_states.items():
|
575
576
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
576
577
|
self.weight_states.add_unique_elem(k, v)
|
577
|
-
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
|
578
|
+
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
578
579
|
|
579
580
|
def update(self, grads: dict):
|
580
581
|
lr = self.lr()
|
@@ -647,8 +648,8 @@ class Adam(_WeightDecayOptimizer):
|
|
647
648
|
for k, v in train_states.items():
|
648
649
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
649
650
|
self.weight_states.add_unique_elem(k, v)
|
650
|
-
self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
|
651
|
-
self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
|
651
|
+
self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
652
|
+
self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
652
653
|
|
653
654
|
def update(self, grads: dict):
|
654
655
|
lr = self.lr()
|
@@ -730,7 +731,7 @@ class LARS(_WeightDecayOptimizer):
|
|
730
731
|
for k, v in train_states.items():
|
731
732
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
732
733
|
self.weight_states.add_unique_elem(k, v)
|
733
|
-
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
|
734
|
+
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
734
735
|
|
735
736
|
def update(self, grads: dict):
|
736
737
|
lr = self.lr()
|
@@ -835,10 +836,10 @@ class Adan(_WeightDecayOptimizer):
|
|
835
836
|
for k, v in train_states.items():
|
836
837
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
837
838
|
self.weight_states.add_unique_elem(k, v)
|
838
|
-
self.exp_avg_states[k] = OptimState(math.tree_zeros_like(v.value))
|
839
|
-
self.exp_avg_sq_states[k] = OptimState(math.tree_zeros_like(v.value))
|
840
|
-
self.exp_avg_diff_states[k] = OptimState(math.tree_zeros_like(v.value))
|
841
|
-
self.pre_grad_states[k] = OptimState(math.tree_zeros_like(v.value))
|
839
|
+
self.exp_avg_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
840
|
+
self.exp_avg_sq_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
841
|
+
self.exp_avg_diff_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
842
|
+
self.pre_grad_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
842
843
|
|
843
844
|
def update(self, grads: dict):
|
844
845
|
lr = self.lr()
|
@@ -989,10 +990,10 @@ class AdamW(_WeightDecayOptimizer):
|
|
989
990
|
for k, v in train_states.items():
|
990
991
|
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
991
992
|
self.weight_states.add_unique_elem(k, v)
|
992
|
-
self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
|
993
|
-
self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
|
993
|
+
self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
994
|
+
self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
994
995
|
if self.amsgrad:
|
995
|
-
self.vmax_states[k] = OptimState(math.tree_zeros_like(v.value))
|
996
|
+
self.vmax_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
|
996
997
|
|
997
998
|
def update(self, grads: dict):
|
998
999
|
lr_old = self.lr()
|
brainstate/transform/__init__.py
CHANGED
@@ -17,10 +17,10 @@
|
|
17
17
|
This module contains the functions for the transformation of the brain data.
|
18
18
|
"""
|
19
19
|
|
20
|
-
from ._control import *
|
21
|
-
from ._control import __all__ as _controls_all
|
22
20
|
from ._autograd import *
|
23
21
|
from ._autograd import __all__ as _gradients_all
|
22
|
+
from ._control import *
|
23
|
+
from ._control import __all__ as _controls_all
|
24
24
|
from ._jit import *
|
25
25
|
from ._jit import __all__ as _jit_all
|
26
26
|
from ._jit_error import *
|
@@ -33,4 +33,3 @@ from ._progress_bar import __all__ as _progress_bar_all
|
|
33
33
|
__all__ = _gradients_all + _jit_error_all + _controls_all + _make_jaxpr_all + _jit_all + _progress_bar_all
|
34
34
|
|
35
35
|
del _gradients_all, _jit_error_all, _controls_all, _make_jaxpr_all, _jit_all, _progress_bar_all
|
36
|
-
|
@@ -25,8 +25,8 @@ from jax._src.api import _vjp
|
|
25
25
|
from jax.api_util import argnums_partial
|
26
26
|
from jax.extend import linear_util
|
27
27
|
|
28
|
-
from brainstate._utils import set_module_as
|
29
28
|
from brainstate._state import State, StateTrace, StateDictManager
|
29
|
+
from brainstate._utils import set_module_as
|
30
30
|
|
31
31
|
__all__ = [
|
32
32
|
'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
@@ -537,7 +537,6 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
537
537
|
|
538
538
|
def test_jacrev_return_aux1(self):
|
539
539
|
with bc.environ.context(precision=64):
|
540
|
-
|
541
540
|
def f1(x, y):
|
542
541
|
a = 4 * x[1] ** 2 - 2 * x[2]
|
543
542
|
r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
|
@@ -564,7 +563,6 @@ class TestPureFuncJacobian(unittest.TestCase):
|
|
564
563
|
assert (vec == _r).all()
|
565
564
|
|
566
565
|
|
567
|
-
|
568
566
|
class TestClassFuncJacobian(unittest.TestCase):
|
569
567
|
def test_jacrev1(self):
|
570
568
|
def f1(x, y):
|
@@ -16,7 +16,6 @@
|
|
16
16
|
import unittest
|
17
17
|
|
18
18
|
import jax.numpy as jnp
|
19
|
-
import jax.stages
|
20
19
|
|
21
20
|
import brainstate as bc
|
22
21
|
|
@@ -90,7 +89,6 @@ class TestJIT(unittest.TestCase):
|
|
90
89
|
self.assertTrue(len(compiling) == 2)
|
91
90
|
|
92
91
|
def test_jit_attribute_origin_fun(self):
|
93
|
-
|
94
92
|
def fun1(x):
|
95
93
|
return x
|
96
94
|
|
@@ -99,4 +97,3 @@ class TestJIT(unittest.TestCase):
|
|
99
97
|
self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction))
|
100
98
|
self.assertTrue(callable(jitted_fun.jitted_fun))
|
101
99
|
self.assertTrue(callable(jitted_fun.clear_cache))
|
102
|
-
|
@@ -14,13 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
|
+
|
17
18
|
import copy
|
18
19
|
from typing import Optional
|
19
20
|
|
20
21
|
import jax
|
21
22
|
|
22
|
-
from brainstate import environ
|
23
|
-
|
24
23
|
try:
|
25
24
|
from tqdm.auto import tqdm
|
26
25
|
except (ImportError, ModuleNotFoundError):
|
@@ -95,7 +94,6 @@ class ProgressBarRunner(object):
|
|
95
94
|
self.tqdm_bars[0].close()
|
96
95
|
|
97
96
|
def __call__(self, iter_num, *args, **kwargs):
|
98
|
-
|
99
97
|
_ = jax.lax.cond(
|
100
98
|
iter_num == 0,
|
101
99
|
lambda: jax.debug.callback(self._define_tqdm),
|
brainstate/util.py
CHANGED
{brainstate-0.0.1.post20240612.dist-info → brainstate-0.0.1.post20240623.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.1.
|
3
|
+
Version: 0.0.1.post20240623
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
5
|
Home-page: https://github.com/brainpy/brainstate
|
6
6
|
Author: BrainPy Team
|
@@ -31,21 +31,11 @@ License-File: LICENSE
|
|
31
31
|
Requires-Dist: jax
|
32
32
|
Requires-Dist: jaxlib
|
33
33
|
Requires-Dist: numpy
|
34
|
+
Requires-Dist: brainunit
|
34
35
|
Provides-Extra: cpu
|
35
36
|
Requires-Dist: jaxlib ; extra == 'cpu'
|
36
|
-
Requires-Dist: brainpylib ; extra == 'cpu'
|
37
|
-
Provides-Extra: cpu_mini
|
38
|
-
Requires-Dist: jaxlib ; extra == 'cpu_mini'
|
39
|
-
Provides-Extra: cuda11
|
40
|
-
Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11'
|
41
|
-
Requires-Dist: brainpylib ; extra == 'cuda11'
|
42
|
-
Provides-Extra: cuda11_mini
|
43
|
-
Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11_mini'
|
44
37
|
Provides-Extra: cuda12
|
45
38
|
Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12'
|
46
|
-
Requires-Dist: brainpylib ; extra == 'cuda12'
|
47
|
-
Provides-Extra: cuda12_mini
|
48
|
-
Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12_mini'
|
49
39
|
Provides-Extra: testing
|
50
40
|
Requires-Dist: pytest ; extra == 'testing'
|
51
41
|
Provides-Extra: tpu
|
@@ -1,45 +1,38 @@
|
|
1
|
-
brainstate/__init__.py,sha256=
|
2
|
-
brainstate/_module.py,sha256=
|
3
|
-
brainstate/_module_test.py,sha256=
|
1
|
+
brainstate/__init__.py,sha256=DwgnJOghZ_qeFh0a_roiaMCDH-V_F6Ve7by3xjSVrwk,1408
|
2
|
+
brainstate/_module.py,sha256=RN02rAqgsVVAHeZpXIpZEZSLbfo5YOstmTLlD-JcnN4,52625
|
3
|
+
brainstate/_module_test.py,sha256=TJlxR4R5bf621y68hTgzTaf0PBN9YmVhwoGKNcpXbpE,7821
|
4
4
|
brainstate/_state.py,sha256=RWnLjMeaidxWXNAA0X-8mxj4i61j3T8w5KhugACUYhI,11422
|
5
5
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
6
6
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
7
|
-
brainstate/environ.py,sha256=
|
8
|
-
brainstate/mixin.py,sha256=
|
7
|
+
brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
|
8
|
+
brainstate/mixin.py,sha256=x4WIYMTCFZgtTp-uiZeNI5J4Qd2BYaV0Ccm_EMdzl9c,10748
|
9
9
|
brainstate/mixin_test.py,sha256=qDYqhHbHw3aBFW8aHQdPhov29013Eo9TJDF7RW2dapE,2919
|
10
10
|
brainstate/random.py,sha256=Mi5i0kAsR8C-VoI8LMuIbPPr6YFzq6NBxhJ5K0w2qW4,186392
|
11
11
|
brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
|
12
12
|
brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
13
13
|
brainstate/typing.py,sha256=Ooweu7c17nYP686fyIeKNomChodSxx_OEpu8QRoB9cY,2180
|
14
|
-
brainstate/util.py,sha256=
|
15
|
-
brainstate/functional/__init__.py,sha256=
|
16
|
-
brainstate/functional/_activations.py,sha256=
|
14
|
+
brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
|
15
|
+
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
16
|
+
brainstate/functional/_activations.py,sha256=xlwvYG8qvpkfMEZTFxD_4amW63ZfEa8x3vzVH2hDgeY,17791
|
17
17
|
brainstate/functional/_normalization.py,sha256=IxE580waloZylZVXcpUUK4bWQdlE6oSPfafaKYfDkbg,2169
|
18
18
|
brainstate/functional/_others.py,sha256=ifB-l82y7ZB632yLUJOEcpkRY-yOoiJ0mtDOxNilp4M,1711
|
19
|
-
brainstate/functional/_spikes.py,sha256=
|
19
|
+
brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il_Jw,2574
|
20
20
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
21
21
|
brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
|
22
22
|
brainstate/init/_generic.py,sha256=OJFS7DHYmZV0JogdsgjnUseUfvTUrAUYiXZynCQqmG4,5163
|
23
23
|
brainstate/init/_random_inits.py,sha256=STbX-mrHwNuICXkw7EldtJLdUUsWOAcGkEzx2ycV-Yc,15321
|
24
24
|
brainstate/init/_regular_inits.py,sha256=n-vF-51FM1UcUh-8h5lUk5Jhjrn04KPcGXgGhUGFAAk,3065
|
25
|
-
brainstate/math/__init__.py,sha256=meQnO6k1EzMRMhO3x_22oj4-LVo_KevHK4L04bmHZPo,873
|
26
|
-
brainstate/math/_einops.py,sha256=Lwi8AGKNPb-x1To0dDQYHbKwUOrO6pPL23qdg28-nB0,31726
|
27
|
-
brainstate/math/_einops_parsing.py,sha256=zjTJdJlEBRS0y02PgKoZ8Y6bv54B4Axzk4AtPQOo934,6805
|
28
|
-
brainstate/math/_einops_parsing_test.py,sha256=JPn73yld300481J6E9cL7jHWn63Vr21VV8k1jJxAK4A,4888
|
29
|
-
brainstate/math/_einops_test.py,sha256=xj-DDTL0EsW1Obm64KCnT7eqELWjjj04Ozdwk0839Tw,13289
|
30
|
-
brainstate/math/_misc.py,sha256=jDtREP4ojxHyj6lXcLcYLGVsLA0HFZcrs8cdlnA7aK8,7863
|
31
|
-
brainstate/math/_misc_test.py,sha256=V41YV-RiEbukKQlzq54174cpSalOhMjaHOoVH8o82eI,2443
|
32
25
|
brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
|
33
26
|
brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
|
34
27
|
brainstate/nn/_connections.py,sha256=GSOW2IbpJRHdPyF4nFJ2RPgO8y6SVHT1Gn-pbri9pMk,22970
|
35
28
|
brainstate/nn/_dynamics.py,sha256=OeYYXv1dqjUDcCsRhZo1XS7SP2li1vlH9uhME_PE9v0,13205
|
36
|
-
brainstate/nn/_elementwise.py,sha256=
|
29
|
+
brainstate/nn/_elementwise.py,sha256=6BTqSvSnaHhldwB5ol5OV0hPJ5yJ-Jpm4WSrtFKMNoQ,43579
|
37
30
|
brainstate/nn/_embedding.py,sha256=WbgrIaM_14abN8zBDr0xipBOsFc8dXP2m7Z_aRLAfmU,2249
|
38
|
-
brainstate/nn/_misc.py,sha256=
|
31
|
+
brainstate/nn/_misc.py,sha256=Xc4U4NLmvfnKdBNDayFrRBPAy3p0beS6T9C59rIDP00,3790
|
39
32
|
brainstate/nn/_normalizations.py,sha256=9yVDORAEpqEkL9MYSPU4m7C4q8Qj5UNsPh9sKmIt5gQ,14329
|
40
|
-
brainstate/nn/_others.py,sha256=
|
41
|
-
brainstate/nn/_poolings.py,sha256=
|
42
|
-
brainstate/nn/_poolings_test.py,sha256=
|
33
|
+
brainstate/nn/_others.py,sha256=AYyrbbdKZj16kT0cVITnoZHck4xcccM1W3LX5XM5Z3Q,4513
|
34
|
+
brainstate/nn/_poolings.py,sha256=wO1Q4s8blsLLv4CMlkrvZm0ravdL3dFGyOcg2QDendI,45754
|
35
|
+
brainstate/nn/_poolings_test.py,sha256=Mj4gO86Xl4JS5hHNR_CgeUdZQIqAxUoeBldS-eoZoBg,7264
|
43
36
|
brainstate/nn/_rate_rnns.py,sha256=Cebhy57UWzfwrCfq0v2qLDegmb__mXL5ht750y4aTro,14457
|
44
37
|
brainstate/nn/_readout.py,sha256=jsQwhVnrJICKw4wFq-Du2AORPb_XXz_tZ4cURcckU-E,4240
|
45
38
|
brainstate/nn/_synouts.py,sha256=gi3EyKlzt4UoyghwvNIr03r7YabZyl1idbq9aYG8zYM,4379
|
@@ -49,23 +42,23 @@ brainstate/nn/_projection/_align_pre.py,sha256=R2U6_RQ_o8y6PWXpozeWE2cx_oQ7WMhhr
|
|
49
42
|
brainstate/nn/_projection/_delta.py,sha256=KT8ySo3n_Q_7swzOH-ISDf0x9rjMkiv99H-vqeQZDR8,7122
|
50
43
|
brainstate/nn/_projection/_utils.py,sha256=UcmELOqsINgqJr7eC5BSNNteyZ--1lyGjhUTJfxyMmA,813
|
51
44
|
brainstate/nn/_projection/_vanilla.py,sha256=_bh_DLtF0o33SBtj6IGL8CTanFEtJwfjBrgxBEAmIlg,3397
|
52
|
-
brainstate/optim/__init__.py,sha256=
|
45
|
+
brainstate/optim/__init__.py,sha256=1L6x_qZprw3PJYddB1nX-uTFGUl6_Qt3PM0OdY6g968,917
|
53
46
|
brainstate/optim/_lr_scheduler.py,sha256=emKnA52UVqOfUcX7LJqwP-FVDVlGGzTQi2djYmbCWUo,15627
|
54
47
|
brainstate/optim/_lr_scheduler_test.py,sha256=OwF8Iz-PorEbO0gO--A7IIgQEytqEfYWbPucAgzqL90,1598
|
55
|
-
brainstate/optim/_sgd_optimizer.py,sha256=
|
56
|
-
brainstate/transform/__init__.py,sha256=
|
57
|
-
brainstate/transform/_autograd.py,sha256=
|
58
|
-
brainstate/transform/_autograd_test.py,sha256=
|
48
|
+
brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aan7b70,45826
|
49
|
+
brainstate/transform/__init__.py,sha256=my2X4ZW0uKZRfN82zyGEPizWNJ0fsSP2akvmkjn43ck,1458
|
50
|
+
brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
|
51
|
+
brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
|
59
52
|
brainstate/transform/_control.py,sha256=NWceTIuLlj2uGTdNcqBAXgnaLuChOGgAtIXtFn5vdLU,26837
|
60
53
|
brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
|
61
54
|
brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
|
62
55
|
brainstate/transform/_jit_error.py,sha256=lO_e5AdhkjozHjM10q0b57OaXbeZ9gQkVmZMN6VQVCw,4450
|
63
|
-
brainstate/transform/_jit_test.py,sha256=
|
64
|
-
brainstate/transform/_make_jaxpr.py,sha256=
|
65
|
-
brainstate/transform/_make_jaxpr_test.py,sha256=
|
66
|
-
brainstate/transform/_progress_bar.py,sha256=
|
67
|
-
brainstate-0.0.1.
|
68
|
-
brainstate-0.0.1.
|
69
|
-
brainstate-0.0.1.
|
70
|
-
brainstate-0.0.1.
|
71
|
-
brainstate-0.0.1.
|
56
|
+
brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
|
57
|
+
brainstate/transform/_make_jaxpr.py,sha256=q3OPy-1Gg0mVaB9pgSTWzzP8FSCAgquSjP-pDEw3Tpg,29970
|
58
|
+
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
59
|
+
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
60
|
+
brainstate-0.0.1.post20240623.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
61
|
+
brainstate-0.0.1.post20240623.dist-info/METADATA,sha256=ezSoiXSzav8KX7NPa8jfR9aWqVPWwh6d46RZnac_Mdg,3814
|
62
|
+
brainstate-0.0.1.post20240623.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
63
|
+
brainstate-0.0.1.post20240623.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
64
|
+
brainstate-0.0.1.post20240623.dist-info/RECORD,,
|
brainstate/math/__init__.py
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from ._einops import *
|
17
|
-
from ._einops import __all__ as _einops_all
|
18
|
-
from ._misc import *
|
19
|
-
from ._misc import __all__ as _misc_all
|
20
|
-
|
21
|
-
__all__ = _misc_all + _einops_all
|