brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,602 @@
1
+ """
2
+ Comprehensive tests for the struct module.
3
+ """
4
+
5
+ import pickle
6
+ from typing import Any
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jax.tree_util
11
+ import pytest
12
+
13
+ # Import the modules to test
14
+ from brainstate.util import (
15
+ field,
16
+ dataclass,
17
+ PyTreeNode,
18
+ FrozenDict,
19
+ freeze,
20
+ unfreeze,
21
+ copy,
22
+ pop,
23
+ pretty_repr,
24
+ )
25
+
26
+
27
+ class TestField:
28
+ """Test the field function."""
29
+
30
+ def test_field_with_pytree_node_true(self):
31
+ """Test field with pytree_node=True."""
32
+ f = field(pytree_node=True)
33
+ assert f.metadata['pytree_node'] is True
34
+
35
+ def test_field_with_pytree_node_false(self):
36
+ """Test field with pytree_node=False."""
37
+ f = field(pytree_node=False)
38
+ assert f.metadata['pytree_node'] is False
39
+
40
+ def test_field_with_default(self):
41
+ """Test field with default value."""
42
+ f = field(default=42)
43
+ assert f.default == 42
44
+ assert f.metadata['pytree_node'] is True
45
+
46
+ def test_field_with_metadata(self):
47
+ """Test field preserves additional metadata."""
48
+ f = field(pytree_node=False, metadata={'custom': 'data'})
49
+ assert f.metadata['pytree_node'] is False
50
+ assert f.metadata['custom'] == 'data'
51
+
52
+
53
+ class TestDataclass:
54
+ """Test the dataclass decorator."""
55
+
56
+ def test_basic_dataclass(self):
57
+ """Test basic dataclass creation."""
58
+
59
+ @dataclass
60
+ class Point:
61
+ x: float
62
+ y: float
63
+
64
+ p = Point(1.0, 2.0)
65
+ assert p.x == 1.0
66
+ assert p.y == 2.0
67
+
68
+ def test_dataclass_is_frozen(self):
69
+ """Test that dataclasses are frozen by default."""
70
+
71
+ @dataclass
72
+ class Point:
73
+ x: float
74
+ y: float
75
+
76
+ p = Point(1.0, 2.0)
77
+ with pytest.raises(Exception): # Should be immutable
78
+ p.x = 3.0
79
+
80
+ def test_dataclass_replace_method(self):
81
+ """Test the replace method."""
82
+
83
+ @dataclass
84
+ class Point:
85
+ x: float
86
+ y: float
87
+
88
+ p1 = Point(1.0, 2.0)
89
+ p2 = p1.replace(x=3.0)
90
+ assert p1.x == 1.0
91
+ assert p2.x == 3.0
92
+ assert p2.y == 2.0
93
+
94
+ def test_dataclass_with_defaults(self):
95
+ """Test dataclass with default values."""
96
+
97
+ @dataclass
98
+ class Config:
99
+ learning_rate: float = 0.001
100
+ batch_size: int = 32
101
+ name: str = field(default="default", pytree_node=False)
102
+
103
+ c1 = Config()
104
+ assert c1.learning_rate == 0.001
105
+ assert c1.batch_size == 32
106
+ assert c1.name == "default"
107
+
108
+ c2 = Config(learning_rate=0.01)
109
+ assert c2.learning_rate == 0.01
110
+
111
+ def test_dataclass_pytree_behavior(self):
112
+ """Test that dataclass works as JAX pytree."""
113
+
114
+ @dataclass
115
+ class Model:
116
+ weights: jax.Array
117
+ bias: jax.Array
118
+ name: str = field(pytree_node=False, default="model")
119
+
120
+ weights = jnp.ones((3, 3))
121
+ bias = jnp.zeros(3)
122
+ model = Model(weights=weights, bias=bias)
123
+
124
+ # Test tree_map
125
+ model2 = jax.tree_util.tree_map(lambda x: x * 2, model)
126
+ assert jnp.allclose(model2.weights, weights * 2)
127
+ assert jnp.allclose(model2.bias, bias * 2)
128
+ assert model2.name == "model" # Should not be affected
129
+
130
+ # Test tree_leaves
131
+ leaves = jax.tree_util.tree_leaves(model)
132
+ assert len(leaves) == 2 # Only weights and bias
133
+
134
+ def test_dataclass_with_jax_transformations(self):
135
+ """Test dataclass with JAX transformations."""
136
+
137
+ @dataclass
138
+ class Linear:
139
+ weight: jax.Array
140
+ bias: jax.Array
141
+
142
+ layer = Linear(
143
+ weight=jnp.ones((4, 3)),
144
+ bias=jnp.zeros(4)
145
+ )
146
+
147
+ # Test with jit
148
+ @jax.jit
149
+ def apply(layer, x):
150
+ return jnp.dot(x, layer.weight.T) + layer.bias
151
+
152
+ x = jnp.ones(3)
153
+ y = apply(layer, x)
154
+ assert y.shape == (4,)
155
+
156
+ # Test with grad
157
+ def loss_fn(layer):
158
+ return jnp.sum(layer.weight ** 2) + jnp.sum(layer.bias ** 2)
159
+
160
+ grad_fn = jax.grad(loss_fn)
161
+ grads = grad_fn(layer)
162
+ assert grads.weight.shape == layer.weight.shape
163
+ assert grads.bias.shape == layer.bias.shape
164
+
165
+ def test_dataclass_no_double_decoration(self):
166
+ """Test that dataclass decorator is idempotent."""
167
+
168
+ @dataclass
169
+ @dataclass # Should not cause issues
170
+ class Point:
171
+ x: float
172
+ y: float
173
+
174
+ p = Point(1.0, 2.0)
175
+ assert p.x == 1.0
176
+ assert hasattr(Point, '_brainstate_dataclass')
177
+
178
+
179
+ class TestPyTreeNode:
180
+ """Test the PyTreeNode base class."""
181
+
182
+ def test_pytreenode_subclass(self):
183
+ """Test creating a PyTreeNode subclass."""
184
+
185
+ class Layer(PyTreeNode):
186
+ weights: jax.Array
187
+ bias: jax.Array
188
+ activation: str = field(pytree_node=False, default="relu")
189
+
190
+ layer = Layer(
191
+ weights=jnp.ones((4, 4)),
192
+ bias=jnp.zeros(4)
193
+ )
194
+ assert layer.activation == "relu"
195
+ assert jnp.allclose(layer.weights, jnp.ones((4, 4)))
196
+
197
+ def test_pytreenode_is_frozen(self):
198
+ """Test that PyTreeNode subclasses are frozen."""
199
+
200
+ class Layer(PyTreeNode):
201
+ weights: jax.Array
202
+
203
+ layer = Layer(weights=jnp.ones(3))
204
+ with pytest.raises(Exception):
205
+ layer.weights = jnp.zeros(3)
206
+
207
+ def test_pytreenode_replace(self):
208
+ """Test replace method on PyTreeNode."""
209
+
210
+ class Layer(PyTreeNode):
211
+ weights: jax.Array
212
+ bias: jax.Array
213
+
214
+ layer1 = Layer(weights=jnp.ones(3), bias=jnp.zeros(3))
215
+ layer2 = layer1.replace(weights=jnp.ones(3) * 2)
216
+ assert jnp.allclose(layer2.weights, jnp.ones(3) * 2)
217
+ assert jnp.allclose(layer2.bias, jnp.zeros(3))
218
+
219
+ def test_pytreenode_with_jax(self):
220
+ """Test PyTreeNode with JAX transformations."""
221
+
222
+ class MLP(PyTreeNode):
223
+ layer1: Any
224
+ layer2: Any
225
+
226
+ class Linear(PyTreeNode):
227
+ weight: jax.Array
228
+ bias: jax.Array
229
+
230
+ mlp = MLP(
231
+ layer1=Linear(weight=jnp.ones((4, 3)), bias=jnp.zeros(4)),
232
+ layer2=Linear(weight=jnp.ones((2, 4)), bias=jnp.zeros(2))
233
+ )
234
+
235
+ # Test tree_map
236
+ mlp2 = jax.tree_util.tree_map(lambda x: x * 2, mlp)
237
+ assert jnp.allclose(mlp2.layer1.weight, mlp.layer1.weight * 2)
238
+
239
+ # Test with grad
240
+ def loss_fn(model):
241
+ return jnp.sum(model.layer1.weight ** 2) + jnp.sum(model.layer2.weight ** 2)
242
+
243
+ grad_fn = jax.grad(loss_fn)
244
+ grads = grad_fn(mlp)
245
+ assert grads.layer1.weight.shape == mlp.layer1.weight.shape
246
+
247
+
248
+ class TestFrozenDict:
249
+ """Test the FrozenDict class."""
250
+
251
+ def test_frozendict_creation(self):
252
+ """Test creating FrozenDict."""
253
+ # From dict
254
+ fd1 = FrozenDict({'a': 1, 'b': 2})
255
+ assert fd1['a'] == 1
256
+ assert fd1['b'] == 2
257
+
258
+ # From kwargs
259
+ fd2 = FrozenDict(a=1, b=2)
260
+ assert fd2['a'] == 1
261
+
262
+ # From items
263
+ fd3 = FrozenDict([('a', 1), ('b', 2)])
264
+ assert fd3['a'] == 1
265
+
266
+ def test_frozendict_immutability(self):
267
+ """Test that FrozenDict is immutable."""
268
+ fd = FrozenDict({'a': 1})
269
+
270
+ with pytest.raises(TypeError):
271
+ fd['a'] = 2
272
+
273
+ with pytest.raises(TypeError):
274
+ fd['c'] = 3
275
+
276
+ with pytest.raises(TypeError):
277
+ del fd['a']
278
+
279
+ def test_frozendict_basic_operations(self):
280
+ """Test basic dictionary operations."""
281
+ fd = FrozenDict({'a': 1, 'b': 2, 'c': 3})
282
+
283
+ # Contains
284
+ assert 'a' in fd
285
+ assert 'd' not in fd
286
+
287
+ # Length
288
+ assert len(fd) == 3
289
+
290
+ # Iteration
291
+ keys = list(fd)
292
+ assert set(keys) == {'a', 'b', 'c'}
293
+
294
+ # Get
295
+ assert fd.get('a') == 1
296
+ assert fd.get('d') is None
297
+ assert fd.get('d', 10) == 10
298
+
299
+ def test_frozendict_views(self):
300
+ """Test dictionary views."""
301
+ fd = FrozenDict({'a': 1, 'b': 2})
302
+
303
+ # Keys view
304
+ keys = fd.keys()
305
+ assert set(keys) == {'a', 'b'}
306
+ assert 'FrozenDict.keys' in repr(keys)
307
+
308
+ # Values view
309
+ values = fd.values()
310
+ assert set(values) == {1, 2}
311
+ assert 'FrozenDict.values' in repr(values)
312
+
313
+ # Items view
314
+ items = list(fd.items())
315
+ assert len(items) == 2
316
+ assert ('a', 1) in items
317
+
318
+ def test_frozendict_copy(self):
319
+ """Test copy method."""
320
+ fd1 = FrozenDict({'a': 1, 'b': 2})
321
+
322
+ # Copy without changes
323
+ fd2 = fd1.copy()
324
+ assert fd2 == fd1
325
+ assert fd2 is not fd1
326
+
327
+ # Copy with updates
328
+ fd3 = fd1.copy({'c': 3, 'a': 10})
329
+ assert fd3['a'] == 10
330
+ assert fd3['b'] == 2
331
+ assert fd3['c'] == 3
332
+ assert fd1['a'] == 1 # Original unchanged
333
+
334
+ def test_frozendict_pop(self):
335
+ """Test pop method."""
336
+ fd1 = FrozenDict({'a': 1, 'b': 2, 'c': 3})
337
+
338
+ fd2, value = fd1.pop('b')
339
+ assert value == 2
340
+ assert 'b' not in fd2
341
+ assert len(fd2) == 2
342
+ assert 'b' in fd1 # Original unchanged
343
+
344
+ # Pop non-existent key
345
+ with pytest.raises(KeyError):
346
+ fd1.pop('d')
347
+
348
+ def test_frozendict_nested(self):
349
+ """Test nested FrozenDict."""
350
+ fd = FrozenDict({
351
+ 'a': 1,
352
+ 'b': {'c': 2, 'd': {'e': 3}}
353
+ })
354
+
355
+ # Access nested values
356
+ assert fd['b']['c'] == 2
357
+ assert fd['b']['d']['e'] == 3
358
+
359
+ # Nested values are also FrozenDict
360
+ assert isinstance(fd['b'], FrozenDict)
361
+ assert isinstance(fd['b']['d'], FrozenDict)
362
+
363
+ def test_frozendict_hash(self):
364
+ """Test FrozenDict hashing."""
365
+ fd1 = FrozenDict({'a': 1, 'b': 2})
366
+ fd2 = FrozenDict({'a': 1, 'b': 2})
367
+ fd3 = FrozenDict({'a': 1, 'b': 3})
368
+
369
+ # Equal dicts have same hash
370
+ assert hash(fd1) == hash(fd2)
371
+
372
+ # Can be used in sets
373
+ s = {fd1, fd2, fd3}
374
+ assert len(s) == 2
375
+
376
+ def test_frozendict_equality(self):
377
+ """Test FrozenDict equality."""
378
+ fd1 = FrozenDict({'a': 1, 'b': 2})
379
+ fd2 = FrozenDict({'a': 1, 'b': 2})
380
+ fd3 = FrozenDict({'a': 1, 'b': 3})
381
+ d = {'a': 1, 'b': 2}
382
+
383
+ assert fd1 == fd2
384
+ assert fd1 != fd3
385
+ assert fd1 == d
386
+ assert fd1 != "not a dict"
387
+
388
+ def test_frozendict_pickle(self):
389
+ """Test FrozenDict pickling."""
390
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
391
+
392
+ # Pickle and unpickle
393
+ pickled = pickle.dumps(fd)
394
+ fd2 = pickle.loads(pickled)
395
+
396
+ assert fd == fd2
397
+ assert fd['b']['c'] == fd2['b']['c']
398
+
399
+ def test_frozendict_pretty_repr(self):
400
+ """Test pretty representation."""
401
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
402
+ repr_str = fd.pretty_repr()
403
+
404
+ assert 'FrozenDict' in repr_str
405
+ assert "'a': 1" in repr_str
406
+ assert "'c': 2" in repr_str
407
+
408
+ def test_frozendict_as_pytree(self):
409
+ """Test FrozenDict as JAX pytree."""
410
+ fd = FrozenDict({'a': jnp.ones(3), 'b': jnp.zeros(2)})
411
+
412
+ # Tree map
413
+ fd2 = jax.tree_util.tree_map(lambda x: x * 2, fd)
414
+ assert jnp.allclose(fd2['a'], jnp.ones(3) * 2)
415
+ assert jnp.allclose(fd2['b'], jnp.zeros(2))
416
+
417
+ # Tree leaves
418
+ leaves = jax.tree_util.tree_leaves(fd)
419
+ assert len(leaves) == 2
420
+
421
+ # Tree flatten and unflatten
422
+ values, treedef = jax.tree_util.tree_flatten(fd)
423
+ fd3 = jax.tree_util.tree_unflatten(treedef, values)
424
+ assert fd == fd3
425
+
426
+
427
+ class TestUtilityFunctions:
428
+ """Test utility functions."""
429
+
430
+ def test_freeze(self):
431
+ """Test freeze function."""
432
+ # Regular dict
433
+ d = {'a': 1, 'b': {'c': 2}}
434
+ fd = freeze(d)
435
+ assert isinstance(fd, FrozenDict)
436
+ assert fd['a'] == 1
437
+ assert isinstance(fd['b'], FrozenDict)
438
+
439
+ # Already frozen
440
+ fd2 = freeze(fd)
441
+ assert fd2 is fd
442
+
443
+ def test_unfreeze(self):
444
+ """Test unfreeze function."""
445
+ # FrozenDict
446
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
447
+ d = unfreeze(fd)
448
+ assert isinstance(d, dict)
449
+ assert not isinstance(d, FrozenDict)
450
+ assert d['a'] == 1
451
+ assert isinstance(d['b'], dict)
452
+
453
+ # Regular dict
454
+ d2 = {'a': 1}
455
+ d3 = unfreeze(d2)
456
+ assert d3 == d2
457
+ assert d3 is not d2 # Should be a copy
458
+
459
+ # Non-dict
460
+ assert unfreeze(42) == 42
461
+
462
+ def test_copy_function(self):
463
+ """Test copy function."""
464
+ # FrozenDict
465
+ fd1 = FrozenDict({'a': 1})
466
+ fd2 = copy(fd1, {'b': 2})
467
+ assert isinstance(fd2, FrozenDict)
468
+ assert fd2['a'] == 1
469
+ assert fd2['b'] == 2
470
+
471
+ # Regular dict
472
+ d1 = {'a': 1}
473
+ d2 = copy(d1, {'b': 2})
474
+ assert isinstance(d2, dict)
475
+ assert not isinstance(d2, FrozenDict)
476
+ assert d2['a'] == 1
477
+ assert d2['b'] == 2
478
+ assert d1 == {'a': 1} # Original unchanged
479
+
480
+ # Invalid type
481
+ with pytest.raises(TypeError):
482
+ copy([1, 2, 3])
483
+
484
+ def test_pop_function(self):
485
+ """Test pop function."""
486
+ # FrozenDict
487
+ fd1 = FrozenDict({'a': 1, 'b': 2})
488
+ fd2, value = pop(fd1, 'a')
489
+ assert isinstance(fd2, FrozenDict)
490
+ assert value == 1
491
+ assert 'a' not in fd2
492
+ assert 'a' in fd1
493
+
494
+ # Regular dict
495
+ d1 = {'a': 1, 'b': 2}
496
+ d2, value = pop(d1, 'a')
497
+ assert isinstance(d2, dict)
498
+ assert value == 1
499
+ assert 'a' not in d2
500
+ assert 'a' in d1
501
+
502
+ # Invalid type
503
+ with pytest.raises(TypeError):
504
+ pop([1, 2, 3], 0)
505
+
506
+ def test_pretty_repr_function(self):
507
+ """Test pretty_repr function."""
508
+ # FrozenDict
509
+ fd = FrozenDict({'a': 1, 'b': {'c': 2}})
510
+ s = pretty_repr(fd)
511
+ assert 'FrozenDict' in s
512
+
513
+ # Regular dict
514
+ d = {'a': 1, 'b': {'c': 2}}
515
+ s = pretty_repr(d)
516
+ assert 'a' in s
517
+ assert 'c' in s
518
+
519
+ # Other type
520
+ s = pretty_repr([1, 2, 3])
521
+ assert s == "[1, 2, 3]"
522
+
523
+
524
+ class TestIntegration:
525
+ """Integration tests combining multiple features."""
526
+
527
+ def test_nested_structures(self):
528
+ """Test complex nested structures."""
529
+
530
+ @dataclass
531
+ class Config:
532
+ hyperparams: FrozenDict
533
+ metadata: dict = field(pytree_node=False)
534
+
535
+ class Model(PyTreeNode):
536
+ config: Config
537
+ weights: jax.Array
538
+
539
+ config = Config(
540
+ hyperparams=FrozenDict({'lr': 0.001, 'batch_size': 32}),
541
+ metadata={'version': '1.0'}
542
+ )
543
+ model = Model(
544
+ config=config,
545
+ weights=jnp.ones((4, 4))
546
+ )
547
+
548
+ # Test tree operations
549
+ model2 = jax.tree_util.tree_map(lambda x: x * 2 if isinstance(x, jax.Array) else x, model)
550
+ assert jnp.allclose(model2.weights, model.weights * 2)
551
+ assert model2.config.hyperparams['lr'] == 0.001
552
+ assert model2.config.metadata['version'] == '1.0'
553
+
554
+ def test_jax_transformations_integration(self):
555
+ """Test integration with various JAX transformations."""
556
+
557
+ @dataclass
558
+ class State:
559
+ params: FrozenDict
560
+ step: int = field(pytree_node=False, default=0)
561
+
562
+ state = State(
563
+ params=FrozenDict({
564
+ 'w': jnp.ones((3, 3)),
565
+ 'b': jnp.zeros(3)
566
+ })
567
+ )
568
+
569
+ # JIT compilation
570
+ @jax.jit
571
+ def update(state, grad):
572
+ new_params = jax.tree_util.tree_map(
573
+ lambda p, g: p - 0.01 * g,
574
+ state.params,
575
+ grad
576
+ )
577
+ return state.replace(
578
+ params=new_params,
579
+ step=state.step + 1
580
+ )
581
+
582
+ grad = FrozenDict({'w': jnp.ones((3, 3)), 'b': jnp.ones(3)})
583
+ new_state = update(state, grad)
584
+ assert new_state.step == 1
585
+ assert jnp.allclose(new_state.params['w'], state.params['w'] - 0.01)
586
+
587
+ # VMAP
588
+ @jax.vmap
589
+ def batch_process(params, x):
590
+ return jnp.dot(x, params['w']) + params['b']
591
+
592
+ batch_params = jax.tree_util.tree_map(
593
+ lambda x: jnp.stack([x, x * 2]),
594
+ state.params
595
+ )
596
+ batch_x = jnp.ones((2, 3))
597
+ result = batch_process(batch_params, batch_x)
598
+ assert result.shape == (2, 3)
599
+
600
+
601
+ if __name__ == "__main__":
602
+ pytest.main([__file__, "-v"])
@@ -1,14 +1,15 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: brainstate
3
- Version: 0.1.9
4
- Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
- Home-page: https://github.com/chaobrain/brainstate
6
- Author: BrainState Developers
3
+ Version: 0.2.0
4
+ Summary: A State-based Transformation System for Brain Modeling.
7
5
  Author-email: BrainState Developers <chao.brain@qq.com>
