brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. brainstate/_compatible_import.py +15 -0
  2. brainstate/_state.py +5 -4
  3. brainstate/_state_test.py +2 -1
  4. brainstate/augment/_autograd_test.py +3 -2
  5. brainstate/augment/_eval_shape.py +2 -1
  6. brainstate/augment/_mapping.py +0 -1
  7. brainstate/augment/_mapping_test.py +1 -0
  8. brainstate/compile/_ad_checkpoint.py +2 -1
  9. brainstate/compile/_conditions.py +3 -3
  10. brainstate/compile/_conditions_test.py +2 -1
  11. brainstate/compile/_error_if.py +2 -1
  12. brainstate/compile/_error_if_test.py +2 -1
  13. brainstate/compile/_jit.py +3 -2
  14. brainstate/compile/_jit_test.py +2 -1
  15. brainstate/compile/_loop_collect_return.py +2 -2
  16. brainstate/compile/_loop_collect_return_test.py +2 -1
  17. brainstate/compile/_loop_no_collection.py +1 -1
  18. brainstate/compile/_make_jaxpr.py +2 -2
  19. brainstate/compile/_make_jaxpr_test.py +2 -1
  20. brainstate/compile/_progress_bar.py +2 -1
  21. brainstate/compile/_unvmap.py +1 -2
  22. brainstate/environ.py +4 -4
  23. brainstate/environ_test.py +2 -1
  24. brainstate/functional/_activations.py +2 -1
  25. brainstate/functional/_activations_test.py +1 -1
  26. brainstate/functional/_normalization.py +2 -1
  27. brainstate/functional/_others.py +2 -1
  28. brainstate/graph/_graph_operation.py +3 -2
  29. brainstate/graph/_graph_operation_test.py +4 -3
  30. brainstate/init/_base.py +2 -1
  31. brainstate/init/_generic.py +2 -1
  32. brainstate/nn/__init__.py +4 -0
  33. brainstate/nn/_collective_ops.py +1 -0
  34. brainstate/nn/_collective_ops_test.py +0 -4
  35. brainstate/nn/_common.py +0 -1
  36. brainstate/nn/_dyn_impl/__init__.py +0 -4
  37. brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
  38. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
  39. brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
  40. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
  41. brainstate/nn/_dyn_impl/_inputs.py +236 -29
  42. brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
  43. brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
  44. brainstate/nn/_dyn_impl/_readout.py +91 -8
  45. brainstate/nn/_dyn_impl/_readout_test.py +2 -1
  46. brainstate/nn/_dynamics/_dynamics_base.py +676 -96
  47. brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
  48. brainstate/nn/_dynamics/_projection_base.py +29 -30
  49. brainstate/nn/_dynamics/_state_delay.py +3 -3
  50. brainstate/nn/_dynamics/_synouts_test.py +2 -1
  51. brainstate/nn/_elementwise/_dropout.py +3 -2
  52. brainstate/nn/_elementwise/_dropout_test.py +2 -1
  53. brainstate/nn/_elementwise/_elementwise.py +2 -1
  54. brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
  55. brainstate/nn/_event/_fixedprob_mv.py +169 -0
  56. brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
  57. brainstate/nn/_event/_linear_mv.py +85 -0
  58. brainstate/nn/_event/_linear_mv_test.py +121 -0
  59. brainstate/nn/_exp_euler.py +2 -1
  60. brainstate/nn/_exp_euler_test.py +2 -1
  61. brainstate/nn/_interaction/_conv.py +2 -1
  62. brainstate/nn/_interaction/_linear.py +2 -1
  63. brainstate/nn/_interaction/_linear_test.py +2 -1
  64. brainstate/nn/_interaction/_normalizations.py +3 -2
  65. brainstate/nn/_interaction/_poolings.py +4 -3
  66. brainstate/nn/_module_test.py +2 -1
  67. brainstate/nn/metrics.py +4 -3
  68. brainstate/optim/_lr_scheduler.py +2 -1
  69. brainstate/optim/_lr_scheduler_test.py +2 -1
  70. brainstate/optim/_optax_optimizer_test.py +2 -1
  71. brainstate/optim/_sgd_optimizer.py +3 -2
  72. brainstate/random/_rand_funs.py +2 -1
  73. brainstate/random/_rand_funs_test.py +3 -2
  74. brainstate/random/_rand_seed.py +3 -2
  75. brainstate/random/_rand_seed_test.py +2 -1
  76. brainstate/random/_rand_state.py +4 -3
  77. brainstate/surrogate.py +1 -2
  78. brainstate/typing.py +4 -4
  79. brainstate/util/_caller.py +2 -1
  80. brainstate/util/_others.py +4 -4
  81. brainstate/util/_pretty_pytree.py +1 -1
  82. brainstate/util/_pretty_pytree_test.py +2 -1
  83. brainstate/util/_pretty_table.py +43 -43
  84. brainstate/util/_struct.py +2 -1
  85. brainstate/util/filter.py +0 -1
  86. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
  87. brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
  88. brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
  89. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
  90. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
  91. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
