brainstate 0.1.0.post20250208__py2.py3-none-any.whl → 0.1.0.post20250209__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_state.py CHANGED
@@ -19,8 +19,10 @@ import contextlib
19
19
  import dataclasses
20
20
  import threading
21
21
  from functools import wraps, partial
22
- from typing import (Any, Union, Callable, Generic, Mapping,
23
- TypeVar, Optional, TYPE_CHECKING, Tuple, Dict, List, Sequence)
22
+ from typing import (
23
+ Any, Union, Callable, Generic, Mapping,
24
+ TypeVar, Optional, TYPE_CHECKING, Tuple, Dict, List, Sequence
25
+ )
24
26
 
25
27
  import jax
26
28
  import numpy as np
@@ -28,7 +30,7 @@ from jax.api_util import shaped_abstractify
28
30
  from jax.extend import source_info_util
29
31
 
30
32
  from brainstate.typing import ArrayLike, PyTree, Missing
31
- from brainstate.util import DictManager, PrettyRepr, PrettyType, PrettyAttr
33
+ from brainstate.util import DictManager, PrettyReprTree
32
34
 
33
35
  __all__ = [
34
36
  'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
@@ -184,7 +186,7 @@ def _get_trace_stack_level() -> int:
184
186
  return len(TRACE_CONTEXT.state_stack)
185
187
 
186
188
 
187
- class State(Generic[A], PrettyRepr):
189
+ class State(Generic[A], PrettyReprTree):
188
190
  """
189
191
  The pointer to specify the dynamical data.
190
192
 
@@ -427,45 +429,25 @@ class State(Generic[A], PrettyRepr):
427
429
  del metadata['_value']
428
430
  return TreefyState(type(self), self._value, **metadata)
429
431
 
430
- def __pretty_repr__(self):
431
- yield PrettyType(type=type(self))
432
- for name, value in vars(self).items():
433
- if name == '_value':
434
- name = 'value'
435
- if name == '_name':
436
- if value is None:
437
- continue
438
- else:
439
- name = 'name'
440
- if name == 'tag' and value is None:
441
- continue
442
- if name in ['_level', '_source_info', '_been_writen']:
443
- continue
444
- yield PrettyAttr(name, repr(value))
445
-
446
- def __treescope_repr__(self, path, subtree_renderer):
447
- children = {}
448
- for name, value in vars(self).items():
449
- if name == '_value':
450
- name = 'value'
451
- if name == '_name':
452
- if value is None:
453
- continue
454
- else:
455
- name = 'name'
456
- if name == 'tag' and value is None:
457
- continue
458
- if name in ['_level', '_source_info', '_been_writen']:
459
- continue
460
- children[name] = value
461
-
462
- import treescope # type: ignore[import-not-found,import-untyped]
463
- return treescope.repr_lib.render_object_constructor(
464
- object_type=type(self),
465
- attributes=children,
466
- path=path,
467
- subtree_renderer=subtree_renderer,
468
- )
432
+ def __pretty_repr_item__(self, k, v):
433
+ if k in ['_level', '_source_info', '_been_writen']:
434
+ return None, None
435
+ if k == '_value':
436
+ return 'value', v
437
+
438
+ if k == '_name':
439
+ if self.name is None:
440
+ return None, None
441
+ else:
442
+ return 'name', v
443
+
444
+ if k == 'tag':
445
+ if self.tag is None:
446
+ return None, None
447
+ else:
448
+ return 'tag', v
449
+
450
+ return k, v
469
451
 
470
452
  def __eq__(self, other: object) -> bool:
471
453
  return type(self) is type(other) and vars(other) == vars(self)
@@ -802,7 +784,7 @@ class StateTraceStack(Generic[A]):
802
784
  return StateTraceStack().merge(self, other)
803
785
 
804
786
 
805
- class TreefyState(Generic[A], PrettyRepr):
787
+ class TreefyState(Generic[A], PrettyReprTree):
806
788
  """
807
789
  The state as a pytree.
808
790
  """
@@ -824,35 +806,19 @@ class TreefyState(Generic[A], PrettyRepr):
824
806
 
825
807
  def __delattr__(self, name: str) -> None: ...
826
808
 
827
- def __pretty_repr__(self):
828
- yield PrettyType(type=type(self))
829
- yield PrettyAttr('type', self.type.__name__)
830
- for name, value in vars(self).items():
831
- if name == '_value':
832
- name = 'value'
833
- if name == '_name':
834
- if value is None:
835
- continue
836
- else:
837
- name = 'name'
838
- if name in ['_level', '_source_info', 'type']:
839
- continue
840
- yield PrettyAttr(name, repr(value))
841
-
842
- def __treescope_repr__(self, path, subtree_renderer):
843
- children = {'type': self.type}
844
- for name, value in vars(self).items():
845
- if name == 'type':
846
- continue
847
- children[name] = value
848
-
849
- import treescope # type: ignore[import-not-found,import-untyped]
850
- return treescope.repr_lib.render_object_constructor(
851
- object_type=type(self),
852
- attributes=children,
853
- path=path,
854
- subtree_renderer=subtree_renderer,
855
- )
809
+ def __pretty_repr_item__(self, k, v):
810
+ if k in ['_level', '_source_info', '_been_writen']:
811
+ return None, None
812
+ if k == '_value':
813
+ return 'value', v
814
+
815
+ if k == '_name':
816
+ if self.name is None:
817
+ return None, None
818
+ else:
819
+ return 'name', v
820
+
821
+ return k, v
856
822
 
857
823
  def replace(self, value: B) -> TreefyState[B]:
858
824
  """
@@ -206,7 +206,7 @@ class StatefulFunction(object):
206
206
  self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
207
207
 
208
208
  def __repr__(self) -> str:
209
- return (f"{self.__class__.__name__}({self.fun}, "
209
+ return (f"{self.__class__.__name__}("
210
210
  f"static_argnums={self.static_argnums}, "
211
211
  f"axis_env={self.axis_env}, "
212
212
  f"abstracted_axes={self.abstracted_axes}, "
@@ -27,7 +27,7 @@ import numpy as np
27
27
 
28
28
  from brainstate._state import State, TreefyState
29
29
  from brainstate.typing import Key
30
- from brainstate.util._pretty_repr import PrettyRepr, pretty_repr_avoid_duplicate, PrettyType, PrettyAttr
30
+ from brainstate.util._pretty_repr import PrettyRepr, yield_unique_pretty_repr_items, PrettyType, PrettyAttr
31
31
  from ._graph_operation import register_graph_node_type
32
32
 
33
33
  __all__ = [
@@ -88,7 +88,7 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
88
88
  """
89
89
  Pretty repr for the object.
90
90
  """
91
- yield from pretty_repr_avoid_duplicate(self, _default_repr_object, _default_repr_attr)
91
+ yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
92
92
 
93
93
  def __treescope_repr__(self, path, subtree_renderer):
94
94
  """
brainstate/util/_dict.py CHANGED
@@ -24,11 +24,12 @@ import jax
24
24
 
25
25
  from brainstate.typing import Filter, PathParts
26
26
  from ._filter import to_predicate
27
- from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, pretty_repr_avoid_duplicate, get_repr
27
+ from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr
28
28
  from ._struct import dataclass
29
29
 
30
30
  __all__ = [
31
- 'NestedDict', 'FlattedDict', 'flat_mapping', 'nest_mapping',
31
+ 'PrettyDict', 'NestedDict', 'FlattedDict', 'flat_mapping', 'nest_mapping',
32
+ 'PrettyList', 'PrettyReprTree',
32
33
  ]
33
34
 
34
35
  A = TypeVar('A')
@@ -40,6 +41,44 @@ ExtractValueFn = abc.Callable[[Any], Any]
40
41
  SetValueFn = abc.Callable[[V, Any], V]
41
42
 
42
43
 
44
+
45
+
46
+ class PrettyReprTree(PrettyRepr):
47
+ """
48
+ Pretty representation of a tree.
49
+ """
50
+
51
+ def __pretty_repr__(self):
52
+ return yield_unique_pretty_repr_items(
53
+ self,
54
+ repr_object=self._repr_object,
55
+ repr_attr=self._repr_attr,
56
+ )
57
+
58
+ def __pretty_repr_item__(self, k, v):
59
+ return k, v
60
+
61
+ def _repr_object(self, node: PrettyDict):
62
+ yield PrettyType(type(node), value_sep=': ', start='({', end='})')
63
+
64
+ def _repr_attr(self, node):
65
+ for k, v in vars(node).items():
66
+ k, v = self.__pretty_repr_item__(k, v)
67
+ if k is None:
68
+ continue
69
+
70
+ if isinstance(v, list):
71
+ v = PrettyList(v)
72
+
73
+ if isinstance(v, dict):
74
+ v = PrettyDict(v)
75
+
76
+ if isinstance(v, PrettyDict):
77
+ v = NestedStateRepr(v)
78
+
79
+ yield PrettyAttr(repr(k), v)
80
+
81
+
43
82
  # the empty node is a struct.dataclass to be compatible with JAX.
44
83
  @dataclass
45
84
  class _EmptyNode:
@@ -213,10 +252,10 @@ class PrettyDict(dict, PrettyRepr):
213
252
 
214
253
  def __repr__(self) -> str:
215
254
  # repr the individual object with the pretty representation
216
- return get_repr(self)
255
+ return pretty_repr(self)
217
256
 
218
257
  def __pretty_repr__(self):
219
- yield from pretty_repr_avoid_duplicate(self, _default_repr_object, _default_repr_attr)
258
+ yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
220
259
 
221
260
  def split(self, *filters) -> Union[PrettyDict[K, V], Tuple[PrettyDict[K, V], ...]]:
222
261
  raise NotImplementedError
@@ -237,15 +276,20 @@ class PrettyDict(dict, PrettyRepr):
237
276
 
238
277
 
239
278
  def _default_repr_object(node: PrettyDict):
240
- yield PrettyType(type(node), value_sep=': ', start='({', end='})')
279
+ yield PrettyType('', value_sep=': ', start='{', end='}')
241
280
 
242
281
 
243
- def _default_repr_attr(node: PrettyDict):
282
+ def _default_repr_attr(node):
244
283
  for k, v in node.items():
284
+ if isinstance(v, list):
285
+ v = PrettyList(v)
286
+
245
287
  if isinstance(v, dict):
246
288
  v = PrettyDict(v)
289
+
247
290
  if isinstance(v, PrettyDict):
248
291
  v = NestedStateRepr(v)
292
+
249
293
  yield PrettyAttr(repr(k), v)
250
294
 
251
295
 
@@ -735,3 +779,37 @@ def _flat_unflatten(
735
779
  jax.tree_util.register_pytree_with_keys(FlattedDict,
736
780
  _nest_flatten_with_keys,
737
781
  _flat_unflatten) # type: ignore[arg-type]
782
+
783
+
784
+ @jax.tree_util.register_pytree_node_class
785
+ class PrettyList(list, PrettyRepr):
786
+ __module__ = 'brainstate.util'
787
+
788
+ def __pretty_repr__(self):
789
+ yield from yield_unique_pretty_repr_items(self, _list_repr_object, _list_repr_attr)
790
+
791
+ def __repr__(self):
792
+ return pretty_repr(self)
793
+
794
+ def tree_flatten(self):
795
+ return list(self), ()
796
+
797
+ @classmethod
798
+ def tree_unflatten(cls, aux_data, children):
799
+ return cls(children)
800
+
801
+
802
+ def _list_repr_attr(node: PrettyList):
803
+ for v in node:
804
+ if isinstance(v, list):
805
+ v = PrettyList(v)
806
+ if isinstance(v, dict):
807
+ v = PrettyDict(v)
808
+ if isinstance(v, PrettyDict):
809
+ v = NestedStateRepr(v)
810
+ yield PrettyAttr('', v)
811
+
812
+
813
+ def _list_repr_object(node: PrettyDict):
814
+ yield PrettyType('', value_sep='', start='[', end=']')
815
+
@@ -21,9 +21,11 @@ import dataclasses
21
21
  import threading
22
22
  from abc import ABC, abstractmethod
23
23
  from functools import partial
24
- from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
24
+ from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional, Sequence
25
25
 
26
26
  __all__ = [
27
+ 'yield_unique_pretty_repr_items',
28
+ 'pretty_repr',
27
29
  'PrettyType',
28
30
  'PrettyAttr',
29
31
  'PrettyRepr',
@@ -80,7 +82,7 @@ class PrettyRepr(ABC):
80
82
 
81
83
  def __repr__(self) -> str:
82
84
  # repr the individual object with the pretty representation
83
- return get_repr(self)
85
+ return pretty_repr(self)
84
86
 
85
87
 
86
88
  def _repr_elem(obj: PrettyType, elem: Any) -> str:
@@ -93,7 +95,7 @@ def _repr_elem(obj: PrettyType, elem: Any) -> str:
93
95
  return f'{obj.elem_indent}{elem.start}{elem.key}{obj.value_sep}{value}{elem.end}'
94
96
 
95
97
 
96
- def get_repr(obj: PrettyRepr) -> str:
98
+ def pretty_repr(obj: PrettyRepr) -> str:
97
99
  """
98
100
  Get the pretty representation of an object.
99
101
  """
@@ -140,9 +142,10 @@ class PrettyMapping(PrettyRepr):
140
142
  Pretty representation of a mapping.
141
143
  """
142
144
  mapping: Mapping
145
+ type_name: str = ''
143
146
 
144
147
  def __pretty_repr__(self):
145
- yield PrettyType(type='', value_sep=': ', start='{', end='}')
148
+ yield PrettyType(type=self.type_name, value_sep=': ', start='{', end='}')
146
149
 
147
150
  for key, value in self.mapping.items():
148
151
  yield PrettyAttr(repr(key), value)
@@ -168,7 +171,7 @@ def _default_repr_attr(node):
168
171
  yield PrettyAttr(name, repr(value))
169
172
 
170
173
 
171
- def pretty_repr_avoid_duplicate(
174
+ def yield_unique_pretty_repr_items(
172
175
  node,
173
176
  repr_object: Optional[Callable] = None,
174
177
  repr_attr: Optional[Callable] = None
@@ -206,3 +209,4 @@ def pretty_repr_avoid_duplicate(
206
209
  finally:
207
210
  if clear_seen:
208
211
  CONTEXT.seen_modules_repr = None
212
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250208
3
+ Version: 0.1.0.post20250209
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,5 +1,5 @@
1
1
  brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
2
- brainstate/_state.py,sha256=GZ46liHZSHbAHQEuELvOeoJ27P9xiZDz06G2AASjAjA,29142
2
+ brainstate/_state.py,sha256=W1Q_RAL01rUSLZuOARMuX9I-26tBuIl_VzNWAziz6A8,27518
3
3
  brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
4
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
5
5
  brainstate/environ.py,sha256=PZnVFWPioUBuWmwCO8wwCKrHQfP3BR-5lYPRl5i5GDA,17698
@@ -30,7 +30,7 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=J4oWoPBwG-fdJvNhBEtNgmo3rXrIWCoajELhaIumgPU,33309
33
+ brainstate/compile/_make_jaxpr.py,sha256=MuAa9LjXi29DjYgDUrK0WaomkjbhHZk9mWW04XGcV98,33297
34
34
  brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
35
35
  brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
36
36
  brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
@@ -42,7 +42,7 @@ brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJ
42
42
  brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
43
43
  brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36QvlZc,2678
44
44
  brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
45
- brainstate/graph/_graph_node.py,sha256=swAokZLKswSTaq2WEhyLIs38sy_67C6maHI6T3e1hvY,8339
45
+ brainstate/graph/_graph_node.py,sha256=XwzOuaZG9x4eZknQjzJoTnnYAy7wcKD5Vox1VkYr8GM,8345
46
46
  brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
47
47
  brainstate/graph/_graph_operation.py,sha256=cIwGo3ICgtce2fmdn917r81evMFjJIKeW9doaQK4DD8,64111
48
48
  brainstate/graph/_graph_operation_test.py,sha256=zjvpKjQAFWtw8YZuqOk_jmlZNb_-E8oPyNx57dyc8jI,18556
@@ -109,16 +109,16 @@ brainstate/random/_rand_state.py,sha256=nuoQ8GU1MfJPRNN-ZmRQsggVjoyPhaEdZmwM7_4-
109
109
  brainstate/random/_random_for_unit.py,sha256=kGp4EUX19MXJ9Govoivbg8N0bddqOldKEI2h_TbdONY,2057
110
110
  brainstate/util/__init__.py,sha256=-FWEuSKXG3mWxYphGFAy3UEuVe39lFs1GruluzdXDoI,1502
111
111
  brainstate/util/_caller.py,sha256=T3bzu7-09r-6EOrU6Muca_aMXSQua_X2lXjEqb-w39w,2782
112
- brainstate/util/_dict.py,sha256=Yapug-_RZQYjvd8cZ3v90_MX7rUYJDBzBnZJT6a0NXY,26178
112
+ brainstate/util/_dict.py,sha256=tb5nPrTKJe4G_BDv33XYTUaYQDz6od-5psG4TKemc7A,28111
113
113
  brainstate/util/_dict_test.py,sha256=Dn0TdjX6wLBXaTD4jfYTu6cKfFHwKSxi4_3bX7kB_IA,5621
114
114
  brainstate/util/_error.py,sha256=eyZ8PGFixqe2K5OEfjSDzI-2tU0ieYQoUpBP7yStlPQ,878
115
115
  brainstate/util/_filter.py,sha256=1-bvFHdjeehvXeHTrCEp8xr25lopKe8d3XZGCNegq0s,4970
116
116
  brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14762
117
- brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
117
+ brainstate/util/_pretty_repr.py,sha256=vNwRlj4sI4QJ_koyIs7eKdUMeB_QWwzRYsE8PpAWN3g,5833
118
118
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
119
119
  brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
120
- brainstate-0.1.0.post20250208.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
- brainstate-0.1.0.post20250208.dist-info/METADATA,sha256=bPJy4z_tBevWkxdyS5QXYtcXxvYwGYs44SuYIkqq4Ns,3585
122
- brainstate-0.1.0.post20250208.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
- brainstate-0.1.0.post20250208.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
- brainstate-0.1.0.post20250208.dist-info/RECORD,,
120
+ brainstate-0.1.0.post20250209.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
+ brainstate-0.1.0.post20250209.dist-info/METADATA,sha256=vc9kKmrq5JM9Os6brL4zecy55nEpd9ASK9GZNJBQV9g,3585
122
+ brainstate-0.1.0.post20250209.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
+ brainstate-0.1.0.post20250209.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
+ brainstate-0.1.0.post20250209.dist-info/RECORD,,