brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. brainstate/_state.py +853 -90
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +4 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +194 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +2 -3
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +63 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/metrics.py +3 -4
  68. brainstate/optim/_lr_scheduler.py +1 -2
  69. brainstate/optim/_lr_scheduler_test.py +2 -3
  70. brainstate/optim/_optax_optimizer_test.py +1 -2
  71. brainstate/optim/_sgd_optimizer.py +2 -3
  72. brainstate/random/_rand_funs.py +1 -2
  73. brainstate/random/_rand_funs_test.py +2 -3
  74. brainstate/random/_rand_seed.py +2 -3
  75. brainstate/random/_rand_seed_test.py +1 -2
  76. brainstate/random/_rand_state.py +3 -4
  77. brainstate/surrogate.py +5 -2
  78. brainstate/transform.py +0 -3
  79. brainstate/typing.py +28 -25
  80. brainstate/util/__init__.py +9 -7
  81. brainstate/util/_caller.py +1 -2
  82. brainstate/util/_error.py +27 -0
  83. brainstate/util/_others.py +60 -15
  84. brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
  85. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  86. brainstate/util/_pretty_repr.py +1 -2
  87. brainstate/util/_pretty_table.py +2900 -0
  88. brainstate/util/_struct.py +11 -11
  89. brainstate/util/filter.py +472 -0
  90. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
  91. brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
  92. brainstate/util/_filter.py +0 -178
  93. brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
  94. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
  95. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
  96. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
brainstate/typing.py CHANGED
@@ -16,13 +16,17 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import builtins
19
+
20
+ import brainunit as u
19
21
  import functools as ft
20
22
  import importlib
21
23
  import inspect
22
-
23
- import brainunit as u
24
24
  import jax
25
25
  import numpy as np
26
+ from typing import (
27
+ Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
28
+ runtime_checkable, TYPE_CHECKING, Generic, Sequence
29
+ )
26
30
 
27
31
  tp = importlib.import_module("typing")
28
32
 
@@ -41,35 +45,35 @@ __all__ = [
41
45
  'Missing',
42
46
  ]
43
47
 
44
- K = tp.TypeVar('K')
48
+ K = TypeVar('K')
45
49
 
46
50
 
47
- @tp.runtime_checkable
48
- class Key(tp.Hashable, tp.Protocol):
51
+ @runtime_checkable
52
+ class Key(Hashable, Protocol):
49
53
  def __lt__(self: K, value: K, /) -> bool:
50
54
  ...
51
55
 
52
56
 
53
- Ellipsis = builtins.ellipsis if tp.TYPE_CHECKING else tp.Any
57
+ Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
54
58
 
55
- PathParts = tp.Tuple[Key, ...]
56
- Predicate = tp.Callable[[PathParts, tp.Any], bool]
57
- FilterLiteral = tp.Union[type, str, Predicate, bool, Ellipsis, None]
58
- Filter = tp.Union[FilterLiteral, tp.Tuple['Filter', ...], tp.List['Filter']]
59
+ PathParts = Tuple[Key, ...]
60
+ Predicate = Callable[[PathParts, Any], bool]
61
+ FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
62
+ Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
59
63
 
60
- _T = tp.TypeVar("_T")
64
+ _T = TypeVar("_T")
61
65
 
62
- _Annotation = tp.TypeVar("_Annotation")
66
+ _Annotation = TypeVar("_Annotation")
63
67
 
64
68
 
65
- class _Array(tp.Generic[_Annotation]):
69
+ class _Array(Generic[_Annotation]):
66
70
  pass
67
71
 
68
72
 
69
73
  _Array.__module__ = "builtins"
70
74
 
71
75
 
72
- def _item_to_str(item: tp.Union[str, type, slice]) -> str:
76
+ def _item_to_str(item: Union[str, type, slice]) -> str:
73
77
  if isinstance(item, slice):
74
78
  if item.step is not None:
75
79
  raise NotImplementedError
@@ -83,7 +87,7 @@ def _item_to_str(item: tp.Union[str, type, slice]) -> str:
83
87
 
84
88
 
85
89
  def _maybe_tuple_to_str(
86
- item: tp.Union[str, type, slice, tp.Tuple[tp.Union[str, type, slice], ...]]
90
+ item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
87
91
  ) -> str:
88
92
  if isinstance(item, tuple):
89
93
  if len(item) == 0:
@@ -113,7 +117,7 @@ class Array:
113
117
  Array.__module__ = "builtins"
114
118
 
115
119
 
116
- class _FakePyTree(tp.Generic[_T]):
120
+ class _FakePyTree(Generic[_T]):
117
121
  pass
118
122
 
119
123
 
@@ -255,11 +259,10 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
255
259
  cases, all named pieces must already have been seen and their structures bound.
