brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/util/struct.py CHANGED
@@ -15,509 +15,896 @@
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
17
 
18
- """Utilities for defining custom classes that can be used with jax transformations."""
18
+ """
19
+ Custom data structures that work seamlessly with JAX transformations.
20
+ """
21
+
22
+ from __future__ import annotations
19
23
 
20
- import collections
21
24
  import dataclasses
22
- from collections.abc import Hashable, Mapping
23
- from types import MappingProxyType
24
- from typing import Any, TypeVar
25
+ from collections.abc import Mapping, KeysView, ValuesView, ItemsView
26
+ from typing import Any, TypeVar, Generic, Iterator, overload
25
27
 
26
28
  import jax
27
- from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
29
+ import jax.tree_util
30
+ from typing_extensions import dataclass_transform
28
31
 
29
32
  __all__ = [
30
- 'dataclass',
31
33
  'field',
34
+ 'dataclass',
32
35
  'PyTreeNode',
33
36
  'FrozenDict',
37
+ 'freeze',
38
+ 'unfreeze',
39
+ 'copy',
40
+ 'pop',
41
+ 'pretty_repr',
34
42
  ]
35
43
 
44
+ # Type variables
36
45
  K = TypeVar('K')
37
46
  V = TypeVar('V')
38
47
  T = TypeVar('T')
48
+ TNode = TypeVar('TNode', bound='PyTreeNode')
39
49
 
40
50
 
41
- def field(pytree_node=True, *, metadata=None, **kwargs):
42
- return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node}, **kwargs)
51
+ def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field:
52
+ """
53
+ Create a dataclass field with JAX pytree metadata.
54
+
55
+ Parameters
56
+ ----------
57
+ pytree_node : bool, optional
58
+ If True (default), this field will be treated as part of the pytree.
59
+ If False, it will be treated as metadata and not be touched
60
+ by JAX transformations.
61
+ **kwargs
62
+ Additional arguments to pass to dataclasses.field().
63
+
64
+ Returns
65
+ -------
66
+ dataclasses.Field
67
+ A dataclass field with the appropriate metadata.
68
+
69
+ Examples
70
+ --------
71
+ .. code-block:: python
72
+
73
+ >>> import jax.numpy as jnp
74
+ >>> from brainstate.util import dataclass, field
75
+
76
+ >>> @dataclass
77
+ ... class Model:
78
+ ... weights: jnp.ndarray
79
+ ... bias: jnp.ndarray
80
+ ... # This field won't be affected by JAX transformations
81
+ ... name: str = field(pytree_node=False, default="model")
82
+ """
83
+ metadata = kwargs.pop('metadata', {})
84
+ metadata['pytree_node'] = pytree_node
85
+ return dataclasses.field(metadata=metadata, **kwargs)
43
86
 
44
87
 
45
- @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
46
- def dataclass(clz: T, **kwargs) -> T:
88
+ @dataclass_transform(field_specifiers=(field,))
89
+ def dataclass(cls: type[T], **kwargs) -> type[T]:
47
90
  """
48
- Create a class which can be passed to functional transformations.
49
-
50
- .. note::
51
- Inherit from ``PyTreeNode`` instead to avoid type checking issues when
52
- using PyType.
53
-
54
- Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are
55
- immutable and can be mapped over using the ``jax.tree_util`` methods.
56
- The ``dataclass`` decorator makes it easy to define custom classes that can be
57
- passed safely to Jax. For example::
58
-
59
- >>> import brainstate as brainstate
60
- >>> import jax
61
- >>> from typing import Any, Callable
62
-
63
- >>> @brainstate.util.dataclass
64
- ... class Model:
65
- ... params: Any
66
- ... # use pytree_node=False to indicate an attribute should not be touched
67
- ... # by Jax transformations.
68
- ... apply_fn: Callable = brainstate.util.field(pytree_node=False)
69
-
70
- ... def __apply__(self, *args):
71
- ... return self.apply_fn(*args)
72
-
73
- >>> params = {}
74
- >>> params_b = {}
75
- >>> apply_fn = lambda v, x: x
76
- >>> model = Model(params, apply_fn)
77
-
78
- >>> # model.params = params_b # Model is immutable. This will raise an error.
79
- >>> model_b = model.replace(params=params_b) # Use the replace method instead.
80
-
81
- >>> # This class can now be used safely in Jax to compute gradients w.r.t. the
82
- >>> # parameters.
83
- >>> model = Model(params, apply_fn)
84
- >>> loss_fn = lambda model: 3.
85
- >>> model_grad = jax.grad(loss_fn)(model)
86
-
87
- Note that dataclasses have an auto-generated ``__init__`` where
88
- the arguments of the constructor and the attributes of the created
89
- instance match 1:1. This correspondence is what makes these objects
90
- valid containers that work with JAX transformations and
91
- more generally the ``jax.tree_util`` library.
92
-
93
- Sometimes a "smart constructor" is desired, for example because
94
- some of the attributes can be (optionally) derived from others.
95
- The way to do this with Flax dataclasses is to make a static or
96
- class method that provides the smart constructor.
97
- This way the simple constructor used by ``jax.tree_util`` is
98
- preserved. Consider the following example::
99
-
100
- >>> @brainstate.util.dataclass
101
- ... class DirectionAndScaleKernel:
102
- ... direction: jax.Array
103
- ... scale: jax.Array
104
-
105
- ... @classmethod
106
- ... def create(cls, kernel):
107
- ... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
108
- ... direction = direction / scale
109
- ... return cls(direction, scale)
110
-
111
- Args:
112
- clz: the class that will be transformed by the decorator.
113
- Returns:
114
- The new class.
91
+ Create a dataclass that works with JAX transformations.
92
+
93
+ This decorator creates immutable dataclasses that can be used safely
94
+ with JAX transformations like jit, grad, vmap, etc. The created class
95
+ will be registered as a JAX pytree node.
96
+
97
+ Parameters
98
+ ----------
99
+ cls : type
100
+ The class to decorate.
101
+ **kwargs
102
+ Additional arguments for dataclasses.dataclass().
103
+ If 'frozen' is not specified, it defaults to True.
104
+
105
+ Returns
106
+ -------
107
+ type
108
+ The decorated class as an immutable JAX-compatible dataclass.
109
+
110
+ See Also
111
+ --------
112
+ PyTreeNode : Base class for creating JAX-compatible pytree nodes.
113
+ field : Create dataclass fields with pytree metadata.
114
+
115
+ Notes
116
+ -----
117
+ The decorated class will be frozen (immutable) by default to ensure
118
+ compatibility with JAX's functional programming paradigm.
119
+
120
+ Examples
121
+ --------
122
+ .. code-block:: python
123
+
124
+ >>> import jax
125
+ >>> import jax.numpy as jnp
126
+ >>> from brainstate.util import dataclass, field
127
+
128
+ >>> @dataclass
129
+ ... class Model:
130
+ ... weights: jax.Array
131
+ ... bias: jax.Array
132
+ ... name: str = field(pytree_node=False, default="model")
133
+
134
+ >>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))
135
+
136
+ >>> # JAX transformations will only apply to weights and bias, not name
137
+ >>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
138
+ >>> grads = grad_fn(model)
139
+
140
+ >>> # Use replace to create modified copies
141
+ >>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)
115
142
  """
