brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250217__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. brainstate/_state.py +853 -90
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +8 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +193 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +6 -1
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +68 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/_utils.py +89 -0
  68. brainstate/nn/metrics.py +3 -4
  69. brainstate/optim/_lr_scheduler.py +1 -2
  70. brainstate/optim/_lr_scheduler_test.py +2 -3
  71. brainstate/optim/_optax_optimizer_test.py +1 -2
  72. brainstate/optim/_sgd_optimizer.py +2 -3
  73. brainstate/random/_rand_funs.py +1 -2
  74. brainstate/random/_rand_funs_test.py +2 -3
  75. brainstate/random/_rand_seed.py +2 -3
  76. brainstate/random/_rand_seed_test.py +1 -2
  77. brainstate/random/_rand_state.py +3 -4
  78. brainstate/surrogate.py +5 -5
  79. brainstate/transform.py +0 -3
  80. brainstate/typing.py +28 -25
  81. brainstate/util/__init__.py +9 -7
  82. brainstate/util/_caller.py +1 -2
  83. brainstate/util/_error.py +27 -0
  84. brainstate/util/_others.py +60 -15
  85. brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
  86. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  87. brainstate/util/_pretty_repr.py +1 -2
  88. brainstate/util/_pretty_table.py +2900 -0
  89. brainstate/util/_struct.py +11 -11
  90. brainstate/util/filter.py +472 -0
  91. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/METADATA +2 -2
  92. brainstate-0.1.0.post20250217.dist-info/RECORD +128 -0
  93. brainstate/util/_filter.py +0 -178
  94. brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
  95. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/LICENSE +0 -0
  96. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/WHEEL +0 -0
  97. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/top_level.txt +0 -0
@@ -55,12 +55,10 @@ from __future__ import annotations
55
55
 
56
56
  import functools
57
57
  import inspect
58
+ import jax
58
59
  import operator
59
60
  from collections.abc import Hashable, Iterable, Sequence
60
61
  from contextlib import ExitStack
61
- from typing import Any, Callable, Tuple, Union, Dict, Optional
62
-
63
- import jax
64
62
  from jax._src import source_info_util
65
63
  from jax._src.linear_util import annotate
66
64
  from jax._src.traceback_util import api_boundary
@@ -68,18 +66,18 @@ from jax.api_util import shaped_abstractify
68
66
  from jax.extend.linear_util import transformation_with_aux, wrap_init
69
67
  from jax.interpreters import partial_eval as pe
70
68
  from jax.util import wraps
69
+ from typing import Any, Callable, Tuple, Union, Dict, Optional
71
70
 
72
71
  from brainstate._state import State, StateTraceStack
73
72
  from brainstate._utils import set_module_as
74
73
  from brainstate.typing import PyTree
75
-
74
+ from brainstate.util import PrettyObject
76
75
 
77
76
  if jax.__version_info__ < (0, 4, 38):
78
77
  from jax.core import ClosedJaxpr
79
78
  else:
80
79
  from jax.extend.core import ClosedJaxpr
81
80
 
82
-
83
81
  AxisName = Hashable
84
82
 
