brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,602 @@
1
+ """
2
+ Comprehensive tests for the struct module.
3
+ """
4
+
5
+ import pickle
6
+ from typing import Any
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jax.tree_util
11
+ import pytest
12
+
13
+ # Import the modules to test
14
+ from brainstate.util import (
15
+ field,
16
+ dataclass,
17
+ PyTreeNode,
18
+ FrozenDict,
19
+ freeze,
20
+ unfreeze,
21
+ copy,
22
+ pop,
23
+ pretty_repr,
24
+ )
25
+
26
+
27
+ class TestField:
28
+ """Test the field function."""
29
+
30
+ def test_field_with_pytree_node_true(self):
31
+ """Test field with pytree_node=True."""
32
+ f = field(pytree_node=True)
33
+ assert f.metadata['pytree_node'] is True
34
+
35
+ def test_field_with_pytree_node_false(self):
36
+ """Test field with pytree_node=False."""
37
+ f = field(pytree_node=False)
38
+ assert f.metadata['pytree_node'] is False
39
+
40
+ def test_field_with_default(self):
41
+ """Test field with default value."""
42
+ f = field(default=42)
43
+ assert f.default == 42
44
+ assert f.metadata['pytree_node'] is True
45
+
46
+ def test_field_with_metadata(self):
47
+ """Test field preserves additional metadata."""
48
+ f = field(pytree_node=False, metadata={'custom': 'data'})
49
+ assert f.metadata['pytree_node'] is False
50
+ assert f.metadata['custom'] == 'data'
51
+
52
+
53
+ class TestDataclass:
54
+ """Test the dataclass decorator."""
55
+
56
+ def test_basic_dataclass(self):
57
+ """Test basic dataclass creation."""
58
+
59
+ @dataclass
60
+ class Point:
61
+ x: float
62
+ y: float
63
+
64
+ p = Point(1.0, 2.0)
65
+ assert p.x == 1.0
66
+ assert p.y == 2.0
67
+
68
+ def test_dataclass_is_frozen(self):
69
+ """Test that dataclasses are frozen by default."""
70
+
71
+ @dataclass
72
+ class Point:
73
+ x: float
74
+ y: float
75
+
76
+ p = Point(1.0, 2.0)
77
+ with pytest.raises(Exception): # Should be immutable
78
+ p.x = 3.0
79
+
80
+ def test_dataclass_replace_method(self):
81
+ """Test the replace method."""
82
+
83
+ @dataclass
84
+ class Point:
85
+ x: float
86
+ y: float
87
+
88
+ p1 = Point(1.0, 2.0)
89
+ p2 = p1.replace(x=3.0)
90
+ assert p1.x == 1.0
91
+ assert p2.x == 3.0
92
+ assert p2.y == 2.0
93
+
94
+ def test_dataclass_with_defaults(self):
95
+ """Test dataclass with default values."""
96
+
97
+ @dataclass
98
+ class Config:
99
+ learning_rate: float = 0.001
100
+ batch_size: int = 32
101
+ name: str = field(default="default", pytree_node=False)
102
+
103
+ c1 = Config()
104
+ assert c1.learning_rate == 0.001
105
+ assert c1.batch_size == 32
106
+ assert c1.name == "default"
107
+
108
+ c2 = Config(learning_rate=0.01)
109
+ assert c2.learning_rate == 0.01
110
+
111
+ def test_dataclass_pytree_behavior(self):
112
+ """Test that dataclass works as JAX pytree."""
113
+
114
+ @dataclass
115
+ class Model:
116
+ weights: jax.Array
117
+ bias: jax.Array
118
+ name: str = field(pytree_node=False, default="model")
119
+
120
+ weights = jnp.ones((3, 3))
121
+ bias = jnp.zeros(3)
122
+ model = Model(weights=weights, bias=bias)
123
+
124
+ # Test tree_map
125
+ model2 = jax.tree_util.tree_map(lambda x: x * 2, model)
126
+ assert jnp.allclose(model2.weights, weights * 2)
127
+ assert jnp.allclose(model2.bias, bias * 2)
128
+ assert model2.name == "model" # Should not be affected
129
+
130
+ # Test tree_leaves
131
+ leaves = jax.tree_util.tree_leaves(model)
132
+ assert len(leaves) == 2 # Only weights and bias
133
+
134
+ def test_dataclass_with_jax_transformations(self):
135
+ """Test dataclass with JAX transformations."""
136
+
137
+ @dataclass
138
+ class Linear:
139
+ weight: jax.Array
140
+ bias: jax.Array
141
+
142
+ layer = Linear(
143
+ weight=jnp.ones((4, 3)),
144
+ bias=jnp.zeros(4)
145
+ )
146
+
147
+ # Test with jit
148
+ @jax.jit
149
+ def apply(layer, x):
150
+ return jnp.dot(x, layer.weight.T) + layer.bias
151
+
152
+ x = jnp.ones(3)
153
+ y = apply(layer, x)
154
+ assert y.shape == (4,)
155
+
156
+ # Test with grad
157
+ def loss_fn(layer):
158
+ return jnp.sum(layer.weight ** 2) + jnp.sum(layer.bias ** 2)
159
+
160
+ grad_fn = jax.grad(loss_fn)
161
+ grads = grad_fn(layer)
162
+ assert grads.weight.shape == layer.weight.shape
163
+ assert grads.bias.shape == layer.bias.shape
164
+
165
+ def test_dataclass_no_double_decoration(self):
166
+ """Test that dataclass decorator is idempotent."""
167
+
168
+ @dataclass
169
+ @dataclass # Should not cause issues
170
+ class Point:
171
+ x: float
172
+ y: float
173
+
174
+ p = Point(1.0, 2.0)
175
+ assert p.x == 1.0
176
+ assert hasattr(Point, '_brainstate_dataclass')
177
+
178
+
179
+ class TestPyTreeNode:
180
+ """Test the PyTreeNode base class."""
181
+
182
+ def test_pytreenode_subclass(self):
183
+ """Test creating a PyTreeNode subclass."""
184
+
185
+ class Layer(PyTreeNode):
186
+ weights: jax.Array
187
+ bias: jax.Array
188
+ activation: str = field(pytree_node=False, default="relu")
189
+
190
+ layer = Layer(
191
+ weights=jnp.ones((4, 4)),
192
+ bias=jnp.zeros(4)
193
+ )
194
+ assert layer.activation == "relu"
195
+ assert jnp.allclose(layer.weights, jnp.ones((4, 4)))
196
+
197
+ def test_pytreenode_is_frozen(self):
198
+ """Test that PyTreeNode subclasses are frozen."""
199
+
200
+ class Layer(PyTreeNode):
201
+ weights: jax.Array
202
+
203
+ layer = Layer(weights=jnp.ones(3))
204
+ with pytest.raises(Exception):
205
+ layer.weights = jnp.zeros(3)
206
+
207
+ def test_pytreenode_replace(self):
208
+ """Test replace method on PyTreeNode."""
209
+
210
+ class Layer(PyTreeNode):
211
+ weights: jax.Array
212
+ bias: jax.Array
213
+
214
+ layer1 = Layer(weights=jnp.ones(3), bias=jnp.zeros(3))
215
+ layer2 = layer1.replace(weights=jnp.ones(3) * 2)
216
+ assert jnp.allclose(layer2.weights, jnp.ones(3) * 2)
217
+ assert jnp.allclose(layer2.bias, jnp.zeros(3))
218
+
219
+ def test_pytreenode_with_jax(self):
220
+ """Test PyTreeNode with JAX transformations."""
221
+
222
+ class MLP(PyTreeNode):
223
+ layer1: Any
224
+ layer2: Any
225
+
226
+ class Linear(PyTreeNode):
227
+ weight: jax.Array
228
+ bias: jax.Array
229
+
230
+ mlp = MLP(
231
+ layer1=Linear(weight=jnp.ones((4, 3)), bias=jnp.zeros(4)),
232
+ layer2=Linear(weight=jnp.ones((2, 4)), bias=jnp.zeros(2))
233
+ )
234
+
235
+ # Test tree_map
236
+ mlp2 = jax.tree_util.tree_map(lambda x: x * 2, mlp)
237
+ assert jnp.allclose(mlp2.layer1.weight, mlp.layer1.weight * 2)
238
+
239
+ # Test with grad
240
+ def loss_fn(model):
241
+ return jnp.sum(model.layer1.weight ** 2) + jnp.sum(model.layer2.weight ** 2)
242
+
243
+ grad_fn = jax.grad(loss_fn)
244
+ grads = grad_fn(mlp)
245
+ assert grads.layer1.weight.shape == mlp.layer1.weight.shape
246
+
247
+
248
+ class TestFrozenDict:
249
+ """Test the FrozenDict class."""
250
+
251
+ def test_frozendict_creation(self):
252
+ """Test creating FrozenDict."""
253
+ # From dict
254
+ fd1 = FrozenDict({'a': 1, 'b': 2})
255
+ assert fd1['a'] == 1
256
+ assert fd1['b'] == 2
257
+
258
+ # From kwargs
259
+ fd2 = FrozenDict(a=1, b=2)
260
+ assert fd2['a'] == 1
261
+
262
+ # From items
263
+ fd3 = FrozenDict([('a', 1), ('b', 2)])
264
+ assert fd3['a'] == 1
265
+
266
+ def test_frozendict_immutability(self):
267
+ """Test that FrozenDict is immutable."""
268
+ fd = FrozenDict({'a': 1})
269
+
270
+ with pytest.raises(TypeError):
271
+ fd['a'] = 2
272
+
273
+ with pytest.raises(TypeError):
274
+ fd['c'] = 3
275
+
276
+ with pytest.raises(TypeError):
277
+ del fd['a']
278
+
279
+ def test_frozendict_basic_operations(self):
280
+ """Test basic dictionary operations."""
281
+ fd = FrozenDict({'a': 1, 'b': 2, 'c': 3})
282
+
283
+ # Contains
284
+ assert 'a' in fd
285
+ assert 'd' not in fd
286
+
287
+ # Length
288
+ assert len(fd) == 3
289
+
290
+ # Iteration
291
+ keys = list(fd)
292
+ assert set(keys) == {'a', 'b', 'c'}
293
+
294
+ # Get
295
+ assert fd.get('a') == 1
296
+ assert fd.get('d') is None
297
+ assert fd.get('d', 10) == 10
298
+
299
+ def test_frozendict_views(self):
300
+ """Test dictionary views."""
301
+ fd = FrozenDict({'a': 1, 'b': 2})
302
+
303
+ # Keys view
304
+ keys = fd.keys()
305
+ assert set(keys) == {'a', 'b'}
306
+ assert 'FrozenDict.keys' in repr(keys)
307
+
308
+ # Values view
309
+ values = fd.values()
310
+ assert set(values) == {1, 2}
311
+ assert 'FrozenDict.values' in repr(values)
312
+
313
+ # Items view
314
+ items = list(fd.items())
315
+ assert len(items) == 2
316
+ assert ('a', 1) in items
317
+
318
+ def test_frozendict_copy(self):
319
+ """Test copy method."""
320
+ fd1 = FrozenDict({'a': 1, 'b': 2})
321
+
322
+ # Copy without changes
323
+ fd2 = fd1.copy()
324
+ assert fd2 == fd1
325
+ assert fd2 is not fd1
326
+
327
+ # Copy with updates
328
+ fd3 = fd1.copy({'c': 3, 'a': 10})
329
+ assert fd3['a'] == 10
330
+ assert fd3['b'] == 2
331
+ assert fd3['c'] == 3
332
+ assert fd1['a'] == 1 # Original unchanged
333
+
334
+ def test_frozendict_pop(self):
335
+ """Test pop method."""
336
+ fd1 = FrozenDict({'a': 1, 'b': 2, 'c': 3})
337
+
338
+ fd2, value = fd1.pop('b')
339
+ assert value == 2
340
+ assert 'b' not in fd2
341
+ assert len(fd2) == 2
342
+ assert 'b' in fd1 # Original unchanged
343
+
344
+ # Pop non-existent key
345
+ with pytest.raises(KeyError):
346
+ fd1.pop('d')
347
+
348
+ def test_frozendict_nested(self):
349
+ """Test nested FrozenDict."""
350
+ fd = FrozenDict({
351
+ 'a': 1,
352
+ 'b': {'c': 2, 'd': {'e': 3}}
353
+ })
354
+
355
+ # Access nested values
356
+ assert fd['b']['c'] == 2
357
+ assert fd['b']['d']['e'] == 3
358
+
359
+ # Nested values are also FrozenDict
360
+ assert isinstance(fd['b'], FrozenDict)
361
+ assert isinstance(fd['b']['d'], FrozenDict)
362
+
363
+ def test_frozendict_hash(self):
364
+ """Test FrozenDict hashing."""
365
+ fd1 = FrozenDict({'a': 1, 'b': 2})
366
+ fd2 = FrozenDict({'a': 1, 'b': 2})
367
+ fd3 = FrozenDict({'a': 1, 'b': 3})
368
+
369
+ # Equal dicts have same hash
370
+ assert hash(fd1) == hash(fd2)
371
+
372
+ # Can be used in sets
373
+ s = {fd1, fd2, fd3}
374
+ assert len(s) == 2
375
+
376
+ def test_frozendict_equality(self):
377
+ """Test FrozenDict equality."""
378
+ fd1 = FrozenDict({'a': 1, 'b': 2})
379
+ fd2 = FrozenDict({'a': 1, 'b': 2})
380
+ fd3 = FrozenDict({'a': 1, 'b': 3})
381
+ d = {'a': 1, 'b': 2}
382
+
383
+ assert fd1 == fd2
384
+ assert fd1 != fd3
385
+ assert fd1 == d
386
+ assert fd1 != "not a dict"
387
+
388
+ def test_frozendict_pickle(self):
389
+ """Test FrozenDict pickling."""
390
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
391
+
392
+ # Pickle and unpickle
393
+ pickled = pickle.dumps(fd)
394
+ fd2 = pickle.loads(pickled)
395
+
396
+ assert fd == fd2
397
+ assert fd['b']['c'] == fd2['b']['c']
398
+
399
+ def test_frozendict_pretty_repr(self):
400
+ """Test pretty representation."""
401
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
402
+ repr_str = fd.pretty_repr()
403
+
404
+ assert 'FrozenDict' in repr_str
405
+ assert "'a': 1" in repr_str
406
+ assert "'c': 2" in repr_str
407
+
408
+ def test_frozendict_as_pytree(self):
409
+ """Test FrozenDict as JAX pytree."""
410
+ fd = FrozenDict({'a': jnp.ones(3), 'b': jnp.zeros(2)})
411
+
412
+ # Tree map
413
+ fd2 = jax.tree_util.tree_map(lambda x: x * 2, fd)
414
+ assert jnp.allclose(fd2['a'], jnp.ones(3) * 2)
415
+ assert jnp.allclose(fd2['b'], jnp.zeros(2))
416
+
417
+ # Tree leaves
418
+ leaves = jax.tree_util.tree_leaves(fd)
419
+ assert len(leaves) == 2
420
+
421
+ # Tree flatten and unflatten
422
+ values, treedef = jax.tree_util.tree_flatten(fd)
423
+ fd3 = jax.tree_util.tree_unflatten(treedef, values)
424
+ assert fd == fd3
425
+
426
+
427
+ class TestUtilityFunctions:
428
+ """Test utility functions."""
429
+
430
+ def test_freeze(self):
431
+ """Test freeze function."""
432
+ # Regular dict
433
+ d = {'a': 1, 'b': {'c': 2}}
434
+ fd = freeze(d)
435
+ assert isinstance(fd, FrozenDict)
436
+ assert fd['a'] == 1
437
+ assert isinstance(fd['b'], FrozenDict)
438
+
439
+ # Already frozen
440
+ fd2 = freeze(fd)
441
+ assert fd2 is fd
442
+
443
+ def test_unfreeze(self):
444
+ """Test unfreeze function."""
445
+ # FrozenDict
446
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
447
+ d = unfreeze(fd)
448
+ assert isinstance(d, dict)
449
+ assert not isinstance(d, FrozenDict)
450
+ assert d['a'] == 1
451
+ assert isinstance(d['b'], dict)
452
+
453
+ # Regular dict
454
+ d2 = {'a': 1}
455
+ d3 = unfreeze(d2)
456
+ assert d3 == d2
457
+ assert d3 is not d2 # Should be a copy
458
+
459
+ # Non-dict
460
+ assert unfreeze(42) == 42
461
+
462
+ def test_copy_function(self):
463
+ """Test copy function."""
464
+ # FrozenDict
465
+ fd1 = FrozenDict({'a': 1})
466
+ fd2 = copy(fd1, {'b': 2})
467
+ assert isinstance(fd2, FrozenDict)
468
+ assert fd2['a'] == 1
469
+ assert fd2['b'] == 2
470
+
471
+ # Regular dict
472
+ d1 = {'a': 1}
473
+ d2 = copy(d1, {'b': 2})
474
+ assert isinstance(d2, dict)
475
+ assert not isinstance(d2, FrozenDict)
476
+ assert d2['a'] == 1
477
+ assert d2['b'] == 2
478
+ assert d1 == {'a': 1} # Original unchanged
479
+
480
+ # Invalid type
481
+ with pytest.raises(TypeError):
482
+ copy([1, 2, 3])
483
+
484
+ def test_pop_function(self):
485
+ """Test pop function."""
486
+ # FrozenDict
487
+ fd1 = FrozenDict({'a': 1, 'b': 2})
488
+ fd2, value = pop(fd1, 'a')
489
+ assert isinstance(fd2, FrozenDict)
490
+ assert value == 1
491
+ assert 'a' not in fd2
492
+ assert 'a' in fd1
493
+
494
+ # Regular dict
495
+ d1 = {'a': 1, 'b': 2}
496
+ d2, value = pop(d1, 'a')
497
+ assert isinstance(d2, dict)
498
+ assert value == 1
499
+ assert 'a' not in d2
500
+ assert 'a' in d1
501
+
502
+ # Invalid type
503
+ with pytest.raises(TypeError):
504
+ pop([1, 2, 3], 0)
505
+
506
+ def test_pretty_repr_function(self):
507
+ """Test pretty_repr function."""
508
+ # FrozenDict
509
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
510
+ s = pretty_repr(fd)
511
+ assert 'FrozenDict' in s
512
+
513
+ # Regular dict
514
+ d = {'a': 1, 'b': {'c': 2}}
515
+ s = pretty_repr(d)
516
+ assert 'a' in s
517
+ assert 'c' in s
518
+
519
+ # Other type
520
+ s = pretty_repr([1, 2, 3])
521
+ assert s == "[1, 2, 3]"
522
+
523
+
524
+ class TestIntegration:
525
+ """Integration tests combining multiple features."""
526
+
527
+ def test_nested_structures(self):
528
+ """Test complex nested structures."""
529
+
530
+ @dataclass
531
+ class Config:
532
+ hyperparams: FrozenDict
533
+ metadata: dict = field(pytree_node=False)
534
+
535
+ class Model(PyTreeNode):
536
+ config: Config
537
+ weights: jax.Array
538
+
539
+ config = Config(
540
+ hyperparams=FrozenDict({'lr': 0.001, 'batch_size': 32}),
541
+ metadata={'version': '1.0'}
542
+ )
543
+ model = Model(
544
+ config=config,
545
+ weights=jnp.ones((4, 4))
546
+ )
547
+
548
+ # Test tree operations
549
+ model2 = jax.tree_util.tree_map(lambda x: x * 2 if isinstance(x, jax.Array) else x, model)
550
+ assert jnp.allclose(model2.weights, model.weights * 2)
551
+ assert model2.config.hyperparams['lr'] == 0.001
552
+ assert model2.config.metadata['version'] == '1.0'
553
+
554
+ def test_jax_transformations_integration(self):
555
+ """Test integration with various JAX transformations."""
556
+
557
+ @dataclass
558
+ class State:
559
+ params: FrozenDict
560
+ step: int = field(pytree_node=False, default=0)
561
+
562
+ state = State(
563
+ params=FrozenDict({
564
+ 'w': jnp.ones((3, 3)),
565
+ 'b': jnp.zeros(3)
566
+ })
567
+ )
568
+
569
+ # JIT compilation
570
+ @jax.jit
571
+ def update(state, grad):
572
+ new_params = jax.tree_util.tree_map(
573
+ lambda p, g: p - 0.01 * g,
574
+ state.params,
575
+ grad
576
+ )
577
+ return state.replace(
578
+ params=new_params,
579
+ step=state.step + 1
580
+ )
581
+
582
+ grad = FrozenDict({'w': jnp.ones((3, 3)), 'b': jnp.ones(3)})
583
+ new_state = update(state, grad)
584
+ assert new_state.step == 1
585
+ assert jnp.allclose(new_state.params['w'], state.params['w'] - 0.01)
586
+
587
+ # VMAP
588
+ @jax.vmap
589
+ def batch_process(params, x):
590
+ return jnp.dot(x, params['w']) + params['b']
591
+
592
+ batch_params = jax.tree_util.tree_map(
593
+ lambda x: jnp.stack([x, x * 2]),
594
+ state.params
595
+ )
596
+ batch_x = jnp.ones((2, 3))
597
+ result = batch_process(batch_params, batch_x)
598
+ assert result.shape == (2, 3)
599
+
600
+
601
+ if __name__ == "__main__":
602
+ pytest.main([__file__, "-v"])