116
- # check if already a flax dataclass
117
- if '_brainstate_dataclass' in clz.__dict__:
118
- return clz
143
+ # Check if already converted
144
+ if hasattr(cls, '_brainstate_dataclass'):
145
+ return cls
146
+
147
+ # Default to frozen for immutability
148
+ kwargs.setdefault('frozen', True)
119
149
 
120
- if 'frozen' not in kwargs.keys():
121
- kwargs['frozen'] = True
122
- data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
150
+ # Apply standard dataclass decorator
151
+ cls = dataclasses.dataclass(**kwargs)(cls)
152
+
153
+ # Separate fields into pytree and metadata
154
+ pytree_fields = []
123
155
  meta_fields = []
124
- data_fields = []
125
- for field_info in dataclasses.fields(data_clz):
126
- is_pytree_node = field_info.metadata.get('pytree_node', True)
127
- if is_pytree_node:
128
- data_fields.append(field_info.name)
156
+
157
+ for field_info in dataclasses.fields(cls):
158
+ if field_info.metadata.get('pytree_node', True):
159
+ pytree_fields.append(field_info.name)
129
160
  else:
130
161
  meta_fields.append(field_info.name)
131
162
 
132
- def replace(self, **updates):
133
- """ "Returns a new object replacing the specified fields with new values."""
163
+ # Add replace method
164
+ def replace(self: T, **updates) -> T:
165
+ """Replace specified fields with new values."""
134
166
  return dataclasses.replace(self, **updates)
135
167
 
136
- data_clz.replace = replace
168
+ cls.replace = replace
137
169
 
138
- # Remove this guard once minimux JAX version is >0.4.26.
139
- try:
140
- if hasattr(jax.tree_util, 'register_dataclass'):
141
- jax.tree_util.register_dataclass(
142
- data_clz, data_fields, meta_fields
143
- )
144
- else:
145
- raise NotImplementedError
146
- except NotImplementedError:
147
-
148
- def iterate_clz(x):
149
- meta = tuple(getattr(x, name) for name in meta_fields)
150
- data = tuple(getattr(x, name) for name in data_fields)
151
- return data, meta
152
-
153
- def iterate_clz_with_keys(x):
154
- meta = tuple(getattr(x, name) for name in meta_fields)
155
- data = tuple(
156
- (jax.tree_util.GetAttrKey(name), getattr(x, name))
157
- for name in data_fields
158
- )
159
- return data, meta
160
-
161
- def clz_from_iterable(meta, data):
162
- meta_args = tuple(zip(meta_fields, meta))
163
- data_args = tuple(zip(data_fields, data))
164
- kwargs = dict(meta_args + data_args)
165
- return data_clz(**kwargs)
170
+ # Register with JAX
171
+ _register_pytree(cls, pytree_fields, meta_fields)
166
172
 
167
- jax.tree_util.register_pytree_with_keys(
168
- data_clz,
169
- iterate_clz_with_keys,
170
- clz_from_iterable,
171
- iterate_clz,
172
- )
173
+ # Mark as BrainState dataclass
174
+ cls._brainstate_dataclass = True
173
175
 
174
- # add a _brainstate_dataclass flag to distinguish from regular dataclasses
175
- data_clz._brainstate_dataclass = True # type: ignore[attr-defined]
176
+ return cls
176
177
 
177
- return data_clz # type: ignore
178
178
 
179
+ def _register_pytree(cls: type, pytree_fields: list[str], meta_fields: list[str]) -> None:
180
+ """Register a class as a JAX pytree."""
179
181
 
