brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,13 @@
|
|
15
15
|
# See the License for the specific language governing permissions and
|
16
16
|
# limitations under the License.
|
17
17
|
|
18
|
+
"""
|
19
|
+
Pretty representation utilities for creating human-readable string representations.
|
20
|
+
|
21
|
+
This module provides utilities for creating customizable pretty representations of
|
22
|
+
objects, with support for nested structures and circular reference detection.
|
23
|
+
"""
|
24
|
+
|
18
25
|
import dataclasses
|
19
26
|
import threading
|
20
27
|
from abc import ABC, abstractmethod
|
@@ -38,6 +45,21 @@ B = TypeVar('B')
|
|
38
45
|
class PrettyType:
|
39
46
|
"""
|
40
47
|
Configuration for pretty representation of objects.
|
48
|
+
|
49
|
+
Attributes
|
50
|
+
----------
|
51
|
+
type : Union[str, type]
|
52
|
+
The type name or type object to display.
|
53
|
+
start : str, default='('
|
54
|
+
The opening delimiter for the representation.
|
55
|
+
end : str, default=')'
|
56
|
+
The closing delimiter for the representation.
|
57
|
+
value_sep : str, default='='
|
58
|
+
The separator between keys and values.
|
59
|
+
elem_indent : str, default=' '
|
60
|
+
The indentation for nested elements.
|
61
|
+
empty_repr : str, default=''
|
62
|
+
The representation for empty objects.
|
41
63
|
"""
|
42
64
|
type: Union[str, type]
|
43
65
|
start: str = '('
|
@@ -51,6 +73,17 @@ class PrettyType:
|
|
51
73
|
class PrettyAttr:
|
52
74
|
"""
|
53
75
|
Configuration for pretty representation of attributes.
|
76
|
+
|
77
|
+
Attributes
|
78
|
+
----------
|
79
|
+
key : str
|
80
|
+
The attribute name or key.
|
81
|
+
value : Union[str, Any]
|
82
|
+
The attribute value.
|
83
|
+
start : str, default=''
|
84
|
+
Optional prefix for the attribute.
|
85
|
+
end : str, default=''
|
86
|
+
Optional suffix for the attribute.
|
54
87
|
"""
|
55
88
|
key: str
|
56
89
|
value: Union[str, Any]
|
@@ -62,23 +95,54 @@ class PrettyRepr(ABC):
|
|
62
95
|
"""
|
63
96
|
Interface for pretty representation of objects.
|
64
97
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
98
|
+
This abstract base class provides a framework for creating custom
|
99
|
+
pretty representations of objects by yielding PrettyType and PrettyAttr
|
100
|
+
instances.
|
101
|
+
|
102
|
+
Examples
|
103
|
+
--------
|
104
|
+
.. code-block:: python
|
105
|
+
|
106
|
+
>>> class MyObject(PrettyRepr):
|
107
|
+
... def __init__(self, key, value):
|
108
|
+
... self.key = key
|
109
|
+
... self.value = value
|
110
|
+
...
|
111
|
+
... def __pretty_repr__(self):
|
112
|
+
... yield PrettyType(type='MyObject', start='{', end='}')
|
113
|
+
... yield PrettyAttr('key', self.key)
|
114
|
+
... yield PrettyAttr('value', self.value)
|
115
|
+
...
|
116
|
+
>>> obj = MyObject('foo', 42)
|
117
|
+
>>> print(obj)
|
118
|
+
MyObject{
|
119
|
+
key=foo,
|
120
|
+
value=42
|
121
|
+
}
|
73
122
|
"""
|
74
123
|
__slots__ = ()
|
75
124
|
|
76
125
|
@abstractmethod
|
77
126
|
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
127
|
+
"""
|
128
|
+
Generate the pretty representation of the object.
|
129
|
+
|
130
|
+
Yields
|
131
|
+
------
|
132
|
+
Union[PrettyType, PrettyAttr]
|
133
|
+
First yield should be PrettyType, followed by PrettyAttr instances.
|
134
|
+
"""
|
78
135
|
raise NotImplementedError
|
79
136
|
|
80
137
|
def __repr__(self) -> str:
|
81
|
-
|
138
|
+
"""
|
139
|
+
Generate string representation using pretty representation.
|
140
|
+
|
141
|
+
Returns
|
142
|
+
-------
|
143
|
+
str
|
144
|
+
The formatted string representation of the object.
|
145
|
+
"""
|
82
146
|
return pretty_repr_object(self)
|
83
147
|
|
84
148
|
|
@@ -174,9 +238,33 @@ def pretty_repr_object(obj: PrettyRepr) -> str:
|
|
174
238
|
class MappingReprMixin(Mapping[A, B]):
|
175
239
|
"""
|
176
240
|
Mapping mixin for pretty representation.
|
241
|
+
|
242
|
+
This mixin provides a default pretty representation for mapping-like objects.
|
243
|
+
|
244
|
+
Examples
|
245
|
+
--------
|
246
|
+
.. code-block:: python
|
247
|
+
|
248
|
+
>>> class MyMapping(dict, MappingReprMixin):
|
249
|
+
... pass
|
250
|
+
...
|
251
|
+
>>> m = MyMapping({'a': 1, 'b': 2})
|
252
|
+
>>> print(m)
|
253
|
+
{
|
254
|
+
'a': 1,
|
255
|
+
'b': 2
|
256
|
+
}
|
177
257
|
"""
|
178
258
|
|
179
|
-
def __pretty_repr__(self):
|
259
|
+
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
260
|
+
"""
|
261
|
+
Generate pretty representation for mapping.
|
262
|
+
|
263
|
+
Yields
|
264
|
+
------
|
265
|
+
Union[PrettyType, PrettyAttr]
|
266
|
+
PrettyType followed by PrettyAttr for each key-value pair.
|
267
|
+
"""
|
180
268
|
yield PrettyType(type='', value_sep=': ', start='{', end='}')
|
181
269
|
|
182
270
|
for key, value in self.items():
|
@@ -187,11 +275,37 @@ class MappingReprMixin(Mapping[A, B]):
|
|
187
275
|
class PrettyMapping(PrettyRepr):
|
188
276
|
"""
|
189
277
|
Pretty representation of a mapping.
|
278
|
+
|
279
|
+
Attributes
|
280
|
+
----------
|
281
|
+
mapping : Mapping
|
282
|
+
The mapping to represent.
|
283
|
+
type_name : str, default=''
|
284
|
+
Optional type name to display.
|
285
|
+
|
286
|
+
Examples
|
287
|
+
--------
|
288
|
+
.. code-block:: python
|
289
|
+
|
290
|
+
>>> m = PrettyMapping({'a': 1, 'b': 2}, type_name='MyDict')
|
291
|
+
>>> print(m)
|
292
|
+
MyDict{
|
293
|
+
'a': 1,
|
294
|
+
'b': 2
|
295
|
+
}
|
190
296
|
"""
|
191
297
|
mapping: Mapping
|
192
298
|
type_name: str = ''
|
193
299
|
|
194
|
-
def __pretty_repr__(self):
|
300
|
+
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
301
|
+
"""
|
302
|
+
Generate pretty representation for the mapping.
|
303
|
+
|
304
|
+
Yields
|
305
|
+
------
|
306
|
+
Union[PrettyType, PrettyAttr]
|
307
|
+
PrettyType followed by PrettyAttr for each key-value pair.
|
308
|
+
"""
|
195
309
|
yield PrettyType(type=self.type_name, value_sep=': ', start='{', end='}')
|
196
310
|
|
197
311
|
for key, value in self.mapping.items():
|
@@ -220,9 +334,9 @@ class PrettyReprContext(threading.local):
|
|
220
334
|
CONTEXT = PrettyReprContext()
|
221
335
|
|
222
336
|
|
223
|
-
def _default_repr_object(node):
|
337
|
+
def _default_repr_object(node: Any) -> Iterator[PrettyType]:
|
224
338
|
"""
|
225
|
-
|
339
|
+
Generate a default pretty representation for an object.
|
226
340
|
|
227
341
|
This function yields a `PrettyType` instance that represents the type
|
228
342
|
of the given object. It is used as a default method for representing
|
@@ -242,9 +356,9 @@ def _default_repr_object(node):
|
|
242
356
|
yield PrettyType(type=type(node))
|
243
357
|
|
244
358
|
|
245
|
-
def _default_repr_attr(node):
|
359
|
+
def _default_repr_attr(node: Any) -> Iterator[PrettyAttr]:
|
246
360
|
"""
|
247
|
-
|
361
|
+
Generate a default pretty representation for the attributes of an object.
|
248
362
|
|
249
363
|
This function iterates over the attributes of the given object and yields
|
250
364
|
a `PrettyAttr` instance for each attribute that does not start with an
|
@@ -269,32 +383,52 @@ def _default_repr_attr(node):
|
|
269
383
|
|
270
384
|
|
271
385
|
def yield_unique_pretty_repr_items(
|
272
|
-
node,
|
386
|
+
node: Any,
|
273
387
|
repr_object: Optional[Callable] = None,
|
274
388
|
repr_attr: Optional[Callable] = None
|
275
|
-
):
|
389
|
+
) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
276
390
|
"""
|
277
|
-
|
391
|
+
Generate a pretty representation of an object while avoiding duplicate representations.
|
278
392
|
|
279
|
-
This function
|
280
|
-
|
281
|
-
|
282
|
-
|
393
|
+
This function yields a structured representation of an object, using custom or default
|
394
|
+
methods for representing the object itself and its attributes. It ensures that each
|
395
|
+
object is only represented once to prevent infinite recursion in cases of circular
|
396
|
+
references.
|
283
397
|
|
284
|
-
Parameters
|
398
|
+
Parameters
|
399
|
+
----------
|
285
400
|
node : Any
|
286
401
|
The object to be represented.
|
287
402
|
repr_object : Optional[Callable], optional
|
288
|
-
A callable that yields the representation of the object itself.
|
403
|
+
A callable that yields the representation of the object itself.
|
289
404
|
If not provided, a default representation function is used.
|
290
405
|
repr_attr : Optional[Callable], optional
|
291
|
-
A callable that yields the representation of the object's attributes.
|
406
|
+
A callable that yields the representation of the object's attributes.
|
292
407
|
If not provided, a default attribute representation function is used.
|
293
408
|
|
294
|
-
Yields
|
409
|
+
Yields
|
410
|
+
------
|
295
411
|
Union[PrettyType, PrettyAttr]
|
296
|
-
The pretty representation of the object and its attributes,
|
412
|
+
The pretty representation of the object and its attributes,
|
297
413
|
avoiding duplicates by tracking seen objects.
|
414
|
+
|
415
|
+
Examples
|
416
|
+
--------
|
417
|
+
.. code-block:: python
|
418
|
+
|
419
|
+
>>> class Node:
|
420
|
+
... def __init__(self, value, next=None):
|
421
|
+
... self.value = value
|
422
|
+
... self.next = next
|
423
|
+
...
|
424
|
+
>>> # Create circular reference
|
425
|
+
>>> node1 = Node(1)
|
426
|
+
>>> node2 = Node(2, node1)
|
427
|
+
>>> node1.next = node2
|
428
|
+
...
|
429
|
+
>>> # This will handle circular reference gracefully
|
430
|
+
>>> for item in yield_unique_pretty_repr_items(node1):
|
431
|
+
... print(item)
|
298
432
|
"""
|
299
433
|
if repr_object is None:
|
300
434
|
repr_object = _default_repr_object
|