256
260
  """ # noqa: E501
257
261
 
258
- Size = tp.Union[int, tp.Sequence[int]]
259
- Axes = tp.Union[int, tp.Sequence[int]]
260
- SeedOrKey = tp.Union[int, jax.Array, np.ndarray]
261
- Shape = tp.Sequence[int]
262
-
262
+ Size = Union[int, Sequence[int]]
263
+ Axes = Union[int, Sequence[int]]
264
+ SeedOrKey = Union[int, jax.Array, np.ndarray]
265
+ Shape = Sequence[int]
263
266
 
264
267
  # --- Array --- #
265
268
 
@@ -267,7 +270,7 @@ Shape = tp.Sequence[int]
267
270
  # standard JAX array (i.e. not including future non-standard array types like
268
271
  # KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
269
272
  # accept arbitrary sequences, nor does it accept string data.
270
- ArrayLike = tp.Union[
273
+ ArrayLike = Union[
271
274
  jax.Array, # JAX array type
272
275
  np.ndarray, # NumPy array type
273
276
  np.bool_, np.number, # NumPy scalar types
@@ -281,7 +284,7 @@ ArrayLike = tp.Union[
281
284
  DType = np.dtype
282
285
 
283
286
 
284
- class SupportsDType(tp.Protocol):
287
+ class SupportsDType(Protocol):
285
288
  @property
286
289
  def dtype(self) -> DType: ...
287
290
 
@@ -291,9 +294,9 @@ class SupportsDType(tp.Protocol):
291
294
  # because JAX doesn't support objects or structured dtypes.
292
295
  # Unlike np.typing.DTypeLike, we exclude None, and instead require
293
296
  # explicit annotations when None is acceptable.
294
- DTypeLike = tp.Union[
297
+ DTypeLike = Union[
295
298
  str, # like 'float32', 'int32'
296
- type[tp.Any], # like np.float32, np.int32, float, int
299
+ type[Any], # like np.float32, np.int32, float, int
297
300
  np.dtype, # like np.dtype('float32'), np.dtype('int32')
298
301
  SupportsDType, # like jnp.float32, jnp.int32
299
302
  ]
@@ -13,36 +13,38 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ._dict import *
17
- from ._dict import __all__ as _mapping_all
16
+ from . import filter
18
17
  from ._error import *
19
18
  from ._error import __all__ as _error_all
20
- from ._filter import *
21
- from ._filter import __all__ as _filter_all
22
19
  from ._others import *
23
20
  from ._others import __all__ as _others_all
21
+ from ._pretty_pytree import *
22
+ from ._pretty_pytree import __all__ as _mapping_all
24
23
  from ._pretty_repr import *
25
24
  from ._pretty_repr import __all__ as _pretty_repr_all
25
+ from ._pretty_table import *
26
+ from ._pretty_table import __all__ as _table_all
26
27
  from ._scaling import *
27
28
  from ._scaling import __all__ as _mem_scale_all
28
29
  from ._struct import *
29
30
  from ._struct import __all__ as _struct_all
30
31
 
31
32
  __all__ = (
32
- _others_all
33
+ ['filter']
34
+ + _others_all
33
35
  + _mem_scale_all
34
- + _filter_all
35
36
  + _pretty_repr_all
36
37
  + _struct_all
37
38
  + _error_all
38
39
  + _mapping_all
40
+ + _table_all
39
41
  )
40
42
  del (
41
43
  _others_all,
42
44
  _mem_scale_all,
43
- _filter_all,
44
45
  _pretty_repr_all,
45
46
  _struct_all,
46
47
  _error_all,
47
48
  _mapping_all,
49
+ _table_all,
48
50
  )
@@ -18,9 +18,8 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import dataclasses
21
- from typing import Any, TypeVar, Protocol, Generic
22
-
23
21
  import jax
22
+ from typing import Any, TypeVar, Protocol, Generic
24
23
 
25
24
  __all__ = [
26
25
  'DelayedAccessor',
brainstate/util/_error.py CHANGED
@@ -21,8 +21,35 @@ __all__ = [
21
21
 
22
22
 
23
23
  class BrainStateError(Exception):
24
+ """
25
+ A custom exception class for BrainState-related errors.
26
+
27
+ This exception is raised when a BrainState-specific error occurs during
28
+ the execution of the program. It serves as a base class for more specific
29
+ BrainState exceptions.
30
+
31
+ Attributes:
32
+ Inherits all attributes from the built-in Exception class.
33
+
34
+ Usage::
35
+
36
+ raise BrainStateError("A BrainState-specific error occurred.")
37
+ """
24
38
  pass
25
39
 
26
40
 
27
41
  class TraceContextError(BrainStateError):
42
+ """
43
+ A custom exception class for trace context-related errors in BrainState.
44
+
45
+ This exception is raised when an error occurs specifically related to
46
+ trace context operations or manipulations within the BrainState framework.
47
+
48
+ Attributes:
49
+ Inherits all attributes from the BrainStateError class.
50
+
51
+ Usage::
52
+
53
+ raise TraceContextError("An error occurred while handling trace context.")
54
+ """
28
55
  pass
@@ -15,20 +15,21 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import gc
19
+
18
20
  import copy
19
21
  import functools
20
- import gc
22
+ import jax
21
23
  import threading
22
24
  import types
23
25
  from collections.abc import Iterable
24
- from typing import Any, Callable, Tuple, Union, Dict
25
-
26
- import jax
27
26
  from jax.lib import xla_bridge
27
+ from typing import Any, Callable, Tuple, Union, Dict
28
28
 
29
29
  from brainstate._utils import set_module_as
30
30
 
31
31
  __all__ = [
32
+ 'split_total',
32
33
  'clear_buffer_memory',
33
34
  'not_instance_eval',
34
35
  'is_instance_eval',
@@ -37,6 +38,61 @@ __all__ = [
37
38
  ]
38
39
 
39
40
 
41
+ def split_total(
42
+ total: int,
43
+ fraction: Union[int, float],
44
+ ) -> int:
45
+ """
46
+ Calculate the number of epochs for simulation based on a total and a fraction.
47
+
48
+ This function determines the number of epochs to simulate given a total number
49
+ of epochs and either a fraction or a specific number of epochs to run.
50
+
51
+ Parameters:
52
+ -----------
53
+ total : int
54
+ The total number of epochs. Must be a positive integer.
55
+ fraction : Union[int, float]
56
+ If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
57
+ If ``int``: The specific number of epochs to run, must not exceed the total.
58
+
59
+ Returns:
60
+ --------
61
+ int
62
+ The calculated number of epochs to simulate.
63
+
64
+ Raises:
65
+ -------
66
+ ValueError
67
+ If total is not positive, fraction is negative, or if fraction as float is > 1
68
+ or as int is > total.
69
+ AssertionError
70
+ If total is not an integer.
71
+ """
72
+ assert isinstance(total, int), "Length must be an integer."
73
+ if total <= 0:
74
+ raise ValueError("'total' must be a positive integer.")
75
+ if fraction < 0:
76
+ raise ValueError("'fraction' value cannot be negative.")
77
+
78
+ if isinstance(fraction, float):
79
+ if fraction < 0:
80
+ raise ValueError("'fraction' value cannot be negative.")
81
+ if fraction > 1:
82
+ raise ValueError("'fraction' value cannot be greater than 1.")
83
+ return int(total * fraction)
84
+
85
+ elif isinstance(fraction, int):
86
+ if fraction < 0:
87
+ raise ValueError("'fraction' value cannot be negative.")
88
+ if fraction > total:
89
+ raise ValueError("'fraction' value cannot be greater than total.")
90
+ return fraction
91
+
92
+ else:
93
+ raise ValueError("'fraction' must be an integer or float.")
94
+
95
+
40
96
  class NameContext(threading.local):
41
97
  def __init__(self):
42
98
  self.typed_names: Dict[str, int] = {}
@@ -249,17 +305,6 @@ class DictManager(dict):
249
305
  else:
250
306
  raise ValueError(f'Unsupported method: {by}')
251
307
 
252
- def union_by_value_ids(self, other: dict):
253
- """
254
- Union the stack by the value ids.
255
-
256
- Args:
257
- other:
258
-
259
- Returns:
260
-
261
- """
262
-
263
308
  def __add__(self, other: dict):
264
309
  """
265
310
  Compose other instance of dict.
@@ -18,14 +18,14 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from collections import abc
21
- from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
22
21
 
23
22
  import jax
23
+ from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
24
24
 
25
25
  from brainstate.typing import Filter, PathParts
26
- from ._filter import to_predicate
27
26
  from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
28
27
  from ._struct import dataclass
28
+ from .filter import to_predicate
29
29
 
30
30
  __all__ = [
31
31
  'PrettyDict',
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax
19
+ import unittest
21
20
  from absl.testing import absltest
22
21
 
23
22
  import brainstate as bst
@@ -21,7 +21,7 @@ import dataclasses
21
21
  import threading
22
22
  from abc import ABC, abstractmethod
23
23
  from functools import partial
24
- from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional, Sequence
24
+ from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
25
25
 
26
26
  __all__ = [
27
27
  'yield_unique_pretty_repr_items',
@@ -328,4 +328,3 @@ def yield_unique_pretty_repr_items(
328
328
  finally:
329
329
  if clear_seen:
330
330
  CONTEXT.seen_modules_repr = None
331
-