180
- TNode = TypeVar('TNode', bound='PyTreeNode')
182
+ def flatten_fn(obj):
183
+ pytree_data = tuple(getattr(obj, name) for name in pytree_fields)
184
+ metadata = tuple(getattr(obj, name) for name in meta_fields)
185
+ return pytree_data, metadata
186
+
187
+ def flatten_with_keys_fn(obj):
188
+ pytree_data = tuple(
189
+ (jax.tree_util.GetAttrKey(name), getattr(obj, name))
190
+ for name in pytree_fields
191
+ )
192
+ metadata = tuple(getattr(obj, name) for name in meta_fields)
193
+ return pytree_data, metadata
194
+
195
+ def unflatten_fn(metadata, pytree_data):
196
+ kwargs = {}
197
+ for name, value in zip(meta_fields, metadata):
198
+ kwargs[name] = value
199
+ for name, value in zip(pytree_fields, pytree_data):
200
+ kwargs[name] = value
201
+ return cls(**kwargs)
202
+
203
+ # Use new API if available, otherwise fall back
204
+ if hasattr(jax.tree_util, 'register_dataclass'):
205
+ jax.tree_util.register_dataclass(cls, pytree_fields, meta_fields)
206
+ else:
207
+ jax.tree_util.register_pytree_with_keys(
208
+ cls,
209
+ flatten_with_keys_fn,
210
+ unflatten_fn,
211
+ flatten_fn
212
+ )
181
213
 
182
214
 
183
- @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
215
+ @dataclass_transform(field_specifiers=(field,))
184
216
  class PyTreeNode:
185
- """Base class for dataclasses that should act like a JAX pytree node.
217
+ """
218
+ Base class for creating JAX-compatible pytree nodes.
186
219
 
187
- See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.
188
- This base class additionally avoids type checking errors when using PyType.
220
+ Subclasses of PyTreeNode are automatically converted to immutable
221
+ dataclasses that work with JAX transformations.
189
222
 
190
- Example::
223
+ See Also
224
+ --------
225
+ dataclass : Decorator for creating JAX-compatible dataclasses.
226
+ field : Create dataclass fields with pytree metadata.
191
227
 
192
- >>> import brainstate as brainstate
193
- >>> import jax
194
- >>> from typing import Any, Callable
228
+ Notes
229
+ -----
230
+ When subclassing PyTreeNode, all fields are automatically treated as
231
+ part of the pytree unless explicitly marked with ``pytree_node=False``
232
+ using the field() function.
195
233
 
196
- >>> class Model(brainstate.util.PyTreeNode):
197
- ... params: Any
198
- ... # use pytree_node=False to indicate an attribute should not be touched
199
- ... # by Jax transformations.
200
- ... apply_fn: Callable = brainstate.util.field(pytree_node=False)
234
+ Examples
235
+ --------
236
+ .. code-block:: python
201
237
 
202
- ... def __apply__(self, *args):
203
- ... return self.apply_fn(*args)
238
+ >>> import jax
239
+ >>> import jax.numpy as jnp
240
+ >>> from brainstate.util import PyTreeNode, field
204
241
 
205
- >>> params = {}
206
- >>> params_b = {}
207
- >>> apply_fn = lambda v, x: x
208
- >>> model = Model(params, apply_fn)
242
+ >>> class Layer(PyTreeNode):
243
+ ... weights: jax.Array
244
+ ... bias: jax.Array
245
+ ... activation: str = field(pytree_node=False, default="relu")
209
246
 
210
- >>> # model.params = params_b # Model is immutable. This will raise an error.
211
- >>> model_b = model.replace(params=params_b) # Use the replace method instead.
247
+ >>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))
212
248
 
213
- >>> # This class can now be used safely in Jax to compute gradients w.r.t. the
214
- >>> # parameters.
215
- >>> model = Model(params, apply_fn)
216
- >>> loss_fn = lambda model: 3.
217
- >>> model_grad = jax.grad(loss_fn)(model)
249
+ >>> # Can be used in JAX transformations
250
+ >>> def loss_fn(layer):
251
+ ... return jnp.sum(layer.weights ** 2)
252
+ >>> grad_fn = jax.grad(loss_fn)
253
+ >>> grads = grad_fn(layer)
254
+
255
+ >>> # Create modified copies with replace
256
+ >>> layer2 = layer.replace(bias=jnp.ones(4))
218
257
  """
219
258
 
220
259
  def __init_subclass__(cls, **kwargs):
221
- dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types
260
+ """Automatically apply dataclass decorator to subclasses."""
261
+ dataclass(cls, **kwargs)
222
262
 
223
263
  def __init__(self, *args, **kwargs):
224
- # stub for pytype
225
- raise NotImplementedError
264
+ """Stub for type checkers."""
265
+ raise NotImplementedError("PyTreeNode is a base class")
226
266
 
227
- def replace(self: TNode, **overrides) -> TNode:
228
- # stub for pytype
229
- raise NotImplementedError
267
+ def replace(self: TNode, **updates) -> TNode:
268
+ """
269
+ Replace specified fields with new values.
230
270
 
271
+ Parameters
272
+ ----------
273
+ **updates
274
+ Field names and their new values.
231
275
 
232
- def _indent(x, num_spaces):
233
- indent_str = ' ' * num_spaces
234
- lines = x.split('\n')
235
- assert not lines[-1]
236
- # skip the final line because it's empty and should not be indented.
237
- return '\n'.join(indent_str + line for line in lines[:-1]) + '\n'
276
+ Returns
277
+ -------
278
+ TNode
279
+ A new instance with updated fields.
280
+ """
281
+ raise NotImplementedError("Implemented by dataclass decorator")
238
282
 
239
283
 
240
284
  @jax.tree_util.register_pytree_with_keys_class
241
- class FrozenDict(Mapping[K, V]):
285
+ class FrozenDict(Mapping[K, V], Generic[K, V]):
242
286
  """
