brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/util/struct.py
CHANGED
@@ -1,910 +1,910 @@
|
|
1
|
-
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
-
# The credit should go to the Flax authors.
|
3
|
-
#
|
4
|
-
# Copyright 2024 The Flax Authors.
|
5
|
-
#
|
6
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
-
# you may not use this file except in compliance with the License.
|
8
|
-
# You may obtain a copy of the License at
|
9
|
-
#
|
10
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
-
#
|
12
|
-
# Unless required by applicable law or agreed to in writing, software
|
13
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
-
# See the License for the specific language governing permissions and
|
16
|
-
# limitations under the License.
|
17
|
-
|
18
|
-
"""
|
19
|
-
Custom data structures that work seamlessly with JAX transformations.
|
20
|
-
"""
|
21
|
-
|
22
|
-
from __future__ import annotations
|
23
|
-
|
24
|
-
import dataclasses
|
25
|
-
from collections.abc import Mapping, KeysView, ValuesView, ItemsView
|
26
|
-
from typing import Any, TypeVar, Generic, Iterator, overload
|
27
|
-
|
28
|
-
import jax
|
29
|
-
import jax.tree_util
|
30
|
-
from typing_extensions import dataclass_transform
|
31
|
-
|
32
|
-
__all__ = [
|
33
|
-
'field',
|
34
|
-
'dataclass',
|
35
|
-
'PyTreeNode',
|
36
|
-
'FrozenDict',
|
37
|
-
'freeze',
|
38
|
-
'unfreeze',
|
39
|
-
'copy',
|
40
|
-
'pop',
|
41
|
-
'pretty_repr',
|
42
|
-
]
|
43
|
-
|
44
|
-
# Type variables
|
45
|
-
K = TypeVar('K')
|
46
|
-
V = TypeVar('V')
|
47
|
-
T = TypeVar('T')
|
48
|
-
TNode = TypeVar('TNode', bound='PyTreeNode')
|
49
|
-
|
50
|
-
|
51
|
-
def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field:
|
52
|
-
"""
|
53
|
-
Create a dataclass field with JAX pytree metadata.
|
54
|
-
|
55
|
-
Parameters
|
56
|
-
----------
|
57
|
-
pytree_node : bool, optional
|
58
|
-
If True (default), this field will be treated as part of the pytree.
|
59
|
-
If False, it will be treated as metadata and not be touched
|
60
|
-
by JAX transformations.
|
61
|
-
**kwargs
|
62
|
-
Additional arguments to pass to dataclasses.field().
|
63
|
-
|
64
|
-
Returns
|
65
|
-
-------
|
66
|
-
dataclasses.Field
|
67
|
-
A dataclass field with the appropriate metadata.
|
68
|
-
|
69
|
-
Examples
|
70
|
-
--------
|
71
|
-
.. code-block:: python
|
72
|
-
|
73
|
-
>>> import jax.numpy as jnp
|
74
|
-
>>> from brainstate.util import dataclass, field
|
75
|
-
|
76
|
-
>>> @dataclass
|
77
|
-
... class Model:
|
78
|
-
... weights: jnp.ndarray
|
79
|
-
... bias: jnp.ndarray
|
80
|
-
... # This field won't be affected by JAX transformations
|
81
|
-
... name: str = field(pytree_node=False, default="model")
|
82
|
-
"""
|
83
|
-
metadata = kwargs.pop('metadata', {})
|
84
|
-
metadata['pytree_node'] = pytree_node
|
85
|
-
return dataclasses.field(metadata=metadata, **kwargs)
|
86
|
-
|
87
|
-
|
88
|
-
@dataclass_transform(field_specifiers=(field,))
|
89
|
-
def dataclass(cls: type[T], **kwargs) -> type[T]:
|
90
|
-
"""
|
91
|
-
Create a dataclass that works with JAX transformations.
|
92
|
-
|
93
|
-
This decorator creates immutable dataclasses that can be used safely
|
94
|
-
with JAX transformations like jit, grad, vmap, etc. The created class
|
95
|
-
will be registered as a JAX pytree node.
|
96
|
-
|
97
|
-
Parameters
|
98
|
-
----------
|
99
|
-
cls : type
|
100
|
-
The class to decorate.
|
101
|
-
**kwargs
|
102
|
-
Additional arguments for dataclasses.dataclass().
|
103
|
-
If 'frozen' is not specified, it defaults to True.
|
104
|
-
|
105
|
-
Returns
|
106
|
-
-------
|
107
|
-
type
|
108
|
-
The decorated class as an immutable JAX-compatible dataclass.
|
109
|
-
|
110
|
-
See Also
|
111
|
-
--------
|
112
|
-
PyTreeNode : Base class for creating JAX-compatible pytree nodes.
|
113
|
-
field : Create dataclass fields with pytree metadata.
|
114
|
-
|
115
|
-
Notes
|
116
|
-
-----
|
117
|
-
The decorated class will be frozen (immutable) by default to ensure
|
118
|
-
compatibility with JAX's functional programming paradigm.
|
119
|
-
|
120
|
-
Examples
|
121
|
-
--------
|
122
|
-
.. code-block:: python
|
123
|
-
|
124
|
-
>>> import jax
|
125
|
-
>>> import jax.numpy as jnp
|
126
|
-
>>> from brainstate.util import dataclass, field
|
127
|
-
|
128
|
-
>>> @dataclass
|
129
|
-
... class Model:
|
130
|
-
... weights: jax.Array
|
131
|
-
... bias: jax.Array
|
132
|
-
... name: str = field(pytree_node=False, default="model")
|
133
|
-
|
134
|
-
>>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))
|
135
|
-
|
136
|
-
>>> # JAX transformations will only apply to weights and bias, not name
|
137
|
-
>>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
|
138
|
-
>>> grads = grad_fn(model)
|
139
|
-
|
140
|
-
>>> # Use replace to create modified copies
|
141
|
-
>>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)
|
142
|
-
"""
|
143
|
-
# Check if already converted
|
144
|
-
if hasattr(cls, '_brainstate_dataclass'):
|
145
|
-
return cls
|
146
|
-
|
147
|
-
# Default to frozen for immutability
|
148
|
-
kwargs.setdefault('frozen', True)
|
149
|
-
|
150
|
-
# Apply standard dataclass decorator
|
151
|
-
cls = dataclasses.dataclass(**kwargs)(cls)
|
152
|
-
|
153
|
-
# Separate fields into pytree and metadata
|
154
|
-
pytree_fields = []
|
155
|
-
meta_fields = []
|
156
|
-
|
157
|
-
for field_info in dataclasses.fields(cls):
|
158
|
-
if field_info.metadata.get('pytree_node', True):
|
159
|
-
pytree_fields.append(field_info.name)
|
160
|
-
else:
|
161
|
-
meta_fields.append(field_info.name)
|
162
|
-
|
163
|
-
# Add replace method
|
164
|
-
def replace(self: T, **updates) -> T:
|
165
|
-
"""Replace specified fields with new values."""
|
166
|
-
return dataclasses.replace(self, **updates)
|
167
|
-
|
168
|
-
cls.replace = replace
|
169
|
-
|
170
|
-
# Register with JAX
|
171
|
-
_register_pytree(cls, pytree_fields, meta_fields)
|
172
|
-
|
173
|
-
# Mark as BrainState dataclass
|
174
|
-
cls._brainstate_dataclass = True
|
175
|
-
|
176
|
-
return cls
|
177
|
-
|
178
|
-
|
179
|
-
def _register_pytree(cls: type, pytree_fields: list[str], meta_fields: list[str]) -> None:
|
180
|
-
"""Register a class as a JAX pytree."""
|
181
|
-
|
182
|
-
def flatten_fn(obj):
|
183
|
-
pytree_data = tuple(getattr(obj, name) for name in pytree_fields)
|
184
|
-
metadata = tuple(getattr(obj, name) for name in meta_fields)
|
185
|
-
return pytree_data, metadata
|
186
|
-
|
187
|
-
def flatten_with_keys_fn(obj):
|
188
|
-
pytree_data = tuple(
|
189
|
-
(jax.tree_util.GetAttrKey(name), getattr(obj, name))
|
190
|
-
for name in pytree_fields
|
191
|
-
)
|
192
|
-
metadata = tuple(getattr(obj, name) for name in meta_fields)
|
193
|
-
return pytree_data, metadata
|
194
|
-
|
195
|
-
def unflatten_fn(metadata, pytree_data):
|
196
|
-
kwargs = {}
|
197
|
-
for name, value in zip(meta_fields, metadata):
|
198
|
-
kwargs[name] = value
|
199
|
-
for name, value in zip(pytree_fields, pytree_data):
|
200
|
-
kwargs[name] = value
|
201
|
-
return cls(**kwargs)
|
202
|
-
|
203
|
-
# Use new API if available, otherwise fall back
|
204
|
-
if hasattr(jax.tree_util, 'register_dataclass'):
|
205
|
-
jax.tree_util.register_dataclass(cls, pytree_fields, meta_fields)
|
206
|
-
else:
|
207
|
-
jax.tree_util.register_pytree_with_keys(
|
208
|
-
cls,
|
209
|
-
flatten_with_keys_fn,
|
210
|
-
unflatten_fn,
|
211
|
-
flatten_fn
|
212
|
-
)
|
213
|
-
|
214
|
-
|
215
|
-
@dataclass_transform(field_specifiers=(field,))
|
216
|
-
class PyTreeNode:
|
217
|
-
"""
|
218
|
-
Base class for creating JAX-compatible pytree nodes.
|
219
|
-
|
220
|
-
Subclasses of PyTreeNode are automatically converted to immutable
|
221
|
-
dataclasses that work with JAX transformations.
|
222
|
-
|
223
|
-
See Also
|
224
|
-
--------
|
225
|
-
dataclass : Decorator for creating JAX-compatible dataclasses.
|
226
|
-
field : Create dataclass fields with pytree metadata.
|
227
|
-
|
228
|
-
Notes
|
229
|
-
-----
|
230
|
-
When subclassing PyTreeNode, all fields are automatically treated as
|
231
|
-
part of the pytree unless explicitly marked with ``pytree_node=False``
|
232
|
-
using the field() function.
|
233
|
-
|
234
|
-
Examples
|
235
|
-
--------
|
236
|
-
.. code-block:: python
|
237
|
-
|
238
|
-
>>> import jax
|
239
|
-
>>> import jax.numpy as jnp
|
240
|
-
>>> from brainstate.util import PyTreeNode, field
|
241
|
-
|
242
|
-
>>> class Layer(PyTreeNode):
|
243
|
-
... weights: jax.Array
|
244
|
-
... bias: jax.Array
|
245
|
-
... activation: str = field(pytree_node=False, default="relu")
|
246
|
-
|
247
|
-
>>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))
|
248
|
-
|
249
|
-
>>> # Can be used in JAX transformations
|
250
|
-
>>> def loss_fn(layer):
|
251
|
-
... return jnp.sum(layer.weights ** 2)
|
252
|
-
>>> grad_fn = jax.grad(loss_fn)
|
253
|
-
>>> grads = grad_fn(layer)
|
254
|
-
|
255
|
-
>>> # Create modified copies with replace
|
256
|
-
>>> layer2 = layer.replace(bias=jnp.ones(4))
|
257
|
-
"""
|
258
|
-
|
259
|
-
def __init_subclass__(cls, **kwargs):
|
260
|
-
"""Automatically apply dataclass decorator to subclasses."""
|
261
|
-
dataclass(cls, **kwargs)
|
262
|
-
|
263
|
-
def __init__(self, *args, **kwargs):
|
264
|
-
"""Stub for type checkers."""
|
265
|
-
raise NotImplementedError("PyTreeNode is a base class")
|
266
|
-
|
267
|
-
def replace(self: TNode, **updates) -> TNode:
|
268
|
-
"""
|
269
|
-
Replace specified fields with new values.
|
270
|
-
|
271
|
-
Parameters
|
272
|
-
----------
|
273
|
-
**updates
|
274
|
-
Field names and their new values.
|
275
|
-
|
276
|
-
Returns
|
277
|
-
-------
|
278
|
-
TNode
|
279
|
-
A new instance with updated fields.
|
280
|
-
"""
|
281
|
-
raise NotImplementedError("Implemented by dataclass decorator")
|
282
|
-
|
283
|
-
|
284
|
-
@jax.tree_util.register_pytree_with_keys_class
|
285
|
-
class FrozenDict(Mapping[K, V], Generic[K, V]):
|
286
|
-
"""
|
287
|
-
An immutable dictionary that works as a JAX pytree.
|
288
|
-
|
289
|
-
FrozenDict provides an immutable mapping interface that can be used
|
290
|
-
safely with JAX transformations. It supports all standard dictionary
|
291
|
-
operations in an immutable fashion.
|
292
|
-
|
293
|
-
Parameters
|
294
|
-
----------
|
295
|
-
*args
|
296
|
-
Positional arguments for dict construction.
|
297
|
-
**kwargs
|
298
|
-
Keyword arguments for dict construction.
|
299
|
-
|
300
|
-
Attributes
|
301
|
-
----------
|
302
|
-
_data : dict
|
303
|
-
Internal immutable data storage.
|
304
|
-
_hash : int or None
|
305
|
-
Cached hash value.
|
306
|
-
|
307
|
-
See Also
|
308
|
-
--------
|
309
|
-
freeze : Convert a mapping to a FrozenDict.
|
310
|
-
unfreeze : Convert a FrozenDict to a regular dict.
|
311
|
-
|
312
|
-
Notes
|
313
|
-
-----
|
314
|
-
FrozenDict is immutable - all operations that would modify the dictionary
|
315
|
-
instead return a new FrozenDict instance with the changes applied.
|
316
|
-
|
317
|
-
Examples
|
318
|
-
--------
|
319
|
-
.. code-block:: python
|
320
|
-
|
321
|
-
>>> from brainstate.util import FrozenDict
|
322
|
-
|
323
|
-
>>> # Create a FrozenDict
|
324
|
-
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
325
|
-
>>> fd['a']
|
326
|
-
1
|
327
|
-
|
328
|
-
>>> # Copy with updates (returns new FrozenDict)
|
329
|
-
>>> new_fd = fd.copy({'c': 3})
|
330
|
-
>>> new_fd['c']
|
331
|
-
3
|
332
|
-
|
333
|
-
>>> # Pop an item (returns new dict and popped value)
|
334
|
-
>>> new_fd, value = fd.pop('b')
|
335
|
-
>>> value
|
336
|
-
2
|
337
|
-
>>> 'b' in new_fd
|
338
|
-
False
|
339
|
-
|
340
|
-
>>> # Nested dictionaries are automatically frozen
|
341
|
-
>>> fd = FrozenDict({'x': {'y': 1}})
|
342
|
-
>>> isinstance(fd['x'], FrozenDict)
|
343
|
-
True
|
344
|
-
"""
|
345
|
-
|
346
|
-
__slots__ = ('_data', '_hash')
|
347
|
-
|
348
|
-
def __init__(self, *args, **kwargs):
|
349
|
-
"""Initialize a FrozenDict."""
|
350
|
-
data = dict(*args, **kwargs)
|
351
|
-
self._data = self._deep_freeze(data)
|
352
|
-
self._hash = None
|
353
|
-
|
354
|
-
@staticmethod
|
355
|
-
def _deep_freeze(obj: Any) -> Any:
|
356
|
-
"""Recursively freeze nested dictionaries."""
|
357
|
-
if isinstance(obj, FrozenDict):
|
358
|
-
return obj._data
|
359
|
-
elif isinstance(obj, dict):
|
360
|
-
return {k: FrozenDict._deep_freeze(v) for k, v in obj.items()}
|
361
|
-
else:
|
362
|
-
return obj
|
363
|
-
|
364
|
-
def __getitem__(self, key: K) -> V:
|
365
|
-
"""Get an item from the dictionary."""
|
366
|
-
value = self._data[key]
|
367
|
-
if isinstance(value, dict):
|
368
|
-
return FrozenDict(value)
|
369
|
-
return value
|
370
|
-
|
371
|
-
def __setitem__(self, key: K, value: V) -> None:
|
372
|
-
"""Raise an error - FrozenDict is immutable."""
|
373
|
-
raise TypeError("FrozenDict does not support item assignment")
|
374
|
-
|
375
|
-
def __delitem__(self, key: K) -> None:
|
376
|
-
"""Raise an error - FrozenDict is immutable."""
|
377
|
-
raise TypeError("FrozenDict does not support item deletion")
|
378
|
-
|
379
|
-
def __contains__(self, key: object) -> bool:
|
380
|
-
"""Check if a key is in the dictionary."""
|
381
|
-
return key in self._data
|
382
|
-
|
383
|
-
def __iter__(self) -> Iterator[K]:
|
384
|
-
"""Iterate over keys."""
|
385
|
-
return iter(self._data)
|
386
|
-
|
387
|
-
def __len__(self) -> int:
|
388
|
-
"""Return the number of items."""
|
389
|
-
return len(self._data)
|
390
|
-
|
391
|
-
def __repr__(self) -> str:
|
392
|
-
"""Return a string representation."""
|
393
|
-
return self.pretty_repr()
|
394
|
-
|
395
|
-
def __hash__(self) -> int:
|
396
|
-
"""Return a hash of the dictionary."""
|
397
|
-
if self._hash is None:
|
398
|
-
items = []
|
399
|
-
for key, value in self.items():
|
400
|
-
if isinstance(value, dict):
|
401
|
-
value = FrozenDict(value)
|
402
|
-
items.append((key, value))
|
403
|
-
self._hash = hash(tuple(sorted(items)))
|
404
|
-
return self._hash
|
405
|
-
|
406
|
-
def __eq__(self, other: object) -> bool:
|
407
|
-
"""Check equality with another object."""
|
408
|
-
if not isinstance(other, (FrozenDict, dict)):
|
409
|
-
return NotImplemented
|
410
|
-
if isinstance(other, FrozenDict):
|
411
|
-
return self._data == other._data
|
412
|
-
return self._data == other
|
413
|
-
|
414
|
-
def __reduce__(self):
|
415
|
-
"""Support for pickling."""
|
416
|
-
return FrozenDict, (self.unfreeze(),)
|
417
|
-
|
418
|
-
def keys(self) -> KeysView[K]:
|
419
|
-
"""
|
420
|
-
Return a view of the keys.
|
421
|
-
|
422
|
-
Returns
|
423
|
-
-------
|
424
|
-
KeysView
|
425
|
-
A view object of the dictionary's keys.
|
426
|
-
"""
|
427
|
-
return FrozenKeysView(self)
|
428
|
-
|
429
|
-
def values(self) -> ValuesView[V]:
|
430
|
-
"""
|
431
|
-
Return a view of the values.
|
432
|
-
|
433
|
-
Returns
|
434
|
-
-------
|
435
|
-
ValuesView
|
436
|
-
A view object of the dictionary's values.
|
437
|
-
"""
|
438
|
-
return FrozenValuesView(self)
|
439
|
-
|
440
|
-
def items(self) -> ItemsView[K, V]:
|
441
|
-
"""
|
442
|
-
Return a view of the items.
|
443
|
-
|
444
|
-
Yields
|
445
|
-
------
|
446
|
-
tuple
|
447
|
-
Key-value pairs from the dictionary.
|
448
|
-
"""
|
449
|
-
for key in self._data:
|
450
|
-
yield (key, self[key])
|
451
|
-
|
452
|
-
def get(self, key: K, default: V | None = None) -> V | None:
|
453
|
-
"""
|
454
|
-
Get a value with a default.
|
455
|
-
|
456
|
-
Parameters
|
457
|
-
----------
|
458
|
-
key : K
|
459
|
-
The key to look up.
|
460
|
-
default : V or None, optional
|
461
|
-
The default value to return if key is not found.
|
462
|
-
|
463
|
-
Returns
|
464
|
-
-------
|
465
|
-
V or None
|
466
|
-
The value associated with the key, or default.
|
467
|
-
"""
|
468
|
-
try:
|
469
|
-
return self[key]
|
470
|
-
except KeyError:
|
471
|
-
return default
|
472
|
-
|
473
|
-
def copy(self, add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
|
474
|
-
"""
|
475
|
-
Create a new FrozenDict with added or replaced entries.
|
476
|
-
|
477
|
-
Parameters
|
478
|
-
----------
|
479
|
-
add_or_replace : Mapping or None, optional
|
480
|
-
Entries to add or replace in the new dictionary.
|
481
|
-
|
482
|
-
Returns
|
483
|
-
-------
|
484
|
-
FrozenDict
|
485
|
-
A new FrozenDict with the updates applied.
|
486
|
-
|
487
|
-
Examples
|
488
|
-
--------
|
489
|
-
.. code-block:: python
|
490
|
-
|
491
|
-
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
492
|
-
>>> fd2 = fd.copy({'b': 3, 'c': 4})
|
493
|
-
>>> fd2['b'], fd2['c']
|
494
|
-
(3, 4)
|
495
|
-
"""
|
496
|
-
if add_or_replace is None:
|
497
|
-
add_or_replace = {}
|
498
|
-
new_data = dict(self._data)
|
499
|
-
new_data.update(add_or_replace)
|
500
|
-
return type(self)(new_data)
|
501
|
-
|
502
|
-
def pop(self, key: K) -> tuple[FrozenDict[K, V], V]:
|
503
|
-
"""
|
504
|
-
Create a new FrozenDict with one entry removed.
|
505
|
-
|
506
|
-
Parameters
|
507
|
-
----------
|
508
|
-
key : K
|
509
|
-
The key to remove.
|
510
|
-
|
511
|
-
Returns
|
512
|
-
-------
|
513
|
-
tuple
|
514
|
-
A tuple of (new FrozenDict without the key, removed value).
|
515
|
-
|
516
|
-
Raises
|
517
|
-
------
|
518
|
-
KeyError
|
519
|
-
If the key is not found in the dictionary.
|
520
|
-
|
521
|
-
Examples
|
522
|
-
--------
|
523
|
-
.. code-block:: python
|
524
|
-
|
525
|
-
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
526
|
-
>>> fd2, value = fd.pop('a')
|
527
|
-
>>> value
|
528
|
-
1
|
529
|
-
>>> 'a' in fd2
|
530
|
-
False
|
531
|
-
"""
|
532
|
-
if key not in self._data:
|
533
|
-
raise KeyError(key)
|
534
|
-
value = self[key]
|
535
|
-
new_data = dict(self._data)
|
536
|
-
del new_data[key]
|
537
|
-
return type(self)(new_data), value
|
538
|
-
|
539
|
-
def unfreeze(self) -> dict[K, V]:
|
540
|
-
"""
|
541
|
-
Convert to a regular mutable dictionary.
|
542
|
-
|
543
|
-
Returns
|
544
|
-
-------
|
545
|
-
dict
|
546
|
-
A mutable dict with the same contents.
|
547
|
-
|
548
|
-
Examples
|
549
|
-
--------
|
550
|
-
.. code-block:: python
|
551
|
-
|
552
|
-
>>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
553
|
-
>>> d = fd.unfreeze()
|
554
|
-
>>> isinstance(d, dict)
|
555
|
-
True
|
556
|
-
>>> isinstance(d['b'], dict) # Nested dicts also unfrozen
|
557
|
-
True
|
558
|
-
"""
|
559
|
-
return unfreeze(self)
|
560
|
-
|
561
|
-
def pretty_repr(self, indent: int = 2) -> str:
|
562
|
-
"""
|
563
|
-
Return a pretty-printed representation.
|
564
|
-
|
565
|
-
Parameters
|
566
|
-
----------
|
567
|
-
indent : int, optional
|
568
|
-
Number of spaces per indentation level (default 2).
|
569
|
-
|
570
|
-
Returns
|
571
|
-
-------
|
572
|
-
str
|
573
|
-
A formatted string representation of the FrozenDict.
|
574
|
-
"""
|
575
|
-
|
576
|
-
def format_value(v, level):
|
577
|
-
if isinstance(v, dict):
|
578
|
-
if not v:
|
579
|
-
return '{}'
|
580
|
-
items = []
|
581
|
-
for k, val in v.items():
|
582
|
-
formatted_val = format_value(val, level + 1)
|
583
|
-
items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted_val}')
|
584
|
-
return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
|
585
|
-
else:
|
586
|
-
return repr(v)
|
587
|
-
|
588
|
-
if not self._data:
|
589
|
-
return 'FrozenDict({})'
|
590
|
-
|
591
|
-
return f'FrozenDict({format_value(self._data, 0)})'
|
592
|
-
|
593
|
-
def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]:
|
594
|
-
"""Flatten for JAX pytree with keys."""
|
595
|
-
sorted_keys = sorted(self._data.keys())
|
596
|
-
values_with_keys = [
|
597
|
-
(jax.tree_util.DictKey(k), self._data[k])
|
598
|
-
for k in sorted_keys
|
599
|
-
]
|
600
|
-
return values_with_keys, tuple(sorted_keys)
|
601
|
-
|
602
|
-
@classmethod
|
603
|
-
def tree_unflatten(cls, keys: tuple[Any, ...], values: list[Any]) -> FrozenDict:
|
604
|
-
"""Unflatten from JAX pytree."""
|
605
|
-
return cls(dict(zip(keys, values)))
|
606
|
-
|
607
|
-
|
608
|
-
class FrozenKeysView(KeysView[K]):
|
609
|
-
"""View of keys in a FrozenDict."""
|
610
|
-
|
611
|
-
def __repr__(self) -> str:
|
612
|
-
return f'FrozenDict.keys({list(self)})'
|
613
|
-
|
614
|
-
|
615
|
-
class FrozenValuesView(ValuesView[V]):
|
616
|
-
"""View of values in a FrozenDict."""
|
617
|
-
|
618
|
-
def __repr__(self) -> str:
|
619
|
-
return f'FrozenDict.values({list(self)})'
|
620
|
-
|
621
|
-
|
622
|
-
def freeze(x: Mapping[K, V]) -> FrozenDict[K, V]:
|
623
|
-
"""
|
624
|
-
Convert a mapping to a FrozenDict.
|
625
|
-
|
626
|
-
Parameters
|
627
|
-
----------
|
628
|
-
x : Mapping
|
629
|
-
A mapping (dict, FrozenDict, etc.) to freeze.
|
630
|
-
|
631
|
-
Returns
|
632
|
-
-------
|
633
|
-
FrozenDict
|
634
|
-
An immutable FrozenDict.
|
635
|
-
|
636
|
-
See Also
|
637
|
-
--------
|
638
|
-
unfreeze : Convert a FrozenDict to a regular dict.
|
639
|
-
FrozenDict : The immutable dictionary class.
|
640
|
-
|
641
|
-
Examples
|
642
|
-
--------
|
643
|
-
.. code-block:: python
|
644
|
-
|
645
|
-
>>> from brainstate.util import freeze
|
646
|
-
|
647
|
-
>>> d = {'a': 1, 'b': {'c': 2}}
|
648
|
-
>>> fd = freeze(d)
|
649
|
-
>>> isinstance(fd, FrozenDict)
|
650
|
-
True
|
651
|
-
>>> isinstance(fd['b'], FrozenDict) # Nested dicts are frozen
|
652
|
-
True
|
653
|
-
"""
|
654
|
-
if isinstance(x, FrozenDict):
|
655
|
-
return x
|
656
|
-
return FrozenDict(x)
|
657
|
-
|
658
|
-
|
659
|
-
def unfreeze(x: FrozenDict[K, V] | dict[K, V]) -> dict[K, V]:
|
660
|
-
"""
|
661
|
-
Convert a FrozenDict to a regular dict.
|
662
|
-
|
663
|
-
Recursively converts FrozenDict instances to mutable dicts.
|
664
|
-
|
665
|
-
Parameters
|
666
|
-
----------
|
667
|
-
x : FrozenDict or dict
|
668
|
-
A FrozenDict or dict to unfreeze.
|
669
|
-
|
670
|
-
Returns
|
671
|
-
-------
|
672
|
-
dict
|
673
|
-
A mutable dictionary.
|
674
|
-
|
675
|
-
See Also
|
676
|
-
--------
|
677
|
-
freeze : Convert a mapping to a FrozenDict.
|
678
|
-
FrozenDict : The immutable dictionary class.
|
679
|
-
|
680
|
-
Examples
|
681
|
-
--------
|
682
|
-
.. code-block:: python
|
683
|
-
|
684
|
-
>>> from brainstate.util import FrozenDict, unfreeze
|
685
|
-
|
686
|
-
>>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
687
|
-
>>> d = unfreeze(fd)
|
688
|
-
>>> isinstance(d, dict)
|
689
|
-
True
|
690
|
-
>>> isinstance(d['b'], dict) # Nested FrozenDicts are unfrozen
|
691
|
-
True
|
692
|
-
>>> d['a'] = 10 # Can modify the result
|
693
|
-
"""
|
694
|
-
if isinstance(x, FrozenDict):
|
695
|
-
result = {}
|
696
|
-
for key, value in x._data.items():
|
697
|
-
result[key] = unfreeze(value)
|
698
|
-
return result
|
699
|
-
elif isinstance(x, dict):
|
700
|
-
result = {}
|
701
|
-
for key, value in x.items():
|
702
|
-
result[key] = unfreeze(value)
|
703
|
-
return result
|
704
|
-
else:
|
705
|
-
return x
|
706
|
-
|
707
|
-
|
708
|
-
@overload
|
709
|
-
def copy(x: FrozenDict[K, V], add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
|
710
|
-
...
|
711
|
-
|
712
|
-
|
713
|
-
@overload
|
714
|
-
def copy(x: dict[K, V], add_or_replace: Mapping[K, V] | None = None) -> dict[K, V]:
|
715
|
-
...
|
716
|
-
|
717
|
-
|
718
|
-
def copy(x, add_or_replace=None):
|
719
|
-
"""
|
720
|
-
Copy a dictionary with optional updates.
|
721
|
-
|
722
|
-
Works with both FrozenDict and regular dict.
|
723
|
-
|
724
|
-
Parameters
|
725
|
-
----------
|
726
|
-
x : FrozenDict or dict
|
727
|
-
Dictionary to copy.
|
728
|
-
add_or_replace : Mapping or None, optional
|
729
|
-
Entries to add or replace in the copy.
|
730
|
-
|
731
|
-
Returns
|
732
|
-
-------
|
733
|
-
FrozenDict or dict
|
734
|
-
A copy of the same type as the input with updates applied.
|
735
|
-
|
736
|
-
Raises
|
737
|
-
------
|
738
|
-
TypeError
|
739
|
-
If x is not a FrozenDict or dict.
|
740
|
-
|
741
|
-
See Also
|
742
|
-
--------
|
743
|
-
FrozenDict.copy : Copy method for FrozenDict.
|
744
|
-
|
745
|
-
Examples
|
746
|
-
--------
|
747
|
-
.. code-block:: python
|
748
|
-
|
749
|
-
>>> from brainstate.util import FrozenDict, copy
|
750
|
-
|
751
|
-
>>> # Works with FrozenDict
|
752
|
-
>>> fd = FrozenDict({'a': 1})
|
753
|
-
>>> fd2 = copy(fd, {'b': 2})
|
754
|
-
>>> isinstance(fd2, FrozenDict)
|
755
|
-
True
|
756
|
-
>>> fd2['b']
|
757
|
-
2
|
758
|
-
|
759
|
-
>>> # Also works with regular dict
|
760
|
-
>>> d = {'a': 1}
|
761
|
-
>>> d2 = copy(d, {'b': 2})
|
762
|
-
>>> isinstance(d2, dict)
|
763
|
-
True
|
764
|
-
>>> d2['b']
|
765
|
-
2
|
766
|
-
"""
|
767
|
-
if add_or_replace is None:
|
768
|
-
add_or_replace = {}
|
769
|
-
|
770
|
-
if isinstance(x, FrozenDict):
|
771
|
-
return x.copy(add_or_replace)
|
772
|
-
elif isinstance(x, dict):
|
773
|
-
result = dict(x)
|
774
|
-
result.update(add_or_replace)
|
775
|
-
return result
|
776
|
-
else:
|
777
|
-
raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
|
778
|
-
|
779
|
-
|
780
|
-
@overload
|
781
|
-
def pop(x: FrozenDict[K, V], key: K) -> tuple[FrozenDict[K, V], V]:
|
782
|
-
...
|
783
|
-
|
784
|
-
|
785
|
-
@overload
|
786
|
-
def pop(x: dict[K, V], key: K) -> tuple[dict[K, V], V]:
|
787
|
-
...
|
788
|
-
|
789
|
-
|
790
|
-
def pop(x, key):
|
791
|
-
"""
|
792
|
-
Remove and return an item from a dictionary.
|
793
|
-
|
794
|
-
Works with both FrozenDict and regular dict, returning a new
|
795
|
-
dictionary without the specified key along with the popped value.
|
796
|
-
|
797
|
-
Parameters
|
798
|
-
----------
|
799
|
-
x : FrozenDict or dict
|
800
|
-
Dictionary to pop from.
|
801
|
-
key : hashable
|
802
|
-
Key to remove.
|
803
|
-
|
804
|
-
Returns
|
805
|
-
-------
|
806
|
-
tuple
|
807
|
-
A tuple of (new dictionary without the key, popped value).
|
808
|
-
|
809
|
-
Raises
|
810
|
-
------
|
811
|
-
TypeError
|
812
|
-
If x is not a FrozenDict or dict.
|
813
|
-
KeyError
|
814
|
-
If the key is not found in the dictionary.
|
815
|
-
|
816
|
-
See Also
|
817
|
-
--------
|
818
|
-
FrozenDict.pop : Pop method for FrozenDict.
|
819
|
-
|
820
|
-
Examples
|
821
|
-
--------
|
822
|
-
.. code-block:: python
|
823
|
-
|
824
|
-
>>> from brainstate.util import FrozenDict, pop
|
825
|
-
|
826
|
-
>>> # Works with FrozenDict
|
827
|
-
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
828
|
-
>>> fd2, value = pop(fd, 'a')
|
829
|
-
>>> value
|
830
|
-
1
|
831
|
-
>>> 'a' in fd2
|
832
|
-
False
|
833
|
-
|
834
|
-
>>> # Also works with regular dict
|
835
|
-
>>> d = {'a': 1, 'b': 2}
|
836
|
-
>>> d2, value = pop(d, 'a')
|
837
|
-
>>> value
|
838
|
-
1
|
839
|
-
>>> 'a' in d2
|
840
|
-
False
|
841
|
-
"""
|
842
|
-
if isinstance(x, FrozenDict):
|
843
|
-
return x.pop(key)
|
844
|
-
elif isinstance(x, dict):
|
845
|
-
new_dict = dict(x)
|
846
|
-
value = new_dict.pop(key)
|
847
|
-
return new_dict, value
|
848
|
-
else:
|
849
|
-
raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
|
850
|
-
|
851
|
-
|
852
|
-
def pretty_repr(x: Any, indent: int = 2) -> str:
|
853
|
-
"""
|
854
|
-
Create a pretty string representation.
|
855
|
-
|
856
|
-
Parameters
|
857
|
-
----------
|
858
|
-
x : any
|
859
|
-
Object to represent. If a dict or FrozenDict, will be
|
860
|
-
pretty-printed with indentation. Otherwise, returns repr(x).
|
861
|
-
indent : int, optional
|
862
|
-
Number of spaces per indentation level (default 2).
|
863
|
-
|
864
|
-
Returns
|
865
|
-
-------
|
866
|
-
str
|
867
|
-
A formatted string representation.
|
868
|
-
|
869
|
-
See Also
|
870
|
-
--------
|
871
|
-
FrozenDict.pretty_repr : Pretty representation for FrozenDict.
|
872
|
-
|
873
|
-
Examples
|
874
|
-
--------
|
875
|
-
.. code-block:: python
|
876
|
-
|
877
|
-
>>> from brainstate.util import pretty_repr
|
878
|
-
|
879
|
-
>>> d = {'a': 1, 'b': {'c': 2, 'd': 3}}
|
880
|
-
>>> print(pretty_repr(d))
|
881
|
-
{
|
882
|
-
'a': 1,
|
883
|
-
'b': {
|
884
|
-
'c': 2,
|
885
|
-
'd': 3
|
886
|
-
}
|
887
|
-
}
|
888
|
-
|
889
|
-
>>> # Non-dict objects return normal repr
|
890
|
-
>>> pretty_repr([1, 2, 3])
|
891
|
-
'[1, 2, 3]'
|
892
|
-
"""
|
893
|
-
if isinstance(x, FrozenDict):
|
894
|
-
return x.pretty_repr(indent)
|
895
|
-
elif isinstance(x, dict):
|
896
|
-
def format_dict(d, level):
|
897
|
-
if not d:
|
898
|
-
return '{}'
|
899
|
-
items = []
|
900
|
-
for k, v in d.items():
|
901
|
-
if isinstance(v, dict):
|
902
|
-
formatted = format_dict(v, level + 1)
|
903
|
-
else:
|
904
|
-
formatted = repr(v)
|
905
|
-
items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted}')
|
906
|
-
return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
|
907
|
-
|
908
|
-
return format_dict(x, 0)
|
909
|
-
else:
|
910
|
-
return repr(x)
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
"""
|
19
|
+
Custom data structures that work seamlessly with JAX transformations.
|
20
|
+
"""
|
21
|
+
|
22
|
+
from __future__ import annotations
|
23
|
+
|
24
|
+
import dataclasses
|
25
|
+
from collections.abc import Mapping, KeysView, ValuesView, ItemsView
|
26
|
+
from typing import Any, TypeVar, Generic, Iterator, overload
|
27
|
+
|
28
|
+
import jax
|
29
|
+
import jax.tree_util
|
30
|
+
from typing_extensions import dataclass_transform
|
31
|
+
|
32
|
+
__all__ = [
|
33
|
+
'field',
|
34
|
+
'dataclass',
|
35
|
+
'PyTreeNode',
|
36
|
+
'FrozenDict',
|
37
|
+
'freeze',
|
38
|
+
'unfreeze',
|
39
|
+
'copy',
|
40
|
+
'pop',
|
41
|
+
'pretty_repr',
|
42
|
+
]
|
43
|
+
|
44
|
+
# Type variables
|
45
|
+
K = TypeVar('K')
|
46
|
+
V = TypeVar('V')
|
47
|
+
T = TypeVar('T')
|
48
|
+
TNode = TypeVar('TNode', bound='PyTreeNode')
|
49
|
+
|
50
|
+
|
51
|
+
def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field:
|
52
|
+
"""
|
53
|
+
Create a dataclass field with JAX pytree metadata.
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
pytree_node : bool, optional
|
58
|
+
If True (default), this field will be treated as part of the pytree.
|
59
|
+
If False, it will be treated as metadata and not be touched
|
60
|
+
by JAX transformations.
|
61
|
+
**kwargs
|
62
|
+
Additional arguments to pass to dataclasses.field().
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
dataclasses.Field
|
67
|
+
A dataclass field with the appropriate metadata.
|
68
|
+
|
69
|
+
Examples
|
70
|
+
--------
|
71
|
+
.. code-block:: python
|
72
|
+
|
73
|
+
>>> import jax.numpy as jnp
|
74
|
+
>>> from brainstate.util import dataclass, field
|
75
|
+
|
76
|
+
>>> @dataclass
|
77
|
+
... class Model:
|
78
|
+
... weights: jnp.ndarray
|
79
|
+
... bias: jnp.ndarray
|
80
|
+
... # This field won't be affected by JAX transformations
|
81
|
+
... name: str = field(pytree_node=False, default="model")
|
82
|
+
"""
|
83
|
+
metadata = kwargs.pop('metadata', {})
|
84
|
+
metadata['pytree_node'] = pytree_node
|
85
|
+
return dataclasses.field(metadata=metadata, **kwargs)
|
86
|
+
|
87
|
+
|
88
|
+
@dataclass_transform(field_specifiers=(field,))
|
89
|
+
def dataclass(cls: type[T], **kwargs) -> type[T]:
|
90
|
+
"""
|
91
|
+
Create a dataclass that works with JAX transformations.
|
92
|
+
|
93
|
+
This decorator creates immutable dataclasses that can be used safely
|
94
|
+
with JAX transformations like jit, grad, vmap, etc. The created class
|
95
|
+
will be registered as a JAX pytree node.
|
96
|
+
|
97
|
+
Parameters
|
98
|
+
----------
|
99
|
+
cls : type
|
100
|
+
The class to decorate.
|
101
|
+
**kwargs
|
102
|
+
Additional arguments for dataclasses.dataclass().
|
103
|
+
If 'frozen' is not specified, it defaults to True.
|
104
|
+
|
105
|
+
Returns
|
106
|
+
-------
|
107
|
+
type
|
108
|
+
The decorated class as an immutable JAX-compatible dataclass.
|
109
|
+
|
110
|
+
See Also
|
111
|
+
--------
|
112
|
+
PyTreeNode : Base class for creating JAX-compatible pytree nodes.
|
113
|
+
field : Create dataclass fields with pytree metadata.
|
114
|
+
|
115
|
+
Notes
|
116
|
+
-----
|
117
|
+
The decorated class will be frozen (immutable) by default to ensure
|
118
|
+
compatibility with JAX's functional programming paradigm.
|
119
|
+
|
120
|
+
Examples
|
121
|
+
--------
|
122
|
+
.. code-block:: python
|
123
|
+
|
124
|
+
>>> import jax
|
125
|
+
>>> import jax.numpy as jnp
|
126
|
+
>>> from brainstate.util import dataclass, field
|
127
|
+
|
128
|
+
>>> @dataclass
|
129
|
+
... class Model:
|
130
|
+
... weights: jax.Array
|
131
|
+
... bias: jax.Array
|
132
|
+
... name: str = field(pytree_node=False, default="model")
|
133
|
+
|
134
|
+
>>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))
|
135
|
+
|
136
|
+
>>> # JAX transformations will only apply to weights and bias, not name
|
137
|
+
>>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
|
138
|
+
>>> grads = grad_fn(model)
|
139
|
+
|
140
|
+
>>> # Use replace to create modified copies
|
141
|
+
>>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)
|
142
|
+
"""
|
143
|
+
# Check if already converted
|
144
|
+
if hasattr(cls, '_brainstate_dataclass'):
|
145
|
+
return cls
|
146
|
+
|
147
|
+
# Default to frozen for immutability
|
148
|
+
kwargs.setdefault('frozen', True)
|
149
|
+
|
150
|
+
# Apply standard dataclass decorator
|
151
|
+
cls = dataclasses.dataclass(**kwargs)(cls)
|
152
|
+
|
153
|
+
# Separate fields into pytree and metadata
|
154
|
+
pytree_fields = []
|
155
|
+
meta_fields = []
|
156
|
+
|
157
|
+
for field_info in dataclasses.fields(cls):
|
158
|
+
if field_info.metadata.get('pytree_node', True):
|
159
|
+
pytree_fields.append(field_info.name)
|
160
|
+
else:
|
161
|
+
meta_fields.append(field_info.name)
|
162
|
+
|
163
|
+
# Add replace method
|
164
|
+
def replace(self: T, **updates) -> T:
|
165
|
+
"""Replace specified fields with new values."""
|
166
|
+
return dataclasses.replace(self, **updates)
|
167
|
+
|
168
|
+
cls.replace = replace
|
169
|
+
|
170
|
+
# Register with JAX
|
171
|
+
_register_pytree(cls, pytree_fields, meta_fields)
|
172
|
+
|
173
|
+
# Mark as BrainState dataclass
|
174
|
+
cls._brainstate_dataclass = True
|
175
|
+
|
176
|
+
return cls
|
177
|
+
|
178
|
+
|
179
|
+
def _register_pytree(cls: type, pytree_fields: list[str], meta_fields: list[str]) -> None:
|
180
|
+
"""Register a class as a JAX pytree."""
|
181
|
+
|
182
|
+
def flatten_fn(obj):
|
183
|
+
pytree_data = tuple(getattr(obj, name) for name in pytree_fields)
|
184
|
+
metadata = tuple(getattr(obj, name) for name in meta_fields)
|
185
|
+
return pytree_data, metadata
|
186
|
+
|
187
|
+
def flatten_with_keys_fn(obj):
|
188
|
+
pytree_data = tuple(
|
189
|
+
(jax.tree_util.GetAttrKey(name), getattr(obj, name))
|
190
|
+
for name in pytree_fields
|
191
|
+
)
|
192
|
+
metadata = tuple(getattr(obj, name) for name in meta_fields)
|
193
|
+
return pytree_data, metadata
|
194
|
+
|
195
|
+
def unflatten_fn(metadata, pytree_data):
|
196
|
+
kwargs = {}
|
197
|
+
for name, value in zip(meta_fields, metadata):
|
198
|
+
kwargs[name] = value
|
199
|
+
for name, value in zip(pytree_fields, pytree_data):
|
200
|
+
kwargs[name] = value
|
201
|
+
return cls(**kwargs)
|
202
|
+
|
203
|
+
# Use new API if available, otherwise fall back
|
204
|
+
if hasattr(jax.tree_util, 'register_dataclass'):
|
205
|
+
jax.tree_util.register_dataclass(cls, pytree_fields, meta_fields)
|
206
|
+
else:
|
207
|
+
jax.tree_util.register_pytree_with_keys(
|
208
|
+
cls,
|
209
|
+
flatten_with_keys_fn,
|
210
|
+
unflatten_fn,
|
211
|
+
flatten_fn
|
212
|
+
)
|
213
|
+
|
214
|
+
|
215
|
+
@dataclass_transform(field_specifiers=(field,))
|
216
|
+
class PyTreeNode:
|
217
|
+
"""
|
218
|
+
Base class for creating JAX-compatible pytree nodes.
|
219
|
+
|
220
|
+
Subclasses of PyTreeNode are automatically converted to immutable
|
221
|
+
dataclasses that work with JAX transformations.
|
222
|
+
|
223
|
+
See Also
|
224
|
+
--------
|
225
|
+
dataclass : Decorator for creating JAX-compatible dataclasses.
|
226
|
+
field : Create dataclass fields with pytree metadata.
|
227
|
+
|
228
|
+
Notes
|
229
|
+
-----
|
230
|
+
When subclassing PyTreeNode, all fields are automatically treated as
|
231
|
+
part of the pytree unless explicitly marked with ``pytree_node=False``
|
232
|
+
using the field() function.
|
233
|
+
|
234
|
+
Examples
|
235
|
+
--------
|
236
|
+
.. code-block:: python
|
237
|
+
|
238
|
+
>>> import jax
|
239
|
+
>>> import jax.numpy as jnp
|
240
|
+
>>> from brainstate.util import PyTreeNode, field
|
241
|
+
|
242
|
+
>>> class Layer(PyTreeNode):
|
243
|
+
... weights: jax.Array
|
244
|
+
... bias: jax.Array
|
245
|
+
... activation: str = field(pytree_node=False, default="relu")
|
246
|
+
|
247
|
+
>>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))
|
248
|
+
|
249
|
+
>>> # Can be used in JAX transformations
|
250
|
+
>>> def loss_fn(layer):
|
251
|
+
... return jnp.sum(layer.weights ** 2)
|
252
|
+
>>> grad_fn = jax.grad(loss_fn)
|
253
|
+
>>> grads = grad_fn(layer)
|
254
|
+
|
255
|
+
>>> # Create modified copies with replace
|
256
|
+
>>> layer2 = layer.replace(bias=jnp.ones(4))
|
257
|
+
"""
|
258
|
+
|
259
|
+
def __init_subclass__(cls, **kwargs):
|
260
|
+
"""Automatically apply dataclass decorator to subclasses."""
|
261
|
+
dataclass(cls, **kwargs)
|
262
|
+
|
263
|
+
def __init__(self, *args, **kwargs):
|
264
|
+
"""Stub for type checkers."""
|
265
|
+
raise NotImplementedError("PyTreeNode is a base class")
|
266
|
+
|
267
|
+
def replace(self: TNode, **updates) -> TNode:
|
268
|
+
"""
|
269
|
+
Replace specified fields with new values.
|
270
|
+
|
271
|
+
Parameters
|
272
|
+
----------
|
273
|
+
**updates
|
274
|
+
Field names and their new values.
|
275
|
+
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
TNode
|
279
|
+
A new instance with updated fields.
|
280
|
+
"""
|
281
|
+
raise NotImplementedError("Implemented by dataclass decorator")
|
282
|
+
|
283
|
+
|
284
|
+
@jax.tree_util.register_pytree_with_keys_class
|
285
|
+
class FrozenDict(Mapping[K, V], Generic[K, V]):
|
286
|
+
"""
|
287
|
+
An immutable dictionary that works as a JAX pytree.
|
288
|
+
|
289
|
+
FrozenDict provides an immutable mapping interface that can be used
|
290
|
+
safely with JAX transformations. It supports all standard dictionary
|
291
|
+
operations in an immutable fashion.
|
292
|
+
|
293
|
+
Parameters
|
294
|
+
----------
|
295
|
+
*args
|
296
|
+
Positional arguments for dict construction.
|
297
|
+
**kwargs
|
298
|
+
Keyword arguments for dict construction.
|
299
|
+
|
300
|
+
Attributes
|
301
|
+
----------
|
302
|
+
_data : dict
|
303
|
+
Internal immutable data storage.
|
304
|
+
_hash : int or None
|
305
|
+
Cached hash value.
|
306
|
+
|
307
|
+
See Also
|
308
|
+
--------
|
309
|
+
freeze : Convert a mapping to a FrozenDict.
|
310
|
+
unfreeze : Convert a FrozenDict to a regular dict.
|
311
|
+
|
312
|
+
Notes
|
313
|
+
-----
|
314
|
+
FrozenDict is immutable - all operations that would modify the dictionary
|
315
|
+
instead return a new FrozenDict instance with the changes applied.
|
316
|
+
|
317
|
+
Examples
|
318
|
+
--------
|
319
|
+
.. code-block:: python
|
320
|
+
|
321
|
+
>>> from brainstate.util import FrozenDict
|
322
|
+
|
323
|
+
>>> # Create a FrozenDict
|
324
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
325
|
+
>>> fd['a']
|
326
|
+
1
|
327
|
+
|
328
|
+
>>> # Copy with updates (returns new FrozenDict)
|
329
|
+
>>> new_fd = fd.copy({'c': 3})
|
330
|
+
>>> new_fd['c']
|
331
|
+
3
|
332
|
+
|
333
|
+
>>> # Pop an item (returns new dict and popped value)
|
334
|
+
>>> new_fd, value = fd.pop('b')
|
335
|
+
>>> value
|
336
|
+
2
|
337
|
+
>>> 'b' in new_fd
|
338
|
+
False
|
339
|
+
|
340
|
+
>>> # Nested dictionaries are automatically frozen
|
341
|
+
>>> fd = FrozenDict({'x': {'y': 1}})
|
342
|
+
>>> isinstance(fd['x'], FrozenDict)
|
343
|
+
True
|
344
|
+
"""
|
345
|
+
|
346
|
+
__slots__ = ('_data', '_hash')
|
347
|
+
|
348
|
+
def __init__(self, *args, **kwargs):
|
349
|
+
"""Initialize a FrozenDict."""
|
350
|
+
data = dict(*args, **kwargs)
|
351
|
+
self._data = self._deep_freeze(data)
|
352
|
+
self._hash = None
|
353
|
+
|
354
|
+
@staticmethod
|
355
|
+
def _deep_freeze(obj: Any) -> Any:
|
356
|
+
"""Recursively freeze nested dictionaries."""
|
357
|
+
if isinstance(obj, FrozenDict):
|
358
|
+
return obj._data
|
359
|
+
elif isinstance(obj, dict):
|
360
|
+
return {k: FrozenDict._deep_freeze(v) for k, v in obj.items()}
|
361
|
+
else:
|
362
|
+
return obj
|
363
|
+
|
364
|
+
def __getitem__(self, key: K) -> V:
|
365
|
+
"""Get an item from the dictionary."""
|
366
|
+
value = self._data[key]
|
367
|
+
if isinstance(value, dict):
|
368
|
+
return FrozenDict(value)
|
369
|
+
return value
|
370
|
+
|
371
|
+
def __setitem__(self, key: K, value: V) -> None:
|
372
|
+
"""Raise an error - FrozenDict is immutable."""
|
373
|
+
raise TypeError("FrozenDict does not support item assignment")
|
374
|
+
|
375
|
+
def __delitem__(self, key: K) -> None:
|
376
|
+
"""Raise an error - FrozenDict is immutable."""
|
377
|
+
raise TypeError("FrozenDict does not support item deletion")
|
378
|
+
|
379
|
+
def __contains__(self, key: object) -> bool:
|
380
|
+
"""Check if a key is in the dictionary."""
|
381
|
+
return key in self._data
|
382
|
+
|
383
|
+
def __iter__(self) -> Iterator[K]:
|
384
|
+
"""Iterate over keys."""
|
385
|
+
return iter(self._data)
|
386
|
+
|
387
|
+
def __len__(self) -> int:
|
388
|
+
"""Return the number of items."""
|
389
|
+
return len(self._data)
|
390
|
+
|
391
|
+
def __repr__(self) -> str:
|
392
|
+
"""Return a string representation."""
|
393
|
+
return self.pretty_repr()
|
394
|
+
|
395
|
+
def __hash__(self) -> int:
|
396
|
+
"""Return a hash of the dictionary."""
|
397
|
+
if self._hash is None:
|
398
|
+
items = []
|
399
|
+
for key, value in self.items():
|
400
|
+
if isinstance(value, dict):
|
401
|
+
value = FrozenDict(value)
|
402
|
+
items.append((key, value))
|
403
|
+
self._hash = hash(tuple(sorted(items)))
|
404
|
+
return self._hash
|
405
|
+
|
406
|
+
def __eq__(self, other: object) -> bool:
|
407
|
+
"""Check equality with another object."""
|
408
|
+
if not isinstance(other, (FrozenDict, dict)):
|
409
|
+
return NotImplemented
|
410
|
+
if isinstance(other, FrozenDict):
|
411
|
+
return self._data == other._data
|
412
|
+
return self._data == other
|
413
|
+
|
414
|
+
def __reduce__(self):
|
415
|
+
"""Support for pickling."""
|
416
|
+
return FrozenDict, (self.unfreeze(),)
|
417
|
+
|
418
|
+
def keys(self) -> KeysView[K]:
|
419
|
+
"""
|
420
|
+
Return a view of the keys.
|
421
|
+
|
422
|
+
Returns
|
423
|
+
-------
|
424
|
+
KeysView
|
425
|
+
A view object of the dictionary's keys.
|
426
|
+
"""
|
427
|
+
return FrozenKeysView(self)
|
428
|
+
|
429
|
+
def values(self) -> ValuesView[V]:
|
430
|
+
"""
|
431
|
+
Return a view of the values.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
ValuesView
|
436
|
+
A view object of the dictionary's values.
|
437
|
+
"""
|
438
|
+
return FrozenValuesView(self)
|
439
|
+
|
440
|
+
def items(self) -> ItemsView[K, V]:
|
441
|
+
"""
|
442
|
+
Return a view of the items.
|
443
|
+
|
444
|
+
Yields
|
445
|
+
------
|
446
|
+
tuple
|
447
|
+
Key-value pairs from the dictionary.
|
448
|
+
"""
|
449
|
+
for key in self._data:
|
450
|
+
yield (key, self[key])
|
451
|
+
|
452
|
+
def get(self, key: K, default: V | None = None) -> V | None:
|
453
|
+
"""
|
454
|
+
Get a value with a default.
|
455
|
+
|
456
|
+
Parameters
|
457
|
+
----------
|
458
|
+
key : K
|
459
|
+
The key to look up.
|
460
|
+
default : V or None, optional
|
461
|
+
The default value to return if key is not found.
|
462
|
+
|
463
|
+
Returns
|
464
|
+
-------
|
465
|
+
V or None
|
466
|
+
The value associated with the key, or default.
|
467
|
+
"""
|
468
|
+
try:
|
469
|
+
return self[key]
|
470
|
+
except KeyError:
|
471
|
+
return default
|
472
|
+
|
473
|
+
def copy(self, add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
|
474
|
+
"""
|
475
|
+
Create a new FrozenDict with added or replaced entries.
|
476
|
+
|
477
|
+
Parameters
|
478
|
+
----------
|
479
|
+
add_or_replace : Mapping or None, optional
|
480
|
+
Entries to add or replace in the new dictionary.
|
481
|
+
|
482
|
+
Returns
|
483
|
+
-------
|
484
|
+
FrozenDict
|
485
|
+
A new FrozenDict with the updates applied.
|
486
|
+
|
487
|
+
Examples
|
488
|
+
--------
|
489
|
+
.. code-block:: python
|
490
|
+
|
491
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
492
|
+
>>> fd2 = fd.copy({'b': 3, 'c': 4})
|
493
|
+
>>> fd2['b'], fd2['c']
|
494
|
+
(3, 4)
|
495
|
+
"""
|
496
|
+
if add_or_replace is None:
|
497
|
+
add_or_replace = {}
|
498
|
+
new_data = dict(self._data)
|
499
|
+
new_data.update(add_or_replace)
|
500
|
+
return type(self)(new_data)
|
501
|
+
|
502
|
+
def pop(self, key: K) -> tuple[FrozenDict[K, V], V]:
|
503
|
+
"""
|
504
|
+
Create a new FrozenDict with one entry removed.
|
505
|
+
|
506
|
+
Parameters
|
507
|
+
----------
|
508
|
+
key : K
|
509
|
+
The key to remove.
|
510
|
+
|
511
|
+
Returns
|
512
|
+
-------
|
513
|
+
tuple
|
514
|
+
A tuple of (new FrozenDict without the key, removed value).
|
515
|
+
|
516
|
+
Raises
|
517
|
+
------
|
518
|
+
KeyError
|
519
|
+
If the key is not found in the dictionary.
|
520
|
+
|
521
|
+
Examples
|
522
|
+
--------
|
523
|
+
.. code-block:: python
|
524
|
+
|
525
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
526
|
+
>>> fd2, value = fd.pop('a')
|
527
|
+
>>> value
|
528
|
+
1
|
529
|
+
>>> 'a' in fd2
|
530
|
+
False
|
531
|
+
"""
|
532
|
+
if key not in self._data:
|
533
|
+
raise KeyError(key)
|
534
|
+
value = self[key]
|
535
|
+
new_data = dict(self._data)
|
536
|
+
del new_data[key]
|
537
|
+
return type(self)(new_data), value
|
538
|
+
|
539
|
+
def unfreeze(self) -> dict[K, V]:
|
540
|
+
"""
|
541
|
+
Convert to a regular mutable dictionary.
|
542
|
+
|
543
|
+
Returns
|
544
|
+
-------
|
545
|
+
dict
|
546
|
+
A mutable dict with the same contents.
|
547
|
+
|
548
|
+
Examples
|
549
|
+
--------
|
550
|
+
.. code-block:: python
|
551
|
+
|
552
|
+
>>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
553
|
+
>>> d = fd.unfreeze()
|
554
|
+
>>> isinstance(d, dict)
|
555
|
+
True
|
556
|
+
>>> isinstance(d['b'], dict) # Nested dicts also unfrozen
|
557
|
+
True
|
558
|
+
"""
|
559
|
+
return unfreeze(self)
|
560
|
+
|
561
|
+
def pretty_repr(self, indent: int = 2) -> str:
|
562
|
+
"""
|
563
|
+
Return a pretty-printed representation.
|
564
|
+
|
565
|
+
Parameters
|
566
|
+
----------
|
567
|
+
indent : int, optional
|
568
|
+
Number of spaces per indentation level (default 2).
|
569
|
+
|
570
|
+
Returns
|
571
|
+
-------
|
572
|
+
str
|
573
|
+
A formatted string representation of the FrozenDict.
|
574
|
+
"""
|
575
|
+
|
576
|
+
def format_value(v, level):
|
577
|
+
if isinstance(v, dict):
|
578
|
+
if not v:
|
579
|
+
return '{}'
|
580
|
+
items = []
|
581
|
+
for k, val in v.items():
|
582
|
+
formatted_val = format_value(val, level + 1)
|
583
|
+
items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted_val}')
|
584
|
+
return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
|
585
|
+
else:
|
586
|
+
return repr(v)
|
587
|
+
|
588
|
+
if not self._data:
|
589
|
+
return 'FrozenDict({})'
|
590
|
+
|
591
|
+
return f'FrozenDict({format_value(self._data, 0)})'
|
592
|
+
|
593
|
+
def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]:
|
594
|
+
"""Flatten for JAX pytree with keys."""
|
595
|
+
sorted_keys = sorted(self._data.keys())
|
596
|
+
values_with_keys = [
|
597
|
+
(jax.tree_util.DictKey(k), self._data[k])
|
598
|
+
for k in sorted_keys
|
599
|
+
]
|
600
|
+
return values_with_keys, tuple(sorted_keys)
|
601
|
+
|
602
|
+
@classmethod
|
603
|
+
def tree_unflatten(cls, keys: tuple[Any, ...], values: list[Any]) -> FrozenDict:
|
604
|
+
"""Unflatten from JAX pytree."""
|
605
|
+
return cls(dict(zip(keys, values)))
|
606
|
+
|
607
|
+
|
608
|
+
class FrozenKeysView(KeysView[K]):
|
609
|
+
"""View of keys in a FrozenDict."""
|
610
|
+
|
611
|
+
def __repr__(self) -> str:
|
612
|
+
return f'FrozenDict.keys({list(self)})'
|
613
|
+
|
614
|
+
|
615
|
+
class FrozenValuesView(ValuesView[V]):
|
616
|
+
"""View of values in a FrozenDict."""
|
617
|
+
|
618
|
+
def __repr__(self) -> str:
|
619
|
+
return f'FrozenDict.values({list(self)})'
|
620
|
+
|
621
|
+
|
622
|
+
def freeze(x: Mapping[K, V]) -> FrozenDict[K, V]:
|
623
|
+
"""
|
624
|
+
Convert a mapping to a FrozenDict.
|
625
|
+
|
626
|
+
Parameters
|
627
|
+
----------
|
628
|
+
x : Mapping
|
629
|
+
A mapping (dict, FrozenDict, etc.) to freeze.
|
630
|
+
|
631
|
+
Returns
|
632
|
+
-------
|
633
|
+
FrozenDict
|
634
|
+
An immutable FrozenDict.
|
635
|
+
|
636
|
+
See Also
|
637
|
+
--------
|
638
|
+
unfreeze : Convert a FrozenDict to a regular dict.
|
639
|
+
FrozenDict : The immutable dictionary class.
|
640
|
+
|
641
|
+
Examples
|
642
|
+
--------
|
643
|
+
.. code-block:: python
|
644
|
+
|
645
|
+
>>> from brainstate.util import freeze
|
646
|
+
|
647
|
+
>>> d = {'a': 1, 'b': {'c': 2}}
|
648
|
+
>>> fd = freeze(d)
|
649
|
+
>>> isinstance(fd, FrozenDict)
|
650
|
+
True
|
651
|
+
>>> isinstance(fd['b'], FrozenDict) # Nested dicts are frozen
|
652
|
+
True
|
653
|
+
"""
|
654
|
+
if isinstance(x, FrozenDict):
|
655
|
+
return x
|
656
|
+
return FrozenDict(x)
|
657
|
+
|
658
|
+
|
659
|
+
def unfreeze(x: FrozenDict[K, V] | dict[K, V]) -> dict[K, V]:
|
660
|
+
"""
|
661
|
+
Convert a FrozenDict to a regular dict.
|
662
|
+
|
663
|
+
Recursively converts FrozenDict instances to mutable dicts.
|
664
|
+
|
665
|
+
Parameters
|
666
|
+
----------
|
667
|
+
x : FrozenDict or dict
|
668
|
+
A FrozenDict or dict to unfreeze.
|
669
|
+
|
670
|
+
Returns
|
671
|
+
-------
|
672
|
+
dict
|
673
|
+
A mutable dictionary.
|
674
|
+
|
675
|
+
See Also
|
676
|
+
--------
|
677
|
+
freeze : Convert a mapping to a FrozenDict.
|
678
|
+
FrozenDict : The immutable dictionary class.
|
679
|
+
|
680
|
+
Examples
|
681
|
+
--------
|
682
|
+
.. code-block:: python
|
683
|
+
|
684
|
+
>>> from brainstate.util import FrozenDict, unfreeze
|
685
|
+
|
686
|
+
>>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
687
|
+
>>> d = unfreeze(fd)
|
688
|
+
>>> isinstance(d, dict)
|
689
|
+
True
|
690
|
+
>>> isinstance(d['b'], dict) # Nested FrozenDicts are unfrozen
|
691
|
+
True
|
692
|
+
>>> d['a'] = 10 # Can modify the result
|
693
|
+
"""
|
694
|
+
if isinstance(x, FrozenDict):
|
695
|
+
result = {}
|
696
|
+
for key, value in x._data.items():
|
697
|
+
result[key] = unfreeze(value)
|
698
|
+
return result
|
699
|
+
elif isinstance(x, dict):
|
700
|
+
result = {}
|
701
|
+
for key, value in x.items():
|
702
|
+
result[key] = unfreeze(value)
|
703
|
+
return result
|
704
|
+
else:
|
705
|
+
return x
|
706
|
+
|
707
|
+
|
708
|
+
@overload
|
709
|
+
def copy(x: FrozenDict[K, V], add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
|
710
|
+
...
|
711
|
+
|
712
|
+
|
713
|
+
@overload
|
714
|
+
def copy(x: dict[K, V], add_or_replace: Mapping[K, V] | None = None) -> dict[K, V]:
|
715
|
+
...
|
716
|
+
|
717
|
+
|
718
|
+
def copy(x, add_or_replace=None):
|
719
|
+
"""
|
720
|
+
Copy a dictionary with optional updates.
|
721
|
+
|
722
|
+
Works with both FrozenDict and regular dict.
|
723
|
+
|
724
|
+
Parameters
|
725
|
+
----------
|
726
|
+
x : FrozenDict or dict
|
727
|
+
Dictionary to copy.
|
728
|
+
add_or_replace : Mapping or None, optional
|
729
|
+
Entries to add or replace in the copy.
|
730
|
+
|
731
|
+
Returns
|
732
|
+
-------
|
733
|
+
FrozenDict or dict
|
734
|
+
A copy of the same type as the input with updates applied.
|
735
|
+
|
736
|
+
Raises
|
737
|
+
------
|
738
|
+
TypeError
|
739
|
+
If x is not a FrozenDict or dict.
|
740
|
+
|
741
|
+
See Also
|
742
|
+
--------
|
743
|
+
FrozenDict.copy : Copy method for FrozenDict.
|
744
|
+
|
745
|
+
Examples
|
746
|
+
--------
|
747
|
+
.. code-block:: python
|
748
|
+
|
749
|
+
>>> from brainstate.util import FrozenDict, copy
|
750
|
+
|
751
|
+
>>> # Works with FrozenDict
|
752
|
+
>>> fd = FrozenDict({'a': 1})
|
753
|
+
>>> fd2 = copy(fd, {'b': 2})
|
754
|
+
>>> isinstance(fd2, FrozenDict)
|
755
|
+
True
|
756
|
+
>>> fd2['b']
|
757
|
+
2
|
758
|
+
|
759
|
+
>>> # Also works with regular dict
|
760
|
+
>>> d = {'a': 1}
|
761
|
+
>>> d2 = copy(d, {'b': 2})
|
762
|
+
>>> isinstance(d2, dict)
|
763
|
+
True
|
764
|
+
>>> d2['b']
|
765
|
+
2
|
766
|
+
"""
|
767
|
+
if add_or_replace is None:
|
768
|
+
add_or_replace = {}
|
769
|
+
|
770
|
+
if isinstance(x, FrozenDict):
|
771
|
+
return x.copy(add_or_replace)
|
772
|
+
elif isinstance(x, dict):
|
773
|
+
result = dict(x)
|
774
|
+
result.update(add_or_replace)
|
775
|
+
return result
|
776
|
+
else:
|
777
|
+
raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
|
778
|
+
|
779
|
+
|
780
|
+
@overload
|
781
|
+
def pop(x: FrozenDict[K, V], key: K) -> tuple[FrozenDict[K, V], V]:
|
782
|
+
...
|
783
|
+
|
784
|
+
|
785
|
+
@overload
|
786
|
+
def pop(x: dict[K, V], key: K) -> tuple[dict[K, V], V]:
|
787
|
+
...
|
788
|
+
|
789
|
+
|
790
|
+
def pop(x, key):
|
791
|
+
"""
|
792
|
+
Remove and return an item from a dictionary.
|
793
|
+
|
794
|
+
Works with both FrozenDict and regular dict, returning a new
|
795
|
+
dictionary without the specified key along with the popped value.
|
796
|
+
|
797
|
+
Parameters
|
798
|
+
----------
|
799
|
+
x : FrozenDict or dict
|
800
|
+
Dictionary to pop from.
|
801
|
+
key : hashable
|
802
|
+
Key to remove.
|
803
|
+
|
804
|
+
Returns
|
805
|
+
-------
|
806
|
+
tuple
|
807
|
+
A tuple of (new dictionary without the key, popped value).
|
808
|
+
|
809
|
+
Raises
|
810
|
+
------
|
811
|
+
TypeError
|
812
|
+
If x is not a FrozenDict or dict.
|
813
|
+
KeyError
|
814
|
+
If the key is not found in the dictionary.
|
815
|
+
|
816
|
+
See Also
|
817
|
+
--------
|
818
|
+
FrozenDict.pop : Pop method for FrozenDict.
|
819
|
+
|
820
|
+
Examples
|
821
|
+
--------
|
822
|
+
.. code-block:: python
|
823
|
+
|
824
|
+
>>> from brainstate.util import FrozenDict, pop
|
825
|
+
|
826
|
+
>>> # Works with FrozenDict
|
827
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
828
|
+
>>> fd2, value = pop(fd, 'a')
|
829
|
+
>>> value
|
830
|
+
1
|
831
|
+
>>> 'a' in fd2
|
832
|
+
False
|
833
|
+
|
834
|
+
>>> # Also works with regular dict
|
835
|
+
>>> d = {'a': 1, 'b': 2}
|
836
|
+
>>> d2, value = pop(d, 'a')
|
837
|
+
>>> value
|
838
|
+
1
|
839
|
+
>>> 'a' in d2
|
840
|
+
False
|
841
|
+
"""
|
842
|
+
if isinstance(x, FrozenDict):
|
843
|
+
return x.pop(key)
|
844
|
+
elif isinstance(x, dict):
|
845
|
+
new_dict = dict(x)
|
846
|
+
value = new_dict.pop(key)
|
847
|
+
return new_dict, value
|
848
|
+
else:
|
849
|
+
raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
|
850
|
+
|
851
|
+
|
852
|
+
def pretty_repr(x: Any, indent: int = 2) -> str:
|
853
|
+
"""
|
854
|
+
Create a pretty string representation.
|
855
|
+
|
856
|
+
Parameters
|
857
|
+
----------
|
858
|
+
x : any
|
859
|
+
Object to represent. If a dict or FrozenDict, will be
|
860
|
+
pretty-printed with indentation. Otherwise, returns repr(x).
|
861
|
+
indent : int, optional
|
862
|
+
Number of spaces per indentation level (default 2).
|
863
|
+
|
864
|
+
Returns
|
865
|
+
-------
|
866
|
+
str
|
867
|
+
A formatted string representation.
|
868
|
+
|
869
|
+
See Also
|
870
|
+
--------
|
871
|
+
FrozenDict.pretty_repr : Pretty representation for FrozenDict.
|
872
|
+
|
873
|
+
Examples
|
874
|
+
--------
|
875
|
+
.. code-block:: python
|
876
|
+
|
877
|
+
>>> from brainstate.util import pretty_repr
|
878
|
+
|
879
|
+
>>> d = {'a': 1, 'b': {'c': 2, 'd': 3}}
|
880
|
+
>>> print(pretty_repr(d))
|
881
|
+
{
|
882
|
+
'a': 1,
|
883
|
+
'b': {
|
884
|
+
'c': 2,
|
885
|
+
'd': 3
|
886
|
+
}
|
887
|
+
}
|
888
|
+
|
889
|
+
>>> # Non-dict objects return normal repr
|
890
|
+
>>> pretty_repr([1, 2, 3])
|
891
|
+
'[1, 2, 3]'
|
892
|
+
"""
|
893
|
+
if isinstance(x, FrozenDict):
|
894
|
+
return x.pretty_repr(indent)
|
895
|
+
elif isinstance(x, dict):
|
896
|
+
def format_dict(d, level):
|
897
|
+
if not d:
|
898
|
+
return '{}'
|
899
|
+
items = []
|
900
|
+
for k, v in d.items():
|
901
|
+
if isinstance(v, dict):
|
902
|
+
formatted = format_dict(v, level + 1)
|
903
|
+
else:
|
904
|
+
formatted = repr(v)
|
905
|
+
items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted}')
|
906
|
+
return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
|
907
|
+
|
908
|
+
return format_dict(x, 0)
|
909
|
+
else:
|
910
|
+
return repr(x)
|