brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/util/struct.py CHANGED
@@ -1,910 +1,910 @@
1
- # The file is adapted from the Flax library (https://github.com/google/flax).
2
- # The credit should go to the Flax authors.
3
- #
4
- # Copyright 2024 The Flax Authors.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- """
19
- Custom data structures that work seamlessly with JAX transformations.
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- import dataclasses
25
- from collections.abc import Mapping, KeysView, ValuesView, ItemsView
26
- from typing import Any, TypeVar, Generic, Iterator, overload
27
-
28
- import jax
29
- import jax.tree_util
30
- from typing_extensions import dataclass_transform
31
-
32
- __all__ = [
33
- 'field',
34
- 'dataclass',
35
- 'PyTreeNode',
36
- 'FrozenDict',
37
- 'freeze',
38
- 'unfreeze',
39
- 'copy',
40
- 'pop',
41
- 'pretty_repr',
42
- ]
43
-
44
- # Type variables
45
- K = TypeVar('K')
46
- V = TypeVar('V')
47
- T = TypeVar('T')
48
- TNode = TypeVar('TNode', bound='PyTreeNode')
49
-
50
-
51
- def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field:
52
- """
53
- Create a dataclass field with JAX pytree metadata.
54
-
55
- Parameters
56
- ----------
57
- pytree_node : bool, optional
58
- If True (default), this field will be treated as part of the pytree.
59
- If False, it will be treated as metadata and not be touched
60
- by JAX transformations.
61
- **kwargs
62
- Additional arguments to pass to dataclasses.field().
63
-
64
- Returns
65
- -------
66
- dataclasses.Field
67
- A dataclass field with the appropriate metadata.
68
-
69
- Examples
70
- --------
71
- .. code-block:: python
72
-
73
- >>> import jax.numpy as jnp
74
- >>> from brainstate.util import dataclass, field
75
-
76
- >>> @dataclass
77
- ... class Model:
78
- ... weights: jnp.ndarray
79
- ... bias: jnp.ndarray
80
- ... # This field won't be affected by JAX transformations
81
- ... name: str = field(pytree_node=False, default="model")
82
- """
83
- metadata = kwargs.pop('metadata', {})
84
- metadata['pytree_node'] = pytree_node
85
- return dataclasses.field(metadata=metadata, **kwargs)
86
-
87
-
88
- @dataclass_transform(field_specifiers=(field,))
89
- def dataclass(cls: type[T], **kwargs) -> type[T]:
90
- """
91
- Create a dataclass that works with JAX transformations.
92
-
93
- This decorator creates immutable dataclasses that can be used safely
94
- with JAX transformations like jit, grad, vmap, etc. The created class
95
- will be registered as a JAX pytree node.
96
-
97
- Parameters
98
- ----------
99
- cls : type
100
- The class to decorate.
101
- **kwargs
102
- Additional arguments for dataclasses.dataclass().
103
- If 'frozen' is not specified, it defaults to True.
104
-
105
- Returns
106
- -------
107
- type
108
- The decorated class as an immutable JAX-compatible dataclass.
109
-
110
- See Also
111
- --------
112
- PyTreeNode : Base class for creating JAX-compatible pytree nodes.
113
- field : Create dataclass fields with pytree metadata.
114
-
115
- Notes
116
- -----
117
- The decorated class will be frozen (immutable) by default to ensure
118
- compatibility with JAX's functional programming paradigm.
119
-
120
- Examples
121
- --------
122
- .. code-block:: python
123
-
124
- >>> import jax
125
- >>> import jax.numpy as jnp
126
- >>> from brainstate.util import dataclass, field
127
-
128
- >>> @dataclass
129
- ... class Model:
130
- ... weights: jax.Array
131
- ... bias: jax.Array
132
- ... name: str = field(pytree_node=False, default="model")
133
-
134
- >>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))
135
-
136
- >>> # JAX transformations will only apply to weights and bias, not name
137
- >>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
138
- >>> grads = grad_fn(model)
139
-
140
- >>> # Use replace to create modified copies
141
- >>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)
142
- """
143
- # Check if already converted
144
- if hasattr(cls, '_brainstate_dataclass'):
145
- return cls
146
-
147
- # Default to frozen for immutability
148
- kwargs.setdefault('frozen', True)
149
-
150
- # Apply standard dataclass decorator
151
- cls = dataclasses.dataclass(**kwargs)(cls)
152
-
153
- # Separate fields into pytree and metadata
154
- pytree_fields = []
155
- meta_fields = []
156
-
157
- for field_info in dataclasses.fields(cls):
158
- if field_info.metadata.get('pytree_node', True):
159
- pytree_fields.append(field_info.name)
160
- else:
161
- meta_fields.append(field_info.name)
162
-
163
- # Add replace method
164
- def replace(self: T, **updates) -> T:
165
- """Replace specified fields with new values."""
166
- return dataclasses.replace(self, **updates)
167
-
168
- cls.replace = replace
169
-
170
- # Register with JAX
171
- _register_pytree(cls, pytree_fields, meta_fields)
172
-
173
- # Mark as BrainState dataclass
174
- cls._brainstate_dataclass = True
175
-
176
- return cls
177
-
178
-
179
- def _register_pytree(cls: type, pytree_fields: list[str], meta_fields: list[str]) -> None:
180
- """Register a class as a JAX pytree."""
181
-
182
- def flatten_fn(obj):
183
- pytree_data = tuple(getattr(obj, name) for name in pytree_fields)
184
- metadata = tuple(getattr(obj, name) for name in meta_fields)
185
- return pytree_data, metadata
186
-
187
- def flatten_with_keys_fn(obj):
188
- pytree_data = tuple(
189
- (jax.tree_util.GetAttrKey(name), getattr(obj, name))
190
- for name in pytree_fields
191
- )
192
- metadata = tuple(getattr(obj, name) for name in meta_fields)
193
- return pytree_data, metadata
194
-
195
- def unflatten_fn(metadata, pytree_data):
196
- kwargs = {}
197
- for name, value in zip(meta_fields, metadata):
198
- kwargs[name] = value
199
- for name, value in zip(pytree_fields, pytree_data):
200
- kwargs[name] = value
201
- return cls(**kwargs)
202
-
203
- # Use new API if available, otherwise fall back
204
- if hasattr(jax.tree_util, 'register_dataclass'):
205
- jax.tree_util.register_dataclass(cls, pytree_fields, meta_fields)
206
- else:
207
- jax.tree_util.register_pytree_with_keys(
208
- cls,
209
- flatten_with_keys_fn,
210
- unflatten_fn,
211
- flatten_fn
212
- )
213
-
214
-
215
- @dataclass_transform(field_specifiers=(field,))
216
- class PyTreeNode:
217
- """
218
- Base class for creating JAX-compatible pytree nodes.
219
-
220
- Subclasses of PyTreeNode are automatically converted to immutable
221
- dataclasses that work with JAX transformations.
222
-
223
- See Also
224
- --------
225
- dataclass : Decorator for creating JAX-compatible dataclasses.
226
- field : Create dataclass fields with pytree metadata.
227
-
228
- Notes
229
- -----
230
- When subclassing PyTreeNode, all fields are automatically treated as
231
- part of the pytree unless explicitly marked with ``pytree_node=False``
232
- using the field() function.
233
-
234
- Examples
235
- --------
236
- .. code-block:: python
237
-
238
- >>> import jax
239
- >>> import jax.numpy as jnp
240
- >>> from brainstate.util import PyTreeNode, field
241
-
242
- >>> class Layer(PyTreeNode):
243
- ... weights: jax.Array
244
- ... bias: jax.Array
245
- ... activation: str = field(pytree_node=False, default="relu")
246
-
247
- >>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))
248
-
249
- >>> # Can be used in JAX transformations
250
- >>> def loss_fn(layer):
251
- ... return jnp.sum(layer.weights ** 2)
252
- >>> grad_fn = jax.grad(loss_fn)
253
- >>> grads = grad_fn(layer)
254
-
255
- >>> # Create modified copies with replace
256
- >>> layer2 = layer.replace(bias=jnp.ones(4))
257
- """
258
-
259
- def __init_subclass__(cls, **kwargs):
260
- """Automatically apply dataclass decorator to subclasses."""
261
- dataclass(cls, **kwargs)
262
-
263
- def __init__(self, *args, **kwargs):
264
- """Stub for type checkers."""
265
- raise NotImplementedError("PyTreeNode is a base class")
266
-
267
- def replace(self: TNode, **updates) -> TNode:
268
- """
269
- Replace specified fields with new values.
270
-
271
- Parameters
272
- ----------
273
- **updates
274
- Field names and their new values.
275
-
276
- Returns
277
- -------
278
- TNode
279
- A new instance with updated fields.
280
- """
281
- raise NotImplementedError("Implemented by dataclass decorator")
282
-
283
-
284
- @jax.tree_util.register_pytree_with_keys_class
285
- class FrozenDict(Mapping[K, V], Generic[K, V]):
286
- """
287
- An immutable dictionary that works as a JAX pytree.
288
-
289
- FrozenDict provides an immutable mapping interface that can be used
290
- safely with JAX transformations. It supports all standard dictionary
291
- operations in an immutable fashion.
292
-
293
- Parameters
294
- ----------
295
- *args
296
- Positional arguments for dict construction.
297
- **kwargs
298
- Keyword arguments for dict construction.
299
-
300
- Attributes
301
- ----------
302
- _data : dict
303
- Internal immutable data storage.
304
- _hash : int or None
305
- Cached hash value.
306
-
307
- See Also
308
- --------
309
- freeze : Convert a mapping to a FrozenDict.
310
- unfreeze : Convert a FrozenDict to a regular dict.
311
-
312
- Notes
313
- -----
314
- FrozenDict is immutable - all operations that would modify the dictionary
315
- instead return a new FrozenDict instance with the changes applied.
316
-
317
- Examples
318
- --------
319
- .. code-block:: python
320
-
321
- >>> from brainstate.util import FrozenDict
322
-
323
- >>> # Create a FrozenDict
324
- >>> fd = FrozenDict({'a': 1, 'b': 2})
325
- >>> fd['a']
326
- 1
327
-
328
- >>> # Copy with updates (returns new FrozenDict)
329
- >>> new_fd = fd.copy({'c': 3})
330
- >>> new_fd['c']
331
- 3
332
-
333
- >>> # Pop an item (returns new dict and popped value)
334
- >>> new_fd, value = fd.pop('b')
335
- >>> value
336
- 2
337
- >>> 'b' in new_fd
338
- False
339
-
340
- >>> # Nested dictionaries are automatically frozen
341
- >>> fd = FrozenDict({'x': {'y': 1}})
342
- >>> isinstance(fd['x'], FrozenDict)
343
- True
344
- """
345
-
346
- __slots__ = ('_data', '_hash')
347
-
348
- def __init__(self, *args, **kwargs):
349
- """Initialize a FrozenDict."""
350
- data = dict(*args, **kwargs)
351
- self._data = self._deep_freeze(data)
352
- self._hash = None
353
-
354
- @staticmethod
355
- def _deep_freeze(obj: Any) -> Any:
356
- """Recursively freeze nested dictionaries."""
357
- if isinstance(obj, FrozenDict):
358
- return obj._data
359
- elif isinstance(obj, dict):
360
- return {k: FrozenDict._deep_freeze(v) for k, v in obj.items()}
361
- else:
362
- return obj
363
-
364
- def __getitem__(self, key: K) -> V:
365
- """Get an item from the dictionary."""
366
- value = self._data[key]
367
- if isinstance(value, dict):
368
- return FrozenDict(value)
369
- return value
370
-
371
- def __setitem__(self, key: K, value: V) -> None:
372
- """Raise an error - FrozenDict is immutable."""
373
- raise TypeError("FrozenDict does not support item assignment")
374
-
375
- def __delitem__(self, key: K) -> None:
376
- """Raise an error - FrozenDict is immutable."""
377
- raise TypeError("FrozenDict does not support item deletion")
378
-
379
- def __contains__(self, key: object) -> bool:
380
- """Check if a key is in the dictionary."""
381
- return key in self._data
382
-
383
- def __iter__(self) -> Iterator[K]:
384
- """Iterate over keys."""
385
- return iter(self._data)
386
-
387
- def __len__(self) -> int:
388
- """Return the number of items."""
389
- return len(self._data)
390
-
391
- def __repr__(self) -> str:
392
- """Return a string representation."""
393
- return self.pretty_repr()
394
-
395
- def __hash__(self) -> int:
396
- """Return a hash of the dictionary."""
397
- if self._hash is None:
398
- items = []
399
- for key, value in self.items():
400
- if isinstance(value, dict):
401
- value = FrozenDict(value)
402
- items.append((key, value))
403
- self._hash = hash(tuple(sorted(items)))
404
- return self._hash
405
-
406
- def __eq__(self, other: object) -> bool:
407
- """Check equality with another object."""
408
- if not isinstance(other, (FrozenDict, dict)):
409
- return NotImplemented
410
- if isinstance(other, FrozenDict):
411
- return self._data == other._data
412
- return self._data == other
413
-
414
- def __reduce__(self):
415
- """Support for pickling."""
416
- return FrozenDict, (self.unfreeze(),)
417
-
418
- def keys(self) -> KeysView[K]:
419
- """
420
- Return a view of the keys.
421
-
422
- Returns
423
- -------
424
- KeysView
425
- A view object of the dictionary's keys.
426
- """
427
- return FrozenKeysView(self)
428
-
429
- def values(self) -> ValuesView[V]:
430
- """
431
- Return a view of the values.
432
-
433
- Returns
434
- -------
435
- ValuesView
436
- A view object of the dictionary's values.
437
- """
438
- return FrozenValuesView(self)
439
-
440
- def items(self) -> ItemsView[K, V]:
441
- """
442
- Return a view of the items.
443
-
444
- Yields
445
- ------
446
- tuple
447
- Key-value pairs from the dictionary.
448
- """
449
- for key in self._data:
450
- yield (key, self[key])
451
-
452
- def get(self, key: K, default: V | None = None) -> V | None:
453
- """
454
- Get a value with a default.
455
-
456
- Parameters
457
- ----------
458
- key : K
459
- The key to look up.
460
- default : V or None, optional
461
- The default value to return if key is not found.
462
-
463
- Returns
464
- -------
465
- V or None
466
- The value associated with the key, or default.
467
- """
468
- try:
469
- return self[key]
470
- except KeyError:
471
- return default
472
-
473
- def copy(self, add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
474
- """
475
- Create a new FrozenDict with added or replaced entries.
476
-
477
- Parameters
478
- ----------
479
- add_or_replace : Mapping or None, optional
480
- Entries to add or replace in the new dictionary.
481
-
482
- Returns
483
- -------
484
- FrozenDict
485
- A new FrozenDict with the updates applied.
486
-
487
- Examples
488
- --------
489
- .. code-block:: python
490
-
491
- >>> fd = FrozenDict({'a': 1, 'b': 2})
492
- >>> fd2 = fd.copy({'b': 3, 'c': 4})
493
- >>> fd2['b'], fd2['c']
494
- (3, 4)
495
- """
496
- if add_or_replace is None:
497
- add_or_replace = {}
498
- new_data = dict(self._data)
499
- new_data.update(add_or_replace)
500
- return type(self)(new_data)
501
-
502
- def pop(self, key: K) -> tuple[FrozenDict[K, V], V]:
503
- """
504
- Create a new FrozenDict with one entry removed.
505
-
506
- Parameters
507
- ----------
508
- key : K
509
- The key to remove.
510
-
511
- Returns
512
- -------
513
- tuple
514
- A tuple of (new FrozenDict without the key, removed value).
515
-
516
- Raises
517
- ------
518
- KeyError
519
- If the key is not found in the dictionary.
520
-
521
- Examples
522
- --------
523
- .. code-block:: python
524
-
525
- >>> fd = FrozenDict({'a': 1, 'b': 2})
526
- >>> fd2, value = fd.pop('a')
527
- >>> value
528
- 1
529
- >>> 'a' in fd2
530
- False
531
- """
532
- if key not in self._data:
533
- raise KeyError(key)
534
- value = self[key]
535
- new_data = dict(self._data)
536
- del new_data[key]
537
- return type(self)(new_data), value
538
-
539
- def unfreeze(self) -> dict[K, V]:
540
- """
541
- Convert to a regular mutable dictionary.
542
-
543
- Returns
544
- -------
545
- dict
546
- A mutable dict with the same contents.
547
-
548
- Examples
549
- --------
550
- .. code-block:: python
551
-
552
- >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
553
- >>> d = fd.unfreeze()
554
- >>> isinstance(d, dict)
555
- True
556
- >>> isinstance(d['b'], dict) # Nested dicts also unfrozen
557
- True
558
- """
559
- return unfreeze(self)
560
-
561
- def pretty_repr(self, indent: int = 2) -> str:
562
- """
563
- Return a pretty-printed representation.
564
-
565
- Parameters
566
- ----------
567
- indent : int, optional
568
- Number of spaces per indentation level (default 2).
569
-
570
- Returns
571
- -------
572
- str
573
- A formatted string representation of the FrozenDict.
574
- """
575
-
576
- def format_value(v, level):
577
- if isinstance(v, dict):
578
- if not v:
579
- return '{}'
580
- items = []
581
- for k, val in v.items():
582
- formatted_val = format_value(val, level + 1)
583
- items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted_val}')
584
- return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
585
- else:
586
- return repr(v)
587
-
588
- if not self._data:
589
- return 'FrozenDict({})'
590
-
591
- return f'FrozenDict({format_value(self._data, 0)})'
592
-
593
- def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]:
594
- """Flatten for JAX pytree with keys."""
595
- sorted_keys = sorted(self._data.keys())
596
- values_with_keys = [
597
- (jax.tree_util.DictKey(k), self._data[k])
598
- for k in sorted_keys
599
- ]
600
- return values_with_keys, tuple(sorted_keys)
601
-
602
- @classmethod
603
- def tree_unflatten(cls, keys: tuple[Any, ...], values: list[Any]) -> FrozenDict:
604
- """Unflatten from JAX pytree."""
605
- return cls(dict(zip(keys, values)))
606
-
607
-
608
- class FrozenKeysView(KeysView[K]):
609
- """View of keys in a FrozenDict."""
610
-
611
- def __repr__(self) -> str:
612
- return f'FrozenDict.keys({list(self)})'
613
-
614
-
615
- class FrozenValuesView(ValuesView[V]):
616
- """View of values in a FrozenDict."""
617
-
618
- def __repr__(self) -> str:
619
- return f'FrozenDict.values({list(self)})'
620
-
621
-
622
- def freeze(x: Mapping[K, V]) -> FrozenDict[K, V]:
623
- """
624
- Convert a mapping to a FrozenDict.
625
-
626
- Parameters
627
- ----------
628
- x : Mapping
629
- A mapping (dict, FrozenDict, etc.) to freeze.
630
-
631
- Returns
632
- -------
633
- FrozenDict
634
- An immutable FrozenDict.
635
-
636
- See Also
637
- --------
638
- unfreeze : Convert a FrozenDict to a regular dict.
639
- FrozenDict : The immutable dictionary class.
640
-
641
- Examples
642
- --------
643
- .. code-block:: python
644
-
645
- >>> from brainstate.util import freeze
646
-
647
- >>> d = {'a': 1, 'b': {'c': 2}}
648
- >>> fd = freeze(d)
649
- >>> isinstance(fd, FrozenDict)
650
- True
651
- >>> isinstance(fd['b'], FrozenDict) # Nested dicts are frozen
652
- True
653
- """
654
- if isinstance(x, FrozenDict):
655
- return x
656
- return FrozenDict(x)
657
-
658
-
659
- def unfreeze(x: FrozenDict[K, V] | dict[K, V]) -> dict[K, V]:
660
- """
661
- Convert a FrozenDict to a regular dict.
662
-
663
- Recursively converts FrozenDict instances to mutable dicts.
664
-
665
- Parameters
666
- ----------
667
- x : FrozenDict or dict
668
- A FrozenDict or dict to unfreeze.
669
-
670
- Returns
671
- -------
672
- dict
673
- A mutable dictionary.
674
-
675
- See Also
676
- --------
677
- freeze : Convert a mapping to a FrozenDict.
678
- FrozenDict : The immutable dictionary class.
679
-
680
- Examples
681
- --------
682
- .. code-block:: python
683
-
684
- >>> from brainstate.util import FrozenDict, unfreeze
685
-
686
- >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
687
- >>> d = unfreeze(fd)
688
- >>> isinstance(d, dict)
689
- True
690
- >>> isinstance(d['b'], dict) # Nested FrozenDicts are unfrozen
691
- True
692
- >>> d['a'] = 10 # Can modify the result
693
- """
694
- if isinstance(x, FrozenDict):
695
- result = {}
696
- for key, value in x._data.items():
697
- result[key] = unfreeze(value)
698
- return result
699
- elif isinstance(x, dict):
700
- result = {}
701
- for key, value in x.items():
702
- result[key] = unfreeze(value)
703
- return result
704
- else:
705
- return x
706
-
707
-
708
- @overload
709
- def copy(x: FrozenDict[K, V], add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
710
- ...
711
-
712
-
713
- @overload
714
- def copy(x: dict[K, V], add_or_replace: Mapping[K, V] | None = None) -> dict[K, V]:
715
- ...
716
-
717
-
718
- def copy(x, add_or_replace=None):
719
- """
720
- Copy a dictionary with optional updates.
721
-
722
- Works with both FrozenDict and regular dict.
723
-
724
- Parameters
725
- ----------
726
- x : FrozenDict or dict
727
- Dictionary to copy.
728
- add_or_replace : Mapping or None, optional
729
- Entries to add or replace in the copy.
730
-
731
- Returns
732
- -------
733
- FrozenDict or dict
734
- A copy of the same type as the input with updates applied.
735
-
736
- Raises
737
- ------
738
- TypeError
739
- If x is not a FrozenDict or dict.
740
-
741
- See Also
742
- --------
743
- FrozenDict.copy : Copy method for FrozenDict.
744
-
745
- Examples
746
- --------
747
- .. code-block:: python
748
-
749
- >>> from brainstate.util import FrozenDict, copy
750
-
751
- >>> # Works with FrozenDict
752
- >>> fd = FrozenDict({'a': 1})
753
- >>> fd2 = copy(fd, {'b': 2})
754
- >>> isinstance(fd2, FrozenDict)
755
- True
756
- >>> fd2['b']
757
- 2
758
-
759
- >>> # Also works with regular dict
760
- >>> d = {'a': 1}
761
- >>> d2 = copy(d, {'b': 2})
762
- >>> isinstance(d2, dict)
763
- True
764
- >>> d2['b']
765
- 2
766
- """
767
- if add_or_replace is None:
768
- add_or_replace = {}
769
-
770
- if isinstance(x, FrozenDict):
771
- return x.copy(add_or_replace)
772
- elif isinstance(x, dict):
773
- result = dict(x)
774
- result.update(add_or_replace)
775
- return result
776
- else:
777
- raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
778
-
779
-
780
- @overload
781
- def pop(x: FrozenDict[K, V], key: K) -> tuple[FrozenDict[K, V], V]:
782
- ...
783
-
784
-
785
- @overload
786
- def pop(x: dict[K, V], key: K) -> tuple[dict[K, V], V]:
787
- ...
788
-
789
-
790
- def pop(x, key):
791
- """
792
- Remove and return an item from a dictionary.
793
-
794
- Works with both FrozenDict and regular dict, returning a new
795
- dictionary without the specified key along with the popped value.
796
-
797
- Parameters
798
- ----------
799
- x : FrozenDict or dict
800
- Dictionary to pop from.
801
- key : hashable
802
- Key to remove.
803
-
804
- Returns
805
- -------
806
- tuple
807
- A tuple of (new dictionary without the key, popped value).
808
-
809
- Raises
810
- ------
811
- TypeError
812
- If x is not a FrozenDict or dict.
813
- KeyError
814
- If the key is not found in the dictionary.
815
-
816
- See Also
817
- --------
818
- FrozenDict.pop : Pop method for FrozenDict.
819
-
820
- Examples
821
- --------
822
- .. code-block:: python
823
-
824
- >>> from brainstate.util import FrozenDict, pop
825
-
826
- >>> # Works with FrozenDict
827
- >>> fd = FrozenDict({'a': 1, 'b': 2})
828
- >>> fd2, value = pop(fd, 'a')
829
- >>> value
830
- 1
831
- >>> 'a' in fd2
832
- False
833
-
834
- >>> # Also works with regular dict
835
- >>> d = {'a': 1, 'b': 2}
836
- >>> d2, value = pop(d, 'a')
837
- >>> value
838
- 1
839
- >>> 'a' in d2
840
- False
841
- """
842
- if isinstance(x, FrozenDict):
843
- return x.pop(key)
844
- elif isinstance(x, dict):
845
- new_dict = dict(x)
846
- value = new_dict.pop(key)
847
- return new_dict, value
848
- else:
849
- raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
850
-
851
-
852
- def pretty_repr(x: Any, indent: int = 2) -> str:
853
- """
854
- Create a pretty string representation.
855
-
856
- Parameters
857
- ----------
858
- x : any
859
- Object to represent. If a dict or FrozenDict, will be
860
- pretty-printed with indentation. Otherwise, returns repr(x).
861
- indent : int, optional
862
- Number of spaces per indentation level (default 2).
863
-
864
- Returns
865
- -------
866
- str
867
- A formatted string representation.
868
-
869
- See Also
870
- --------
871
- FrozenDict.pretty_repr : Pretty representation for FrozenDict.
872
-
873
- Examples
874
- --------
875
- .. code-block:: python
876
-
877
- >>> from brainstate.util import pretty_repr
878
-
879
- >>> d = {'a': 1, 'b': {'c': 2, 'd': 3}}
880
- >>> print(pretty_repr(d))
881
- {
882
- 'a': 1,
883
- 'b': {
884
- 'c': 2,
885
- 'd': 3
886
- }
887
- }
888
-
889
- >>> # Non-dict objects return normal repr
890
- >>> pretty_repr([1, 2, 3])
891
- '[1, 2, 3]'
892
- """
893
- if isinstance(x, FrozenDict):
894
- return x.pretty_repr(indent)
895
- elif isinstance(x, dict):
896
- def format_dict(d, level):
897
- if not d:
898
- return '{}'
899
- items = []
900
- for k, v in d.items():
901
- if isinstance(v, dict):
902
- formatted = format_dict(v, level + 1)
903
- else:
904
- formatted = repr(v)
905
- items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted}')
906
- return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
907
-
908
- return format_dict(x, 0)
909
- else:
910
- return repr(x)
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Custom data structures that work seamlessly with JAX transformations.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import dataclasses
25
+ from collections.abc import Mapping, KeysView, ValuesView, ItemsView
26
+ from typing import Any, TypeVar, Generic, Iterator, overload
27
+
28
+ import jax
29
+ import jax.tree_util
30
+ from typing_extensions import dataclass_transform
31
+
32
+ __all__ = [
33
+ 'field',
34
+ 'dataclass',
35
+ 'PyTreeNode',
36
+ 'FrozenDict',
37
+ 'freeze',
38
+ 'unfreeze',
39
+ 'copy',
40
+ 'pop',
41
+ 'pretty_repr',
42
+ ]
43
+
44
+ # Type variables
45
+ K = TypeVar('K')
46
+ V = TypeVar('V')
47
+ T = TypeVar('T')
48
+ TNode = TypeVar('TNode', bound='PyTreeNode')
49
+
50
+
51
+ def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field:
52
+ """
53
+ Create a dataclass field with JAX pytree metadata.
54
+
55
+ Parameters
56
+ ----------
57
+ pytree_node : bool, optional
58
+ If True (default), this field will be treated as part of the pytree.
59
+ If False, it will be treated as metadata and not be touched
60
+ by JAX transformations.
61
+ **kwargs
62
+ Additional arguments to pass to dataclasses.field().
63
+
64
+ Returns
65
+ -------
66
+ dataclasses.Field
67
+ A dataclass field with the appropriate metadata.
68
+
69
+ Examples
70
+ --------
71
+ .. code-block:: python
72
+
73
+ >>> import jax.numpy as jnp
74
+ >>> from brainstate.util import dataclass, field
75
+
76
+ >>> @dataclass
77
+ ... class Model:
78
+ ... weights: jnp.ndarray
79
+ ... bias: jnp.ndarray
80
+ ... # This field won't be affected by JAX transformations
81
+ ... name: str = field(pytree_node=False, default="model")
82
+ """
83
+ metadata = kwargs.pop('metadata', {})
84
+ metadata['pytree_node'] = pytree_node
85
+ return dataclasses.field(metadata=metadata, **kwargs)
86
+
87
+
88
+ @dataclass_transform(field_specifiers=(field,))
89
+ def dataclass(cls: type[T], **kwargs) -> type[T]:
90
+ """
91
+ Create a dataclass that works with JAX transformations.
92
+
93
+ This decorator creates immutable dataclasses that can be used safely
94
+ with JAX transformations like jit, grad, vmap, etc. The created class
95
+ will be registered as a JAX pytree node.
96
+
97
+ Parameters
98
+ ----------
99
+ cls : type
100
+ The class to decorate.
101
+ **kwargs
102
+ Additional arguments for dataclasses.dataclass().
103
+ If 'frozen' is not specified, it defaults to True.
104
+
105
+ Returns
106
+ -------
107
+ type
108
+ The decorated class as an immutable JAX-compatible dataclass.
109
+
110
+ See Also
111
+ --------
112
+ PyTreeNode : Base class for creating JAX-compatible pytree nodes.
113
+ field : Create dataclass fields with pytree metadata.
114
+
115
+ Notes
116
+ -----
117
+ The decorated class will be frozen (immutable) by default to ensure
118
+ compatibility with JAX's functional programming paradigm.
119
+
120
+ Examples
121
+ --------
122
+ .. code-block:: python
123
+
124
+ >>> import jax
125
+ >>> import jax.numpy as jnp
126
+ >>> from brainstate.util import dataclass, field
127
+
128
+ >>> @dataclass
129
+ ... class Model:
130
+ ... weights: jax.Array
131
+ ... bias: jax.Array
132
+ ... name: str = field(pytree_node=False, default="model")
133
+
134
+ >>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))
135
+
136
+ >>> # JAX transformations will only apply to weights and bias, not name
137
+ >>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
138
+ >>> grads = grad_fn(model)
139
+
140
+ >>> # Use replace to create modified copies
141
+ >>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)
142
+ """
143
+ # Check if already converted
144
+ if hasattr(cls, '_brainstate_dataclass'):
145
+ return cls
146
+
147
+ # Default to frozen for immutability
148
+ kwargs.setdefault('frozen', True)
149
+
150
+ # Apply standard dataclass decorator
151
+ cls = dataclasses.dataclass(**kwargs)(cls)
152
+
153
+ # Separate fields into pytree and metadata
154
+ pytree_fields = []
155
+ meta_fields = []
156
+
157
+ for field_info in dataclasses.fields(cls):
158
+ if field_info.metadata.get('pytree_node', True):
159
+ pytree_fields.append(field_info.name)
160
+ else:
161
+ meta_fields.append(field_info.name)
162
+
163
+ # Add replace method
164
+ def replace(self: T, **updates) -> T:
165
+ """Replace specified fields with new values."""
166
+ return dataclasses.replace(self, **updates)
167
+
168
+ cls.replace = replace
169
+
170
+ # Register with JAX
171
+ _register_pytree(cls, pytree_fields, meta_fields)
172
+
173
+ # Mark as BrainState dataclass
174
+ cls._brainstate_dataclass = True
175
+
176
+ return cls
177
+
178
+
179
+ def _register_pytree(cls: type, pytree_fields: list[str], meta_fields: list[str]) -> None:
180
+ """Register a class as a JAX pytree."""
181
+
182
+ def flatten_fn(obj):
183
+ pytree_data = tuple(getattr(obj, name) for name in pytree_fields)
184
+ metadata = tuple(getattr(obj, name) for name in meta_fields)
185
+ return pytree_data, metadata
186
+
187
+ def flatten_with_keys_fn(obj):
188
+ pytree_data = tuple(
189
+ (jax.tree_util.GetAttrKey(name), getattr(obj, name))
190
+ for name in pytree_fields
191
+ )
192
+ metadata = tuple(getattr(obj, name) for name in meta_fields)
193
+ return pytree_data, metadata
194
+
195
+ def unflatten_fn(metadata, pytree_data):
196
+ kwargs = {}
197
+ for name, value in zip(meta_fields, metadata):
198
+ kwargs[name] = value
199
+ for name, value in zip(pytree_fields, pytree_data):
200
+ kwargs[name] = value
201
+ return cls(**kwargs)
202
+
203
+ # Use new API if available, otherwise fall back
204
+ if hasattr(jax.tree_util, 'register_dataclass'):
205
+ jax.tree_util.register_dataclass(cls, pytree_fields, meta_fields)
206
+ else:
207
+ jax.tree_util.register_pytree_with_keys(
208
+ cls,
209
+ flatten_with_keys_fn,
210
+ unflatten_fn,
211
+ flatten_fn
212
+ )
213
+
214
+
215
+ @dataclass_transform(field_specifiers=(field,))
216
+ class PyTreeNode:
217
+ """
218
+ Base class for creating JAX-compatible pytree nodes.
219
+
220
+ Subclasses of PyTreeNode are automatically converted to immutable
221
+ dataclasses that work with JAX transformations.
222
+
223
+ See Also
224
+ --------
225
+ dataclass : Decorator for creating JAX-compatible dataclasses.
226
+ field : Create dataclass fields with pytree metadata.
227
+
228
+ Notes
229
+ -----
230
+ When subclassing PyTreeNode, all fields are automatically treated as
231
+ part of the pytree unless explicitly marked with ``pytree_node=False``
232
+ using the field() function.
233
+
234
+ Examples
235
+ --------
236
+ .. code-block:: python
237
+
238
+ >>> import jax
239
+ >>> import jax.numpy as jnp
240
+ >>> from brainstate.util import PyTreeNode, field
241
+
242
+ >>> class Layer(PyTreeNode):
243
+ ... weights: jax.Array
244
+ ... bias: jax.Array
245
+ ... activation: str = field(pytree_node=False, default="relu")
246
+
247
+ >>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))
248
+
249
+ >>> # Can be used in JAX transformations
250
+ >>> def loss_fn(layer):
251
+ ... return jnp.sum(layer.weights ** 2)
252
+ >>> grad_fn = jax.grad(loss_fn)
253
+ >>> grads = grad_fn(layer)
254
+
255
+ >>> # Create modified copies with replace
256
+ >>> layer2 = layer.replace(bias=jnp.ones(4))
257
+ """
258
+
259
+ def __init_subclass__(cls, **kwargs):
260
+ """Automatically apply dataclass decorator to subclasses."""
261
+ dataclass(cls, **kwargs)
262
+
263
+ def __init__(self, *args, **kwargs):
264
+ """Stub for type checkers."""
265
+ raise NotImplementedError("PyTreeNode is a base class")
266
+
267
+ def replace(self: TNode, **updates) -> TNode:
268
+ """
269
+ Replace specified fields with new values.
270
+
271
+ Parameters
272
+ ----------
273
+ **updates
274
+ Field names and their new values.
275
+
276
+ Returns
277
+ -------
278
+ TNode
279
+ A new instance with updated fields.
280
+ """
281
+ raise NotImplementedError("Implemented by dataclass decorator")
282
+
283
+
284
+ @jax.tree_util.register_pytree_with_keys_class
285
+ class FrozenDict(Mapping[K, V], Generic[K, V]):
286
+ """
287
+ An immutable dictionary that works as a JAX pytree.
288
+
289
+ FrozenDict provides an immutable mapping interface that can be used
290
+ safely with JAX transformations. It supports all standard dictionary
291
+ operations in an immutable fashion.
292
+
293
+ Parameters
294
+ ----------
295
+ *args
296
+ Positional arguments for dict construction.
297
+ **kwargs
298
+ Keyword arguments for dict construction.
299
+
300
+ Attributes
301
+ ----------
302
+ _data : dict
303
+ Internal immutable data storage.
304
+ _hash : int or None
305
+ Cached hash value.
306
+
307
+ See Also
308
+ --------
309
+ freeze : Convert a mapping to a FrozenDict.
310
+ unfreeze : Convert a FrozenDict to a regular dict.
311
+
312
+ Notes
313
+ -----
314
+ FrozenDict is immutable - all operations that would modify the dictionary
315
+ instead return a new FrozenDict instance with the changes applied.
316
+
317
+ Examples
318
+ --------
319
+ .. code-block:: python
320
+
321
+ >>> from brainstate.util import FrozenDict
322
+
323
+ >>> # Create a FrozenDict
324
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
325
+ >>> fd['a']
326
+ 1
327
+
328
+ >>> # Copy with updates (returns new FrozenDict)
329
+ >>> new_fd = fd.copy({'c': 3})
330
+ >>> new_fd['c']
331
+ 3
332
+
333
+ >>> # Pop an item (returns new dict and popped value)
334
+ >>> new_fd, value = fd.pop('b')
335
+ >>> value
336
+ 2
337
+ >>> 'b' in new_fd
338
+ False
339
+
340
+ >>> # Nested dictionaries are automatically frozen
341
+ >>> fd = FrozenDict({'x': {'y': 1}})
342
+ >>> isinstance(fd['x'], FrozenDict)
343
+ True
344
+ """
345
+
346
+ __slots__ = ('_data', '_hash')
347
+
348
+ def __init__(self, *args, **kwargs):
349
+ """Initialize a FrozenDict."""
350
+ data = dict(*args, **kwargs)
351
+ self._data = self._deep_freeze(data)
352
+ self._hash = None
353
+
354
+ @staticmethod
355
+ def _deep_freeze(obj: Any) -> Any:
356
+ """Recursively freeze nested dictionaries."""
357
+ if isinstance(obj, FrozenDict):
358
+ return obj._data
359
+ elif isinstance(obj, dict):
360
+ return {k: FrozenDict._deep_freeze(v) for k, v in obj.items()}
361
+ else:
362
+ return obj
363
+
364
+ def __getitem__(self, key: K) -> V:
365
+ """Get an item from the dictionary."""
366
+ value = self._data[key]
367
+ if isinstance(value, dict):
368
+ return FrozenDict(value)
369
+ return value
370
+
371
+ def __setitem__(self, key: K, value: V) -> None:
372
+ """Raise an error - FrozenDict is immutable."""
373
+ raise TypeError("FrozenDict does not support item assignment")
374
+
375
+ def __delitem__(self, key: K) -> None:
376
+ """Raise an error - FrozenDict is immutable."""
377
+ raise TypeError("FrozenDict does not support item deletion")
378
+
379
+ def __contains__(self, key: object) -> bool:
380
+ """Check if a key is in the dictionary."""
381
+ return key in self._data
382
+
383
+ def __iter__(self) -> Iterator[K]:
384
+ """Iterate over keys."""
385
+ return iter(self._data)
386
+
387
+ def __len__(self) -> int:
388
+ """Return the number of items."""
389
+ return len(self._data)
390
+
391
+ def __repr__(self) -> str:
392
+ """Return a string representation."""
393
+ return self.pretty_repr()
394
+
395
+ def __hash__(self) -> int:
396
+ """Return a hash of the dictionary."""
397
+ if self._hash is None:
398
+ items = []
399
+ for key, value in self.items():
400
+ if isinstance(value, dict):
401
+ value = FrozenDict(value)
402
+ items.append((key, value))
403
+ self._hash = hash(tuple(sorted(items)))
404
+ return self._hash
405
+
406
+ def __eq__(self, other: object) -> bool:
407
+ """Check equality with another object."""
408
+ if not isinstance(other, (FrozenDict, dict)):
409
+ return NotImplemented
410
+ if isinstance(other, FrozenDict):
411
+ return self._data == other._data
412
+ return self._data == other
413
+
414
+ def __reduce__(self):
415
+ """Support for pickling."""
416
+ return FrozenDict, (self.unfreeze(),)
417
+
418
+ def keys(self) -> KeysView[K]:
419
+ """
420
+ Return a view of the keys.
421
+
422
+ Returns
423
+ -------
424
+ KeysView
425
+ A view object of the dictionary's keys.
426
+ """
427
+ return FrozenKeysView(self)
428
+
429
+ def values(self) -> ValuesView[V]:
430
+ """
431
+ Return a view of the values.
432
+
433
+ Returns
434
+ -------
435
+ ValuesView
436
+ A view object of the dictionary's values.
437
+ """
438
+ return FrozenValuesView(self)
439
+
440
+ def items(self) -> ItemsView[K, V]:
441
+ """
442
+ Return a view of the items.
443
+
444
+ Yields
445
+ ------
446
+ tuple
447
+ Key-value pairs from the dictionary.
448
+ """
449
+ for key in self._data:
450
+ yield (key, self[key])
451
+
452
+ def get(self, key: K, default: V | None = None) -> V | None:
453
+ """
454
+ Get a value with a default.
455
+
456
+ Parameters
457
+ ----------
458
+ key : K
459
+ The key to look up.
460
+ default : V or None, optional
461
+ The default value to return if key is not found.
462
+
463
+ Returns
464
+ -------
465
+ V or None
466
+ The value associated with the key, or default.
467
+ """
468
+ try:
469
+ return self[key]
470
+ except KeyError:
471
+ return default
472
+
473
+ def copy(self, add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
474
+ """
475
+ Create a new FrozenDict with added or replaced entries.
476
+
477
+ Parameters
478
+ ----------
479
+ add_or_replace : Mapping or None, optional
480
+ Entries to add or replace in the new dictionary.
481
+
482
+ Returns
483
+ -------
484
+ FrozenDict
485
+ A new FrozenDict with the updates applied.
486
+
487
+ Examples
488
+ --------
489
+ .. code-block:: python
490
+
491
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
492
+ >>> fd2 = fd.copy({'b': 3, 'c': 4})
493
+ >>> fd2['b'], fd2['c']
494
+ (3, 4)
495
+ """
496
+ if add_or_replace is None:
497
+ add_or_replace = {}
498
+ new_data = dict(self._data)
499
+ new_data.update(add_or_replace)
500
+ return type(self)(new_data)
501
+
502
+ def pop(self, key: K) -> tuple[FrozenDict[K, V], V]:
503
+ """
504
+ Create a new FrozenDict with one entry removed.
505
+
506
+ Parameters
507
+ ----------
508
+ key : K
509
+ The key to remove.
510
+
511
+ Returns
512
+ -------
513
+ tuple
514
+ A tuple of (new FrozenDict without the key, removed value).
515
+
516
+ Raises
517
+ ------
518
+ KeyError
519
+ If the key is not found in the dictionary.
520
+
521
+ Examples
522
+ --------
523
+ .. code-block:: python
524
+
525
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
526
+ >>> fd2, value = fd.pop('a')
527
+ >>> value
528
+ 1
529
+ >>> 'a' in fd2
530
+ False
531
+ """
532
+ if key not in self._data:
533
+ raise KeyError(key)
534
+ value = self[key]
535
+ new_data = dict(self._data)
536
+ del new_data[key]
537
+ return type(self)(new_data), value
538
+
539
+ def unfreeze(self) -> dict[K, V]:
540
+ """
541
+ Convert to a regular mutable dictionary.
542
+
543
+ Returns
544
+ -------
545
+ dict
546
+ A mutable dict with the same contents.
547
+
548
+ Examples
549
+ --------
550
+ .. code-block:: python
551
+
552
+ >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
553
+ >>> d = fd.unfreeze()
554
+ >>> isinstance(d, dict)
555
+ True
556
+ >>> isinstance(d['b'], dict) # Nested dicts also unfrozen
557
+ True
558
+ """
559
+ return unfreeze(self)
560
+
561
+ def pretty_repr(self, indent: int = 2) -> str:
562
+ """
563
+ Return a pretty-printed representation.
564
+
565
+ Parameters
566
+ ----------
567
+ indent : int, optional
568
+ Number of spaces per indentation level (default 2).
569
+
570
+ Returns
571
+ -------
572
+ str
573
+ A formatted string representation of the FrozenDict.
574
+ """
575
+
576
+ def format_value(v, level):
577
+ if isinstance(v, dict):
578
+ if not v:
579
+ return '{}'
580
+ items = []
581
+ for k, val in v.items():
582
+ formatted_val = format_value(val, level + 1)
583
+ items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted_val}')
584
+ return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
585
+ else:
586
+ return repr(v)
587
+
588
+ if not self._data:
589
+ return 'FrozenDict({})'
590
+
591
+ return f'FrozenDict({format_value(self._data, 0)})'
592
+
593
+ def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]:
594
+ """Flatten for JAX pytree with keys."""
595
+ sorted_keys = sorted(self._data.keys())
596
+ values_with_keys = [
597
+ (jax.tree_util.DictKey(k), self._data[k])
598
+ for k in sorted_keys
599
+ ]
600
+ return values_with_keys, tuple(sorted_keys)
601
+
602
+ @classmethod
603
+ def tree_unflatten(cls, keys: tuple[Any, ...], values: list[Any]) -> FrozenDict:
604
+ """Unflatten from JAX pytree."""
605
+ return cls(dict(zip(keys, values)))
606
+
607
+
608
+ class FrozenKeysView(KeysView[K]):
609
+ """View of keys in a FrozenDict."""
610
+
611
+ def __repr__(self) -> str:
612
+ return f'FrozenDict.keys({list(self)})'
613
+
614
+
615
+ class FrozenValuesView(ValuesView[V]):
616
+ """View of values in a FrozenDict."""
617
+
618
+ def __repr__(self) -> str:
619
+ return f'FrozenDict.values({list(self)})'
620
+
621
+
622
+ def freeze(x: Mapping[K, V]) -> FrozenDict[K, V]:
623
+ """
624
+ Convert a mapping to a FrozenDict.
625
+
626
+ Parameters
627
+ ----------
628
+ x : Mapping
629
+ A mapping (dict, FrozenDict, etc.) to freeze.
630
+
631
+ Returns
632
+ -------
633
+ FrozenDict
634
+ An immutable FrozenDict.
635
+
636
+ See Also
637
+ --------
638
+ unfreeze : Convert a FrozenDict to a regular dict.
639
+ FrozenDict : The immutable dictionary class.
640
+
641
+ Examples
642
+ --------
643
+ .. code-block:: python
644
+
645
+ >>> from brainstate.util import freeze
646
+
647
+ >>> d = {'a': 1, 'b': {'c': 2}}
648
+ >>> fd = freeze(d)
649
+ >>> isinstance(fd, FrozenDict)
650
+ True
651
+ >>> isinstance(fd['b'], FrozenDict) # Nested dicts are frozen
652
+ True
653
+ """
654
+ if isinstance(x, FrozenDict):
655
+ return x
656
+ return FrozenDict(x)
657
+
658
+
659
+ def unfreeze(x: FrozenDict[K, V] | dict[K, V]) -> dict[K, V]:
660
+ """
661
+ Convert a FrozenDict to a regular dict.
662
+
663
+ Recursively converts FrozenDict instances to mutable dicts.
664
+
665
+ Parameters
666
+ ----------
667
+ x : FrozenDict or dict
668
+ A FrozenDict or dict to unfreeze.
669
+
670
+ Returns
671
+ -------
672
+ dict
673
+ A mutable dictionary.
674
+
675
+ See Also
676
+ --------
677
+ freeze : Convert a mapping to a FrozenDict.
678
+ FrozenDict : The immutable dictionary class.
679
+
680
+ Examples
681
+ --------
682
+ .. code-block:: python
683
+
684
+ >>> from brainstate.util import FrozenDict, unfreeze
685
+
686
+ >>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
687
+ >>> d = unfreeze(fd)
688
+ >>> isinstance(d, dict)
689
+ True
690
+ >>> isinstance(d['b'], dict) # Nested FrozenDicts are unfrozen
691
+ True
692
+ >>> d['a'] = 10 # Can modify the result
693
+ """
694
+ if isinstance(x, FrozenDict):
695
+ result = {}
696
+ for key, value in x._data.items():
697
+ result[key] = unfreeze(value)
698
+ return result
699
+ elif isinstance(x, dict):
700
+ result = {}
701
+ for key, value in x.items():
702
+ result[key] = unfreeze(value)
703
+ return result
704
+ else:
705
+ return x
706
+
707
+
708
+ @overload
709
+ def copy(x: FrozenDict[K, V], add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
710
+ ...
711
+
712
+
713
+ @overload
714
+ def copy(x: dict[K, V], add_or_replace: Mapping[K, V] | None = None) -> dict[K, V]:
715
+ ...
716
+
717
+
718
+ def copy(x, add_or_replace=None):
719
+ """
720
+ Copy a dictionary with optional updates.
721
+
722
+ Works with both FrozenDict and regular dict.
723
+
724
+ Parameters
725
+ ----------
726
+ x : FrozenDict or dict
727
+ Dictionary to copy.
728
+ add_or_replace : Mapping or None, optional
729
+ Entries to add or replace in the copy.
730
+
731
+ Returns
732
+ -------
733
+ FrozenDict or dict
734
+ A copy of the same type as the input with updates applied.
735
+
736
+ Raises
737
+ ------
738
+ TypeError
739
+ If x is not a FrozenDict or dict.
740
+
741
+ See Also
742
+ --------
743
+ FrozenDict.copy : Copy method for FrozenDict.
744
+
745
+ Examples
746
+ --------
747
+ .. code-block:: python
748
+
749
+ >>> from brainstate.util import FrozenDict, copy
750
+
751
+ >>> # Works with FrozenDict
752
+ >>> fd = FrozenDict({'a': 1})
753
+ >>> fd2 = copy(fd, {'b': 2})
754
+ >>> isinstance(fd2, FrozenDict)
755
+ True
756
+ >>> fd2['b']
757
+ 2
758
+
759
+ >>> # Also works with regular dict
760
+ >>> d = {'a': 1}
761
+ >>> d2 = copy(d, {'b': 2})
762
+ >>> isinstance(d2, dict)
763
+ True
764
+ >>> d2['b']
765
+ 2
766
+ """
767
+ if add_or_replace is None:
768
+ add_or_replace = {}
769
+
770
+ if isinstance(x, FrozenDict):
771
+ return x.copy(add_or_replace)
772
+ elif isinstance(x, dict):
773
+ result = dict(x)
774
+ result.update(add_or_replace)
775
+ return result
776
+ else:
777
+ raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
778
+
779
+
780
+ @overload
781
+ def pop(x: FrozenDict[K, V], key: K) -> tuple[FrozenDict[K, V], V]:
782
+ ...
783
+
784
+
785
+ @overload
786
+ def pop(x: dict[K, V], key: K) -> tuple[dict[K, V], V]:
787
+ ...
788
+
789
+
790
+ def pop(x, key):
791
+ """
792
+ Remove and return an item from a dictionary.
793
+
794
+ Works with both FrozenDict and regular dict, returning a new
795
+ dictionary without the specified key along with the popped value.
796
+
797
+ Parameters
798
+ ----------
799
+ x : FrozenDict or dict
800
+ Dictionary to pop from.
801
+ key : hashable
802
+ Key to remove.
803
+
804
+ Returns
805
+ -------
806
+ tuple
807
+ A tuple of (new dictionary without the key, popped value).
808
+
809
+ Raises
810
+ ------
811
+ TypeError
812
+ If x is not a FrozenDict or dict.
813
+ KeyError
814
+ If the key is not found in the dictionary.
815
+
816
+ See Also
817
+ --------
818
+ FrozenDict.pop : Pop method for FrozenDict.
819
+
820
+ Examples
821
+ --------
822
+ .. code-block:: python
823
+
824
+ >>> from brainstate.util import FrozenDict, pop
825
+
826
+ >>> # Works with FrozenDict
827
+ >>> fd = FrozenDict({'a': 1, 'b': 2})
828
+ >>> fd2, value = pop(fd, 'a')
829
+ >>> value
830
+ 1
831
+ >>> 'a' in fd2
832
+ False
833
+
834
+ >>> # Also works with regular dict
835
+ >>> d = {'a': 1, 'b': 2}
836
+ >>> d2, value = pop(d, 'a')
837
+ >>> value
838
+ 1
839
+ >>> 'a' in d2
840
+ False
841
+ """
842
+ if isinstance(x, FrozenDict):
843
+ return x.pop(key)
844
+ elif isinstance(x, dict):
845
+ new_dict = dict(x)
846
+ value = new_dict.pop(key)
847
+ return new_dict, value
848
+ else:
849
+ raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
850
+
851
+
852
+ def pretty_repr(x: Any, indent: int = 2) -> str:
853
+ """
854
+ Create a pretty string representation.
855
+
856
+ Parameters
857
+ ----------
858
+ x : any
859
+ Object to represent. If a dict or FrozenDict, will be
860
+ pretty-printed with indentation. Otherwise, returns repr(x).
861
+ indent : int, optional
862
+ Number of spaces per indentation level (default 2).
863
+
864
+ Returns
865
+ -------
866
+ str
867
+ A formatted string representation.
868
+
869
+ See Also
870
+ --------
871
+ FrozenDict.pretty_repr : Pretty representation for FrozenDict.
872
+
873
+ Examples
874
+ --------
875
+ .. code-block:: python
876
+
877
+ >>> from brainstate.util import pretty_repr
878
+
879
+ >>> d = {'a': 1, 'b': {'c': 2, 'd': 3}}
880
+ >>> print(pretty_repr(d))
881
+ {
882
+ 'a': 1,
883
+ 'b': {
884
+ 'c': 2,
885
+ 'd': 3
886
+ }
887
+ }
888
+
889
+ >>> # Non-dict objects return normal repr
890
+ >>> pretty_repr([1, 2, 3])
891
+ '[1, 2, 3]'
892
+ """
893
+ if isinstance(x, FrozenDict):
894
+ return x.pretty_repr(indent)
895
+ elif isinstance(x, dict):
896
+ def format_dict(d, level):
897
+ if not d:
898
+ return '{}'
899
+ items = []
900
+ for k, v in d.items():
901
+ if isinstance(v, dict):
902
+ formatted = format_dict(v, level + 1)
903
+ else:
904
+ formatted = repr(v)
905
+ items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted}')
906
+ return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
907
+
908
+ return format_dict(x, 0)
909
+ else:
910
+ return repr(x)