243
- An immutable variant of the Python dict.
287
+ An immutable dictionary that works as a JAX pytree.
288
+
289
+ FrozenDict provides an immutable mapping interface that can be used
290
+ safely with JAX transformations. It supports all standard dictionary
291
+ operations in an immutable fashion.
292
+
293
+ Parameters
294
+ ----------
295
+ *args
296
+ Positional arguments for dict construction.
297
+ **kwargs
298
+ Keyword arguments for dict construction.
299
+
300
+ Attributes
301
+ ----------
302
+ _data : dict
303
+ Internal immutable data storage.
304
+ _hash : int or None
305
+ Cached hash value.
306
+
307
+ See Also
308
+ --------
309
+ freeze : Convert a mapping to a FrozenDict.
310
+ unfreeze : Convert a FrozenDict to a regular dict.
311
+
312
+ Notes
313
+ -----
314
+ FrozenDict is immutable - all operations that would modify the dictionary
315
+ instead return a new FrozenDict instance with the changes applied.
316
+
317
+ Examples
318
+ --------
319
+ .. code-block:: python
320
+
321
+ >>> from brainstate.util import FrozenDict
322
+
323
+ >>> # Create a FrozenDict
324
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
325
+ >>> fd['a']
326
+ 1
327
+
328
+ >>> # Copy with updates (returns new FrozenDict)
329
+ >>> new_fd = fd.copy({'c': 3})
330
+ >>> new_fd['c']
331
+ 3
332
+
333
+ >>> # Pop an item (returns new dict and popped value)
334
+ >>> new_fd, value = fd.pop('b')
335
+ >>> value
336
+ 2
337
+ >>> 'b' in new_fd
338
+ False
339
+
340
+ >>> # Nested dictionaries are automatically frozen
341
+ >>> fd = FrozenDict({'x': {'y': 1}})
342
+ >>> isinstance(fd['x'], FrozenDict)
343
+ True
244
344
  """
245
345
 
246
- __slots__ = ('_dict', '_hash')
247
-
248
- def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name
249
- # make sure the dict is as
250
- xs = dict(*args, **kwargs)
251
- if __unsafe_skip_copy__:
252
- self._dict = xs
253
- else:
254
- self._dict = _prepare_freeze(xs)
346
+ __slots__ = ('_data', '_hash')
255
347
 
348
+ def __init__(self, *args, **kwargs):
349
+ """Initialize a FrozenDict."""
350
+ data = dict(*args, **kwargs)
351
+ self._data = self._deep_freeze(data)
256
352
  self._hash = None
257
353
 
258
- def __getitem__(self, key):
259
- v = self._dict[key]
260
- if isinstance(v, dict):
261
- return FrozenDict(v)
262
- return v
354
+ @staticmethod
355
+ def _deep_freeze(obj: Any) -> Any:
356
+ """Recursively freeze nested dictionaries."""
357
+ if isinstance(obj, FrozenDict):
358
+ return obj._data
359
+ elif isinstance(obj, dict):
360
+ return {k: FrozenDict._deep_freeze(v) for k, v in obj.items()}
361
+ else:
362
+ return obj
263
363
 
264
- def __setitem__(self, key, value):
265
- raise ValueError('FrozenDict is immutable.')
364
+ def __getitem__(self, key: K) -> V:
365
+ """Get an item from the dictionary."""
366
+ value = self._data[key]
367
+ if isinstance(value, dict):
368
+ return FrozenDict(value)
369
+ return value
266
370
 
267
- def __contains__(self, key):
268
- return key in self._dict
371
+ def __setitem__(self, key: K, value: V) -> None:
372
+ """Raise an error - FrozenDict is immutable."""
373
+ raise TypeError("FrozenDict does not support item assignment")
269
374
 
270
- def __iter__(self):
271
- return iter(self._dict)
375
+ def __delitem__(self, key: K) -> None:
376
+ """Raise an error - FrozenDict is immutable."""
377
+ raise TypeError("FrozenDict does not support item deletion")
272
378
 
273
- def __len__(self):
274
- return len(self._dict)
379
+ def __contains__(self, key: object) -> bool:
380
+ """Check if a key is in the dictionary."""
381
+ return key in self._data
275
382
 
276
- def __repr__(self):
277
- return self.pretty_repr()
383
+ def __iter__(self) -> Iterator[K]:
384
+ """Iterate over keys."""
385
+ return iter(self._data)
278
386
 
279
- def __reduce__(self):
280
- return FrozenDict, (self.unfreeze(),)
281
-
282
- def pretty_repr(self, num_spaces=4):
283
- """Returns an indented representation of the nested dictionary."""
284
-
285
- def pretty_dict(x):
286
- if not isinstance(x, dict):
287
- return repr(x)
288
- rep = ''
289
- for key, val in x.items():
290
- rep += f'{key}: {pretty_dict(val)},\n'
291
- if rep:
292
- return '{\n' + _indent(rep, num_spaces) + '}'
293
- else:
294
- return '{}'
387
+ def __len__(self) -> int:
388
+ """Return the number of items."""
389
+ return len(self._data)
295
390
 
296
- return f'FrozenDict({pretty_dict(self._dict)})'
391
+ def __repr__(self) -> str:
392
+ """Return a string representation."""
393
+ return self.pretty_repr()
297
394
 
298
- def __hash__(self):
395
+ def __hash__(self) -> int:
396
+ """Return a hash of the dictionary."""
299
397
  if self._hash is None:
300
- h = 0
398
+ items = []
301
399
  for key, value in self.items():
302
- h ^= hash((key, value))
303
- self._hash = h
400
+ if isinstance(value, dict):
401
+ value = FrozenDict(value)
402
+ items.append((key, value))
403
+ self._hash = hash(tuple(sorted(items)))
304
404
  return self._hash
305
405
 
306
- def copy(
307
- self,
308
- add_or_replace: Mapping[K, V] = MappingProxyType({})
309
- ) -> 'FrozenDict[K, V]':
310
- """Create a new FrozenDict with additional or replaced entries."""
311
- return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
406
+ def __eq__(self, other: object) -> bool:
407
+ """Check equality with another object."""
408
+ if not isinstance(other, (FrozenDict, dict)):
409
+ return NotImplemented
410
+ if isinstance(other, FrozenDict):
411
+ return self._data == other._data
412
+ return self._data == other
413
+
414
+ def __reduce__(self):
415
+ """Support for pickling."""
416
+ return FrozenDict, (self.unfreeze(),)
417
+
418
+ def keys(self) -> KeysView[K]:
419
+ """
420
+ Return a view of the keys.
312
421
 
