brainstate 0.1.0.post20250501__py2.py3-none-any.whl → 0.1.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +8 -19
  3. brainstate/_state.py +177 -177
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_eval_shape.py +0 -2
  7. brainstate/augment/_mapping.py +2 -3
  8. brainstate/augment/_random.py +0 -2
  9. brainstate/compile/_ad_checkpoint.py +0 -2
  10. brainstate/compile/_conditions.py +0 -2
  11. brainstate/compile/_error_if.py +0 -2
  12. brainstate/compile/_jit.py +9 -8
  13. brainstate/compile/_loop_collect_return.py +0 -2
  14. brainstate/compile/_loop_no_collection.py +0 -2
  15. brainstate/compile/_make_jaxpr.py +4 -6
  16. brainstate/compile/_progress_bar.py +0 -1
  17. brainstate/compile/_unvmap.py +0 -1
  18. brainstate/compile/_util.py +0 -2
  19. brainstate/environ.py +0 -2
  20. brainstate/functional/_activations.py +0 -2
  21. brainstate/functional/_normalization.py +0 -2
  22. brainstate/functional/_others.py +0 -2
  23. brainstate/functional/_spikes.py +0 -1
  24. brainstate/graph/_graph_node.py +1 -3
  25. brainstate/graph/_graph_operation.py +4 -2
  26. brainstate/init/_base.py +0 -2
  27. brainstate/init/_generic.py +0 -1
  28. brainstate/init/_random_inits.py +0 -1
  29. brainstate/init/_regular_inits.py +0 -2
  30. brainstate/mixin.py +0 -2
  31. brainstate/nn/_collective_ops.py +0 -3
  32. brainstate/nn/_common.py +0 -2
  33. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  34. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  35. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  36. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  37. brainstate/nn/_dyn_impl/_readout.py +0 -1
  38. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  39. brainstate/nn/_dynamics/_projection_base.py +0 -1
  40. brainstate/nn/_dynamics/_state_delay.py +0 -2
  41. brainstate/nn/_dynamics/_synouts.py +0 -2
  42. brainstate/nn/_elementwise/_dropout.py +0 -2
  43. brainstate/nn/_elementwise/_elementwise.py +0 -2
  44. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  45. brainstate/nn/_event/_linear_mv.py +0 -2
  46. brainstate/nn/_exp_euler.py +0 -2
  47. brainstate/nn/_interaction/_conv.py +0 -2
  48. brainstate/nn/_interaction/_embedding.py +0 -1
  49. brainstate/nn/_interaction/_linear.py +0 -2
  50. brainstate/nn/_interaction/_normalizations.py +0 -2
  51. brainstate/nn/_interaction/_poolings.py +0 -2
  52. brainstate/nn/_module.py +0 -1
  53. brainstate/nn/metrics.py +0 -2
  54. brainstate/optim/_base.py +0 -2
  55. brainstate/optim/_lr_scheduler.py +0 -1
  56. brainstate/optim/_optax_optimizer.py +0 -2
  57. brainstate/optim/_sgd_optimizer.py +0 -1
  58. brainstate/random/_rand_funs.py +0 -1
  59. brainstate/random/_rand_seed.py +0 -1
  60. brainstate/random/_rand_state.py +0 -1
  61. brainstate/surrogate.py +0 -1
  62. brainstate/typing.py +0 -2
  63. brainstate/util/_caller.py +4 -6
  64. brainstate/util/_others.py +0 -2
  65. brainstate/util/_pretty_pytree.py +201 -150
  66. brainstate/util/_pretty_repr.py +0 -2
  67. brainstate/util/_pretty_table.py +57 -3
  68. brainstate/util/_scaling.py +0 -2
  69. brainstate/util/_struct.py +0 -2
  70. brainstate/util/filter.py +0 -2
  71. {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/METADATA +16 -6
  72. brainstate-0.1.1.dist-info/RECORD +133 -0
  73. brainstate-0.1.0.post20250501.dist-info/RECORD +0 -133
  74. {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/LICENSE +0 -0
  75. {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/WHEEL +0 -0
  76. {brainstate-0.1.0.post20250501.dist-info → brainstate-0.1.1.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,6 @@
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
17
 
18
- from __future__ import annotations
19
-
20
18
  from collections import abc
21
19
  from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict
22
20
 
@@ -46,66 +44,6 @@ ExtractValueFn = abc.Callable[[Any], Any]
46
44
  SetValueFn = abc.Callable[[V, Any], V]
47
45
 
48
46
 
49
- def _repr_object_general(node: PrettyDict):
50
- """
51
- Generate a general representation of a PrettyDict object.
52
-
53
- This function is used to create a pretty representation of a PrettyDict
54
- object, which includes the type of the object and its value separator.
55
-
56
- Args:
57
- node (PrettyDict): The PrettyDict object to be represented.
58
-
59
- Yields:
60
- PrettyType: A PrettyType object representing the type of the node,
61
- with specified value separator, start, and end characters.
62
- """
63
- yield PrettyType(type(node), value_sep='=', start='(', end=')')
64
-
65
-
66
- def _repr_attribute_general(node):
67
- """
68
- Generate a pretty representation of the attributes of a node.
69
-
70
- This function iterates over the attributes of a given node and attempts
71
- to generate a pretty representation for each attribute. It handles
72
- conversion of lists and dictionaries to their pretty representation
73
- counterparts and yields a PrettyAttr object for each attribute.
74
-
75
- Args:
76
- node: The object whose attributes are to be represented.
77
-
78
- Yields:
79
- PrettyAttr: A PrettyAttr object representing the key and value of
80
- each attribute in a pretty format.
81
- """
82
- for k, v in vars(node).items():
83
- try:
84
- res = node.__pretty_repr_item__(k, v)
85
- if res is None:
86
- continue
87
- k, v = res
88
- except AttributeError:
89
- pass
90
-
91
- if k is None:
92
- continue
93
-
94
- # convert list to PrettyList
95
- if isinstance(v, list):
96
- v = PrettyList(v)
97
-
98
- # convert dict to PrettyDict
99
- if isinstance(v, dict):
100
- v = PrettyDict(v)
101
-
102
- # convert PrettyDict to NestedStateRepr
103
- if isinstance(v, PrettyDict):
104
- v = NestedStateRepr(v)
105
-
106
- yield PrettyAttr(k, v)
107
-
108
-
109
47
  class PrettyObject(PrettyRepr):
110
48
  """
111
49
  A class for generating a pretty representation of a tree-like structure.
@@ -284,37 +222,15 @@ def _default_process(x):
284
222
  return id(x)
285
223
 
286
224
 
287
- class NestedStateRepr(PrettyRepr):
288
- def __init__(self, state: PrettyDict):
289
- self.state = state
290
-
291
- def __pretty_repr__(self):
292
- yield PrettyType('', value_sep=': ', start='{', end='}')
293
-
294
- for r in self.state.__pretty_repr__():
295
- if isinstance(r, PrettyType):
296
- continue
297
- yield r
298
-
299
- def __treescope_repr__(self, path, subtree_renderer):
300
- children = {}
301
- for k, v in self.state.items():
302
- if isinstance(v, PrettyDict):
303
- v = NestedStateRepr(v)
304
- children[k] = v
305
- # Render as the dictionary itself at the same path.
306
- return subtree_renderer(children, path=path)
307
-
308
-
309
225
  class PrettyDict(dict, PrettyRepr):
310
226
  __module__ = 'brainstate.util'
311
227
 
312
- def __getattr__(self, key: K) -> NestedMapping | V: # type: ignore[misc]
228
+ def __getattr__(self, key: K): # type: ignore[misc]
313
229
  return self[key]
314
230
 
315
231
  def treefy_state(self):
316
232
  """
317
- Convert the ``State`` objects to a reference tree of the state.
233
+ Convert the :class:`State` objects to a reference tree of the state.
318
234
  """
319
235
  from brainstate._state import State
320
236
  leaves, treedef = jax.tree.flatten(self)
@@ -323,7 +239,7 @@ class PrettyDict(dict, PrettyRepr):
323
239
 
324
240
  def to_dict(self) -> Dict[K, Dict[K, Any] | V]:
325
241
  """
326
- Convert the ``PrettyDict`` to a dictionary.
242
+ Convert the :class:`PrettyDict` to a dictionary.
327
243
 
328
244
  Returns:
329
245
  The dictionary.
@@ -337,24 +253,46 @@ class PrettyDict(dict, PrettyRepr):
337
253
  def __pretty_repr__(self):
338
254
  yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
339
255
 
340
- def split(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
256
+ def split(self, *filters) -> Union['PrettyDict[K, V]', Tuple['PrettyDict[K, V]', ...]]:
341
257
  raise NotImplementedError
342
258
 
343
- def filter(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
259
+ def filter(self, *filters) -> Union['PrettyDict[K, V]', Tuple['PrettyDict[K, V]', ...]]:
344
260
  raise NotImplementedError
345
261
 
346
- def merge(self, *states) -> PrettyDict[K, V]:
262
+ def merge(self, *states) -> 'PrettyDict[K, V]':
347
263
  raise NotImplementedError
348
264
 
349
- def subset(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
265
+ def subset(self, *filters) -> Union['PrettyDict[K, V]', Tuple['PrettyDict[K, V]', ...]]:
350
266
  """
351
- Subset a ``PrettyDict`` into one or more ``PrettyDict``'s. The user must pass at least one
352
- ``Filter`` (i.e. :class:`State`), and the filters must be exhaustive (i.e. they must cover all
353
- :class:`State` types in the ``PrettyDict``).
267
+ Subset a :class:`PrettyDict` into one or more :class:`PrettyDict`'s. The user must pass at least one
268
+ `:class:`Filter` (i.e. :class:`State`), and the filters must be exhaustive (i.e. they must cover all
269
+ :class:`State` types in the :class:`PrettyDict`).
354
270
  """
355
271
  return self.filter(*filters)
356
272
 
357
273
 
274
+ class NestedStateRepr(PrettyRepr):
275
+ def __init__(self, state: PrettyDict):
276
+ self.state = state
277
+
278
+ def __pretty_repr__(self):
279
+ yield PrettyType('', value_sep=': ', start='{', end='}')
280
+
281
+ for r in self.state.__pretty_repr__():
282
+ if isinstance(r, PrettyType):
283
+ continue
284
+ yield r
285
+
286
+ def __treescope_repr__(self, path, subtree_renderer):
287
+ children = {}
288
+ for k, v in self.state.items():
289
+ if isinstance(v, PrettyDict):
290
+ v = NestedStateRepr(v)
291
+ children[k] = v
292
+ # Render as the dictionary itself at the same path.
293
+ return subtree_renderer(children, path=path)
294
+
295
+
358
296
  def _default_repr_object(node: PrettyDict):
359
297
  yield PrettyType('', value_sep=': ', start='{', end='}')
360
298
 
@@ -375,20 +313,20 @@ def _default_repr_attr(node):
375
313
 
376
314
  class NestedDict(PrettyDict):
377
315
  """
378
- A pytree-like structure that contains a ``Mapping`` from strings or integers to leaves.
316
+ A pytree-like structure that contains a :class:`Mapping` from strings or integers to leaves.
379
317
 
380
318
  A valid leaf type is either :class:`State`, ``jax.Array``, ``numpy.ndarray`` or nested
381
- ``NestedDict`` and ``FlattedDict``.
319
+ :class:`NestedDict` and :class:`FlattedDict`.
382
320
  """
383
321
  __module__ = 'brainstate.util'
384
322
 
385
- def __or__(self, other: NestedDict[K, V]) -> NestedDict[K, V]:
323
+ def __or__(self, other: 'NestedDict[K, V]') -> 'NestedDict[K, V]':
386
324
  if not other:
387
325
  return self
388
326
  assert isinstance(other, NestedDict), f'expected NestedDict; got {type(other).__qualname__}'
389
327
  return NestedDict.merge(self, other)
390
328
 
391
- def __sub__(self, other: NestedDict[K, V]) -> NestedDict[K, V]:
329
+ def __sub__(self, other: 'NestedDict[K, V]') -> 'NestedDict[K, V]':
392
330
  if not other:
393
331
  return self
394
332
 
@@ -398,25 +336,25 @@ class NestedDict(PrettyDict):
398
336
  diff = {k: v for k, v in self_flat.items() if k not in other_flat}
399
337
  return NestedDict.from_flat(diff)
400
338
 
401
- def to_flat(self) -> FlattedDict:
339
+ def to_flat(self) -> 'FlattedDict':
402
340
  """
403
341
  Flatten the nested mapping into a flat mapping.
404
342
 
405
343
  Returns:
406
- The flattened mapping.
344
+ The flattened mapping.
407
345
  """
408
346
  return flat_mapping(self)
409
347
 
410
348
  @classmethod
411
- def from_flat(cls, flat_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]]) -> NestedDict:
349
+ def from_flat(cls, flat_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]]) -> 'NestedDict':
412
350
  """
413
- Create a ``NestedDict`` from a flat mapping.
351
+ Create a :class:`NestedDict` from a flat mapping.
414
352
 
415
353
  Args:
416
354
  flat_dict: The flat mapping.
417
355
 
418
356
  Returns:
419
- The ``NestedDict``.
357
+ The :class:`NestedDict`.
420
358
  """
421
359
  nested_state = nest_mapping(dict(flat_dict))
422
360
  return cls(nested_state)
@@ -426,12 +364,12 @@ class NestedDict(PrettyDict):
426
364
  first: Filter,
427
365
  /,
428
366
  *filters: Filter
429
- ) -> Union[NestedDict[K, V], Tuple[NestedDict[K, V], ...]]:
367
+ ) -> Union['NestedDict[K, V]', Tuple['NestedDict[K, V]', ...]]:
430
368
  """
431
- Split a ``NestedDict`` into one or more ``NestedDict``'s. The
432
- user must pass at least one ``Filter`` (i.e. :class:`State`),
369
+ Split a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
370
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`),
433
371
  and the filters must be exhaustive (i.e. they must cover all
434
- :class:`State` types in the ``NestedDict``).
372
+ :class:`State` types in the :class:`NestedDict`).
435
373
 
436
374
  Example usage::
437
375
 
@@ -474,10 +412,10 @@ class NestedDict(PrettyDict):
474
412
  first: Filter,
475
413
  /,
476
414
  *filters: Filter,
477
- ) -> Union[NestedDict[K, V], Tuple[NestedDict[K, V], ...]]:
415
+ ) -> Union['NestedDict[K, V]', Tuple['NestedDict[K, V]', ...]]:
478
416
  """
479
- Filter a ``NestedDict`` into one or more ``NestedDict``'s. The
480
- user must pass at least one ``Filter`` (i.e. :class:`State`).
417
+ Filter a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
418
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`).
481
419
  This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
482
420
  except the filters can be non-exhaustive.
483
421
 
@@ -498,21 +436,21 @@ class NestedDict(PrettyDict):
498
436
 
499
437
  @staticmethod
500
438
  def merge(
501
- state: NestedDict[K, V] | FlattedDict[K, V],
439
+ state: Union['NestedDict[K, V]', 'FlattedDict[K, V]'],
502
440
  /,
503
- *states: NestedDict[K, V] | FlattedDict[K, V]
504
- ) -> NestedDict[K, V]:
441
+ *states: Union['NestedDict[K, V]', 'FlattedDict[K, V]']
442
+ ) -> 'NestedDict[K, V]':
505
443
  """
506
444
  The inverse of :meth:`split()`.
507
445
 
508
- ``merge`` takes one or more ``PrettyDict``'s and creates a new ``PrettyDict``.
446
+ ``merge`` takes one or more :class:`PrettyDict`'s and creates a new :class:`PrettyDict`.
509
447
 
510
448
  Args:
511
- state: A ``PrettyDict`` object.
512
- *states: Additional ``PrettyDict`` objects.
449
+ state: A :class:`PrettyDict` object.
450
+ *states: Additional :class:`PrettyDict` objects.
513
451
 
514
452
  Returns:
515
- The merged ``PrettyDict``.
453
+ The merged :class:`PrettyDict`.
516
454
  """
517
455
  if not states:
518
456
  return state
@@ -548,11 +486,11 @@ class NestedDict(PrettyDict):
548
486
 
549
487
  class FlattedDict(PrettyDict):
550
488
  """
551
- A pytree-like structure that contains a ``Mapping`` from strings or integers to leaves.
489
+ A pytree-like structure that contains a :class:`Mapping` from strings or integers to leaves.
552
490
 
553
491
  A valid leaf type is either :class:`State`, ``jax.Array``, ``numpy.ndarray`` or Python variables.
554
492
 
555
- A ``NestedDict`` can be generated by either calling :func:`states()` or
493
+ A :class:`NestedDict` can be generated by either calling :func:`states()` or
556
494
  :func:`nodes()` on the :class:`Module`.
557
495
 
558
496
  Example usage::
@@ -631,13 +569,13 @@ class FlattedDict(PrettyDict):
631
569
  """
632
570
  __module__ = 'brainstate.util'
633
571
 
634
- def __or__(self, other: FlattedDict[K, V]) -> FlattedDict[K, V]:
572
+ def __or__(self, other: 'FlattedDict[K, V]') -> 'FlattedDict[K, V]':
635
573
  if not other:
636
574
  return self
637
575
  assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
638
576
  return FlattedDict.merge(self, other)
639
577
 
640
- def __sub__(self, other: FlattedDict[K, V]) -> FlattedDict[K, V]:
578
+ def __sub__(self, other: 'FlattedDict[K, V]') -> 'FlattedDict[K, V]':
641
579
  if not other:
642
580
  return self
643
581
  assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
@@ -649,22 +587,22 @@ class FlattedDict(PrettyDict):
649
587
  Unflatten the flat mapping into a nested mapping.
650
588
 
651
589
  Returns:
652
- The nested mapping.
590
+ The nested mapping.
653
591
  """
654
592
  return nest_mapping(self)
655
593
 
656
594
  @classmethod
657
595
  def from_nest(
658
596
  cls, nested_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]],
659
- ) -> FlattedDict:
597
+ ) -> 'FlattedDict':
660
598
  """
661
- Create a ``NestedDict`` from a flat mapping.
599
+ Create a :class:`NestedDict` from a flat mapping.
662
600
 
663
601
  Args:
664
602
  nested_dict: The flat mapping.
665
603
 
666
604
  Returns:
667
- The ``NestedDict``.
605
+ The :class:`NestedDict`.
668
606
  """
669
607
  return flat_mapping(nested_dict)
670
608
 
@@ -673,19 +611,19 @@ class FlattedDict(PrettyDict):
673
611
  first: Filter,
674
612
  /,
675
613
  *filters: Filter
676
- ) -> Union[FlattedDict[K, V], tuple[FlattedDict[K, V], ...]]:
614
+ ) -> Union['FlattedDict[K, V]', tuple['FlattedDict[K, V]', ...]]:
677
615
  """
678
- Split a ``FlattedDict`` into one or more ``FlattedDict``'s. The
679
- user must pass at least one ``Filter`` (i.e. :class:`State`),
616
+ Split a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
617
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`),
680
618
  and the filters must be exhaustive (i.e. they must cover all
681
- :class:`State` types in the ``NestedDict``).
619
+ :class:`State` types in the :class:`NestedDict`).
682
620
 