8
6
  License: Apache-2.0 license
9
- Project-URL: homepage, http://github.com/chaobrain/brainstate
10
- Project-URL: repository, http://github.com/chaobrain/brainstate
11
- Keywords: computational neuroscience,brain-inspired computation,brain dynamics programming
7
+ Project-URL: homepage, https://github.com/chaobrain/brainstate
8
+ Project-URL: repository, https://github.com/chaobrain/brainstate
9
+ Project-URL: Documentation, https://brainstate.readthedocs.io/
10
+ Project-URL: Source Code, https://github.com/chaobrain/brainstate
11
+ Project-URL: Bug Tracker, https://github.com/chaobrain/brainstate/issues
12
+ Keywords: computational neuroscience,brain-inspired computing,brain simulation,brain modeling,spiking neural networks
12
13
  Classifier: Natural Language :: English
13
14
  Classifier: Operating System :: OS Independent
14
15
  Classifier: Development Status :: 4 - Beta
@@ -28,20 +29,36 @@ Classifier: Topic :: Software Development :: Libraries
28
29
  Requires-Python: >=3.10
29
30
  Description-Content-Type: text/markdown
30
31
  License-File: LICENSE
31
- Requires-Dist: jax
32
- Requires-Dist: jaxlib
33
- Requires-Dist: numpy
34
- Requires-Dist: brainunit>=0.1.0
32
+ Requires-Dist: numpy>=1.15
33
+ Requires-Dist: tqdm
34
+ Requires-Dist: brainunit
35
35
  Requires-Dist: brainevent