313
- def keys(self):
422
+ Returns
423
+ -------
424
+ KeysView
425
+ A view object of the dictionary's keys.
426
+ """
314
427
  return FrozenKeysView(self)
315
428
 
316
- def values(self):
429
+ def values(self) -> ValuesView[V]:
430
+ """
431
+ Return a view of the values.
432
+
433
+ Returns
434
+ -------
435
+ ValuesView
436
+ A view object of the dictionary's values.
437
+ """
317
438
  return FrozenValuesView(self)
318
439
 
319
- def items(self):
320
- for key in self._dict:
321
- yield (key, self[key])
440
+ def items(self) -> ItemsView[K, V]:
441
+ """
442
+ Return a view of the items.
322
443
 
323
- def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]:
324
- """Create a new FrozenDict where one entry is removed.
444
+ Yields
445
+ ------
446
+ tuple
447
+ Key-value pairs from the dictionary.
448
+ """
449
+ for key in self._data:
450
+ yield (key, self[key])
325
451
 
326
- Example::
452
+ def get(self, key: K, default: V | None = None) -> V | None:
453
+ """
454
+ Get a value with a default.
455
+
456
+ Parameters
457
+ ----------
458
+ key : K
459
+ The key to look up.
460
+ default : V or None, optional
461
+ The default value to return if key is not found.
462
+
463
+ Returns
464
+ -------
465
+ V or None
466
+ The value associated with the key, or default.
467
+ """
468
+ try:
469
+ return self[key]
470
+ except KeyError:
471
+ return default
327
472
 
328
- >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
329
- >>> new_variables, params = variables.pop('params')
473
+ def copy(self, add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
474
+ """
475
+ Create a new FrozenDict with added or replaced entries.
476
+
477
+ Parameters
478
+ ----------
479
+ add_or_replace : Mapping or None, optional
480
+ Entries to add or replace in the new dictionary.
481
+
482
+ Returns
483
+ -------
484
+ FrozenDict
485
+ A new FrozenDict with the updates applied.
486
+
487
+ Examples
488
+ --------
489
+ .. code-block:: python
490
+
491
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
492
+ >>> fd2 = fd.copy({'b': 3, 'c': 4})
493
+ >>> fd2['b'], fd2['c']
494
+ (3, 4)
495
+ """
496
+ if add_or_replace is None:
497
+ add_or_replace = {}
498
+ new_data = dict(self._data)
499
+ new_data.update(add_or_replace)
500
+ return type(self)(new_data)
330
501
 
331
- Args:
332
- key: the key to remove from the dict
333
- Returns:
334
- A pair with the new FrozenDict and the removed value.
502
+ def pop(self, key: K) -> tuple[FrozenDict[K, V], V]:
335
503
  """
504
+ Create a new FrozenDict with one entry removed.
505
+
506
+ Parameters
507
+ ----------
508
+ key : K
509
+ The key to remove.
510
+
511
+ Returns
512
+ -------
513
+ tuple
514
+ A tuple of (new FrozenDict without the key, removed value).
515
+
516
+ Raises
517
+ ------
518
+ KeyError
519
+ If the key is not found in the dictionary.
520
+
521
+ Examples
522
+ --------
523
+ .. code-block:: python
524
+
525
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
526
+ >>> fd2, value = fd.pop('a')
527
+ >>> value
528
+ 1
529
+ >>> 'a' in fd2
530
+ False
531
+ """
532
+ if key not in self._data:
533
+ raise KeyError(key)
336
534
  value = self[key]
337
- new_dict = dict(self._dict)
338
- new_dict.pop(key)
339
- new_self = type(self)(new_dict)
340
- return new_self, value
535
+ new_data = dict(self._data)
536
+ del new_data[key]
537
+ return type(self)(new_data), value
341
538
 
342
539
  def unfreeze(self) -> dict[K, V]:
343
- """Unfreeze this FrozenDict.
344
-
345
- Returns:
346
- An unfrozen version of this FrozenDict instance.
540
+ """
541
+ Convert to a regular mutable dictionary.
542
+
543
+ Returns
544
+ -------
545
+ dict
546
+ A mutable dict with the same contents.
547
+
548
+ Examples
549
+ --------
550
+ .. code-block:: python
551
+
552
+ >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
553
+ >>> d = fd.unfreeze()
554
+ >>> isinstance(d, dict)
555
+ True
556
+ >>> isinstance(d['b'], dict) # Nested dicts also unfrozen
557
+ True
347
558
  """
348
559
  return unfreeze(self)
349
560
 
350
- def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]:
351
- """Flattens this FrozenDict.
561
+ def pretty_repr(self, indent: int = 2) -> str:
562
+ """
563
+ Return a pretty-printed representation.
352
564
 
353
- Returns:
354
- A flattened version of this FrozenDict instance.
565
+ Parameters
566
+ ----------
567
+ indent : int, optional
568
+ Number of spaces per indentation level (default 2).
569
+
570
+ Returns
571
+ -------
572
+ str
573
+ A formatted string representation of the FrozenDict.
355
574
  """
356
- sorted_keys = sorted(self._dict)
357
- return tuple(
358
- [(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]
359
- ), tuple(sorted_keys)
575
+
576
+ def format_value(v, level):
577
+ if isinstance(v, dict):
578
+ if not v:
579
+ return '{}'
580
+ items = []
581
+ for k, val in v.items():
582
+ formatted_val = format_value(val, level + 1)
583
+ items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted_val}')
584
+ return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
585
+ else:
586
+ return repr(v)
587
+
588
+ if not self._data:
589
+ return 'FrozenDict({})'
590
+
591
+ return f'FrozenDict({format_value(self._data, 0)})'
592
+
593
+ def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]:
594
+ """Flatten for JAX pytree with keys."""
595
+ sorted_keys = sorted(self._data.keys())
596
+ values_with_keys = [
597
+ (jax.tree_util.DictKey(k), self._data[k])
598
+ for k in sorted_keys
599
+ ]
600
+ return values_with_keys, tuple(sorted_keys)
360
601
 
