brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/util/struct.py
CHANGED
@@ -15,509 +15,896 @@
|
|
15
15
|
# See the License for the specific language governing permissions and
|
16
16
|
# limitations under the License.
|
17
17
|
|
18
|
-
"""
|
18
|
+
"""
|
19
|
+
Custom data structures that work seamlessly with JAX transformations.
|
20
|
+
"""
|
21
|
+
|
22
|
+
from __future__ import annotations
|
19
23
|
|
20
|
-
import collections
|
21
24
|
import dataclasses
|
22
|
-
from collections.abc import
|
23
|
-
from
|
24
|
-
from typing import Any, TypeVar
|
25
|
+
from collections.abc import Mapping, KeysView, ValuesView, ItemsView
|
26
|
+
from typing import Any, TypeVar, Generic, Iterator, overload
|
25
27
|
|
26
28
|
import jax
|
27
|
-
|
29
|
+
import jax.tree_util
|
30
|
+
from typing_extensions import dataclass_transform
|
28
31
|
|
29
32
|
__all__ = [
|
30
|
-
'dataclass',
|
31
33
|
'field',
|
34
|
+
'dataclass',
|
32
35
|
'PyTreeNode',
|
33
36
|
'FrozenDict',
|
37
|
+
'freeze',
|
38
|
+
'unfreeze',
|
39
|
+
'copy',
|
40
|
+
'pop',
|
41
|
+
'pretty_repr',
|
34
42
|
]
|
35
43
|
|
44
|
+
# Type variables
|
36
45
|
K = TypeVar('K')
|
37
46
|
V = TypeVar('V')
|
38
47
|
T = TypeVar('T')
|
48
|
+
TNode = TypeVar('TNode', bound='PyTreeNode')
|
39
49
|
|
40
50
|
|
41
|
-
def field(pytree_node
|
42
|
-
|
51
|
+
def field(pytree_node: bool = True, **kwargs) -> dataclasses.Field:
|
52
|
+
"""
|
53
|
+
Create a dataclass field with JAX pytree metadata.
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
pytree_node : bool, optional
|
58
|
+
If True (default), this field will be treated as part of the pytree.
|
59
|
+
If False, it will be treated as metadata and not be touched
|
60
|
+
by JAX transformations.
|
61
|
+
**kwargs
|
62
|
+
Additional arguments to pass to dataclasses.field().
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
dataclasses.Field
|
67
|
+
A dataclass field with the appropriate metadata.
|
68
|
+
|
69
|
+
Examples
|
70
|
+
--------
|
71
|
+
.. code-block:: python
|
72
|
+
|
73
|
+
>>> import jax.numpy as jnp
|
74
|
+
>>> from brainstate.util import dataclass, field
|
75
|
+
|
76
|
+
>>> @dataclass
|
77
|
+
... class Model:
|
78
|
+
... weights: jnp.ndarray
|
79
|
+
... bias: jnp.ndarray
|
80
|
+
... # This field won't be affected by JAX transformations
|
81
|
+
... name: str = field(pytree_node=False, default="model")
|
82
|
+
"""
|
83
|
+
metadata = kwargs.pop('metadata', {})
|
84
|
+
metadata['pytree_node'] = pytree_node
|
85
|
+
return dataclasses.field(metadata=metadata, **kwargs)
|
43
86
|
|
44
87
|
|
45
|
-
@dataclass_transform(field_specifiers=(field,))
|
46
|
-
def dataclass(
|
88
|
+
@dataclass_transform(field_specifiers=(field,))
|
89
|
+
def dataclass(cls: type[T], **kwargs) -> type[T]:
|
47
90
|
"""
|
48
|
-
Create a
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
>>> @brainstate.util.dataclass
|
101
|
-
... class DirectionAndScaleKernel:
|
102
|
-
... direction: jax.Array
|
103
|
-
... scale: jax.Array
|
104
|
-
|
105
|
-
... @classmethod
|
106
|
-
... def create(cls, kernel):
|
107
|
-
... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
|
108
|
-
... direction = direction / scale
|
109
|
-
... return cls(direction, scale)
|
110
|
-
|
111
|
-
Args:
|
112
|
-
clz: the class that will be transformed by the decorator.
|
113
|
-
Returns:
|
114
|
-
The new class.
|
91
|
+
Create a dataclass that works with JAX transformations.
|
92
|
+
|
93
|
+
This decorator creates immutable dataclasses that can be used safely
|
94
|
+
with JAX transformations like jit, grad, vmap, etc. The created class
|
95
|
+
will be registered as a JAX pytree node.
|
96
|
+
|
97
|
+
Parameters
|
98
|
+
----------
|
99
|
+
cls : type
|
100
|
+
The class to decorate.
|
101
|
+
**kwargs
|
102
|
+
Additional arguments for dataclasses.dataclass().
|
103
|
+
If 'frozen' is not specified, it defaults to True.
|
104
|
+
|
105
|
+
Returns
|
106
|
+
-------
|
107
|
+
type
|
108
|
+
The decorated class as an immutable JAX-compatible dataclass.
|
109
|
+
|
110
|
+
See Also
|
111
|
+
--------
|
112
|
+
PyTreeNode : Base class for creating JAX-compatible pytree nodes.
|
113
|
+
field : Create dataclass fields with pytree metadata.
|
114
|
+
|
115
|
+
Notes
|
116
|
+
-----
|
117
|
+
The decorated class will be frozen (immutable) by default to ensure
|
118
|
+
compatibility with JAX's functional programming paradigm.
|
119
|
+
|
120
|
+
Examples
|
121
|
+
--------
|
122
|
+
.. code-block:: python
|
123
|
+
|
124
|
+
>>> import jax
|
125
|
+
>>> import jax.numpy as jnp
|
126
|
+
>>> from brainstate.util import dataclass, field
|
127
|
+
|
128
|
+
>>> @dataclass
|
129
|
+
... class Model:
|
130
|
+
... weights: jax.Array
|
131
|
+
... bias: jax.Array
|
132
|
+
... name: str = field(pytree_node=False, default="model")
|
133
|
+
|
134
|
+
>>> model = Model(weights=jnp.ones((3, 3)), bias=jnp.zeros(3))
|
135
|
+
|
136
|
+
>>> # JAX transformations will only apply to weights and bias, not name
|
137
|
+
>>> grad_fn = jax.grad(lambda m: jnp.sum(m.weights))
|
138
|
+
>>> grads = grad_fn(model)
|
139
|
+
|
140
|
+
>>> # Use replace to create modified copies
|
141
|
+
>>> model2 = model.replace(weights=jnp.ones((3, 3)) * 2)
|
115
142
|
"""
|
116
|
-
#
|
117
|
-
if '_brainstate_dataclass'
|
118
|
-
return
|
143
|
+
# Check if already converted
|
144
|
+
if hasattr(cls, '_brainstate_dataclass'):
|
145
|
+
return cls
|
146
|
+
|
147
|
+
# Default to frozen for immutability
|
148
|
+
kwargs.setdefault('frozen', True)
|
119
149
|
|
120
|
-
|
121
|
-
|
122
|
-
|
150
|
+
# Apply standard dataclass decorator
|
151
|
+
cls = dataclasses.dataclass(**kwargs)(cls)
|
152
|
+
|
153
|
+
# Separate fields into pytree and metadata
|
154
|
+
pytree_fields = []
|
123
155
|
meta_fields = []
|
124
|
-
|
125
|
-
for field_info in dataclasses.fields(
|
126
|
-
|
127
|
-
|
128
|
-
data_fields.append(field_info.name)
|
156
|
+
|
157
|
+
for field_info in dataclasses.fields(cls):
|
158
|
+
if field_info.metadata.get('pytree_node', True):
|
159
|
+
pytree_fields.append(field_info.name)
|
129
160
|
else:
|
130
161
|
meta_fields.append(field_info.name)
|
131
162
|
|
132
|
-
|
133
|
-
|
163
|
+
# Add replace method
|
164
|
+
def replace(self: T, **updates) -> T:
|
165
|
+
"""Replace specified fields with new values."""
|
134
166
|
return dataclasses.replace(self, **updates)
|
135
167
|
|
136
|
-
|
168
|
+
cls.replace = replace
|
137
169
|
|
138
|
-
#
|
139
|
-
|
140
|
-
if hasattr(jax.tree_util, 'register_dataclass'):
|
141
|
-
jax.tree_util.register_dataclass(
|
142
|
-
data_clz, data_fields, meta_fields
|
143
|
-
)
|
144
|
-
else:
|
145
|
-
raise NotImplementedError
|
146
|
-
except NotImplementedError:
|
147
|
-
|
148
|
-
def iterate_clz(x):
|
149
|
-
meta = tuple(getattr(x, name) for name in meta_fields)
|
150
|
-
data = tuple(getattr(x, name) for name in data_fields)
|
151
|
-
return data, meta
|
152
|
-
|
153
|
-
def iterate_clz_with_keys(x):
|
154
|
-
meta = tuple(getattr(x, name) for name in meta_fields)
|
155
|
-
data = tuple(
|
156
|
-
(jax.tree_util.GetAttrKey(name), getattr(x, name))
|
157
|
-
for name in data_fields
|
158
|
-
)
|
159
|
-
return data, meta
|
160
|
-
|
161
|
-
def clz_from_iterable(meta, data):
|
162
|
-
meta_args = tuple(zip(meta_fields, meta))
|
163
|
-
data_args = tuple(zip(data_fields, data))
|
164
|
-
kwargs = dict(meta_args + data_args)
|
165
|
-
return data_clz(**kwargs)
|
170
|
+
# Register with JAX
|
171
|
+
_register_pytree(cls, pytree_fields, meta_fields)
|
166
172
|
|
167
|
-
|
168
|
-
|
169
|
-
iterate_clz_with_keys,
|
170
|
-
clz_from_iterable,
|
171
|
-
iterate_clz,
|
172
|
-
)
|
173
|
+
# Mark as BrainState dataclass
|
174
|
+
cls._brainstate_dataclass = True
|
173
175
|
|
174
|
-
|
175
|
-
data_clz._brainstate_dataclass = True # type: ignore[attr-defined]
|
176
|
+
return cls
|
176
177
|
|
177
|
-
return data_clz # type: ignore
|
178
178
|
|
179
|
+
def _register_pytree(cls: type, pytree_fields: list[str], meta_fields: list[str]) -> None:
|
180
|
+
"""Register a class as a JAX pytree."""
|
179
181
|
|
180
|
-
|
182
|
+
def flatten_fn(obj):
|
183
|
+
pytree_data = tuple(getattr(obj, name) for name in pytree_fields)
|
184
|
+
metadata = tuple(getattr(obj, name) for name in meta_fields)
|
185
|
+
return pytree_data, metadata
|
186
|
+
|
187
|
+
def flatten_with_keys_fn(obj):
|
188
|
+
pytree_data = tuple(
|
189
|
+
(jax.tree_util.GetAttrKey(name), getattr(obj, name))
|
190
|
+
for name in pytree_fields
|
191
|
+
)
|
192
|
+
metadata = tuple(getattr(obj, name) for name in meta_fields)
|
193
|
+
return pytree_data, metadata
|
194
|
+
|
195
|
+
def unflatten_fn(metadata, pytree_data):
|
196
|
+
kwargs = {}
|
197
|
+
for name, value in zip(meta_fields, metadata):
|
198
|
+
kwargs[name] = value
|
199
|
+
for name, value in zip(pytree_fields, pytree_data):
|
200
|
+
kwargs[name] = value
|
201
|
+
return cls(**kwargs)
|
202
|
+
|
203
|
+
# Use new API if available, otherwise fall back
|
204
|
+
if hasattr(jax.tree_util, 'register_dataclass'):
|
205
|
+
jax.tree_util.register_dataclass(cls, pytree_fields, meta_fields)
|
206
|
+
else:
|
207
|
+
jax.tree_util.register_pytree_with_keys(
|
208
|
+
cls,
|
209
|
+
flatten_with_keys_fn,
|
210
|
+
unflatten_fn,
|
211
|
+
flatten_fn
|
212
|
+
)
|
181
213
|
|
182
214
|
|
183
|
-
@dataclass_transform(field_specifiers=(field,))
|
215
|
+
@dataclass_transform(field_specifiers=(field,))
|
184
216
|
class PyTreeNode:
|
185
|
-
"""
|
217
|
+
"""
|
218
|
+
Base class for creating JAX-compatible pytree nodes.
|
186
219
|
|
187
|
-
|
188
|
-
|
220
|
+
Subclasses of PyTreeNode are automatically converted to immutable
|
221
|
+
dataclasses that work with JAX transformations.
|
189
222
|
|
190
|
-
|
223
|
+
See Also
|
224
|
+
--------
|
225
|
+
dataclass : Decorator for creating JAX-compatible dataclasses.
|
226
|
+
field : Create dataclass fields with pytree metadata.
|
191
227
|
|
192
|
-
|
193
|
-
|
194
|
-
|
228
|
+
Notes
|
229
|
+
-----
|
230
|
+
When subclassing PyTreeNode, all fields are automatically treated as
|
231
|
+
part of the pytree unless explicitly marked with ``pytree_node=False``
|
232
|
+
using the field() function.
|
195
233
|
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
... # by Jax transformations.
|
200
|
-
... apply_fn: Callable = brainstate.util.field(pytree_node=False)
|
234
|
+
Examples
|
235
|
+
--------
|
236
|
+
.. code-block:: python
|
201
237
|
|
202
|
-
|
203
|
-
|
238
|
+
>>> import jax
|
239
|
+
>>> import jax.numpy as jnp
|
240
|
+
>>> from brainstate.util import PyTreeNode, field
|
204
241
|
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
242
|
+
>>> class Layer(PyTreeNode):
|
243
|
+
... weights: jax.Array
|
244
|
+
... bias: jax.Array
|
245
|
+
... activation: str = field(pytree_node=False, default="relu")
|
209
246
|
|
210
|
-
|
211
|
-
>>> model_b = model.replace(params=params_b) # Use the replace method instead.
|
247
|
+
>>> layer = Layer(weights=jnp.ones((4, 4)), bias=jnp.zeros(4))
|
212
248
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
249
|
+
>>> # Can be used in JAX transformations
|
250
|
+
>>> def loss_fn(layer):
|
251
|
+
... return jnp.sum(layer.weights ** 2)
|
252
|
+
>>> grad_fn = jax.grad(loss_fn)
|
253
|
+
>>> grads = grad_fn(layer)
|
254
|
+
|
255
|
+
>>> # Create modified copies with replace
|
256
|
+
>>> layer2 = layer.replace(bias=jnp.ones(4))
|
218
257
|
"""
|
219
258
|
|
220
259
|
def __init_subclass__(cls, **kwargs):
|
221
|
-
dataclass
|
260
|
+
"""Automatically apply dataclass decorator to subclasses."""
|
261
|
+
dataclass(cls, **kwargs)
|
222
262
|
|
223
263
|
def __init__(self, *args, **kwargs):
|
224
|
-
|
225
|
-
raise NotImplementedError
|
264
|
+
"""Stub for type checkers."""
|
265
|
+
raise NotImplementedError("PyTreeNode is a base class")
|
226
266
|
|
227
|
-
def replace(self: TNode, **
|
228
|
-
|
229
|
-
|
267
|
+
def replace(self: TNode, **updates) -> TNode:
|
268
|
+
"""
|
269
|
+
Replace specified fields with new values.
|
230
270
|
|
271
|
+
Parameters
|
272
|
+
----------
|
273
|
+
**updates
|
274
|
+
Field names and their new values.
|
231
275
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
276
|
+
Returns
|
277
|
+
-------
|
278
|
+
TNode
|
279
|
+
A new instance with updated fields.
|
280
|
+
"""
|
281
|
+
raise NotImplementedError("Implemented by dataclass decorator")
|
238
282
|
|
239
283
|
|
240
284
|
@jax.tree_util.register_pytree_with_keys_class
|
241
|
-
class FrozenDict(Mapping[K, V]):
|
285
|
+
class FrozenDict(Mapping[K, V], Generic[K, V]):
|
242
286
|
"""
|
243
|
-
An immutable
|
287
|
+
An immutable dictionary that works as a JAX pytree.
|
288
|
+
|
289
|
+
FrozenDict provides an immutable mapping interface that can be used
|
290
|
+
safely with JAX transformations. It supports all standard dictionary
|
291
|
+
operations in an immutable fashion.
|
292
|
+
|
293
|
+
Parameters
|
294
|
+
----------
|
295
|
+
*args
|
296
|
+
Positional arguments for dict construction.
|
297
|
+
**kwargs
|
298
|
+
Keyword arguments for dict construction.
|
299
|
+
|
300
|
+
Attributes
|
301
|
+
----------
|
302
|
+
_data : dict
|
303
|
+
Internal immutable data storage.
|
304
|
+
_hash : int or None
|
305
|
+
Cached hash value.
|
306
|
+
|
307
|
+
See Also
|
308
|
+
--------
|
309
|
+
freeze : Convert a mapping to a FrozenDict.
|
310
|
+
unfreeze : Convert a FrozenDict to a regular dict.
|
311
|
+
|
312
|
+
Notes
|
313
|
+
-----
|
314
|
+
FrozenDict is immutable - all operations that would modify the dictionary
|
315
|
+
instead return a new FrozenDict instance with the changes applied.
|
316
|
+
|
317
|
+
Examples
|
318
|
+
--------
|
319
|
+
.. code-block:: python
|
320
|
+
|
321
|
+
>>> from brainstate.util import FrozenDict
|
322
|
+
|
323
|
+
>>> # Create a FrozenDict
|
324
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
325
|
+
>>> fd['a']
|
326
|
+
1
|
327
|
+
|
328
|
+
>>> # Copy with updates (returns new FrozenDict)
|
329
|
+
>>> new_fd = fd.copy({'c': 3})
|
330
|
+
>>> new_fd['c']
|
331
|
+
3
|
332
|
+
|
333
|
+
>>> # Pop an item (returns new dict and popped value)
|
334
|
+
>>> new_fd, value = fd.pop('b')
|
335
|
+
>>> value
|
336
|
+
2
|
337
|
+
>>> 'b' in new_fd
|
338
|
+
False
|
339
|
+
|
340
|
+
>>> # Nested dictionaries are automatically frozen
|
341
|
+
>>> fd = FrozenDict({'x': {'y': 1}})
|
342
|
+
>>> isinstance(fd['x'], FrozenDict)
|
343
|
+
True
|
244
344
|
"""
|
245
345
|
|
246
|
-
__slots__ = ('
|
247
|
-
|
248
|
-
def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name
|
249
|
-
# make sure the dict is as
|
250
|
-
xs = dict(*args, **kwargs)
|
251
|
-
if __unsafe_skip_copy__:
|
252
|
-
self._dict = xs
|
253
|
-
else:
|
254
|
-
self._dict = _prepare_freeze(xs)
|
346
|
+
__slots__ = ('_data', '_hash')
|
255
347
|
|
348
|
+
def __init__(self, *args, **kwargs):
|
349
|
+
"""Initialize a FrozenDict."""
|
350
|
+
data = dict(*args, **kwargs)
|
351
|
+
self._data = self._deep_freeze(data)
|
256
352
|
self._hash = None
|
257
353
|
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
354
|
+
@staticmethod
|
355
|
+
def _deep_freeze(obj: Any) -> Any:
|
356
|
+
"""Recursively freeze nested dictionaries."""
|
357
|
+
if isinstance(obj, FrozenDict):
|
358
|
+
return obj._data
|
359
|
+
elif isinstance(obj, dict):
|
360
|
+
return {k: FrozenDict._deep_freeze(v) for k, v in obj.items()}
|
361
|
+
else:
|
362
|
+
return obj
|
263
363
|
|
264
|
-
def
|
265
|
-
|
364
|
+
def __getitem__(self, key: K) -> V:
|
365
|
+
"""Get an item from the dictionary."""
|
366
|
+
value = self._data[key]
|
367
|
+
if isinstance(value, dict):
|
368
|
+
return FrozenDict(value)
|
369
|
+
return value
|
266
370
|
|
267
|
-
def
|
268
|
-
|
371
|
+
def __setitem__(self, key: K, value: V) -> None:
|
372
|
+
"""Raise an error - FrozenDict is immutable."""
|
373
|
+
raise TypeError("FrozenDict does not support item assignment")
|
269
374
|
|
270
|
-
def
|
271
|
-
|
375
|
+
def __delitem__(self, key: K) -> None:
|
376
|
+
"""Raise an error - FrozenDict is immutable."""
|
377
|
+
raise TypeError("FrozenDict does not support item deletion")
|
272
378
|
|
273
|
-
def
|
274
|
-
|
379
|
+
def __contains__(self, key: object) -> bool:
|
380
|
+
"""Check if a key is in the dictionary."""
|
381
|
+
return key in self._data
|
275
382
|
|
276
|
-
def
|
277
|
-
|
383
|
+
def __iter__(self) -> Iterator[K]:
|
384
|
+
"""Iterate over keys."""
|
385
|
+
return iter(self._data)
|
278
386
|
|
279
|
-
def
|
280
|
-
|
281
|
-
|
282
|
-
def pretty_repr(self, num_spaces=4):
|
283
|
-
"""Returns an indented representation of the nested dictionary."""
|
284
|
-
|
285
|
-
def pretty_dict(x):
|
286
|
-
if not isinstance(x, dict):
|
287
|
-
return repr(x)
|
288
|
-
rep = ''
|
289
|
-
for key, val in x.items():
|
290
|
-
rep += f'{key}: {pretty_dict(val)},\n'
|
291
|
-
if rep:
|
292
|
-
return '{\n' + _indent(rep, num_spaces) + '}'
|
293
|
-
else:
|
294
|
-
return '{}'
|
387
|
+
def __len__(self) -> int:
|
388
|
+
"""Return the number of items."""
|
389
|
+
return len(self._data)
|
295
390
|
|
296
|
-
|
391
|
+
def __repr__(self) -> str:
|
392
|
+
"""Return a string representation."""
|
393
|
+
return self.pretty_repr()
|
297
394
|
|
298
|
-
def __hash__(self):
|
395
|
+
def __hash__(self) -> int:
|
396
|
+
"""Return a hash of the dictionary."""
|
299
397
|
if self._hash is None:
|
300
|
-
|
398
|
+
items = []
|
301
399
|
for key, value in self.items():
|
302
|
-
|
303
|
-
|
400
|
+
if isinstance(value, dict):
|
401
|
+
value = FrozenDict(value)
|
402
|
+
items.append((key, value))
|
403
|
+
self._hash = hash(tuple(sorted(items)))
|
304
404
|
return self._hash
|
305
405
|
|
306
|
-
def
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
406
|
+
def __eq__(self, other: object) -> bool:
|
407
|
+
"""Check equality with another object."""
|
408
|
+
if not isinstance(other, (FrozenDict, dict)):
|
409
|
+
return NotImplemented
|
410
|
+
if isinstance(other, FrozenDict):
|
411
|
+
return self._data == other._data
|
412
|
+
return self._data == other
|
413
|
+
|
414
|
+
def __reduce__(self):
|
415
|
+
"""Support for pickling."""
|
416
|
+
return FrozenDict, (self.unfreeze(),)
|
417
|
+
|
418
|
+
def keys(self) -> KeysView[K]:
|
419
|
+
"""
|
420
|
+
Return a view of the keys.
|
312
421
|
|
313
|
-
|
422
|
+
Returns
|
423
|
+
-------
|
424
|
+
KeysView
|
425
|
+
A view object of the dictionary's keys.
|
426
|
+
"""
|
314
427
|
return FrozenKeysView(self)
|
315
428
|
|
316
|
-
def values(self):
|
429
|
+
def values(self) -> ValuesView[V]:
|
430
|
+
"""
|
431
|
+
Return a view of the values.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
ValuesView
|
436
|
+
A view object of the dictionary's values.
|
437
|
+
"""
|
317
438
|
return FrozenValuesView(self)
|
318
439
|
|
319
|
-
def items(self):
|
320
|
-
|
321
|
-
|
440
|
+
def items(self) -> ItemsView[K, V]:
|
441
|
+
"""
|
442
|
+
Return a view of the items.
|
322
443
|
|
323
|
-
|
324
|
-
|
444
|
+
Yields
|
445
|
+
------
|
446
|
+
tuple
|
447
|
+
Key-value pairs from the dictionary.
|
448
|
+
"""
|
449
|
+
for key in self._data:
|
450
|
+
yield (key, self[key])
|
325
451
|
|
326
|
-
|
452
|
+
def get(self, key: K, default: V | None = None) -> V | None:
|
453
|
+
"""
|
454
|
+
Get a value with a default.
|
455
|
+
|
456
|
+
Parameters
|
457
|
+
----------
|
458
|
+
key : K
|
459
|
+
The key to look up.
|
460
|
+
default : V or None, optional
|
461
|
+
The default value to return if key is not found.
|
462
|
+
|
463
|
+
Returns
|
464
|
+
-------
|
465
|
+
V or None
|
466
|
+
The value associated with the key, or default.
|
467
|
+
"""
|
468
|
+
try:
|
469
|
+
return self[key]
|
470
|
+
except KeyError:
|
471
|
+
return default
|
327
472
|
|
328
|
-
|
329
|
-
|
473
|
+
def copy(self, add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
|
474
|
+
"""
|
475
|
+
Create a new FrozenDict with added or replaced entries.
|
476
|
+
|
477
|
+
Parameters
|
478
|
+
----------
|
479
|
+
add_or_replace : Mapping or None, optional
|
480
|
+
Entries to add or replace in the new dictionary.
|
481
|
+
|
482
|
+
Returns
|
483
|
+
-------
|
484
|
+
FrozenDict
|
485
|
+
A new FrozenDict with the updates applied.
|
486
|
+
|
487
|
+
Examples
|
488
|
+
--------
|
489
|
+
.. code-block:: python
|
490
|
+
|
491
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
492
|
+
>>> fd2 = fd.copy({'b': 3, 'c': 4})
|
493
|
+
>>> fd2['b'], fd2['c']
|
494
|
+
(3, 4)
|
495
|
+
"""
|
496
|
+
if add_or_replace is None:
|
497
|
+
add_or_replace = {}
|
498
|
+
new_data = dict(self._data)
|
499
|
+
new_data.update(add_or_replace)
|
500
|
+
return type(self)(new_data)
|
330
501
|
|
331
|
-
|
332
|
-
key: the key to remove from the dict
|
333
|
-
Returns:
|
334
|
-
A pair with the new FrozenDict and the removed value.
|
502
|
+
def pop(self, key: K) -> tuple[FrozenDict[K, V], V]:
|
335
503
|
"""
|
504
|
+
Create a new FrozenDict with one entry removed.
|
505
|
+
|
506
|
+
Parameters
|
507
|
+
----------
|
508
|
+
key : K
|
509
|
+
The key to remove.
|
510
|
+
|
511
|
+
Returns
|
512
|
+
-------
|
513
|
+
tuple
|
514
|
+
A tuple of (new FrozenDict without the key, removed value).
|
515
|
+
|
516
|
+
Raises
|
517
|
+
------
|
518
|
+
KeyError
|
519
|
+
If the key is not found in the dictionary.
|
520
|
+
|
521
|
+
Examples
|
522
|
+
--------
|
523
|
+
.. code-block:: python
|
524
|
+
|
525
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
526
|
+
>>> fd2, value = fd.pop('a')
|
527
|
+
>>> value
|
528
|
+
1
|
529
|
+
>>> 'a' in fd2
|
530
|
+
False
|
531
|
+
"""
|
532
|
+
if key not in self._data:
|
533
|
+
raise KeyError(key)
|
336
534
|
value = self[key]
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
return new_self, value
|
535
|
+
new_data = dict(self._data)
|
536
|
+
del new_data[key]
|
537
|
+
return type(self)(new_data), value
|
341
538
|
|
342
539
|
def unfreeze(self) -> dict[K, V]:
|
343
|
-
"""
|
344
|
-
|
345
|
-
|
346
|
-
|
540
|
+
"""
|
541
|
+
Convert to a regular mutable dictionary.
|
542
|
+
|
543
|
+
Returns
|
544
|
+
-------
|
545
|
+
dict
|
546
|
+
A mutable dict with the same contents.
|
547
|
+
|
548
|
+
Examples
|
549
|
+
--------
|
550
|
+
.. code-block:: python
|
551
|
+
|
552
|
+
>>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
553
|
+
>>> d = fd.unfreeze()
|
554
|
+
>>> isinstance(d, dict)
|
555
|
+
True
|
556
|
+
>>> isinstance(d['b'], dict) # Nested dicts also unfrozen
|
557
|
+
True
|
347
558
|
"""
|
348
559
|
return unfreeze(self)
|
349
560
|
|
350
|
-
def
|
351
|
-
"""
|
561
|
+
def pretty_repr(self, indent: int = 2) -> str:
|
562
|
+
"""
|
563
|
+
Return a pretty-printed representation.
|
352
564
|
|
353
|
-
|
354
|
-
|
565
|
+
Parameters
|
566
|
+
----------
|
567
|
+
indent : int, optional
|
568
|
+
Number of spaces per indentation level (default 2).
|
569
|
+
|
570
|
+
Returns
|
571
|
+
-------
|
572
|
+
str
|
573
|
+
A formatted string representation of the FrozenDict.
|
355
574
|
"""
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
575
|
+
|
576
|
+
def format_value(v, level):
|
577
|
+
if isinstance(v, dict):
|
578
|
+
if not v:
|
579
|
+
return '{}'
|
580
|
+
items = []
|
581
|
+
for k, val in v.items():
|
582
|
+
formatted_val = format_value(val, level + 1)
|
583
|
+
items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted_val}')
|
584
|
+
return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
|
585
|
+
else:
|
586
|
+
return repr(v)
|
587
|
+
|
588
|
+
if not self._data:
|
589
|
+
return 'FrozenDict({})'
|
590
|
+
|
591
|
+
return f'FrozenDict({format_value(self._data, 0)})'
|
592
|
+
|
593
|
+
def tree_flatten_with_keys(self) -> tuple[list[tuple[Any, Any]], tuple[Any, ...]]:
|
594
|
+
"""Flatten for JAX pytree with keys."""
|
595
|
+
sorted_keys = sorted(self._data.keys())
|
596
|
+
values_with_keys = [
|
597
|
+
(jax.tree_util.DictKey(k), self._data[k])
|
598
|
+
for k in sorted_keys
|
599
|
+
]
|
600
|
+
return values_with_keys, tuple(sorted_keys)
|
360
601
|
|
361
602
|
@classmethod
|
362
|
-
def tree_unflatten(cls, keys, values):
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
"""Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
|
370
|
-
if isinstance(xs, FrozenDict):
|
371
|
-
# we can safely ref share the internal state of a FrozenDict
|
372
|
-
# because it is immutable.
|
373
|
-
return xs._dict # pylint: disable=protected-access
|
374
|
-
if not isinstance(xs, dict):
|
375
|
-
# return a leaf as is.
|
376
|
-
return xs
|
377
|
-
# recursively copy dictionary to avoid ref sharing
|
378
|
-
return {key: _prepare_freeze(val) for key, val in xs.items()}
|
379
|
-
|
380
|
-
|
381
|
-
def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
|
382
|
-
"""Freeze a nested dict.
|
383
|
-
|
384
|
-
Makes a nested ``dict`` immutable by transforming it into ``FrozenDict``.
|
385
|
-
|
386
|
-
Args:
|
387
|
-
xs: Dictionary to freeze (a regualr Python dict).
|
388
|
-
Returns:
|
389
|
-
The frozen dictionary.
|
390
|
-
"""
|
391
|
-
return FrozenDict(xs)
|
603
|
+
def tree_unflatten(cls, keys: tuple[Any, ...], values: list[Any]) -> FrozenDict:
|
604
|
+
"""Unflatten from JAX pytree."""
|
605
|
+
return cls(dict(zip(keys, values)))
|
606
|
+
|
607
|
+
|
608
|
+
class FrozenKeysView(KeysView[K]):
|
609
|
+
"""View of keys in a FrozenDict."""
|
392
610
|
|
611
|
+
def __repr__(self) -> str:
|
612
|
+
return f'FrozenDict.keys({list(self)})'
|
393
613
|
|
394
|
-
def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]:
|
395
|
-
"""Unfreeze a FrozenDict.
|
396
614
|
|
397
|
-
|
398
|
-
|
615
|
+
class FrozenValuesView(ValuesView[V]):
|
616
|
+
"""View of values in a FrozenDict."""
|
399
617
|
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
618
|
+
def __repr__(self) -> str:
|
619
|
+
return f'FrozenDict.values({list(self)})'
|
620
|
+
|
621
|
+
|
622
|
+
def freeze(x: Mapping[K, V]) -> FrozenDict[K, V]:
|
623
|
+
"""
|
624
|
+
Convert a mapping to a FrozenDict.
|
625
|
+
|
626
|
+
Parameters
|
627
|
+
----------
|
628
|
+
x : Mapping
|
629
|
+
A mapping (dict, FrozenDict, etc.) to freeze.
|
630
|
+
|
631
|
+
Returns
|
632
|
+
-------
|
633
|
+
FrozenDict
|
634
|
+
An immutable FrozenDict.
|
635
|
+
|
636
|
+
See Also
|
637
|
+
--------
|
638
|
+
unfreeze : Convert a FrozenDict to a regular dict.
|
639
|
+
FrozenDict : The immutable dictionary class.
|
640
|
+
|
641
|
+
Examples
|
642
|
+
--------
|
643
|
+
.. code-block:: python
|
644
|
+
|
645
|
+
>>> from brainstate.util import freeze
|
646
|
+
|
647
|
+
>>> d = {'a': 1, 'b': {'c': 2}}
|
648
|
+
>>> fd = freeze(d)
|
649
|
+
>>> isinstance(fd, FrozenDict)
|
650
|
+
True
|
651
|
+
>>> isinstance(fd['b'], FrozenDict) # Nested dicts are frozen
|
652
|
+
True
|
653
|
+
"""
|
654
|
+
if isinstance(x, FrozenDict):
|
655
|
+
return x
|
656
|
+
return FrozenDict(x)
|
657
|
+
|
658
|
+
|
659
|
+
def unfreeze(x: FrozenDict[K, V] | dict[K, V]) -> dict[K, V]:
|
660
|
+
"""
|
661
|
+
Convert a FrozenDict to a regular dict.
|
662
|
+
|
663
|
+
Recursively converts FrozenDict instances to mutable dicts.
|
664
|
+
|
665
|
+
Parameters
|
666
|
+
----------
|
667
|
+
x : FrozenDict or dict
|
668
|
+
A FrozenDict or dict to unfreeze.
|
669
|
+
|
670
|
+
Returns
|
671
|
+
-------
|
672
|
+
dict
|
673
|
+
A mutable dictionary.
|
674
|
+
|
675
|
+
See Also
|
676
|
+
--------
|
677
|
+
freeze : Convert a mapping to a FrozenDict.
|
678
|
+
FrozenDict : The immutable dictionary class.
|
679
|
+
|
680
|
+
Examples
|
681
|
+
--------
|
682
|
+
.. code-block:: python
|
683
|
+
|
684
|
+
>>> from brainstate.util import FrozenDict, unfreeze
|
685
|
+
|
686
|
+
>>> fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
687
|
+
>>> d = unfreeze(fd)
|
688
|
+
>>> isinstance(d, dict)
|
689
|
+
True
|
690
|
+
>>> isinstance(d['b'], dict) # Nested FrozenDicts are unfrozen
|
691
|
+
True
|
692
|
+
>>> d['a'] = 10 # Can modify the result
|
404
693
|
"""
|
405
694
|
if isinstance(x, FrozenDict):
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore
|
695
|
+
result = {}
|
696
|
+
for key, value in x._data.items():
|
697
|
+
result[key] = unfreeze(value)
|
698
|
+
return result
|
411
699
|
elif isinstance(x, dict):
|
412
|
-
|
700
|
+
result = {}
|
413
701
|
for key, value in x.items():
|
414
|
-
|
415
|
-
return
|
702
|
+
result[key] = unfreeze(value)
|
703
|
+
return result
|
416
704
|
else:
|
417
705
|
return x
|
418
706
|
|
419
707
|
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
) -> FrozenDict | dict[str, Any]:
|
424
|
-
"""Create a new dict with additional and/or replaced entries. This is a utility
|
425
|
-
function that can act on either a FrozenDict or regular dict and mimics the
|
426
|
-
behavior of ``FrozenDict.copy``.
|
708
|
+
@overload
|
709
|
+
def copy(x: FrozenDict[K, V], add_or_replace: Mapping[K, V] | None = None) -> FrozenDict[K, V]:
|
710
|
+
...
|
427
711
|
|
428
|
-
Example::
|
429
712
|
|
430
|
-
|
431
|
-
|
713
|
+
@overload
|
714
|
+
def copy(x: dict[K, V], add_or_replace: Mapping[K, V] | None = None) -> dict[K, V]:
|
715
|
+
...
|
432
716
|
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
717
|
+
|
718
|
+
def copy(x, add_or_replace=None):
|
719
|
+
"""
|
720
|
+
Copy a dictionary with optional updates.
|
721
|
+
|
722
|
+
Works with both FrozenDict and regular dict.
|
723
|
+
|
724
|
+
Parameters
|
725
|
+
----------
|
726
|
+
x : FrozenDict or dict
|
727
|
+
Dictionary to copy.
|
728
|
+
add_or_replace : Mapping or None, optional
|
729
|
+
Entries to add or replace in the copy.
|
730
|
+
|
731
|
+
Returns
|
732
|
+
-------
|
733
|
+
FrozenDict or dict
|
734
|
+
A copy of the same type as the input with updates applied.
|
735
|
+
|
736
|
+
Raises
|
737
|
+
------
|
738
|
+
TypeError
|
739
|
+
If x is not a FrozenDict or dict.
|
740
|
+
|
741
|
+
See Also
|
742
|
+
--------
|
743
|
+
FrozenDict.copy : Copy method for FrozenDict.
|
744
|
+
|
745
|
+
Examples
|
746
|
+
--------
|
747
|
+
.. code-block:: python
|
748
|
+
|
749
|
+
>>> from brainstate.util import FrozenDict, copy
|
750
|
+
|
751
|
+
>>> # Works with FrozenDict
|
752
|
+
>>> fd = FrozenDict({'a': 1})
|
753
|
+
>>> fd2 = copy(fd, {'b': 2})
|
754
|
+
>>> isinstance(fd2, FrozenDict)
|
755
|
+
True
|
756
|
+
>>> fd2['b']
|
757
|
+
2
|
758
|
+
|
759
|
+
>>> # Also works with regular dict
|
760
|
+
>>> d = {'a': 1}
|
761
|
+
>>> d2 = copy(d, {'b': 2})
|
762
|
+
>>> isinstance(d2, dict)
|
763
|
+
True
|
764
|
+
>>> d2['b']
|
765
|
+
2
|
438
766
|
"""
|
767
|
+
if add_or_replace is None:
|
768
|
+
add_or_replace = {}
|
439
769
|
|
440
770
|
if isinstance(x, FrozenDict):
|
441
771
|
return x.copy(add_or_replace)
|
442
772
|
elif isinstance(x, dict):
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
def pop(
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
>>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
|
461
|
-
>>> new_variables, params = pop(variables, 'params')
|
462
|
-
|
463
|
-
Args:
|
464
|
-
x: the dictionary to remove the entry from
|
465
|
-
key: the key to remove from the dict
|
466
|
-
Returns:
|
467
|
-
A pair with the new dict and the removed value.
|
468
|
-
"""
|
773
|
+
result = dict(x)
|
774
|
+
result.update(add_or_replace)
|
775
|
+
return result
|
776
|
+
else:
|
777
|
+
raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
|
778
|
+
|
779
|
+
|
780
|
+
@overload
|
781
|
+
def pop(x: FrozenDict[K, V], key: K) -> tuple[FrozenDict[K, V], V]:
|
782
|
+
...
|
783
|
+
|
784
|
+
|
785
|
+
@overload
|
786
|
+
def pop(x: dict[K, V], key: K) -> tuple[dict[K, V], V]:
|
787
|
+
...
|
788
|
+
|
469
789
|
|
790
|
+
def pop(x, key):
|
791
|
+
"""
|
792
|
+
Remove and return an item from a dictionary.
|
793
|
+
|
794
|
+
Works with both FrozenDict and regular dict, returning a new
|
795
|
+
dictionary without the specified key along with the popped value.
|
796
|
+
|
797
|
+
Parameters
|
798
|
+
----------
|
799
|
+
x : FrozenDict or dict
|
800
|
+
Dictionary to pop from.
|
801
|
+
key : hashable
|
802
|
+
Key to remove.
|
803
|
+
|
804
|
+
Returns
|
805
|
+
-------
|
806
|
+
tuple
|
807
|
+
A tuple of (new dictionary without the key, popped value).
|
808
|
+
|
809
|
+
Raises
|
810
|
+
------
|
811
|
+
TypeError
|
812
|
+
If x is not a FrozenDict or dict.
|
813
|
+
KeyError
|
814
|
+
If the key is not found in the dictionary.
|
815
|
+
|
816
|
+
See Also
|
817
|
+
--------
|
818
|
+
FrozenDict.pop : Pop method for FrozenDict.
|
819
|
+
|
820
|
+
Examples
|
821
|
+
--------
|
822
|
+
.. code-block:: python
|
823
|
+
|
824
|
+
>>> from brainstate.util import FrozenDict, pop
|
825
|
+
|
826
|
+
>>> # Works with FrozenDict
|
827
|
+
>>> fd = FrozenDict({'a': 1, 'b': 2})
|
828
|
+
>>> fd2, value = pop(fd, 'a')
|
829
|
+
>>> value
|
830
|
+
1
|
831
|
+
>>> 'a' in fd2
|
832
|
+
False
|
833
|
+
|
834
|
+
>>> # Also works with regular dict
|
835
|
+
>>> d = {'a': 1, 'b': 2}
|
836
|
+
>>> d2, value = pop(d, 'a')
|
837
|
+
>>> value
|
838
|
+
1
|
839
|
+
>>> 'a' in d2
|
840
|
+
False
|
841
|
+
"""
|
470
842
|
if isinstance(x, FrozenDict):
|
471
843
|
return x.pop(key)
|
472
844
|
elif isinstance(x, dict):
|
473
|
-
new_dict =
|
474
|
-
lambda x: x, x
|
475
|
-
) # make a deep copy of dict x
|
845
|
+
new_dict = dict(x)
|
476
846
|
value = new_dict.pop(key)
|
477
847
|
return new_dict, value
|
478
|
-
|
479
|
-
|
848
|
+
else:
|
849
|
+
raise TypeError(f"Expected FrozenDict or dict, got {type(x)}")
|
480
850
|
|
481
|
-
def pretty_repr(x: Any, num_spaces: int = 4) -> str:
|
482
|
-
"""Returns an indented representation of the nested dictionary.
|
483
|
-
This is a utility function that can act on either a FrozenDict or
|
484
|
-
regular dict and mimics the behavior of ``FrozenDict.pretty_repr``.
|
485
|
-
If x is any other dtype, this function will return ``repr(x)``.
|
486
851
|
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
852
|
+
def pretty_repr(x: Any, indent: int = 2) -> str:
|
853
|
+
"""
|
854
|
+
Create a pretty string representation.
|
855
|
+
|
856
|
+
Parameters
|
857
|
+
----------
|
858
|
+
x : any
|
859
|
+
Object to represent. If a dict or FrozenDict, will be
|
860
|
+
pretty-printed with indentation. Otherwise, returns repr(x).
|
861
|
+
indent : int, optional
|
862
|
+
Number of spaces per indentation level (default 2).
|
863
|
+
|
864
|
+
Returns
|
865
|
+
-------
|
866
|
+
str
|
867
|
+
A formatted string representation.
|
868
|
+
|
869
|
+
See Also
|
870
|
+
--------
|
871
|
+
FrozenDict.pretty_repr : Pretty representation for FrozenDict.
|
872
|
+
|
873
|
+
Examples
|
874
|
+
--------
|
875
|
+
.. code-block:: python
|
876
|
+
|
877
|
+
>>> from brainstate.util import pretty_repr
|
878
|
+
|
879
|
+
>>> d = {'a': 1, 'b': {'c': 2, 'd': 3}}
|
880
|
+
>>> print(pretty_repr(d))
|
881
|
+
{
|
882
|
+
'a': 1,
|
883
|
+
'b': {
|
884
|
+
'c': 2,
|
885
|
+
'd': 3
|
886
|
+
}
|
887
|
+
}
|
888
|
+
|
889
|
+
>>> # Non-dict objects return normal repr
|
890
|
+
>>> pretty_repr([1, 2, 3])
|
891
|
+
'[1, 2, 3]'
|
492
892
|
"""
|
493
|
-
|
494
893
|
if isinstance(x, FrozenDict):
|
495
|
-
return x.pretty_repr()
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
if not isinstance(x, dict):
|
500
|
-
return repr(x)
|
501
|
-
rep = ''
|
502
|
-
for key, val in x.items():
|
503
|
-
rep += f'{key}: {pretty_dict(val)},\n'
|
504
|
-
if rep:
|
505
|
-
return '{\n' + _indent(rep, num_spaces) + '}'
|
506
|
-
else:
|
894
|
+
return x.pretty_repr(indent)
|
895
|
+
elif isinstance(x, dict):
|
896
|
+
def format_dict(d, level):
|
897
|
+
if not d:
|
507
898
|
return '{}'
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
"""A wrapper for a more useful repr of the values in a frozen dict."""
|
521
|
-
|
522
|
-
def __repr__(self):
|
523
|
-
return f'frozen_dict_values({list(self)})'
|
899
|
+
items = []
|
900
|
+
for k, v in d.items():
|
901
|
+
if isinstance(v, dict):
|
902
|
+
formatted = format_dict(v, level + 1)
|
903
|
+
else:
|
904
|
+
formatted = repr(v)
|
905
|
+
items.append(f'{" " * (level + 1) * indent}{k!r}: {formatted}')
|
906
|
+
return '{\n' + ',\n'.join(items) + f'\n{" " * level * indent}}}'
|
907
|
+
|
908
|
+
return format_dict(x, 0)
|
909
|
+
else:
|
910
|
+
return repr(x)
|