brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,602 @@
|
|
1
|
+
"""
|
2
|
+
Comprehensive tests for the struct module.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pickle
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import jax
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import jax.tree_util
|
11
|
+
import pytest
|
12
|
+
|
13
|
+
# Import the modules to test
|
14
|
+
from brainstate.util import (
|
15
|
+
field,
|
16
|
+
dataclass,
|
17
|
+
PyTreeNode,
|
18
|
+
FrozenDict,
|
19
|
+
freeze,
|
20
|
+
unfreeze,
|
21
|
+
copy,
|
22
|
+
pop,
|
23
|
+
pretty_repr,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class TestField:
|
28
|
+
"""Test the field function."""
|
29
|
+
|
30
|
+
def test_field_with_pytree_node_true(self):
|
31
|
+
"""Test field with pytree_node=True."""
|
32
|
+
f = field(pytree_node=True)
|
33
|
+
assert f.metadata['pytree_node'] is True
|
34
|
+
|
35
|
+
def test_field_with_pytree_node_false(self):
|
36
|
+
"""Test field with pytree_node=False."""
|
37
|
+
f = field(pytree_node=False)
|
38
|
+
assert f.metadata['pytree_node'] is False
|
39
|
+
|
40
|
+
def test_field_with_default(self):
|
41
|
+
"""Test field with default value."""
|
42
|
+
f = field(default=42)
|
43
|
+
assert f.default == 42
|
44
|
+
assert f.metadata['pytree_node'] is True
|
45
|
+
|
46
|
+
def test_field_with_metadata(self):
|
47
|
+
"""Test field preserves additional metadata."""
|
48
|
+
f = field(pytree_node=False, metadata={'custom': 'data'})
|
49
|
+
assert f.metadata['pytree_node'] is False
|
50
|
+
assert f.metadata['custom'] == 'data'
|
51
|
+
|
52
|
+
|
53
|
+
class TestDataclass:
|
54
|
+
"""Test the dataclass decorator."""
|
55
|
+
|
56
|
+
def test_basic_dataclass(self):
|
57
|
+
"""Test basic dataclass creation."""
|
58
|
+
|
59
|
+
@dataclass
|
60
|
+
class Point:
|
61
|
+
x: float
|
62
|
+
y: float
|
63
|
+
|
64
|
+
p = Point(1.0, 2.0)
|
65
|
+
assert p.x == 1.0
|
66
|
+
assert p.y == 2.0
|
67
|
+
|
68
|
+
def test_dataclass_is_frozen(self):
|
69
|
+
"""Test that dataclasses are frozen by default."""
|
70
|
+
|
71
|
+
@dataclass
|
72
|
+
class Point:
|
73
|
+
x: float
|
74
|
+
y: float
|
75
|
+
|
76
|
+
p = Point(1.0, 2.0)
|
77
|
+
with pytest.raises(Exception): # Should be immutable
|
78
|
+
p.x = 3.0
|
79
|
+
|
80
|
+
def test_dataclass_replace_method(self):
|
81
|
+
"""Test the replace method."""
|
82
|
+
|
83
|
+
@dataclass
|
84
|
+
class Point:
|
85
|
+
x: float
|
86
|
+
y: float
|
87
|
+
|
88
|
+
p1 = Point(1.0, 2.0)
|
89
|
+
p2 = p1.replace(x=3.0)
|
90
|
+
assert p1.x == 1.0
|
91
|
+
assert p2.x == 3.0
|
92
|
+
assert p2.y == 2.0
|
93
|
+
|
94
|
+
def test_dataclass_with_defaults(self):
|
95
|
+
"""Test dataclass with default values."""
|
96
|
+
|
97
|
+
@dataclass
|
98
|
+
class Config:
|
99
|
+
learning_rate: float = 0.001
|
100
|
+
batch_size: int = 32
|
101
|
+
name: str = field(default="default", pytree_node=False)
|
102
|
+
|
103
|
+
c1 = Config()
|
104
|
+
assert c1.learning_rate == 0.001
|
105
|
+
assert c1.batch_size == 32
|
106
|
+
assert c1.name == "default"
|
107
|
+
|
108
|
+
c2 = Config(learning_rate=0.01)
|
109
|
+
assert c2.learning_rate == 0.01
|
110
|
+
|
111
|
+
def test_dataclass_pytree_behavior(self):
|
112
|
+
"""Test that dataclass works as JAX pytree."""
|
113
|
+
|
114
|
+
@dataclass
|
115
|
+
class Model:
|
116
|
+
weights: jax.Array
|
117
|
+
bias: jax.Array
|
118
|
+
name: str = field(pytree_node=False, default="model")
|
119
|
+
|
120
|
+
weights = jnp.ones((3, 3))
|
121
|
+
bias = jnp.zeros(3)
|
122
|
+
model = Model(weights=weights, bias=bias)
|
123
|
+
|
124
|
+
# Test tree_map
|
125
|
+
model2 = jax.tree_util.tree_map(lambda x: x * 2, model)
|
126
|
+
assert jnp.allclose(model2.weights, weights * 2)
|
127
|
+
assert jnp.allclose(model2.bias, bias * 2)
|
128
|
+
assert model2.name == "model" # Should not be affected
|
129
|
+
|
130
|
+
# Test tree_leaves
|
131
|
+
leaves = jax.tree_util.tree_leaves(model)
|
132
|
+
assert len(leaves) == 2 # Only weights and bias
|
133
|
+
|
134
|
+
def test_dataclass_with_jax_transformations(self):
|
135
|
+
"""Test dataclass with JAX transformations."""
|
136
|
+
|
137
|
+
@dataclass
|
138
|
+
class Linear:
|
139
|
+
weight: jax.Array
|
140
|
+
bias: jax.Array
|
141
|
+
|
142
|
+
layer = Linear(
|
143
|
+
weight=jnp.ones((4, 3)),
|
144
|
+
bias=jnp.zeros(4)
|
145
|
+
)
|
146
|
+
|
147
|
+
# Test with jit
|
148
|
+
@jax.jit
|
149
|
+
def apply(layer, x):
|
150
|
+
return jnp.dot(x, layer.weight.T) + layer.bias
|
151
|
+
|
152
|
+
x = jnp.ones(3)
|
153
|
+
y = apply(layer, x)
|
154
|
+
assert y.shape == (4,)
|
155
|
+
|
156
|
+
# Test with grad
|
157
|
+
def loss_fn(layer):
|
158
|
+
return jnp.sum(layer.weight ** 2) + jnp.sum(layer.bias ** 2)
|
159
|
+
|
160
|
+
grad_fn = jax.grad(loss_fn)
|
161
|
+
grads = grad_fn(layer)
|
162
|
+
assert grads.weight.shape == layer.weight.shape
|
163
|
+
assert grads.bias.shape == layer.bias.shape
|
164
|
+
|
165
|
+
def test_dataclass_no_double_decoration(self):
|
166
|
+
"""Test that dataclass decorator is idempotent."""
|
167
|
+
|
168
|
+
@dataclass
|
169
|
+
@dataclass # Should not cause issues
|
170
|
+
class Point:
|
171
|
+
x: float
|
172
|
+
y: float
|
173
|
+
|
174
|
+
p = Point(1.0, 2.0)
|
175
|
+
assert p.x == 1.0
|
176
|
+
assert hasattr(Point, '_brainstate_dataclass')
|
177
|
+
|
178
|
+
|
179
|
+
class TestPyTreeNode:
|
180
|
+
"""Test the PyTreeNode base class."""
|
181
|
+
|
182
|
+
def test_pytreenode_subclass(self):
|
183
|
+
"""Test creating a PyTreeNode subclass."""
|
184
|
+
|
185
|
+
class Layer(PyTreeNode):
|
186
|
+
weights: jax.Array
|
187
|
+
bias: jax.Array
|
188
|
+
activation: str = field(pytree_node=False, default="relu")
|
189
|
+
|
190
|
+
layer = Layer(
|
191
|
+
weights=jnp.ones((4, 4)),
|
192
|
+
bias=jnp.zeros(4)
|
193
|
+
)
|
194
|
+
assert layer.activation == "relu"
|
195
|
+
assert jnp.allclose(layer.weights, jnp.ones((4, 4)))
|
196
|
+
|
197
|
+
def test_pytreenode_is_frozen(self):
|
198
|
+
"""Test that PyTreeNode subclasses are frozen."""
|
199
|
+
|
200
|
+
class Layer(PyTreeNode):
|
201
|
+
weights: jax.Array
|
202
|
+
|
203
|
+
layer = Layer(weights=jnp.ones(3))
|
204
|
+
with pytest.raises(Exception):
|
205
|
+
layer.weights = jnp.zeros(3)
|
206
|
+
|
207
|
+
def test_pytreenode_replace(self):
|
208
|
+
"""Test replace method on PyTreeNode."""
|
209
|
+
|
210
|
+
class Layer(PyTreeNode):
|
211
|
+
weights: jax.Array
|
212
|
+
bias: jax.Array
|
213
|
+
|
214
|
+
layer1 = Layer(weights=jnp.ones(3), bias=jnp.zeros(3))
|
215
|
+
layer2 = layer1.replace(weights=jnp.ones(3) * 2)
|
216
|
+
assert jnp.allclose(layer2.weights, jnp.ones(3) * 2)
|
217
|
+
assert jnp.allclose(layer2.bias, jnp.zeros(3))
|
218
|
+
|
219
|
+
def test_pytreenode_with_jax(self):
|
220
|
+
"""Test PyTreeNode with JAX transformations."""
|
221
|
+
|
222
|
+
class MLP(PyTreeNode):
|
223
|
+
layer1: Any
|
224
|
+
layer2: Any
|
225
|
+
|
226
|
+
class Linear(PyTreeNode):
|
227
|
+
weight: jax.Array
|
228
|
+
bias: jax.Array
|
229
|
+
|
230
|
+
mlp = MLP(
|
231
|
+
layer1=Linear(weight=jnp.ones((4, 3)), bias=jnp.zeros(4)),
|
232
|
+
layer2=Linear(weight=jnp.ones((2, 4)), bias=jnp.zeros(2))
|
233
|
+
)
|
234
|
+
|
235
|
+
# Test tree_map
|
236
|
+
mlp2 = jax.tree_util.tree_map(lambda x: x * 2, mlp)
|
237
|
+
assert jnp.allclose(mlp2.layer1.weight, mlp.layer1.weight * 2)
|
238
|
+
|
239
|
+
# Test with grad
|
240
|
+
def loss_fn(model):
|
241
|
+
return jnp.sum(model.layer1.weight ** 2) + jnp.sum(model.layer2.weight ** 2)
|
242
|
+
|
243
|
+
grad_fn = jax.grad(loss_fn)
|
244
|
+
grads = grad_fn(mlp)
|
245
|
+
assert grads.layer1.weight.shape == mlp.layer1.weight.shape
|
246
|
+
|
247
|
+
|
248
|
+
class TestFrozenDict:
|
249
|
+
"""Test the FrozenDict class."""
|
250
|
+
|
251
|
+
def test_frozendict_creation(self):
|
252
|
+
"""Test creating FrozenDict."""
|
253
|
+
# From dict
|
254
|
+
fd1 = FrozenDict({'a': 1, 'b': 2})
|
255
|
+
assert fd1['a'] == 1
|
256
|
+
assert fd1['b'] == 2
|
257
|
+
|
258
|
+
# From kwargs
|
259
|
+
fd2 = FrozenDict(a=1, b=2)
|
260
|
+
assert fd2['a'] == 1
|
261
|
+
|
262
|
+
# From items
|
263
|
+
fd3 = FrozenDict([('a', 1), ('b', 2)])
|
264
|
+
assert fd3['a'] == 1
|
265
|
+
|
266
|
+
def test_frozendict_immutability(self):
|
267
|
+
"""Test that FrozenDict is immutable."""
|
268
|
+
fd = FrozenDict({'a': 1})
|
269
|
+
|
270
|
+
with pytest.raises(TypeError):
|
271
|
+
fd['a'] = 2
|
272
|
+
|
273
|
+
with pytest.raises(TypeError):
|
274
|
+
fd['c'] = 3
|
275
|
+
|
276
|
+
with pytest.raises(TypeError):
|
277
|
+
del fd['a']
|
278
|
+
|
279
|
+
def test_frozendict_basic_operations(self):
|
280
|
+
"""Test basic dictionary operations."""
|
281
|
+
fd = FrozenDict({'a': 1, 'b': 2, 'c': 3})
|
282
|
+
|
283
|
+
# Contains
|
284
|
+
assert 'a' in fd
|
285
|
+
assert 'd' not in fd
|
286
|
+
|
287
|
+
# Length
|
288
|
+
assert len(fd) == 3
|
289
|
+
|
290
|
+
# Iteration
|
291
|
+
keys = list(fd)
|
292
|
+
assert set(keys) == {'a', 'b', 'c'}
|
293
|
+
|
294
|
+
# Get
|
295
|
+
assert fd.get('a') == 1
|
296
|
+
assert fd.get('d') is None
|
297
|
+
assert fd.get('d', 10) == 10
|
298
|
+
|
299
|
+
def test_frozendict_views(self):
|
300
|
+
"""Test dictionary views."""
|
301
|
+
fd = FrozenDict({'a': 1, 'b': 2})
|
302
|
+
|
303
|
+
# Keys view
|
304
|
+
keys = fd.keys()
|
305
|
+
assert set(keys) == {'a', 'b'}
|
306
|
+
assert 'FrozenDict.keys' in repr(keys)
|
307
|
+
|
308
|
+
# Values view
|
309
|
+
values = fd.values()
|
310
|
+
assert set(values) == {1, 2}
|
311
|
+
assert 'FrozenDict.values' in repr(values)
|
312
|
+
|
313
|
+
# Items view
|
314
|
+
items = list(fd.items())
|
315
|
+
assert len(items) == 2
|
316
|
+
assert ('a', 1) in items
|
317
|
+
|
318
|
+
def test_frozendict_copy(self):
|
319
|
+
"""Test copy method."""
|
320
|
+
fd1 = FrozenDict({'a': 1, 'b': 2})
|
321
|
+
|
322
|
+
# Copy without changes
|
323
|
+
fd2 = fd1.copy()
|
324
|
+
assert fd2 == fd1
|
325
|
+
assert fd2 is not fd1
|
326
|
+
|
327
|
+
# Copy with updates
|
328
|
+
fd3 = fd1.copy({'c': 3, 'a': 10})
|
329
|
+
assert fd3['a'] == 10
|
330
|
+
assert fd3['b'] == 2
|
331
|
+
assert fd3['c'] == 3
|
332
|
+
assert fd1['a'] == 1 # Original unchanged
|
333
|
+
|
334
|
+
def test_frozendict_pop(self):
|
335
|
+
"""Test pop method."""
|
336
|
+
fd1 = FrozenDict({'a': 1, 'b': 2, 'c': 3})
|
337
|
+
|
338
|
+
fd2, value = fd1.pop('b')
|
339
|
+
assert value == 2
|
340
|
+
assert 'b' not in fd2
|
341
|
+
assert len(fd2) == 2
|
342
|
+
assert 'b' in fd1 # Original unchanged
|
343
|
+
|
344
|
+
# Pop non-existent key
|
345
|
+
with pytest.raises(KeyError):
|
346
|
+
fd1.pop('d')
|
347
|
+
|
348
|
+
def test_frozendict_nested(self):
|
349
|
+
"""Test nested FrozenDict."""
|
350
|
+
fd = FrozenDict({
|
351
|
+
'a': 1,
|
352
|
+
'b': {'c': 2, 'd': {'e': 3}}
|
353
|
+
})
|
354
|
+
|
355
|
+
# Access nested values
|
356
|
+
assert fd['b']['c'] == 2
|
357
|
+
assert fd['b']['d']['e'] == 3
|
358
|
+
|
359
|
+
# Nested values are also FrozenDict
|
360
|
+
assert isinstance(fd['b'], FrozenDict)
|
361
|
+
assert isinstance(fd['b']['d'], FrozenDict)
|
362
|
+
|
363
|
+
def test_frozendict_hash(self):
|
364
|
+
"""Test FrozenDict hashing."""
|
365
|
+
fd1 = FrozenDict({'a': 1, 'b': 2})
|
366
|
+
fd2 = FrozenDict({'a': 1, 'b': 2})
|
367
|
+
fd3 = FrozenDict({'a': 1, 'b': 3})
|
368
|
+
|
369
|
+
# Equal dicts have same hash
|
370
|
+
assert hash(fd1) == hash(fd2)
|
371
|
+
|
372
|
+
# Can be used in sets
|
373
|
+
s = {fd1, fd2, fd3}
|
374
|
+
assert len(s) == 2
|
375
|
+
|
376
|
+
def test_frozendict_equality(self):
|
377
|
+
"""Test FrozenDict equality."""
|
378
|
+
fd1 = FrozenDict({'a': 1, 'b': 2})
|
379
|
+
fd2 = FrozenDict({'a': 1, 'b': 2})
|
380
|
+
fd3 = FrozenDict({'a': 1, 'b': 3})
|
381
|
+
d = {'a': 1, 'b': 2}
|
382
|
+
|
383
|
+
assert fd1 == fd2
|
384
|
+
assert fd1 != fd3
|
385
|
+
assert fd1 == d
|
386
|
+
assert fd1 != "not a dict"
|
387
|
+
|
388
|
+
def test_frozendict_pickle(self):
|
389
|
+
"""Test FrozenDict pickling."""
|
390
|
+
fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
391
|
+
|
392
|
+
# Pickle and unpickle
|
393
|
+
pickled = pickle.dumps(fd)
|
394
|
+
fd2 = pickle.loads(pickled)
|
395
|
+
|
396
|
+
assert fd == fd2
|
397
|
+
assert fd['b']['c'] == fd2['b']['c']
|
398
|
+
|
399
|
+
def test_frozendict_pretty_repr(self):
|
400
|
+
"""Test pretty representation."""
|
401
|
+
fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
402
|
+
repr_str = fd.pretty_repr()
|
403
|
+
|
404
|
+
assert 'FrozenDict' in repr_str
|
405
|
+
assert "'a': 1" in repr_str
|
406
|
+
assert "'c': 2" in repr_str
|
407
|
+
|
408
|
+
def test_frozendict_as_pytree(self):
|
409
|
+
"""Test FrozenDict as JAX pytree."""
|
410
|
+
fd = FrozenDict({'a': jnp.ones(3), 'b': jnp.zeros(2)})
|
411
|
+
|
412
|
+
# Tree map
|
413
|
+
fd2 = jax.tree_util.tree_map(lambda x: x * 2, fd)
|
414
|
+
assert jnp.allclose(fd2['a'], jnp.ones(3) * 2)
|
415
|
+
assert jnp.allclose(fd2['b'], jnp.zeros(2))
|
416
|
+
|
417
|
+
# Tree leaves
|
418
|
+
leaves = jax.tree_util.tree_leaves(fd)
|
419
|
+
assert len(leaves) == 2
|
420
|
+
|
421
|
+
# Tree flatten and unflatten
|
422
|
+
values, treedef = jax.tree_util.tree_flatten(fd)
|
423
|
+
fd3 = jax.tree_util.tree_unflatten(treedef, values)
|
424
|
+
assert fd == fd3
|
425
|
+
|
426
|
+
|
427
|
+
class TestUtilityFunctions:
|
428
|
+
"""Test utility functions."""
|
429
|
+
|
430
|
+
def test_freeze(self):
|
431
|
+
"""Test freeze function."""
|
432
|
+
# Regular dict
|
433
|
+
d = {'a': 1, 'b': {'c': 2}}
|
434
|
+
fd = freeze(d)
|
435
|
+
assert isinstance(fd, FrozenDict)
|
436
|
+
assert fd['a'] == 1
|
437
|
+
assert isinstance(fd['b'], FrozenDict)
|
438
|
+
|
439
|
+
# Already frozen
|
440
|
+
fd2 = freeze(fd)
|
441
|
+
assert fd2 is fd
|
442
|
+
|
443
|
+
def test_unfreeze(self):
|
444
|
+
"""Test unfreeze function."""
|
445
|
+
# FrozenDict
|
446
|
+
fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
447
|
+
d = unfreeze(fd)
|
448
|
+
assert isinstance(d, dict)
|
449
|
+
assert not isinstance(d, FrozenDict)
|
450
|
+
assert d['a'] == 1
|
451
|
+
assert isinstance(d['b'], dict)
|
452
|
+
|
453
|
+
# Regular dict
|
454
|
+
d2 = {'a': 1}
|
455
|
+
d3 = unfreeze(d2)
|
456
|
+
assert d3 == d2
|
457
|
+
assert d3 is not d2 # Should be a copy
|
458
|
+
|
459
|
+
# Non-dict
|
460
|
+
assert unfreeze(42) == 42
|
461
|
+
|
462
|
+
def test_copy_function(self):
|
463
|
+
"""Test copy function."""
|
464
|
+
# FrozenDict
|
465
|
+
fd1 = FrozenDict({'a': 1})
|
466
|
+
fd2 = copy(fd1, {'b': 2})
|
467
|
+
assert isinstance(fd2, FrozenDict)
|
468
|
+
assert fd2['a'] == 1
|
469
|
+
assert fd2['b'] == 2
|
470
|
+
|
471
|
+
# Regular dict
|
472
|
+
d1 = {'a': 1}
|
473
|
+
d2 = copy(d1, {'b': 2})
|
474
|
+
assert isinstance(d2, dict)
|
475
|
+
assert not isinstance(d2, FrozenDict)
|
476
|
+
assert d2['a'] == 1
|
477
|
+
assert d2['b'] == 2
|
478
|
+
assert d1 == {'a': 1} # Original unchanged
|
479
|
+
|
480
|
+
# Invalid type
|
481
|
+
with pytest.raises(TypeError):
|
482
|
+
copy([1, 2, 3])
|
483
|
+
|
484
|
+
def test_pop_function(self):
|
485
|
+
"""Test pop function."""
|
486
|
+
# FrozenDict
|
487
|
+
fd1 = FrozenDict({'a': 1, 'b': 2})
|
488
|
+
fd2, value = pop(fd1, 'a')
|
489
|
+
assert isinstance(fd2, FrozenDict)
|
490
|
+
assert value == 1
|
491
|
+
assert 'a' not in fd2
|
492
|
+
assert 'a' in fd1
|
493
|
+
|
494
|
+
# Regular dict
|
495
|
+
d1 = {'a': 1, 'b': 2}
|
496
|
+
d2, value = pop(d1, 'a')
|
497
|
+
assert isinstance(d2, dict)
|
498
|
+
assert value == 1
|
499
|
+
assert 'a' not in d2
|
500
|
+
assert 'a' in d1
|
501
|
+
|
502
|
+
# Invalid type
|
503
|
+
with pytest.raises(TypeError):
|
504
|
+
pop([1, 2, 3], 0)
|
505
|
+
|
506
|
+
def test_pretty_repr_function(self):
|
507
|
+
"""Test pretty_repr function."""
|
508
|
+
# FrozenDict
|
509
|
+
fd = FrozenDict({'a': 1, 'b': {'c': 2}})
|
510
|
+
s = pretty_repr(fd)
|
511
|
+
assert 'FrozenDict' in s
|
512
|
+
|
513
|
+
# Regular dict
|
514
|
+
d = {'a': 1, 'b': {'c': 2}}
|
515
|
+
s = pretty_repr(d)
|
516
|
+
assert 'a' in s
|
517
|
+
assert 'c' in s
|
518
|
+
|
519
|
+
# Other type
|
520
|
+
s = pretty_repr([1, 2, 3])
|
521
|
+
assert s == "[1, 2, 3]"
|
522
|
+
|
523
|
+
|
524
|
+
class TestIntegration:
|
525
|
+
"""Integration tests combining multiple features."""
|
526
|
+
|
527
|
+
def test_nested_structures(self):
|
528
|
+
"""Test complex nested structures."""
|
529
|
+
|
530
|
+
@dataclass
|
531
|
+
class Config:
|
532
|
+
hyperparams: FrozenDict
|
533
|
+
metadata: dict = field(pytree_node=False)
|
534
|
+
|
535
|
+
class Model(PyTreeNode):
|
536
|
+
config: Config
|
537
|
+
weights: jax.Array
|
538
|
+
|
539
|
+
config = Config(
|
540
|
+
hyperparams=FrozenDict({'lr': 0.001, 'batch_size': 32}),
|
541
|
+
metadata={'version': '1.0'}
|
542
|
+
)
|
543
|
+
model = Model(
|
544
|
+
config=config,
|
545
|
+
weights=jnp.ones((4, 4))
|
546
|
+
)
|
547
|
+
|
548
|
+
# Test tree operations
|
549
|
+
model2 = jax.tree_util.tree_map(lambda x: x * 2 if isinstance(x, jax.Array) else x, model)
|
550
|
+
assert jnp.allclose(model2.weights, model.weights * 2)
|
551
|
+
assert model2.config.hyperparams['lr'] == 0.001
|
552
|
+
assert model2.config.metadata['version'] == '1.0'
|
553
|
+
|
554
|
+
def test_jax_transformations_integration(self):
|
555
|
+
"""Test integration with various JAX transformations."""
|
556
|
+
|
557
|
+
@dataclass
|
558
|
+
class State:
|
559
|
+
params: FrozenDict
|
560
|
+
step: int = field(pytree_node=False, default=0)
|
561
|
+
|
562
|
+
state = State(
|
563
|
+
params=FrozenDict({
|
564
|
+
'w': jnp.ones((3, 3)),
|
565
|
+
'b': jnp.zeros(3)
|
566
|
+
})
|
567
|
+
)
|
568
|
+
|
569
|
+
# JIT compilation
|
570
|
+
@jax.jit
|
571
|
+
def update(state, grad):
|
572
|
+
new_params = jax.tree_util.tree_map(
|
573
|
+
lambda p, g: p - 0.01 * g,
|
574
|
+
state.params,
|
575
|
+
grad
|
576
|
+
)
|
577
|
+
return state.replace(
|
578
|
+
params=new_params,
|
579
|
+
step=state.step + 1
|
580
|
+
)
|
581
|
+
|
582
|
+
grad = FrozenDict({'w': jnp.ones((3, 3)), 'b': jnp.ones(3)})
|
583
|
+
new_state = update(state, grad)
|
584
|
+
assert new_state.step == 1
|
585
|
+
assert jnp.allclose(new_state.params['w'], state.params['w'] - 0.01)
|
586
|
+
|
587
|
+
# VMAP
|
588
|
+
@jax.vmap
|
589
|
+
def batch_process(params, x):
|
590
|
+
return jnp.dot(x, params['w']) + params['b']
|
591
|
+
|
592
|
+
batch_params = jax.tree_util.tree_map(
|
593
|
+
lambda x: jnp.stack([x, x * 2]),
|
594
|
+
state.params
|
595
|
+
)
|
596
|
+
batch_x = jnp.ones((2, 3))
|
597
|
+
result = batch_process(batch_params, batch_x)
|
598
|
+
assert result.shape == (2, 3)
|
599
|
+
|
600
|
+
|
601
|
+
if __name__ == "__main__":
|
602
|
+
pytest.main([__file__, "-v"])
|