brainstate 0.1.0.post20250413__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +73 -0
- brainstate/_state.py +5 -4
- brainstate/_state_test.py +2 -1
- brainstate/augment/_autograd_test.py +3 -2
- brainstate/augment/_eval_shape.py +2 -1
- brainstate/augment/_mapping.py +0 -1
- brainstate/augment/_mapping_test.py +1 -0
- brainstate/compile/_ad_checkpoint.py +2 -1
- brainstate/compile/_conditions.py +4 -2
- brainstate/compile/_conditions_test.py +2 -1
- brainstate/compile/_error_if.py +2 -1
- brainstate/compile/_error_if_test.py +2 -1
- brainstate/compile/_jit.py +3 -2
- brainstate/compile/_jit_test.py +2 -1
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_loop_collect_return_test.py +2 -1
- brainstate/compile/_loop_no_collection.py +1 -1
- brainstate/compile/_make_jaxpr.py +10 -13
- brainstate/compile/_make_jaxpr_test.py +3 -6
- brainstate/compile/_progress_bar.py +2 -1
- brainstate/compile/_unvmap.py +1 -5
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +2 -1
- brainstate/functional/_activations.py +2 -1
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +2 -1
- brainstate/functional/_others.py +2 -1
- brainstate/graph/_graph_operation.py +3 -2
- brainstate/graph/_graph_operation_test.py +4 -3
- brainstate/init/_base.py +2 -1
- brainstate/init/_generic.py +2 -1
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +1 -0
- brainstate/nn/_collective_ops_test.py +0 -4
- brainstate/nn/_common.py +0 -1
- brainstate/nn/_dyn_impl/__init__.py +0 -4
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
- brainstate/nn/_dyn_impl/_inputs.py +236 -29
- brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
- brainstate/nn/_dyn_impl/_readout.py +91 -8
- brainstate/nn/_dyn_impl/_readout_test.py +2 -1
- brainstate/nn/_dynamics/_dynamics_base.py +676 -96
- brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
- brainstate/nn/_dynamics/_projection_base.py +29 -30
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +2 -1
- brainstate/nn/_elementwise/_dropout.py +3 -2
- brainstate/nn/_elementwise/_dropout_test.py +2 -1
- brainstate/nn/_elementwise/_elementwise.py +2 -1
- brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
- brainstate/nn/_event/_fixedprob_mv.py +169 -0
- brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
- brainstate/nn/_event/_linear_mv.py +85 -0
- brainstate/nn/_event/_linear_mv_test.py +121 -0
- brainstate/nn/_exp_euler.py +2 -1
- brainstate/nn/_exp_euler_test.py +2 -1
- brainstate/nn/_interaction/_conv.py +2 -1
- brainstate/nn/_interaction/_linear.py +2 -1
- brainstate/nn/_interaction/_linear_test.py +2 -1
- brainstate/nn/_interaction/_normalizations.py +2 -1
- brainstate/nn/_interaction/_poolings.py +4 -3
- brainstate/nn/_module_test.py +2 -1
- brainstate/nn/metrics.py +4 -3
- brainstate/optim/_lr_scheduler.py +2 -1
- brainstate/optim/_lr_scheduler_test.py +2 -1
- brainstate/optim/_optax_optimizer_test.py +2 -1
- brainstate/optim/_sgd_optimizer.py +3 -2
- brainstate/random/_rand_funs.py +2 -1
- brainstate/random/_rand_funs_test.py +3 -2
- brainstate/random/_rand_seed.py +3 -2
- brainstate/random/_rand_seed_test.py +2 -1
- brainstate/random/_rand_state.py +4 -3
- brainstate/surrogate.py +1 -5
- brainstate/typing.py +4 -4
- brainstate/util/_caller.py +2 -1
- brainstate/util/_others.py +4 -4
- brainstate/util/_pretty_pytree.py +1 -1
- brainstate/util/_pretty_pytree_test.py +2 -1
- brainstate/util/_pretty_table.py +43 -43
- brainstate/util/_struct.py +2 -1
- brainstate/util/filter.py +0 -1
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
- brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250413.dist-info/RECORD +0 -128
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250413.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
import importlib.util
|
20
|
+
from contextlib import contextmanager
|
21
|
+
from typing import Iterable, Hashable
|
22
|
+
|
23
|
+
import jax
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'ClosedJaxpr',
|
27
|
+
'Primitive',
|
28
|
+
'extend_axis_env_nd',
|
29
|
+
'jaxpr_as_fun',
|
30
|
+
'get_aval',
|
31
|
+
'Tracer',
|
32
|
+
'to_concrete_aval',
|
33
|
+
'brainevent',
|
34
|
+
]
|
35
|
+
|
36
|
+
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
37
|
+
|
38
|
+
from jax.core import get_aval, Tracer
|
39
|
+
|
40
|
+
if jax.__version_info__ < (0, 4, 38):
|
41
|
+
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
42
|
+
else:
|
43
|
+
from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
|
44
|
+
from jax.core import trace_ctx
|
45
|
+
|
46
|
+
|
47
|
+
@contextmanager
|
48
|
+
def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
|
49
|
+
prev = trace_ctx.axis_env
|
50
|
+
try:
|
51
|
+
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
|
52
|
+
yield
|
53
|
+
finally:
|
54
|
+
trace_ctx.set_axis_env(prev)
|
55
|
+
|
56
|
+
|
57
|
+
def to_concrete_aval(aval):
|
58
|
+
aval = get_aval(aval)
|
59
|
+
if isinstance(aval, Tracer):
|
60
|
+
return aval.to_concrete_value()
|
61
|
+
return aval
|
62
|
+
|
63
|
+
|
64
|
+
if brainevent_installed:
|
65
|
+
import brainevent
|
66
|
+
else:
|
67
|
+
|
68
|
+
class BrainEvent:
|
69
|
+
def __getattr__(self, item):
|
70
|
+
raise ImportError('brainevent is not installed, please install brainevent first.')
|
71
|
+
|
72
|
+
|
73
|
+
brainevent = BrainEvent()
|
brainstate/_state.py
CHANGED
@@ -17,12 +17,8 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import contextlib
|
19
19
|
import dataclasses
|
20
|
-
import jax
|
21
|
-
import numpy as np
|
22
20
|
import threading
|
23
21
|
from functools import wraps, partial
|
24
|
-
from jax.api_util import shaped_abstractify
|
25
|
-
from jax.extend import source_info_util
|
26
22
|
from typing import (
|
27
23
|
Any,
|
28
24
|
Union,
|
@@ -39,6 +35,11 @@ from typing import (
|
|
39
35
|
Generator,
|
40
36
|
)
|
41
37
|
|
38
|
+
import jax
|
39
|
+
import numpy as np
|
40
|
+
from jax.api_util import shaped_abstractify
|
41
|
+
from jax.extend import source_info_util
|
42
|
+
|
42
43
|
from brainstate.typing import ArrayLike, PyTree, Missing, Filter
|
43
44
|
from brainstate.util import DictManager, PrettyObject
|
44
45
|
from brainstate.util.filter import Nothing
|
brainstate/_state_test.py
CHANGED
@@ -16,12 +16,13 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
import unittest
|
20
|
+
from pprint import pprint
|
21
|
+
|
19
22
|
import brainunit as u
|
20
23
|
import jax
|
21
24
|
import jax.numpy as jnp
|
22
25
|
import pytest
|
23
|
-
import unittest
|
24
|
-
from pprint import pprint
|
25
26
|
|
26
27
|
import brainstate as bst
|
27
28
|
from brainstate.augment._autograd import _jacfwd
|
@@ -16,9 +16,10 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
19
|
from typing import Any, TypeVar, Callable, Sequence, Union
|
21
20
|
|
21
|
+
import jax
|
22
|
+
|
22
23
|
from brainstate import random
|
23
24
|
from brainstate.graph import Node, flatten, unflatten
|
24
25
|
from ._random import restore_rngs
|
brainstate/augment/_mapping.py
CHANGED
@@ -425,6 +425,7 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
425
425
|
|
426
426
|
def test_incompatible_shapes(self):
|
427
427
|
foo = brainstate.nn.LIF(3)
|
428
|
+
|
428
429
|
# Simulate an incompatible shapes scenario:
|
429
430
|
# We intentionally assign a state with a different shape than expected.
|
430
431
|
@bst.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
@@ -16,9 +16,10 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
19
|
from typing import Callable, Tuple, Union
|
21
20
|
|
21
|
+
import jax
|
22
|
+
|
22
23
|
from brainstate.typing import Missing
|
23
24
|
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
24
25
|
from ._util import write_back_state_values
|
@@ -15,11 +15,13 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
from collections.abc import Callable, Sequence
|
19
|
+
|
18
20
|
import jax
|
19
21
|
import jax.numpy as jnp
|
20
22
|
import numpy as np
|
21
|
-
from collections.abc import Callable, Sequence
|
22
23
|
|
24
|
+
from brainstate._compatible_import import to_concrete_aval, Tracer
|
23
25
|
from brainstate._utils import set_module_as
|
24
26
|
from ._error_if import jit_error_if
|
25
27
|
from ._make_jaxpr import StatefulFunction
|
@@ -86,7 +88,7 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
|
86
88
|
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
|
87
89
|
|
88
90
|
# not jit
|
89
|
-
if jax.config.jax_disable_jit and isinstance(
|
91
|
+
if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(pred), Tracer):
|
90
92
|
if pred:
|
91
93
|
return true_fun(*operands)
|
92
94
|
else:
|
brainstate/compile/_error_if.py
CHANGED
brainstate/compile/_jit.py
CHANGED
@@ -16,11 +16,12 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
19
|
from collections.abc import Iterable, Sequence
|
20
|
+
from typing import (Any, Callable, Union)
|
21
|
+
|
22
|
+
import jax
|
21
23
|
from jax._src import sharding_impls
|
22
24
|
from jax.lib import xla_client as xc
|
23
|
-
from typing import (Any, Callable, Union)
|
24
25
|
|
25
26
|
from brainstate._utils import set_module_as
|
26
27
|
from brainstate.typing import Missing
|
brainstate/compile/_jit_test.py
CHANGED
@@ -16,11 +16,11 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import math
|
19
|
+
from functools import wraps
|
20
|
+
from typing import Callable, Optional, TypeVar, Tuple, Any
|
19
21
|
|
20
22
|
import jax
|
21
23
|
import jax.numpy as jnp
|
22
|
-
from functools import wraps
|
23
|
-
from typing import Callable, Optional, TypeVar, Tuple, Any
|
24
24
|
|
25
25
|
from brainstate._utils import set_module_as
|
26
26
|
from ._make_jaxpr import StatefulFunction
|
@@ -16,9 +16,9 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import math
|
19
|
+
from typing import Any, Callable, TypeVar
|
19
20
|
|
20
21
|
import jax
|
21
|
-
from typing import Any, Callable, TypeVar
|
22
22
|
|
23
23
|
from brainstate._utils import set_module_as
|
24
24
|
from ._loop_collect_return import _bounded_while_loop
|
@@ -55,10 +55,12 @@ from __future__ import annotations
|
|
55
55
|
|
56
56
|
import functools
|
57
57
|
import inspect
|
58
|
-
import jax
|
59
58
|
import operator
|
60
59
|
from collections.abc import Hashable, Iterable, Sequence
|
61
60
|
from contextlib import ExitStack
|
61
|
+
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
62
|
+
|
63
|
+
import jax
|
62
64
|
from jax._src import source_info_util
|
63
65
|
from jax._src.linear_util import annotate
|
64
66
|
from jax._src.traceback_util import api_boundary
|
@@ -66,18 +68,13 @@ from jax.api_util import shaped_abstractify
|
|
66
68
|
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
67
69
|
from jax.interpreters import partial_eval as pe
|
68
70
|
from jax.util import wraps
|
69
|
-
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
70
71
|
|
72
|
+
from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
|
71
73
|
from brainstate._state import State, StateTraceStack
|
72
74
|
from brainstate._utils import set_module_as
|
73
75
|
from brainstate.typing import PyTree
|
74
76
|
from brainstate.util import PrettyObject
|
75
77
|
|
76
|
-
if jax.__version_info__ < (0, 4, 38):
|
77
|
-
from jax.core import ClosedJaxpr
|
78
|
-
else:
|
79
|
-
from jax.extend.core import ClosedJaxpr
|
80
|
-
|
81
78
|
AxisName = Hashable
|
82
79
|
|
83
80
|
__all__ = [
|
@@ -200,7 +197,7 @@ class StatefulFunction(PrettyObject):
|
|
200
197
|
|
201
198
|
# implicit parameters
|
202
199
|
self.cache_type = cache_type
|
203
|
-
self._cached_jaxpr: Dict[Any,
|
200
|
+
self._cached_jaxpr: Dict[Any, ClosedJaxpr] = dict()
|
204
201
|
self._cached_out_shapes: Dict[Any, PyTree] = dict()
|
205
202
|
self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
|
206
203
|
self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
|
@@ -210,7 +207,7 @@ class StatefulFunction(PrettyObject):
|
|
210
207
|
return None
|
211
208
|
return k, v
|
212
209
|
|
213
|
-
def get_jaxpr(self, cache_key: Hashable = ()) ->
|
210
|
+
def get_jaxpr(self, cache_key: Hashable = ()) -> ClosedJaxpr:
|
214
211
|
"""
|
215
212
|
Read the JAX Jaxpr representation of the function.
|
216
213
|
|
@@ -507,8 +504,8 @@ def make_jaxpr(
|
|
507
504
|
return_shape: bool = False,
|
508
505
|
abstracted_axes: Optional[Any] = None,
|
509
506
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
510
|
-
) -> Callable[..., (Tuple[
|
511
|
-
Tuple[
|
507
|
+
) -> Callable[..., (Tuple[ClosedJaxpr, Tuple[State, ...]] |
|
508
|
+
Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])]:
|
512
509
|
"""
|
513
510
|
Creates a function that produces its jaxpr given example args.
|
514
511
|
|
@@ -754,12 +751,12 @@ def _make_jaxpr(
|
|
754
751
|
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
755
752
|
with ExitStack() as stack:
|
756
753
|
if axis_env is not None:
|
757
|
-
stack.enter_context(
|
754
|
+
stack.enter_context(extend_axis_env_nd(axis_env))
|
758
755
|
if jax.__version_info__ < (0, 5, 0):
|
759
756
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
760
757
|
else:
|
761
758
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
762
|
-
closed_jaxpr =
|
759
|
+
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
763
760
|
if return_shape:
|
764
761
|
out_avals, _ = jax.util.unzip2(out_type)
|
765
762
|
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
@@ -15,17 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import unittest
|
19
|
+
|
18
20
|
import jax
|
19
21
|
import jax.numpy as jnp
|
20
22
|
import pytest
|
21
|
-
import unittest
|
22
23
|
|
23
24
|
import brainstate as bst
|
24
|
-
|
25
|
-
if jax.__version_info__ < (0, 4, 38):
|
26
|
-
from jax.core import jaxpr_as_fun
|
27
|
-
else:
|
28
|
-
from jax.extend.core import jaxpr_as_fun
|
25
|
+
from brainstate._compatible_import import jaxpr_as_fun
|
29
26
|
|
30
27
|
|
31
28
|
class TestMakeJaxpr(unittest.TestCase):
|
@@ -17,9 +17,10 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import copy
|
19
19
|
import importlib.util
|
20
|
-
import jax
|
21
20
|
from typing import Optional, Callable, Any, Tuple, Dict
|
22
21
|
|
22
|
+
import jax
|
23
|
+
|
23
24
|
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
24
25
|
|
25
26
|
__all__ = [
|
brainstate/compile/_unvmap.py
CHANGED
@@ -20,13 +20,9 @@ import jax.interpreters.batching as batching
|
|
20
20
|
import jax.interpreters.mlir as mlir
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
|
+
from brainstate._compatible_import import Primitive
|
23
24
|
from brainstate._utils import set_module_as
|
24
25
|
|
25
|
-
if jax.__version_info__ < (0, 4, 38):
|
26
|
-
from jax.core import Primitive
|
27
|
-
else:
|
28
|
-
from jax.extend.core import Primitive
|
29
|
-
|
30
26
|
__all__ = [
|
31
27
|
"unvmap",
|
32
28
|
]
|
brainstate/environ.py
CHANGED
@@ -17,18 +17,18 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from collections import defaultdict
|
21
|
-
|
22
20
|
import contextlib
|
23
21
|
import dataclasses
|
24
22
|
import functools
|
25
|
-
import numpy as np
|
26
23
|
import os
|
27
24
|
import re
|
28
25
|
import threading
|
26
|
+
from collections import defaultdict
|
27
|
+
from typing import Any, Callable, Dict, Hashable
|
28
|
+
|
29
|
+
import numpy as np
|
29
30
|
from jax import config, devices, numpy as jnp
|
30
31
|
from jax.typing import DTypeLike
|
31
|
-
from typing import Any, Callable, Dict, Hashable
|
32
32
|
|
33
33
|
from .mixin import Mode
|
34
34
|
|
brainstate/environ_test.py
CHANGED
@@ -20,10 +20,11 @@ Shared neural network activations and other functions.
|
|
20
20
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
|
+
from typing import Any, Union, Sequence
|
24
|
+
|
23
25
|
import brainunit as u
|
24
26
|
import jax
|
25
27
|
from jax.scipy.special import logsumexp
|
26
|
-
from typing import Any, Union, Sequence
|
27
28
|
|
28
29
|
from brainstate import random
|
29
30
|
from brainstate.typing import ArrayLike
|
@@ -16,12 +16,12 @@
|
|
16
16
|
"""Tests for nn module."""
|
17
17
|
|
18
18
|
import itertools
|
19
|
+
from functools import partial
|
19
20
|
|
20
21
|
import jax
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import scipy.stats
|
23
24
|
from absl.testing import parameterized
|
24
|
-
from functools import partial
|
25
25
|
from jax._src import test_util as jtu
|
26
26
|
from jax.test_util import check_grads
|
27
27
|
|
@@ -15,9 +15,10 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
from typing import Optional, Union
|
19
|
+
|
18
20
|
import brainunit as u
|
19
21
|
import jax
|
20
|
-
from typing import Optional, Union
|
21
22
|
|
22
23
|
from brainstate._utils import set_module_as
|
23
24
|
from brainstate.typing import ArrayLike
|
brainstate/functional/_others.py
CHANGED
@@ -18,10 +18,11 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import dataclasses
|
21
|
-
import jax
|
22
|
-
import numpy as np
|
23
21
|
from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
24
22
|
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
|
23
|
+
|
24
|
+
import jax
|
25
|
+
import numpy as np
|
25
26
|
from typing_extensions import TypeGuard, Unpack
|
26
27
|
|
27
28
|
from brainstate._state import State, TreefyState
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
18
|
import unittest
|
21
|
-
from absl.testing import absltest, parameterized
|
22
19
|
from collections.abc import Callable
|
23
20
|
from threading import Thread
|
24
21
|
|
22
|
+
import jax
|
23
|
+
import jax.numpy as jnp
|
24
|
+
from absl.testing import absltest, parameterized
|
25
|
+
|
25
26
|
import brainstate as bst
|
26
27
|
|
27
28
|
|
brainstate/init/_base.py
CHANGED
brainstate/init/_generic.py
CHANGED
@@ -16,10 +16,11 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
from typing import Union, Callable, Optional, Sequence
|
20
|
+
|
19
21
|
import brainunit as bu
|
20
22
|
import jax
|
21
23
|
import numpy as np
|
22
|
-
from typing import Union, Callable, Optional, Sequence
|
23
24
|
|
24
25
|
from brainstate._state import State
|
25
26
|
from brainstate._utils import set_module_as
|
brainstate/nn/__init__.py
CHANGED
@@ -25,6 +25,8 @@ from ._dynamics import *
|
|
25
25
|
from ._dynamics import __all__ as dynamics_all
|
26
26
|
from ._elementwise import *
|
27
27
|
from ._elementwise import __all__ as elementwise_all
|
28
|
+
from ._event import *
|
29
|
+
from ._event import __all__ as _event_all
|
28
30
|
from ._exp_euler import *
|
29
31
|
from ._exp_euler import __all__ as exp_euler_all
|
30
32
|
from ._interaction import *
|
@@ -45,6 +47,7 @@ __all__ = (
|
|
45
47
|
+ exp_euler_all
|
46
48
|
+ interaction_all
|
47
49
|
+ utils_all
|
50
|
+
+ _event_all
|
48
51
|
)
|
49
52
|
|
50
53
|
del (
|
@@ -57,4 +60,5 @@ del (
|
|
57
60
|
exp_euler_all,
|
58
61
|
interaction_all,
|
59
62
|
utils_all,
|
63
|
+
_event_all,
|
60
64
|
)
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -19,6 +19,7 @@ from collections import namedtuple
|
|
19
19
|
from typing import Callable, TypeVar, Tuple, Any, Dict
|
20
20
|
|
21
21
|
import jax
|
22
|
+
from typing import Callable, TypeVar, Tuple, Any, Dict
|
22
23
|
|
23
24
|
from brainstate._state import catch_new_states
|
24
25
|
from brainstate._utils import set_module_as
|
brainstate/nn/_common.py
CHANGED
@@ -20,8 +20,6 @@ from ._dynamics_synapse import *
|
|
20
20
|
from ._dynamics_synapse import __all__ as dyn_synapse_all
|
21
21
|
from ._inputs import *
|
22
22
|
from ._inputs import __all__ as inputs_all
|
23
|
-
from ._projection_alignpost import *
|
24
|
-
from ._projection_alignpost import __all__ as alignpost_all
|
25
23
|
from ._rate_rnns import *
|
26
24
|
from ._rate_rnns import __all__ as rate_rnns
|
27
25
|
from ._readout import *
|
@@ -31,7 +29,6 @@ __all__ = (
|
|
31
29
|
dyn_neuron_all
|
32
30
|
+ dyn_synapse_all
|
33
31
|
+ inputs_all
|
34
|
-
+ alignpost_all
|
35
32
|
+ rate_rnns
|
36
33
|
+ readout_all
|
37
34
|
)
|
@@ -41,6 +38,5 @@ del (
|
|
41
38
|
dyn_synapse_all,
|
42
39
|
inputs_all,
|
43
40
|
readout_all,
|
44
|
-
alignpost_all,
|
45
41
|
rate_rnns,
|
46
42
|
)
|