brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250423__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. brainstate/_compatible_import.py +15 -0
  2. brainstate/_state.py +5 -4
  3. brainstate/_state_test.py +2 -1
  4. brainstate/augment/_autograd_test.py +3 -2
  5. brainstate/augment/_eval_shape.py +2 -1
  6. brainstate/augment/_mapping.py +0 -1
  7. brainstate/augment/_mapping_test.py +1 -0
  8. brainstate/compile/_ad_checkpoint.py +2 -1
  9. brainstate/compile/_conditions.py +3 -3
  10. brainstate/compile/_conditions_test.py +2 -1
  11. brainstate/compile/_error_if.py +2 -1
  12. brainstate/compile/_error_if_test.py +2 -1
  13. brainstate/compile/_jit.py +3 -2
  14. brainstate/compile/_jit_test.py +2 -1
  15. brainstate/compile/_loop_collect_return.py +2 -2
  16. brainstate/compile/_loop_collect_return_test.py +2 -1
  17. brainstate/compile/_loop_no_collection.py +1 -1
  18. brainstate/compile/_make_jaxpr.py +2 -2
  19. brainstate/compile/_make_jaxpr_test.py +2 -1
  20. brainstate/compile/_progress_bar.py +2 -1
  21. brainstate/compile/_unvmap.py +1 -2
  22. brainstate/environ.py +4 -4
  23. brainstate/environ_test.py +2 -1
  24. brainstate/functional/_activations.py +2 -1
  25. brainstate/functional/_activations_test.py +1 -1
  26. brainstate/functional/_normalization.py +2 -1
  27. brainstate/functional/_others.py +2 -1
  28. brainstate/graph/_graph_operation.py +3 -2
  29. brainstate/graph/_graph_operation_test.py +4 -3
  30. brainstate/init/_base.py +2 -1
  31. brainstate/init/_generic.py +2 -1
  32. brainstate/nn/__init__.py +4 -0
  33. brainstate/nn/_collective_ops.py +1 -0
  34. brainstate/nn/_collective_ops_test.py +0 -4
  35. brainstate/nn/_common.py +0 -1
  36. brainstate/nn/_dyn_impl/__init__.py +0 -4
  37. brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
  38. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
  39. brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
  40. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
  41. brainstate/nn/_dyn_impl/_inputs.py +236 -29
  42. brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
  43. brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
  44. brainstate/nn/_dyn_impl/_readout.py +91 -8
  45. brainstate/nn/_dyn_impl/_readout_test.py +2 -1
  46. brainstate/nn/_dynamics/_dynamics_base.py +676 -96
  47. brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
  48. brainstate/nn/_dynamics/_projection_base.py +29 -30
  49. brainstate/nn/_dynamics/_state_delay.py +3 -3
  50. brainstate/nn/_dynamics/_synouts_test.py +2 -1
  51. brainstate/nn/_elementwise/_dropout.py +3 -2
  52. brainstate/nn/_elementwise/_dropout_test.py +2 -1
  53. brainstate/nn/_elementwise/_elementwise.py +2 -1
  54. brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
  55. brainstate/nn/_event/_fixedprob_mv.py +169 -0
  56. brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
  57. brainstate/nn/_event/_linear_mv.py +85 -0
  58. brainstate/nn/_event/_linear_mv_test.py +121 -0
  59. brainstate/nn/_exp_euler.py +2 -1
  60. brainstate/nn/_exp_euler_test.py +2 -1
  61. brainstate/nn/_interaction/_conv.py +2 -1
  62. brainstate/nn/_interaction/_linear.py +2 -1
  63. brainstate/nn/_interaction/_linear_test.py +2 -1
  64. brainstate/nn/_interaction/_normalizations.py +3 -2
  65. brainstate/nn/_interaction/_poolings.py +4 -3
  66. brainstate/nn/_module_test.py +2 -1
  67. brainstate/nn/metrics.py +4 -3
  68. brainstate/optim/_lr_scheduler.py +2 -1
  69. brainstate/optim/_lr_scheduler_test.py +2 -1
  70. brainstate/optim/_optax_optimizer_test.py +2 -1
  71. brainstate/optim/_sgd_optimizer.py +3 -2
  72. brainstate/random/_rand_funs.py +2 -1
  73. brainstate/random/_rand_funs_test.py +3 -2
  74. brainstate/random/_rand_seed.py +3 -2
  75. brainstate/random/_rand_seed_test.py +2 -1
  76. brainstate/random/_rand_state.py +4 -3
  77. brainstate/surrogate.py +1 -2
  78. brainstate/typing.py +4 -4
  79. brainstate/util/_caller.py +2 -1
  80. brainstate/util/_others.py +4 -4
  81. brainstate/util/_pretty_pytree.py +1 -1
  82. brainstate/util/_pretty_pytree_test.py +2 -1
  83. brainstate/util/_pretty_table.py +43 -43
  84. brainstate/util/_struct.py +2 -1
  85. brainstate/util/filter.py +0 -1
  86. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/METADATA +3 -3
  87. brainstate-0.1.0.post20250423.dist-info/RECORD +133 -0
  88. brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
  89. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/LICENSE +0 -0
  90. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/WHEEL +0 -0
  91. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,10 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
