brainstate 0.1.0.post20250501__py2.py3-none-any.whl → 0.1.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +8 -19
- brainstate/_state.py +177 -177
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_make_jaxpr.py +4 -6
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_regular_inits.py +0 -2
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_module.py +0 -1
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/METADATA +16 -6
- brainstate-0.1.1.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250501.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,6 @@
|
|
15
15
|
# See the License for the specific language governing permissions and
|
16
16
|
# limitations under the License.
|
17
17
|
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
18
|
from collections import abc
|
21
19
|
from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
|
22
20
|
|
@@ -46,66 +44,6 @@ ExtractValueFn = abc.Callable[[Any], Any]
|
|
46
44
|
SetValueFn = abc.Callable[[V, Any], V]
|
47
45
|
|
48
46
|
|
49
|
-
def _repr_object_general(node: PrettyDict):
|
50
|
-
"""
|
51
|
-
Generate a general representation of a PrettyDict object.
|
52
|
-
|
53
|
-
This function is used to create a pretty representation of a PrettyDict
|
54
|
-
object, which includes the type of the object and its value separator.
|
55
|
-
|
56
|
-
Args:
|
57
|
-
node (PrettyDict): The PrettyDict object to be represented.
|
58
|
-
|
59
|
-
Yields:
|
60
|
-
PrettyType: A PrettyType object representing the type of the node,
|
61
|
-
with specified value separator, start, and end characters.
|
62
|
-
"""
|
63
|
-
yield PrettyType(type(node), value_sep='=', start='(', end=')')
|
64
|
-
|
65
|
-
|
66
|
-
def _repr_attribute_general(node):
|
67
|
-
"""
|
68
|
-
Generate a pretty representation of the attributes of a node.
|
69
|
-
|
70
|
-
This function iterates over the attributes of a given node and attempts
|
71
|
-
to generate a pretty representation for each attribute. It handles
|
72
|
-
conversion of lists and dictionaries to their pretty representation
|
73
|
-
counterparts and yields a PrettyAttr object for each attribute.
|
74
|
-
|
75
|
-
Args:
|
76
|
-
node: The object whose attributes are to be represented.
|
77
|
-
|
78
|
-
Yields:
|
79
|
-
PrettyAttr: A PrettyAttr object representing the key and value of
|
80
|
-
each attribute in a pretty format.
|
81
|
-
"""
|
82
|
-
for k, v in vars(node).items():
|
83
|
-
try:
|
84
|
-
res = node.__pretty_repr_item__(k, v)
|
85
|
-
if res is None:
|
86
|
-
continue
|
87
|
-
k, v = res
|
88
|
-
except AttributeError:
|
89
|
-
pass
|
90
|
-
|
91
|
-
if k is None:
|
92
|
-
continue
|
93
|
-
|
94
|
-
# convert list to PrettyList
|
95
|
-
if isinstance(v, list):
|
96
|
-
v = PrettyList(v)
|
97
|
-
|
98
|
-
# convert dict to PrettyDict
|
99
|
-
if isinstance(v, dict):
|
100
|
-
v = PrettyDict(v)
|
101
|
-
|
102
|
-
# convert PrettyDict to NestedStateRepr
|
103
|
-
if isinstance(v, PrettyDict):
|
104
|
-
v = NestedStateRepr(v)
|
105
|
-
|
106
|
-
yield PrettyAttr(k, v)
|
107
|
-
|
108
|
-
|
109
47
|
class PrettyObject(PrettyRepr):
|
110
48
|
"""
|
111
49
|
A class for generating a pretty representation of a tree-like structure.
|
@@ -284,37 +222,15 @@ def _default_process(x):
|
|
284
222
|
return id(x)
|
285
223
|
|
286
224
|
|
287
|
-
class NestedStateRepr(PrettyRepr):
|
288
|
-
def __init__(self, state: PrettyDict):
|
289
|
-
self.state = state
|
290
|
-
|
291
|
-
def __pretty_repr__(self):
|
292
|
-
yield PrettyType('', value_sep=': ', start='{', end='}')
|
293
|
-
|
294
|
-
for r in self.state.__pretty_repr__():
|
295
|
-
if isinstance(r, PrettyType):
|
296
|
-
continue
|
297
|
-
yield r
|
298
|
-
|
299
|
-
def __treescope_repr__(self, path, subtree_renderer):
|
300
|
-
children = {}
|
301
|
-
for k, v in self.state.items():
|
302
|
-
if isinstance(v, PrettyDict):
|
303
|
-
v = NestedStateRepr(v)
|
304
|
-
children[k] = v
|
305
|
-
# Render as the dictionary itself at the same path.
|
306
|
-
return subtree_renderer(children, path=path)
|
307
|
-
|
308
|
-
|
309
225
|
class PrettyDict(dict, PrettyRepr):
|
310
226
|
__module__ = 'brainstate.util'
|
311
227
|
|
312
|
-
def __getattr__(self, key: K)
|
228
|
+
def __getattr__(self, key: K): # type: ignore[misc]
|
313
229
|
return self[key]
|
314
230
|
|
315
231
|
def treefy_state(self):
|
316
232
|
"""
|
317
|
-
Convert the
|
233
|
+
Convert the :class:`State` objects to a reference tree of the state.
|
318
234
|
"""
|
319
235
|
from brainstate._state import State
|
320
236
|
leaves, treedef = jax.tree.flatten(self)
|
@@ -323,7 +239,7 @@ class PrettyDict(dict, PrettyRepr):
|
|
323
239
|
|
324
240
|
def to_dict(self) -> Dict[K, Dict[K, Any] | V]:
|
325
241
|
"""
|
326
|
-
Convert the
|
242
|
+
Convert the :class:`PrettyDict` to a dictionary.
|
327
243
|
|
328
244
|
Returns:
|
329
245
|
The dictionary.
|
@@ -337,24 +253,46 @@ class PrettyDict(dict, PrettyRepr):
|
|
337
253
|
def __pretty_repr__(self):
|
338
254
|
yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
|
339
255
|
|
340
|
-
def split(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
|
256
|
+
def split(self, *filters) -> Union['PrettyDict[K, V]', Tuple['PrettyDict[K, V]', ...]]:
|
341
257
|
raise NotImplementedError
|
342
258
|
|
343
|
-
def filter(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
|
259
|
+
def filter(self, *filters) -> Union['PrettyDict[K, V]', Tuple['PrettyDict[K, V]', ...]]:
|
344
260
|
raise NotImplementedError
|
345
261
|
|
346
|
-
def merge(self, *states) -> PrettyDict[K, V]:
|
262
|
+
def merge(self, *states) -> 'PrettyDict[K, V]':
|
347
263
|
raise NotImplementedError
|
348
264
|
|
349
|
-
def subset(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
|
265
|
+
def subset(self, *filters) -> Union['PrettyDict[K, V]', Tuple['PrettyDict[K, V]', ...]]:
|
350
266
|
"""
|
351
|
-
Subset a
|
352
|
-
|
353
|
-
:class:`State` types in the
|
267
|
+
Subset a :class:`PrettyDict` into one or more :class:`PrettyDict`'s. The user must pass at least one
|
268
|
+
`:class:`Filter` (i.e. :class:`State`), and the filters must be exhaustive (i.e. they must cover all
|
269
|
+
:class:`State` types in the :class:`PrettyDict`).
|
354
270
|
"""
|
355
271
|
return self.filter(*filters)
|
356
272
|
|
357
273
|
|
274
|
+
class NestedStateRepr(PrettyRepr):
|
275
|
+
def __init__(self, state: PrettyDict):
|
276
|
+
self.state = state
|
277
|
+
|
278
|
+
def __pretty_repr__(self):
|
279
|
+
yield PrettyType('', value_sep=': ', start='{', end='}')
|
280
|
+
|
281
|
+
for r in self.state.__pretty_repr__():
|
282
|
+
if isinstance(r, PrettyType):
|
283
|
+
continue
|
284
|
+
yield r
|
285
|
+
|
286
|
+
def __treescope_repr__(self, path, subtree_renderer):
|
287
|
+
children = {}
|
288
|
+
for k, v in self.state.items():
|
289
|
+
if isinstance(v, PrettyDict):
|
290
|
+
v = NestedStateRepr(v)
|
291
|
+
children[k] = v
|
292
|
+
# Render as the dictionary itself at the same path.
|
293
|
+
return subtree_renderer(children, path=path)
|
294
|
+
|
295
|
+
|
358
296
|
def _default_repr_object(node: PrettyDict):
|
359
297
|
yield PrettyType('', value_sep=': ', start='{', end='}')
|
360
298
|
|
@@ -375,20 +313,20 @@ def _default_repr_attr(node):
|
|
375
313
|
|
376
314
|
class NestedDict(PrettyDict):
|
377
315
|
"""
|
378
|
-
A pytree-like structure that contains a
|
316
|
+
A pytree-like structure that contains a :class:`Mapping` from strings or integers to leaves.
|
379
317
|
|
380
318
|
A valid leaf type is either :class:`State`, ``jax.Array``, ``numpy.ndarray`` or nested
|
381
|
-
|
319
|
+
:class:`NestedDict` and :class:`FlattedDict`.
|
382
320
|
"""
|
383
321
|
__module__ = 'brainstate.util'
|
384
322
|
|
385
|
-
def __or__(self, other: NestedDict[K, V]) -> NestedDict[K, V]:
|
323
|
+
def __or__(self, other: 'NestedDict[K, V]') -> 'NestedDict[K, V]':
|
386
324
|
if not other:
|
387
325
|
return self
|
388
326
|
assert isinstance(other, NestedDict), f'expected NestedDict; got {type(other).__qualname__}'
|
389
327
|
return NestedDict.merge(self, other)
|
390
328
|
|
391
|
-
def __sub__(self, other: NestedDict[K, V]) -> NestedDict[K, V]:
|
329
|
+
def __sub__(self, other: 'NestedDict[K, V]') -> 'NestedDict[K, V]':
|
392
330
|
if not other:
|
393
331
|
return self
|
394
332
|
|
@@ -398,25 +336,25 @@ class NestedDict(PrettyDict):
|
|
398
336
|
diff = {k: v for k, v in self_flat.items() if k not in other_flat}
|
399
337
|
return NestedDict.from_flat(diff)
|
400
338
|
|
401
|
-
def to_flat(self) -> FlattedDict:
|
339
|
+
def to_flat(self) -> 'FlattedDict':
|
402
340
|
"""
|
403
341
|
Flatten the nested mapping into a flat mapping.
|
404
342
|
|
405
343
|
Returns:
|
406
|
-
|
344
|
+
The flattened mapping.
|
407
345
|
"""
|
408
346
|
return flat_mapping(self)
|
409
347
|
|
410
348
|
@classmethod
|
411
|
-
def from_flat(cls, flat_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]]) -> NestedDict:
|
349
|
+
def from_flat(cls, flat_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]]) -> 'NestedDict':
|
412
350
|
"""
|
413
|
-
Create a
|
351
|
+
Create a :class:`NestedDict` from a flat mapping.
|
414
352
|
|
415
353
|
Args:
|
416
354
|
flat_dict: The flat mapping.
|
417
355
|
|
418
356
|
Returns:
|
419
|
-
The
|
357
|
+
The :class:`NestedDict`.
|
420
358
|
"""
|
421
359
|
nested_state = nest_mapping(dict(flat_dict))
|
422
360
|
return cls(nested_state)
|
@@ -426,12 +364,12 @@ class NestedDict(PrettyDict):
|
|
426
364
|
first: Filter,
|
427
365
|
/,
|
428
366
|
*filters: Filter
|
429
|
-
) -> Union[NestedDict[K, V], Tuple[NestedDict[K, V], ...]]:
|
367
|
+
) -> Union['NestedDict[K, V]', Tuple['NestedDict[K, V]', ...]]:
|
430
368
|
"""
|
431
|
-
Split a
|
432
|
-
user must pass at least one
|
369
|
+
Split a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
|
370
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`),
|
433
371
|
and the filters must be exhaustive (i.e. they must cover all
|
434
|
-
:class:`State` types in the
|
372
|
+
:class:`State` types in the :class:`NestedDict`).
|
435
373
|
|
436
374
|
Example usage::
|
437
375
|
|
@@ -474,10 +412,10 @@ class NestedDict(PrettyDict):
|
|
474
412
|
first: Filter,
|
475
413
|
/,
|
476
414
|
*filters: Filter,
|
477
|
-
) -> Union[NestedDict[K, V], Tuple[NestedDict[K, V], ...]]:
|
415
|
+
) -> Union['NestedDict[K, V]', Tuple['NestedDict[K, V]', ...]]:
|
478
416
|
"""
|
479
|
-
Filter a
|
480
|
-
user must pass at least one
|
417
|
+
Filter a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
|
418
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`).
|
481
419
|
This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
|
482
420
|
except the filters can be non-exhaustive.
|
483
421
|
|
@@ -498,21 +436,21 @@ class NestedDict(PrettyDict):
|
|
498
436
|
|
499
437
|
@staticmethod
|
500
438
|
def merge(
|
501
|
-
state: NestedDict[K, V]
|
439
|
+
state: Union['NestedDict[K, V]', 'FlattedDict[K, V]'],
|
502
440
|
/,
|
503
|
-
*states: NestedDict[K, V]
|
504
|
-
) -> NestedDict[K, V]:
|
441
|
+
*states: Union['NestedDict[K, V]', 'FlattedDict[K, V]']
|
442
|
+
) -> 'NestedDict[K, V]':
|
505
443
|
"""
|
506
444
|
The inverse of :meth:`split()`.
|
507
445
|
|
508
|
-
``merge`` takes one or more
|
446
|
+
``merge`` takes one or more :class:`PrettyDict`'s and creates a new :class:`PrettyDict`.
|
509
447
|
|
510
448
|
Args:
|
511
|
-
state: A
|
512
|
-
*states: Additional
|
449
|
+
state: A :class:`PrettyDict` object.
|
450
|
+
*states: Additional :class:`PrettyDict` objects.
|
513
451
|
|
514
452
|
Returns:
|
515
|
-
The merged
|
453
|
+
The merged :class:`PrettyDict`.
|
516
454
|
"""
|
517
455
|
if not states:
|
518
456
|
return state
|
@@ -548,11 +486,11 @@ class NestedDict(PrettyDict):
|
|
548
486
|
|
549
487
|
class FlattedDict(PrettyDict):
|
550
488
|
"""
|
551
|
-
A pytree-like structure that contains a
|
489
|
+
A pytree-like structure that contains a :class:`Mapping` from strings or integers to leaves.
|
552
490
|
|
553
491
|
A valid leaf type is either :class:`State`, ``jax.Array``, ``numpy.ndarray`` or Python variables.
|
554
492
|
|
555
|
-
A
|
493
|
+
A :class:`NestedDict` can be generated by either calling :func:`states()` or
|
556
494
|
:func:`nodes()` on the :class:`Module`.
|
557
495
|
|
558
496
|
Example usage::
|
@@ -631,13 +569,13 @@ class FlattedDict(PrettyDict):
|
|
631
569
|
"""
|
632
570
|
__module__ = 'brainstate.util'
|
633
571
|
|
634
|
-
def __or__(self, other: FlattedDict[K, V]) -> FlattedDict[K, V]:
|
572
|
+
def __or__(self, other: 'FlattedDict[K, V]') -> 'FlattedDict[K, V]':
|
635
573
|
if not other:
|
636
574
|
return self
|
637
575
|
assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
|
638
576
|
return FlattedDict.merge(self, other)
|
639
577
|
|
640
|
-
def __sub__(self, other: FlattedDict[K, V]) -> FlattedDict[K, V]:
|
578
|
+
def __sub__(self, other: 'FlattedDict[K, V]') -> 'FlattedDict[K, V]':
|
641
579
|
if not other:
|
642
580
|
return self
|
643
581
|
assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
|
@@ -649,22 +587,22 @@ class FlattedDict(PrettyDict):
|
|
649
587
|
Unflatten the flat mapping into a nested mapping.
|
650
588
|
|
651
589
|
Returns:
|
652
|
-
|
590
|
+
The nested mapping.
|
653
591
|
"""
|
654
592
|
return nest_mapping(self)
|
655
593
|
|
656
594
|
@classmethod
|
657
595
|
def from_nest(
|
658
596
|
cls, nested_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]],
|
659
|
-
) -> FlattedDict:
|
597
|
+
) -> 'FlattedDict':
|
660
598
|
"""
|
661
|
-
Create a
|
599
|
+
Create a :class:`NestedDict` from a flat mapping.
|
662
600
|
|
663
601
|
Args:
|
664
602
|
nested_dict: The flat mapping.
|
665
603
|
|
666
604
|
Returns:
|
667
|
-
The
|
605
|
+
The :class:`NestedDict`.
|
668
606
|
"""
|
669
607
|
return flat_mapping(nested_dict)
|
670
608
|
|
@@ -673,19 +611,19 @@ class FlattedDict(PrettyDict):
|
|
673
611
|
first: Filter,
|
674
612
|
/,
|
675
613
|
*filters: Filter
|
676
|
-
) -> Union[FlattedDict[K, V], tuple[FlattedDict[K, V], ...]]:
|
614
|
+
) -> Union['FlattedDict[K, V]', tuple['FlattedDict[K, V]', ...]]:
|
677
615
|
"""
|
678
|
-
Split a
|
679
|
-
user must pass at least one
|
616
|
+
Split a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
|
617
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`),
|
680
618
|
and the filters must be exhaustive (i.e. they must cover all
|
681
|
-
:class:`State` types in the
|
619
|
+
:class:`State` types in the :class:`NestedDict`).
|
682
620
|
|
683
621
|
Arguments:
|
684
|
-
|
685
|
-
|
622
|
+
first: The first filter
|
623
|
+
*filters: The optional, additional filters to group the state into mutually exclusive substates.
|
686
624
|
|
687
625
|
Returns:
|
688
|
-
|
626
|
+
One or more ``States`` equal to the number of filters passed.
|
689
627
|
"""
|
690
628
|
filters = (first, *filters)
|
691
629
|
*states_, rest = _split_flatted_mapping(self, *filters)
|
@@ -705,19 +643,19 @@ class FlattedDict(PrettyDict):
|
|
705
643
|
first: Filter,
|
706
644
|
/,
|
707
645
|
*filters: Filter,
|
708
|
-
) -> Union[FlattedDict[K, V], Tuple[FlattedDict[K, V], ...]]:
|
646
|
+
) -> Union['FlattedDict[K, V]', Tuple['FlattedDict[K, V]', ...]]:
|
709
647
|
"""
|
710
|
-
Filter a
|
711
|
-
user must pass at least one
|
648
|
+
Filter a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
|
649
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`).
|
712
650
|
This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
|
713
651
|
except the filters can be non-exhaustive.
|
714
652
|
|
715
653
|
Arguments:
|
716
|
-
|
717
|
-
|
654
|
+
first: The first filter
|
655
|
+
*filters: The optional, additional filters to group the state into mutually exclusive substates.
|
718
656
|
|
719
657
|
Returns:
|
720
|
-
|
658
|
+
One or more ``States`` equal to the number of filters passed.
|
721
659
|
"""
|
722
660
|
*states_, _rest = _split_flatted_mapping(self, first, *filters)
|
723
661
|
assert len(states_) == len(filters) + 1, f'Expected {len(filters) + 1} states, got {len(states_)}'
|
@@ -729,21 +667,21 @@ class FlattedDict(PrettyDict):
|
|
729
667
|
|
730
668
|
@staticmethod
|
731
669
|
def merge(
|
732
|
-
state: FlattedDict[K, V]
|
670
|
+
state: Union['FlattedDict[K, V]', 'NestedDict[K, V]'],
|
733
671
|
/,
|
734
|
-
*states: FlattedDict[K, V]
|
735
|
-
) -> FlattedDict[K, V]:
|
672
|
+
*states: Union['FlattedDict[K, V]', 'NestedDict[K, V]']
|
673
|
+
) -> 'FlattedDict[K, V]':
|
736
674
|
"""
|
737
675
|
The inverse of :meth:`split()`.
|
738
676
|
|
739
|
-
``merge`` takes one or more
|
677
|
+
``merge`` takes one or more :class:`FlattedDict`'s and creates a new :class:`FlattedDict`.
|
740
678
|
|
741
679
|
Args:
|
742
|
-
state: A
|
743
|
-
*states: Additional
|
680
|
+
state: A :class:`PrettyDict` object.
|
681
|
+
*states: Additional :class:`PrettyDict` objects.
|
744
682
|
|
745
683
|
Returns:
|
746
|
-
The merged
|
684
|
+
The merged :class:`PrettyDict`.
|
747
685
|
"""
|
748
686
|
if not states:
|
749
687
|
return state
|
@@ -759,14 +697,67 @@ class FlattedDict(PrettyDict):
|
|
759
697
|
return FlattedDict(new_state)
|
760
698
|
|
761
699
|
def to_dict_values(self):
|
700
|
+
"""
|
701
|
+
Convert a FlattedDict containing State objects to a plain dictionary of values.
|
702
|
+
|
703
|
+
This method extracts the underlying values from any State objects in the FlattedDict,
|
704
|
+
creating a new dictionary with the same keys but where each State object is replaced
|
705
|
+
by its value attribute. Non-State objects are kept as is.
|
706
|
+
|
707
|
+
Returns:
|
708
|
+
dict: A dictionary with the same keys as the FlattedDict, but where each State
|
709
|
+
object is replaced by its value attribute. Non-State objects remain unchanged.
|
710
|
+
|
711
|
+
Example:
|
712
|
+
>>> flat_dict = FlattedDict({('model', 'layer1', 'weight'): ParamState(value=jnp.ones((10, 5)))})
|
713
|
+
>>> flat_dict.to_dict_values()
|
714
|
+
{('model', 'layer1', 'weight'): Array([[1., 1., ...]], dtype=float32)}
|
715
|
+
"""
|
762
716
|
from brainstate._state import State
|
763
|
-
return {
|
717
|
+
return {
|
718
|
+
k: v.value if isinstance(v, State) else v
|
719
|
+
for k, v in self.items()
|
720
|
+
}
|
721
|
+
|
722
|
+
def assign_dict_values(self, data: dict):
|
723
|
+
"""
|
724
|
+
Assign values from a dictionary to this FlattedDict.
|
725
|
+
|
726
|
+
This method updates the values in the FlattedDict with values from the provided
|
727
|
+
dictionary. For keys that correspond to State objects, the value attribute of
|
728
|
+
the State is updated. For other keys, the value in the FlattedDict is directly
|
729
|
+
replaced with the new value.
|
730
|
+
|
731
|
+
The method requires that all keys in the FlattedDict exist in the provided
|
732
|
+
dictionary, otherwise a KeyError is raised.
|
733
|
+
|
734
|
+
Args:
|
735
|
+
data (dict): A dictionary containing the values to assign, where keys
|
736
|
+
must match those in the FlattedDict.
|
737
|
+
|
738
|
+
Raises:
|
739
|
+
KeyError: If a key in the FlattedDict is not present in the provided dictionary.
|
740
|
+
|
741
|
+
Example:
|
742
|
+
>>> flat_dict = FlattedDict({('model', 'weight'): ParamState(value=jnp.zeros((5, 5)))})
|
743
|
+
>>> flat_dict.assign_dict_values({('model', 'weight'): jnp.ones((5, 5))})
|
744
|
+
# The ParamState's value is now an array of ones
|
745
|
+
"""
|
746
|
+
from brainstate._state import State
|
747
|
+
for k in self.keys():
|
748
|
+
if k not in data:
|
749
|
+
raise KeyError(f'Invalid key: {k!r}')
|
750
|
+
val = self[k]
|
751
|
+
if isinstance(val, State):
|
752
|
+
val.value = data[k]
|
753
|
+
else:
|
754
|
+
self[k] = data[k]
|
764
755
|
|
765
756
|
|
766
757
|
def _split_nested_mapping(
|
767
|
-
mapping: NestedDict[K, V],
|
758
|
+
mapping: 'NestedDict[K, V]',
|
768
759
|
*filters: Filter,
|
769
|
-
) -> Tuple[NestedDict[K, V], ...]:
|
760
|
+
) -> Tuple['NestedDict[K, V]', ...]:
|
770
761
|
# check if the filters are exhaustive
|
771
762
|
for i, filter_ in enumerate(filters):
|
772
763
|
if filter_ in (..., True) and i != len(filters) - 1:
|
@@ -828,7 +819,7 @@ def _split_flatted_mapping(
|
|
828
819
|
return tuple(FlattedDict(flat_state) for flat_state in flat_states)
|
829
820
|
|
830
821
|
|
831
|
-
# register
|
822
|
+
# register :class:`NestedDict` as a pytree
|
832
823
|
def _nest_flatten_with_keys(x: NestedDict):
|
833
824
|
items = sorted(x.items())
|
834
825
|
children = tuple((jax.tree_util.DictKey(key), value) for key, value in items)
|
@@ -847,7 +838,7 @@ jax.tree_util.register_pytree_with_keys(NestedDict,
|
|
847
838
|
_nest_unflatten) # type: ignore[arg-type]
|
848
839
|
|
849
840
|
|
850
|
-
# register
|
841
|
+
# register :class:`FlattedDict` as a pytree
|
851
842
|
|
852
843
|
def _flat_unflatten(
|
853
844
|
static: Tuple[K, ...],
|
@@ -892,3 +883,63 @@ def _list_repr_attr(node: PrettyList):
|
|
892
883
|
|
893
884
|
def _list_repr_object(node: PrettyDict):
|
894
885
|
yield PrettyType('', value_sep='', start='[', end=']')
|
886
|
+
|
887
|
+
|
888
|
+
def _repr_object_general(node: PrettyDict):
|
889
|
+
"""
|
890
|
+
Generate a general representation of a PrettyDict object.
|
891
|
+
|
892
|
+
This function is used to create a pretty representation of a PrettyDict
|
893
|
+
object, which includes the type of the object and its value separator.
|
894
|
+
|
895
|
+
Args:
|
896
|
+
node (PrettyDict): The PrettyDict object to be represented.
|
897
|
+
|
898
|
+
Yields:
|
899
|
+
PrettyType: A PrettyType object representing the type of the node,
|
900
|
+
with specified value separator, start, and end characters.
|
901
|
+
"""
|
902
|
+
yield PrettyType(type(node), value_sep='=', start='(', end=')')
|
903
|
+
|
904
|
+
|
905
|
+
def _repr_attribute_general(node):
|
906
|
+
"""
|
907
|
+
Generate a pretty representation of the attributes of a node.
|
908
|
+
|
909
|
+
This function iterates over the attributes of a given node and attempts
|
910
|
+
to generate a pretty representation for each attribute. It handles
|
911
|
+
conversion of lists and dictionaries to their pretty representation
|
912
|
+
counterparts and yields a PrettyAttr object for each attribute.
|
913
|
+
|
914
|
+
Args:
|
915
|
+
node: The object whose attributes are to be represented.
|
916
|
+
|
917
|
+
Yields:
|
918
|
+
PrettyAttr: A PrettyAttr object representing the key and value of
|
919
|
+
each attribute in a pretty format.
|
920
|
+
"""
|
921
|
+
for k, v in vars(node).items():
|
922
|
+
try:
|
923
|
+
res = node.__pretty_repr_item__(k, v)
|
924
|
+
if res is None:
|
925
|
+
continue
|
926
|
+
k, v = res
|
927
|
+
except AttributeError:
|
928
|
+
pass
|
929
|
+
|
930
|
+
if k is None:
|
931
|
+
continue
|
932
|
+
|
933
|
+
# convert list to PrettyList
|
934
|
+
if isinstance(v, list):
|
935
|
+
v = PrettyList(v)
|
936
|
+
|
937
|
+
# convert dict to PrettyDict
|
938
|
+
if isinstance(v, dict):
|
939
|
+
v = PrettyDict(v)
|
940
|
+
|
941
|
+
# convert PrettyDict to NestedStateRepr
|
942
|
+
if isinstance(v, PrettyDict):
|
943
|
+
v = NestedStateRepr(v)
|
944
|
+
|
945
|
+
yield PrettyAttr(k, v)
|
brainstate/util/_pretty_repr.py
CHANGED
brainstate/util/_pretty_table.py
CHANGED
@@ -469,7 +469,7 @@ class PrettyTable:
|
|
469
469
|
else:
|
470
470
|
raise AttributeError(name)
|
471
471
|
|
472
|
-
def __getitem__(self, index: int | slice) -> PrettyTable:
|
472
|
+
def __getitem__(self, index: int | slice) -> 'PrettyTable':
|
473
473
|
new = PrettyTable()
|
474
474
|
new.field_names = self.field_names
|
475
475
|
for attr in self._options:
|
@@ -2713,10 +2713,64 @@ class PrettyTable:
|
|
2713
2713
|
##############################
|
2714
2714
|
|
2715
2715
|
|
2716
|
+
# def _str_block_width(val: str) -> int:
|
2717
|
+
# import wcwidth # type: ignore[import-untyped]
|
2718
|
+
#
|
2719
|
+
# return wcwidth.wcswidth(_re.sub("", val))
|
2720
|
+
|
2716
2721
|
def _str_block_width(val: str) -> int:
|
2717
|
-
|
2722
|
+
"""Calculate the visual width of a string, accounting for wide Unicode characters.
|
2723
|
+
|
2724
|
+
This is a custom implementation to replace the wcwidth dependency.
|
2725
|
+
|
2726
|
+
Args:
|
2727
|
+
val: The string to measure
|
2728
|
+
|
2729
|
+
Returns:
|
2730
|
+
The visual width of the string when displayed in a monospace context
|
2731
|
+
"""
|
2732
|
+
# Remove ANSI escape sequences
|
2733
|
+
val = _re.sub("", val)
|
2734
|
+
|
2735
|
+
# Fast path for ASCII-only strings
|
2736
|
+
if all(ord(c) < 128 for c in val):
|
2737
|
+
return len(val)
|
2718
2738
|
|
2719
|
-
|
2739
|
+
width = 0
|
2740
|
+
for char in val:
|
2741
|
+
width += _char_width(char)
|
2742
|
+
|
2743
|
+
return width
|
2744
|
+
|
2745
|
+
|
2746
|
+
def _char_width(char: str) -> int:
|
2747
|
+
"""Calculate the display width of a single character.
|
2748
|
+
|
2749
|
+
Args:
|
2750
|
+
char: A single Unicode character
|
2751
|
+
|
2752
|
+
Returns:
|
2753
|
+
0 for control characters
|
2754
|
+
2 for wide characters (CJK, emoji, etc.)
|
2755
|
+
1 for all other characters
|
2756
|
+
"""
|
2757
|
+
code = ord(char)
|
2758
|
+
|
2759
|
+
# Control characters and empty space
|
2760
|
+
if code == 0 or code == 0x034F or (0x200B <= code <= 0x200F) or code == 0x2028 or code == 0x2029 or (
|
2761
|
+
0x202A <= code <= 0x202E) or (0x2060 <= code <= 0x2063):
|
2762
|
+
return 0
|
2763
|
+
|
2764
|
+
# Wide characters: CJK, emoji, etc.
|
2765
|
+
if (0x1100 <= code <= 0x115F) or (0x2329 <= code <= 0x232A) or (0x2E80 <= code <= 0x303E) or (
|
2766
|
+
0x3040 <= code <= 0x4DBF) or (0x4E00 <= code <= 0x9FFF) or (0xA000 <= code <= 0xA4CF) or (
|
2767
|
+
0xAC00 <= code <= 0xD7A3) or (0xF900 <= code <= 0xFAFF) or (0xFE10 <= code <= 0xFE19) or (
|
2768
|
+
0xFE30 <= code <= 0xFE6F) or (0xFF00 <= code <= 0xFF60) or (0xFFE0 <= code <= 0xFFE6) or (
|
2769
|
+
0x1F300 <= code <= 0x1F64F) or (0x1F900 <= code <= 0x1F9FF):
|
2770
|
+
return 2
|
2771
|
+
|
2772
|
+
# Default single-width character
|
2773
|
+
return 1
|
2720
2774
|
|
2721
2775
|
|
2722
2776
|
##############################
|
brainstate/util/_scaling.py
CHANGED
brainstate/util/_struct.py
CHANGED