361
602
  @classmethod
362
- def tree_unflatten(cls, keys, values):
363
- # data is already deep copied due to tree map mechanism
364
- # we can skip the deep copy in the constructor
365
- return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
366
-
367
-
368
- def _prepare_freeze(xs: Any) -> Any:
369
- """Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
370
- if isinstance(xs, FrozenDict):
371
- # we can safely ref share the internal state of a FrozenDict
372
- # because it is immutable.
373
- return xs._dict # pylint: disable=protected-access
374
- if not isinstance(xs, dict):
375
- # return a leaf as is.
376
- return xs
377
- # recursively copy dictionary to avoid ref sharing
378
- return {key: _prepare_freeze(val) for key, val in xs.items()}
379
-
380
-
381
- def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
382
- """Freeze a nested dict.
383
-
384
- Makes a nested ``dict`` immutable by transforming it into ``FrozenDict``.
385
-
386
- Args:
387
- xs: Dictionary to freeze (a regualr Python dict).
388
- Returns:
389
- The frozen dictionary.
390
- """
391
- return FrozenDict(xs)
603
+ def tree_unflatten(cls, keys: tuple[Any, ...], values: list[Any]) -> FrozenDict:
604
+ """Unflatten from JAX pytree."""
605
+ return cls(dict(zip(keys, values)))
606
+
607
+
608
+ class FrozenKeysView(KeysView[K]):
609
+ """View of keys in a FrozenDict."""
392
610
 
611
+ def __repr__(self) -> str:
612
+ return f'FrozenDict.keys({list(self)})'
393
613
 
394
- def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]:
395
- """Unfreeze a FrozenDict.
396
614
 
397
- Makes a mutable copy of a ``FrozenDict`` mutable by transforming
398
- it into (nested) dict.
615
+ class FrozenValuesView(ValuesView[V]):
616
+ """View of values in a FrozenDict."""
399
617
 
400
- Args:
401
- x: Frozen dictionary to unfreeze.
402
- Returns:
403
- The unfrozen dictionary (a regular Python dict).
618
+ def __repr__(self) -> str:
619
+ return f'FrozenDict.values({list(self)})'
620
+
621
+
622
+ def freeze(x: Mapping[K, V]) -> FrozenDict[K, V]:
623
+ """
624
+ Convert a mapping to a FrozenDict.
625
+
626
+ Parameters
627
+ ----------
628
+ x : Mapping
629
+ A mapping (dict, FrozenDict, etc.) to freeze.
630
+
631
+ Returns
632
+ -------
633
+ FrozenDict
634
+ An immutable FrozenDict.
635
+
636
+ See Also
637
+ --------
638
+ unfreeze : Convert a FrozenDict to a regular dict.
639
+ FrozenDict : The immutable dictionary class.
640
+
641
+ Examples
642
+ --------
643
+ .. code-block:: python
644
+
645
+ >>> from brainstate.util import freeze
646
+
647
+ >>> d = {'a': 1, 'b': {'c': 2}}
648
+ >>> fd = freeze(d)
649
+ >>> isinstance(fd, FrozenDict)
650
+ True
651
+ >>> isinstance(fd['b'], FrozenDict) # Nested dicts are frozen
652
+ True
653
+ """
654
+ if isinstance(x, FrozenDict):
655
+ return x
656
+ return FrozenDict(x)
657
+
658
+
659
+ def unfreeze(x: FrozenDict[K, V] | dict[K, V]) -> dict[K, V]:
660
+ """
661
+ Convert a FrozenDict to a regular dict.
662
+
663
+ Recursively converts FrozenDict instances to mutable dicts.
664
+
665
+ Parameters
666
+ ----------
667
+ x : FrozenDict or dict
668
+ A FrozenDict or dict to unfreeze.
669
+
670
+ Returns
671
+ -------
672
+ dict
673
+ A mutable dictionary.
674
+
675
+ See Also
676
+ --------
677
+ freeze : Convert a mapping to a FrozenDict.
678
+ FrozenDict : The immutable dictionary class.
679
+
680
+ Examples
681
+ --------
682
+ .. code-block:: python
683
+
684
+ >>> from brainstate.util import FrozenDict, unfreeze
685
+
686
+ >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
687
+ >>> d = unfreeze(fd)
688
+ >>> isinstance(d, dict)
689
+ True
690
+ >>> isinstance(d['b'], dict) # Nested FrozenDicts are unfrozen
691
+ True
692
+ >>> d['a'] = 10 # Can modify the result
404
693
  """