+ from typing import Callable
20
+
19
21
  import brainunit as u
20
22
  import jax.numpy as jnp
21
- from typing import Callable
22
23
 
23
24
  from brainstate import environ, random
24
25
  from brainstate.augment import vector_grad
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import brainunit as u
19
18
  import unittest
20
19
 
20
+ import brainunit as u
21
+
21
22
  import brainstate as bst
22
23
 
23
24
 
@@ -18,9 +18,10 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import collections.abc
21
+ from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
22
+
21
23
  import jax
22
24
  import jax.numpy as jnp
23
- from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
24
25
 
25
26
  from brainstate import init, functional
26
27
  from brainstate._state import ParamState
@@ -17,9 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from typing import Callable, Union, Optional
21
+
20
22
  import brainunit as u
21
23
  import jax.numpy as jnp
22
- from typing import Callable, Union, Optional
23
24
 
24
25
  from brainstate import init, functional
25
26
  from brainstate._state import ParamState
@@ -16,8 +16,9 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import brainunit as u
20
19
  import unittest
20
+
21
+ import brainunit as u
21
22
  from absl.testing import parameterized
22
23
 
23
24
  import brainstate as bst
@@ -17,9 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from typing import Callable, Union, Sequence, Optional, Any
21
+
20
22
  import jax
21
23
  import jax.numpy as jnp
22
- from typing import Callable, Union, Sequence, Optional, Any
23
24
 
24
25
  from brainstate import environ, init
25
26
  from brainstate._state import ParamState, BatchState
