brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,524 @@
|
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
"""Utilities for defining custom classes that can be used with jax transformations."""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import collections
|
23
|
+
import dataclasses
|
24
|
+
from collections.abc import Hashable, Mapping
|
25
|
+
from types import MappingProxyType
|
26
|
+
from typing import Any, TypeVar
|
27
|
+
|
28
|
+
import jax
|
29
|
+
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'dataclass',
|
33
|
+
'field',
|
34
|
+
'PyTreeNode',
|
35
|
+
'FrozenDict',
|
36
|
+
]
|
37
|
+
|
38
|
+
K = TypeVar('K')
|
39
|
+
V = TypeVar('V')
|
40
|
+
T = TypeVar('T')
|
41
|
+
|
42
|
+
|
43
|
+
def field(pytree_node=True, *, metadata=None, **kwargs):
|
44
|
+
return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node}, **kwargs)
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
|
48
|
+
def dataclass(clz: T, **kwargs) -> T:
|
49
|
+
"""Create a class which can be passed to functional transformations.
|
50
|
+
|
51
|
+
.. note::
|
52
|
+
Inherit from ``PyTreeNode`` instead to avoid type checking issues when
|
53
|
+
using PyType.
|
54
|
+
|
55
|
+
Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are
|
56
|
+
immutable and can be mapped over using the ``jax.tree_util`` methods.
|
57
|
+
The ``dataclass`` decorator makes it easy to define custom classes that can be
|
58
|
+
passed safely to Jax. For example::
|
59
|
+
|
60
|
+
>>> import brainstate as bst
|
61
|
+
>>> import jax
|
62
|
+
>>> from typing import Any, Callable
|
63
|
+
|
64
|
+
>>> @bst.util.dataclass
|
65
|
+
... class Model:
|
66
|
+
... params: Any
|
67
|
+
... # use pytree_node=False to indicate an attribute should not be touched
|
68
|
+
... # by Jax transformations.
|
69
|
+
... apply_fn: Callable = bst.util.field(pytree_node=False)
|
70
|
+
|
71
|
+
... def __apply__(self, *args):
|
72
|
+
... return self.apply_fn(*args)
|
73
|
+
|
74
|
+
>>> params = {}
|
75
|
+
>>> params_b = {}
|
76
|
+
>>> apply_fn = lambda v, x: x
|
77
|
+
>>> model = Model(params, apply_fn)
|
78
|
+
|
79
|
+
>>> # model.params = params_b # Model is immutable. This will raise an error.
|
80
|
+
>>> model_b = model.replace(params=params_b) # Use the replace method instead.
|
81
|
+
|
82
|
+
>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
|
83
|
+
>>> # parameters.
|
84
|
+
>>> model = Model(params, apply_fn)
|
85
|
+
>>> loss_fn = lambda model: 3.
|
86
|
+
>>> model_grad = jax.grad(loss_fn)(model)
|
87
|
+
|
88
|
+
Note that dataclasses have an auto-generated ``__init__`` where
|
89
|
+
the arguments of the constructor and the attributes of the created
|
90
|
+
instance match 1:1. This correspondence is what makes these objects
|
91
|
+
valid containers that work with JAX transformations and
|
92
|
+
more generally the ``jax.tree_util`` library.
|
93
|
+
|
94
|
+
Sometimes a "smart constructor" is desired, for example because
|
95
|
+
some of the attributes can be (optionally) derived from others.
|
96
|
+
The way to do this with Flax dataclasses is to make a static or
|
97
|
+
class method that provides the smart constructor.
|
98
|
+
This way the simple constructor used by ``jax.tree_util`` is
|
99
|
+
preserved. Consider the following example::
|
100
|
+
|
101
|
+
>>> @bst.util.dataclass
|
102
|
+
... class DirectionAndScaleKernel:
|
103
|
+
... direction: jax.Array
|
104
|
+
... scale: jax.Array
|
105
|
+
|
106
|
+
... @classmethod
|
107
|
+
... def create(cls, kernel):
|
108
|
+
... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
|
109
|
+
... direction = direction / scale
|
110
|
+
... return cls(direction, scale)
|
111
|
+
|
112
|
+
Args:
|
113
|
+
clz: the class that will be transformed by the decorator.
|
114
|
+
Returns:
|
115
|
+
The new class.
|
116
|
+
"""
|
117
|
+
# check if already a flax dataclass
|
118
|
+
if '_flax_dataclass' in clz.__dict__:
|
119
|
+
return clz
|
120
|
+
|
121
|
+
if 'frozen' not in kwargs.keys():
|
122
|
+
kwargs['frozen'] = True
|
123
|
+
data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
|
124
|
+
meta_fields = []
|
125
|
+
data_fields = []
|
126
|
+
for field_info in dataclasses.fields(data_clz):
|
127
|
+
is_pytree_node = field_info.metadata.get('pytree_node', True)
|
128
|
+
if is_pytree_node:
|
129
|
+
data_fields.append(field_info.name)
|
130
|
+
else:
|
131
|
+
meta_fields.append(field_info.name)
|
132
|
+
|
133
|
+
def replace(self, **updates):
|
134
|
+
""" "Returns a new object replacing the specified fields with new values."""
|
135
|
+
return dataclasses.replace(self, **updates)
|
136
|
+
|
137
|
+
data_clz.replace = replace
|
138
|
+
|
139
|
+
# Remove this guard once minimux JAX version is >0.4.26.
|
140
|
+
try:
|
141
|
+
if hasattr(jax.tree_util, 'register_dataclass'):
|
142
|
+
jax.tree_util.register_dataclass(
|
143
|
+
data_clz, data_fields, meta_fields
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
raise NotImplementedError
|
147
|
+
except NotImplementedError:
|
148
|
+
|
149
|
+
def iterate_clz(x):
|
150
|
+
meta = tuple(getattr(x, name) for name in meta_fields)
|
151
|
+
data = tuple(getattr(x, name) for name in data_fields)
|
152
|
+
return data, meta
|
153
|
+
|
154
|
+
def iterate_clz_with_keys(x):
|
155
|
+
meta = tuple(getattr(x, name) for name in meta_fields)
|
156
|
+
data = tuple(
|
157
|
+
(jax.tree_util.GetAttrKey(name), getattr(x, name))
|
158
|
+
for name in data_fields
|
159
|
+
)
|
160
|
+
return data, meta
|
161
|
+
|
162
|
+
def clz_from_iterable(meta, data):
|
163
|
+
meta_args = tuple(zip(meta_fields, meta))
|
164
|
+
data_args = tuple(zip(data_fields, data))
|
165
|
+
kwargs = dict(meta_args + data_args)
|
166
|
+
return data_clz(**kwargs)
|
167
|
+
|
168
|
+
jax.tree_util.register_pytree_with_keys(
|
169
|
+
data_clz,
|
170
|
+
iterate_clz_with_keys,
|
171
|
+
clz_from_iterable,
|
172
|
+
iterate_clz,
|
173
|
+
)
|
174
|
+
|
175
|
+
# add a _flax_dataclass flag to distinguish from regular dataclasses
|
176
|
+
data_clz._flax_dataclass = True # type: ignore[attr-defined]
|
177
|
+
|
178
|
+
return data_clz # type: ignore
|
179
|
+
|
180
|
+
|
181
|
+
TNode = TypeVar('TNode', bound='PyTreeNode')
|
182
|
+
|
183
|
+
|
184
|
+
@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
|
185
|
+
class PyTreeNode:
|
186
|
+
"""Base class for dataclasses that should act like a JAX pytree node.
|
187
|
+
|
188
|
+
See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.
|
189
|
+
This base class additionally avoids type checking errors when using PyType.
|
190
|
+
|
191
|
+
Example::
|
192
|
+
|
193
|
+
>>> import brainstate as bst
|
194
|
+
>>> import jax
|
195
|
+
>>> from typing import Any, Callable
|
196
|
+
|
197
|
+
>>> class Model(bst.util.PyTreeNode):
|
198
|
+
... params: Any
|
199
|
+
... # use pytree_node=False to indicate an attribute should not be touched
|
200
|
+
... # by Jax transformations.
|
201
|
+
... apply_fn: Callable = bst.util.field(pytree_node=False)
|
202
|
+
|
203
|
+
... def __apply__(self, *args):
|
204
|
+
... return self.apply_fn(*args)
|
205
|
+
|
206
|
+
>>> params = {}
|
207
|
+
>>> params_b = {}
|
208
|
+
>>> apply_fn = lambda v, x: x
|
209
|
+
>>> model = Model(params, apply_fn)
|
210
|
+
|
211
|
+
>>> # model.params = params_b # Model is immutable. This will raise an error.
|
212
|
+
>>> model_b = model.replace(params=params_b) # Use the replace method instead.
|
213
|
+
|
214
|
+
>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
|
215
|
+
>>> # parameters.
|
216
|
+
>>> model = Model(params, apply_fn)
|
217
|
+
>>> loss_fn = lambda model: 3.
|
218
|
+
>>> model_grad = jax.grad(loss_fn)(model)
|
219
|
+
"""
|
220
|
+
|
221
|
+
def __init_subclass__(cls, **kwargs):
|
222
|
+
dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types
|
223
|
+
|
224
|
+
def __init__(self, *args, **kwargs):
|
225
|
+
# stub for pytype
|
226
|
+
raise NotImplementedError
|
227
|
+
|
228
|
+
def replace(self: TNode, **overrides) -> TNode:
|
229
|
+
# stub for pytype
|
230
|
+
raise NotImplementedError
|
231
|
+
|
232
|
+
|
233
|
+
def _indent(x, num_spaces):
|
234
|
+
indent_str = ' ' * num_spaces
|
235
|
+
lines = x.split('\n')
|
236
|
+
assert not lines[-1]
|
237
|
+
# skip the final line because it's empty and should not be indented.
|
238
|
+
return '\n'.join(indent_str + line for line in lines[:-1]) + '\n'
|
239
|
+
|
240
|
+
|
241
|
+
@jax.tree_util.register_pytree_with_keys_class
|
242
|
+
class FrozenDict(Mapping[K, V]):
|
243
|
+
"""An immutable variant of the Python dict."""
|
244
|
+
|
245
|
+
__slots__ = ('_dict', '_hash')
|
246
|
+
|
247
|
+
def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name
|
248
|
+
# make sure the dict is as
|
249
|
+
xs = dict(*args, **kwargs)
|
250
|
+
if __unsafe_skip_copy__:
|
251
|
+
self._dict = xs
|
252
|
+
else:
|
253
|
+
self._dict = _prepare_freeze(xs)
|
254
|
+
|
255
|
+
self._hash = None
|
256
|
+
|
257
|
+
def __getitem__(self, key):
|
258
|
+
v = self._dict[key]
|
259
|
+
if isinstance(v, dict):
|
260
|
+
return FrozenDict(v)
|
261
|
+
return v
|
262
|
+
|
263
|
+
def __setitem__(self, key, value):
|
264
|
+
raise ValueError('FrozenDict is immutable.')
|
265
|
+
|
266
|
+
def __contains__(self, key):
|
267
|
+
return key in self._dict
|
268
|
+
|
269
|
+
def __iter__(self):
|
270
|
+
return iter(self._dict)
|
271
|
+
|
272
|
+
def __len__(self):
|
273
|
+
return len(self._dict)
|
274
|
+
|
275
|
+
def __repr__(self):
|
276
|
+
return self.pretty_repr()
|
277
|
+
|
278
|
+
def __reduce__(self):
|
279
|
+
return FrozenDict, (self.unfreeze(),)
|
280
|
+
|
281
|
+
def pretty_repr(self, num_spaces=4):
|
282
|
+
"""Returns an indented representation of the nested dictionary."""
|
283
|
+
|
284
|
+
def pretty_dict(x):
|
285
|
+
if not isinstance(x, dict):
|
286
|
+
return repr(x)
|
287
|
+
rep = ''
|
288
|
+
for key, val in x.items():
|
289
|
+
rep += f'{key}: {pretty_dict(val)},\n'
|
290
|
+
if rep:
|
291
|
+
return '{\n' + _indent(rep, num_spaces) + '}'
|
292
|
+
else:
|
293
|
+
return '{}'
|
294
|
+
|
295
|
+
return f'FrozenDict({pretty_dict(self._dict)})'
|
296
|
+
|
297
|
+
def __hash__(self):
|
298
|
+
if self._hash is None:
|
299
|
+
h = 0
|
300
|
+
for key, value in self.items():
|
301
|
+
h ^= hash((key, value))
|
302
|
+
self._hash = h
|
303
|
+
return self._hash
|
304
|
+
|
305
|
+
def copy(
|
306
|
+
self, add_or_replace: Mapping[K, V] = MappingProxyType({})
|
307
|
+
) -> 'FrozenDict[K, V]':
|
308
|
+
"""Create a new FrozenDict with additional or replaced entries."""
|
309
|
+
return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
|
310
|
+
|
311
|
+
def keys(self):
|
312
|
+
return FrozenKeysView(self)
|
313
|
+
|
314
|
+
def values(self):
|
315
|
+
return FrozenValuesView(self)
|
316
|
+
|
317
|
+
def items(self):
|
318
|
+
for key in self._dict:
|
319
|
+
yield (key, self[key])
|
320
|
+
|
321
|
+
def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]:
|
322
|
+
"""Create a new FrozenDict where one entry is removed.
|
323
|
+
|
324
|
+
Example::
|
325
|
+
|
326
|
+
>>> from flax.core import FrozenDict
|
327
|
+
>>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
|
328
|
+
>>> new_variables, params = variables.pop('params')
|
329
|
+
|
330
|
+
Args:
|
331
|
+
key: the key to remove from the dict
|
332
|
+
Returns:
|
333
|
+
A pair with the new FrozenDict and the removed value.
|
334
|
+
"""
|
335
|
+
value = self[key]
|
336
|
+
new_dict = dict(self._dict)
|
337
|
+
new_dict.pop(key)
|
338
|
+
new_self = type(self)(new_dict)
|
339
|
+
return new_self, value
|
340
|
+
|
341
|
+
def unfreeze(self) -> dict[K, V]:
|
342
|
+
"""Unfreeze this FrozenDict.
|
343
|
+
|
344
|
+
Returns:
|
345
|
+
An unfrozen version of this FrozenDict instance.
|
346
|
+
"""
|
347
|
+
return unfreeze(self)
|
348
|
+
|
349
|
+
def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]:
|
350
|
+
"""Flattens this FrozenDict.
|
351
|
+
|
352
|
+
Returns:
|
353
|
+
A flattened version of this FrozenDict instance.
|
354
|
+
"""
|
355
|
+
sorted_keys = sorted(self._dict)
|
356
|
+
return tuple(
|
357
|
+
[(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]
|
358
|
+
), tuple(sorted_keys)
|
359
|
+
|
360
|
+
@classmethod
|
361
|
+
def tree_unflatten(cls, keys, values):
|
362
|
+
# data is already deep copied due to tree map mechanism
|
363
|
+
# we can skip the deep copy in the constructor
|
364
|
+
return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
|
365
|
+
|
366
|
+
|
367
|
+
def _prepare_freeze(xs: Any) -> Any:
|
368
|
+
"""Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
|
369
|
+
if isinstance(xs, FrozenDict):
|
370
|
+
# we can safely ref share the internal state of a FrozenDict
|
371
|
+
# because it is immutable.
|
372
|
+
return xs._dict # pylint: disable=protected-access
|
373
|
+
if not isinstance(xs, dict):
|
374
|
+
# return a leaf as is.
|
375
|
+
return xs
|
376
|
+
# recursively copy dictionary to avoid ref sharing
|
377
|
+
return {key: _prepare_freeze(val) for key, val in xs.items()}
|
378
|
+
|
379
|
+
|
380
|
+
def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
|
381
|
+
"""Freeze a nested dict.
|
382
|
+
|
383
|
+
Makes a nested ``dict`` immutable by transforming it into ``FrozenDict``.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
xs: Dictionary to freeze (a regualr Python dict).
|
387
|
+
Returns:
|
388
|
+
The frozen dictionary.
|
389
|
+
"""
|
390
|
+
return FrozenDict(xs)
|
391
|
+
|
392
|
+
|
393
|
+
def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]:
|
394
|
+
"""Unfreeze a FrozenDict.
|
395
|
+
|
396
|
+
Makes a mutable copy of a ``FrozenDict`` mutable by transforming
|
397
|
+
it into (nested) dict.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
x: Frozen dictionary to unfreeze.
|
401
|
+
Returns:
|
402
|
+
The unfrozen dictionary (a regular Python dict).
|
403
|
+
"""
|
404
|
+
if isinstance(x, FrozenDict):
|
405
|
+
# deep copy internal state of a FrozenDict
|
406
|
+
# the dict branch would also work here but
|
407
|
+
# it is much less performant because jax.tree_util.tree_map
|
408
|
+
# uses an optimized C implementation.
|
409
|
+
return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore
|
410
|
+
elif isinstance(x, dict):
|
411
|
+
ys = {}
|
412
|
+
for key, value in x.items():
|
413
|
+
ys[key] = unfreeze(value)
|
414
|
+
return ys
|
415
|
+
else:
|
416
|
+
return x
|
417
|
+
|
418
|
+
|
419
|
+
def copy(
|
420
|
+
x: FrozenDict | dict[str, Any],
|
421
|
+
add_or_replace: FrozenDict[str, Any] | dict[str, Any] = FrozenDict({}),
|
422
|
+
) -> FrozenDict | dict[str, Any]:
|
423
|
+
"""Create a new dict with additional and/or replaced entries. This is a utility
|
424
|
+
function that can act on either a FrozenDict or regular dict and mimics the
|
425
|
+
behavior of ``FrozenDict.copy``.
|
426
|
+
|
427
|
+
Example::
|
428
|
+
|
429
|
+
>>> from flax.core import FrozenDict, copy
|
430
|
+
>>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
|
431
|
+
>>> new_variables = copy(variables, {'additional_entries': 1})
|
432
|
+
|
433
|
+
Args:
|
434
|
+
x: the dictionary to be copied and updated
|
435
|
+
add_or_replace: dictionary of key-value pairs to add or replace in the dict x
|
436
|
+
Returns:
|
437
|
+
A new dict with the additional and/or replaced entries.
|
438
|
+
"""
|
439
|
+
|
440
|
+
if isinstance(x, FrozenDict):
|
441
|
+
return x.copy(add_or_replace)
|
442
|
+
elif isinstance(x, dict):
|
443
|
+
new_dict = jax.tree_util.tree_map(
|
444
|
+
lambda x: x, x
|
445
|
+
) # make a deep copy of dict x
|
446
|
+
new_dict.update(add_or_replace)
|
447
|
+
return new_dict
|
448
|
+
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
|
449
|
+
|
450
|
+
|
451
|
+
def pop(
|
452
|
+
x: FrozenDict | dict[str, Any], key: str
|
453
|
+
) -> tuple[FrozenDict | dict[str, Any], Any]:
|
454
|
+
"""Create a new dict where one entry is removed. This is a utility
|
455
|
+
function that can act on either a FrozenDict or regular dict and
|
456
|
+
mimics the behavior of ``FrozenDict.pop``.
|
457
|
+
|
458
|
+
Example::
|
459
|
+
|
460
|
+
>>> from flax.core import FrozenDict, pop
|
461
|
+
>>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
|
462
|
+
>>> new_variables, params = pop(variables, 'params')
|
463
|
+
|
464
|
+
Args:
|
465
|
+
x: the dictionary to remove the entry from
|
466
|
+
key: the key to remove from the dict
|
467
|
+
Returns:
|
468
|
+
A pair with the new dict and the removed value.
|
469
|
+
"""
|
470
|
+
|
471
|
+
if isinstance(x, FrozenDict):
|
472
|
+
return x.pop(key)
|
473
|
+
elif isinstance(x, dict):
|
474
|
+
new_dict = jax.tree_util.tree_map(
|
475
|
+
lambda x: x, x
|
476
|
+
) # make a deep copy of dict x
|
477
|
+
value = new_dict.pop(key)
|
478
|
+
return new_dict, value
|
479
|
+
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
|
480
|
+
|
481
|
+
|
482
|
+
def pretty_repr(x: Any, num_spaces: int = 4) -> str:
|
483
|
+
"""Returns an indented representation of the nested dictionary.
|
484
|
+
This is a utility function that can act on either a FrozenDict or
|
485
|
+
regular dict and mimics the behavior of ``FrozenDict.pretty_repr``.
|
486
|
+
If x is any other dtype, this function will return ``repr(x)``.
|
487
|
+
|
488
|
+
Args:
|
489
|
+
x: the dictionary to be represented
|
490
|
+
num_spaces: the number of space characters in each indentation level
|
491
|
+
Returns:
|
492
|
+
An indented string representation of the nested dictionary.
|
493
|
+
"""
|
494
|
+
|
495
|
+
if isinstance(x, FrozenDict):
|
496
|
+
return x.pretty_repr()
|
497
|
+
else:
|
498
|
+
|
499
|
+
def pretty_dict(x):
|
500
|
+
if not isinstance(x, dict):
|
501
|
+
return repr(x)
|
502
|
+
rep = ''
|
503
|
+
for key, val in x.items():
|
504
|
+
rep += f'{key}: {pretty_dict(val)},\n'
|
505
|
+
if rep:
|
506
|
+
return '{\n' + _indent(rep, num_spaces) + '}'
|
507
|
+
else:
|
508
|
+
return '{}'
|
509
|
+
|
510
|
+
return pretty_dict(x)
|
511
|
+
|
512
|
+
|
513
|
+
class FrozenKeysView(collections.abc.KeysView):
|
514
|
+
"""A wrapper for a more useful repr of the keys in a frozen dict."""
|
515
|
+
|
516
|
+
def __repr__(self):
|
517
|
+
return f'frozen_dict_keys({list(self)})'
|
518
|
+
|
519
|
+
|
520
|
+
class FrozenValuesView(collections.abc.ValuesView):
|
521
|
+
"""A wrapper for a more useful repr of the values in a frozen dict."""
|
522
|
+
|
523
|
+
def __repr__(self):
|
524
|
+
return f'frozen_dict_values({list(self)})'
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
import jax
|
18
|
+
import jax.core
|
19
|
+
from jax.interpreters import partial_eval as pe
|
20
|
+
|
21
|
+
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
'StateJaxTracer',
|
25
|
+
]
|
26
|
+
|
27
|
+
|
28
|
+
def new_jax_trace():
|
29
|
+
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
30
|
+
frame = main.jaxpr_stack[-1]
|
31
|
+
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
32
|
+
return frame, trace
|
33
|
+
|
34
|
+
|
35
|
+
def current_jax_trace():
|
36
|
+
"""Returns the Jax tracing state."""
|
37
|
+
if jax.__version_info__ <= (0, 4, 33):
|
38
|
+
return jax.core.thread_local_state.trace_state.trace_stack.dynamic
|
39
|
+
return jax.core.get_opaque_trace_state(convention="nnx")
|
40
|
+
|
41
|
+
|
42
|
+
class StateJaxTracer(PrettyRepr):
|
43
|
+
__slots__ = ['_jax_trace']
|
44
|
+
|
45
|
+
def __init__(self):
|
46
|
+
self._jax_trace = current_jax_trace()
|
47
|
+
|
48
|
+
@property
|
49
|
+
def jax_trace(self):
|
50
|
+
return self._jax_trace
|
51
|
+
|
52
|
+
def is_valid(self) -> bool:
|
53
|
+
if jax.__version_info__ <= (0, 4, 33):
|
54
|
+
return self._jax_trace is current_jax_trace()
|
55
|
+
else:
|
56
|
+
return self._jax_trace == current_jax_trace()
|
57
|
+
|
58
|
+
def __eq__(self, other):
|
59
|
+
if jax.__version_info__ <= (0, 4, 33):
|
60
|
+
return isinstance(other, StateJaxTracer) and self._jax_trace is other._jax_trace
|
61
|
+
else:
|
62
|
+
return isinstance(other, StateJaxTracer) and self._jax_trace == other._jax_trace
|
63
|
+
|
64
|
+
def __pretty_repr__(self):
|
65
|
+
yield PrettyType(f'{type(self).__name__}')
|
66
|
+
yield PrettyAttr('jax_trace', self._jax_trace)
|
67
|
+
|
68
|
+
def __treescope_repr__(self, path, subtree_renderer):
|
69
|
+
import treescope # type: ignore[import-not-found,import-untyped]
|
70
|
+
return treescope.repr_lib.render_object_constructor(
|
71
|
+
object_type=type(self),
|
72
|
+
attributes={'jax_trace': self._jax_trace},
|
73
|
+
path=path,
|
74
|
+
subtree_renderer=subtree_renderer,
|
75
|
+
)
|
@@ -13,35 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from __future__ import annotations
|
16
17
|
|
17
18
|
__all__ = [
|
18
|
-
|
19
|
+
'display',
|
19
20
|
]
|
20
21
|
|
21
|
-
|
22
22
|
import importlib.util
|
23
23
|
|
24
24
|
treescope_installed = importlib.util.find_spec('treescope') is not None
|
25
25
|
try:
|
26
|
-
|
26
|
+
from IPython import get_ipython
|
27
27
|
|
28
|
-
|
28
|
+
in_ipython = get_ipython() is not None
|
29
29
|
except ImportError:
|
30
|
-
|
30
|
+
in_ipython = False
|
31
31
|
|
32
32
|
|
33
33
|
def display(*args):
|
34
|
-
|
34
|
+
"""Display the given objects using the Treescope pretty-printer.
|
35
35
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
36
|
+
If treescope is not installed or the code is not running in IPython,
|
37
|
+
``display`` will print the objects instead.
|
38
|
+
"""
|
39
|
+
if not treescope_installed or not in_ipython:
|
40
|
+
for x in args:
|
41
|
+
print(x)
|
42
|
+
return
|
43
43
|
|
44
|
-
|
44
|
+
import treescope # type: ignore[import-not-found,import-untyped]
|
45
45
|
|
46
|
-
|
47
|
-
|
46
|
+
for x in args:
|
47
|
+
treescope.display(x, ignore_exceptions=True, autovisualize=True)
|