brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. brainstate/_state.py +875 -93
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +4 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +194 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +2 -3
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +63 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/metrics.py +3 -4
  68. brainstate/optim/_lr_scheduler.py +1 -2
  69. brainstate/optim/_lr_scheduler_test.py +2 -3
  70. brainstate/optim/_optax_optimizer_test.py +1 -2
  71. brainstate/optim/_sgd_optimizer.py +2 -3
  72. brainstate/random/_rand_funs.py +1 -2
  73. brainstate/random/_rand_funs_test.py +2 -3
  74. brainstate/random/_rand_seed.py +2 -3
  75. brainstate/random/_rand_seed_test.py +1 -2
  76. brainstate/random/_rand_state.py +3 -4
  77. brainstate/surrogate.py +183 -35
  78. brainstate/transform.py +0 -3
  79. brainstate/typing.py +28 -25
  80. brainstate/util/__init__.py +9 -7
  81. brainstate/util/_caller.py +1 -2
  82. brainstate/util/_error.py +27 -0
  83. brainstate/util/_others.py +60 -15
  84. brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
  85. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  86. brainstate/util/_pretty_repr.py +128 -10
  87. brainstate/util/_pretty_table.py +2900 -0
  88. brainstate/util/_struct.py +11 -11
  89. brainstate/util/filter.py +472 -0
  90. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
  91. brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
  92. brainstate/util/_filter.py +0 -178
  93. brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
  94. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
  95. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
  96. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,23 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from collections import abc
21
- from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
22
21
 
23
22
  import jax
23
+ from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
24
24
 
25
25
  from brainstate.typing import Filter, PathParts
26
- from ._filter import to_predicate
27
- from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr
26
+ from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
28
27
  from ._struct import dataclass
28
+ from .filter import to_predicate
29
29
 
30
30
  __all__ = [
31
- 'PrettyDict', 'NestedDict', 'FlattedDict', 'flat_mapping', 'nest_mapping',
32
- 'PrettyList', 'PrettyReprTree',
31
+ 'PrettyDict',
32
+ 'NestedDict',
33
+ 'FlattedDict',
34
+ 'flat_mapping',
35
+ 'nest_mapping',
36
+ 'PrettyList',
37
+ 'PrettyObject',
33
38
  ]
34
39
 
35
40
  A = TypeVar('A')
@@ -41,42 +46,117 @@ ExtractValueFn = abc.Callable[[Any], Any]
41
46
  SetValueFn = abc.Callable[[V, Any], V]
42
47
 
43
48
 
49
+ def _repr_object_general(node: PrettyDict):
50
+ """
51
+ Generate a general representation of a PrettyDict object.
52
+
53
+ This function is used to create a pretty representation of a PrettyDict
54
+ object, which includes the type of the object and its value separator.
44
55
 
56
+ Args:
57
+ node (PrettyDict): The PrettyDict object to be represented.
45
58
 
46
- class PrettyReprTree(PrettyRepr):
59
+ Yields:
60
+ PrettyType: A PrettyType object representing the type of the node,
61
+ with specified value separator, start, and end characters.
47
62
  """