+ import importlib.util
19
20
  from contextlib import contextmanager
20
21
  from typing import Iterable, Hashable
21
22
 
@@ -29,8 +30,11 @@ __all__ = [
29
30
  'get_aval',
30
31
  'Tracer',
31
32
  'to_concrete_aval',
33
+ 'brainevent',
32
34
  ]
33
35
 
36
+ brainevent_installed = importlib.util.find_spec('brainevent') is not None
37
+
34
38
  from jax.core import get_aval, Tracer
35
39
 
36
40
  if jax.__version_info__ < (0, 4, 38):
@@ -56,3 +60,14 @@ def to_concrete_aval(aval):
56
60
  return aval.to_concrete_value()
57
61
  return aval
58
62
 
63
+
64
+ if brainevent_installed:
65
+ import brainevent
66
+ else:
67
+
68
+ class BrainEvent:
69
+ def __getattr__(self, item):
70
+ raise ImportError('brainevent is not installed, please install brainevent first.')
71
+
72
+
73
+ brainevent = BrainEvent()
brainstate/_state.py CHANGED
@@ -17,12 +17,8 @@ from __future__ import annotations
17
17
 
18
18
  import contextlib
19
19
  import dataclasses
20
- import jax
21
- import numpy as np
22
20
  import threading
23
21
  from functools import wraps, partial
