brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250423__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +15 -0
- brainstate/_state.py +5 -4
- brainstate/_state_test.py +2 -1
- brainstate/augment/_autograd_test.py +3 -2
- brainstate/augment/_eval_shape.py +2 -1
- brainstate/augment/_mapping.py +0 -1
- brainstate/augment/_mapping_test.py +1 -0
- brainstate/compile/_ad_checkpoint.py +2 -1
- brainstate/compile/_conditions.py +3 -3
- brainstate/compile/_conditions_test.py +2 -1
- brainstate/compile/_error_if.py +2 -1
- brainstate/compile/_error_if_test.py +2 -1
- brainstate/compile/_jit.py +3 -2
- brainstate/compile/_jit_test.py +2 -1
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_loop_collect_return_test.py +2 -1
- brainstate/compile/_loop_no_collection.py +1 -1
- brainstate/compile/_make_jaxpr.py +2 -2
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +2 -1
- brainstate/compile/_unvmap.py +1 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +2 -1
- brainstate/functional/_activations.py +2 -1
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +2 -1
- brainstate/functional/_others.py +2 -1
- brainstate/graph/_graph_operation.py +3 -2
- brainstate/graph/_graph_operation_test.py +4 -3
- brainstate/init/_base.py +2 -1
- brainstate/init/_generic.py +2 -1
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +1 -0
- brainstate/nn/_collective_ops_test.py +0 -4
- brainstate/nn/_common.py +0 -1
- brainstate/nn/_dyn_impl/__init__.py +0 -4
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
- brainstate/nn/_dyn_impl/_inputs.py +236 -29
- brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
- brainstate/nn/_dyn_impl/_readout.py +91 -8
- brainstate/nn/_dyn_impl/_readout_test.py +2 -1
- brainstate/nn/_dynamics/_dynamics_base.py +676 -96
- brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
- brainstate/nn/_dynamics/_projection_base.py +29 -30
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +2 -1
- brainstate/nn/_elementwise/_dropout.py +3 -2
- brainstate/nn/_elementwise/_dropout_test.py +2 -1
- brainstate/nn/_elementwise/_elementwise.py +2 -1
- brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
- brainstate/nn/_event/_fixedprob_mv.py +169 -0
- brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
- brainstate/nn/_event/_linear_mv.py +85 -0
- brainstate/nn/_event/_linear_mv_test.py +121 -0
- brainstate/nn/_exp_euler.py +2 -1
- brainstate/nn/_exp_euler_test.py +2 -1
- brainstate/nn/_interaction/_conv.py +2 -1
- brainstate/nn/_interaction/_linear.py +2 -1
- brainstate/nn/_interaction/_linear_test.py +2 -1
- brainstate/nn/_interaction/_normalizations.py +3 -2
- brainstate/nn/_interaction/_poolings.py +4 -3
- brainstate/nn/_module_test.py +2 -1
- brainstate/nn/metrics.py +4 -3
- brainstate/optim/_lr_scheduler.py +2 -1
- brainstate/optim/_lr_scheduler_test.py +2 -1
- brainstate/optim/_optax_optimizer_test.py +2 -1
- brainstate/optim/_sgd_optimizer.py +3 -2
- brainstate/random/_rand_funs.py +2 -1
- brainstate/random/_rand_funs_test.py +3 -2
- brainstate/random/_rand_seed.py +3 -2
- brainstate/random/_rand_seed_test.py +2 -1
- brainstate/random/_rand_state.py +4 -3
- brainstate/surrogate.py +1 -2
- brainstate/typing.py +4 -4
- brainstate/util/_caller.py +2 -1
- brainstate/util/_others.py +4 -4
- brainstate/util/_pretty_pytree.py +1 -1
- brainstate/util/_pretty_pytree_test.py +2 -1
- brainstate/util/_pretty_table.py +43 -43
- brainstate/util/_struct.py +2 -1
- brainstate/util/filter.py +0 -1
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/METADATA +3 -3
- brainstate-0.1.0.post20250423.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/top_level.txt +0 -0
brainstate/_compatible_import.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
+
import importlib.util
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from typing import Iterable, Hashable
|
21
22
|
|
@@ -29,8 +30,11 @@ __all__ = [
|
|
29
30
|
'get_aval',
|
30
31
|
'Tracer',
|
31
32
|
'to_concrete_aval',
|
33
|
+
'brainevent',
|
32
34
|
]
|
33
35
|
|
36
|
+
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
37
|
+
|
34
38
|
from jax.core import get_aval, Tracer
|
35
39
|
|
36
40
|
if jax.__version_info__ < (0, 4, 38):
|
@@ -56,3 +60,14 @@ def to_concrete_aval(aval):
|
|
56
60
|
return aval.to_concrete_value()
|
57
61
|
return aval
|
58
62
|
|
63
|
+
|
64
|
+
if brainevent_installed:
|
65
|
+
import brainevent
|
66
|
+
else:
|
67
|
+
|
68
|
+
class BrainEvent:
|
69
|
+
def __getattr__(self, item):
|
70
|
+
raise ImportError('brainevent is not installed, please install brainevent first.')
|
71
|
+
|
72
|
+
|
73
|
+
brainevent = BrainEvent()
|
brainstate/_state.py
CHANGED
@@ -17,12 +17,8 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import contextlib
|
19
19
|
import dataclasses
|
20
|
-
import jax
|
21
|
-
import numpy as np
|
22
20
|
import threading
|
23
21
|
from functools import wraps, partial
|
24
|
-
from jax.api_util import shaped_abstractify
|
25
|
-
from jax.extend import source_info_util
|
26
22
|
from typing import (
|
27
23
|
Any,
|
28
24
|
Union,
|
@@ -39,6 +35,11 @@ from typing import (
|
|
39
35
|
Generator,
|
40
36
|
)
|
41
37
|
|
38
|
+
import jax
|
39
|
+
import numpy as np
|
40
|
+
from jax.api_util import shaped_abstractify
|
41
|
+
from jax.extend import source_info_util
|
42
|
+
|
42
43
|
from brainstate.typing import ArrayLike, PyTree, Missing, Filter
|
43
44
|
from brainstate.util import DictManager, PrettyObject
|
44
45
|
from brainstate.util.filter import Nothing
|
brainstate/_state_test.py
CHANGED
@@ -16,12 +16,13 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
import unittest
|
20
|
+
from pprint import pprint
|
21
|
+
|
19
22
|
import brainunit as u
|
20
23
|
import jax
|
21
24
|
import jax.numpy as jnp
|
22
25
|
import pytest
|
23
|
-
import unittest
|
24
|
-
from pprint import pprint
|
25
26
|
|
26
27
|
import brainstate as bst
|
27
28
|
from brainstate.augment._autograd import _jacfwd
|
@@ -16,9 +16,10 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
19
|
from typing import Any, TypeVar, Callable, Sequence, Union
|
21
20
|
|
21
|
+
import jax
|
22
|
+
|
22
23
|
from brainstate import random
|
23
24
|
from brainstate.graph import Node, flatten, unflatten
|
24
25
|
from ._random import restore_rngs
|
brainstate/augment/_mapping.py
CHANGED
@@ -425,6 +425,7 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
425
425
|
|
426
426
|
def test_incompatible_shapes(self):
|
427
427
|
foo = brainstate.nn.LIF(3)
|
428
|
+
|
428
429
|
# Simulate an incompatible shapes scenario:
|
429
430
|
# We intentionally assign a state with a different shape than expected.
|
430
431
|
@bst.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
@@ -16,9 +16,10 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
19
|
from typing import Callable, Tuple, Union
|
21
20
|
|
21
|
+
import jax
|
22
|
+
|
22
23
|
from brainstate.typing import Missing
|
23
24
|
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
24
25
|
from ._util import write_back_state_values
|
@@ -15,17 +15,17 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
from collections.abc import Callable, Sequence
|
19
|
+
|
18
20
|
import jax
|
19
21
|
import jax.numpy as jnp
|
20
22
|
import numpy as np
|
21
|
-
from collections.abc import Callable, Sequence
|
22
23
|
|
24
|
+
from brainstate._compatible_import import to_concrete_aval, Tracer
|
23
25
|
from brainstate._utils import set_module_as
|
24
26
|
from ._error_if import jit_error_if
|
25
27
|
from ._make_jaxpr import StatefulFunction
|
26
28
|
from ._util import wrap_single_fun_in_multi_branches, write_back_state_values
|
27
|
-
from brainstate._compatible_import import to_concrete_aval, Tracer
|
28
|
-
|
29
29
|
|
30
30
|
__all__ = [
|
31
31
|
'cond', 'switch', 'ifelse',
|
brainstate/compile/_error_if.py
CHANGED
brainstate/compile/_jit.py
CHANGED
@@ -16,11 +16,12 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
19
|
from collections.abc import Iterable, Sequence
|
20
|
+
from typing import (Any, Callable, Union)
|
21
|
+
|
22
|
+
import jax
|
21
23
|
from jax._src import sharding_impls
|
22
24
|
from jax.lib import xla_client as xc
|
23
|
-
from typing import (Any, Callable, Union)
|
24
25
|
|
25
26
|
from brainstate._utils import set_module_as
|
26
27
|
from brainstate.typing import Missing
|
brainstate/compile/_jit_test.py
CHANGED
@@ -16,11 +16,11 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import math
|
19
|
+
from functools import wraps
|
20
|
+
from typing import Callable, Optional, TypeVar, Tuple, Any
|
19
21
|
|
20
22
|
import jax
|
21
23
|
import jax.numpy as jnp
|
22
|
-
from functools import wraps
|
23
|
-
from typing import Callable, Optional, TypeVar, Tuple, Any
|
24
24
|
|
25
25
|
from brainstate._utils import set_module_as
|
26
26
|
from ._make_jaxpr import StatefulFunction
|
@@ -16,9 +16,9 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import math
|
19
|
+
from typing import Any, Callable, TypeVar
|
19
20
|
|
20
21
|
import jax
|
21
|
-
from typing import Any, Callable, TypeVar
|
22
22
|
|
23
23
|
from brainstate._utils import set_module_as
|
24
24
|
from ._loop_collect_return import _bounded_while_loop
|
@@ -57,7 +57,7 @@ import functools
|
|
57
57
|
import inspect
|
58
58
|
import operator
|
59
59
|
from collections.abc import Hashable, Iterable, Sequence
|
60
|
-
from contextlib import ExitStack
|
60
|
+
from contextlib import ExitStack
|
61
61
|
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
62
62
|
|
63
63
|
import jax
|
@@ -69,11 +69,11 @@ from jax.extend.linear_util import transformation_with_aux, wrap_init
|
|
69
69
|
from jax.interpreters import partial_eval as pe
|
70
70
|
from jax.util import wraps
|
71
71
|
|
72
|
+
from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
|
72
73
|
from brainstate._state import State, StateTraceStack
|
73
74
|
from brainstate._utils import set_module_as
|
74
75
|
from brainstate.typing import PyTree
|
75
76
|
from brainstate.util import PrettyObject
|
76
|
-
from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
|
77
77
|
|
78
78
|
AxisName = Hashable
|
79
79
|
|
@@ -17,9 +17,10 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import copy
|
19
19
|
import importlib.util
|
20
|
-
import jax
|
21
20
|
from typing import Optional, Callable, Any, Tuple, Dict
|
22
21
|
|
22
|
+
import jax
|
23
|
+
|
23
24
|
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
24
25
|
|
25
26
|
__all__ = [
|
brainstate/compile/_unvmap.py
CHANGED
@@ -20,9 +20,8 @@ import jax.interpreters.batching as batching
|
|
20
20
|
import jax.interpreters.mlir as mlir
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
|
-
from brainstate._utils import set_module_as
|
24
23
|
from brainstate._compatible_import import Primitive
|
25
|
-
|
24
|
+
from brainstate._utils import set_module_as
|
26
25
|
|
27
26
|
__all__ = [
|
28
27
|
"unvmap",
|
brainstate/environ.py
CHANGED
@@ -17,18 +17,18 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from collections import defaultdict
|
21
|
-
|
22
20
|
import contextlib
|
23
21
|
import dataclasses
|
24
22
|
import functools
|
25
|
-
import numpy as np
|
26
23
|
import os
|
27
24
|
import re
|
28
25
|
import threading
|
26
|
+
from collections import defaultdict
|
27
|
+
from typing import Any, Callable, Dict, Hashable
|
28
|
+
|
29
|
+
import numpy as np
|
29
30
|
from jax import config, devices, numpy as jnp
|
30
31
|
from jax.typing import DTypeLike
|
31
|
-
from typing import Any, Callable, Dict, Hashable
|
32
32
|
|
33
33
|
from .mixin import Mode
|
34
34
|
|
brainstate/environ_test.py
CHANGED
@@ -20,10 +20,11 @@ Shared neural network activations and other functions.
|
|
20
20
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
|
+
from typing import Any, Union, Sequence
|
24
|
+
|
23
25
|
import brainunit as u
|
24
26
|
import jax
|
25
27
|
from jax.scipy.special import logsumexp
|
26
|
-
from typing import Any, Union, Sequence
|
27
28
|
|
28
29
|
from brainstate import random
|
29
30
|
from brainstate.typing import ArrayLike
|
@@ -16,12 +16,12 @@
|
|
16
16
|
"""Tests for nn module."""
|
17
17
|
|
18
18
|
import itertools
|
19
|
+
from functools import partial
|
19
20
|
|
20
21
|
import jax
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import scipy.stats
|
23
24
|
from absl.testing import parameterized
|
24
|
-
from functools import partial
|
25
25
|
from jax._src import test_util as jtu
|
26
26
|
from jax.test_util import check_grads
|
27
27
|
|
@@ -15,9 +15,10 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
from typing import Optional, Union
|
19
|
+
|
18
20
|
import brainunit as u
|
19
21
|
import jax
|
20
|
-
from typing import Optional, Union
|
21
22
|
|
22
23
|
from brainstate._utils import set_module_as
|
23
24
|
from brainstate.typing import ArrayLike
|
brainstate/functional/_others.py
CHANGED
@@ -18,10 +18,11 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import dataclasses
|
21
|
-
import jax
|
22
|
-
import numpy as np
|
23
21
|
from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
24
22
|
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
|
23
|
+
|
24
|
+
import jax
|
25
|
+
import numpy as np
|
25
26
|
from typing_extensions import TypeGuard, Unpack
|
26
27
|
|
27
28
|
from brainstate._state import State, TreefyState
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
18
|
import unittest
|
21
|
-
from absl.testing import absltest, parameterized
|
22
19
|
from collections.abc import Callable
|
23
20
|
from threading import Thread
|
24
21
|
|
22
|
+
import jax
|
23
|
+
import jax.numpy as jnp
|
24
|
+
from absl.testing import absltest, parameterized
|
25
|
+
|
25
26
|
import brainstate as bst
|
26
27
|
|
27
28
|
|
brainstate/init/_base.py
CHANGED
brainstate/init/_generic.py
CHANGED
@@ -16,10 +16,11 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
from typing import Union, Callable, Optional, Sequence
|
20
|
+
|
19
21
|
import brainunit as bu
|
20
22
|
import jax
|
21
23
|
import numpy as np
|
22
|
-
from typing import Union, Callable, Optional, Sequence
|
23
24
|
|
24
25
|
from brainstate._state import State
|
25
26
|
from brainstate._utils import set_module_as
|
brainstate/nn/__init__.py
CHANGED
@@ -25,6 +25,8 @@ from ._dynamics import *
|
|
25
25
|
from ._dynamics import __all__ as dynamics_all
|
26
26
|
from ._elementwise import *
|
27
27
|
from ._elementwise import __all__ as elementwise_all
|
28
|
+
from ._event import *
|
29
|
+
from ._event import __all__ as _event_all
|
28
30
|
from ._exp_euler import *
|
29
31
|
from ._exp_euler import __all__ as exp_euler_all
|
30
32
|
from ._interaction import *
|
@@ -45,6 +47,7 @@ __all__ = (
|
|
45
47
|
+ exp_euler_all
|
46
48
|
+ interaction_all
|
47
49
|
+ utils_all
|
50
|
+
+ _event_all
|
48
51
|
)
|
49
52
|
|
50
53
|
del (
|
@@ -57,4 +60,5 @@ del (
|
|
57
60
|
exp_euler_all,
|
58
61
|
interaction_all,
|
59
62
|
utils_all,
|
63
|
+
_event_all,
|
60
64
|
)
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -19,6 +19,7 @@ from collections import namedtuple
|
|
19
19
|
from typing import Callable, TypeVar, Tuple, Any, Dict
|
20
20
|
|
21
21
|
import jax
|
22
|
+
from typing import Callable, TypeVar, Tuple, Any, Dict
|
22
23
|
|
23
24
|
from brainstate._state import catch_new_states
|
24
25
|
from brainstate._utils import set_module_as
|
brainstate/nn/_common.py
CHANGED
@@ -20,8 +20,6 @@ from ._dynamics_synapse import *
|
|
20
20
|
from ._dynamics_synapse import __all__ as dyn_synapse_all
|
21
21
|
from ._inputs import *
|
22
22
|
from ._inputs import __all__ as inputs_all
|
23
|
-
from ._projection_alignpost import *
|
24
|
-
from ._projection_alignpost import __all__ as alignpost_all
|
25
23
|
from ._rate_rnns import *
|
26
24
|
from ._rate_rnns import __all__ as rate_rnns
|
27
25
|
from ._readout import *
|
@@ -31,7 +29,6 @@ __all__ = (
|
|
31
29
|
dyn_neuron_all
|
32
30
|
+ dyn_synapse_all
|
33
31
|
+ inputs_all
|
34
|
-
+ alignpost_all
|
35
32
|
+ rate_rnns
|
36
33
|
+ readout_all
|
37
34
|
)
|
@@ -41,6 +38,5 @@ del (
|
|
41
38
|
dyn_synapse_all,
|
42
39
|
inputs_all,
|
43
40
|
readout_all,
|
44
|
-
alignpost_all,
|
45
41
|
rate_rnns,
|
46
42
|
)
|