brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +875 -93
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +194 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +2 -3
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +63 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +183 -35
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +128 -10
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,23 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from collections import abc
|
21
|
-
from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
|
22
21
|
|
23
22
|
import jax
|
23
|
+
from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
|
24
24
|
|
25
25
|
from brainstate.typing import Filter, PathParts
|
26
|
-
from .
|
27
|
-
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr
|
26
|
+
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
|
28
27
|
from ._struct import dataclass
|
28
|
+
from .filter import to_predicate
|
29
29
|
|
30
30
|
__all__ = [
|
31
|
-
'PrettyDict',
|
32
|
-
'
|
31
|
+
'PrettyDict',
|
32
|
+
'NestedDict',
|
33
|
+
'FlattedDict',
|
34
|
+
'flat_mapping',
|
35
|
+
'nest_mapping',
|
36
|
+
'PrettyList',
|
37
|
+
'PrettyObject',
|
33
38
|
]
|
34
39
|
|
35
40
|
A = TypeVar('A')
|
@@ -41,42 +46,117 @@ ExtractValueFn = abc.Callable[[Any], Any]
|
|
41
46
|
SetValueFn = abc.Callable[[V, Any], V]
|
42
47
|
|
43
48
|
|
49
|
+
def _repr_object_general(node: PrettyDict):
|
50
|
+
"""
|
51
|
+
Generate a general representation of a PrettyDict object.
|
52
|
+
|
53
|
+
This function is used to create a pretty representation of a PrettyDict
|
54
|
+
object, which includes the type of the object and its value separator.
|
44
55
|
|
56
|
+
Args:
|
57
|
+
node (PrettyDict): The PrettyDict object to be represented.
|
45
58
|
|
46
|
-
|
59
|
+
Yields:
|
60
|
+
PrettyType: A PrettyType object representing the type of the node,
|
61
|
+
with specified value separator, start, and end characters.
|
47
62
|
"""
|
48
|
-
|
63
|
+
yield PrettyType(type(node), value_sep='=', start='(', end=')')
|
64
|
+
|
65
|
+
|
66
|
+
def _repr_attribute_general(node):
|
67
|
+
"""
|
68
|
+
Generate a pretty representation of the attributes of a node.
|
69
|
+
|
70
|
+
This function iterates over the attributes of a given node and attempts
|
71
|
+
to generate a pretty representation for each attribute. It handles
|
72
|
+
conversion of lists and dictionaries to their pretty representation
|
73
|
+
counterparts and yields a PrettyAttr object for each attribute.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
node: The object whose attributes are to be represented.
|
77
|
+
|
78
|
+
Yields:
|
79
|
+
PrettyAttr: A PrettyAttr object representing the key and value of
|
80
|
+
each attribute in a pretty format.
|
81
|
+
"""
|
82
|
+
for k, v in vars(node).items():
|
83
|
+
try:
|
84
|
+
res = node.__pretty_repr_item__(k, v)
|
85
|
+
if res is None:
|
86
|
+
continue
|
87
|
+
k, v = res
|
88
|
+
except AttributeError:
|
89
|
+
pass
|
90
|
+
|
91
|
+
if k is None:
|
92
|
+
continue
|
93
|
+
|
94
|
+
# convert list to PrettyList
|
95
|
+
if isinstance(v, list):
|
96
|
+
v = PrettyList(v)
|
97
|
+
|
98
|
+
# convert dict to PrettyDict
|
99
|
+
if isinstance(v, dict):
|
100
|
+
v = PrettyDict(v)
|
101
|
+
|
102
|
+
# convert PrettyDict to NestedStateRepr
|
103
|
+
if isinstance(v, PrettyDict):
|
104
|
+
v = NestedStateRepr(v)
|
105
|
+
|
106
|
+
yield PrettyAttr(k, v)
|
107
|
+
|
108
|
+
|
109
|
+
class PrettyObject(PrettyRepr):
|
110
|
+
"""
|
111
|
+
A class for generating a pretty representation of a tree-like structure.
|
112
|
+
|
113
|
+
This class extends the PrettyRepr class to provide a mechanism for
|
114
|
+
generating a human-readable, pretty representation of tree-like data
|
115
|
+
structures. It utilizes custom functions to represent the object and
|
116
|
+
its attributes in a structured and visually appealing format.
|
117
|
+
|
118
|
+
Methods
|
119
|
+
-------
|
120
|
+
__pretty_repr__: Generates a sequence of pretty representation items
|
121
|
+
for the object.
|
122
|
+
__pretty_repr_item__: Returns a tuple of the key and value for pretty
|
123
|
+
representation of an item in the data structure.
|
49
124
|
"""
|
50
125
|
|
51
126
|
def __pretty_repr__(self):
|
52
|
-
|
127
|
+
"""
|
128
|
+
Generates a pretty representation of the object.
|
129
|
+
|
130
|
+
This method yields a sequence of pretty representation items for the object,
|
131
|
+
using specified functions to represent the object and its attributes.
|
132
|
+
|
133
|
+
Yields:
|
134
|
+
Pretty representation items generated by `yield_unique_pretty_repr_items`.
|
135
|
+
"""
|
136
|
+
yield from yield_unique_pretty_repr_items(
|
53
137
|
self,
|
54
|
-
repr_object=
|
55
|
-
repr_attr=
|
138
|
+
repr_object=_repr_object_general,
|
139
|
+
repr_attr=_repr_attribute_general,
|
56
140
|
)
|
57
141
|
|
58
142
|
def __pretty_repr_item__(self, k, v):
|
59
|
-
|
60
|
-
|
61
|
-
def _repr_object(self, node: PrettyDict):
|
62
|
-
yield PrettyType(type(node), value_sep=': ', start='({', end='})')
|
143
|
+
"""
|
144
|
+
Returns a tuple of the key and value for pretty representation.
|
63
145
|
|
64
|
-
|
65
|
-
|
66
|
-
k, v = self.__pretty_repr_item__(k, v)
|
67
|
-
if k is None:
|
68
|
-
continue
|
146
|
+
This method is used to generate a pretty representation of an item
|
147
|
+
in a data structure, typically for debugging or logging purposes.
|
69
148
|
|
70
|
-
|
71
|
-
|
149
|
+
Args:
|
150
|
+
k: The key of the item.
|
151
|
+
v: The value of the item.
|
72
152
|
|
73
|
-
|
74
|
-
|
153
|
+
Returns:
|
154
|
+
A tuple containing the key and value.
|
155
|
+
"""
|
156
|
+
return k, v
|
75
157
|
|
76
|
-
if isinstance(v, PrettyDict):
|
77
|
-
v = NestedStateRepr(v)
|
78
158
|
|
79
|
-
|
159
|
+
PrettyReprTree = PrettyObject
|
80
160
|
|
81
161
|
|
82
162
|
# the empty node is a struct.dataclass to be compatible with JAX.
|
@@ -252,7 +332,7 @@ class PrettyDict(dict, PrettyRepr):
|
|
252
332
|
|
253
333
|
def __repr__(self) -> str:
|
254
334
|
# repr the individual object with the pretty representation
|
255
|
-
return
|
335
|
+
return pretty_repr_object(self)
|
256
336
|
|
257
337
|
def __pretty_repr__(self):
|
258
338
|
yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
|
@@ -789,7 +869,7 @@ class PrettyList(list, PrettyRepr):
|
|
789
869
|
yield from yield_unique_pretty_repr_items(self, _list_repr_object, _list_repr_attr)
|
790
870
|
|
791
871
|
def __repr__(self):
|
792
|
-
return
|
872
|
+
return pretty_repr_object(self)
|
793
873
|
|
794
874
|
def tree_flatten(self):
|
795
875
|
return list(self), ()
|
@@ -812,4 +892,3 @@ def _list_repr_attr(node: PrettyList):
|
|
812
892
|
|
813
893
|
def _list_repr_object(node: PrettyDict):
|
814
894
|
yield PrettyType('', value_sep='', start='[', end=']')
|
815
|
-
|
brainstate/util/_pretty_repr.py
CHANGED
@@ -21,11 +21,10 @@ import dataclasses
|
|
21
21
|
import threading
|
22
22
|
from abc import ABC, abstractmethod
|
23
23
|
from functools import partial
|
24
|
-
from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
|
24
|
+
from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
|
25
25
|
|
26
26
|
__all__ = [
|
27
27
|
'yield_unique_pretty_repr_items',
|
28
|
-
'pretty_repr',
|
29
28
|
'PrettyType',
|
30
29
|
'PrettyAttr',
|
31
30
|
'PrettyRepr',
|
@@ -82,10 +81,37 @@ class PrettyRepr(ABC):
|
|
82
81
|
|
83
82
|
def __repr__(self) -> str:
|
84
83
|
# repr the individual object with the pretty representation
|
85
|
-
return
|
84
|
+
return pretty_repr_object(self)
|
86
85
|
|
87
86
|
|
88
|
-
def
|
87
|
+
def pretty_repr_elem(obj: PrettyType, elem: Any) -> str:
|
88
|
+
"""
|
89
|
+
Constructs a string representation of a single element within a pretty representation.
|
90
|
+
|
91
|
+
This function takes a `PrettyType` object and an element, which must be an instance
|
92
|
+
of `PrettyAttr`, and generates a formatted string that represents the element. The
|
93
|
+
formatting is based on the configuration provided by the `PrettyType` object.
|
94
|
+
|
95
|
+
Parameters
|
96
|
+
----------
|
97
|
+
obj : PrettyType
|
98
|
+
The configuration object that defines how the element should be formatted.
|
99
|
+
It includes details such as indentation, separators, and surrounding characters.
|
100
|
+
elem : Any
|
101
|
+
The element to be represented. It must be an instance of `PrettyAttr`, which
|
102
|
+
contains the key and value to be formatted.
|
103
|
+
|
104
|
+
Returns
|
105
|
+
-------
|
106
|
+
str
|
107
|
+
A string that represents the element in a formatted manner, adhering to the
|
108
|
+
configuration specified by the `PrettyType` object.
|
109
|
+
|
110
|
+
Raises
|
111
|
+
------
|
112
|
+
TypeError
|
113
|
+
If the provided element is not an instance of `PrettyAttr`.
|
114
|
+
"""
|
89
115
|
if not isinstance(elem, PrettyAttr):
|
90
116
|
raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
|
91
117
|
|
@@ -95,9 +121,32 @@ def _repr_elem(obj: PrettyType, elem: Any) -> str:
|
|
95
121
|
return f'{obj.elem_indent}{elem.start}{elem.key}{obj.value_sep}{value}{elem.end}'
|
96
122
|
|
97
123
|
|
98
|
-
def
|
124
|
+
def pretty_repr_object(obj: PrettyRepr) -> str:
|
99
125
|
"""
|
100
|
-
|
126
|
+
Generates a pretty string representation of an object that implements the PrettyRepr interface.
|
127
|
+
|
128
|
+
This function utilizes the __pretty_repr__ method of the PrettyRepr interface to obtain
|
129
|
+
a structured representation of the object, which includes both the type and attributes
|
130
|
+
of the object in a human-readable format.
|
131
|
+
|
132
|
+
Parameters
|
133
|
+
----------
|
134
|
+
obj : PrettyRepr
|
135
|
+
The object for which the pretty representation is to be generated. The object must
|
136
|
+
implement the PrettyRepr interface.
|
137
|
+
|
138
|
+
Returns
|
139
|
+
-------
|
140
|
+
str
|
141
|
+
A string that represents the object in a pretty format, including its type and attributes.
|
142
|
+
The format is determined by the PrettyType and PrettyAttr instances yielded by the
|
143
|
+
__pretty_repr__ method of the object.
|
144
|
+
|
145
|
+
Raises
|
146
|
+
------
|
147
|
+
TypeError
|
148
|
+
If the provided object does not implement the PrettyRepr interface or if the first item
|
149
|
+
yielded by the __pretty_repr__ method is not an instance of PrettyType.
|
101
150
|
"""
|
102
151
|
if not isinstance(obj, PrettyRepr):
|
103
152
|
raise TypeError(f'Object {obj!r} is not representable')
|
@@ -110,7 +159,7 @@ def pretty_repr(obj: PrettyRepr) -> str:
|
|
110
159
|
raise TypeError(f'First item must be PrettyType, got {type(obj_repr).__name__}')
|
111
160
|
|
112
161
|
# repr attributes
|
113
|
-
elem_reprs = tuple(map(partial(
|
162
|
+
elem_reprs = tuple(map(partial(pretty_repr_elem, obj_repr), iterator))
|
114
163
|
elems = ',\n'.join(elem_reprs)
|
115
164
|
if elems:
|
116
165
|
elems = '\n' + elems + '\n'
|
@@ -153,7 +202,20 @@ class PrettyMapping(PrettyRepr):
|
|
153
202
|
|
154
203
|
@dataclasses.dataclass
|
155
204
|
class PrettyReprContext(threading.local):
|
156
|
-
|
205
|
+
"""
|
206
|
+
A thread-local context for managing the state of pretty representation.
|
207
|
+
|
208
|
+
This class is used to keep track of objects that have been seen during
|
209
|
+
the generation of pretty representations, preventing infinite recursion
|
210
|
+
in cases of circular references.
|
211
|
+
|
212
|
+
Attributes
|
213
|
+
----------
|
214
|
+
seen_modules_repr : dict[int, Any] | None
|
215
|
+
A dictionary mapping object IDs to objects that have been seen
|
216
|
+
during the pretty representation process. This is used to avoid
|
217
|
+
representing the same object multiple times.
|
218
|
+
"""
|
157
219
|
seen_modules_repr: dict[int, Any] | None = None
|
158
220
|
|
159
221
|
|
@@ -161,10 +223,47 @@ CONTEXT = PrettyReprContext()
|
|
161
223
|
|
162
224
|
|
163
225
|
def _default_repr_object(node):
|
226
|
+
"""
|
227
|
+
Generates a default pretty representation for an object.
|
228
|
+
|
229
|
+
This function yields a `PrettyType` instance that represents the type
|
230
|
+
of the given object. It is used as a default method for representing
|
231
|
+
objects when no custom representation function is provided.
|
232
|
+
|
233
|
+
Parameters
|
234
|
+
----------
|
235
|
+
node : Any
|
236
|
+
The object for which the pretty representation is to be generated.
|
237
|
+
|
238
|
+
Yields
|
239
|
+
------
|
240
|
+
PrettyType
|
241
|
+
An instance of `PrettyType` that contains the type information of
|
242
|
+
the object.
|
243
|
+
"""
|
164
244
|
yield PrettyType(type=type(node))
|
165
245
|
|
166
246
|
|
167
247
|
def _default_repr_attr(node):
|
248
|
+
"""
|
249
|
+
Generates a default pretty representation for the attributes of an object.
|
250
|
+
|
251
|
+
This function iterates over the attributes of the given object and yields
|
252
|
+
a `PrettyAttr` instance for each attribute that does not start with an
|
253
|
+
underscore. The `PrettyAttr` instances contain the attribute name and its
|
254
|
+
string representation.
|
255
|
+
|
256
|
+
Parameters
|
257
|
+
----------
|
258
|
+
node : Any
|
259
|
+
The object whose attributes are to be represented.
|
260
|
+
|
261
|
+
Yields
|
262
|
+
------
|
263
|
+
PrettyAttr
|
264
|
+
An instance of `PrettyAttr` for each non-private attribute of the object,
|
265
|
+
containing the attribute name and its string representation.
|
266
|
+
"""
|
168
267
|
for name, value in vars(node).items():
|
169
268
|
if name.startswith('_'):
|
170
269
|
continue
|
@@ -177,7 +276,27 @@ def yield_unique_pretty_repr_items(
|
|
177
276
|
repr_attr: Optional[Callable] = None
|
178
277
|
):
|
179
278
|
"""
|
180
|
-
|
279
|
+
Generates a pretty representation of an object while avoiding duplicate representations.
|
280
|
+
|
281
|
+
This function is designed to yield a structured representation of an object,
|
282
|
+
using custom or default methods for representing the object itself and its attributes.
|
283
|
+
It ensures that each object is only represented once to prevent infinite recursion
|
284
|
+
in cases of circular references.
|
285
|
+
|
286
|
+
Parameters:
|
287
|
+
node : Any
|
288
|
+
The object to be represented.
|
289
|
+
repr_object : Optional[Callable], optional
|
290
|
+
A callable that yields the representation of the object itself.
|
291
|
+
If not provided, a default representation function is used.
|
292
|
+
repr_attr : Optional[Callable], optional
|
293
|
+
A callable that yields the representation of the object's attributes.
|
294
|
+
If not provided, a default attribute representation function is used.
|
295
|
+
|
296
|
+
Yields:
|
297
|
+
Union[PrettyType, PrettyAttr]
|
298
|
+
The pretty representation of the object and its attributes,
|
299
|
+
avoiding duplicates by tracking seen objects.
|
181
300
|
"""
|
182
301
|
if repr_object is None:
|
183
302
|
repr_object = _default_repr_object
|
@@ -209,4 +328,3 @@ def yield_unique_pretty_repr_items(
|
|
209
328
|
finally:
|
210
329
|
if clear_seen:
|
211
330
|
CONTEXT.seen_modules_repr = None
|
212
|
-
|