brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/util/struct.py CHANGED
@@ -1,523 +1,910 @@
1
- # The file is adapted from the Flax library (https://github.com/google/flax).
2
- # The credit should go to the Flax authors.
3
- #
4
- # Copyright 2024 The Flax Authors.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- """Utilities for defining custom classes that can be used with jax transformations."""
19
-
20
- import collections
21
- import dataclasses
22
- from collections.abc import Hashable, Mapping
23
- from types import MappingProxyType
24
- from typing import Any, TypeVar
25
-
26
- import jax
27
- from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
28
-
29
- __all__ = [
30
- 'dataclass',
31
- 'field',
32
- 'PyTreeNode',
33
- 'FrozenDict',
34
- ]
35
-
36
- K = TypeVar('K')
37
- V = TypeVar('V')
38
- T = TypeVar('T')
39
-
40
-
41
- def field(pytree_node=True, *, metadata=None, **kwargs):
42
- return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node}, **kwargs)
43
-
44
-
45
- @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
46
- def dataclass(clz: T, **kwargs) -> T:
47
- """
48
- Create a class which can be passed to functional transformations.
49
-
50
- .. note::
51
- Inherit from ``PyTreeNode`` instead to avoid type checking issues when
52
- using PyType.
53
-
54
- Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are
55
- immutable and can be mapped over using the ``jax.tree_util`` methods.
56
- The ``dataclass`` decorator makes it easy to define custom classes that can be
57
- passed safely to Jax. For example::
58
-
59
- >>> import brainstate as brainstate
60
- >>> import jax
61
- >>> from typing import Any, Callable
62
-
63
- >>> @brainstate.util.dataclass
64
- ... class Model:
65
- ... params: Any
66
- ... # use pytree_node=False to indicate an attribute should not be touched
67
- ... # by Jax transformations.
68
- ... apply_fn: Callable = brainstate.util.field(pytree_node=False)
69
-
70
- ... def __apply__(self, *args):
71
- ... return self.apply_fn(*args)
72
-
73
- >>> params = {}
74
- >>> params_b = {}
75
- >>> apply_fn = lambda v, x: x
76
- >>> model = Model(params, apply_fn)
77
-
78
- >>> # model.params = params_b # Model is immutable. This will raise an error.
79
- >>> model_b = model.replace(params=params_b) # Use the replace method instead.
80
-
81
- >>> # This class can now be used safely in Jax to compute gradients w.r.t. the
82
- >>> # parameters.
83
- >>> model = Model(params, apply_fn)
84
- >>> loss_fn = lambda model: 3.
85
- >>> model_grad = jax.grad(loss_fn)(model)
86
-
87
- Note that dataclasses have an auto-generated ``__init__`` where
88
- the arguments of the constructor and the attributes of the created
89
- instance match 1:1. This correspondence is what makes these objects
90
- valid containers that work with JAX transformations and
91
- more generally the ``jax.tree_util`` library.
92
-
93
- Sometimes a "smart constructor" is desired, for example because
94
- some of the attributes can be (optionally) derived from others.
95
- The way to do this with Flax dataclasses is to make a static or
96
- class method that provides the smart constructor.
97
- This way the simple constructor used by ``jax.tree_util`` is
98
- preserved. Consider the following example::
99
-
100
- >>> @brainstate.util.dataclass
101
- ... class DirectionAndScaleKernel:
102
- ... direction: jax.Array
103
- ... scale: jax.Array
104
-
105
- ... @classmethod
106
- ... def create(cls, kernel):
107
- ... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
108
- ... direction = direction / scale
109
- ... return cls(direction, scale)
110
-
111
- Args:
112
- clz: the class that will be transformed by the decorator.
113
- Returns:
114
- The new class.
115
- """
116
- # check if already a flax dataclass
117
- if '_brainstate_dataclass' in clz.__dict__:
118
- return clz
119
-
120
- if 'frozen' not in kwargs.keys():
121
- kwargs['frozen'] = True
122
- data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
123
- meta_fields = []
124
- data_fields = []
125
- for field_info in dataclasses.fields(data_clz):
126
- is_pytree_node = field_info.metadata.get('pytree_node', True)
127
- if is_pytree_node:
128
- data_fields.append(field_info.name)
129
- else:
130
- meta_fields.append(field_info.name)
131
-
132
- def replace(self, **updates):
133
- """ "Returns a new object replacing the specified fields with new values."""
134
- return dataclasses.replace(self, **updates)
135
-
136
- data_clz.replace = replace
137
-
138
- # Remove this guard once minimux JAX version is >0.4.26.
139
- try:
140
- if hasattr(jax.tree_util, 'register_dataclass'):
141
- jax.tree_util.register_dataclass(
142
- data_clz, data_fields, meta_fields
143
- )
144
- else:
145
- raise NotImplementedError
146
- except NotImplementedError:
147
-
148
- def iterate_clz(x):
149
- meta = tuple(getattr(x, name) for name in meta_fields)
150
- data = tuple(getattr(x, name) for name in data_fields)
151
- return data, meta
152
-
153
- def iterate_clz_with_keys(x):
154
- meta = tuple(getattr(x, name) for name in meta_fields)
155
- data = tuple(
156
- (jax.tree_util.GetAttrKey(name), getattr(x, name))
157
- for name in data_fields
158
- )
159
- return data, meta
160
-
161
- def clz_from_iterable(meta, data):
162
- meta_args = tuple(zip(meta_fields, meta))
163
- data_args = tuple(zip(data_fields, data))
164
- kwargs = dict(meta_args + data_args)
165
- return data_clz(**kwargs)
166
-
167
- jax.tree_util.register_pytree_with_keys(
168
- data_clz,
169
- iterate_clz_with_keys,
170
- clz_from_iterable,
171
- iterate_clz,
172
- )
173
-
174
- # add a _brainstate_dataclass flag to distinguish from regular dataclasses
175
- data_clz._brainstate_dataclass = True # type: ignore[attr-defined]
176
-
177
- return data_clz # type: ignore
178
-
179
-
180
- TNode = TypeVar('TNode', bound='PyTreeNode')
181
-
182
-
183
- @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
184
- class PyTreeNode:
185
- """Base class for dataclasses that should act like a JAX pytree node.
186
-
187
- See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.
188
- This base class additionally avoids type checking errors when using PyType.
189
-
190
- Example::
191
-
192
- >>> import brainstate as brainstate
193
- >>> import jax
194
- >>> from typing import Any, Callable
195
-
196
- >>> class Model(brainstate.util.PyTreeNode):
197
- ... params: Any
198
- ... # use pytree_node=False to indicate an attribute should not be touched
199
- ... # by Jax transformations.
200
- ... apply_fn: Callable = brainstate.util.field(pytree_node=False)
201
-
202
- ... def __apply__(self, *args):
203
- ... return self.apply_fn(*args)
204
-
205
- >>> params = {}
206
- >>> params_b = {}
207
- >>> apply_fn = lambda v, x: x
208
- >>> model = Model(params, apply_fn)
209
-
210
- >>> # model.params = params_b # Model is immutable. This will raise an error.
211
- >>> model_b = model.replace(params=params_b) # Use the replace method instead.
212
-
213
- >>> # This class can now be used safely in Jax to compute gradients w.r.t. the
214
- >>> # parameters.
215
- >>> model = Model(params, apply_fn)
216
- >>> loss_fn = lambda model: 3.
217
- >>> model_grad = jax.grad(loss_fn)(model)
218
- """
219
-
220
- def __init_subclass__(cls, **kwargs):
221
- dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types
222
-
223
- def __init__(self, *args, **kwargs):
224
- # stub for pytype
225
- raise NotImplementedError
226
-
227
- def replace(self: TNode, **overrides) -> TNode:
228
- # stub for pytype
229
- raise NotImplementedError
230
-
231
-
232
- def _indent(x, num_spaces):
233
- indent_str = ' ' * num_spaces
234
- lines = x.split('\n')
235
- assert not lines[-1]
236
- # skip the final line because it's empty and should not be indented.
237
- return '\n'.join(indent_str + line for line in lines[:-1]) + '\n'
238
-
239
-
240
- @jax.tree_util.register_pytree_with_keys_class
241
- class FrozenDict(Mapping[K, V]):
242
- """
243
- An immutable variant of the Python dict.
244
- """
245
-
246
- __slots__ = ('_dict', '_hash')
247
-
248
- def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name
249
- # make sure the dict is as
250
- xs = dict(*args, **kwargs)
251
- if __unsafe_skip_copy__:
252
- self._dict = xs
253
- else:
254
- self._dict = _prepare_freeze(xs)
255
-
256
- self._hash = None
257
-
258
- def __getitem__(self, key):
259
- v = self._dict[key]
260
- if isinstance(v, dict):
261
- return FrozenDict(v)
262
- return v
263
-
264
- def __setitem__(self, key, value):
265
- raise ValueError('FrozenDict is immutable.')
266
-
267
- def __contains__(self, key):
268
- return key in self._dict
269
-
270
- def __iter__(self):
271
- return iter(self._dict)
272
-
273
- def __len__(self):
274
- return len(self._dict)
275
-
276
- def __repr__(self):
277
- return self.pretty_repr()
278
-
279
- def __reduce__(self):
280
- return FrozenDict, (self.unfreeze(),)
281
-
282
- def pretty_repr(self, num_spaces=4):
283
- """Returns an indented representation of the nested dictionary."""
284
-
285
- def pretty_dict(x):
286
- if not isinstance(x, dict):
287
- return repr(x)
288
- rep = ''
289
- for key, val in x.items():
290
- rep += f'{key}: {pretty_dict(val)},\n'
291
- if rep:
292
- return '{\n' + _indent(rep, num_spaces) + '}'
293
- else:
294
- return '{}'
295
-
296
- return f'FrozenDict({pretty_dict(self._dict)})'
297
-
298
- def __hash__(self):
299
- if self._hash is None:
300
- h = 0
301
- for key, value in self.items():
302
- h ^= hash((key, value))
303
- self._hash = h
304
- return self._hash
305
-
306
- def copy(
307
- self,
308
- add_or_replace: Mapping[K, V] = MappingProxyType({})
309
- ) -> 'FrozenDict[K, V]':
310
- """Create a new FrozenDict with additional or replaced entries."""
311
- return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
312
-
313
- def keys(self):
314
- return FrozenKeysView(self)
315
-
316
- def values(self):
317
- return FrozenValuesView(self)
318
-
319
- def items(self):
320
- for key in self._dict:
321
- yield (key, self[key])
322
-
323
- def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]:
324
- """Create a new FrozenDict where one entry is removed.
325
-
326
- Example::
327
-
328
- >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
329
- >>> new_variables, params = variables.pop('params')
330
-
331
- Args:
332
- key: the key to remove from the dict
333
- Returns:
334
- A pair with the new FrozenDict and the removed value.
335
- """
336
- value = self[key]
337
- new_dict = dict(self._dict)
338
- new_dict.pop(key)
339
- new_self = type(self)(new_dict)
340
- return new_self, value
341
-
342
- def unfreeze(self) -> dict[K, V]:
343
- """Unfreeze this FrozenDict.
344
-
345
- Returns:
346
- An unfrozen version of this FrozenDict instance.
347
- """
348
- return unfreeze(self)
349
-
350
- def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]:
351
- """Flattens this FrozenDict.
352
-
353
- Returns:
354
- A flattened version of this FrozenDict instance.
355
- """
356
- sorted_keys = sorted(self._dict)
357
- return tuple(
358
- [(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]
359
- ), tuple(sorted_keys)
360
-
361
- @classmethod
362
- def tree_unflatten(cls, keys, values):
363
- # data is already deep copied due to tree map mechanism
364
- # we can skip the deep copy in the constructor
365
- return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
366
-
367
-
368
- def _prepare_freeze(xs: Any) -> Any:
369
- """Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
370
- if isinstance(xs, FrozenDict):
371
- # we can safely ref share the internal state of a FrozenDict
372
- # because it is immutable.
373
- return xs._dict # pylint: disable=protected-access
374
- if not isinstance(xs, dict):
375
- # return a leaf as is.
376
- return xs
377
- # recursively copy dictionary to avoid ref sharing
378
- return {key: _prepare_freeze(val) for key, val in xs.items()}
379
-
380
-
381
- def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
382
- """Freeze a nested dict.
383
-
384
- Makes a nested ``dict`` immutable by transforming it into ``FrozenDict``.
385
-
386
- Args:
387
- xs: Dictionary to freeze (a regualr Python dict).
388
- Returns:
389
- The frozen dictionary.
390
- """
391
- return FrozenDict(xs)
392
-
393
-
394
- def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]:
395
- """Unfreeze a FrozenDict.
396
-
397
- Makes a mutable copy of a ``FrozenDict`` mutable by transforming
398
- it into (nested) dict.
399
-
400
- Args:
401
- x: Frozen dictionary to unfreeze.
402
- Returns:
403
- The unfrozen dictionary (a regular Python dict).
404
- """
405
- if isinstance(x, FrozenDict):
406
- # deep copy internal state of a FrozenDict
407
- # the dict branch would also work here but
408
- # it is much less performant because jax.tree_util.tree_map
409
- # uses an optimized C implementation.
410
- return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore
411
- elif isinstance(x, dict):
412
- ys = {}
413
- for key, value in x.items():
414
- ys[key] = unfreeze(value)
415
- return ys
416
- else:
417
- return x
418
-
419
-
420
- def copy(
421
- x: FrozenDict | dict[str, Any],
422
- add_or_replace: FrozenDict[str, Any] | dict[str, Any] = FrozenDict({}),
423
- ) -> FrozenDict | dict[str, Any]:
424
- """Create a new dict with additional and/or replaced entries. This is a utility
425
- function that can act on either a FrozenDict or regular dict and mimics the
426
- behavior of ``FrozenDict.copy``.
427
-
428
- Example::
429
-
430
- >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
431
- >>> new_variables = copy(variables, {'additional_entries': 1})
432
-
433
- Args:
434
- x: the dictionary to be copied and updated
435
- add_or_replace: dictionary of key-value pairs to add or replace in the dict x
436
- Returns:
437
- A new dict with the additional and/or replaced entries.
438
- """
439
-
440
- if isinstance(x, FrozenDict):
441
- return x.copy(add_or_replace)
442
- elif isinstance(x, dict):
443
- new_dict = jax.tree_util.tree_map(
444
- lambda x: x, x
445
- ) # make a deep copy of dict x
446
- new_dict.update(add_or_replace)
447
- return new_dict
448
- raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
449
-
450
-
451
- def pop(
452
- x: FrozenDict | dict[str, Any], key: str
453
- ) -> tuple[FrozenDict | dict[str, Any], Any]:
454
- """Create a new dict where one entry is removed. This is a utility
455
- function that can act on either a FrozenDict or regular dict and
456
- mimics the behavior of ``FrozenDict.pop``.
457
-
458
- Example::
459
-
460
- >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
461
- >>> new_variables, params = pop(variables, 'params')
462
-
463
- Args:
464
- x: the dictionary to remove the entry from
465
- key: the key to remove from the dict
466
- Returns:
467
- A pair with the new dict and the removed value.
468
- """
469
-
470
- if isinstance(x, FrozenDict):
471
- return x.pop(key)
472
- elif isinstance(x, dict):
473
- new_dict = jax.tree_util.tree_map(
474
- lambda x: x, x
475
- ) # make a deep copy of dict x
476
- value = new_dict.pop(key)
477
- return new_dict, value
478
- raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
479
-
480
-
481
- def pretty_repr(x: Any, num_spaces: int = 4) -> str:
482
- """Returns an indented representation of the nested dictionary.
483
- This is a utility function that can act on either a FrozenDict or
484
- regular dict and mimics the behavior of ``FrozenDict.pretty_repr``.
485
- If x is any other dtype, this function will return ``repr(x)``.
486
-
487
- Args:
488
- x: the dictionary to be represented
489
- num_spaces: the number of space characters in each indentation level
490
- Returns:
491
- An indented string representation of the nested dictionary.
492
- """
493
-
494
- if isinstance(x, FrozenDict):
495
- return x.pretty_repr()
496
- else:
497
-
498
- def pretty_dict(x):
499
- if not isinstance(x, dict):
500
- return repr(x)
501
- rep = ''
502
- for key, val in x.items():
503
- rep += f'{key}: {pretty_dict(val)},\n'
504
- if rep:
505
- return '{\n' + _indent(rep, num_spaces) + '}'
506
- else:
507
- return '{}'
508
-
509
- return pretty_dict(x)
510
-
511
-
512
- class FrozenKeysView(collections.abc.KeysView):
513
- """A wrapper for a more useful repr of the keys in a frozen dict."""
514
-
515
- def __repr__(self):
516
- return f'frozen_dict_keys({list(self)})'
517
-
518
-
519
- class FrozenValuesView(collections.abc.ValuesView):
520
- """A wrapper for a more useful repr of the values in a frozen dict."""
521
-
522
- def __repr__(self):
523
- return f'frozen_dict_values({list(self)})'
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Custom data structures that work seamlessly with JAX transformations.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import dataclasses
25
+ from collections.abc import Mapping, KeysView, ValuesView, ItemsView
26
+ from typing import Any, TypeVar, Generic, Iterator, overload
27
+
28
+ import jax
29
+ import jax.tree_util
30
+ from typing_extensions import dataclass_transform
31
+
32
+ __all__ = [
33
+ 'field',
34
+ 'dataclass',
35
+ 'PyTreeNode',
36
+ 'FrozenDict',
37
+ 'freeze',
38
+ 'unfreeze',
39
+ 'copy',
40
+ 'pop',
41
+ 'pretty_repr',
42
+ ]
43
+
44
+ # Type variables
45
+ K = TypeVar('K')
46
+ V = TypeVar('V')
47
+ T = TypeVar('T')
48
+ TNode = TypeVar('TNode', bound='PyTreeNode')
49
+
50
+
51
+ def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field:
52
+ """
53
+ Create a dataclass field with JAX pytree metadata.
54
+
55
+ Parameters
56
+ ----------
57
+ pytree_node : bool, optional
58
+ If True (default), this field will be treated as part of the pytree.
59
+ If False, it will be treated as metadata and not be touched
60
+ by JAX transformations.
61
+ **kwargs
62
+ Additional arguments to pass to dataclasses.field().
63
+
64
+ Returns
65
+ -------
66
+ dataclasses.Field
67
+ A dataclass field with the appropriate metadata.
68
+
69
+ Examples
70
+ --------
71
+ .. code-block:: python
72
+
73
+ >>> import jax.numpy as jnp
74
+ >>> from brainstate.util import dataclass, field
75
+
76
+ >>> @dataclass
77
+ ... class Model:
78
+ ... weights: jnp.ndarray
79
+ ... bias: jnp.ndarray
80
+ ... # This field won't be affected by JAX transformations
81
+ ... name: str = field(pytree_node=False, default="model")
82
+ """
83
+ metadata = kwargs.pop('metadata', {})
84
+ metadata['pytree_node'] = pytree_node
85
+ return dataclasses.field(metadata=metadata, **kwargs)
86
+
87
+
88
+ @dataclass_transform(field_specifiers=(field,))
89
+ def dataclass(cls: type[T], **kwargs) -> type[T]:
90
+ """
91
+ Create a dataclass that works with JAX transformations.
92
+
93
+ This decorator creates immutable dataclasses that can be used safely
94
+ with JAX transformations like jit, grad, vmap, etc. The created class
95
+ will be registered as a JAX pytree node.
96
+
97
+ Parameters
98
+ ----------
99
+ cls : type
100
+ The class to decorate.
101
+ **kwargs
102
+ Additional arguments for dataclasses.dataclass().
103
+ If 'frozen' is not specified, it defaults to True.
104
+
105
+ Returns
106
+ -------
107
+ type
108
+ The decorated class as an immutable JAX-compatible dataclass.
109
+
110
+ See Also
111
+ --------
112
+ PyTreeNode : Base class for creating JAX-compatible pytree nodes.
113
+ field : Create dataclass fields with pytree metadata.
114
+
115
+ Notes
116
+ -----
117
+ The decorated class will be frozen (immutable) by default to ensure
118
+ compatibility with JAX's functional programming paradigm.
119
+
120
+ Examples
121
+ --------
122
+ .. code-block:: python
123
+
124
+ >>> import jax
125
+ >>> import jax.numpy as jnp
126
+ >>> from brainstate.util import dataclass, field
127
+
128
+ >>> @dataclass
129
+ ... class Model:
130
+ ... weights: jax.Array
131
+ ... bias: jax.Array
132
+ ... name: str = field(pytree_node=False, default="model")
133
+
134
+ >>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))
135
+
136
+ >>> # JAX transformations will only apply to weights and bias, not name
137
+ >>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
138
+ >>> grads = grad_fn(model)
139
+
140
+ >>> # Use replace to create modified copies
141
+ >>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)
142
+ """
143
+ # Check if already converted
144
+ if hasattr(cls, '_brainstate_dataclass'):
145
+ return cls
146
+
147
+ # Default to frozen for immutability
148
+ kwargs.setdefault('frozen', True)
149
+
150
+ # Apply standard dataclass decorator
151
+ cls = dataclasses.dataclass(**kwargs)(cls)
152
+
153
+ # Separate fields into pytree and metadata
154
+ pytree_fields = []
155
+ meta_fields = []
156
+
157
+ for field_info in dataclasses.fields(cls):
158
+ if field_info.metadata.get('pytree_node', True):
159
+ pytree_fields.append(field_info.name)
160
+ else:
161
+ meta_fields.append(field_info.name)
162
+
163
+ # Add replace method
164
+ def replace(self: T, **updates) -> T:
165
+ """Replace specified fields with new values."""
166
+ return dataclasses.replace(self, **updates)
167
+
168
+ cls.replace = replace
169
+
170
+ # Register with JAX
171
+ _register_pytree(cls, pytree_fields, meta_fields)
172
+
173
+ # Mark as BrainState dataclass
174
+ cls._brainstate_dataclass = True
175
+
176
+ return cls
177
+
178
+
179
+ def _register_pytree(cls: type, pytree_fields: list[str], meta_fields: list[str]) -> None:
180
+ """Register a class as a JAX pytree."""
181
+
182
+ def flatten_fn(obj):
183
+ pytree_data = tuple(getattr(obj, name) for name in pytree_fields)
184
+ metadata = tuple(getattr(obj, name) for name in meta_fields)
185
+ return pytree_data, metadata
186
+
187
+ def flatten_with_keys_fn(obj):
188
+ pytree_data = tuple(
189
+ (jax.tree_util.GetAttrKey(name), getattr(obj, name))
190
+ for name in pytree_fields
191
+ )
192
+ metadata = tuple(getattr(obj, name) for name in meta_fields)
193
+ return pytree_data, metadata
194
+
195
+ def unflatten_fn(metadata, pytree_data):
196
+ kwargs = {}
197
+ for name, value in zip(meta_fields, metadata):
198
+ kwargs[name] = value
199
+ for name, value in zip(pytree_fields, pytree_data):
200
+ kwargs[name] = value
201
+ return cls(**kwargs)
202
+
203
+ # Use new API if available, otherwise fall back
204
+ if hasattr(jax.tree_util, 'register_dataclass'):
205
+ jax.tree_util.register_dataclass(cls, pytree_fields, meta_fields)
206
+ else:
207
+ jax.tree_util.register_pytree_with_keys(
208
+ cls,
209
+ flatten_with_keys_fn,
210
+ unflatten_fn,
211
+ flatten_fn
212
+ )
213
+
214
+
215
+ @dataclass_transform(field_specifiers=(field,))
216
+ class PyTreeNode:
217
+ """
218
+ Base class for creating JAX-compatible pytree nodes.
219
+
220
+ Subclasses of PyTreeNode are automatically converted to immutable
221
+ dataclasses that work with JAX transformations.
222
+
223
+ See Also
224
+ --------
225
+ dataclass : Decorator for creating JAX-compatible dataclasses.
226
+ field : Create dataclass fields with pytree metadata.
227
+
228
+ Notes
229
+ -----
230
+ When subclassing PyTreeNode, all fields are automatically treated as
231
+ part of the pytree unless explicitly marked with ``pytree_node=False``
232
+ using the field() function.
233
+
234
+ Examples
235
+ --------
236
+ .. code-block:: python
237
+
238
+ >>> import jax
239
+ >>> import jax.numpy as jnp
240
+ >>> from brainstate.util import PyTreeNode, field
241
+
242
+ >>> class Layer(PyTreeNode):
243
+ ... weights: jax.Array
244
+ ... bias: jax.Array
245
+ ... activation: str = field(pytree_node=False, default="relu")
246
+
247
+ >>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))
248
+
249
+ >>> # Can be used in JAX transformations
250
+ >>> def loss_fn(layer):
251
+ ... return jnp.sum(layer.weights ** 2)
252
+ >>> grad_fn = jax.grad(loss_fn)
253
+ >>> grads = grad_fn(layer)
254
+
255
+ >>> # Create modified copies with replace
256
+ >>> layer2 = layer.replace(bias=jnp.ones(4))
257
+ """
258
+
259
+ def __init_subclass__(cls, **kwargs):
260
+ """Automatically apply dataclass decorator to subclasses."""
261
+ dataclass(cls, **kwargs)
262
+
263
+ def __init__(self, *args, **kwargs):
264
+ """Stub for type checkers."""
265
+ raise NotImplementedError("PyTreeNode is a base class")
266
+
267
+ def replace(self: TNode, **updates) -> TNode:
268
+ """
269
+ Replace specified fields with new values.
270
+
271
+ Parameters
272
+ ----------
273
+ **updates
274
+ Field names and their new values.
275
+
276
+ Returns
277
+ -------
278
+ TNode
279
+ A new instance with updated fields.
280
+ """
281
+ raise NotImplementedError("Implemented by dataclass decorator")
282
+
283
+
284
+ @jax.tree_util.register_pytree_with_keys_class
285
+ class FrozenDict(Mapping[K, V], Generic[K, V]):
286
+ """
287
+ An immutable dictionary that works as a JAX pytree.
288
+
289
+ FrozenDict provides an immutable mapping interface that can be used
290
+ safely with JAX transformations. It supports all standard dictionary
291
+ operations in an immutable fashion.
292
+
293
+ Parameters
294
+ ----------
295
+ *args
296
+ Positional arguments for dict construction.
297
+ **kwargs
298
+ Keyword arguments for dict construction.
299
+
300
+ Attributes
301
+ ----------
302
+ _data : dict
303
+ Internal immutable data storage.
304
+ _hash : int or None
305
+ Cached hash value.
306
+
307
+ See Also
308
+ --------
309
+ freeze : Convert a mapping to a FrozenDict.
310
+ unfreeze : Convert a FrozenDict to a regular dict.
311
+
312
+ Notes
313
+ -----
314
+ FrozenDict is immutable - all operations that would modify the dictionary
315
+ instead return a new FrozenDict instance with the changes applied.
316
+
317
+ Examples
318
+ --------
319
+ .. code-block:: python
320
+
321
+ >>> from brainstate.util import FrozenDict
322
+
323
+ >>> # Create a FrozenDict
324
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
325
+ >>> fd['a']
326
+ 1
327
+
328
+ >>> # Copy with updates (returns new FrozenDict)
329
+ >>> new_fd = fd.copy({'c': 3})
330
+ >>> new_fd['c']
331
+ 3
332
+
333
+ >>> # Pop an item (returns new dict and popped value)
334
+ >>> new_fd, value = fd.pop('b')
335
+ >>> value
336
+ 2
337
+ >>> 'b' in new_fd
338
+ False
339
+
340
+ >>> # Nested dictionaries are automatically frozen
341
+ >>> fd = FrozenDict({'x': {'y': 1}})
342
+ >>> isinstance(fd['x'], FrozenDict)
343
+ True
344
+ """
345
+
346
+ __slots__ = ('_data', '_hash')
347
+
348
+ def __init__(self, *args, **kwargs):
349
+ """Initialize a FrozenDict."""
350
+ data = dict(*args, **kwargs)
351
+ self._data = self._deep_freeze(data)
352
+ self._hash = None
353
+
354
+ @staticmethod
355
+ def _deep_freeze(obj: Any) -> Any:
356
+ """Recursively freeze nested dictionaries."""
357
+ if isinstance(obj, FrozenDict):
358
+ return obj._data
359
+ elif isinstance(obj, dict):
360
+ return {k: FrozenDict._deep_freeze(v) for k, v in obj.items()}
361
+ else:
362
+ return obj
363
+
364
+ def __getitem__(self, key: K) -> V:
365
+ """Get an item from the dictionary."""
366
+ value = self._data[key]
367
+ if isinstance(value, dict):
368
+ return FrozenDict(value)
369
+ return value
370
+
371
+ def __setitem__(self, key: K, value: V) -> None:
372
+ """Raise an error - FrozenDict is immutable."""
373
+ raise TypeError("FrozenDict does not support item assignment")
374
+
375
+ def __delitem__(self, key: K) -> None:
376
+ """Raise an error - FrozenDict is immutable."""
377
+ raise TypeError("FrozenDict does not support item deletion")
378
+
379
+ def __contains__(self, key: object) -> bool:
380
+ """Check if a key is in the dictionary."""
381
+ return key in self._data
382
+
383
+ def __iter__(self) -> Iterator[K]:
384
+ """Iterate over keys."""
385
+ return iter(self._data)
386
+
387
+ def __len__(self) -> int:
388
+ """Return the number of items."""
389
+ return len(self._data)
390
+
391
+ def __repr__(self) -> str:
392
+ """Return a string representation."""
393
+ return self.pretty_repr()
394
+
395
+ def __hash__(self) -> int:
396
+ """Return a hash of the dictionary."""
397
+ if self._hash is None:
398
+ items = []
399
+ for key, value in self.items():
400
+ if isinstance(value, dict):
401
+ value = FrozenDict(value)
402
+ items.append((key, value))
403
+ self._hash = hash(tuple(sorted(items)))
404
+ return self._hash
405
+
406
+ def __eq__(self, other: object) -> bool:
407
+ """Check equality with another object."""
408
+ if not isinstance(other, (FrozenDict, dict)):
409
+ return NotImplemented
410
+ if isinstance(other, FrozenDict):
411
+ return self._data == other._data
412
+ return self._data == other
413
+
414
+ def __reduce__(self):
415
+ """Support for pickling."""
416
+ return FrozenDict, (self.unfreeze(),)
417
+
418
+ def keys(self) -> KeysView[K]:
419
+ """
420
+ Return a view of the keys.
421
+
422
+ Returns
423
+ -------
424
+ KeysView
425
+ A view object of the dictionary's keys.
426
+ """
427
+ return FrozenKeysView(self)
428
+
429
+ def values(self) -> ValuesView[V]:
430
+ """
431
+ Return a view of the values.
432
+
433
+ Returns
434
+ -------
435
+ ValuesView
436
+ A view object of the dictionary's values.
437
+ """
438
+ return FrozenValuesView(self)
439
+
440
+ def items(self) -> ItemsView[K, V]:
441
+ """
442
+ Return a view of the items.
443
+
444
+ Yields
445
+ ------
446
+ tuple
447
+ Key-value pairs from the dictionary.
448
+ """
449
+ for key in self._data:
450
+ yield (key, self[key])
451
+
452
+ def get(self, key: K, default: V | None = None) -> V | None:
453
+ """
454
+ Get a value with a default.
455
+
456
+ Parameters
457
+ ----------
458
+ key : K
459
+ The key to look up.
460
+ default : V or None, optional
461
+ The default value to return if key is not found.
462
+
463
+ Returns
464
+ -------
465
+ V or None
466
+ The value associated with the key, or default.
467
+ """
468
+ try:
469
+ return self[key]
470
+ except KeyError:
471
+ return default
472
+
473
+ def copy(self, add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
474
+ """
475
+ Create a new FrozenDict with added or replaced entries.
476
+
477
+ Parameters
478
+ ----------
479
+ add_or_replace : Mapping or None, optional
480
+ Entries to add or replace in the new dictionary.
481
+
482
+ Returns
483
+ -------
484
+ FrozenDict
485
+ A new FrozenDict with the updates applied.
486
+
487
+ Examples
488
+ --------
489
+ .. code-block:: python
490
+
491
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
492
+ >>> fd2 = fd.copy({'b': 3, 'c': 4})
493
+ >>> fd2['b'], fd2['c']
494
+ (3, 4)
495
+ """
496
+ if add_or_replace is None:
497
+ add_or_replace = {}
498
+ new_data = dict(self._data)
499
+ new_data.update(add_or_replace)
500
+ return type(self)(new_data)
501
+
502
+ def pop(self, key: K) -> tuple[FrozenDict[K, V], V]:
503
+ """
504
+ Create a new FrozenDict with one entry removed.
505
+
506
+ Parameters
507
+ ----------
508
+ key : K
509
+ The key to remove.
510
+
511
+ Returns
512
+ -------
513
+ tuple
514
+ A tuple of (new FrozenDict without the key, removed value).
515
+
516
+ Raises
517
+ ------
518
+ KeyError
519
+ If the key is not found in the dictionary.
520
+
521
+ Examples
522
+ --------
523
+ .. code-block:: python
524
+
525
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
526
+ >>> fd2, value = fd.pop('a')
527
+ >>> value
528
+ 1
529
+ >>> 'a' in fd2
530
+ False
531
+ """
532
+ if key not in self._data:
533
+ raise KeyError(key)
534
+ value = self[key]
535
+ new_data = dict(self._data)
536
+ del new_data[key]
537
+ return type(self)(new_data), value
538
+
539
+ def unfreeze(self) -> dict[K, V]:
540
+ """
541
+ Convert to a regular mutable dictionary.
542
+
543
+ Returns
544
+ -------
545
+ dict
546
+ A mutable dict with the same contents.
547
+
548
+ Examples
549
+ --------
550
+ .. code-block:: python
551
+
552
+ >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
553
+ >>> d = fd.unfreeze()
554
+ >>> isinstance(d, dict)
555
+ True
556
+ >>> isinstance(d['b'], dict) # Nested dicts also unfrozen
557
+ True
558
+ """
559
+ return unfreeze(self)
560
+
561
+ def pretty_repr(self, indent: int = 2) -> str:
562
+ """
563
+ Return a pretty-printed representation.
564
+
565
+ Parameters
566
+ ----------
567
+ indent : int, optional
568
+ Number of spaces per indentation level (default 2).
569
+
570
+ Returns
571
+ -------
572
+ str
573
+ A formatted string representation of the FrozenDict.
574
+ """
575
+
576
+ def format_value(v, level):
577
+ if isinstance(v, dict):
578
+ if not v:
579
+ return '{}'
580
+ items = []
581
+ for k, val in v.items():
582
+ formatted_val = format_value(val, level + 1)
583
+ items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted_val}')
584
+ return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
585
+ else:
586
+ return repr(v)
587
+
588
+ if not self._data:
589
+ return 'FrozenDict({})'
590
+
591
+ return f'FrozenDict({format_value(self._data, 0)})'
592
+
593
+ def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]:
594
+ """Flatten for JAX pytree with keys."""
595
+ sorted_keys = sorted(self._data.keys())
596
+ values_with_keys = [
597
+ (jax.tree_util.DictKey(k), self._data[k])
598
+ for k in sorted_keys
599
+ ]
600
+ return values_with_keys, tuple(sorted_keys)
601
+
602
+ @classmethod
603
+ def tree_unflatten(cls, keys: tuple[Any, ...], values: list[Any]) -> FrozenDict:
604
+ """Unflatten from JAX pytree."""
605
+ return cls(dict(zip(keys, values)))
606
+
607
+
608
+ class FrozenKeysView(KeysView[K]):
609
+ """View of keys in a FrozenDict."""
610
+
611
+ def __repr__(self) -> str:
612
+ return f'FrozenDict.keys({list(self)})'
613
+
614
+
615
+ class FrozenValuesView(ValuesView[V]):
616
+ """View of values in a FrozenDict."""
617
+
618
+ def __repr__(self) -> str:
619
+ return f'FrozenDict.values({list(self)})'
620
+
621
+
622
+ def freeze(x: Mapping[K, V]) -> FrozenDict[K, V]:
623
+ """
624
+ Convert a mapping to a FrozenDict.
625
+
626
+ Parameters
627
+ ----------
628
+ x : Mapping
629
+ A mapping (dict, FrozenDict, etc.) to freeze.
630
+
631
+ Returns
632
+ -------
633
+ FrozenDict
634
+ An immutable FrozenDict.
635
+
636
+ See Also
637
+ --------
638
+ unfreeze : Convert a FrozenDict to a regular dict.
639
+ FrozenDict : The immutable dictionary class.
640
+
641
+ Examples
642
+ --------
643
+ .. code-block:: python
644
+
645
+ >>> from brainstate.util import freeze
646
+
647
+ >>> d = {'a': 1, 'b': {'c': 2}}
648
+ >>> fd = freeze(d)
649
+ >>> isinstance(fd, FrozenDict)
650
+ True
651
+ >>> isinstance(fd['b'], FrozenDict) # Nested dicts are frozen
652
+ True
653
+ """
654
+ if isinstance(x, FrozenDict):
655
+ return x
656
+ return FrozenDict(x)
657
+
658
+
659
+ def unfreeze(x: FrozenDict[K, V] | dict[K, V]) -> dict[K, V]:
660
+ """
661
+ Convert a FrozenDict to a regular dict.
662
+
663
+ Recursively converts FrozenDict instances to mutable dicts.
664
+
665
+ Parameters
666
+ ----------
667
+ x : FrozenDict or dict
668
+ A FrozenDict or dict to unfreeze.
669
+
670
+ Returns
671
+ -------
672
+ dict
673
+ A mutable dictionary.
674
+
675
+ See Also
676
+ --------
677
+ freeze : Convert a mapping to a FrozenDict.
678
+ FrozenDict : The immutable dictionary class.
679
+
680
+ Examples
681
+ --------
682
+ .. code-block:: python
683
+
684
+ >>> from brainstate.util import FrozenDict, unfreeze
685
+
686
+ >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
687
+ >>> d = unfreeze(fd)
688
+ >>> isinstance(d, dict)
689
+ True
690
+ >>> isinstance(d['b'], dict) # Nested FrozenDicts are unfrozen
691
+ True
692
+ >>> d['a'] = 10 # Can modify the result
693
+ """
694
+ if isinstance(x, FrozenDict):
695
+ result = {}
696
+ for key, value in x._data.items():
697
+ result[key] = unfreeze(value)
698
+ return result
699
+ elif isinstance(x, dict):
700
+ result = {}
701
+ for key, value in x.items():
702
+ result[key] = unfreeze(value)
703
+ return result
704
+ else:
705
+ return x
706
+
707
+
708
+ @overload
709
+ def copy(x: FrozenDict[K, V], add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
710
+ ...
711
+
712
+
713
+ @overload
714
+ def copy(x: dict[K, V], add_or_replace: Mapping[K, V] | None = None) -> dict[K, V]:
715
+ ...
716
+
717
+
718
+ def copy(x, add_or_replace=None):
719
+ """
720
+ Copy a dictionary with optional updates.
721
+
722
+ Works with both FrozenDict and regular dict.
723
+
724
+ Parameters
725
+ ----------
726
+ x : FrozenDict or dict
727
+ Dictionary to copy.
728
+ add_or_replace : Mapping or None, optional
729
+ Entries to add or replace in the copy.
730
+
731
+ Returns
732
+ -------
733
+ FrozenDict or dict
734
+ A copy of the same type as the input with updates applied.
735
+
736
+ Raises
737
+ ------
738
+ TypeError
739
+ If x is not a FrozenDict or dict.
740
+
741
+ See Also
742
+ --------
743
+ FrozenDict.copy : Copy method for FrozenDict.
744
+
745
+ Examples
746
+ --------
747
+ .. code-block:: python
748
+
749
+ >>> from brainstate.util import FrozenDict, copy
750
+
751
+ >>> # Works with FrozenDict
752
+ >>> fd = FrozenDict({'a': 1})
753
+ >>> fd2 = copy(fd, {'b': 2})
754
+ >>> isinstance(fd2, FrozenDict)
755
+ True
756
+ >>> fd2['b']
757
+ 2
758
+
759
+ >>> # Also works with regular dict
760
+ >>> d = {'a': 1}
761
+ >>> d2 = copy(d, {'b': 2})
762
+ >>> isinstance(d2, dict)
763
+ True
764
+ >>> d2['b']
765
+ 2
766
+ """
767
+ if add_or_replace is None:
768
+ add_or_replace = {}
769
+
770
+ if isinstance(x, FrozenDict):
771
+ return x.copy(add_or_replace)
772
+ elif isinstance(x, dict):
773
+ result = dict(x)
774
+ result.update(add_or_replace)
775
+ return result
776
+ else:
777
+ raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
778
+
779
+
780
+ @overload
781
+ def pop(x: FrozenDict[K, V], key: K) -> tuple[FrozenDict[K, V], V]:
782
+ ...
783
+
784
+
785
+ @overload
786
+ def pop(x: dict[K, V], key: K) -> tuple[dict[K, V], V]:
787
+ ...
788
+
789
+
790
+ def pop(x, key):
791
+ """
792
+ Remove and return an item from a dictionary.
793
+
794
+ Works with both FrozenDict and regular dict, returning a new
795
+ dictionary without the specified key along with the popped value.
796
+
797
+ Parameters
798
+ ----------
799
+ x : FrozenDict or dict
800
+ Dictionary to pop from.
801
+ key : hashable
802
+ Key to remove.
803
+
804
+ Returns
805
+ -------
806
+ tuple
807
+ A tuple of (new dictionary without the key, popped value).
808
+
809
+ Raises
810
+ ------
811
+ TypeError
812
+ If x is not a FrozenDict or dict.
813
+ KeyError
814
+ If the key is not found in the dictionary.
815
+
816
+ See Also
817
+ --------
818
+ FrozenDict.pop : Pop method for FrozenDict.
819
+
820
+ Examples
821
+ --------
822
+ .. code-block:: python
823
+
824
+ >>> from brainstate.util import FrozenDict, pop
825
+
826
+ >>> # Works with FrozenDict
827
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
828
+ >>> fd2, value = pop(fd, 'a')
829
+ >>> value
830
+ 1
831
+ >>> 'a' in fd2
832
+ False
833
+
834
+ >>> # Also works with regular dict
835
+ >>> d = {'a': 1, 'b': 2}
836
+ >>> d2, value = pop(d, 'a')
837
+ >>> value
838
+ 1
839
+ >>> 'a' in d2
840
+ False
841
+ """
842
+ if isinstance(x, FrozenDict):
843
+ return x.pop(key)
844
+ elif isinstance(x, dict):
845
+ new_dict = dict(x)
846
+ value = new_dict.pop(key)
847
+ return new_dict, value
848
+ else:
849
+ raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
850
+
851
+
852
+ def pretty_repr(x: Any, indent: int = 2) -> str:
853
+ """
854
+ Create a pretty string representation.
855
+
856
+ Parameters
857
+ ----------
858
+ x : any
859
+ Object to represent. If a dict or FrozenDict, will be
860
+ pretty-printed with indentation. Otherwise, returns repr(x).
861
+ indent : int, optional
862
+ Number of spaces per indentation level (default 2).
863
+
864
+ Returns
865
+ -------
866
+ str
867
+ A formatted string representation.
868
+
869
+ See Also
870
+ --------
871
+ FrozenDict.pretty_repr : Pretty representation for FrozenDict.
872
+
873
+ Examples
874
+ --------
875
+ .. code-block:: python
876
+
877
+ >>> from brainstate.util import pretty_repr
878
+
879
+ >>> d = {'a': 1, 'b': {'c': 2, 'd': 3}}
880
+ >>> print(pretty_repr(d))
881
+ {
882
+ 'a': 1,
883
+ 'b': {
884
+ 'c': 2,
885
+ 'd': 3
886
+ }
887
+ }
888
+
889
+ >>> # Non-dict objects return normal repr
890
+ >>> pretty_repr([1, 2, 3])
891
+ '[1, 2, 3]'
892
+ """
893
+ if isinstance(x, FrozenDict):
894
+ return x.pretty_repr(indent)
895
+ elif isinstance(x, dict):
896
+ def format_dict(d, level):
897
+ if not d:
898
+ return '{}'
899
+ items = []
900
+ for k, v in d.items():
901
+ if isinstance(v, dict):
902
+ formatted = format_dict(v, level + 1)
903
+ else:
904
+ formatted = repr(v)
905
+ items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted}')
906
+ return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
907
+
908
+ return format_dict(x, 0)
909
+ else:
910
+ return repr(x)