48
- Pretty representation of a tree.
63
+ yield PrettyType(type(node), value_sep='=', start='(', end=')')
64
+
65
+
66
+ def _repr_attribute_general(node):
67
+ """
68
+ Generate a pretty representation of the attributes of a node.
69
+
70
+ This function iterates over the attributes of a given node and attempts
71
+ to generate a pretty representation for each attribute. It handles
72
+ conversion of lists and dictionaries to their pretty representation
73
+ counterparts and yields a PrettyAttr object for each attribute.
74
+
75
+ Args:
76
+ node: The object whose attributes are to be represented.
77
+
78
+ Yields:
79
+ PrettyAttr: A PrettyAttr object representing the key and value of
80
+ each attribute in a pretty format.
81
+ """
82
+ for k, v in vars(node).items():
83
+ try:
84
+ res = node.__pretty_repr_item__(k, v)
85
+ if res is None:
86
+ continue
87
+ k, v = res
88
+ except AttributeError:
89
+ pass
90
+
91
+ if k is None:
92
+ continue
93
+
94
+ # convert list to PrettyList
95
+ if isinstance(v, list):
96
+ v = PrettyList(v)
97
+
98
+ # convert dict to PrettyDict
99
+ if isinstance(v, dict):
100
+ v = PrettyDict(v)
101
+
102
+ # convert PrettyDict to NestedStateRepr
103
+ if isinstance(v, PrettyDict):
104
+ v = NestedStateRepr(v)
105
+
106
+ yield PrettyAttr(k, v)
107
+
108
+
109
+ class PrettyObject(PrettyRepr):
110
+ """
111
+ A class for generating a pretty representation of a tree-like structure.
112
+
113
+ This class extends the PrettyRepr class to provide a mechanism for
114
+ generating a human-readable, pretty representation of tree-like data
115
+ structures. It utilizes custom functions to represent the object and
116
+ its attributes in a structured and visually appealing format.
117
+
118
+ Methods
119
+ -------
120
+ __pretty_repr__: Generates a sequence of pretty representation items
121
+ for the object.
122
+ __pretty_repr_item__: Returns a tuple of the key and value for pretty
123
+ representation of an item in the data structure.
49
124
  """
50
125
 
51
126
  def __pretty_repr__(self):
