brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,524 @@
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Utilities for defining custom classes that can be used with jax transformations."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import collections
23
+ import dataclasses
24
+ from collections.abc import Hashable, Mapping
25
+ from types import MappingProxyType
26
+ from typing import Any, TypeVar
27
+
28
+ import jax
29
+ from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
30
+
31
+ __all__ = [
32
+ 'dataclass',
33
+ 'field',
34
+ 'PyTreeNode',
35
+ 'FrozenDict',
36
+ ]
37
+
38
+ K = TypeVar('K')
39
+ V = TypeVar('V')
40
+ T = TypeVar('T')
41
+
42
+
43
+ def field(pytree_node=True, *, metadata=None, **kwargs):
44
+ return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node}, **kwargs)
45
+
46
+
47
+ @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
48
+ def dataclass(clz: T, **kwargs) -> T:
49
+ """Create a class which can be passed to functional transformations.
50
+
51
+ .. note::
52
+ Inherit from ``PyTreeNode`` instead to avoid type checking issues when
53
+ using PyType.
54
+
55
+ Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are
56
+ immutable and can be mapped over using the ``jax.tree_util`` methods.
57
+ The ``dataclass`` decorator makes it easy to define custom classes that can be
58
+ passed safely to Jax. For example::
59
+
60
+ >>> import brainstate as bst
61
+ >>> import jax
62
+ >>> from typing import Any, Callable
63
+
64
+ >>> @bst.util.dataclass
65
+ ... class Model:
66
+ ... params: Any
67
+ ... # use pytree_node=False to indicate an attribute should not be touched
68
+ ... # by Jax transformations.
69
+ ... apply_fn: Callable = bst.util.field(pytree_node=False)
70
+
71
+ ... def __apply__(self, *args):
72
+ ... return self.apply_fn(*args)
73
+
74
+ >>> params = {}
75
+ >>> params_b = {}
76
+ >>> apply_fn = lambda v, x: x
77
+ >>> model = Model(params, apply_fn)
78
+
79
+ >>> # model.params = params_b # Model is immutable. This will raise an error.
80
+ >>> model_b = model.replace(params=params_b) # Use the replace method instead.
81
+
82
+ >>> # This class can now be used safely in Jax to compute gradients w.r.t. the
83
+ >>> # parameters.
84
+ >>> model = Model(params, apply_fn)
85
+ >>> loss_fn = lambda model: 3.
86
+ >>> model_grad = jax.grad(loss_fn)(model)
87
+
88
+ Note that dataclasses have an auto-generated ``__init__`` where
89
+ the arguments of the constructor and the attributes of the created
90
+ instance match 1:1. This correspondence is what makes these objects
91
+ valid containers that work with JAX transformations and
92
+ more generally the ``jax.tree_util`` library.
93
+
94
+ Sometimes a "smart constructor" is desired, for example because
95
+ some of the attributes can be (optionally) derived from others.
96
+ The way to do this with Flax dataclasses is to make a static or
97
+ class method that provides the smart constructor.
98
+ This way the simple constructor used by ``jax.tree_util`` is
99
+ preserved. Consider the following example::
100
+
101
+ >>> @bst.util.dataclass
102
+ ... class DirectionAndScaleKernel:
103
+ ... direction: jax.Array
104
+ ... scale: jax.Array
105
+
106
+ ... @classmethod
107
+ ... def create(cls, kernel):
108
+ ... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
109
+ ... direction = direction / scale
110
+ ... return cls(direction, scale)
111
+
112
+ Args:
113
+ clz: the class that will be transformed by the decorator.
114
+ Returns:
115
+ The new class.
116
+ """
117
+ # check if already a flax dataclass
118
+ if '_flax_dataclass' in clz.__dict__:
119
+ return clz
120
+
121
+ if 'frozen' not in kwargs.keys():
122
+ kwargs['frozen'] = True
123
+ data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
124
+ meta_fields = []
125
+ data_fields = []
126
+ for field_info in dataclasses.fields(data_clz):
127
+ is_pytree_node = field_info.metadata.get('pytree_node', True)
128
+ if is_pytree_node:
129
+ data_fields.append(field_info.name)
130
+ else:
131
+ meta_fields.append(field_info.name)
132
+
133
+ def replace(self, **updates):
134
+ """ "Returns a new object replacing the specified fields with new values."""
135
+ return dataclasses.replace(self, **updates)
136
+
137
+ data_clz.replace = replace
138
+
139
+ # Remove this guard once minimux JAX version is >0.4.26.
140
+ try:
141
+ if hasattr(jax.tree_util, 'register_dataclass'):
142
+ jax.tree_util.register_dataclass(
143
+ data_clz, data_fields, meta_fields
144
+ )
145
+ else:
146
+ raise NotImplementedError
147
+ except NotImplementedError:
148
+
149
+ def iterate_clz(x):
150
+ meta = tuple(getattr(x, name) for name in meta_fields)
151
+ data = tuple(getattr(x, name) for name in data_fields)
152
+ return data, meta
153
+
154
+ def iterate_clz_with_keys(x):
155
+ meta = tuple(getattr(x, name) for name in meta_fields)
156
+ data = tuple(
157
+ (jax.tree_util.GetAttrKey(name), getattr(x, name))
158
+ for name in data_fields
159
+ )
160
+ return data, meta
161
+
162
+ def clz_from_iterable(meta, data):
163
+ meta_args = tuple(zip(meta_fields, meta))
164
+ data_args = tuple(zip(data_fields, data))
165
+ kwargs = dict(meta_args + data_args)
166
+ return data_clz(**kwargs)
167
+
168
+ jax.tree_util.register_pytree_with_keys(
169
+ data_clz,
170
+ iterate_clz_with_keys,
171
+ clz_from_iterable,
172
+ iterate_clz,
173
+ )
174
+
175
+ # add a _flax_dataclass flag to distinguish from regular dataclasses
176
+ data_clz._flax_dataclass = True # type: ignore[attr-defined]
177
+
178
+ return data_clz # type: ignore
179
+
180
+
181
+ TNode = TypeVar('TNode', bound='PyTreeNode')
182
+
183
+
184
+ @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
185
+ class PyTreeNode:
186
+ """Base class for dataclasses that should act like a JAX pytree node.
187
+
188
+ See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.
189
+ This base class additionally avoids type checking errors when using PyType.
190
+
191
+ Example::
192
+
193
+ >>> import brainstate as bst
194
+ >>> import jax
195
+ >>> from typing import Any, Callable
196
+
197
+ >>> class Model(bst.util.PyTreeNode):
198
+ ... params: Any
199
+ ... # use pytree_node=False to indicate an attribute should not be touched
200
+ ... # by Jax transformations.
201
+ ... apply_fn: Callable = bst.util.field(pytree_node=False)
202
+
203
+ ... def __apply__(self, *args):
204
+ ... return self.apply_fn(*args)
205
+
206
+ >>> params = {}
207
+ >>> params_b = {}
208
+ >>> apply_fn = lambda v, x: x
209
+ >>> model = Model(params, apply_fn)
210
+
211
+ >>> # model.params = params_b # Model is immutable. This will raise an error.
212
+ >>> model_b = model.replace(params=params_b) # Use the replace method instead.
213
+
214
+ >>> # This class can now be used safely in Jax to compute gradients w.r.t. the
215
+ >>> # parameters.
216
+ >>> model = Model(params, apply_fn)
217
+ >>> loss_fn = lambda model: 3.
218
+ >>> model_grad = jax.grad(loss_fn)(model)
219
+ """
220
+
221
+ def __init_subclass__(cls, **kwargs):
222
+ dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types
223
+
224
+ def __init__(self, *args, **kwargs):
225
+ # stub for pytype
226
+ raise NotImplementedError
227
+
228
+ def replace(self: TNode, **overrides) -> TNode:
229
+ # stub for pytype
230
+ raise NotImplementedError
231
+
232
+
233
+ def _indent(x, num_spaces):
234
+ indent_str = ' ' * num_spaces
235
+ lines = x.split('\n')
236
+ assert not lines[-1]
237
+ # skip the final line because it's empty and should not be indented.
238
+ return '\n'.join(indent_str + line for line in lines[:-1]) + '\n'
239
+
240
+
241
+ @jax.tree_util.register_pytree_with_keys_class
242
+ class FrozenDict(Mapping[K, V]):
243
+ """An immutable variant of the Python dict."""
244
+
245
+ __slots__ = ('_dict', '_hash')
246
+
247
+ def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name
248
+ # make sure the dict is as
249
+ xs = dict(*args, **kwargs)
250
+ if __unsafe_skip_copy__:
251
+ self._dict = xs
252
+ else:
253
+ self._dict = _prepare_freeze(xs)
254
+
255
+ self._hash = None
256
+
257
+ def __getitem__(self, key):
258
+ v = self._dict[key]
259
+ if isinstance(v, dict):
260
+ return FrozenDict(v)
261
+ return v
262
+
263
+ def __setitem__(self, key, value):
264
+ raise ValueError('FrozenDict is immutable.')
265
+
266
+ def __contains__(self, key):
267
+ return key in self._dict
268
+
269
+ def __iter__(self):
270
+ return iter(self._dict)
271
+
272
+ def __len__(self):
273
+ return len(self._dict)
274
+
275
+ def __repr__(self):
276
+ return self.pretty_repr()
277
+
278
+ def __reduce__(self):
279
+ return FrozenDict, (self.unfreeze(),)
280
+
281
+ def pretty_repr(self, num_spaces=4):
282
+ """Returns an indented representation of the nested dictionary."""
283
+
284
+ def pretty_dict(x):
285
+ if not isinstance(x, dict):
286
+ return repr(x)
287
+ rep = ''
288
+ for key, val in x.items():
289
+ rep += f'{key}: {pretty_dict(val)},\n'
290
+ if rep:
291
+ return '{\n' + _indent(rep, num_spaces) + '}'
292
+ else:
293
+ return '{}'
294
+
295
+ return f'FrozenDict({pretty_dict(self._dict)})'
296
+
297
+ def __hash__(self):
298
+ if self._hash is None:
299
+ h = 0
300
+ for key, value in self.items():
301
+ h ^= hash((key, value))
302
+ self._hash = h
303
+ return self._hash
304
+
305
+ def copy(
306
+ self, add_or_replace: Mapping[K, V] = MappingProxyType({})
307
+ ) -> 'FrozenDict[K, V]':
308
+ """Create a new FrozenDict with additional or replaced entries."""
309
+ return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
310
+
311
+ def keys(self):
312
+ return FrozenKeysView(self)
313
+
314
+ def values(self):
315
+ return FrozenValuesView(self)
316
+
317
+ def items(self):
318
+ for key in self._dict:
319
+ yield (key, self[key])
320
+
321
+ def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]:
322
+ """Create a new FrozenDict where one entry is removed.
323
+
324
+ Example::
325
+
326
+ >>> from flax.core import FrozenDict
327
+ >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
328
+ >>> new_variables, params = variables.pop('params')
329
+
330
+ Args:
331
+ key: the key to remove from the dict
332
+ Returns:
333
+ A pair with the new FrozenDict and the removed value.
334
+ """
335
+ value = self[key]
336
+ new_dict = dict(self._dict)
337
+ new_dict.pop(key)
338
+ new_self = type(self)(new_dict)
339
+ return new_self, value
340
+
341
+ def unfreeze(self) -> dict[K, V]:
342
+ """Unfreeze this FrozenDict.
343
+
344
+ Returns:
345
+ An unfrozen version of this FrozenDict instance.
346
+ """
347
+ return unfreeze(self)
348
+
349
+ def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]:
350
+ """Flattens this FrozenDict.
351
+
352
+ Returns:
353
+ A flattened version of this FrozenDict instance.
354
+ """
355
+ sorted_keys = sorted(self._dict)
356
+ return tuple(
357
+ [(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]
358
+ ), tuple(sorted_keys)
359
+
360
+ @classmethod
361
+ def tree_unflatten(cls, keys, values):
362
+ # data is already deep copied due to tree map mechanism
363
+ # we can skip the deep copy in the constructor
364
+ return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
365
+
366
+
367
+ def _prepare_freeze(xs: Any) -> Any:
368
+ """Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
369
+ if isinstance(xs, FrozenDict):
370
+ # we can safely ref share the internal state of a FrozenDict
371
+ # because it is immutable.
372
+ return xs._dict # pylint: disable=protected-access
373
+ if not isinstance(xs, dict):
374
+ # return a leaf as is.
375
+ return xs
376
+ # recursively copy dictionary to avoid ref sharing
377
+ return {key: _prepare_freeze(val) for key, val in xs.items()}
378
+
379
+
380
+ def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]:
381
+ """Freeze a nested dict.
382
+
383
+ Makes a nested ``dict`` immutable by transforming it into ``FrozenDict``.
384
+
385
+ Args:
386
+ xs: Dictionary to freeze (a regualr Python dict).
387
+ Returns:
388
+ The frozen dictionary.
389
+ """
390
+ return FrozenDict(xs)
391
+
392
+
393
+ def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]:
394
+ """Unfreeze a FrozenDict.
395
+
396
+ Makes a mutable copy of a ``FrozenDict`` mutable by transforming
397
+ it into (nested) dict.
398
+
399
+ Args:
400
+ x: Frozen dictionary to unfreeze.
401
+ Returns:
402
+ The unfrozen dictionary (a regular Python dict).
403
+ """
404
+ if isinstance(x, FrozenDict):
405
+ # deep copy internal state of a FrozenDict
406
+ # the dict branch would also work here but
407
+ # it is much less performant because jax.tree_util.tree_map
408
+ # uses an optimized C implementation.
409
+ return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore
410
+ elif isinstance(x, dict):
411
+ ys = {}
412
+ for key, value in x.items():
413
+ ys[key] = unfreeze(value)
414
+ return ys
415
+ else:
416
+ return x
417
+
418
+
419
+ def copy(
420
+ x: FrozenDict | dict[str, Any],
421
+ add_or_replace: FrozenDict[str, Any] | dict[str, Any] = FrozenDict({}),
422
+ ) -> FrozenDict | dict[str, Any]:
423
+ """Create a new dict with additional and/or replaced entries. This is a utility
424
+ function that can act on either a FrozenDict or regular dict and mimics the
425
+ behavior of ``FrozenDict.copy``.
426
+
427
+ Example::
428
+
429
+ >>> from flax.core import FrozenDict, copy
430
+ >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
431
+ >>> new_variables = copy(variables, {'additional_entries': 1})
432
+
433
+ Args:
434
+ x: the dictionary to be copied and updated
435
+ add_or_replace: dictionary of key-value pairs to add or replace in the dict x
436
+ Returns:
437
+ A new dict with the additional and/or replaced entries.
438
+ """
439
+
440
+ if isinstance(x, FrozenDict):
441
+ return x.copy(add_or_replace)
442
+ elif isinstance(x, dict):
443
+ new_dict = jax.tree_util.tree_map(
444
+ lambda x: x, x
445
+ ) # make a deep copy of dict x
446
+ new_dict.update(add_or_replace)
447
+ return new_dict
448
+ raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
449
+
450
+
451
+ def pop(
452
+ x: FrozenDict | dict[str, Any], key: str
453
+ ) -> tuple[FrozenDict | dict[str, Any], Any]:
454
+ """Create a new dict where one entry is removed. This is a utility
455
+ function that can act on either a FrozenDict or regular dict and
456
+ mimics the behavior of ``FrozenDict.pop``.
457
+
458
+ Example::
459
+
460
+ >>> from flax.core import FrozenDict, pop
461
+ >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}})
462
+ >>> new_variables, params = pop(variables, 'params')
463
+
464
+ Args:
465
+ x: the dictionary to remove the entry from
466
+ key: the key to remove from the dict
467
+ Returns:
468
+ A pair with the new dict and the removed value.
469
+ """
470
+
471
+ if isinstance(x, FrozenDict):
472
+ return x.pop(key)
473
+ elif isinstance(x, dict):
474
+ new_dict = jax.tree_util.tree_map(
475
+ lambda x: x, x
476
+ ) # make a deep copy of dict x
477
+ value = new_dict.pop(key)
478
+ return new_dict, value
479
+ raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')
480
+
481
+
482
+ def pretty_repr(x: Any, num_spaces: int = 4) -> str:
483
+ """Returns an indented representation of the nested dictionary.
484
+ This is a utility function that can act on either a FrozenDict or
485
+ regular dict and mimics the behavior of ``FrozenDict.pretty_repr``.
486
+ If x is any other dtype, this function will return ``repr(x)``.
487
+
488
+ Args:
489
+ x: the dictionary to be represented
490
+ num_spaces: the number of space characters in each indentation level
491
+ Returns:
492
+ An indented string representation of the nested dictionary.
493
+ """
494
+
495
+ if isinstance(x, FrozenDict):
496
+ return x.pretty_repr()
497
+ else:
498
+
499
+ def pretty_dict(x):
500
+ if not isinstance(x, dict):
501
+ return repr(x)
502
+ rep = ''
503
+ for key, val in x.items():
504
+ rep += f'{key}: {pretty_dict(val)},\n'
505
+ if rep:
506
+ return '{\n' + _indent(rep, num_spaces) + '}'
507
+ else:
508
+ return '{}'
509
+
510
+ return pretty_dict(x)
511
+
512
+
513
+ class FrozenKeysView(collections.abc.KeysView):
514
+ """A wrapper for a more useful repr of the keys in a frozen dict."""
515
+
516
+ def __repr__(self):
517
+ return f'frozen_dict_keys({list(self)})'
518
+
519
+
520
+ class FrozenValuesView(collections.abc.ValuesView):
521
+ """A wrapper for a more useful repr of the values in a frozen dict."""
522
+
523
+ def __repr__(self):
524
+ return f'frozen_dict_values({list(self)})'
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import jax
18
+ import jax.core
19
+ from jax.interpreters import partial_eval as pe
20
+
21
+ from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr
22
+
23
+ __all__ = [
24
+ 'StateJaxTracer',
25
+ ]
26
+
27
+
28
+ def new_jax_trace():
29
+ main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
30
+ frame = main.jaxpr_stack[-1]
31
+ trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
32
+ return frame, trace
33
+
34
+
35
+ def current_jax_trace():
36
+ """Returns the Jax tracing state."""
37
+ if jax.__version_info__ <= (0, 4, 33):
38
+ return jax.core.thread_local_state.trace_state.trace_stack.dynamic
39
+ return jax.core.get_opaque_trace_state(convention="nnx")
40
+
41
+
42
+ class StateJaxTracer(PrettyRepr):
43
+ __slots__ = ['_jax_trace']
44
+
45
+ def __init__(self):
46
+ self._jax_trace = current_jax_trace()
47
+
48
+ @property
49
+ def jax_trace(self):
50
+ return self._jax_trace
51
+
52
+ def is_valid(self) -> bool:
53
+ if jax.__version_info__ <= (0, 4, 33):
54
+ return self._jax_trace is current_jax_trace()
55
+ else:
56
+ return self._jax_trace == current_jax_trace()
57
+
58
+ def __eq__(self, other):
59
+ if jax.__version_info__ <= (0, 4, 33):
60
+ return isinstance(other, StateJaxTracer) and self._jax_trace is other._jax_trace
61
+ else:
62
+ return isinstance(other, StateJaxTracer) and self._jax_trace == other._jax_trace
63
+
64
+ def __pretty_repr__(self):
65
+ yield PrettyType(f'{type(self).__name__}')
66
+ yield PrettyAttr('jax_trace', self._jax_trace)
67
+
68
+ def __treescope_repr__(self, path, subtree_renderer):
69
+ import treescope # type: ignore[import-not-found,import-untyped]
70
+ return treescope.repr_lib.render_object_constructor(
71
+ object_type=type(self),
72
+ attributes={'jax_trace': self._jax_trace},
73
+ path=path,
74
+ subtree_renderer=subtree_renderer,
75
+ )
@@ -13,35 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
16
17
 
17
18
  __all__ = [
18
- 'display',
19
+ 'display',
19
20
  ]
20
21
 
21
-
22
22
  import importlib.util
23
23
 
24
24
  treescope_installed = importlib.util.find_spec('treescope') is not None
25
25
  try:
26
- from IPython import get_ipython
26
+ from IPython import get_ipython
27
27
 
28
- in_ipython = get_ipython() is not None
28
+ in_ipython = get_ipython() is not None
29
29
  except ImportError:
30
- in_ipython = False
30
+ in_ipython = False
31
31
 
32
32
 
33
33
  def display(*args):
34
- """Display the given objects using the Treescope pretty-printer.
34
+ """Display the given objects using the Treescope pretty-printer.
35
35
 
36
- If treescope is not installed or the code is not running in IPython,
37
- ``display`` will print the objects instead.
38
- """
39
- if not treescope_installed or not in_ipython:
40
- for x in args:
41
- print(x)
42
- return
36
+ If treescope is not installed or the code is not running in IPython,
37
+ ``display`` will print the objects instead.
38
+ """
39
+ if not treescope_installed or not in_ipython:
40
+ for x in args:
41
+ print(x)
42
+ return
43
43
 
44
- import treescope # type: ignore[import-not-found,import-untyped]
44
+ import treescope # type: ignore[import-not-found,import-untyped]
45
45
 
46
- for x in args:
47
- treescope.display(x, ignore_exceptions=True, autovisualize=True)
46
+ for x in args:
47
+ treescope.display(x, ignore_exceptions=True, autovisualize=True)