24
- from jax.api_util import shaped_abstractify
25
- from jax.extend import source_info_util
26
22
  from typing import (
27
23
  Any,
28
24
  Union,
@@ -39,6 +35,11 @@ from typing import (
39
35
  Generator,
40
36
  )
41
37
 
38
+ import jax
39
+ import numpy as np
40
+ from jax.api_util import shaped_abstractify
41
+ from jax.extend import source_info_util
42
+
42
43
  from brainstate.typing import ArrayLike, PyTree, Missing, Filter
43
44
  from brainstate.util import DictManager, PrettyObject
44
45
  from brainstate.util.filter import Nothing
brainstate/_state_test.py CHANGED
@@ -14,9 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import jax.numpy as jnp
18
17
  import unittest
19
18
 
19
+ import jax.numpy as jnp
20
+
20
21
  import brainstate as bst
21
22
 
22
23
 
@@ -16,12 +16,13 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
+ import unittest
20
+ from pprint import pprint
21
+
19
22
  import brainunit as u
20
23
  import jax
21
24
  import jax.numpy as jnp
22
25
  import pytest
23
- import unittest
24
- from pprint import pprint
25
26
 
26
27
  import brainstate as bst
27
28
  from brainstate.augment._autograd import _jacfwd
@@ -16,9 +16,10 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- import jax
20
19
  from typing import Any, TypeVar, Callable, Sequence, Union
21
20
 
21
+ import jax
22
+
22
23
  from brainstate import random
23
24
  from brainstate.graph import Node, flatten, unflatten
24
25
  from ._random import restore_rngs
@@ -955,7 +955,6 @@ def _vmap_new_states_transform(
955
955
  if isinstance(axis_size, int) and axis_size <= 0:
956
956
  raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
957
957
 
958
-
959
958
  @vmap(
960
959
  in_axes=in_axes,
961
960
  out_axes=out_axes,
@@ -425,6 +425,7 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
425
425
 
426
426
  def test_incompatible_shapes(self):
427
427
  foo = brainstate.nn.LIF(3)
428
+
428
429
  # Simulate an incompatible shapes scenario:
429
430
  # We intentionally assign a state with a different shape than expected.
430
431
  @bst.augment.vmap_new_states(state_tag='new1', axis_size=5)
@@ -16,9 +16,10 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- import jax
20
19
  from typing import Callable, Tuple, Union
21
20
 
21
+ import jax
22
+
22
23
  from brainstate.typing import Missing
23
24
  from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
24
25
  from ._util import write_back_state_values
@@ -15,17 +15,17 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ from collections.abc import Callable, Sequence
19
+
18
20
  import jax
19
21
  import jax.numpy as jnp
20
22
  import numpy as np
21
- from collections.abc import Callable, Sequence
22
23
 
24
+ from brainstate._compatible_import import to_concrete_aval, Tracer
23
25
  from brainstate._utils import set_module_as
24
26
  from ._error_if import jit_error_if
25
27
  from ._make_jaxpr import StatefulFunction
26
28
  from ._util import wrap_single_fun_in_multi_branches, write_back_state_values
27
- from brainstate._compatible_import import to_concrete_aval, Tracer
28
-
29
29
 
30
30
  __all__ = [
31
31
  'cond', 'switch', 'ifelse',
@@ -14,9 +14,10 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
+ import unittest
18
+
17
19
  import jax
18
20
  import jax.numpy as jnp
19
- import unittest
20
21
 
21
22
  import brainstate as bst
22
23
 
@@ -16,10 +16,11 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- import jax
20
19
  from functools import partial
21
20
  from typing import Callable, Union
22
21
 
22
+ import jax
23
+
23
24
  from brainstate._utils import set_module_as
24
25
  from ._unvmap import unvmap
25
26
 
@@ -15,10 +15,11 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import jax
19
21
  import jax.numpy as jnp
20
22
  import jaxlib.xla_extension
21
- import unittest
22
23
 
23
24
  import brainstate as bst
24
25
 
@@ -16,11 +16,12 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- import jax
20
19
  from collections.abc import Iterable, Sequence
20
+ from typing import (Any, Callable, Union)
21
+
22
+ import jax
21
23
  from jax._src import sharding_impls
22
24
  from jax.lib import xla_client as xc
23
- from typing import (Any, Callable, Union)
24
25
 
25
26
  from brainstate._utils import set_module_as
26
27
  from brainstate.typing import Missing
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import jax.numpy as jnp
19
18
  import unittest
20
19
 
20
+ import jax.numpy as jnp
21
+
21
22
  import brainstate as bst
22
23
 
23
24
 
@@ -16,11 +16,11 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import math
19
+ from functools import wraps
20
+ from typing import Callable, Optional, TypeVar, Tuple, Any
19
21
 
20
22
  import jax
21
23
  import jax.numpy as jnp
22
- from functools import wraps
23
- from typing import Callable, Optional, TypeVar, Tuple, Any
24
24
 
25
25
  from brainstate._utils import set_module_as
26
26
  from ._make_jaxpr import StatefulFunction
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import jax.numpy as jnp
19
21
  import numpy as np
20
- import unittest
21
22
 
22
23
  import brainstate as bst
23
24
 
@@ -16,9 +16,9 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import math
19
+ from typing import Any, Callable, TypeVar
19
20
 
20
21
  import jax
21
- from typing import Any, Callable, TypeVar
22
22
 
23
23
  from brainstate._utils import set_module_as
24
24
  from ._loop_collect_return import _bounded_while_loop
@@ -57,7 +57,7 @@ import functools
57
57
  import inspect
58
58
  import operator
59
59
  from collections.abc import Hashable, Iterable, Sequence
60
- from contextlib import ExitStack, contextmanager
60
+ from contextlib import ExitStack
61
61
  from typing import Any, Callable, Tuple, Union, Dict, Optional
62
62
 
63
63
  import jax
@@ -69,11 +69,11 @@ from jax.extend.linear_util import transformation_with_aux, wrap_init
69
69
  from jax.interpreters import partial_eval as pe
70
70
  from jax.util import wraps
71
71
 
72
+ from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
72
73
  from brainstate._state import State, StateTraceStack
73
74
  from brainstate._utils import set_module_as
74
75
  from brainstate.typing import PyTree
75
76
  from brainstate.util import PrettyObject
76
- from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
77
77
 
78
78
  AxisName = Hashable
79
79
 
@@ -15,10 +15,11 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import jax
19
21
  import jax.numpy as jnp
20
22
  import pytest
21
- import unittest
22
23
 
23
24
  import brainstate as bst
24
25
  from brainstate._compatible_import import jaxpr_as_fun
@@ -17,9 +17,10 @@ from __future__ import annotations
17
17
 
18
18
  import copy
19
19
  import importlib.util
20
- import jax
21
20
  from typing import Optional, Callable, Any, Tuple, Dict
22
21
 
22
+ import jax
23
+
23
24
  tqdm_installed = importlib.util.find_spec('tqdm') is not None
24
25
 
25
26
  __all__ = [
@@ -20,9 +20,8 @@ import jax.interpreters.batching as batching
20
20
  import jax.interpreters.mlir as mlir
21
21
  import jax.numpy as jnp
22
22
 
23
- from brainstate._utils import set_module_as
24
23
  from brainstate._compatible_import import Primitive
25
-
24
+ from brainstate._utils import set_module_as
26
25
 
27
26
  __all__ = [
28
27
  "unvmap",
brainstate/environ.py CHANGED
@@ -17,18 +17,18 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from collections import defaultdict
21
-
22
20
  import contextlib
23
21
  import dataclasses
24
22
  import functools
25
- import numpy as np
26
23
  import os
27
24
  import re
28
25
  import threading
26
+ from collections import defaultdict
27
+ from typing import Any, Callable, Dict, Hashable
28
+
29
+ import numpy as np
29
30
  from jax import config, devices, numpy as jnp
30
31
  from jax.typing import DTypeLike
31
- from typing import Any, Callable, Dict, Hashable
32
32
 
33
33
  from .mixin import Mode
34
34
 
@@ -14,9 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import jax.numpy as jnp
18
17
  import unittest
19
18
 
19
+ import jax.numpy as jnp
20
+
20
21
  import brainstate as bst
21
22
 
22
23
 
@@ -20,10 +20,11 @@ Shared neural network activations and other functions.
20
20
 
21
21
  from __future__ import annotations
22
22
 
23
+ from typing import Any, Union, Sequence
24
+
23
25
  import brainunit as u
24
26
  import jax
25
27
  from jax.scipy.special import logsumexp
26
- from typing import Any, Union, Sequence
27
28
 
28
29
  from brainstate import random
29
30
  from brainstate.typing import ArrayLike
@@ -16,12 +16,12 @@
16
16
  """Tests for nn module."""
17
17
 
18
18
  import itertools
19
+ from functools import partial
19
20
 
20
21
  import jax
21
22
  import jax.numpy as jnp
22
23
  import scipy.stats
23
24
  from absl.testing import parameterized
24
- from functools import partial
25
25
  from jax._src import test_util as jtu
26
26
  from jax.test_util import check_grads
27
27
 
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ from typing import Optional, Union
19
+
18
20
  import brainunit as u
19
21
  import jax
20
- from typing import Optional, Union
21
22
 
22
23
  from brainstate._utils import set_module_as
23
24
  from brainstate.typing import ArrayLike
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ from functools import partial
19
+
18
20
  import jax
19
21
  import jax.numpy as jnp
20
- from functools import partial
21
22
 
22
23
  from brainstate.typing import PyTree
23
24
 
@@ -18,10 +18,11 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import dataclasses
21
- import jax
22
- import numpy as np
23
21
  from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
24
22
  Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
23
+
24
+ import jax
25
+ import numpy as np
25
26
  from typing_extensions import TypeGuard, Unpack
26
27
 
27
28
  from brainstate._state import State, TreefyState
@@ -15,13 +15,14 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import jax
19
- import jax.numpy as jnp
20
18
  import unittest
21
- from absl.testing import absltest, parameterized
22
19
  from collections.abc import Callable
23
20
  from threading import Thread
24
21
 
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from absl.testing import absltest, parameterized
25
+
25
26
  import brainstate as bst
26
27
 
27
28
 
brainstate/init/_base.py CHANGED
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import numpy as np
19
18
  from typing import Optional, Tuple
20
19
 
20
+ import numpy as np
21
+
21
22
  from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
22
23
 
23
24
  __all__ = ['Initializer', 'to_size']
@@ -16,10 +16,11 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
+ from typing import Union, Callable, Optional, Sequence
20
+
19
21
  import brainunit as bu
20
22
  import jax
21
23
  import numpy as np
22
- from typing import Union, Callable, Optional, Sequence
23
24
 
24
25
  from brainstate._state import State
25
26
  from brainstate._utils import set_module_as
brainstate/nn/__init__.py CHANGED
@@ -25,6 +25,8 @@ from ._dynamics import *
25
25
  from ._dynamics import __all__ as dynamics_all
26
26
  from ._elementwise import *
27
27
  from ._elementwise import __all__ as elementwise_all
28
+ from ._event import *
29
+ from ._event import __all__ as _event_all
28
30
  from ._exp_euler import *
29
31
  from ._exp_euler import __all__ as exp_euler_all
30
32
  from ._interaction import *
@@ -45,6 +47,7 @@ __all__ = (
45
47
  + exp_euler_all
46
48
  + interaction_all
47
49
  + utils_all
50
+ + _event_all
48
51
  )
49
52
 
50
53
  del (
@@ -57,4 +60,5 @@ del (
57
60
  exp_euler_all,
58
61
  interaction_all,
59
62
  utils_all,
63
+ _event_all,
60
64
  )
@@ -19,6 +19,7 @@ from collections import namedtuple
19
19
  from typing import Callable, TypeVar, Tuple, Any, Dict
20
20
 
21
21
  import jax
22
+ from typing import Callable, TypeVar, Tuple, Any, Dict
22
23
 
23
24
  from brainstate._state import catch_new_states
24
25
  from brainstate._utils import set_module_as
@@ -41,7 +41,3 @@ class Test_init_all_states:
41
41
  gru = bst.nn.GRUCell(1, 2)
42
42
  bst.nn.init_all_states(gru, batch_size=10)
43
43
  print(gru)
44
-
45
-
46
-
47
-
brainstate/nn/_common.py CHANGED
@@ -18,7 +18,6 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from collections import defaultdict
21
-
22
21
  from typing import Any, Sequence, Hashable, Dict
23
22
 
24
23
  from brainstate import environ
@@ -20,8 +20,6 @@ from ._dynamics_synapse import *
20
20
  from ._dynamics_synapse import __all__ as dyn_synapse_all
21
21
  from ._inputs import *
22
22
  from ._inputs import __all__ as inputs_all
23
- from ._projection_alignpost import *
24
- from ._projection_alignpost import __all__ as alignpost_all
25
23
  from ._rate_rnns import *
26
24
  from ._rate_rnns import __all__ as rate_rnns
27
25
  from ._readout import *
@@ -31,7 +29,6 @@ __all__ = (
31
29
  dyn_neuron_all
32
30
  + dyn_synapse_all
33
31
  + inputs_all
34
- + alignpost_all
35
32
  + rate_rnns
36
33
  + readout_all
37
34
  )
@@ -41,6 +38,5 @@ del (
41
38
  dyn_synapse_all,
42
39
  inputs_all,
43
40
  readout_all,
44
- alignpost_all,
45
41
  rate_rnns,
46
42
  )