brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,602 +1,602 @@
1
- """
2
- Comprehensive tests for the struct module.
3
- """
4
-
5
- import pickle
6
- from typing import Any
7
-
8
- import jax
9
- import jax.numpy as jnp
10
- import jax.tree_util
11
- import pytest
12
-
13
- # Import the modules to test
14
- from brainstate.util import (
15
- field,
16
- dataclass,
17
- PyTreeNode,
18
- FrozenDict,
19
- freeze,
20
- unfreeze,
21
- copy,
22
- pop,
23
- pretty_repr,
24
- )
25
-
26
-
27
- class TestField:
28
- """Test the field function."""
29
-
30
- def test_field_with_pytree_node_true(self):
31
- """Test field with pytree_node=True."""
32
- f = field(pytree_node=True)
33
- assert f.metadata['pytree_node'] is True
34
-
35
- def test_field_with_pytree_node_false(self):
36
- """Test field with pytree_node=False."""
37
- f = field(pytree_node=False)
38
- assert f.metadata['pytree_node'] is False
39
-
40
- def test_field_with_default(self):
41
- """Test field with default value."""
42
- f = field(default=42)
43
- assert f.default == 42
44
- assert f.metadata['pytree_node'] is True
45
-
46
- def test_field_with_metadata(self):
47
- """Test field preserves additional metadata."""
48
- f = field(pytree_node=False, metadata={'custom': 'data'})
49
- assert f.metadata['pytree_node'] is False
50
- assert f.metadata['custom'] == 'data'
51
-
52
-
53
- class TestDataclass:
54
- """Test the dataclass decorator."""
55
-
56
- def test_basic_dataclass(self):
57
- """Test basic dataclass creation."""
58
-
59
- @dataclass
60
- class Point:
61
- x: float
62
- y: float
63
-
64
- p = Point(1.0, 2.0)
65
- assert p.x == 1.0
66
- assert p.y == 2.0
67
-
68
- def test_dataclass_is_frozen(self):
69
- """Test that dataclasses are frozen by default."""
70
-
71
- @dataclass
72
- class Point:
73
- x: float
74
- y: float
75
-
76
- p = Point(1.0, 2.0)
77
- with pytest.raises(Exception): # Should be immutable
78
- p.x = 3.0
79
-
80
- def test_dataclass_replace_method(self):
81
- """Test the replace method."""
82
-
83
- @dataclass
84
- class Point:
85
- x: float
86
- y: float
87
-
88
- p1 = Point(1.0, 2.0)
89
- p2 = p1.replace(x=3.0)
90
- assert p1.x == 1.0
91
- assert p2.x == 3.0
92
- assert p2.y == 2.0
93
-
94
- def test_dataclass_with_defaults(self):
95
- """Test dataclass with default values."""
96
-
97
- @dataclass
98
- class Config:
99
- learning_rate: float = 0.001
100
- batch_size: int = 32
101
- name: str = field(default="default", pytree_node=False)
102
-
103
- c1 = Config()
104
- assert c1.learning_rate == 0.001
105
- assert c1.batch_size == 32
106
- assert c1.name == "default"
107
-
108
- c2 = Config(learning_rate=0.01)
109
- assert c2.learning_rate == 0.01
110
-
111
- def test_dataclass_pytree_behavior(self):
112
- """Test that dataclass works as JAX pytree."""
113
-
114
- @dataclass
115
- class Model:
116
- weights: jax.Array
117
- bias: jax.Array
118
- name: str = field(pytree_node=False, default="model")
119
-
120
- weights = jnp.ones((3, 3))
121
- bias = jnp.zeros(3)
122
- model = Model(weights=weights, bias=bias)
123
-
124
- # Test tree_map
125
- model2 = jax.tree_util.tree_map(lambda x: x * 2, model)
126
- assert jnp.allclose(model2.weights, weights * 2)
127
- assert jnp.allclose(model2.bias, bias * 2)
128
- assert model2.name == "model" # Should not be affected
129
-
130
- # Test tree_leaves
131
- leaves = jax.tree_util.tree_leaves(model)
132
- assert len(leaves) == 2 # Only weights and bias
133
-
134
- def test_dataclass_with_jax_transformations(self):
135
- """Test dataclass with JAX transformations."""
136
-
137
- @dataclass
138
- class Linear:
139
- weight: jax.Array
140
- bias: jax.Array
141
-
142
- layer = Linear(
143
- weight=jnp.ones((4, 3)),
144
- bias=jnp.zeros(4)
145
- )
146
-
147
- # Test with jit
148
- @jax.jit
149
- def apply(layer, x):
150
- return jnp.dot(x, layer.weight.T) + layer.bias
151
-
152
- x = jnp.ones(3)
153
- y = apply(layer, x)
154
- assert y.shape == (4,)
155
-
156
- # Test with grad
157
- def loss_fn(layer):
158
- return jnp.sum(layer.weight ** 2) + jnp.sum(layer.bias ** 2)
159
-
160
- grad_fn = jax.grad(loss_fn)
161
- grads = grad_fn(layer)
162
- assert grads.weight.shape == layer.weight.shape
163
- assert grads.bias.shape == layer.bias.shape
164
-
165
- def test_dataclass_no_double_decoration(self):
166
- """Test that dataclass decorator is idempotent."""
167
-
168
- @dataclass
169
- @dataclass # Should not cause issues
170
- class Point:
171
- x: float
172
- y: float
173
-
174
- p = Point(1.0, 2.0)
175
- assert p.x == 1.0
176
- assert hasattr(Point, '_brainstate_dataclass')
177
-
178
-
179
- class TestPyTreeNode:
180
- """Test the PyTreeNode base class."""
181
-
182
- def test_pytreenode_subclass(self):
183
- """Test creating a PyTreeNode subclass."""
184
-
185
- class Layer(PyTreeNode):
186
- weights: jax.Array
187
- bias: jax.Array
188
- activation: str = field(pytree_node=False, default="relu")
189
-
190
- layer = Layer(
191
- weights=jnp.ones((4, 4)),
192
- bias=jnp.zeros(4)
193
- )
194
- assert layer.activation == "relu"
195
- assert jnp.allclose(layer.weights, jnp.ones((4, 4)))
196
-
197
- def test_pytreenode_is_frozen(self):
198
- """Test that PyTreeNode subclasses are frozen."""
199
-
200
- class Layer(PyTreeNode):
201
- weights: jax.Array
202
-
203
- layer = Layer(weights=jnp.ones(3))
204
- with pytest.raises(Exception):
205
- layer.weights = jnp.zeros(3)
206
-
207
- def test_pytreenode_replace(self):
208
- """Test replace method on PyTreeNode."""
209
-
210
- class Layer(PyTreeNode):
211
- weights: jax.Array
212
- bias: jax.Array
213
-
214
- layer1 = Layer(weights=jnp.ones(3), bias=jnp.zeros(3))
215
- layer2 = layer1.replace(weights=jnp.ones(3) * 2)
216
- assert jnp.allclose(layer2.weights, jnp.ones(3) * 2)
217
- assert jnp.allclose(layer2.bias, jnp.zeros(3))
218
-
219
- def test_pytreenode_with_jax(self):
220
- """Test PyTreeNode with JAX transformations."""
221
-
222
- class MLP(PyTreeNode):
223
- layer1: Any
224
- layer2: Any
225
-
226
- class Linear(PyTreeNode):
227
- weight: jax.Array
228
- bias: jax.Array
229
-
230
- mlp = MLP(
231
- layer1=Linear(weight=jnp.ones((4, 3)), bias=jnp.zeros(4)),
232
- layer2=Linear(weight=jnp.ones((2, 4)), bias=jnp.zeros(2))
233
- )
234
-
235
- # Test tree_map
236
- mlp2 = jax.tree_util.tree_map(lambda x: x * 2, mlp)
237
- assert jnp.allclose(mlp2.layer1.weight, mlp.layer1.weight * 2)
238
-
239
- # Test with grad
240
- def loss_fn(model):
241
- return jnp.sum(model.layer1.weight ** 2) + jnp.sum(model.layer2.weight ** 2)
242
-
243
- grad_fn = jax.grad(loss_fn)
244
- grads = grad_fn(mlp)
245
- assert grads.layer1.weight.shape == mlp.layer1.weight.shape
246
-
247
-
248
- class TestFrozenDict:
249
- """Test the FrozenDict class."""
250
-
251
- def test_frozendict_creation(self):
252
- """Test creating FrozenDict."""
253
- # From dict
254
- fd1 = FrozenDict({'a': 1, 'b': 2})
255
- assert fd1['a'] == 1
256
- assert fd1['b'] == 2
257
-
258
- # From kwargs
259
- fd2 = FrozenDict(a=1, b=2)
260
- assert fd2['a'] == 1
261
-
262
- # From items
263
- fd3 = FrozenDict([('a', 1), ('b', 2)])
264
- assert fd3['a'] == 1
265
-
266
- def test_frozendict_immutability(self):
267
- """Test that FrozenDict is immutable."""
268
- fd = FrozenDict({'a': 1})
269
-
270
- with pytest.raises(TypeError):
271
- fd['a'] = 2
272
-
273
- with pytest.raises(TypeError):
274
- fd['c'] = 3
275
-
276
- with pytest.raises(TypeError):
277
- del fd['a']
278
-
279
- def test_frozendict_basic_operations(self):
280
- """Test basic dictionary operations."""
281
- fd = FrozenDict({'a': 1, 'b': 2, 'c': 3})
282
-
283
- # Contains
284
- assert 'a' in fd
285
- assert 'd' not in fd
286
-
287
- # Length
288
- assert len(fd) == 3
289
-
290
- # Iteration
291
- keys = list(fd)
292
- assert set(keys) == {'a', 'b', 'c'}
293
-
294
- # Get
295
- assert fd.get('a') == 1
296
- assert fd.get('d') is None
297
- assert fd.get('d', 10) == 10
298
-
299
- def test_frozendict_views(self):
300
- """Test dictionary views."""
301
- fd = FrozenDict({'a': 1, 'b': 2})
302
-
303
- # Keys view
304
- keys = fd.keys()
305
- assert set(keys) == {'a', 'b'}
306
- assert 'FrozenDict.keys' in repr(keys)
307
-
308
- # Values view
309
- values = fd.values()
310
- assert set(values) == {1, 2}
311
- assert 'FrozenDict.values' in repr(values)
312
-
313
- # Items view
314
- items = list(fd.items())
315
- assert len(items) == 2
316
- assert ('a', 1) in items
317
-
318
- def test_frozendict_copy(self):
319
- """Test copy method."""
320
- fd1 = FrozenDict({'a': 1, 'b': 2})
321
-
322
- # Copy without changes
323
- fd2 = fd1.copy()
324
- assert fd2 == fd1
325
- assert fd2 is not fd1
326
-
327
- # Copy with updates
328
- fd3 = fd1.copy({'c': 3, 'a': 10})
329
- assert fd3['a'] == 10
330
- assert fd3['b'] == 2
331
- assert fd3['c'] == 3
332
- assert fd1['a'] == 1 # Original unchanged
333
-
334
- def test_frozendict_pop(self):
335
- """Test pop method."""
336
- fd1 = FrozenDict({'a': 1, 'b': 2, 'c': 3})
337
-
338
- fd2, value = fd1.pop('b')
339
- assert value == 2
340
- assert 'b' not in fd2
341
- assert len(fd2) == 2
342
- assert 'b' in fd1 # Original unchanged
343
-
344
- # Pop non-existent key
345
- with pytest.raises(KeyError):
346
- fd1.pop('d')
347
-
348
- def test_frozendict_nested(self):
349
- """Test nested FrozenDict."""
350
- fd = FrozenDict({
351
- 'a': 1,
352
- 'b': {'c': 2, 'd': {'e': 3}}
353
- })
354
-
355
- # Access nested values
356
- assert fd['b']['c'] == 2
357
- assert fd['b']['d']['e'] == 3
358
-
359
- # Nested values are also FrozenDict
360
- assert isinstance(fd['b'], FrozenDict)
361
- assert isinstance(fd['b']['d'], FrozenDict)
362
-
363
- def test_frozendict_hash(self):
364
- """Test FrozenDict hashing."""
365
- fd1 = FrozenDict({'a': 1, 'b': 2})
366
- fd2 = FrozenDict({'a': 1, 'b': 2})
367
- fd3 = FrozenDict({'a': 1, 'b': 3})
368
-
369
- # Equal dicts have same hash
370
- assert hash(fd1) == hash(fd2)
371
-
372
- # Can be used in sets
373
- s = {fd1, fd2, fd3}
374
- assert len(s) == 2
375
-
376
- def test_frozendict_equality(self):
377
- """Test FrozenDict equality."""
378
- fd1 = FrozenDict({'a': 1, 'b': 2})
379
- fd2 = FrozenDict({'a': 1, 'b': 2})
380
- fd3 = FrozenDict({'a': 1, 'b': 3})
381
- d = {'a': 1, 'b': 2}
382
-
383
- assert fd1 == fd2
384
- assert fd1 != fd3
385
- assert fd1 == d
386
- assert fd1 != "not a dict"
387
-
388
- def test_frozendict_pickle(self):
389
- """Test FrozenDict pickling."""
390
- fd = FrozenDict({'a': 1, 'b': {'c': 2}})
391
-
392
- # Pickle and unpickle
393
- pickled = pickle.dumps(fd)
394
- fd2 = pickle.loads(pickled)
395
-
396
- assert fd == fd2
397
- assert fd['b']['c'] == fd2['b']['c']
398
-
399
- def test_frozendict_pretty_repr(self):
400
- """Test pretty representation."""
401
- fd = FrozenDict({'a': 1, 'b': {'c': 2}})
402
- repr_str = fd.pretty_repr()
403
-
404
- assert 'FrozenDict' in repr_str
405
- assert "'a': 1" in repr_str
406
- assert "'c': 2" in repr_str
407
-
408
- def test_frozendict_as_pytree(self):
409
- """Test FrozenDict as JAX pytree."""
410
- fd = FrozenDict({'a': jnp.ones(3), 'b': jnp.zeros(2)})
411
-
412
- # Tree map
413
- fd2 = jax.tree_util.tree_map(lambda x: x * 2, fd)
414
- assert jnp.allclose(fd2['a'], jnp.ones(3) * 2)
415
- assert jnp.allclose(fd2['b'], jnp.zeros(2))
416
-
417
- # Tree leaves
418
- leaves = jax.tree_util.tree_leaves(fd)
419
- assert len(leaves) == 2
420
-
421
- # Tree flatten and unflatten
422
- values, treedef = jax.tree_util.tree_flatten(fd)
423
- fd3 = jax.tree_util.tree_unflatten(treedef, values)
424
- assert fd == fd3
425
-
426
-
427
- class TestUtilityFunctions:
428
- """Test utility functions."""
429
-
430
- def test_freeze(self):
431
- """Test freeze function."""
432
- # Regular dict
433
- d = {'a': 1, 'b': {'c': 2}}
434
- fd = freeze(d)
435
- assert isinstance(fd, FrozenDict)
436
- assert fd['a'] == 1
437
- assert isinstance(fd['b'], FrozenDict)
438
-
439
- # Already frozen
440
- fd2 = freeze(fd)
441
- assert fd2 is fd
442
-
443
- def test_unfreeze(self):
444
- """Test unfreeze function."""
445
- # FrozenDict
446
- fd = FrozenDict({'a': 1, 'b': {'c': 2}})
447
- d = unfreeze(fd)
448
- assert isinstance(d, dict)
449
- assert not isinstance(d, FrozenDict)
450
- assert d['a'] == 1
451
- assert isinstance(d['b'], dict)
452
-
453
- # Regular dict
454
- d2 = {'a': 1}
455
- d3 = unfreeze(d2)
456
- assert d3 == d2
457
- assert d3 is not d2 # Should be a copy
458
-
459
- # Non-dict
460
- assert unfreeze(42) == 42
461
-
462
- def test_copy_function(self):
463
- """Test copy function."""
464
- # FrozenDict
465
- fd1 = FrozenDict({'a': 1})
466
- fd2 = copy(fd1, {'b': 2})
467
- assert isinstance(fd2, FrozenDict)
468
- assert fd2['a'] == 1
469
- assert fd2['b'] == 2
470
-
471
- # Regular dict
472
- d1 = {'a': 1}
473
- d2 = copy(d1, {'b': 2})
474
- assert isinstance(d2, dict)
475
- assert not isinstance(d2, FrozenDict)
476
- assert d2['a'] == 1
477
- assert d2['b'] == 2
478
- assert d1 == {'a': 1} # Original unchanged
479
-
480
- # Invalid type
481
- with pytest.raises(TypeError):
482
- copy([1, 2, 3])
483
-
484
- def test_pop_function(self):
485
- """Test pop function."""
486
- # FrozenDict
487
- fd1 = FrozenDict({'a': 1, 'b': 2})
488
- fd2, value = pop(fd1, 'a')
489
- assert isinstance(fd2, FrozenDict)
490
- assert value == 1
491
- assert 'a' not in fd2
492
- assert 'a' in fd1
493
-
494
- # Regular dict
495
- d1 = {'a': 1, 'b': 2}
496
- d2, value = pop(d1, 'a')
497
- assert isinstance(d2, dict)
498
- assert value == 1
499
- assert 'a' not in d2
500
- assert 'a' in d1
501
-
502
- # Invalid type
503
- with pytest.raises(TypeError):
504
- pop([1, 2, 3], 0)
505
-
506
- def test_pretty_repr_function(self):
507
- """Test pretty_repr function."""
508
- # FrozenDict
509
- fd = FrozenDict({'a': 1, 'b': {'c': 2}})
510
- s = pretty_repr(fd)
511
- assert 'FrozenDict' in s
512
-
513
- # Regular dict
514
- d = {'a': 1, 'b': {'c': 2}}
515
- s = pretty_repr(d)
516
- assert 'a' in s
517
- assert 'c' in s
518
-
519
- # Other type
520
- s = pretty_repr([1, 2, 3])
521
- assert s == "[1, 2, 3]"
522
-
523
-
524
- class TestIntegration:
525
- """Integration tests combining multiple features."""
526
-
527
- def test_nested_structures(self):
528
- """Test complex nested structures."""
529
-
530
- @dataclass
531
- class Config:
532
- hyperparams: FrozenDict
533
- metadata: dict = field(pytree_node=False)
534
-
535
- class Model(PyTreeNode):
536
- config: Config
537
- weights: jax.Array
538
-
539
- config = Config(
540
- hyperparams=FrozenDict({'lr': 0.001, 'batch_size': 32}),
541
- metadata={'version': '1.0'}
542
- )
543
- model = Model(
544
- config=config,
545
- weights=jnp.ones((4, 4))
546
- )
547
-
548
- # Test tree operations
549
- model2 = jax.tree_util.tree_map(lambda x: x * 2 if isinstance(x, jax.Array) else x, model)
550
- assert jnp.allclose(model2.weights, model.weights * 2)
551
- assert model2.config.hyperparams['lr'] == 0.001
552
- assert model2.config.metadata['version'] == '1.0'
553
-
554
- def test_jax_transformations_integration(self):
555
- """Test integration with various JAX transformations."""
556
-
557
- @dataclass
558
- class State:
559
- params: FrozenDict
560
- step: int = field(pytree_node=False, default=0)
561
-
562
- state = State(
563
- params=FrozenDict({
564
- 'w': jnp.ones((3, 3)),
565
- 'b': jnp.zeros(3)
566
- })
567
- )
568
-
569
- # JIT compilation
570
- @jax.jit
571
- def update(state, grad):
572
- new_params = jax.tree_util.tree_map(
573
- lambda p, g: p - 0.01 * g,
574
- state.params,
575
- grad
576
- )
577
- return state.replace(
578
- params=new_params,
579
- step=state.step + 1
580
- )
581
-
582
- grad = FrozenDict({'w': jnp.ones((3, 3)), 'b': jnp.ones(3)})
583
- new_state = update(state, grad)
584
- assert new_state.step == 1
585
- assert jnp.allclose(new_state.params['w'], state.params['w'] - 0.01)
586
-
587
- # VMAP
588
- @jax.vmap
589
- def batch_process(params, x):
590
- return jnp.dot(x, params['w']) + params['b']
591
-
592
- batch_params = jax.tree_util.tree_map(
593
- lambda x: jnp.stack([x, x * 2]),
594
- state.params
595
- )
596
- batch_x = jnp.ones((2, 3))
597
- result = batch_process(batch_params, batch_x)
598
- assert result.shape == (2, 3)
599
-
600
-
601
- if __name__ == "__main__":
602
- pytest.main([__file__, "-v"])
1
+ """
2
+ Comprehensive tests for the struct module.
3
+ """
4
+
5
+ import pickle
6
+ from typing import Any
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jax.tree_util
11
+ import pytest
12
+
13
+ # Import the modules to test
14
+ from brainstate.util import (
15
+ field,
16
+ dataclass,
17
+ PyTreeNode,
18
+ FrozenDict,
19
+ freeze,
20
+ unfreeze,
21
+ copy,
22
+ pop,
23
+ pretty_repr,
24
+ )
25
+
26
+
27
+ class TestField:
28
+ """Test the field function."""
29
+
30
+ def test_field_with_pytree_node_true(self):
31
+ """Test field with pytree_node=True."""
32
+ f = field(pytree_node=True)
33
+ assert f.metadata['pytree_node'] is True
34
+
35
+ def test_field_with_pytree_node_false(self):
36
+ """Test field with pytree_node=False."""
37
+ f = field(pytree_node=False)
38
+ assert f.metadata['pytree_node'] is False
39
+
40
+ def test_field_with_default(self):
41
+ """Test field with default value."""
42
+ f = field(default=42)
43
+ assert f.default == 42
44
+ assert f.metadata['pytree_node'] is True
45
+
46
+ def test_field_with_metadata(self):
47
+ """Test field preserves additional metadata."""
48
+ f = field(pytree_node=False, metadata={'custom': 'data'})
49
+ assert f.metadata['pytree_node'] is False
50
+ assert f.metadata['custom'] == 'data'
51
+
52
+
53
+ class TestDataclass:
54
+ """Test the dataclass decorator."""
55
+
56
+ def test_basic_dataclass(self):
57
+ """Test basic dataclass creation."""
58
+
59
+ @dataclass
60
+ class Point:
61
+ x: float
62
+ y: float
63
+
64
+ p = Point(1.0, 2.0)
65
+ assert p.x == 1.0
66
+ assert p.y == 2.0
67
+
68
+ def test_dataclass_is_frozen(self):
69
+ """Test that dataclasses are frozen by default."""
70
+
71
+ @dataclass
72
+ class Point:
73
+ x: float
74
+ y: float
75
+
76
+ p = Point(1.0, 2.0)
77
+ with pytest.raises(Exception): # Should be immutable
78
+ p.x = 3.0
79
+
80
+ def test_dataclass_replace_method(self):
81
+ """Test the replace method."""
82
+
83
+ @dataclass
84
+ class Point:
85
+ x: float
86
+ y: float
87
+
88
+ p1 = Point(1.0, 2.0)
89
+ p2 = p1.replace(x=3.0)
90
+ assert p1.x == 1.0
91
+ assert p2.x == 3.0
92
+ assert p2.y == 2.0
93
+
94
+ def test_dataclass_with_defaults(self):
95
+ """Test dataclass with default values."""
96
+
97
+ @dataclass
98
+ class Config:
99
+ learning_rate: float = 0.001
100
+ batch_size: int = 32
101
+ name: str = field(default="default", pytree_node=False)
102
+
103
+ c1 = Config()
104
+ assert c1.learning_rate == 0.001
105
+ assert c1.batch_size == 32
106
+ assert c1.name == "default"
107
+
108
+ c2 = Config(learning_rate=0.01)
109
+ assert c2.learning_rate == 0.01
110
+
111
+ def test_dataclass_pytree_behavior(self):
112
+ """Test that dataclass works as JAX pytree."""
113
+
114
+ @dataclass
115
+ class Model:
116
+ weights: jax.Array
117
+ bias: jax.Array
118
+ name: str = field(pytree_node=False, default="model")
119
+
120
+ weights = jnp.ones((3, 3))
121
+ bias = jnp.zeros(3)
122
+ model = Model(weights=weights, bias=bias)
123
+
124
+ # Test tree_map
125
+ model2 = jax.tree_util.tree_map(lambda x: x * 2, model)
126
+ assert jnp.allclose(model2.weights, weights * 2)
127
+ assert jnp.allclose(model2.bias, bias * 2)
128
+ assert model2.name == "model" # Should not be affected
129
+
130
+ # Test tree_leaves
131
+ leaves = jax.tree_util.tree_leaves(model)
132
+ assert len(leaves) == 2 # Only weights and bias
133
+
134
+ def test_dataclass_with_jax_transformations(self):
135
+ """Test dataclass with JAX transformations."""
136
+
137
+ @dataclass
138
+ class Linear:
139
+ weight: jax.Array
140
+ bias: jax.Array
141
+
142
+ layer = Linear(
143
+ weight=jnp.ones((4, 3)),
144
+ bias=jnp.zeros(4)
145
+ )
146
+
147
+ # Test with jit
148
+ @jax.jit
149
+ def apply(layer, x):
150
+ return jnp.dot(x, layer.weight.T) + layer.bias
151
+
152
+ x = jnp.ones(3)
153
+ y = apply(layer, x)
154
+ assert y.shape == (4,)
155
+
156
+ # Test with grad
157
+ def loss_fn(layer):
158
+ return jnp.sum(layer.weight ** 2) + jnp.sum(layer.bias ** 2)
159
+
160
+ grad_fn = jax.grad(loss_fn)
161
+ grads = grad_fn(layer)
162
+ assert grads.weight.shape == layer.weight.shape
163
+ assert grads.bias.shape == layer.bias.shape
164
+
165
+ def test_dataclass_no_double_decoration(self):
166
+ """Test that dataclass decorator is idempotent."""
167
+
168
+ @dataclass
169
+ @dataclass # Should not cause issues
170
+ class Point:
171
+ x: float
172
+ y: float
173
+
174
+ p = Point(1.0, 2.0)
175
+ assert p.x == 1.0
176
+ assert hasattr(Point, '_brainstate_dataclass')
177
+
178
+
179
+ class TestPyTreeNode:
180
+ """Test the PyTreeNode base class."""
181
+
182
+ def test_pytreenode_subclass(self):
183
+ """Test creating a PyTreeNode subclass."""
184
+
185
+ class Layer(PyTreeNode):
186
+ weights: jax.Array
187
+ bias: jax.Array
188
+ activation: str = field(pytree_node=False, default="relu")
189
+
190
+ layer = Layer(
191
+ weights=jnp.ones((4, 4)),
192
+ bias=jnp.zeros(4)
193
+ )
194
+ assert layer.activation == "relu"
195
+ assert jnp.allclose(layer.weights, jnp.ones((4, 4)))
196
+
197
+ def test_pytreenode_is_frozen(self):
198
+ """Test that PyTreeNode subclasses are frozen."""
199
+
200
+ class Layer(PyTreeNode):
201
+ weights: jax.Array
202
+
203
+ layer = Layer(weights=jnp.ones(3))
204
+ with pytest.raises(Exception):
205
+ layer.weights = jnp.zeros(3)
206
+
207
+ def test_pytreenode_replace(self):
208
+ """Test replace method on PyTreeNode."""
209
+
210
+ class Layer(PyTreeNode):
211
+ weights: jax.Array
212
+ bias: jax.Array
213
+
214
+ layer1 = Layer(weights=jnp.ones(3), bias=jnp.zeros(3))
215
+ layer2 = layer1.replace(weights=jnp.ones(3) * 2)
216
+ assert jnp.allclose(layer2.weights, jnp.ones(3) * 2)
217
+ assert jnp.allclose(layer2.bias, jnp.zeros(3))
218
+
219
+ def test_pytreenode_with_jax(self):
220
+ """Test PyTreeNode with JAX transformations."""
221
+
222
+ class MLP(PyTreeNode):
223
+ layer1: Any
224
+ layer2: Any
225
+
226
+ class Linear(PyTreeNode):
227
+ weight: jax.Array
228
+ bias: jax.Array
229
+
230
+ mlp = MLP(
231
+ layer1=Linear(weight=jnp.ones((4, 3)), bias=jnp.zeros(4)),
232
+ layer2=Linear(weight=jnp.ones((2, 4)), bias=jnp.zeros(2))
233
+ )
234
+
235
+ # Test tree_map
236
+ mlp2 = jax.tree_util.tree_map(lambda x: x * 2, mlp)
237
+ assert jnp.allclose(mlp2.layer1.weight, mlp.layer1.weight * 2)
238
+
239
+ # Test with grad
240
+ def loss_fn(model):
241
+ return jnp.sum(model.layer1.weight ** 2) + jnp.sum(model.layer2.weight ** 2)
242
+
243
+ grad_fn = jax.grad(loss_fn)
244
+ grads = grad_fn(mlp)
245
+ assert grads.layer1.weight.shape == mlp.layer1.weight.shape
246
+
247
+
248
+ class TestFrozenDict:
249
+ """Test the FrozenDict class."""
250
+
251
+ def test_frozendict_creation(self):
252
+ """Test creating FrozenDict."""
253
+ # From dict
254
+ fd1 = FrozenDict({'a': 1, 'b': 2})
255
+ assert fd1['a'] == 1
256
+ assert fd1['b'] == 2
257
+
258
+ # From kwargs
259
+ fd2 = FrozenDict(a=1, b=2)
260
+ assert fd2['a'] == 1
261
+
262
+ # From items
263
+ fd3 = FrozenDict([('a', 1), ('b', 2)])
264
+ assert fd3['a'] == 1
265
+
266
+ def test_frozendict_immutability(self):
267
+ """Test that FrozenDict is immutable."""
268
+ fd = FrozenDict({'a': 1})
269
+
270
+ with pytest.raises(TypeError):
271
+ fd['a'] = 2
272
+
273
+ with pytest.raises(TypeError):
274
+ fd['c'] = 3
275
+
276
+ with pytest.raises(TypeError):
277
+ del fd['a']
278
+
279
+ def test_frozendict_basic_operations(self):
280
+ """Test basic dictionary operations."""
281
+ fd = FrozenDict({'a': 1, 'b': 2, 'c': 3})
282
+
283
+ # Contains
284
+ assert 'a' in fd
285
+ assert 'd' not in fd
286
+
287
+ # Length
288
+ assert len(fd) == 3
289
+
290
+ # Iteration
291
+ keys = list(fd)
292
+ assert set(keys) == {'a', 'b', 'c'}
293
+
294
+ # Get
295
+ assert fd.get('a') == 1
296
+ assert fd.get('d') is None
297
+ assert fd.get('d', 10) == 10
298
+
299
+ def test_frozendict_views(self):
300
+ """Test dictionary views."""
301
+ fd = FrozenDict({'a': 1, 'b': 2})
302
+
303
+ # Keys view
304
+ keys = fd.keys()
305
+ assert set(keys) == {'a', 'b'}
306
+ assert 'FrozenDict.keys' in repr(keys)
307
+
308
+ # Values view
309
+ values = fd.values()
310
+ assert set(values) == {1, 2}
311
+ assert 'FrozenDict.values' in repr(values)
312
+
313
+ # Items view
314
+ items = list(fd.items())
315
+ assert len(items) == 2
316
+ assert ('a', 1) in items
317
+
318
+ def test_frozendict_copy(self):
319
+ """Test copy method."""
320
+ fd1 = FrozenDict({'a': 1, 'b': 2})
321
+
322
+ # Copy without changes
323
+ fd2 = fd1.copy()
324
+ assert fd2 == fd1
325
+ assert fd2 is not fd1
326
+
327
+ # Copy with updates
328
+ fd3 = fd1.copy({'c': 3, 'a': 10})
329
+ assert fd3['a'] == 10
330
+ assert fd3['b'] == 2
331
+ assert fd3['c'] == 3
332
+ assert fd1['a'] == 1 # Original unchanged
333
+
334
+ def test_frozendict_pop(self):
335
+ """Test pop method."""
336
+ fd1 = FrozenDict({'a': 1, 'b': 2, 'c': 3})
337
+
338
+ fd2, value = fd1.pop('b')
339
+ assert value == 2
340
+ assert 'b' not in fd2
341
+ assert len(fd2) == 2
342
+ assert 'b' in fd1 # Original unchanged
343
+
344
+ # Pop non-existent key
345
+ with pytest.raises(KeyError):
346
+ fd1.pop('d')
347
+
348
+ def test_frozendict_nested(self):
349
+ """Test nested FrozenDict."""
350
+ fd = FrozenDict({
351
+ 'a': 1,
352
+ 'b': {'c': 2, 'd': {'e': 3}}
353
+ })
354
+
355
+ # Access nested values
356
+ assert fd['b']['c'] == 2
357
+ assert fd['b']['d']['e'] == 3
358
+
359
+ # Nested values are also FrozenDict
360
+ assert isinstance(fd['b'], FrozenDict)
361
+ assert isinstance(fd['b']['d'], FrozenDict)
362
+
363
+ def test_frozendict_hash(self):
364
+ """Test FrozenDict hashing."""
365
+ fd1 = FrozenDict({'a': 1, 'b': 2})
366
+ fd2 = FrozenDict({'a': 1, 'b': 2})
367
+ fd3 = FrozenDict({'a': 1, 'b': 3})
368
+
369
+ # Equal dicts have same hash
370
+ assert hash(fd1) == hash(fd2)
371
+
372
+ # Can be used in sets
373
+ s = {fd1, fd2, fd3}
374
+ assert len(s) == 2
375
+
376
+ def test_frozendict_equality(self):
377
+ """Test FrozenDict equality."""
378
+ fd1 = FrozenDict({'a': 1, 'b': 2})
379
+ fd2 = FrozenDict({'a': 1, 'b': 2})
380
+ fd3 = FrozenDict({'a': 1, 'b': 3})
381
+ d = {'a': 1, 'b': 2}
382
+
383
+ assert fd1 == fd2
384
+ assert fd1 != fd3
385
+ assert fd1 == d
386
+ assert fd1 != "not a dict"
387
+
388
+ def test_frozendict_pickle(self):
389
+ """Test FrozenDict pickling."""
390
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
391
+
392
+ # Pickle and unpickle
393
+ pickled = pickle.dumps(fd)
394
+ fd2 = pickle.loads(pickled)
395
+
396
+ assert fd == fd2
397
+ assert fd['b']['c'] == fd2['b']['c']
398
+
399
+ def test_frozendict_pretty_repr(self):
400
+ """Test pretty representation."""
401
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
402
+ repr_str = fd.pretty_repr()
403
+
404
+ assert 'FrozenDict' in repr_str
405
+ assert "'a': 1" in repr_str
406
+ assert "'c': 2" in repr_str
407
+
408
+ def test_frozendict_as_pytree(self):
409
+ """Test FrozenDict as JAX pytree."""
410
+ fd = FrozenDict({'a': jnp.ones(3), 'b': jnp.zeros(2)})
411
+
412
+ # Tree map
413
+ fd2 = jax.tree_util.tree_map(lambda x: x * 2, fd)
414
+ assert jnp.allclose(fd2['a'], jnp.ones(3) * 2)
415
+ assert jnp.allclose(fd2['b'], jnp.zeros(2))
416
+
417
+ # Tree leaves
418
+ leaves = jax.tree_util.tree_leaves(fd)
419
+ assert len(leaves) == 2
420
+
421
+ # Tree flatten and unflatten
422
+ values, treedef = jax.tree_util.tree_flatten(fd)
423
+ fd3 = jax.tree_util.tree_unflatten(treedef, values)
424
+ assert fd == fd3
425
+
426
+
427
+ class TestUtilityFunctions:
428
+ """Test utility functions."""
429
+
430
+ def test_freeze(self):
431
+ """Test freeze function."""
432
+ # Regular dict
433
+ d = {'a': 1, 'b': {'c': 2}}
434
+ fd = freeze(d)
435
+ assert isinstance(fd, FrozenDict)
436
+ assert fd['a'] == 1
437
+ assert isinstance(fd['b'], FrozenDict)
438
+
439
+ # Already frozen
440
+ fd2 = freeze(fd)
441
+ assert fd2 is fd
442
+
443
+ def test_unfreeze(self):
444
+ """Test unfreeze function."""
445
+ # FrozenDict
446
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
447
+ d = unfreeze(fd)
448
+ assert isinstance(d, dict)
449
+ assert not isinstance(d, FrozenDict)
450
+ assert d['a'] == 1
451
+ assert isinstance(d['b'], dict)
452
+
453
+ # Regular dict
454
+ d2 = {'a': 1}
455
+ d3 = unfreeze(d2)
456
+ assert d3 == d2
457
+ assert d3 is not d2 # Should be a copy
458
+
459
+ # Non-dict
460
+ assert unfreeze(42) == 42
461
+
462
+ def test_copy_function(self):
463
+ """Test copy function."""
464
+ # FrozenDict
465
+ fd1 = FrozenDict({'a': 1})
466
+ fd2 = copy(fd1, {'b': 2})
467
+ assert isinstance(fd2, FrozenDict)
468
+ assert fd2['a'] == 1
469
+ assert fd2['b'] == 2
470
+
471
+ # Regular dict
472
+ d1 = {'a': 1}
473
+ d2 = copy(d1, {'b': 2})
474
+ assert isinstance(d2, dict)
475
+ assert not isinstance(d2, FrozenDict)
476
+ assert d2['a'] == 1
477
+ assert d2['b'] == 2
478
+ assert d1 == {'a': 1} # Original unchanged
479
+
480
+ # Invalid type
481
+ with pytest.raises(TypeError):
482
+ copy([1, 2, 3])
483
+
484
+ def test_pop_function(self):
485
+ """Test pop function."""
486
+ # FrozenDict
487
+ fd1 = FrozenDict({'a': 1, 'b': 2})
488
+ fd2, value = pop(fd1, 'a')
489
+ assert isinstance(fd2, FrozenDict)
490
+ assert value == 1
491
+ assert 'a' not in fd2
492
+ assert 'a' in fd1
493
+
494
+ # Regular dict
495
+ d1 = {'a': 1, 'b': 2}
496
+ d2, value = pop(d1, 'a')
497
+ assert isinstance(d2, dict)
498
+ assert value == 1
499
+ assert 'a' not in d2
500
+ assert 'a' in d1
501
+
502
+ # Invalid type
503
+ with pytest.raises(TypeError):
504
+ pop([1, 2, 3], 0)
505
+
506
+ def test_pretty_repr_function(self):
507
+ """Test pretty_repr function."""
508
+ # FrozenDict
509
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
510
+ s = pretty_repr(fd)
511
+ assert 'FrozenDict' in s
512
+
513
+ # Regular dict
514
+ d = {'a': 1, 'b': {'c': 2}}
515
+ s = pretty_repr(d)
516
+ assert 'a' in s
517
+ assert 'c' in s
518
+
519
+ # Other type
520
+ s = pretty_repr([1, 2, 3])
521
+ assert s == "[1, 2, 3]"
522
+
523
+
524
+ class TestIntegration:
525
+ """Integration tests combining multiple features."""
526
+
527
+ def test_nested_structures(self):
528
+ """Test complex nested structures."""
529
+
530
+ @dataclass
531
+ class Config:
532
+ hyperparams: FrozenDict
533
+ metadata: dict = field(pytree_node=False)
534
+
535
+ class Model(PyTreeNode):
536
+ config: Config
537
+ weights: jax.Array
538
+
539
+ config = Config(
540
+ hyperparams=FrozenDict({'lr': 0.001, 'batch_size': 32}),
541
+ metadata={'version': '1.0'}
542
+ )
543
+ model = Model(
544
+ config=config,
545
+ weights=jnp.ones((4, 4))
546
+ )
547
+
548
+ # Test tree operations
549
+ model2 = jax.tree_util.tree_map(lambda x: x * 2 if isinstance(x, jax.Array) else x, model)
550
+ assert jnp.allclose(model2.weights, model.weights * 2)
551
+ assert model2.config.hyperparams['lr'] == 0.001
552
+ assert model2.config.metadata['version'] == '1.0'
553
+
554
+ def test_jax_transformations_integration(self):
555
+ """Test integration with various JAX transformations."""
556
+
557
+ @dataclass
558
+ class State:
559
+ params: FrozenDict
560
+ step: int = field(pytree_node=False, default=0)
561
+
562
+ state = State(
563
+ params=FrozenDict({
564
+ 'w': jnp.ones((3, 3)),
565
+ 'b': jnp.zeros(3)
566
+ })
567
+ )
568
+
569
+ # JIT compilation
570
+ @jax.jit
571
+ def update(state, grad):
572
+ new_params = jax.tree_util.tree_map(
573
+ lambda p, g: p - 0.01 * g,
574
+ state.params,
575
+ grad
576
+ )
577
+ return state.replace(
578
+ params=new_params,
579
+ step=state.step + 1
580
+ )
581
+
582
+ grad = FrozenDict({'w': jnp.ones((3, 3)), 'b': jnp.ones(3)})
583
+ new_state = update(state, grad)
584
+ assert new_state.step == 1
585
+ assert jnp.allclose(new_state.params['w'], state.params['w'] - 0.01)
586
+
587
+ # VMAP
588
+ @jax.vmap
589
+ def batch_process(params, x):
590
+ return jnp.dot(x, params['w']) + params['b']
591
+
592
+ batch_params = jax.tree_util.tree_map(
593
+ lambda x: jnp.stack([x, x * 2]),
594
+ state.params
595
+ )
596
+ batch_x = jnp.ones((2, 3))
597
+ result = batch_process(batch_params, batch_x)
598
+ assert result.shape == (2, 3)
599
+
600
+
601
+ if __name__ == "__main__":
602
+ pytest.main([__file__, "-v"])