brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
brainstate/random/_rand_seed.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
15
|
|
17
16
|
from contextlib import contextmanager
|
18
17
|
from typing import Optional
|
@@ -13,14 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
|
20
18
|
import jax.numpy as jnp
|
21
19
|
import jax.random
|
22
20
|
|
23
|
-
import brainstate
|
21
|
+
import brainstate
|
24
22
|
|
25
23
|
|
26
24
|
class TestRandom(unittest.TestCase):
|
@@ -28,23 +26,23 @@ class TestRandom(unittest.TestCase):
|
|
28
26
|
def test_seed2(self):
|
29
27
|
test_seed = 299
|
30
28
|
key = jax.random.PRNGKey(test_seed)
|
31
|
-
|
29
|
+
brainstate.random.seed(key)
|
32
30
|
|
33
31
|
@jax.jit
|
34
32
|
def jit_seed(key):
|
35
|
-
|
36
|
-
with
|
37
|
-
print(
|
33
|
+
brainstate.random.seed(key)
|
34
|
+
with brainstate.random.seed_context(key):
|
35
|
+
print(brainstate.random.DEFAULT.value)
|
38
36
|
|
39
37
|
jit_seed(key)
|
40
38
|
jit_seed(1)
|
41
39
|
jit_seed(None)
|
42
|
-
|
40
|
+
brainstate.random.seed(1)
|
43
41
|
|
44
42
|
def test_seed(self):
|
45
43
|
test_seed = 299
|
46
|
-
|
47
|
-
a =
|
48
|
-
|
49
|
-
b =
|
44
|
+
brainstate.random.seed(test_seed)
|
45
|
+
a = brainstate.random.rand(3)
|
46
|
+
brainstate.random.seed(test_seed)
|
47
|
+
b = brainstate.random.rand(3)
|
50
48
|
self.assertTrue(jnp.array_equal(a, b))
|
brainstate/random/_rand_state.py
CHANGED
brainstate/surrogate.py
CHANGED
brainstate/typing.py
CHANGED
brainstate/util/_caller.py
CHANGED
@@ -15,8 +15,6 @@
|
|
15
15
|
# See the License for the specific language governing permissions and
|
16
16
|
# limitations under the License.
|
17
17
|
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
18
|
import dataclasses
|
21
19
|
from typing import Any, TypeVar, Protocol, Generic
|
22
20
|
|
@@ -82,18 +80,18 @@ class CallableProxy:
|
|
82
80
|
def __call__(self, *args, **kwargs):
|
83
81
|
return self._callable(self._accessor, *args, **kwargs)
|
84
82
|
|
85
|
-
def __getattr__(self, name) -> CallableProxy:
|
83
|
+
def __getattr__(self, name) -> 'CallableProxy':
|
86
84
|
return CallableProxy(self._callable, getattr(self._accessor, name))
|
87
85
|
|
88
|
-
def __getitem__(self, key) -> CallableProxy:
|
86
|
+
def __getitem__(self, key) -> 'CallableProxy':
|
89
87
|
return CallableProxy(self._callable, self._accessor[key])
|
90
88
|
|
91
89
|
|
92
90
|
class ApplyCaller(Protocol, Generic[A]):
|
93
|
-
def __getattr__(self, __name) -> ApplyCaller[A]:
|
91
|
+
def __getattr__(self, __name) -> 'ApplyCaller[A]':
|
94
92
|
...
|
95
93
|
|
96
|
-
def __getitem__(self, __name) -> ApplyCaller[A]:
|
94
|
+
def __getitem__(self, __name) -> 'ApplyCaller[A]':
|
97
95
|
...
|
98
96
|
|
99
97
|
def __call__(self, *args, **kwargs) -> tuple[Any, A]:
|