52
- return yield_unique_pretty_repr_items(
127
+ """
128
+ Generates a pretty representation of the object.
129
+
130
+ This method yields a sequence of pretty representation items for the object,
131
+ using specified functions to represent the object and its attributes.
132
+
133
+ Yields:
134
+ Pretty representation items generated by `yield_unique_pretty_repr_items`.
135
+ """
136
+ yield from yield_unique_pretty_repr_items(
53
137
  self,
54
- repr_object=self._repr_object,
55
- repr_attr=self._repr_attr,
138
+ repr_object=_repr_object_general,
139
+ repr_attr=_repr_attribute_general,
56
140
  )
57
141
 
58
142
  def __pretty_repr_item__(self, k, v):
59
- return k, v
60
-
61
- def _repr_object(self, node: PrettyDict):
62
- yield PrettyType(type(node), value_sep=': ', start='({', end='})')
143
+ """
144
+ Returns a tuple of the key and value for pretty representation.
63
145
 
64
- def _repr_attr(self, node):
65
- for k, v in vars(node).items():
66
- k, v = self.__pretty_repr_item__(k, v)
67
- if k is None:
68
- continue
146
+ This method is used to generate a pretty representation of an item
147
+ in a data structure, typically for debugging or logging purposes.
69
148
 
70
- if isinstance(v, list):
71
- v = PrettyList(v)
149
+ Args:
150
+ k: The key of the item.
151
+ v: The value of the item.
72
152
 
73
- if isinstance(v, dict):
74
- v = PrettyDict(v)
153
+ Returns:
154
+ A tuple containing the key and value.
155
+ """
156
+ return k, v
75
157
 
76
- if isinstance(v, PrettyDict):
77
- v = NestedStateRepr(v)
78
158
 
79
- yield PrettyAttr(repr(k), v)
159
+ PrettyReprTree = PrettyObject
80
160
 
81
161
 
82
162
  # the empty node is a struct.dataclass to be compatible with JAX.
@@ -252,7 +332,7 @@ class PrettyDict(dict, PrettyRepr):
252
332
 
253
333
  def __repr__(self) -> str:
254
334
  # repr the individual object with the pretty representation
255
- return pretty_repr(self)
335
+ return pretty_repr_object(self)
256
336
 
257
337
  def __pretty_repr__(self):
258
338
  yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
@@ -789,7 +869,7 @@ class PrettyList(list, PrettyRepr):
789
869
  yield from yield_unique_pretty_repr_items(self, _list_repr_object, _list_repr_attr)
790
870
 
791
871
  def __repr__(self):
792
- return pretty_repr(self)
872
+ return pretty_repr_object(self)
793
873
 
794
874
  def tree_flatten(self):
795
875
  return list(self), ()
@@ -812,4 +892,3 @@ def _list_repr_attr(node: PrettyList):
812
892
 
813
893
  def _list_repr_object(node: PrettyDict):
814
894
  yield PrettyType('', value_sep='', start='[', end=']')
815
-
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax
19
+ import unittest
21
20
  from absl.testing import absltest
22
21
 
23
22
  import brainstate as bst
@@ -21,11 +21,10 @@ import dataclasses
21
21
  import threading
22
22
  from abc import ABC, abstractmethod
23
23
  from functools import partial
24
- from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional, Sequence
24
+ from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
25
25
 
26
26
  __all__ = [
27
27
  'yield_unique_pretty_repr_items',
28
- 'pretty_repr',
29
28
  'PrettyType',
30
29
  'PrettyAttr',
31
30
  'PrettyRepr',
@@ -82,10 +81,37 @@ class PrettyRepr(ABC):
82
81
 
83
82
  def __repr__(self) -> str:
84
83
  # repr the individual object with the pretty representation
85
- return pretty_repr(self)
84
+ return pretty_repr_object(self)
86
85
 
87
86
 
88
- def _repr_elem(obj: PrettyType, elem: Any) -> str:
87
+ def pretty_repr_elem(obj: PrettyType, elem: Any) -> str:
88
+ """
89
+ Constructs a string representation of a single element within a pretty representation.
90
+
91
+ This function takes a `PrettyType` object and an element, which must be an instance
92
+ of `PrettyAttr`, and generates a formatted string that represents the element. The
93
+ formatting is based on the configuration provided by the `PrettyType` object.
94
+
95
+ Parameters
96
+ ----------
97
+ obj : PrettyType
98
+ The configuration object that defines how the element should be formatted.
99
+ It includes details such as indentation, separators, and surrounding characters.
100
+ elem : Any
101
+ The element to be represented. It must be an instance of `PrettyAttr`, which
102
+ contains the key and value to be formatted.
103
+
104
+ Returns
105
+ -------
106
+ str
107
+ A string that represents the element in a formatted manner, adhering to the
108
+ configuration specified by the `PrettyType` object.
109
+
110
+ Raises
111
+ ------
112
+ TypeError
113
+ If the provided element is not an instance of `PrettyAttr`.
114
+ """
89
115
  if not isinstance(elem, PrettyAttr):
90
116
  raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
91
117
 
@@ -95,9 +121,32 @@ def _repr_elem(obj: PrettyType, elem: Any) -> str:
95
121
  return f'{obj.elem_indent}{elem.start}{elem.key}{obj.value_sep}{value}{elem.end}'
96
122
 
97
123
 
98
- def pretty_repr(obj: PrettyRepr) -> str:
124
+ def pretty_repr_object(obj: PrettyRepr) -> str:
99
125
  """
100
- Get the pretty representation of an object.
126
+ Generates a pretty string representation of an object that implements the PrettyRepr interface.
127
+
128
+ This function utilizes the __pretty_repr__ method of the PrettyRepr interface to obtain
129
+ a structured representation of the object, which includes both the type and attributes
130
+ of the object in a human-readable format.
131
+
132
+ Parameters
133
+ ----------
134
+ obj : PrettyRepr
135
+ The object for which the pretty representation is to be generated. The object must
136
+ implement the PrettyRepr interface.
137
+
138
+ Returns
139
+ -------
140
+ str
141
+ A string that represents the object in a pretty format, including its type and attributes.
142
+ The format is determined by the PrettyType and PrettyAttr instances yielded by the
143
+ __pretty_repr__ method of the object.
144
+
145
+ Raises
146
+ ------
147
+ TypeError
148
+ If the provided object does not implement the PrettyRepr interface or if the first item
149
+ yielded by the __pretty_repr__ method is not an instance of PrettyType.
101
150
  """
102
151
  if not isinstance(obj, PrettyRepr):
103
152
  raise TypeError(f'Object {obj!r} is not representable')
@@ -110,7 +159,7 @@ def pretty_repr(obj: PrettyRepr) -> str:
110
159
  raise TypeError(f'First item must be PrettyType, got {type(obj_repr).__name__}')
111
160
 
112
161
  # repr attributes
113
- elem_reprs = tuple(map(partial(_repr_elem, obj_repr), iterator))
162
+ elem_reprs = tuple(map(partial(pretty_repr_elem, obj_repr), iterator))
114
163
  elems = ',\n'.join(elem_reprs)
115
164
  if elems:
116
165
  elems = '\n' + elems + '\n'
@@ -153,7 +202,20 @@ class PrettyMapping(PrettyRepr):
153
202
 
154
203
  @dataclasses.dataclass
155
204
  class PrettyReprContext(threading.local):
156
- # seen_modules_repr: set[int] | None = None
205
+ """
206
+ A thread-local context for managing the state of pretty representation.
207
+
208
+ This class is used to keep track of objects that have been seen during
209
+ the generation of pretty representations, preventing infinite recursion
210
+ in cases of circular references.
211
+
212
+ Attributes
213
+ ----------
214
+ seen_modules_repr : dict[int, Any] | None
215
+ A dictionary mapping object IDs to objects that have been seen
216
+ during the pretty representation process. This is used to avoid
217
+ representing the same object multiple times.
218
+ """
157
219
  seen_modules_repr: dict[int, Any] | None = None
158
220
 
159
221
 
@@ -161,10 +223,47 @@ CONTEXT = PrettyReprContext()
161
223
 
162
224
 
163
225
  def _default_repr_object(node):
226
+ """
227
+ Generates a default pretty representation for an object.
228
+
229
+ This function yields a `PrettyType` instance that represents the type
230
+ of the given object. It is used as a default method for representing
231
+ objects when no custom representation function is provided.
232
+
233
+ Parameters
234
+ ----------
235
+ node : Any
236
+ The object for which the pretty representation is to be generated.
237
+
238
+ Yields
239
+ ------
240
+ PrettyType
241
+ An instance of `PrettyType` that contains the type information of
242
+ the object.
243
+ """
164
244
  yield PrettyType(type=type(node))
165
245
 
166
246
 
167
247
  def _default_repr_attr(node):
248
+ """
249
+ Generates a default pretty representation for the attributes of an object.
250
+
251
+ This function iterates over the attributes of the given object and yields
252
+ a `PrettyAttr` instance for each attribute that does not start with an
253
+ underscore. The `PrettyAttr` instances contain the attribute name and its
254
+ string representation.
255
+
256
+ Parameters
257
+ ----------
258
+ node : Any
259
+ The object whose attributes are to be represented.
260
+
261
+ Yields
262
+ ------
263
+ PrettyAttr
264
+ An instance of `PrettyAttr` for each non-private attribute of the object,
265
+ containing the attribute name and its string representation.
266
+ """
168
267
  for name, value in vars(node).items():
169
268
  if name.startswith('_'):
170
269
  continue
@@ -177,7 +276,27 @@ def yield_unique_pretty_repr_items(
177
276
  repr_attr: Optional[Callable] = None
178
277
  ):
179
278
  """
180
- Pretty representation of an object avoiding duplicate representations.
279
+ Generates a pretty representation of an object while avoiding duplicate representations.
280
+
281
+ This function is designed to yield a structured representation of an object,
282
+ using custom or default methods for representing the object itself and its attributes.
283
+ It ensures that each object is only represented once to prevent infinite recursion
284
+ in cases of circular references.
285
+
286
+ Parameters:
287
+ node : Any
288
+ The object to be represented.
289
+ repr_object : Optional[Callable], optional
290
+ A callable that yields the representation of the object itself.
291
+ If not provided, a default representation function is used.
292
+ repr_attr : Optional[Callable], optional
293
+ A callable that yields the representation of the object's attributes.
294
+ If not provided, a default attribute representation function is used.
295
+
296
+ Yields:
297
+ Union[PrettyType, PrettyAttr]
298
+ The pretty representation of the object and its attributes,
299
+ avoiding duplicates by tracking seen objects.
181
300
  """
182
301
  if repr_object is None:
183
302
  repr_object = _default_repr_object
@@ -209,4 +328,3 @@ def yield_unique_pretty_repr_items(
209
328
  finally:
210
329
  if clear_seen:
211
330
  CONTEXT.seen_modules_repr = None
212
-