683
621
  Arguments:
684
- first: The first filter
685
- *filters: The optional, additional filters to group the state into mutually exclusive substates.
622
+ first: The first filter
623
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
686
624
 
687
625
  Returns:
688
- One or more ``States`` equal to the number of filters passed.
626
+ One or more ``States`` equal to the number of filters passed.
689
627
  """
690
628
  filters = (first, *filters)
691
629
  *states_, rest = _split_flatted_mapping(self, *filters)
@@ -705,19 +643,19 @@ class FlattedDict(PrettyDict):
705
643
  first: Filter,
706
644
  /,
707
645
  *filters: Filter,
708
- ) -> Union[FlattedDict[K, V], Tuple[FlattedDict[K, V], ...]]:
646
+ ) -> Union['FlattedDict[K, V]', Tuple['FlattedDict[K, V]', ...]]:
709
647
  """
710
- Filter a ``FlattedDict`` into one or more ``FlattedDict``'s. The
711
- user must pass at least one ``Filter`` (i.e. :class:`State`).
648
+ Filter a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
649
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`).
712
650
  This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
713
651
  except the filters can be non-exhaustive.
714
652
 
715
653
  Arguments:
716
- first: The first filter
717
- *filters: The optional, additional filters to group the state into mutually exclusive substates.
654
+ first: The first filter
655
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
718
656
 
719
657
  Returns:
720
- One or more ``States`` equal to the number of filters passed.
658
+ One or more ``States`` equal to the number of filters passed.
721
659
  """
722
660
  *states_, _rest = _split_flatted_mapping(self, first, *filters)
723
661
  assert len(states_) == len(filters) + 1, f'Expected {len(filters) + 1} states, got {len(states_)}'
@@ -729,21 +667,21 @@ class FlattedDict(PrettyDict):
729
667
 
730
668
  @staticmethod
731
669
  def merge(
732
- state: FlattedDict[K, V] | NestedDict[K, V],
670
+ state: Union['FlattedDict[K, V]', 'NestedDict[K, V]'],
733
671
  /,
734
- *states: FlattedDict[K, V] | NestedDict[K, V]
735
- ) -> FlattedDict[K, V]:
672
+ *states: Union['FlattedDict[K, V]', 'NestedDict[K, V]']
673
+ ) -> 'FlattedDict[K, V]':
736
674
  """
737
675
  The inverse of :meth:`split()`.
738
676
 
739
- ``merge`` takes one or more ``FlattedDict``'s and creates a new ``FlattedDict``.
677
+ ``merge`` takes one or more :class:`FlattedDict`'s and creates a new :class:`FlattedDict`.
740
678
 
741
679
  Args:
742
- state: A ``PrettyDict`` object.
743
- *states: Additional ``PrettyDict`` objects.
680
+ state: A :class:`PrettyDict` object.
681
+ *states: Additional :class:`PrettyDict` objects.
744
682
 
745
683
  Returns:
746
- The merged ``PrettyDict``.
684
+ The merged :class:`PrettyDict`.
747
685
  """
748
686
  if not states:
749
687
  return state
@@ -759,14 +697,67 @@ class FlattedDict(PrettyDict):
759
697
  return FlattedDict(new_state)
760
698
 
761
699
  def to_dict_values(self):
700
+ """
701
+ Convert a FlattedDict containing State objects to a plain dictionary of values.
702
+
703
+ This method extracts the underlying values from any State objects in the FlattedDict,
704
+ creating a new dictionary with the same keys but where each State object is replaced
705
+ by its value attribute. Non-State objects are kept as is.
706
+
707
+ Returns:
708
+ dict: A dictionary with the same keys as the FlattedDict, but where each State
709
+ object is replaced by its value attribute. Non-State objects remain unchanged.
710
+
711
+ Example:
712
+ >>> flat_dict = FlattedDict({('model', 'layer1', 'weight'): ParamState(value=jnp.ones((10, 5)))})
713
+ >>> flat_dict.to_dict_values()
714
+ {('model', 'layer1', 'weight'): Array([[1., 1., ...]], dtype=float32)}
715
+ """
762
716
  from brainstate._state import State
763
- return {k: v.value if isinstance(v, State) else v for k, v in self.items()}
717
+ return {
718
+ k: v.value if isinstance(v, State) else v
719
+ for k, v in self.items()
720
+ }
721
+
722
+ def assign_dict_values(self, data: dict):
723
+ """
724
+ Assign values from a dictionary to this FlattedDict.
725
+
726
+ This method updates the values in the FlattedDict with values from the provided
727
+ dictionary. For keys that correspond to State objects, the value attribute of
728
+ the State is updated. For other keys, the value in the FlattedDict is directly
729
+ replaced with the new value.
730
+
731
+ The method requires that all keys in the FlattedDict exist in the provided
732
+ dictionary, otherwise a KeyError is raised.
733
+
734
+ Args:
735
+ data (dict): A dictionary containing the values to assign, where keys
736
+ must match those in the FlattedDict.
737
+
738
+ Raises:
739
+ KeyError: If a key in the FlattedDict is not present in the provided dictionary.
740
+
741
+ Example:
742
+ >>> flat_dict = FlattedDict({('model', 'weight'): ParamState(value=jnp.zeros((5, 5)))})
743
+ >>> flat_dict.assign_dict_values({('model', 'weight'): jnp.ones((5, 5))})
744
+ # The ParamState's value is now an array of ones
745
+ """
746
+ from brainstate._state import State
747
+ for k in self.keys():
748
+ if k not in data:
749
+ raise KeyError(f'Invalid key: {k!r}')
750
+ val = self[k]
751
+ if isinstance(val, State):
752
+ val.value = data[k]
753
+ else:
754
+ self[k] = data[k]
764
755
 
765
756
 
766
757
  def _split_nested_mapping(
767
- mapping: NestedDict[K, V],
758
+ mapping: 'NestedDict[K, V]',
768
759
  *filters: Filter,
769
- ) -> Tuple[NestedDict[K, V], ...]:
760
+ ) -> Tuple['NestedDict[K, V]', ...]:
770
761
  # check if the filters are exhaustive
771
762
  for i, filter_ in enumerate(filters):
772
763
  if filter_ in (..., True) and i != len(filters) - 1:
@@ -828,7 +819,7 @@ def _split_flatted_mapping(
828
819
  return tuple(FlattedDict(flat_state) for flat_state in flat_states)
829
820
 
830
821
 
831
- # register ``NestedDict`` as a pytree
822
+ # register :class:`NestedDict` as a pytree
832
823
  def _nest_flatten_with_keys(x: NestedDict):
833
824
  items = sorted(x.items())
834
825
  children = tuple((jax.tree_util.DictKey(key), value) for key, value in items)
@@ -847,7 +838,7 @@ jax.tree_util.register_pytree_with_keys(NestedDict,
847
838
  _nest_unflatten) # type: ignore[arg-type]
848
839
 
849
840
 
850
- # register ``FlattedDict`` as a pytree
841
+ # register :class:`FlattedDict` as a pytree
851
842
 
852
843
  def _flat_unflatten(
853
844
  static: Tuple[K, ...],
@@ -892,3 +883,63 @@ def _list_repr_attr(node: PrettyList):
892
883
 
893
884
  def _list_repr_object(node: PrettyDict):
894
885
  yield PrettyType('', value_sep='', start='[', end=']')
886
+
887
+
888
+ def _repr_object_general(node: PrettyDict):
889
+ """
890
+ Generate a general representation of a PrettyDict object.
891
+
892
+ This function is used to create a pretty representation of a PrettyDict
893
+ object, which includes the type of the object and its value separator.
894
+
895
+ Args:
896
+ node (PrettyDict): The PrettyDict object to be represented.
897
+
898
+ Yields:
899
+ PrettyType: A PrettyType object representing the type of the node,
900
+ with specified value separator, start, and end characters.
901
+ """
902
+ yield PrettyType(type(node), value_sep='=', start='(', end=')')
903
+
904
+
905
+ def _repr_attribute_general(node):
906
+ """
907
+ Generate a pretty representation of the attributes of a node.
908
+
909
+ This function iterates over the attributes of a given node and attempts
910
+ to generate a pretty representation for each attribute. It handles
911
+ conversion of lists and dictionaries to their pretty representation
912
+ counterparts and yields a PrettyAttr object for each attribute.
913
+
914
+ Args:
915
+ node: The object whose attributes are to be represented.
916
+
917
+ Yields:
918
+ PrettyAttr: A PrettyAttr object representing the key and value of
919
+ each attribute in a pretty format.
920
+ """
921
+ for k, v in vars(node).items():
922
+ try:
923
+ res = node.__pretty_repr_item__(k, v)
924
+ if res is None:
925
+ continue
926
+ k, v = res
927
+ except AttributeError:
928
+ pass
929
+
930
+ if k is None:
931
+ continue
932
+
933
+ # convert list to PrettyList
934
+ if isinstance(v, list):
935
+ v = PrettyList(v)
936
+
937
+ # convert dict to PrettyDict
938
+ if isinstance(v, dict):
939
+ v = PrettyDict(v)
940
+
941
+ # convert PrettyDict to NestedStateRepr
942
+ if isinstance(v, PrettyDict):
943
+ v = NestedStateRepr(v)
944
+
945
+ yield PrettyAttr(k, v)
@@ -15,8 +15,6 @@
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
17
 
18
- from __future__ import annotations
19
-
20
18
  import dataclasses
21
19
  import threading
22
20
  from abc import ABC, abstractmethod
@@ -469,7 +469,7 @@ class PrettyTable:
469
469
  else:
470
470
  raise AttributeError(name)
471
471
 
472
- def __getitem__(self, index: int | slice) -> PrettyTable:
472
+ def __getitem__(self, index: int | slice) -> 'PrettyTable':
473
473
  new = PrettyTable()
474
474
  new.field_names = self.field_names
475
475
  for attr in self._options:
@@ -2713,10 +2713,64 @@ class PrettyTable:
2713
2713
  ##############################
2714
2714
 
2715
2715
 
2716
+ # def _str_block_width(val: str) -> int:
2717
+ # import wcwidth # type: ignore[import-untyped]
2718
+ #
2719
+ # return wcwidth.wcswidth(_re.sub("", val))
2720
+
2716
2721
  def _str_block_width(val: str) -> int:
2717
- import wcwidth # type: ignore[import-untyped]
2722
+ """Calculate the visual width of a string, accounting for wide Unicode characters.
2723
+
2724
+ This is a custom implementation to replace the wcwidth dependency.
2725
+
2726
+ Args:
2727
+ val: The string to measure
2728
+
2729
+ Returns:
2730
+ The visual width of the string when displayed in a monospace context
2731
+ """
2732
+ # Remove ANSI escape sequences
2733
+ val = _re.sub("", val)
2734
+
2735
+ # Fast path for ASCII-only strings
2736
+ if all(ord(c) < 128 for c in val):
2737
+ return len(val)
2718
2738
 
2719
- return wcwidth.wcswidth(_re.sub("", val))
2739
+ width = 0
2740
+ for char in val:
2741
+ width += _char_width(char)
2742
+
2743
+ return width
2744
+
2745
+
2746
+ def _char_width(char: str) -> int:
2747
+ """Calculate the display width of a single character.
2748
+
2749
+ Args:
2750
+ char: A single Unicode character
2751
+
2752
+ Returns:
2753
+ 0 for control characters
2754
+ 2 for wide characters (CJK, emoji, etc.)
2755
+ 1 for all other characters
2756
+ """
2757
+ code = ord(char)
2758
+
2759
+ # Control characters and empty space
2760
+ if code == 0 or code == 0x034F or (0x200B <= code <= 0x200F) or code == 0x2028 or code == 0x2029 or (
2761
+ 0x202A <= code <= 0x202E) or (0x2060 <= code <= 0x2063):
2762
+ return 0
2763
+
2764
+ # Wide characters: CJK, emoji, etc.
2765
+ if (0x1100 <= code <= 0x115F) or (0x2329 <= code <= 0x232A) or (0x2E80 <= code <= 0x303E) or (
2766
+ 0x3040 <= code <= 0x4DBF) or (0x4E00 <= code <= 0x9FFF) or (0xA000 <= code <= 0xA4CF) or (
2767
+ 0xAC00 <= code <= 0xD7A3) or (0xF900 <= code <= 0xFAFF) or (0xFE10 <= code <= 0xFE19) or (
2768
+ 0xFE30 <= code <= 0xFE6F) or (0xFF00 <= code <= 0xFF60) or (0xFFE0 <= code <= 0xFFE6) or (
2769
+ 0x1F300 <= code <= 0x1F64F) or (0x1F900 <= code <= 0x1F9FF):
2770
+ return 2
2771
+
2772
+ # Default single-width character
2773
+ return 1
2720
2774
 
2721
2775
 
2722
2776
  ##############################
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from typing import Union, Sequence
19
17
 
20
18
  __all__ = [
@@ -17,8 +17,6 @@
17
17
 
18
18
  """Utilities for defining custom classes that can be used with jax transformations."""
19
19
 
20
- from __future__ import annotations
21
-
22
20
  import collections
23
21
  import dataclasses
24
22
  from collections.abc import Hashable, Mapping