85
83
  __all__ = [
@@ -125,8 +123,8 @@ def _new_jax_trace():
125
123
  return frame, trace
126
124
 
127
125
 
128
- def _init_state_trace_stack() -> StateTraceStack:
129
- state_trace: StateTraceStack = StateTraceStack()
126
+ def _init_state_trace_stack(name) -> StateTraceStack:
127
+ state_trace: StateTraceStack = StateTraceStack(name=name)
130
128
 
131
129
  if jax.__version_info__ < (0, 4, 36):
132
130
  # Should be within the calling of ``jax.make_jaxpr()``
@@ -141,7 +139,7 @@ def _init_state_trace_stack() -> StateTraceStack:
141
139
  return state_trace
142
140
 
143
141
 
144
- class StatefulFunction(object):
142
+ class StatefulFunction(PrettyObject):
145
143
  """
146
144
  A wrapper class for a function that collects the states that are read and written by the function. The states are
147
145
  collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
@@ -189,6 +187,7 @@ class StatefulFunction(object):
189
187
  abstracted_axes: Optional[Any] = None,
190
188
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
191
189
  cache_type: Optional[str] = None,
190
+ name: Optional[str] = None,
192
191
  ):
193
192
  # explicit parameters
194
193
  self.fun = fun
@@ -197,6 +196,7 @@ class StatefulFunction(object):
197
196
  self.abstracted_axes = abstracted_axes
198
197
  self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
199
198
  assert cache_type in [None, 'jit']
199
+ self.name = name
200
200
 
201
201
  # implicit parameters
202
202
  self.cache_type = cache_type
@@ -205,12 +205,10 @@ class StatefulFunction(object):
205
205
  self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
206
206
  self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
207
207
 
208
- def __repr__(self) -> str:
209
- return (f"{self.__class__.__name__}("
210
- f"static_argnums={self.static_argnums}, "
211
- f"axis_env={self.axis_env}, "
212
- f"abstracted_axes={self.abstracted_axes}, "
213
- f"state_returns={self.state_returns})")
208
+ def __pretty_repr_item__(self, k, v):
209
+ if k.startswith('_'):
210
+ return None
211
+ return k, v
214
212
 
215
213
  def get_jaxpr(self, cache_key: Hashable = ()) -> jax.core.ClosedJaxpr:
216
214
  """
@@ -388,7 +386,7 @@ class StatefulFunction(object):
388
386
  A tuple of the states that are read and written by the function and the output of the function.
389
387
  """
390
388
  # state trace
391
- state_trace = _init_state_trace_stack()
389
+ state_trace = _init_state_trace_stack(self.name)
392
390
  self._cached_state_trace[cache_key] = state_trace
393
391
  with state_trace:
394
392
  out = self.fun(*args, **kwargs)
@@ -497,11 +495,7 @@ class StatefulFunction(object):
497
495
  """
498
496
  state_trace = self.get_state_trace(self.get_arg_cache_key(*args, **kwargs))
499
497
  state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
500
- for st, written, val in zip(state_trace.states, state_trace.been_writen, state_vals):
501
- if written:
502
- st.value = val
503
- else:
504
- st.restore_value(val)
498
+ state_trace.assign_state_vals(state_vals)
505
499
  return out
506
500
 
507
501
 
@@ -592,7 +586,15 @@ def make_jaxpr(
592
586
  in (g,) }
593
587
  """
594
588
 
595
- stateful_fun = StatefulFunction(fun, static_argnums, axis_env, abstracted_axes, state_returns)
589
+ stateful_fun = StatefulFunction(
590
+ fun,
591
+ static_argnums=static_argnums,
592
+ axis_env=axis_env,
593
+ abstracted_axes=abstracted_axes,
594
+ state_returns=state_returns,
595
+ name='make_jaxpr'
596
+
597
+ )
596
598
 
597
599
  @wraps(fun)
598
600
  def make_jaxpr_f(*args, **kwargs):
@@ -15,11 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax
21
19
  import jax.numpy as jnp
22
20
  import pytest
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
 
@@ -17,9 +17,8 @@ from __future__ import annotations
17
17
 
18
18
  import copy
19
19
  import importlib.util
20
- from typing import Optional, Callable, Any, Tuple, Dict
21
-
22
20
  import jax
21
+ from typing import Optional, Callable, Any, Tuple, Dict
23
22
 
24
23
  tqdm_installed = importlib.util.find_spec('tqdm') is not None
25
24
 
@@ -19,6 +19,7 @@ import jax.core
19
19
  import jax.interpreters.batching as batching
20
20
  import jax.interpreters.mlir as mlir
21
21
  import jax.numpy as jnp
22
+
22
23
  from brainstate._utils import set_module_as
23
24
 
24
25
  if jax.__version_info__ < (0, 4, 38):
@@ -132,8 +132,10 @@ def wrap_single_fun(
132
132
  assert len(been_writen) == len(writen_state_vals) == len(read_state_vals)
133
133
 
134
134
  # collect all written and read states
135
- state_vals = [written_val if written else read_val
136
- for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)]
135
+ state_vals = [
136
+ written_val if written else read_val
137
+ for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)
138
+ ]
137
139
 
138
140
  # call the jaxpr
139
141
  state_vals, (carry, out) = stateful_fun.jaxpr_call(state_vals, carry, inputs)
brainstate/environ.py CHANGED
@@ -17,18 +17,18 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from collections import defaultdict
21
+
20
22
  import contextlib
21
23
  import dataclasses
22
24
  import functools
25
+ import numpy as np
23
26
  import os
24
27
  import re
25
28
  import threading
26
- from collections import defaultdict
27
- from typing import Any, Callable, Dict, Hashable
28
-
29
- import numpy as np
30
29
  from jax import config, devices, numpy as jnp
31
30
  from jax.typing import DTypeLike
31
+ from typing import Any, Callable, Dict, Hashable
32
32
 
33
33
  from .mixin import Mode
34
34
 
@@ -14,9 +14,8 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import unittest
18
-
19
17
  import jax.numpy as jnp
18
+ import unittest
20
19
 
21
20
  import brainstate as bst
22
21
 
@@ -20,11 +20,10 @@ Shared neural network activations and other functions.
20
20
 
21
21
  from __future__ import annotations
22
22
 
23
- from typing import Any, Union, Sequence
24
-
25
23
  import brainunit as u
26
24
  import jax
27
25
  from jax.scipy.special import logsumexp
26
+ from typing import Any, Union, Sequence
28
27
 
29
28
  from brainstate import random
30
29
  from brainstate.typing import ArrayLike
@@ -16,12 +16,12 @@
16
16
  """Tests for nn module."""
17
17
 
18
18
  import itertools
19
- from functools import partial
20
19
 
21
20
  import jax
22
21
  import jax.numpy as jnp
23
22
  import scipy.stats
24
23
  from absl.testing import parameterized
24
+ from functools import partial
25
25
  from jax._src import test_util as jtu
26
26
  from jax.test_util import check_grads
27
27
 
@@ -15,10 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from typing import Optional, Union
19
-
20
18
  import brainunit as u
21
19
  import jax
20
+ from typing import Optional, Union
22
21
 
23
22
  from brainstate._utils import set_module_as
24
23
  from brainstate.typing import ArrayLike
@@ -15,10 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from functools import partial
19
-
20
18
  import jax
21
19
  import jax.numpy as jnp
20
+ from functools import partial
22
21
 
23
22
  from brainstate.typing import PyTree
24
23
 
@@ -27,53 +27,169 @@ __all__ = [
27
27
 
28
28
 
29
29
  def spike_bitwise_or(x, y):
30
- """Bitwise OR operation for spike tensors."""
30
+ """
31
+ Perform a bitwise OR operation on spike tensors.
32
+
33
+ This function computes the OR operation between two spike tensors.
34
+ The OR operation is implemented using the formula: x + y - x * y,
35
+ which is equivalent to the OR operation for binary values.
36
+
37
+ Args:
38
+ x (Tensor): The first input spike tensor.
39
+ y (Tensor): The second input spike tensor.
40
+
41
+ Returns:
42
+ Tensor: The result of the bitwise OR operation applied to the input tensors.
43
+ The output tensor has the same shape as the input tensors.
44
+
45
+ Note:
46
+ This operation assumes that the input tensors contain binary (0 or 1) values.
47
+ For non-binary inputs, the behavior may not correspond to a true bitwise OR.
48
+ """
31
49
  return x + y - x * y
32
50
 
33
51
 
34
52
  def spike_bitwise_and(x, y):
35
- """Bitwise AND operation for spike tensors."""
53
+ """
54
+ Perform a bitwise AND operation on spike tensors.
55
+
56
+ This function computes the AND operation between two spike tensors.
57
+ The AND operation is equivalent to element-wise multiplication for binary values.
58
+
59
+ Args:
60
+ x (Tensor): The first input spike tensor.
61
+ y (Tensor): The second input spike tensor.
62
+
63
+ Returns:
64
+ Tensor: The result of the bitwise AND operation applied to the input tensors.
65
+ The output tensor has the same shape as the input tensors.
66
+
67
+ Note:
68
+ This operation is implemented using element-wise multiplication (x * y),
69
+ which is equivalent to the AND operation for binary values.
70
+ """
36
71
  return x * y
37
72
 
38
73
 
39
74
  def spike_bitwise_iand(x, y):
40
- """Bitwise IAND operation for spike tensors."""
75
+ """
76
+ Perform a bitwise IAND (Inverse AND) operation on spike tensors.
77
+
78
+ This function computes the Inverse AND (IAND) operation between two spike tensors.
79
+ IAND is defined as (NOT x) AND y.
80
+
81
+ Args:
82
+ x (Tensor): The first input spike tensor.
83
+ y (Tensor): The second input spike tensor.
84
+
85
+ Returns:
86
+ Tensor: The result of the bitwise IAND operation applied to the input tensors.
87
+ The output tensor has the same shape as the input tensors.
88
+
89
+ Note:
90
+ This operation is implemented using the formula: (1 - x) * y,
91
+ which is equivalent to the IAND operation for binary values.
92
+ """
41
93
  return (1 - x) * y
42
94
 
43
95
 
44
96
  def spike_bitwise_not(x):
45
- """Bitwise NOT operation for spike tensors."""
97
+ """
98
+ Perform a bitwise NOT operation on spike tensors.
99
+
100
+ This function computes the NOT operation on a spike tensor.
101
+ The NOT operation inverts the binary values in the tensor.
102
+
103
+ Args:
104
+ x (Tensor): The input spike tensor.
105
+
106
+ Returns:
107
+ Tensor: The result of the bitwise NOT operation applied to the input tensor.
108
+ The output tensor has the same shape as the input tensor.
109
+
110
+ Note:
111
+ This operation is implemented using the formula: 1 - x,
112
+ which is equivalent to the NOT operation for binary values.
113
+ """
46
114
  return 1 - x
47
115
 
48
116
 
49
117
  def spike_bitwise_xor(x, y):
50
- """Bitwise XOR operation for spike tensors."""
118
+ """
119
+ Perform a bitwise XOR operation on spike tensors.
120
+
121
+ This function computes the XOR operation between two spike tensors.
122
+ XOR is defined as (x OR y) AND NOT (x AND y).
123
+
124
+ Args:
125
+ x (Tensor): The first input spike tensor.
126
+ y (Tensor): The second input spike tensor.
127
+
128
+ Returns:
129
+ Tensor: The result of the bitwise XOR operation applied to the input tensors.
130
+ The output tensor has the same shape as the input tensors.
131
+
132
+ Note:
133
+ This operation is implemented using the formula: x + y - 2 * x * y,
134
+ which is equivalent to the XOR operation for binary values.
135
+ """
51
136
  return x + y - 2 * x * y
52
137
 
53
138
 
54
139
  def spike_bitwise_ixor(x, y):
55
- """Bitwise IXOR operation for spike tensors."""
140
+ """
141
+ Perform a bitwise IXOR (Inverse XOR) operation on spike tensors.
142
+
143
+ This function computes the Inverse XOR (IXOR) operation between two spike tensors.
144
+ IXOR is defined as (x AND NOT y) OR (NOT x AND y).
145
+
146
+ Args:
147
+ x (Tensor): The first input spike tensor.
148
+ y (Tensor): The second input spike tensor.
149
+
150
+ Returns:
151
+ Tensor: The result of the bitwise IXOR operation applied to the input tensors.
152
+ The output tensor has the same shape as the input tensors.
153
+
154
+ Note:
155
+ This operation is implemented using the formula: x * (1 - y) + (1 - x) * y,
156
+ which is equivalent to the IXOR operation for binary values.
157
+ """
56
158
  return x * (1 - y) + (1 - x) * y
57
159
 
58
160
 
59
161
  def spike_bitwise(x, y, op: str):
60
- r"""Bitwise operation for spike tensors.
61
-
62
- .. math::
162
+ """
163
+ Perform bitwise operations on spike tensors.
63
164
 
64
- \begin{array}{ccc}
65
- \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
66
- \hline \text { ADD } & x+y & x+y \\
67
- \text { AND } & x \cap y & x \cdot y \\
68
- \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
69
- \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
70
- \hline
71
- \end{array}
165
+ This function applies various bitwise operations on spike tensors based on the specified operation.
166
+ It supports 'or', 'and', 'iand', 'xor', and 'ixor' operations.
72
167
 
73
168
  Args:
74
- x: A spike tensor.
75
- y: A spike tensor.
76
- op: A string indicating the bitwise operation to perform.
169
+ x (Tensor): The first input spike tensor.
170
+ y (Tensor): The second input spike tensor.
171
+ op (str): A string indicating the bitwise operation to perform.
172
+ Supported operations are 'or', 'and', 'iand', 'xor', and 'ixor'.
173
+
174
+ Returns:
175
+ Tensor: The result of the bitwise operation applied to the input tensors.
176
+
177
+ Raises:
178
+ NotImplementedError: If an unsupported bitwise operation is specified.
179
+
180
+ Note:
181
+ The function uses the following mathematical expressions for different operations:
182
+
183
+ .. math::
184
+
185
+ \begin{array}{ccc}
186
+ \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
187
+ \hline \text { ADD } & x+y & x+y \\
188
+ \text { AND } & x \cap y & x \cdot y \\
189
+ \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
190
+ \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
191
+ \hline
192
+ \end{array}
77
193
  """
78
194
  if op == 'or':
79
195
  return spike_bitwise_or(x, y)
@@ -27,7 +27,7 @@ import numpy as np
27
27
 
28
28
  from brainstate._state import State, TreefyState
29
29
  from brainstate.typing import Key
30
- from brainstate.util._pretty_repr import PrettyRepr, yield_unique_pretty_repr_items, PrettyType, PrettyAttr
30
+ from brainstate.util._pretty_pytree import PrettyObject
31
31
  from ._graph_operation import register_graph_node_type
32
32
 
33
33
  __all__ = [
@@ -46,7 +46,7 @@ class GraphNodeMeta(ABCMeta):
46
46
  return node
47
47
 
48
48
 
49
- class Node(PrettyRepr, metaclass=GraphNodeMeta):
49
+ class Node(PrettyObject, metaclass=GraphNodeMeta):
50
50
  """
51
51
  Base class for all graph nodes.
52
52
 
@@ -84,47 +84,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
84
84
  state = deepcopy(state)
85
85
  return treefy_merge(graphdef, state)
86
86
 
87
- def __pretty_repr__(self):
88
- """
89
- Pretty repr for the object.
90
- """
91
- yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
92
-
93
- def __treescope_repr__(self, path, subtree_renderer):
94
- """
95
- Treescope repr for the object.
96
- """
97
- children = {}
98
- for name, value in vars(self).items():
99
- name, value = self.__leaf_fn__(name, value)
100
- if name.startswith('_'):
101
- continue
102
- children[name] = value
103
- import treescope # type: ignore[import-not-found,import-untyped]
104
- return treescope.repr_lib.render_object_constructor(
105
- object_type=type(self),
106
- attributes=children,
107
- path=path,
108
- subtree_renderer=subtree_renderer,
109
- color=treescope.formatting_util.color_from_string(type(self).__qualname__)
110
- )
111
-
112
- def __leaf_fn__(self, leaf, value):
113
- return leaf, value
114
-
115
-
116
- def _default_repr_object(node: Node):
117
- yield PrettyType(type=type(node))
118
-
119
-
120
- def _default_repr_attr(node: Node):
121
- for name, value in vars(node).items():
122
- name, value = node.__leaf_fn__(name, value)
123
- if name.startswith('_'):
124
- continue
125
- # value = jax.tree.map(_to_shape_dtype, value, is_leaf=lambda x: isinstance(x, u.Quantity))
126
- yield PrettyAttr(name, repr(value))
127
-
128
87
 
129
88
  class String:
130
89
  def __init__(self, msg):
@@ -18,21 +18,20 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import dataclasses
21
- from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
22
- Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
23
-
24
21
  import jax
25
22
  import numpy as np
23
+ from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
24
+ Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
26
25
  from typing_extensions import TypeGuard, Unpack
27
26
 
28
27
  from brainstate._state import State, TreefyState
29
28
  from brainstate._utils import set_module_as
30
29
  from brainstate.typing import PathParts, Filter, Predicate, Key
31
30
  from brainstate.util._caller import ApplyCaller, CallableProxy, DelayedAccessor
32
- from brainstate.util._dict import NestedDict, FlattedDict, PrettyDict
33
- from brainstate.util._filter import to_predicate
31
+ from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
34
32
  from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
35
33
  from brainstate.util._struct import FrozenDict
34
+ from brainstate.util.filter import to_predicate
36
35
 
37
36
  _max_int = np.iinfo(np.int32).max
38
37
 
@@ -347,21 +346,6 @@ class NodeDef(GraphDef[Node], PrettyRepr):
347
346
  yield PrettyAttr('metadata', self.metadata)
348
347
  yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
349
348
 
350
- def __treescope_repr__(self, path, subtree_renderer):
351
- import treescope # type: ignore[import-not-found,import-untyped]
352
- return treescope.repr_lib.render_object_constructor(
353
- object_type=type(self),
354
- attributes={'type': self.type,
355
- 'index': self.index,
356
- 'attributes': self.attributes,
357
- 'subgraphs': dict(self.subgraphs),
358
- 'static_fields': dict(self.static_fields),
359
- 'leaves': dict(self.leaves),
360
- 'metadata': self.metadata, },
361
- path=path,
362
- subtree_renderer=subtree_renderer,
363
- )
364
-
365
349
  def apply(
366
350
  self,
367
351
  state_map: GraphStateMapping,
@@ -15,13 +15,12 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
- from collections.abc import Callable
20
- from threading import Thread
21
-
22
18
  import jax
23
19
  import jax.numpy as jnp
20
+ import unittest
24
21
  from absl.testing import absltest, parameterized
22
+ from collections.abc import Callable
23
+ from threading import Thread
25
24
 
26
25
  import brainstate as bst
27
26
 
brainstate/init/_base.py CHANGED
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from typing import Optional, Tuple
19
-
20
18
  import numpy as np
19
+ from typing import Optional, Tuple
21
20
 
22
21
  from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
23
22
 
@@ -16,11 +16,10 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- from typing import Union, Callable, Optional, Sequence
20
-
21
19
  import brainunit as bu
22
20
  import jax
23
21
  import numpy as np
22
+ from typing import Union, Callable, Optional, Sequence
24
23
 
25
24
  from brainstate._state import State
26
25
  from brainstate._utils import set_module_as
brainstate/nn/__init__.py CHANGED
@@ -17,6 +17,8 @@
17
17
  from . import metrics
18
18
  from ._collective_ops import *
19
19
  from ._collective_ops import __all__ as collective_ops_all
20
+ from ._common import *
21
+ from ._common import __all__ as common_all
20
22
  from ._dyn_impl import *
21
23
  from ._dyn_impl import __all__ as dyn_impl_all
22
24
  from ._dynamics import *
@@ -29,24 +31,30 @@ from ._interaction import *
29
31
  from ._interaction import __all__ as interaction_all
30
32
  from ._module import *
31
33
  from ._module import __all__ as module_all
34
+ from ._utils import *
35
+ from ._utils import __all__ as utils_all
32
36
 
33
37
  __all__ = (
34
38
  ['metrics']
35
39
  + collective_ops_all
40
+ + common_all
36
41
  + dyn_impl_all
37
42
  + dynamics_all
38
43
  + elementwise_all
39
44
  + module_all
40
45
  + exp_euler_all
41
46
  + interaction_all
47
+ + utils_all
42
48
  )
43
49
 
44
50
  del (
45
51
  collective_ops_all,
52
+ common_all,
46
53
  dyn_impl_all,
47
54
  dynamics_all,
48
55
  elementwise_all,
49
56
  module_all,
50
57
  exp_euler_all,
51
58
  interaction_all,
59
+ utils_all,
52
60
  )