36
+ Provides-Extra: cpu
37
+ Requires-Dist: jax[cpu]; extra == "cpu"
38
+ Requires-Dist: brainunit; extra == "cpu"
39
+ Requires-Dist: brainevent; extra == "cpu"
40
+ Provides-Extra: cuda12
41
+ Requires-Dist: jax[cuda12]; extra == "cuda12"
42
+ Requires-Dist: brainunit; extra == "cuda12"
43
+ Requires-Dist: brainevent; extra == "cuda12"
44
+ Provides-Extra: cuda13
45
+ Requires-Dist: jax[cuda13]; extra == "cuda13"
46
+ Requires-Dist: brainunit; extra == "cuda13"
47
+ Requires-Dist: brainevent; extra == "cuda13"
48
+ Provides-Extra: tpu
49
+ Requires-Dist: jax[tpu]; extra == "tpu"
50
+ Requires-Dist: brainunit; extra == "tpu"
51
+ Requires-Dist: brainevent; extra == "tpu"
36
52
  Provides-Extra: testing
53
+ Requires-Dist: absl-py; extra == "testing"
37
54
  Requires-Dist: pytest; extra == "testing"
38
- Dynamic: author
39
- Dynamic: home-page
55
+ Requires-Dist: jax; extra == "testing"
56
+ Requires-Dist: brainunit; extra == "testing"
57
+ Requires-Dist: brainevent; extra == "testing"
40
58
  Dynamic: license-file
41
- Dynamic: requires-python
42
59
 
43
60
 
44
- # A ``State``-based Transformation System for Program Compilation and Augmentation
61
+ # A ``State``-based Transformation System for Brain Modeling
45
62
 
46
63
 
47
64
 
@@ -84,8 +101,8 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
84
101
 
85
102
 
86
103
 
87
- ## See also the brain modeling ecosystem
104
+ ## See also the ecosystem
88
105
 
89
- We are building the brain modeling ecosystem: https://brainmodeling.readthedocs.io/
106
+ ``brainstate`` is one part of our brain simulation ecosystem: https://brainmodeling.readthedocs.io/
90
107
 
91
108