@@ -230,7 +231,7 @@ def _normalize(
230
231
  y = y * mul
231
232
  if weights is not None:
232
233
  y = weights.execute(y)
233
- dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
234
+ dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
234
235
  else:
235
236
  assert var is None, 'mean and val must be both None or not None.'
236
237
  assert weights is None, 'scale and bias are not supported without mean and val'
@@ -17,13 +17,14 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import brainunit as u
21
20
  import functools
21
+ from typing import Sequence, Optional
22
+ from typing import Union, Tuple, Callable, List
23
+
24
+ import brainunit as u
22
25
  import jax
23
26
  import jax.numpy as jnp
24
27
  import numpy as np
25
- from typing import Sequence, Optional
26
- from typing import Union, Tuple, Callable, List
27
28
 
28
29
  from brainstate import environ
29
30
  from brainstate.nn._module import Module
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import jax.numpy as jnp
19
21
  import jaxlib.xla_extension
20
- import unittest
21
22
 
22
23
  import brainstate as bst
23
24
 
brainstate/nn/metrics.py CHANGED
@@ -16,13 +16,14 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
19
  import typing as tp
23
20
  from dataclasses import dataclass
24
21
  from functools import partial
25
22
 
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
26
27
  from brainstate._state import State
27
28
 
28
29
  __all__ = [
@@ -16,10 +16,11 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
+ from typing import Sequence, Union
20
+
19
21
  import jax
20
22
  import jax.numpy as jnp
21
23
  import numpy as np
22
- from typing import Sequence, Union
23
24
 
24
25
  from brainstate import environ
25
26
  from brainstate._state import State, LongTermState
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import jax.numpy as jnp
19
18
  import unittest
20
19
 
20
+ import jax.numpy as jnp
21
+
21
22
  import brainstate as bst
22
23
 
23
24
 
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import jax
19
21
  import optax
20
- import unittest
21
22
 
22
23
  import brainstate as bst
23
24
 
@@ -16,11 +16,12 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- import brainunit as u
20
19
  import functools
20
+ from typing import Union, Dict, Optional, Tuple, Any, TypeVar
21
+
22
+ import brainunit as u
21
23
  import jax
22
24
  import jax.numpy as jnp
23
- from typing import Union, Dict, Optional, Tuple, Any, TypeVar
24
25
 
25
26
  from brainstate import environ
26
27
  from brainstate._state import State, LongTermState, StateDictManager
@@ -17,9 +17,10 @@
17
17
  # -*- coding: utf-8 -*-
18
18
  from __future__ import annotations
19
19
 
20
- import numpy as np
21
20
  from typing import Optional
22
21
 
22
+ import numpy as np
23
+
23
24
  from brainstate.typing import DTypeLike, Size, SeedOrKey
24
25
  from ._rand_state import RandomState, DEFAULT
25
26
 
@@ -15,12 +15,13 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import platform
19
+ import unittest
20
+
18
21
  import jax.numpy as jnp
19
22
  import jax.random as jr
20
23
  import numpy as np
21
- import platform
22
24
  import pytest
23
- import unittest
24
25
 
25
26
  import brainstate as bst
26
27
 
@@ -14,11 +14,12 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
- import jax
18
- import numpy as np
19
17
  from contextlib import contextmanager
20
18
  from typing import Optional
21
19
 
20
+ import jax
21
+ import numpy as np
22
+
22
23
  from brainstate.typing import SeedOrKey
23
24
  from ._rand_state import RandomState, DEFAULT, use_prng_key
24
25
 
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import jax.numpy as jnp
19
21
  import jax.random
20
- import unittest
21
22
 
22
23
  import brainstate as bst
23
24
 
@@ -16,16 +16,17 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
+ from functools import partial
20
+ from operator import index
21
+ from typing import Optional
22
+
19
23
  import brainunit as u
20
24
  import jax
21
25
  import jax.numpy as jnp
22
26
  import jax.random as jr
23
27
  import numpy as np
24
- from functools import partial
25
28
  from jax import jit, vmap
26
29
  from jax import lax, core, dtypes
27
- from operator import index
28
- from typing import Optional
29
30
 
30
31
  from brainstate import environ
31
32
  from brainstate._state import State
brainstate/surrogate.py CHANGED
@@ -21,9 +21,8 @@ import jax.numpy as jnp
21
21
  import jax.scipy as sci
22
22
  from jax.interpreters import batching, ad, mlir
23
23
 
24
- from brainstate.util._pretty_pytree import PrettyObject
25
24
  from brainstate._compatible_import import Primitive
26
-
25
+ from brainstate.util._pretty_pytree import PrettyObject
27
26
 
28
27
  __all__ = [
29
28
  'Surrogate',
brainstate/typing.py CHANGED
@@ -16,18 +16,18 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import builtins
19
-
20
- import brainunit as u
21
19
  import functools as ft
22
20
  import importlib
23
21
  import inspect
24
- import jax
25
- import numpy as np
26
22
  from typing import (
27
23
  Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
28
24
  runtime_checkable, TYPE_CHECKING, Generic, Sequence
29
25
  )
30
26
 
27
+ import brainunit as u
28
+ import jax
29
+ import numpy as np
30
+
31
31
  tp = importlib.import_module("typing")
32
32
 
33
33
  __all__ = [
@@ -18,9 +18,10 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import dataclasses
21
- import jax
22
21
  from typing import Any, TypeVar, Protocol, Generic
23
22
 
23
+ import jax
24
+
24
25
  __all__ = [
25
26
  'DelayedAccessor',
26
27
  'CallableProxy',
@@ -15,17 +15,17 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import gc
19
-
20
18
  import copy
21
19
  import functools
22
- import jax
20
+ import gc
23
21
  import threading
24
22
  import types
25
23
  from collections.abc import Iterable
26
- from jax.lib import xla_bridge
27
24
  from typing import Any, Callable, Tuple, Union, Dict
28
25
 
26
+ import jax
27
+ from jax.lib import xla_bridge
28
+
29
29
  from brainstate._utils import set_module_as
30
30
 
31
31
  __all__ = [
@@ -18,9 +18,9 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from collections import abc
21
+ from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
21
22
 
22
23
  import jax
23
- from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
24
24
 
25
25
  from brainstate.typing import Filter, PathParts
26
26
  from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
@@ -15,8 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import jax
19
18
  import unittest
19
+
20
+ import jax
20
21
  from absl.testing import absltest
21
22
 
22
23
  import brainstate as bst
@@ -422,14 +422,14 @@ class PrettyTable:
422
422
  def _column_specific_args(self):
423
423
  # Column specific arguments, use property.setters
424
424
  for attr in (
425
- "align",
426
- "valign",
427
- "max_width",
428
- "min_width",
429
- "int_format",
430
- "float_format",
431
- "custom_format",
432
- "none_format",
425
+ "align",
426
+ "valign",
427
+ "max_width",
428
+ "min_width",
429
+ "int_format",
430
+ "float_format",
431
+ "custom_format",
432
+ "none_format",
433
433
  ):
434
434
  setattr(
435
435
  self, attr, (self._kwargs[attr] or {}) if attr in self._kwargs else {}
@@ -517,15 +517,15 @@ class PrettyTable:
517
517
  elif option == "none_format":
518
518
  self._validate_none_format(val)
519
519
  elif option in (
520
- "start",
521
- "end",
522
- "max_width",
523
- "min_width",
524
- "min_table_width",
525
- "max_table_width",
526
- "padding_width",
527
- "left_padding_width",
528
- "right_padding_width",
520
+ "start",
521
+ "end",
522
+ "max_width",
523
+ "min_width",
524
+ "min_table_width",
525
+ "max_table_width",
526
+ "padding_width",
527
+ "left_padding_width",
528
+ "right_padding_width",
529
529
  ):
530
530
  self._validate_nonnegative_int(option, val)
531
531
  elif option == "sortby":
@@ -539,16 +539,16 @@ class PrettyTable:
539
539
  elif option == "fields":
540
540
  self._validate_all_field_names(option, val)
541
541
  elif option in (
542
- "header",
543
- "border",
544
- "preserve_internal_border",
545
- "reversesort",
546
- "xhtml",
547
- "format",
548
- "print_empty",
549
- "oldsortslice",
550
- "escape_header",
551
- "escape_data",
542
+ "header",
543
+ "border",
544
+ "preserve_internal_border",
545
+ "reversesort",
546
+ "xhtml",
547
+ "format",
548
+ "print_empty",
549
+ "oldsortslice",
550
+ "escape_header",
551
+ "escape_data",
552
552
  ):
553
553
  self._validate_true_or_false(option, val)
554
554
  elif option == "header_style":
@@ -561,18 +561,18 @@ class PrettyTable:
561
561
  for k, formatter in val.items():
562
562
  self._validate_function(f"{option}.{k}", formatter)
563
563
  elif option in (
564
- "vertical_char",
565
- "horizontal_char",
566
- "horizontal_align_char",
567
- "junction_char",
568
- "top_junction_char",
569
- "bottom_junction_char",
570
- "right_junction_char",
571
- "left_junction_char",
572
- "top_right_junction_char",
573
- "top_left_junction_char",
574
- "bottom_right_junction_char",
575
- "bottom_left_junction_char",
564
+ "vertical_char",
565
+ "horizontal_char",
566
+ "horizontal_align_char",
567
+ "junction_char",
568
+ "top_junction_char",
569
+ "bottom_junction_char",
570
+ "right_junction_char",
571
+ "left_junction_char",
572
+ "top_right_junction_char",
573
+ "top_left_junction_char",
574
+ "bottom_right_junction_char",
575
+ "bottom_left_junction_char",
576
576
  ):
577
577
  self._validate_single_char(option, val)
578
578
  elif option == "attributes":
@@ -1998,8 +1998,8 @@ class PrettyTable:
1998
1998
  if options["header"]:
1999
1999
  lines.append(self._stringify_header(options))
2000
2000
  elif options["border"] and options["hrules"] in (
2001
- HRuleStyle.ALL,
2002
- HRuleStyle.FRAME,
2001
+ HRuleStyle.ALL,
2002
+ HRuleStyle.FRAME,
2003
2003
  ):
2004
2004
  lines.append(self._stringify_hrule(options, where="top_"))
2005
2005
  if title and options["vrules"] in (VRuleStyle.ALL, VRuleStyle.FRAME):
@@ -2115,8 +2115,8 @@ class PrettyTable:
2115
2115
  if options["hrules"] in (HRuleStyle.ALL, HRuleStyle.FRAME):
2116
2116
  bits.append(self._stringify_hrule(options, "top_"))
2117
2117
  if options["title"] and options["vrules"] in (
2118
- VRuleStyle.ALL,
2119
- VRuleStyle.FRAME,
2118
+ VRuleStyle.ALL,
2119
+ VRuleStyle.FRAME,
2120
2120
  ):
2121
2121
  left_j_len = len(self.left_junction_char)
2122
2122
  right_j_len = len(self.right_junction_char)
@@ -21,10 +21,11 @@ from __future__ import annotations
21
21
 
22
22
  import collections
23
23
  import dataclasses
24
- import jax
25
24
  from collections.abc import Hashable, Mapping
26
25
  from types import MappingProxyType
27
26
  from typing import Any, TypeVar
27
+
28
+ import jax
28
29
  from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
29
30
 
30
31
  __all__ = [
brainstate/util/filter.py CHANGED
@@ -18,7 +18,6 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import builtins
21
-
22
21
  import dataclasses
23
22
  import typing
24
23
  from typing import TYPE_CHECKING
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250420
3
+ Version: 0.1.0.post20250423
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -15,7 +15,6 @@ Classifier: Development Status :: 4 - Beta
15
15
  Classifier: Intended Audience :: Developers
16
16
  Classifier: Intended Audience :: Science/Research
17
17
  Classifier: Programming Language :: Python :: 3
18
- Classifier: Programming Language :: Python :: 3.9
19
18
  Classifier: Programming Language :: Python :: 3.10
20
19
  Classifier: Programming Language :: Python :: 3.11
21
20
  Classifier: Programming Language :: Python :: 3.12
@@ -33,6 +32,7 @@ Requires-Dist: jax
33
32
  Requires-Dist: jaxlib
34
33
  Requires-Dist: numpy
35
34
  Requires-Dist: brainunit (>=0.0.4)
35
+ Requires-Dist: brainevent
36
36
  Provides-Extra: cpu
37
37
  Requires-Dist: jaxlib ; extra == 'cpu'
38
38
  Provides-Extra: cuda12
@@ -56,7 +56,7 @@ Requires-Dist: jaxlib[tpu] ; extra == 'tpu'
56
56
  <p align="center">
57
57
  <a href="https://pypi.org/project/brainstate/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/brainstate"></a>
58
58
  <a href="https://github.com/chaobrain/brainstate/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-Apache%202.0-blue.svg"></a>
59
- <a href='https://brainstate.readthedocs.io/en/latest/?badge=latest'>
59
+ <a href='https://brainstate.readthedocs.io/?badge=latest'>
60
60
  <img src='https://readthedocs.org/projects/brainstate/badge/?version=latest' alt='Documentation Status' />
61
61
  </a>
62
62
  <a href="https://badge.fury.io/py/brainstate"><img alt="PyPI version" src="https://badge.fury.io/py/brainstate.svg"></a>