405
694
  if isinstance(x, FrozenDict):
406
- # deep copy internal state of a FrozenDict
407
- # the dict branch would also work here but
408
- # it is much less performant because jax.tree_util.tree_map
409
- # uses an optimized C implementation.
410
- return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore
695
+ result = {}
696
+ for key, value in x._data.items():
697
+ result[key] = unfreeze(value)
698
+ return result
411
699
  elif isinstance(x, dict):
412
- ys = {}
700
+ result = {}
413
701
  for key, value in x.items():
414
- ys[key] = unfreeze(value)
415
- return ys
702
+ result[key] = unfreeze(value)
703
+ return result
416
704
  else:
417
705
  return x
418
706
 
419
707
 
420
- def copy(
421
- x: FrozenDict | dict[str, Any],
422
- add_or_replace: FrozenDict[str, Any] | dict[str, Any] = FrozenDict({}),
423
- ) -> FrozenDict | dict[str, Any]:
424
- """Create a new dict with additional and/or replaced entries. This is a utility
425
- function that can act on either a FrozenDict or regular dict and mimics the
426
- behavior of ``FrozenDict.copy``.
708
+ @overload
709
+ def copy(x: FrozenDict[K, V], add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
710
+ ...
427
711
 
428
- Example::
429
712
 
430
- >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
431
- >>> new_variables = copy(variables, {'additional_entries': 1})
713
+ @overload
714
+ def copy(x: dict[K, V], add_or_replace: Mapping[K, V] | None = None) -> dict[K, V]:
715
+ ...
432
716
 
433
- Args:
434
- x: the dictionary to be copied and updated
435
- add_or_replace: dictionary of key-value pairs to add or replace in the dict x
436
- Returns:
437
- A new dict with the additional and/or replaced entries.
717
+
718
+ def copy(x, add_or_replace=None):
719
+ """
720
+ Copy a dictionary with optional updates.
721
+
722
+ Works with both FrozenDict and regular dict.
723
+
724
+ Parameters
725
+ ----------
726
+ x : FrozenDict or dict
727
+ Dictionary to copy.
728
+ add_or_replace : Mapping or None, optional
729
+ Entries to add or replace in the copy.
730
+
731
+ Returns
732
+ -------
733
+ FrozenDict or dict
734
+ A copy of the same type as the input with updates applied.
735
+
736
+ Raises
737
+ ------
738
+ TypeError
739
+ If x is not a FrozenDict or dict.
740
+
741
+ See Also
742
+ --------
743
+ FrozenDict.copy : Copy method for FrozenDict.
744
+
745
+ Examples
746
+ --------
747
+ .. code-block:: python
748
+
749
+ >>> from brainstate.util import FrozenDict, copy
750
+
751
+ >>> # Works with FrozenDict
752
+ >>> fd = FrozenDict({'a': 1})
753
+ >>> fd2 = copy(fd, {'b': 2})
754
+ >>> isinstance(fd2, FrozenDict)
755
+ True
756
+ >>> fd2['b']
757
+ 2
758
+
759
+ >>> # Also works with regular dict
760
+ >>> d = {'a': 1}
761
+ >>> d2 = copy(d, {'b': 2})
762
+ >>> isinstance(d2, dict)
763
+ True
764
+ >>> d2['b']
765
+ 2
438
766
  """
767
+ if add_or_replace is None:
768
+ add_or_replace = {}
439
769
 
440
770
  if isinstance(x, FrozenDict):
441
771
  return x.copy(add_or_replace)
442
772
  elif isinstance(x, dict):
443
- new_dict = jax.tree_util.tree_map(
444
- lambda x: x, x
445
- ) # make a deep copy of dict x
446
- new_dict.update(add_or_replace)
447
- return new_dict
448
- raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
449
-
450
-
451
- def pop(
452
- x: FrozenDict | dict[str, Any], key: str
453
- ) -> tuple[FrozenDict | dict[str, Any], Any]:
454
- """Create a new dict where one entry is removed. This is a utility
455
- function that can act on either a FrozenDict or regular dict and
456
- mimics the behavior of ``FrozenDict.pop``.
457
-
458
- Example::
459
-
460
- >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
461
- >>> new_variables, params = pop(variables, 'params')
462
-
463
- Args:
464
- x: the dictionary to remove the entry from
465
- key: the key to remove from the dict
466
- Returns:
467
- A pair with the new dict and the removed value.
468
- """
773
+ result = dict(x)
774
+ result.update(add_or_replace)
775
+ return result
776
+ else:
777
+ raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
778
+
779
+
780
+ @overload
781
+ def pop(x: FrozenDict[K, V], key: K) -> tuple[FrozenDict[K, V], V]:
782
+ ...
783
+
784
+
785
+ @overload
786
+ def pop(x: dict[K, V], key: K) -> tuple[dict[K, V], V]:
787
+ ...
788
+
469
789
 
790
+ def pop(x, key):
791
+ """
792
+ Remove and return an item from a dictionary.
793
+
794
+ Works with both FrozenDict and regular dict, returning a new
795
+ dictionary without the specified key along with the popped value.
796
+
797
+ Parameters
798
+ ----------
799
+ x : FrozenDict or dict
800
+ Dictionary to pop from.
801
+ key : hashable
802
+ Key to remove.
803
+
804
+ Returns
805
+ -------
806
+ tuple
807
+ A tuple of (new dictionary without the key, popped value).
808
+
809
+ Raises
810
+ ------
811
+ TypeError
812
+ If x is not a FrozenDict or dict.
813
+ KeyError
814
+ If the key is not found in the dictionary.
815
+
816
+ See Also
817
+ --------
818
+ FrozenDict.pop : Pop method for FrozenDict.
819
+
820
+ Examples
821
+ --------
822
+ .. code-block:: python
823
+
824
+ >>> from brainstate.util import FrozenDict, pop
825
+
826
+ >>> # Works with FrozenDict
827
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
828
+ >>> fd2, value = pop(fd, 'a')
829
+ >>> value
830
+ 1
831
+ >>> 'a' in fd2
832
+ False
833
+
834
+ >>> # Also works with regular dict
835
+ >>> d = {'a': 1, 'b': 2}
836
+ >>> d2, value = pop(d, 'a')
837
+ >>> value
838
+ 1
839
+ >>> 'a' in d2
840
+ False
841
+ """
470
842
  if isinstance(x, FrozenDict):
471
843
  return x.pop(key)
472
844
  elif isinstance(x, dict):
473
- new_dict = jax.tree_util.tree_map(
474
- lambda x: x, x
475
- ) # make a deep copy of dict x
845
+ new_dict = dict(x)
476
846
  value = new_dict.pop(key)
477
847
  return new_dict, value
478
- raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
479
-
848
+ else:
849
+ raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
480
850
 
481
- def pretty_repr(x: Any, num_spaces: int = 4) -> str:
482
- """Returns an indented representation of the nested dictionary.
483
- This is a utility function that can act on either a FrozenDict or
484
- regular dict and mimics the behavior of ``FrozenDict.pretty_repr``.
485
- If x is any other dtype, this function will return ``repr(x)``.
486
851
 
487
- Args:
488
- x: the dictionary to be represented
489
- num_spaces: the number of space characters in each indentation level
490
- Returns:
491
- An indented string representation of the nested dictionary.
852
+ def pretty_repr(x: Any, indent: int = 2) -> str:
853
+ """
854
+ Create a pretty string representation.
855
+
856
+ Parameters
857
+ ----------
858
+ x : any
859
+ Object to represent. If a dict or FrozenDict, will be
860
+ pretty-printed with indentation. Otherwise, returns repr(x).
861
+ indent : int, optional
862
+ Number of spaces per indentation level (default 2).
863
+
864
+ Returns
865
+ -------
866
+ str
867
+ A formatted string representation.
868
+
869
+ See Also
870
+ --------
871
+ FrozenDict.pretty_repr : Pretty representation for FrozenDict.
872
+
873
+ Examples
874
+ --------
875
+ .. code-block:: python
876
+
877
+ >>> from brainstate.util import pretty_repr
878
+
879
+ >>> d = {'a': 1, 'b': {'c': 2, 'd': 3}}
880
+ >>> print(pretty_repr(d))
881
+ {
882
+ 'a': 1,
883
+ 'b': {
884
+ 'c': 2,
885
+ 'd': 3
886
+ }
887
+ }
888
+
889
+ >>> # Non-dict objects return normal repr
890
+ >>> pretty_repr([1, 2, 3])
891
+ '[1, 2, 3]'
492
892
  """
493
-
494
893
  if isinstance(x, FrozenDict):
495
- return x.pretty_repr()
496
- else:
497
-
498
- def pretty_dict(x):
499
- if not isinstance(x, dict):
500
- return repr(x)
501
- rep = ''
502
- for key, val in x.items():
503
- rep += f'{key}: {pretty_dict(val)},\n'
504
- if rep:
505
- return '{\n' + _indent(rep, num_spaces) + '}'
506
- else:
894
+ return x.pretty_repr(indent)
895
+ elif isinstance(x, dict):
896
+ def format_dict(d, level):
897
+ if not d:
507
898
  return '{}'
508
-
509
- return pretty_dict(x)
510
-
511
-
512
- class FrozenKeysView(collections.abc.KeysView):
513
- """A wrapper for a more useful repr of the keys in a frozen dict."""
514
-
515
- def __repr__(self):
516
- return f'frozen_dict_keys({list(self)})'
517
-
518
-
519
- class FrozenValuesView(collections.abc.ValuesView):
520
- """A wrapper for a more useful repr of the values in a frozen dict."""
521
-
522
- def __repr__(self):
523
- return f'frozen_dict_values({list(self)})'
899
+ items = []
900
+ for k, v in d.items():
901
+ if isinstance(v, dict):
902
+ formatted = format_dict(v, level + 1)
903
+ else:
904
+ formatted = repr(v)
905
+ items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted}')
906
+ return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
907
+
908
+ return format_dict(x, 0)
909
+ else:
910
+ return repr(x)