brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1510 +1,1634 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import threading
18
- import unittest
19
-
20
- import jax
21
- import jax.numpy as jnp
22
- import pytest
23
-
24
- import brainstate
25
- from brainstate._error import BatchAxisError
26
- from brainstate._compatible_import import jaxpr_as_fun
27
- from brainstate.transform._make_jaxpr import _BoundedCache, make_hashable
28
-
29
-
30
- class TestMakeJaxpr(unittest.TestCase):
31
- def test_compar_jax_make_jaxpr(self):
32
- def func4(arg): # Arg is a pair
33
- temp = arg[0] + jnp.sin(arg[1]) * 3.
34
- c = brainstate.random.rand_like(arg[0])
35
- return jnp.sum(temp + c)
36
-
37
- key = brainstate.random.DEFAULT.value
38
- jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
39
- print(jaxpr)
40
- self.assertTrue(len(jaxpr.in_avals) == 2)
41
- self.assertTrue(len(jaxpr.consts) == 1)
42
- self.assertTrue(len(jaxpr.out_avals) == 1)
43
- self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
44
-
45
- brainstate.random.seed(1)
46
- print(brainstate.random.DEFAULT.value)
47
-
48
- jaxpr2, states = brainstate.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
49
- print(jaxpr2)
50
- self.assertTrue(len(jaxpr2.in_avals) == 3)
51
- self.assertTrue(len(jaxpr2.out_avals) == 2)
52
- self.assertTrue(len(jaxpr2.consts) == 0)
53
- print(brainstate.random.DEFAULT.value)
54
-
55
- def test_StatefulFunction_1(self):
56
- def func4(arg): # Arg is a pair
57
- temp = arg[0] + jnp.sin(arg[1]) * 3.
58
- c = brainstate.random.rand_like(arg[0])
59
- return jnp.sum(temp + c)
60
-
61
- fun = brainstate.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
62
- cache_key = fun.get_arg_cache_key((jnp.zeros(8), jnp.ones(8)))
63
- print(fun.get_states_by_cache(cache_key))
64
- print(fun.get_jaxpr_by_cache(cache_key))
65
-
66
- def test_StatefulFunction_2(self):
67
- st1 = brainstate.State(jnp.ones(10))
68
-
69
- def f1(x):
70
- st1.value = x + st1.value
71
-
72
- def f2(x):
73
- jaxpr = brainstate.transform.make_jaxpr(f1)(x)
74
- c = 1. + x
75
- return c
76
-
77
- def f3(x):
78
- jaxpr = brainstate.transform.make_jaxpr(f1)(x)
79
- c = 1.
80
- return c
81
-
82
- print()
83
- jaxpr = brainstate.transform.make_jaxpr(f1)(jnp.zeros(1))
84
- print(jaxpr)
85
- jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
86
- print(jaxpr)
87
- jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
88
- print(jaxpr)
89
- jaxpr, _ = brainstate.transform.make_jaxpr(f3)(jnp.zeros(1))
90
- print(jaxpr)
91
- self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
92
- f3(jnp.zeros(1))))
93
-
94
- def test_compare_jax_make_jaxpr2(self):
95
- st1 = brainstate.State(jnp.ones(10))
96
-
97
- def fa(x):
98
- st1.value = x + st1.value
99
-
100
- def ffa(x):
101
- jaxpr, states = brainstate.transform.make_jaxpr(fa)(x)
102
- c = 1. + x
103
- return c
104
-
105
- jaxpr, states = brainstate.transform.make_jaxpr(ffa)(jnp.zeros(1))
106
- print()
107
- print(jaxpr)
108
- print(states)
109
- print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
110
- jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
111
- print(jaxpr)
112
- print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
113
-
114
- def test_compare_jax_make_jaxpr3(self):
115
- def fa(x):
116
- return 1.
117
-
118
- jaxpr, states = brainstate.transform.make_jaxpr(fa)(jnp.zeros(1))
119
- print()
120
- print(jaxpr)
121
- print(states)
122
- # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
- jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
124
- print(jaxpr)
125
- # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
126
-
127
- def test_static_argnames(self):
128
- def func4(a, b): # Arg is a pair
129
- temp = a + jnp.sin(b) * 3.
130
- c = brainstate.random.rand_like(a)
131
- return jnp.sum(temp + c)
132
-
133
- jaxpr, states = brainstate.transform.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
134
- print()
135
- print(jaxpr)
136
- print(states)
137
-
138
- def test_state_in(self):
139
- def f(a):
140
- return a.value
141
-
142
- with pytest.raises(ValueError):
143
- brainstate.transform.StatefulFunction(f).make_jaxpr(brainstate.State(1.))
144
-
145
- def test_state_out(self):
146
- def f(a):
147
- return brainstate.State(a)
148
-
149
- with pytest.raises(ValueError):
150
- brainstate.transform.StatefulFunction(f).make_jaxpr(1.)
151
-
152
- def test_return_states(self):
153
- a = brainstate.State(jnp.ones(3))
154
-
155
- @brainstate.transform.jit
156
- def f():
157
- return a
158
-
159
- with pytest.raises(ValueError):
160
- f()
161
-
162
-
163
- class TestBoundedCache(unittest.TestCase):
164
- """Test the _BoundedCache class."""
165
-
166
- def test_cache_basic_operations(self):
167
- """Test basic get and set operations."""
168
- cache = _BoundedCache(maxsize=3)
169
-
170
- # Test set and get
171
- cache.set('key1', 'value1')
172
- self.assertEqual(cache.get('key1'), 'value1')
173
-
174
- # Test default value
175
- self.assertIsNone(cache.get('nonexistent'))
176
- self.assertEqual(cache.get('nonexistent', 'default'), 'default')
177
-
178
- # Test __contains__
179
- self.assertIn('key1', cache)
180
- self.assertNotIn('key2', cache)
181
-
182
- # Test __len__
183
- self.assertEqual(len(cache), 1)
184
-
185
- def test_cache_lru_eviction(self):
186
- """Test LRU eviction when cache is full."""
187
- cache = _BoundedCache(maxsize=3)
188
-
189
- # Fill cache
190
- cache.set('key1', 'value1')
191
- cache.set('key2', 'value2')
192
- cache.set('key3', 'value3')
193
- self.assertEqual(len(cache), 3)
194
-
195
- # Add one more, should evict key1 (least recently used)
196
- cache.set('key4', 'value4')
197
- self.assertEqual(len(cache), 3)
198
- self.assertNotIn('key1', cache)
199
- self.assertIn('key4', cache)
200
-
201
- # Access key2 to make it recently used
202
- cache.get('key2')
203
-
204
- # Add another key, should evict key3 (now least recently used)
205
- cache.set('key5', 'value5')
206
- self.assertNotIn('key3', cache)
207
- self.assertIn('key2', cache)
208
-
209
- def test_cache_update_existing(self):
210
- """Test updating an existing key."""
211
- cache = _BoundedCache(maxsize=2)
212
-
213
- cache.set('key1', 'value1')
214
- cache.set('key2', 'value2')
215
-
216
- # Update key1 (should move it to end)
217
- cache.replace('key1', 'updated_value1')
218
- self.assertEqual(cache.get('key1'), 'updated_value1')
219
-
220
- # Add new key, should evict key2 (now LRU)
221
- cache.set('key3', 'value3')
222
- self.assertNotIn('key2', cache)
223
- self.assertIn('key1', cache)
224
-
225
- def test_cache_statistics(self):
226
- """Test cache statistics tracking."""
227
- cache = _BoundedCache(maxsize=5)
228
-
229
- # Initial stats
230
- stats = cache.get_stats()
231
- self.assertEqual(stats['size'], 0)
232
- self.assertEqual(stats['maxsize'], 5)
233
- self.assertEqual(stats['hits'], 0)
234
- self.assertEqual(stats['misses'], 0)
235
- self.assertEqual(stats['hit_rate'], 0.0)
236
-
237
- # Add items and test hits/misses
238
- cache.set('key1', 'value1')
239
- cache.set('key2', 'value2')
240
-
241
- # Generate hits
242
- cache.get('key1') # hit
243
- cache.get('key1') # hit
244
- cache.get('key3') # miss
245
- cache.get('key2') # hit
246
-
247
- stats = cache.get_stats()
248
- self.assertEqual(stats['size'], 2)
249
- self.assertEqual(stats['hits'], 3)
250
- self.assertEqual(stats['misses'], 1)
251
- self.assertEqual(stats['hit_rate'], 75.0)
252
-
253
- def test_cache_clear(self):
254
- """Test clearing the cache."""
255
- cache = _BoundedCache(maxsize=5)
256
-
257
- # Add items
258
- cache.set('key1', 'value1')
259
- cache.set('key2', 'value2')
260
- cache.get('key1') # Generate a hit
261
-
262
- # Clear cache
263
- cache.clear()
264
-
265
- self.assertEqual(len(cache), 0)
266
- self.assertNotIn('key1', cache)
267
-
268
- # Check stats are reset
269
- stats = cache.get_stats()
270
- self.assertEqual(stats['hits'], 0)
271
- self.assertEqual(stats['misses'], 0)
272
-
273
- def test_cache_keys(self):
274
- """Test getting all cache keys."""
275
- cache = _BoundedCache(maxsize=5)
276
-
277
- cache.set('key1', 'value1')
278
- cache.set('key2', 'value2')
279
- cache.set('key3', 'value3')
280
-
281
- keys = cache.keys()
282
- self.assertEqual(set(keys), {'key1', 'key2', 'key3'})
283
-
284
- def test_cache_set_duplicate_raises(self):
285
- """Test that setting an existing key raises ValueError."""
286
- cache = _BoundedCache(maxsize=5)
287
-
288
- cache.set('key1', 'value1')
289
-
290
- # Attempting to set the same key should raise ValueError
291
- with pytest.raises(ValueError, match="Cache key already exists"):
292
- cache.set('key1', 'value2')
293
-
294
- def test_cache_pop(self):
295
- """Test pop method."""
296
- cache = _BoundedCache(maxsize=5)
297
-
298
- cache.set('key1', 'value1')
299
- cache.set('key2', 'value2')
300
-
301
- # Pop existing key
302
- value = cache.pop('key1')
303
- self.assertEqual(value, 'value1')
304
- self.assertNotIn('key1', cache)
305
- self.assertEqual(len(cache), 1)
306
-
307
- # Pop non-existent key with default
308
- value = cache.pop('nonexistent', 'default')
309
- self.assertEqual(value, 'default')
310
-
311
- # Pop non-existent key without default
312
- value = cache.pop('nonexistent')
313
- self.assertIsNone(value)
314
-
315
- def test_cache_replace(self):
316
- """Test replace method."""
317
- cache = _BoundedCache(maxsize=5)
318
-
319
- cache.set('key1', 'value1')
320
- cache.set('key2', 'value2')
321
-
322
- # Replace existing key
323
- cache.replace('key1', 'new_value1')
324
- self.assertEqual(cache.get('key1'), 'new_value1')
325
-
326
- # Replacing should move to end (most recently used)
327
- cache.set('key3', 'value3')
328
- cache.replace('key2', 'new_value2')
329
-
330
- # Add more items to test LRU behavior
331
- cache.set('key4', 'value4')
332
- cache.set('key5', 'value5')
333
-
334
- # Now when we add key6, key1 should be evicted (oldest after replace moved key2 to end)
335
- cache.set('key6', 'value6')
336
-
337
- # key2 should still be there because replace moved it to end
338
- self.assertIn('key2', cache)
339
-
340
- def test_cache_replace_nonexistent_raises(self):
341
- """Test that replacing a non-existent key raises KeyError."""
342
- cache = _BoundedCache(maxsize=5)
343
-
344
- with pytest.raises(KeyError, match="Cache key does not exist"):
345
- cache.replace('nonexistent', 'value')
346
-
347
- def test_cache_get_with_raise_on_miss(self):
348
- """Test get method with raise_on_miss parameter."""
349
- cache = _BoundedCache(maxsize=5)
350
-
351
- cache.set('key1', 'value1')
352
- cache.set('key2', 'value2')
353
-
354
- # Should work normally for existing key
355
- value = cache.get('key1', raise_on_miss=True)
356
- self.assertEqual(value, 'value1')
357
-
358
- # Should raise ValueError for missing key with raise_on_miss=True
359
- with pytest.raises(ValueError, match="not compiled for the requested cache key"):
360
- cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
361
-
362
- def test_cache_detailed_error_message(self):
363
- """Test that error message shows available keys."""
364
- cache = _BoundedCache(maxsize=5)
365
-
366
- cache.set('key1', 'value1')
367
- cache.set('key2', 'value2')
368
-
369
- # Error should include all available keys
370
- with pytest.raises(ValueError) as exc_info:
371
- cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
372
-
373
- error_msg = str(exc_info.value)
374
- # Should show requested key
375
- self.assertIn('nonexistent', error_msg)
376
- # Should show available keys
377
- self.assertIn('key1', error_msg)
378
- self.assertIn('key2', error_msg)
379
- # Should have helpful message
380
- self.assertIn('make_jaxpr()', error_msg)
381
-
382
- def test_cache_error_message_no_keys(self):
383
- """Test error message when cache is empty."""
384
- cache = _BoundedCache(maxsize=5)
385
-
386
- with pytest.raises(ValueError) as exc_info:
387
- cache.get('key', raise_on_miss=True, error_context="Empty cache")
388
-
389
- error_msg = str(exc_info.value)
390
- # Should indicate no keys available
391
- self.assertIn('none', error_msg.lower())
392
-
393
- def test_cache_thread_safety(self):
394
- """Test thread safety of cache operations."""
395
- cache = _BoundedCache(maxsize=100)
396
- errors = []
397
-
398
- def worker(thread_id):
399
- try:
400
- for i in range(50):
401
- key = f'key_{thread_id}_{i}'
402
- cache.set(key, f'value_{thread_id}_{i}')
403
- value = cache.get(key)
404
- if value != f'value_{thread_id}_{i}':
405
- errors.append(f'Mismatch in thread {thread_id}')
406
- except Exception as e:
407
- errors.append(f'Error in thread {thread_id}: {e}')
408
-
409
- # Create multiple threads
410
- threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
411
-
412
- # Start all threads
413
- for t in threads:
414
- t.start()
415
-
416
- # Wait for all threads to complete
417
- for t in threads:
418
- t.join()
419
-
420
- # Check no errors occurred
421
- self.assertEqual(len(errors), 0, f"Thread safety errors: {errors}")
422
-
423
-
424
- class TestStatefulFunctionEnhancements(unittest.TestCase):
425
- """Test enhancements to StatefulFunction class."""
426
-
427
- def test_cache_stats(self):
428
- """Test get_cache_stats method."""
429
- state = brainstate.State(jnp.array([1.0, 2.0]))
430
-
431
- def f(x):
432
- state.value += x
433
- return state.value * 2
434
-
435
- sf = brainstate.transform.StatefulFunction(f)
436
-
437
- # Compile for different inputs
438
- x1 = jnp.array([0.5, 0.5])
439
- x2 = jnp.array([1.0, 1.0])
440
-
441
- sf.make_jaxpr(x1)
442
- sf.make_jaxpr(x2)
443
-
444
- # Get cache stats
445
- stats = sf.get_cache_stats()
446
-
447
- # Verify all cache types are present
448
- self.assertIn('jaxpr_cache', stats)
449
- self.assertIn('out_shapes_cache', stats)
450
- self.assertIn('jaxpr_out_tree_cache', stats)
451
- self.assertIn('state_trace_cache', stats)
452
-
453
- # Verify each cache has proper stats
454
- for cache_name, cache_stats in stats.items():
455
- self.assertIn('size', cache_stats)
456
- self.assertIn('maxsize', cache_stats)
457
- self.assertIn('hits', cache_stats)
458
- self.assertIn('misses', cache_stats)
459
- self.assertIn('hit_rate', cache_stats)
460
-
461
- def test_validate_states(self):
462
- """Test validate_states method."""
463
- state = brainstate.State(jnp.array([1.0, 2.0]))
464
-
465
- def f(x):
466
- state.value += x
467
- return state.value
468
-
469
- sf = brainstate.transform.StatefulFunction(f)
470
- x = jnp.array([0.5, 0.5])
471
- sf.make_jaxpr(x)
472
-
473
- cache_key = sf.get_arg_cache_key(x)
474
-
475
- # Should validate successfully
476
- result = sf.validate_states(cache_key)
477
- self.assertTrue(result)
478
-
479
- def test_validate_all_states(self):
480
- """Test validate_all_states method."""
481
- state = brainstate.State(jnp.array([1.0, 2.0]))
482
-
483
- def f(x, n):
484
- state.value += x
485
- return state.value * n
486
-
487
- # Use static_argnums to create different cache keys
488
- sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
489
-
490
- # Compile for multiple inputs with different static args
491
- x = jnp.array([0.5, 0.5])
492
-
493
- sf.make_jaxpr(x, 1)
494
- sf.make_jaxpr(x, 2)
495
-
496
- # Validate all
497
- results = sf.validate_all_states()
498
-
499
- # Should have results for both cache keys
500
- self.assertEqual(len(results), 2)
501
-
502
- # All should be valid
503
- for result in results.values():
504
- self.assertTrue(result)
505
-
506
- def test_clear_cache(self):
507
- """Test clear_cache method."""
508
- state = brainstate.State(jnp.array([1.0, 2.0]))
509
-
510
- def f(x):
511
- state.value += x
512
- return state.value
513
-
514
- sf = brainstate.transform.StatefulFunction(f)
515
- x = jnp.array([0.5, 0.5])
516
- sf.make_jaxpr(x)
517
-
518
- # Verify cache has entries
519
- stats = sf.get_cache_stats()
520
- self.assertGreater(stats['jaxpr_cache']['size'], 0)
521
-
522
- # Clear cache
523
- sf.clear_cache()
524
-
525
- # Verify all caches are empty
526
- stats = sf.get_cache_stats()
527
- self.assertEqual(stats['jaxpr_cache']['size'], 0)
528
- self.assertEqual(stats['out_shapes_cache']['size'], 0)
529
- self.assertEqual(stats['jaxpr_out_tree_cache']['size'], 0)
530
- self.assertEqual(stats['state_trace_cache']['size'], 0)
531
-
532
- def test_return_only_write_parameter(self):
533
- """Test return_only_write parameter."""
534
- read_state = brainstate.State(jnp.array([1.0, 2.0]))
535
- write_state = brainstate.State(jnp.array([3.0, 4.0]))
536
-
537
- def f(x):
538
- # Read from read_state, write to write_state
539
- _ = read_state.value + x
540
- write_state.value += x
541
- return write_state.value
542
-
543
- # Test with return_only_write=False (default)
544
- sf_all = brainstate.transform.StatefulFunction(f, return_only_write=False)
545
- sf_all.make_jaxpr(jnp.array([0.5, 0.5]))
546
- cache_key = sf_all.get_arg_cache_key(jnp.array([0.5, 0.5]))
547
- states_all = sf_all.get_states_by_cache(cache_key)
548
-
549
- # Test with return_only_write=True
550
- sf_write_only = brainstate.transform.StatefulFunction(f, return_only_write=True)
551
- sf_write_only.make_jaxpr(jnp.array([0.5, 0.5]))
552
- cache_key_write = sf_write_only.get_arg_cache_key(jnp.array([0.5, 0.5]))
553
- states_write = sf_write_only.get_states_by_cache(cache_key_write)
554
-
555
- # With return_only_write=True, should have fewer or equal states
556
- self.assertLessEqual(len(states_write), len(states_all))
557
-
558
-
559
- class TestErrorHandling(unittest.TestCase):
560
- """Test error handling in StatefulFunction."""
561
-
562
- def test_jaxpr_call_state_mismatch(self):
563
- """Test error when state values length doesn't match."""
564
- state1 = brainstate.State(jnp.array([1.0, 2.0]))
565
- state2 = brainstate.State(jnp.array([3.0, 4.0]))
566
-
567
- def f(x):
568
- state1.value += x
569
- state2.value += x
570
- return state1.value + state2.value
571
-
572
- sf = brainstate.transform.StatefulFunction(f)
573
- x = jnp.array([0.5, 0.5])
574
- sf.make_jaxpr(x)
575
-
576
- # Try to call with wrong number of state values (only 1 instead of 2)
577
- with pytest.raises(ValueError, match="State length mismatch"):
578
- sf.jaxpr_call([jnp.array([1.0, 1.0])], x) # Only 1 state instead of 2
579
-
580
- def test_get_jaxpr_not_compiled_detailed_error(self):
581
- """Test detailed error message when getting jaxpr for uncompiled function."""
582
- state = brainstate.State(jnp.array([1.0, 2.0]))
583
-
584
- def f(x):
585
- return x * 2
586
-
587
- sf = brainstate.transform.StatefulFunction(f)
588
-
589
- # Compile for one input shape
590
- sf.make_jaxpr(jnp.array([1.0, 2.0]))
591
-
592
- # Try to get jaxpr with a different cache key
593
- from brainstate.transform._make_jaxpr import hashabledict
594
- fake_key = hashabledict(
595
- static_args=(),
596
- dyn_args=(),
597
- static_kwargs=(),
598
- dyn_kwargs=()
599
- )
600
-
601
- # Should raise detailed error
602
- with pytest.raises(ValueError) as exc_info:
603
- sf.get_jaxpr_by_cache(fake_key)
604
-
605
- error_msg = str(exc_info.value)
606
- # Should contain the requested key
607
- self.assertIn('Requested key:', error_msg)
608
- # Should show available keys
609
- self.assertIn('Available', error_msg)
610
- # Should have helpful message
611
- self.assertIn('make_jaxpr()', error_msg)
612
-
613
- def test_get_out_shapes_not_compiled_detailed_error(self):
614
- """Test detailed error message when getting output shapes for uncompiled function."""
615
-
616
- def f(x):
617
- return x * 2
618
-
619
- sf = brainstate.transform.StatefulFunction(f)
620
-
621
- from brainstate.transform._make_jaxpr import hashabledict
622
- fake_key = hashabledict(
623
- static_args=(),
624
- dyn_args=(),
625
- static_kwargs=(),
626
- dyn_kwargs=()
627
- )
628
-
629
- # Should raise detailed error with context "Output shapes"
630
- with pytest.raises(ValueError) as exc_info:
631
- sf.get_out_shapes_by_cache(fake_key)
632
-
633
- error_msg = str(exc_info.value)
634
- self.assertIn('Output shapes', error_msg)
635
- self.assertIn('Requested key:', error_msg)
636
-
637
- def test_get_out_treedef_not_compiled_detailed_error(self):
638
- """Test detailed error message when getting output tree for uncompiled function."""
639
-
640
- def f(x):
641
- return x * 2
642
-
643
- sf = brainstate.transform.StatefulFunction(f)
644
-
645
- from brainstate.transform._make_jaxpr import hashabledict
646
- fake_key = hashabledict(
647
- static_args=(),
648
- dyn_args=(),
649
- static_kwargs=(),
650
- dyn_kwargs=()
651
- )
652
-
653
- # Should raise detailed error with context "Output tree"
654
- with pytest.raises(ValueError) as exc_info:
655
- sf.get_out_treedef_by_cache(fake_key)
656
-
657
- error_msg = str(exc_info.value)
658
- self.assertIn('Output tree', error_msg)
659
- self.assertIn('Requested key:', error_msg)
660
-
661
- def test_get_state_trace_not_compiled_detailed_error(self):
662
- """Test detailed error message when getting state trace for uncompiled function."""
663
-
664
- def f(x):
665
- return x * 2
666
-
667
- sf = brainstate.transform.StatefulFunction(f)
668
-
669
- from brainstate.transform._make_jaxpr import hashabledict
670
- fake_key = hashabledict(
671
- static_args=(),
672
- dyn_args=(),
673
- static_kwargs=(),
674
- dyn_kwargs=()
675
- )
676
-
677
- # Should raise detailed error with context "State trace"
678
- with pytest.raises(ValueError) as exc_info:
679
- sf.get_state_trace_by_cache(fake_key)
680
-
681
- error_msg = str(exc_info.value)
682
- self.assertIn('State trace', error_msg)
683
- self.assertIn('Requested key:', error_msg)
684
-
685
-
686
- class TestCompileIfMiss(unittest.TestCase):
687
- """Test compile_if_miss parameter in *_by_call methods."""
688
-
689
- def test_get_jaxpr_by_call_with_compile_if_miss_true(self):
690
- """Test get_jaxpr_by_call with compile_if_miss=True (default)."""
691
-
692
- def f(x):
693
- return x * 2
694
-
695
- sf = brainstate.transform.StatefulFunction(f)
696
-
697
- # Should compile automatically
698
- jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=True)
699
- self.assertIsNotNone(jaxpr)
700
-
701
- def test_get_jaxpr_by_call_with_compile_if_miss_false(self):
702
- """Test get_jaxpr_by_call with compile_if_miss=False."""
703
-
704
- def f(x):
705
- return x * 2
706
-
707
- sf = brainstate.transform.StatefulFunction(f)
708
-
709
- # Should raise error because not compiled
710
- with pytest.raises(ValueError, match="not compiled"):
711
- sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=False)
712
-
713
- def test_get_out_shapes_by_call_compile_if_miss(self):
714
- """Test get_out_shapes_by_call with compile_if_miss parameter."""
715
- state = brainstate.State(jnp.array([1.0, 2.0]))
716
-
717
- def f(x):
718
- state.value += x
719
- return state.value * 2
720
-
721
- sf = brainstate.transform.StatefulFunction(f)
722
-
723
- # With compile_if_miss=True, should compile automatically
724
- shapes = sf.get_out_shapes(jnp.array([1.0, 2.0]), compile_if_miss=True)
725
- self.assertIsNotNone(shapes)
726
-
727
- # With compile_if_miss=False on different input, should fail
728
- with pytest.raises(ValueError):
729
- sf.get_out_shapes(jnp.array([1.0, 2.0, 3.0]), compile_if_miss=False)
730
-
731
- def test_get_out_treedef_by_call_compile_if_miss(self):
732
- """Test get_out_treedef_by_call with compile_if_miss parameter."""
733
-
734
- def f(x):
735
- return x * 2, x + 1
736
-
737
- sf = brainstate.transform.StatefulFunction(f)
738
-
739
- # Should compile automatically with default compile_if_miss=True
740
- treedef = sf.get_out_treedef(jnp.array([1.0, 2.0]))
741
- self.assertIsNotNone(treedef)
742
-
743
- def test_get_state_trace_by_call_compile_if_miss(self):
744
- """Test get_state_trace_by_call with compile_if_miss parameter."""
745
- state = brainstate.State(jnp.array([1.0, 2.0]))
746
-
747
- def f(x):
748
- state.value += x
749
- return state.value
750
-
751
- sf = brainstate.transform.StatefulFunction(f)
752
-
753
- # Should compile automatically
754
- trace = sf.get_state_trace(jnp.array([1.0, 2.0]), compile_if_miss=True)
755
- self.assertIsNotNone(trace)
756
-
757
- def test_get_states_by_call_compile_if_miss(self):
758
- """Test get_states_by_call with compile_if_miss parameter."""
759
- state1 = brainstate.State(jnp.array([1.0, 2.0]))
760
- state2 = brainstate.State(jnp.array([3.0, 4.0]))
761
-
762
- def f(x):
763
- state1.value += x
764
- state2.value += x
765
- return state1.value + state2.value
766
-
767
- sf = brainstate.transform.StatefulFunction(f)
768
-
769
- # Should compile automatically
770
- states = sf.get_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
771
- self.assertEqual(len(states), 2)
772
-
773
- def test_get_read_states_by_call_compile_if_miss(self):
774
- """Test get_read_states_by_call with compile_if_miss parameter."""
775
- read_state = brainstate.State(jnp.array([1.0, 2.0]))
776
- write_state = brainstate.State(jnp.array([3.0, 4.0]))
777
-
778
- def f(x):
779
- _ = read_state.value
780
- write_state.value += x
781
- return write_state.value
782
-
783
- sf = brainstate.transform.StatefulFunction(f)
784
-
785
- # Should compile automatically
786
- read_states = sf.get_read_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
787
- self.assertIsNotNone(read_states)
788
-
789
- def test_get_write_states_by_call_compile_if_miss(self):
790
- """Test get_write_states_by_call with compile_if_miss parameter."""
791
- read_state = brainstate.State(jnp.array([1.0, 2.0]))
792
- write_state = brainstate.State(jnp.array([3.0, 4.0]))
793
-
794
- def f(x):
795
- _ = read_state.value
796
- write_state.value += x
797
- return write_state.value
798
-
799
- sf = brainstate.transform.StatefulFunction(f)
800
-
801
- # Should compile automatically
802
- write_states = sf.get_write_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
803
- self.assertIsNotNone(write_states)
804
-
805
- def test_compile_if_miss_default_behavior(self):
806
- """Test that compile_if_miss defaults to True for all *_by_call methods."""
807
- state = brainstate.State(jnp.array([1.0, 2.0]))
808
-
809
- def f(x):
810
- state.value += x
811
- return state.value
812
-
813
- sf = brainstate.transform.StatefulFunction(f)
814
-
815
- # All these should work without explicit compile_if_miss=True
816
- jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]))
817
- self.assertIsNotNone(jaxpr)
818
-
819
- # Create new instance for fresh cache
820
- sf2 = brainstate.transform.StatefulFunction(f)
821
- shapes = sf2.get_out_shapes(jnp.array([1.0, 2.0]))
822
- self.assertIsNotNone(shapes)
823
-
824
- # Create new instance for fresh cache
825
- sf3 = brainstate.transform.StatefulFunction(f)
826
- states = sf3.get_states(jnp.array([1.0, 2.0]))
827
- self.assertIsNotNone(states)
828
-
829
-
830
- class TestMakeHashable(unittest.TestCase):
831
- """Test the make_hashable utility function."""
832
-
833
- def test_hashable_list(self):
834
- """Test converting list to hashable."""
835
- result = make_hashable([1, 2, 3])
836
- # Should return a tuple
837
- self.assertIsInstance(result, tuple)
838
- # Should be hashable
839
- hash(result)
840
-
841
- def test_hashable_dict(self):
842
- """Test converting dict to hashable."""
843
- result = make_hashable({'b': 2, 'a': 1})
844
- # Should return a tuple of sorted key-value pairs
845
- self.assertIsInstance(result, tuple)
846
- # Should be hashable
847
- hash(result)
848
- # Keys should be sorted
849
- keys = [item[0] for item in result]
850
- self.assertEqual(keys, ['a', 'b'])
851
-
852
- def test_hashable_set(self):
853
- """Test converting set to hashable."""
854
- result = make_hashable({1, 2, 3})
855
- # Should return a frozenset
856
- self.assertIsInstance(result, frozenset)
857
- # Should be hashable
858
- hash(result)
859
-
860
- def test_hashable_nested(self):
861
- """Test converting nested structures."""
862
- nested = {
863
- 'list': [1, 2, 3],
864
- 'dict': {'a': 1, 'b': 2},
865
- 'set': {4, 5}
866
- }
867
- result = make_hashable(nested)
868
- # Should be hashable
869
- hash(result) # Should not raise
870
-
871
- def test_hashable_tuple(self):
872
- """Test with tuples."""
873
- result = make_hashable((1, 2, 3))
874
- # Should return a tuple
875
- self.assertIsInstance(result, tuple)
876
- # Should be hashable
877
- hash(result)
878
-
879
- def test_hashable_idempotent(self):
880
- """Test that applying make_hashable twice gives consistent results."""
881
- original = {'a': [1, 2], 'b': {3, 4}}
882
- result1 = make_hashable(original)
883
- result2 = make_hashable(original)
884
- # Should be the same
885
- self.assertEqual(result1, result2)
886
-
887
-
888
- class TestCacheCleanupOnError(unittest.TestCase):
889
- """Test that cache is properly cleaned up when compilation fails."""
890
-
891
- def test_cache_cleanup_on_compilation_error(self):
892
- """Test that partial cache entries are cleaned up when make_jaxpr fails."""
893
-
894
- def f(x):
895
- # This will cause an error during JAX tracing
896
- if x > 0: # Control flow not allowed in JAX
897
- return x * 2
898
- else:
899
- return x + 1
900
-
901
- sf = brainstate.transform.StatefulFunction(f)
902
-
903
- # Try to compile, should fail
904
- try:
905
- sf.make_jaxpr(jnp.array([1.0]))
906
- except Exception:
907
- pass # Expected to fail
908
-
909
- # Cache should be empty after error
910
- stats = sf.get_cache_stats()
911
- # All caches should be empty since error cleanup should have removed partial entries
912
- # Note: The actual behavior depends on when the error occurs during compilation
913
- # If error happens early, no cache entries; if late, entries might exist
914
- # This test just verifies the cleanup mechanism exists
915
-
916
-
917
- class TestMakeJaxprReturnOnlyWrite(unittest.TestCase):
918
- """Test make_jaxpr with return_only_write parameter."""
919
-
920
- def test_make_jaxpr_return_only_write(self):
921
- """Test make_jaxpr function with return_only_write parameter."""
922
- read_state = brainstate.State(jnp.array([1.0]))
923
- write_state = brainstate.State(jnp.array([2.0]))
924
-
925
- def f(x):
926
- _ = read_state.value # Read only
927
- write_state.value += x # Write
928
- return x * 2
929
-
930
- # Test with return_only_write=True
931
- jaxpr_maker = brainstate.transform.make_jaxpr(f, return_only_write=True)
932
- jaxpr, states = jaxpr_maker(jnp.array([1.0]))
933
-
934
- # Should compile successfully
935
- self.assertIsNotNone(jaxpr)
936
- self.assertIsInstance(states, tuple)
937
-
938
-
939
- class TestStatefulFunctionCallable(unittest.TestCase):
940
- """Test __call__ method of StatefulFunction."""
941
-
942
- def test_stateful_function_call(self):
943
- """Test calling StatefulFunction directly."""
944
- state = brainstate.State(jnp.array([1.0, 2.0]))
945
-
946
- def f(x):
947
- state.value += x
948
- return state.value * 2
949
-
950
- sf = brainstate.transform.StatefulFunction(f)
951
- x = jnp.array([0.5, 0.5])
952
- sf.make_jaxpr(x)
953
-
954
- # Test direct call
955
- result = sf(x)
956
- self.assertEqual(result.shape, (2,))
957
-
958
- def test_stateful_function_call_auto_compile(self):
959
- """Test that __call__ automatically compiles if needed."""
960
- state = brainstate.State(jnp.array([1.0, 2.0]))
961
-
962
- def f(x):
963
- state.value += x
964
- return state.value * 2
965
-
966
- sf = brainstate.transform.StatefulFunction(f)
967
- x = jnp.array([0.5, 0.5])
968
-
969
- # Call without pre-compilation should work
970
- result = sf(x)
971
- self.assertEqual(result.shape, (2,))
972
-
973
- def test_stateful_function_multiple_calls(self):
974
- """Test multiple calls to StatefulFunction."""
975
- state = brainstate.State(jnp.array([0.0]))
976
-
977
- def f(x):
978
- state.value += x
979
- return state.value
980
-
981
- sf = brainstate.transform.StatefulFunction(f)
982
-
983
- # Multiple calls should accumulate state
984
- result1 = sf(jnp.array([1.0]))
985
- result2 = sf(jnp.array([2.0]))
986
- result3 = sf(jnp.array([3.0]))
987
-
988
- # Each call should update the state
989
- self.assertIsNotNone(result1)
990
- self.assertIsNotNone(result2)
991
- self.assertIsNotNone(result3)
992
-
993
-
994
- class TestStatefulFunctionStaticArgs(unittest.TestCase):
995
- """Test StatefulFunction with static arguments."""
996
-
997
- def test_static_argnums_basic(self):
998
- """Test basic usage of static_argnums."""
999
- state = brainstate.State(jnp.array([1.0, 2.0]))
1000
-
1001
- def f(x, multiplier):
1002
- state.value += x
1003
- return state.value * multiplier
1004
-
1005
- sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
1006
- x = jnp.array([0.5, 0.5])
1007
-
1008
- # Compile with multiplier=2
1009
- sf.make_jaxpr(x, 2)
1010
- cache_key1 = sf.get_arg_cache_key(x, 2)
1011
-
1012
- # Compile with multiplier=3
1013
- sf.make_jaxpr(x, 3)
1014
- cache_key2 = sf.get_arg_cache_key(x, 3)
1015
-
1016
- # Should have different cache keys
1017
- self.assertNotEqual(cache_key1, cache_key2)
1018
-
1019
- def test_static_argnames_basic(self):
1020
- """Test basic usage of static_argnames."""
1021
- state = brainstate.State(jnp.array([1.0, 2.0]))
1022
-
1023
- def f(x, multiplier=2):
1024
- state.value += x
1025
- return state.value * multiplier
1026
-
1027
- sf = brainstate.transform.StatefulFunction(f, static_argnames='multiplier')
1028
- x = jnp.array([0.5, 0.5])
1029
-
1030
- # Compile with different multiplier values
1031
- sf.make_jaxpr(x, multiplier=2)
1032
- cache_key1 = sf.get_arg_cache_key(x, multiplier=2)
1033
-
1034
- sf.make_jaxpr(x, multiplier=3)
1035
- cache_key2 = sf.get_arg_cache_key(x, multiplier=3)
1036
-
1037
- # Should have different cache keys
1038
- self.assertNotEqual(cache_key1, cache_key2)
1039
-
1040
- def test_static_args_combination(self):
1041
- """Test using both static_argnums and static_argnames."""
1042
- state = brainstate.State(jnp.array([1.0]))
1043
-
1044
- def f(x, multiplier, offset=0):
1045
- state.value += x
1046
- return state.value * multiplier + offset
1047
-
1048
- sf = brainstate.transform.StatefulFunction(
1049
- f, static_argnums=(1,), static_argnames='offset'
1050
- )
1051
- x = jnp.array([0.5])
1052
-
1053
- # Compile with different static args
1054
- sf.make_jaxpr(x, 2, offset=0)
1055
- cache_key1 = sf.get_arg_cache_key(x, 2, offset=0)
1056
-
1057
- sf.make_jaxpr(x, 3, offset=1)
1058
- cache_key2 = sf.get_arg_cache_key(x, 3, offset=1)
1059
-
1060
- # Should have different cache keys
1061
- self.assertNotEqual(cache_key1, cache_key2)
1062
-
1063
-
1064
- class TestStatefulFunctionComplexStates(unittest.TestCase):
1065
- """Test StatefulFunction with complex state scenarios."""
1066
-
1067
- def test_multiple_states(self):
1068
- """Test function with multiple states."""
1069
- state1 = brainstate.State(jnp.array([1.0]))
1070
- state2 = brainstate.State(jnp.array([2.0]))
1071
- state3 = brainstate.State(jnp.array([3.0]))
1072
-
1073
- def f(x):
1074
- state1.value += x
1075
- state2.value += x * 2
1076
- state3.value += x * 3
1077
- return state1.value + state2.value + state3.value
1078
-
1079
- sf = brainstate.transform.StatefulFunction(f)
1080
- x = jnp.array([1.0])
1081
- sf.make_jaxpr(x)
1082
-
1083
- cache_key = sf.get_arg_cache_key(x)
1084
- states = sf.get_states_by_cache(cache_key)
1085
-
1086
- # Should track all three states
1087
- self.assertEqual(len(states), 3)
1088
-
1089
- def test_nested_state_access(self):
1090
- """Test function with nested state access patterns."""
1091
- outer_state = brainstate.State(jnp.array([1.0]))
1092
- inner_state = brainstate.State(jnp.array([2.0]))
1093
-
1094
- def inner_fn(x):
1095
- inner_state.value += x
1096
- return inner_state.value
1097
-
1098
- def outer_fn(x):
1099
- outer_state.value += x
1100
- result = inner_fn(x)
1101
- return outer_state.value + result
1102
-
1103
- sf = brainstate.transform.StatefulFunction(outer_fn)
1104
- x = jnp.array([1.0])
1105
- sf.make_jaxpr(x)
1106
-
1107
- cache_key = sf.get_arg_cache_key(x)
1108
- states = sf.get_states_by_cache(cache_key)
1109
-
1110
- # Should track both states
1111
- self.assertGreaterEqual(len(states), 2)
1112
-
1113
- def test_conditional_state_write(self):
1114
- """Test function that conditionally writes to states."""
1115
- state1 = brainstate.State(jnp.array([1.0]))
1116
- state2 = brainstate.State(jnp.array([2.0]))
1117
-
1118
- def f(x, write_state1=True):
1119
- # Note: In JAX, actual control flow needs special handling
1120
- # This test is more about the framework's ability to track states
1121
- state1.value += x # Always write to state1
1122
- state2.value += x * 2 # Always write to state2
1123
- return state1.value + state2.value
1124
-
1125
- sf = brainstate.transform.StatefulFunction(f, static_argnames='write_state1')
1126
- x = jnp.array([1.0])
1127
- sf.make_jaxpr(x, write_state1=True)
1128
-
1129
- cache_key = sf.get_arg_cache_key(x, write_state1=True)
1130
- states = sf.get_states_by_cache(cache_key)
1131
-
1132
- # Should track states
1133
- self.assertGreaterEqual(len(states), 2)
1134
-
1135
-
1136
- class TestStatefulFunctionOutputShapes(unittest.TestCase):
1137
- """Test StatefulFunction output shape tracking."""
1138
-
1139
- def test_single_output(self):
1140
- """Test tracking single output shape."""
1141
- state = brainstate.State(jnp.array([1.0, 2.0, 3.0]))
1142
-
1143
- def f(x):
1144
- state.value += x
1145
- return state.value
1146
-
1147
- sf = brainstate.transform.StatefulFunction(f)
1148
- x = jnp.array([1.0, 2.0, 3.0])
1149
- sf.make_jaxpr(x)
1150
-
1151
- cache_key = sf.get_arg_cache_key(x)
1152
- out_shapes = sf.get_out_shapes_by_cache(cache_key)
1153
-
1154
- # Should have output shapes
1155
- self.assertIsNotNone(out_shapes)
1156
-
1157
- def test_multiple_outputs(self):
1158
- """Test tracking multiple output shapes."""
1159
- state = brainstate.State(jnp.array([1.0, 2.0]))
1160
-
1161
- def f(x):
1162
- state.value += x
1163
- return state.value, state.value * 2, jnp.sum(state.value)
1164
-
1165
- sf = brainstate.transform.StatefulFunction(f)
1166
- x = jnp.array([1.0, 2.0])
1167
- sf.make_jaxpr(x)
1168
-
1169
- cache_key = sf.get_arg_cache_key(x)
1170
- out_shapes = sf.get_out_shapes_by_cache(cache_key)
1171
-
1172
- # Should track all output shapes
1173
- self.assertIsNotNone(out_shapes)
1174
-
1175
- def test_nested_output_structure(self):
1176
- """Test tracking nested output structures."""
1177
- state = brainstate.State(jnp.array([1.0, 2.0]))
1178
-
1179
- def f(x):
1180
- state.value += x
1181
- return {
1182
- 'sum': jnp.sum(state.value),
1183
- 'prod': jnp.prod(state.value),
1184
- 'values': state.value
1185
- }
1186
-
1187
- sf = brainstate.transform.StatefulFunction(f)
1188
- x = jnp.array([1.0, 2.0])
1189
- sf.make_jaxpr(x)
1190
-
1191
- cache_key = sf.get_arg_cache_key(x)
1192
- out_treedef = sf.get_out_treedef_by_cache(cache_key)
1193
-
1194
- # Should have tree definition
1195
- self.assertIsNotNone(out_treedef)
1196
-
1197
-
1198
- class TestStatefulFunctionJaxprCall(unittest.TestCase):
1199
- """Test jaxpr_call and jaxpr_call_auto methods."""
1200
-
1201
- def test_jaxpr_call_basic(self):
1202
- """Test basic jaxpr_call usage."""
1203
- state = brainstate.State(jnp.array([1.0, 2.0]))
1204
-
1205
- def f(x):
1206
- state.value += x
1207
- return state.value * 2
1208
-
1209
- sf = brainstate.transform.StatefulFunction(f)
1210
- x = jnp.array([0.5, 0.5])
1211
- sf.make_jaxpr(x)
1212
-
1213
- # Get current state values
1214
- state_vals = [state.value]
1215
-
1216
- # Call at jaxpr level
1217
- new_state_vals, out = sf.jaxpr_call(state_vals, x)
1218
-
1219
- self.assertEqual(len(new_state_vals), 1)
1220
- self.assertEqual(out.shape, (2,))
1221
-
1222
- def test_jaxpr_call_auto_basic(self):
1223
- """Test basic jaxpr_call_auto usage."""
1224
- state = brainstate.State(jnp.array([1.0, 2.0]))
1225
-
1226
- def f(x):
1227
- state.value += x
1228
- return state.value * 2
1229
-
1230
- sf = brainstate.transform.StatefulFunction(f)
1231
- x = jnp.array([0.5, 0.5])
1232
- sf.make_jaxpr(x)
1233
-
1234
- # Call with automatic state management
1235
- result = sf.jaxpr_call_auto(x)
1236
-
1237
- self.assertEqual(result.shape, (2,))
1238
-
1239
- def test_jaxpr_call_preserves_state_order(self):
1240
- """Test that jaxpr_call preserves state order."""
1241
- state1 = brainstate.State(jnp.array([1.0]))
1242
- state2 = brainstate.State(jnp.array([2.0]))
1243
- state3 = brainstate.State(jnp.array([3.0]))
1244
-
1245
- def f(x):
1246
- state1.value += x
1247
- state2.value += x * 2
1248
- state3.value += x * 3
1249
- return state1.value + state2.value + state3.value
1250
-
1251
- sf = brainstate.transform.StatefulFunction(f)
1252
- x = jnp.array([1.0])
1253
- sf.make_jaxpr(x)
1254
-
1255
- cache_key = sf.get_arg_cache_key(x)
1256
- states = sf.get_states_by_cache(cache_key)
1257
-
1258
- # Get initial state values
1259
- state_vals = [s.value for s in states]
1260
-
1261
- # Call at jaxpr level
1262
- new_state_vals, _ = sf.jaxpr_call(state_vals, x)
1263
-
1264
- # Should return same number of states
1265
- self.assertEqual(len(new_state_vals), len(state_vals))
1266
-
1267
-
1268
- class TestStatefulFunctionEdgeCases(unittest.TestCase):
1269
- """Test edge cases and corner scenarios."""
1270
-
1271
- def test_no_state_function(self):
1272
- """Test function that doesn't use any states."""
1273
-
1274
- def f(x):
1275
- return x * 2 + 1
1276
-
1277
- sf = brainstate.transform.StatefulFunction(f)
1278
- x = jnp.array([1.0, 2.0])
1279
- sf.make_jaxpr(x)
1280
-
1281
- cache_key = sf.get_arg_cache_key(x)
1282
- states = sf.get_states_by_cache(cache_key)
1283
-
1284
- # Should have no states
1285
- self.assertEqual(len(states), 0)
1286
-
1287
- def test_read_only_state(self):
1288
- """Test function that only reads from states."""
1289
- state = brainstate.State(jnp.array([1.0, 2.0]))
1290
-
1291
- def f(x):
1292
- # Only read from state, don't write
1293
- return state.value + x
1294
-
1295
- sf = brainstate.transform.StatefulFunction(f, return_only_write=True)
1296
- x = jnp.array([1.0, 2.0])
1297
- sf.make_jaxpr(x)
1298
-
1299
- cache_key = sf.get_arg_cache_key(x)
1300
- write_states = sf.get_write_states_by_cache(cache_key)
1301
-
1302
- # Should have no write states
1303
- self.assertEqual(len(write_states), 0)
1304
-
1305
- def test_scalar_inputs_outputs(self):
1306
- """Test with scalar inputs and outputs."""
1307
- state = brainstate.State(jnp.array(1.0))
1308
-
1309
- def f(x):
1310
- state.value += x
1311
- return state.value
1312
-
1313
- sf = brainstate.transform.StatefulFunction(f)
1314
- x = jnp.array(0.5)
1315
- sf.make_jaxpr(x)
1316
-
1317
- cache_key = sf.get_arg_cache_key(x)
1318
- jaxpr = sf.get_jaxpr_by_cache(cache_key)
1319
-
1320
- # Should compile successfully
1321
- self.assertIsNotNone(jaxpr)
1322
-
1323
- def test_empty_function(self):
1324
- """Test function with no operations."""
1325
-
1326
- def f(x):
1327
- return x
1328
-
1329
- sf = brainstate.transform.StatefulFunction(f)
1330
- x = jnp.array([1.0, 2.0])
1331
- sf.make_jaxpr(x)
1332
-
1333
- cache_key = sf.get_arg_cache_key(x)
1334
- jaxpr = sf.get_jaxpr_by_cache(cache_key)
1335
-
1336
- # Should compile successfully
1337
- self.assertIsNotNone(jaxpr)
1338
-
1339
- def test_complex_dtype(self):
1340
- """Test with complex dtype arrays."""
1341
- state = brainstate.State(jnp.array([1.0 + 2.0j, 3.0 + 4.0j]))
1342
-
1343
- def f(x):
1344
- state.value += x
1345
- return state.value
1346
-
1347
- sf = brainstate.transform.StatefulFunction(f)
1348
- x = jnp.array([0.5 + 0.5j, 0.5 + 0.5j])
1349
- sf.make_jaxpr(x)
1350
-
1351
- cache_key = sf.get_arg_cache_key(x)
1352
- jaxpr = sf.get_jaxpr_by_cache(cache_key)
1353
-
1354
- # Should compile successfully
1355
- self.assertIsNotNone(jaxpr)
1356
-
1357
-
1358
- class TestStatefulFunctionCacheKey(unittest.TestCase):
1359
- """Test cache key generation and behavior."""
1360
-
1361
- def test_cache_key_different_shapes(self):
1362
- """Test that different input shapes produce different cache keys."""
1363
-
1364
- def f(x):
1365
- return x * 2
1366
-
1367
- sf = brainstate.transform.StatefulFunction(f)
1368
-
1369
- x1 = jnp.array([1.0, 2.0])
1370
- x2 = jnp.array([1.0, 2.0, 3.0])
1371
-
1372
- cache_key1 = sf.get_arg_cache_key(x1)
1373
- cache_key2 = sf.get_arg_cache_key(x2)
1374
-
1375
- # Should have different cache keys
1376
- self.assertNotEqual(cache_key1, cache_key2)
1377
-
1378
- def test_cache_key_different_dtypes(self):
1379
- """Test that different dtypes produce different cache keys."""
1380
-
1381
- def f(x):
1382
- return x * 2
1383
-
1384
- sf = brainstate.transform.StatefulFunction(f)
1385
-
1386
- # Use int32 and float32 instead, which are always available in JAX
1387
- x1 = jnp.array([1.0, 2.0], dtype=jnp.float32)
1388
- x2 = jnp.array([1, 2], dtype=jnp.int32)
1389
-
1390
- cache_key1 = sf.get_arg_cache_key(x1)
1391
- cache_key2 = sf.get_arg_cache_key(x2)
1392
-
1393
- # Should have different cache keys due to different dtypes
1394
- self.assertNotEqual(cache_key1, cache_key2)
1395
-
1396
- def test_cache_key_same_abstract_values(self):
1397
- """Test that same abstract values produce same cache keys."""
1398
-
1399
- def f(x):
1400
- return x * 2
1401
-
1402
- sf = brainstate.transform.StatefulFunction(f)
1403
-
1404
- x1 = jnp.array([1.0, 2.0])
1405
- x2 = jnp.array([3.0, 4.0]) # Different values, same shape/dtype
1406
-
1407
- cache_key1 = sf.get_arg_cache_key(x1)
1408
- cache_key2 = sf.get_arg_cache_key(x2)
1409
-
1410
- # Should have same cache keys (abstract values are the same)
1411
- self.assertEqual(cache_key1, cache_key2)
1412
-
1413
- def test_cache_key_with_pytree_inputs(self):
1414
- """Test cache key generation with pytree inputs."""
1415
-
1416
- def f(inputs):
1417
- x, y = inputs
1418
- return x + y
1419
-
1420
- sf = brainstate.transform.StatefulFunction(f)
1421
-
1422
- inputs1 = (jnp.array([1.0]), jnp.array([2.0]))
1423
- inputs2 = (jnp.array([3.0]), jnp.array([4.0]))
1424
-
1425
- cache_key1 = sf.get_arg_cache_key(inputs1)
1426
- cache_key2 = sf.get_arg_cache_key(inputs2)
1427
-
1428
- # Should have same cache keys (same structure/shapes)
1429
- self.assertEqual(cache_key1, cache_key2)
1430
-
1431
-
1432
- class TestStatefulFunctionRecompilation(unittest.TestCase):
1433
- """Test recompilation scenarios."""
1434
-
1435
- def test_cache_reuse(self):
1436
- """Test that cache is reused for same inputs."""
1437
- state = brainstate.State(jnp.array([1.0]))
1438
-
1439
- def f(x):
1440
- state.value += x
1441
- return state.value
1442
-
1443
- sf = brainstate.transform.StatefulFunction(f)
1444
-
1445
- x = jnp.array([1.0])
1446
-
1447
- # First compilation
1448
- sf.make_jaxpr(x)
1449
- stats1 = sf.get_cache_stats()
1450
-
1451
- # Second call with same shape should reuse cache
1452
- sf.make_jaxpr(x)
1453
- stats2 = sf.get_cache_stats()
1454
-
1455
- # Cache size should remain the same
1456
- self.assertEqual(
1457
- stats1['jaxpr_cache']['size'],
1458
- stats2['jaxpr_cache']['size']
1459
- )
1460
-
1461
- def test_multiple_compilations_different_shapes(self):
1462
- """Test multiple compilations with different shapes."""
1463
- state = brainstate.State(jnp.array([1.0]))
1464
-
1465
- def f(x):
1466
- return x * 2
1467
-
1468
- sf = brainstate.transform.StatefulFunction(f)
1469
-
1470
- # Compile for different shapes
1471
- shapes = [
1472
- jnp.array([1.0]),
1473
- jnp.array([1.0, 2.0]),
1474
- jnp.array([1.0, 2.0, 3.0]),
1475
- ]
1476
-
1477
- for x in shapes:
1478
- sf.make_jaxpr(x)
1479
-
1480
- stats = sf.get_cache_stats()
1481
-
1482
- # Should have 3 different cache entries
1483
- self.assertEqual(stats['jaxpr_cache']['size'], 3)
1484
-
1485
- def test_clear_and_recompile(self):
1486
- """Test clearing cache and recompiling."""
1487
- state = brainstate.State(jnp.array([1.0]))
1488
-
1489
- def f(x):
1490
- state.value += x
1491
- return state.value
1492
-
1493
- sf = brainstate.transform.StatefulFunction(f)
1494
- x = jnp.array([1.0])
1495
-
1496
- # Compile
1497
- sf.make_jaxpr(x)
1498
- stats_before = sf.get_cache_stats()
1499
- self.assertGreater(stats_before['jaxpr_cache']['size'], 0)
1500
-
1501
- # Clear cache
1502
- sf.clear_cache()
1503
- stats_after_clear = sf.get_cache_stats()
1504
- self.assertEqual(stats_after_clear['jaxpr_cache']['size'], 0)
1505
-
1506
- # Recompile
1507
- sf.make_jaxpr(x)
1508
- stats_after_recompile = sf.get_cache_stats()
1509
- self.assertGreater(stats_after_recompile['jaxpr_cache']['size'], 0)
1510
-
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import threading
18
+ import unittest
19
+ import warnings
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import jax.random as jr
24
+ import pytest
25
+
26
+ import brainstate
27
+ from brainstate._compatible_import import jaxpr_as_fun
28
+ from brainstate._error import BatchAxisError
29
+ from brainstate.transform._make_jaxpr import _BoundedCache, make_hashable
30
+ from brainstate.util import filter as state_filter
31
+
32
+
33
+ class TestMakeJaxpr(unittest.TestCase):
34
+ def test_compar_jax_make_jaxpr(self):
35
+ def func4(arg): # Arg is a pair
36
+ temp = arg[0] + jnp.sin(arg[1]) * 3.
37
+ c = brainstate.random.rand_like(arg[0])
38
+ return jnp.sum(temp + c)
39
+
40
+ key = brainstate.random.DEFAULT.value
41
+ jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
42
+ print(jaxpr)
43
+ self.assertTrue(len(jaxpr.in_avals) == 2)
44
+ self.assertTrue(len(jaxpr.consts) == 1)
45
+ self.assertTrue(len(jaxpr.out_avals) == 1)
46
+ self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
47
+
48
+ brainstate.random.seed(1)
49
+ print(brainstate.random.DEFAULT.value)
50
+
51
+ jaxpr2, states = brainstate.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
52
+ print(jaxpr2)
53
+ self.assertTrue(len(jaxpr2.in_avals) == 3)
54
+ self.assertTrue(len(jaxpr2.out_avals) == 2)
55
+ self.assertTrue(len(jaxpr2.consts) == 0)
56
+ print(brainstate.random.DEFAULT.value)
57
+
58
+ def test_StatefulFunction_1(self):
59
+ def func4(arg): # Arg is a pair
60
+ temp = arg[0] + jnp.sin(arg[1]) * 3.
61
+ c = brainstate.random.rand_like(arg[0])
62
+ return jnp.sum(temp + c)
63
+
64
+ fun = brainstate.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
65
+ cache_key = fun.get_arg_cache_key((jnp.zeros(8), jnp.ones(8)))
66
+ print(fun.get_states_by_cache(cache_key))
67
+ print(fun.get_jaxpr_by_cache(cache_key))
68
+
69
+ def test_StatefulFunction_2(self):
70
+ st1 = brainstate.State(jnp.ones(10))
71
+
72
+ def f1(x):
73
+ st1.value = x + st1.value
74
+
75
+ def f2(x):
76
+ jaxpr = brainstate.transform.make_jaxpr(f1)(x)
77
+ c = 1. + x
78
+ return c
79
+
80
+ def f3(x):
81
+ jaxpr = brainstate.transform.make_jaxpr(f1)(x)
82
+ c = 1.
83
+ return c
84
+
85
+ print()
86
+ jaxpr = brainstate.transform.make_jaxpr(f1)(jnp.zeros(1))
87
+ print(jaxpr)
88
+ jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
89
+ print(jaxpr)
90
+ jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
91
+ print(jaxpr)
92
+ jaxpr, _ = brainstate.transform.make_jaxpr(f3)(jnp.zeros(1))
93
+ print(jaxpr)
94
+ self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
95
+ f3(jnp.zeros(1))))
96
+
97
+ def test_compare_jax_make_jaxpr2(self):
98
+ st1 = brainstate.State(jnp.ones(10))
99
+
100
+ def fa(x):
101
+ st1.value = x + st1.value
102
+
103
+ def ffa(x):
104
+ jaxpr, states = brainstate.transform.make_jaxpr(fa)(x)
105
+ c = 1. + x
106
+ return c
107
+
108
+ jaxpr, states = brainstate.transform.make_jaxpr(ffa)(jnp.zeros(1))
109
+ print()
110
+ print(jaxpr)
111
+ print(states)
112
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
113
+ jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
114
+ print(jaxpr)
115
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
116
+
117
+ def test_compare_jax_make_jaxpr3(self):
118
+ def fa(x):
119
+ return 1.
120
+
121
+ jaxpr, states = brainstate.transform.make_jaxpr(fa)(jnp.zeros(1))
122
+ print()
123
+ print(jaxpr)
124
+ print(states)
125
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
126
+ jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
127
+ print(jaxpr)
128
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
129
+
130
+ def test_static_argnames(self):
131
+ def func4(a, b): # Arg is a pair
132
+ temp = a + jnp.sin(b) * 3.
133
+ c = brainstate.random.rand_like(a)
134
+ return jnp.sum(temp + c)
135
+
136
+ jaxpr, states = brainstate.transform.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
137
+ print()
138
+ print(jaxpr)
139
+ print(states)
140
+
141
+ def test_state_in(self):
142
+ def f(a):
143
+ return a.value
144
+
145
+ with pytest.raises(ValueError):
146
+ brainstate.transform.StatefulFunction(f).make_jaxpr(brainstate.State(1.))
147
+
148
+ def test_state_out(self):
149
+ def f(a):
150
+ return brainstate.State(a)
151
+
152
+ with pytest.raises(ValueError):
153
+ brainstate.transform.StatefulFunction(f).make_jaxpr(1.)
154
+
155
+ def test_return_states(self):
156
+ a = brainstate.State(jnp.ones(3))
157
+
158
+ @brainstate.transform.jit
159
+ def f():
160
+ return a
161
+
162
+ with pytest.raises(ValueError):
163
+ f()
164
+
165
+
166
+ class TestBoundedCache(unittest.TestCase):
167
+ """Test the _BoundedCache class."""
168
+
169
+ def test_cache_basic_operations(self):
170
+ """Test basic get and set operations."""
171
+ cache = _BoundedCache(maxsize=3)
172
+
173
+ # Test set and get
174
+ cache.set('key1', 'value1')
175
+ self.assertEqual(cache.get('key1'), 'value1')
176
+
177
+ # Test default value
178
+ self.assertIsNone(cache.get('nonexistent'))
179
+ self.assertEqual(cache.get('nonexistent', 'default'), 'default')
180
+
181
+ # Test __contains__
182
+ self.assertIn('key1', cache)
183
+ self.assertNotIn('key2', cache)
184
+
185
+ # Test __len__
186
+ self.assertEqual(len(cache), 1)
187
+
188
+ def test_cache_lru_eviction(self):
189
+ """Test LRU eviction when cache is full."""
190
+ cache = _BoundedCache(maxsize=3)
191
+
192
+ # Fill cache
193
+ cache.set('key1', 'value1')
194
+ cache.set('key2', 'value2')
195
+ cache.set('key3', 'value3')
196
+ self.assertEqual(len(cache), 3)
197
+
198
+ # Add one more, should evict key1 (least recently used)
199
+ cache.set('key4', 'value4')
200
+ self.assertEqual(len(cache), 3)
201
+ self.assertNotIn('key1', cache)
202
+ self.assertIn('key4', cache)
203
+
204
+ # Access key2 to make it recently used
205
+ cache.get('key2')
206
+
207
+ # Add another key, should evict key3 (now least recently used)
208
+ cache.set('key5', 'value5')
209
+ self.assertNotIn('key3', cache)
210
+ self.assertIn('key2', cache)
211
+
212
+ def test_cache_update_existing(self):
213
+ """Test updating an existing key."""
214
+ cache = _BoundedCache(maxsize=2)
215
+
216
+ cache.set('key1', 'value1')
217
+ cache.set('key2', 'value2')
218
+
219
+ # Update key1 (should move it to end)
220
+ cache.replace('key1', 'updated_value1')
221
+ self.assertEqual(cache.get('key1'), 'updated_value1')
222
+
223
+ # Add new key, should evict key2 (now LRU)
224
+ cache.set('key3', 'value3')
225
+ self.assertNotIn('key2', cache)
226
+ self.assertIn('key1', cache)
227
+
228
+ def test_cache_statistics(self):
229
+ """Test cache statistics tracking."""
230
+ cache = _BoundedCache(maxsize=5)
231
+
232
+ # Initial stats
233
+ stats = cache.get_stats()
234
+ self.assertEqual(stats['size'], 0)
235
+ self.assertEqual(stats['maxsize'], 5)
236
+ self.assertEqual(stats['hits'], 0)
237
+ self.assertEqual(stats['misses'], 0)
238
+ self.assertEqual(stats['hit_rate'], 0.0)
239
+
240
+ # Add items and test hits/misses
241
+ cache.set('key1', 'value1')
242
+ cache.set('key2', 'value2')
243
+
244
+ # Generate hits
245
+ cache.get('key1') # hit
246
+ cache.get('key1') # hit
247
+ cache.get('key3') # miss
248
+ cache.get('key2') # hit
249
+
250
+ stats = cache.get_stats()
251
+ self.assertEqual(stats['size'], 2)
252
+ self.assertEqual(stats['hits'], 3)
253
+ self.assertEqual(stats['misses'], 1)
254
+ self.assertEqual(stats['hit_rate'], 75.0)
255
+
256
+ def test_cache_clear(self):
257
+ """Test clearing the cache."""
258
+ cache = _BoundedCache(maxsize=5)
259
+
260
+ # Add items
261
+ cache.set('key1', 'value1')
262
+ cache.set('key2', 'value2')
263
+ cache.get('key1') # Generate a hit
264
+
265
+ # Clear cache
266
+ cache.clear()
267
+
268
+ self.assertEqual(len(cache), 0)
269
+ self.assertNotIn('key1', cache)
270
+
271
+ # Check stats are reset
272
+ stats = cache.get_stats()
273
+ self.assertEqual(stats['hits'], 0)
274
+ self.assertEqual(stats['misses'], 0)
275
+
276
+ def test_cache_keys(self):
277
+ """Test getting all cache keys."""
278
+ cache = _BoundedCache(maxsize=5)
279
+
280
+ cache.set('key1', 'value1')
281
+ cache.set('key2', 'value2')
282
+ cache.set('key3', 'value3')
283
+
284
+ keys = cache.keys()
285
+ self.assertEqual(set(keys), {'key1', 'key2', 'key3'})
286
+
287
+ def test_cache_set_duplicate_raises(self):
288
+ """Test that setting an existing key raises ValueError."""
289
+ cache = _BoundedCache(maxsize=5)
290
+
291
+ cache.set('key1', 'value1')
292
+
293
+ # Attempting to set the same key should raise ValueError
294
+ with pytest.raises(ValueError, match="Cache key already exists"):
295
+ cache.set('key1', 'value2')
296
+
297
+ def test_cache_pop(self):
298
+ """Test pop method."""
299
+ cache = _BoundedCache(maxsize=5)
300
+
301
+ cache.set('key1', 'value1')
302
+ cache.set('key2', 'value2')
303
+
304
+ # Pop existing key
305
+ value = cache.pop('key1')
306
+ self.assertEqual(value, 'value1')
307
+ self.assertNotIn('key1', cache)
308
+ self.assertEqual(len(cache), 1)
309
+
310
+ # Pop non-existent key with default
311
+ value = cache.pop('nonexistent', 'default')
312
+ self.assertEqual(value, 'default')
313
+
314
+ # Pop non-existent key without default
315
+ value = cache.pop('nonexistent')
316
+ self.assertIsNone(value)
317
+
318
+ def test_cache_replace(self):
319
+ """Test replace method."""
320
+ cache = _BoundedCache(maxsize=5)
321
+
322
+ cache.set('key1', 'value1')
323
+ cache.set('key2', 'value2')
324
+
325
+ # Replace existing key
326
+ cache.replace('key1', 'new_value1')
327
+ self.assertEqual(cache.get('key1'), 'new_value1')
328
+
329
+ # Replacing should move to end (most recently used)
330
+ cache.set('key3', 'value3')
331
+ cache.replace('key2', 'new_value2')
332
+
333
+ # Add more items to test LRU behavior
334
+ cache.set('key4', 'value4')
335
+ cache.set('key5', 'value5')
336
+
337
+ # Now when we add key6, key1 should be evicted (oldest after replace moved key2 to end)
338
+ cache.set('key6', 'value6')
339
+
340
+ # key2 should still be there because replace moved it to end
341
+ self.assertIn('key2', cache)
342
+
343
+ def test_cache_replace_nonexistent_raises(self):
344
+ """Test that replacing a non-existent key raises KeyError."""
345
+ cache = _BoundedCache(maxsize=5)
346
+
347
+ with pytest.raises(KeyError, match="Cache key does not exist"):
348
+ cache.replace('nonexistent', 'value')
349
+
350
+ def test_cache_get_with_raise_on_miss(self):
351
+ """Test get method with raise_on_miss parameter."""
352
+ cache = _BoundedCache(maxsize=5)
353
+
354
+ cache.set('key1', 'value1')
355
+ cache.set('key2', 'value2')
356
+
357
+ # Should work normally for existing key
358
+ value = cache.get('key1', raise_on_miss=True)
359
+ self.assertEqual(value, 'value1')
360
+
361
+ # Should raise ValueError for missing key with raise_on_miss=True
362
+ with pytest.raises(ValueError, match="not compiled for the requested cache key"):
363
+ cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
364
+
365
+ def test_cache_detailed_error_message(self):
366
+ """Test that error message shows available keys."""
367
+ cache = _BoundedCache(maxsize=5)
368
+
369
+ cache.set('key1', 'value1')
370
+ cache.set('key2', 'value2')
371
+
372
+ # Error should include all available keys
373
+ with pytest.raises(ValueError) as exc_info:
374
+ cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
375
+
376
+ error_msg = str(exc_info.value)
377
+ # Should show requested key
378
+ self.assertIn('nonexistent', error_msg)
379
+ # Should show available keys
380
+ self.assertIn('key1', error_msg)
381
+ self.assertIn('key2', error_msg)
382
+ # Should have helpful message
383
+ self.assertIn('make_jaxpr()', error_msg)
384
+
385
+ def test_cache_error_message_no_keys(self):
386
+ """Test error message when cache is empty."""
387
+ cache = _BoundedCache(maxsize=5)
388
+
389
+ with pytest.raises(ValueError) as exc_info:
390
+ cache.get('key', raise_on_miss=True, error_context="Empty cache")
391
+
392
+ error_msg = str(exc_info.value)
393
+ # Should indicate no keys available
394
+ self.assertIn('none', error_msg.lower())
395
+
396
+ def test_cache_thread_safety(self):
397
+ """Test thread safety of cache operations."""
398
+ cache = _BoundedCache(maxsize=100)
399
+ errors = []
400
+
401
+ def worker(thread_id):
402
+ try:
403
+ for i in range(50):
404
+ key = f'key_{thread_id}_{i}'
405
+ cache.set(key, f'value_{thread_id}_{i}')
406
+ value = cache.get(key)
407
+ if value != f'value_{thread_id}_{i}':
408
+ errors.append(f'Mismatch in thread {thread_id}')
409
+ except Exception as e:
410
+ errors.append(f'Error in thread {thread_id}: {e}')
411
+
412
+ # Create multiple threads
413
+ threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
414
+
415
+ # Start all threads
416
+ for t in threads:
417
+ t.start()
418
+
419
+ # Wait for all threads to complete
420
+ for t in threads:
421
+ t.join()
422
+
423
+ # Check no errors occurred
424
+ self.assertEqual(len(errors), 0, f"Thread safety errors: {errors}")
425
+
426
+
427
+ class TestStatefulFunctionEnhancements(unittest.TestCase):
428
+ """Test enhancements to StatefulFunction class."""
429
+
430
+ def test_cache_stats(self):
431
+ """Test get_cache_stats method."""
432
+ state = brainstate.State(jnp.array([1.0, 2.0]))
433
+
434
+ def f(x):
435
+ state.value += x
436
+ return state.value * 2
437
+
438
+ sf = brainstate.transform.StatefulFunction(f)
439
+
440
+ # Compile for different inputs
441
+ x1 = jnp.array([0.5, 0.5])
442
+ x2 = jnp.array([1.0, 1.0])
443
+
444
+ sf.make_jaxpr(x1)
445
+ sf.make_jaxpr(x2)
446
+
447
+ # Get cache stats
448
+ stats = sf.get_cache_stats()
449
+
450
+ # Verify all cache types are present
451
+ self.assertIn('jaxpr_cache', stats)
452
+ self.assertIn('out_shapes_cache', stats)
453
+ self.assertIn('jaxpr_out_tree_cache', stats)
454
+ self.assertIn('state_trace_cache', stats)
455
+
456
+ # Verify each cache has proper stats
457
+ for cache_name, cache_stats in stats.items():
458
+ self.assertIn('size', cache_stats)
459
+ self.assertIn('maxsize', cache_stats)
460
+ self.assertIn('hits', cache_stats)
461
+ self.assertIn('misses', cache_stats)
462
+ self.assertIn('hit_rate', cache_stats)
463
+
464
+ def test_validate_states(self):
465
+ """Test validate_states method."""
466
+ state = brainstate.State(jnp.array([1.0, 2.0]))
467
+
468
+ def f(x):
469
+ state.value += x
470
+ return state.value
471
+
472
+ sf = brainstate.transform.StatefulFunction(f)
473
+ x = jnp.array([0.5, 0.5])
474
+ sf.make_jaxpr(x)
475
+
476
+ cache_key = sf.get_arg_cache_key(x)
477
+
478
+ # Should validate successfully
479
+ result = sf.validate_states(cache_key)
480
+ self.assertTrue(result)
481
+
482
+ def test_validate_all_states(self):
483
+ """Test validate_all_states method."""
484
+ state = brainstate.State(jnp.array([1.0, 2.0]))
485
+
486
+ def f(x, n):
487
+ state.value += x
488
+ return state.value * n
489
+
490
+ # Use static_argnums to create different cache keys
491
+ sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
492
+
493
+ # Compile for multiple inputs with different static args
494
+ x = jnp.array([0.5, 0.5])
495
+
496
+ sf.make_jaxpr(x, 1)
497
+ sf.make_jaxpr(x, 2)
498
+
499
+ # Validate all
500
+ results = sf.validate_all_states()
501
+
502
+ # Should have results for both cache keys
503
+ self.assertEqual(len(results), 2)
504
+
505
+ # All should be valid
506
+ for result in results.values():
507
+ self.assertTrue(result)
508
+
509
+ def test_clear_cache(self):
510
+ """Test clear_cache method."""
511
+ state = brainstate.State(jnp.array([1.0, 2.0]))
512
+
513
+ def f(x):
514
+ state.value += x
515
+ return state.value
516
+
517
+ sf = brainstate.transform.StatefulFunction(f)
518
+ x = jnp.array([0.5, 0.5])
519
+ sf.make_jaxpr(x)
520
+
521
+ # Verify cache has entries
522
+ stats = sf.get_cache_stats()
523
+ self.assertGreater(stats['jaxpr_cache']['size'], 0)
524
+
525
+ # Clear cache
526
+ sf.clear_cache()
527
+
528
+ # Verify all caches are empty
529
+ stats = sf.get_cache_stats()
530
+ self.assertEqual(stats['jaxpr_cache']['size'], 0)
531
+ self.assertEqual(stats['out_shapes_cache']['size'], 0)
532
+ self.assertEqual(stats['jaxpr_out_tree_cache']['size'], 0)
533
+ self.assertEqual(stats['state_trace_cache']['size'], 0)
534
+
535
+ def test_return_only_write_parameter(self):
536
+ """Test return_only_write parameter."""
537
+ read_state = brainstate.State(jnp.array([1.0, 2.0]))
538
+ write_state = brainstate.State(jnp.array([3.0, 4.0]))
539
+
540
+ def f(x):
541
+ # Read from read_state, write to write_state
542
+ _ = read_state.value + x
543
+ write_state.value += x
544
+ return write_state.value
545
+
546
+ # Test with return_only_write=False (default)
547
+ sf_all = brainstate.transform.StatefulFunction(f, return_only_write=False)
548
+ sf_all.make_jaxpr(jnp.array([0.5, 0.5]))
549
+ cache_key = sf_all.get_arg_cache_key(jnp.array([0.5, 0.5]))
550
+ states_all = sf_all.get_states_by_cache(cache_key)
551
+
552
+ # Test with return_only_write=True
553
+ sf_write_only = brainstate.transform.StatefulFunction(f, return_only_write=True)
554
+ sf_write_only.make_jaxpr(jnp.array([0.5, 0.5]))
555
+ cache_key_write = sf_write_only.get_arg_cache_key(jnp.array([0.5, 0.5]))
556
+ states_write = sf_write_only.get_states_by_cache(cache_key_write)
557
+
558
+ # With return_only_write=True, should have fewer or equal states
559
+ self.assertLessEqual(len(states_write), len(states_all))
560
+
561
+
562
+ class TestErrorHandling(unittest.TestCase):
563
+ """Test error handling in StatefulFunction."""
564
+
565
+ def test_jaxpr_call_state_mismatch(self):
566
+ """Test error when state values length doesn't match."""
567
+ state1 = brainstate.State(jnp.array([1.0, 2.0]))
568
+ state2 = brainstate.State(jnp.array([3.0, 4.0]))
569
+
570
+ def f(x):
571
+ state1.value += x
572
+ state2.value += x
573
+ return state1.value + state2.value
574
+
575
+ sf = brainstate.transform.StatefulFunction(f)
576
+ x = jnp.array([0.5, 0.5])
577
+ sf.make_jaxpr(x)
578
+
579
+ # Try to call with wrong number of state values (only 1 instead of 2)
580
+ with pytest.raises(ValueError, match="State length mismatch"):
581
+ sf.jaxpr_call([jnp.array([1.0, 1.0])], x) # Only 1 state instead of 2
582
+
583
+ def test_get_jaxpr_not_compiled_detailed_error(self):
584
+ """Test detailed error message when getting jaxpr for uncompiled function."""
585
+ state = brainstate.State(jnp.array([1.0, 2.0]))
586
+
587
+ def f(x):
588
+ return x * 2
589
+
590
+ sf = brainstate.transform.StatefulFunction(f)
591
+
592
+ # Compile for one input shape
593
+ sf.make_jaxpr(jnp.array([1.0, 2.0]))
594
+
595
+ # Try to get jaxpr with a different cache key
596
+ from brainstate.transform._make_jaxpr import hashabledict
597
+ fake_key = hashabledict(
598
+ static_args=(),
599
+ dyn_args=(),
600
+ static_kwargs=(),
601
+ dyn_kwargs=()
602
+ )
603
+
604
+ # Should raise detailed error
605
+ with pytest.raises(ValueError) as exc_info:
606
+ sf.get_jaxpr_by_cache(fake_key)
607
+
608
+ error_msg = str(exc_info.value)
609
+ # Should contain the requested key
610
+ self.assertIn('Requested key:', error_msg)
611
+ # Should show available keys
612
+ self.assertIn('Available', error_msg)
613
+ # Should have helpful message
614
+ self.assertIn('make_jaxpr()', error_msg)
615
+
616
+ def test_get_out_shapes_not_compiled_detailed_error(self):
617
+ """Test detailed error message when getting output shapes for uncompiled function."""
618
+
619
+ def f(x):
620
+ return x * 2
621
+
622
+ sf = brainstate.transform.StatefulFunction(f)
623
+
624
+ from brainstate.transform._make_jaxpr import hashabledict
625
+ fake_key = hashabledict(
626
+ static_args=(),
627
+ dyn_args=(),
628
+ static_kwargs=(),
629
+ dyn_kwargs=()
630
+ )
631
+
632
+ # Should raise detailed error with context "Output shapes"
633
+ with pytest.raises(ValueError) as exc_info:
634
+ sf.get_out_shapes_by_cache(fake_key)
635
+
636
+ error_msg = str(exc_info.value)
637
+ self.assertIn('Output shapes', error_msg)
638
+ self.assertIn('Requested key:', error_msg)
639
+
640
+ def test_get_out_treedef_not_compiled_detailed_error(self):
641
+ """Test detailed error message when getting output tree for uncompiled function."""
642
+
643
+ def f(x):
644
+ return x * 2
645
+
646
+ sf = brainstate.transform.StatefulFunction(f)
647
+
648
+ from brainstate.transform._make_jaxpr import hashabledict
649
+ fake_key = hashabledict(
650
+ static_args=(),
651
+ dyn_args=(),
652
+ static_kwargs=(),
653
+ dyn_kwargs=()
654
+ )
655
+
656
+ # Should raise detailed error with context "Output tree"
657
+ with pytest.raises(ValueError) as exc_info:
658
+ sf.get_out_treedef_by_cache(fake_key)
659
+
660
+ error_msg = str(exc_info.value)
661
+ self.assertIn('Output tree', error_msg)
662
+ self.assertIn('Requested key:', error_msg)
663
+
664
+ def test_get_state_trace_not_compiled_detailed_error(self):
665
+ """Test detailed error message when getting state trace for uncompiled function."""
666
+
667
+ def f(x):
668
+ return x * 2
669
+
670
+ sf = brainstate.transform.StatefulFunction(f)
671
+
672
+ from brainstate.transform._make_jaxpr import hashabledict
673
+ fake_key = hashabledict(
674
+ static_args=(),
675
+ dyn_args=(),
676
+ static_kwargs=(),
677
+ dyn_kwargs=()
678
+ )
679
+
680
+ # Should raise detailed error with context "State trace"
681
+ with pytest.raises(ValueError) as exc_info:
682
+ sf.get_state_trace_by_cache(fake_key)
683
+
684
+ error_msg = str(exc_info.value)
685
+ self.assertIn('State trace', error_msg)
686
+ self.assertIn('Requested key:', error_msg)
687
+
688
+
689
+ class TestCompileIfMiss(unittest.TestCase):
690
+ """Test compile_if_miss parameter in *_by_call methods."""
691
+
692
+ def test_get_jaxpr_by_call_with_compile_if_miss_true(self):
693
+ """Test get_jaxpr_by_call with compile_if_miss=True (default)."""
694
+
695
+ def f(x):
696
+ return x * 2
697
+
698
+ sf = brainstate.transform.StatefulFunction(f)
699
+
700
+ # Should compile automatically
701
+ jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=True)
702
+ self.assertIsNotNone(jaxpr)
703
+
704
+ def test_get_jaxpr_by_call_with_compile_if_miss_false(self):
705
+ """Test get_jaxpr_by_call with compile_if_miss=False."""
706
+
707
+ def f(x):
708
+ return x * 2
709
+
710
+ sf = brainstate.transform.StatefulFunction(f)
711
+
712
+ # Should raise error because not compiled
713
+ with pytest.raises(ValueError, match="not compiled"):
714
+ sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=False)
715
+
716
+ def test_get_out_shapes_by_call_compile_if_miss(self):
717
+ """Test get_out_shapes_by_call with compile_if_miss parameter."""
718
+ state = brainstate.State(jnp.array([1.0, 2.0]))
719
+
720
+ def f(x):
721
+ state.value += x
722
+ return state.value * 2
723
+
724
+ sf = brainstate.transform.StatefulFunction(f)
725
+
726
+ # With compile_if_miss=True, should compile automatically
727
+ shapes = sf.get_out_shapes(jnp.array([1.0, 2.0]), compile_if_miss=True)
728
+ self.assertIsNotNone(shapes)
729
+
730
+ # With compile_if_miss=False on different input, should fail
731
+ with pytest.raises(ValueError):
732
+ sf.get_out_shapes(jnp.array([1.0, 2.0, 3.0]), compile_if_miss=False)
733
+
734
+ def test_get_out_treedef_by_call_compile_if_miss(self):
735
+ """Test get_out_treedef_by_call with compile_if_miss parameter."""
736
+
737
+ def f(x):
738
+ return x * 2, x + 1
739
+
740
+ sf = brainstate.transform.StatefulFunction(f)
741
+
742
+ # Should compile automatically with default compile_if_miss=True
743
+ treedef = sf.get_out_treedef(jnp.array([1.0, 2.0]))
744
+ self.assertIsNotNone(treedef)
745
+
746
+ def test_get_state_trace_by_call_compile_if_miss(self):
747
+ """Test get_state_trace_by_call with compile_if_miss parameter."""
748
+ state = brainstate.State(jnp.array([1.0, 2.0]))
749
+
750
+ def f(x):
751
+ state.value += x
752
+ return state.value
753
+
754
+ sf = brainstate.transform.StatefulFunction(f)
755
+
756
+ # Should compile automatically
757
+ trace = sf.get_state_trace(jnp.array([1.0, 2.0]), compile_if_miss=True)
758
+ self.assertIsNotNone(trace)
759
+
760
+ def test_get_states_by_call_compile_if_miss(self):
761
+ """Test get_states_by_call with compile_if_miss parameter."""
762
+ state1 = brainstate.State(jnp.array([1.0, 2.0]))
763
+ state2 = brainstate.State(jnp.array([3.0, 4.0]))
764
+
765
+ def f(x):
766
+ state1.value += x
767
+ state2.value += x
768
+ return state1.value + state2.value
769
+
770
+ sf = brainstate.transform.StatefulFunction(f)
771
+
772
+ # Should compile automatically
773
+ states = sf.get_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
774
+ self.assertEqual(len(states), 2)
775
+
776
+ def test_get_read_states_by_call_compile_if_miss(self):
777
+ """Test get_read_states_by_call with compile_if_miss parameter."""
778
+ read_state = brainstate.State(jnp.array([1.0, 2.0]))
779
+ write_state = brainstate.State(jnp.array([3.0, 4.0]))
780
+
781
+ def f(x):
782
+ _ = read_state.value
783
+ write_state.value += x
784
+ return write_state.value
785
+
786
+ sf = brainstate.transform.StatefulFunction(f)
787
+
788
+ # Should compile automatically
789
+ read_states = sf.get_read_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
790
+ self.assertIsNotNone(read_states)
791
+
792
+ def test_get_write_states_by_call_compile_if_miss(self):
793
+ """Test get_write_states_by_call with compile_if_miss parameter."""
794
+ read_state = brainstate.State(jnp.array([1.0, 2.0]))
795
+ write_state = brainstate.State(jnp.array([3.0, 4.0]))
796
+
797
+ def f(x):
798
+ _ = read_state.value
799
+ write_state.value += x
800
+ return write_state.value
801
+
802
+ sf = brainstate.transform.StatefulFunction(f)
803
+
804
+ # Should compile automatically
805
+ write_states = sf.get_write_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
806
+ self.assertIsNotNone(write_states)
807
+
808
+ def test_compile_if_miss_default_behavior(self):
809
+ """Test that compile_if_miss defaults to True for all *_by_call methods."""
810
+ state = brainstate.State(jnp.array([1.0, 2.0]))
811
+
812
+ def f(x):
813
+ state.value += x
814
+ return state.value
815
+
816
+ sf = brainstate.transform.StatefulFunction(f)
817
+
818
+ # All these should work without explicit compile_if_miss=True
819
+ jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]))
820
+ self.assertIsNotNone(jaxpr)
821
+
822
+ # Create new instance for fresh cache
823
+ sf2 = brainstate.transform.StatefulFunction(f)
824
+ shapes = sf2.get_out_shapes(jnp.array([1.0, 2.0]))
825
+ self.assertIsNotNone(shapes)
826
+
827
+ # Create new instance for fresh cache
828
+ sf3 = brainstate.transform.StatefulFunction(f)
829
+ states = sf3.get_states(jnp.array([1.0, 2.0]))
830
+ self.assertIsNotNone(states)
831
+
832
+
833
+ class TestMakeHashable(unittest.TestCase):
834
+ """Test the make_hashable utility function."""
835
+
836
+ def test_hashable_list(self):
837
+ """Test converting list to hashable."""
838
+ result = make_hashable([1, 2, 3])
839
+ # Should return a tuple
840
+ self.assertIsInstance(result, tuple)
841
+ # Should be hashable
842
+ hash(result)
843
+
844
+ def test_hashable_dict(self):
845
+ """Test converting dict to hashable."""
846
+ result = make_hashable({'b': 2, 'a': 1})
847
+ # Should return a tuple of sorted key-value pairs
848
+ self.assertIsInstance(result, tuple)
849
+ # Should be hashable
850
+ hash(result)
851
+ # Keys should be sorted
852
+ keys = [item[0] for item in result]
853
+ self.assertEqual(keys, ['a', 'b'])
854
+
855
+ def test_hashable_set(self):
856
+ """Test converting set to hashable."""
857
+ result = make_hashable({1, 2, 3})
858
+ # Should return a frozenset
859
+ self.assertIsInstance(result, frozenset)
860
+ # Should be hashable
861
+ hash(result)
862
+
863
+ def test_hashable_nested(self):
864
+ """Test converting nested structures."""
865
+ nested = {
866
+ 'list': [1, 2, 3],
867
+ 'dict': {'a': 1, 'b': 2},
868
+ 'set': {4, 5}
869
+ }
870
+ result = make_hashable(nested)
871
+ # Should be hashable
872
+ hash(result) # Should not raise
873
+
874
+ def test_hashable_tuple(self):
875
+ """Test with tuples."""
876
+ result = make_hashable((1, 2, 3))
877
+ # Should return a tuple
878
+ self.assertIsInstance(result, tuple)
879
+ # Should be hashable
880
+ hash(result)
881
+
882
+ def test_hashable_idempotent(self):
883
+ """Test that applying make_hashable twice gives consistent results."""
884
+ original = {'a': [1, 2], 'b': {3, 4}}
885
+ result1 = make_hashable(original)
886
+ result2 = make_hashable(original)
887
+ # Should be the same
888
+ self.assertEqual(result1, result2)
889
+
890
+
891
+ class TestCacheCleanupOnError(unittest.TestCase):
892
+ """Test that cache is properly cleaned up when compilation fails."""
893
+
894
+ def test_cache_cleanup_on_compilation_error(self):
895
+ """Test that partial cache entries are cleaned up when make_jaxpr fails."""
896
+
897
+ def f(x):
898
+ # This will cause an error during JAX tracing
899
+ if x > 0: # Control flow not allowed in JAX
900
+ return x * 2
901
+ else:
902
+ return x + 1
903
+
904
+ sf = brainstate.transform.StatefulFunction(f)
905
+
906
+ # Try to compile, should fail
907
+ try:
908
+ sf.make_jaxpr(jnp.array([1.0]))
909
+ except Exception:
910
+ pass # Expected to fail
911
+
912
+ # Cache should be empty after error
913
+ stats = sf.get_cache_stats()
914
+ # All caches should be empty since error cleanup should have removed partial entries
915
+ # Note: The actual behavior depends on when the error occurs during compilation
916
+ # If error happens early, no cache entries; if late, entries might exist
917
+ # This test just verifies the cleanup mechanism exists
918
+
919
+
920
+ class TestMakeJaxprReturnOnlyWrite(unittest.TestCase):
921
+ """Test make_jaxpr with return_only_write parameter."""
922
+
923
+ def test_make_jaxpr_return_only_write(self):
924
+ """Test make_jaxpr function with return_only_write parameter."""
925
+ read_state = brainstate.State(jnp.array([1.0]))
926
+ write_state = brainstate.State(jnp.array([2.0]))
927
+
928
+ def f(x):
929
+ _ = read_state.value # Read only
930
+ write_state.value += x # Write
931
+ return x * 2
932
+
933
+ # Test with return_only_write=True
934
+ jaxpr_maker = brainstate.transform.make_jaxpr(f, return_only_write=True)
935
+ jaxpr, states = jaxpr_maker(jnp.array([1.0]))
936
+
937
+ # Should compile successfully
938
+ self.assertIsNotNone(jaxpr)
939
+ self.assertIsInstance(states, tuple)
940
+
941
+
942
+ class TestStatefulFunctionCallable(unittest.TestCase):
943
+ """Test __call__ method of StatefulFunction."""
944
+
945
+ def test_stateful_function_call(self):
946
+ """Test calling StatefulFunction directly."""
947
+ state = brainstate.State(jnp.array([1.0, 2.0]))
948
+
949
+ def f(x):
950
+ state.value += x
951
+ return state.value * 2
952
+
953
+ sf = brainstate.transform.StatefulFunction(f)
954
+ x = jnp.array([0.5, 0.5])
955
+ sf.make_jaxpr(x)
956
+
957
+ # Test direct call
958
+ result = sf(x)
959
+ self.assertEqual(result.shape, (2,))
960
+
961
+ def test_stateful_function_call_auto_compile(self):
962
+ """Test that __call__ automatically compiles if needed."""
963
+ state = brainstate.State(jnp.array([1.0, 2.0]))
964
+
965
+ def f(x):
966
+ state.value += x
967
+ return state.value * 2
968
+
969
+ sf = brainstate.transform.StatefulFunction(f)
970
+ x = jnp.array([0.5, 0.5])
971
+
972
+ # Call without pre-compilation should work
973
+ result = sf(x)
974
+ self.assertEqual(result.shape, (2,))
975
+
976
+ def test_stateful_function_multiple_calls(self):
977
+ """Test multiple calls to StatefulFunction."""
978
+ state = brainstate.State(jnp.array([0.0]))
979
+
980
+ def f(x):
981
+ state.value += x
982
+ return state.value
983
+
984
+ sf = brainstate.transform.StatefulFunction(f)
985
+
986
+ # Multiple calls should accumulate state
987
+ result1 = sf(jnp.array([1.0]))
988
+ result2 = sf(jnp.array([2.0]))
989
+ result3 = sf(jnp.array([3.0]))
990
+
991
+ # Each call should update the state
992
+ self.assertIsNotNone(result1)
993
+ self.assertIsNotNone(result2)
994
+ self.assertIsNotNone(result3)
995
+
996
+
997
+ class TestStatefulFunctionStaticArgs(unittest.TestCase):
998
+ """Test StatefulFunction with static arguments."""
999
+
1000
+ def test_static_argnums_basic(self):
1001
+ """Test basic usage of static_argnums."""
1002
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1003
+
1004
+ def f(x, multiplier):
1005
+ state.value += x
1006
+ return state.value * multiplier
1007
+
1008
+ sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
1009
+ x = jnp.array([0.5, 0.5])
1010
+
1011
+ # Compile with multiplier=2
1012
+ sf.make_jaxpr(x, 2)
1013
+ cache_key1 = sf.get_arg_cache_key(x, 2)
1014
+
1015
+ # Compile with multiplier=3
1016
+ sf.make_jaxpr(x, 3)
1017
+ cache_key2 = sf.get_arg_cache_key(x, 3)
1018
+
1019
+ # Should have different cache keys
1020
+ self.assertNotEqual(cache_key1, cache_key2)
1021
+
1022
+ def test_static_argnames_basic(self):
1023
+ """Test basic usage of static_argnames."""
1024
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1025
+
1026
+ def f(x, multiplier=2):
1027
+ state.value += x
1028
+ return state.value * multiplier
1029
+
1030
+ sf = brainstate.transform.StatefulFunction(f, static_argnames='multiplier')
1031
+ x = jnp.array([0.5, 0.5])
1032
+
1033
+ # Compile with different multiplier values
1034
+ sf.make_jaxpr(x, multiplier=2)
1035
+ cache_key1 = sf.get_arg_cache_key(x, multiplier=2)
1036
+
1037
+ sf.make_jaxpr(x, multiplier=3)
1038
+ cache_key2 = sf.get_arg_cache_key(x, multiplier=3)
1039
+
1040
+ # Should have different cache keys
1041
+ self.assertNotEqual(cache_key1, cache_key2)
1042
+
1043
+ def test_static_args_combination(self):
1044
+ """Test using both static_argnums and static_argnames."""
1045
+ state = brainstate.State(jnp.array([1.0]))
1046
+
1047
+ def f(x, multiplier, offset=0):
1048
+ state.value += x
1049
+ return state.value * multiplier + offset
1050
+
1051
+ sf = brainstate.transform.StatefulFunction(
1052
+ f, static_argnums=(1,), static_argnames='offset'
1053
+ )
1054
+ x = jnp.array([0.5])
1055
+
1056
+ # Compile with different static args
1057
+ sf.make_jaxpr(x, 2, offset=0)
1058
+ cache_key1 = sf.get_arg_cache_key(x, 2, offset=0)
1059
+
1060
+ sf.make_jaxpr(x, 3, offset=1)
1061
+ cache_key2 = sf.get_arg_cache_key(x, 3, offset=1)
1062
+
1063
+ # Should have different cache keys
1064
+ self.assertNotEqual(cache_key1, cache_key2)
1065
+
1066
+
1067
+ class TestStatefulFunctionComplexStates(unittest.TestCase):
1068
+ """Test StatefulFunction with complex state scenarios."""
1069
+
1070
+ def test_multiple_states(self):
1071
+ """Test function with multiple states."""
1072
+ state1 = brainstate.State(jnp.array([1.0]))
1073
+ state2 = brainstate.State(jnp.array([2.0]))
1074
+ state3 = brainstate.State(jnp.array([3.0]))
1075
+
1076
+ def f(x):
1077
+ state1.value += x
1078
+ state2.value += x * 2
1079
+ state3.value += x * 3
1080
+ return state1.value + state2.value + state3.value
1081
+
1082
+ sf = brainstate.transform.StatefulFunction(f)
1083
+ x = jnp.array([1.0])
1084
+ sf.make_jaxpr(x)
1085
+
1086
+ cache_key = sf.get_arg_cache_key(x)
1087
+ states = sf.get_states_by_cache(cache_key)
1088
+
1089
+ # Should track all three states
1090
+ self.assertEqual(len(states), 3)
1091
+
1092
+ def test_nested_state_access(self):
1093
+ """Test function with nested state access patterns."""
1094
+ outer_state = brainstate.State(jnp.array([1.0]))
1095
+ inner_state = brainstate.State(jnp.array([2.0]))
1096
+
1097
+ def inner_fn(x):
1098
+ inner_state.value += x
1099
+ return inner_state.value
1100
+
1101
+ def outer_fn(x):
1102
+ outer_state.value += x
1103
+ result = inner_fn(x)
1104
+ return outer_state.value + result
1105
+
1106
+ sf = brainstate.transform.StatefulFunction(outer_fn)
1107
+ x = jnp.array([1.0])
1108
+ sf.make_jaxpr(x)
1109
+
1110
+ cache_key = sf.get_arg_cache_key(x)
1111
+ states = sf.get_states_by_cache(cache_key)
1112
+
1113
+ # Should track both states
1114
+ self.assertGreaterEqual(len(states), 2)
1115
+
1116
+ def test_conditional_state_write(self):
1117
+ """Test function that conditionally writes to states."""
1118
+ state1 = brainstate.State(jnp.array([1.0]))
1119
+ state2 = brainstate.State(jnp.array([2.0]))
1120
+
1121
+ def f(x, write_state1=True):
1122
+ # Note: In JAX, actual control flow needs special handling
1123
+ # This test is more about the framework's ability to track states
1124
+ state1.value += x # Always write to state1
1125
+ state2.value += x * 2 # Always write to state2
1126
+ return state1.value + state2.value
1127
+
1128
+ sf = brainstate.transform.StatefulFunction(f, static_argnames='write_state1')
1129
+ x = jnp.array([1.0])
1130
+ sf.make_jaxpr(x, write_state1=True)
1131
+
1132
+ cache_key = sf.get_arg_cache_key(x, write_state1=True)
1133
+ states = sf.get_states_by_cache(cache_key)
1134
+
1135
+ # Should track states
1136
+ self.assertGreaterEqual(len(states), 2)
1137
+
1138
+
1139
+ class TestStatefulFunctionOutputShapes(unittest.TestCase):
1140
+ """Test StatefulFunction output shape tracking."""
1141
+
1142
+ def test_single_output(self):
1143
+ """Test tracking single output shape."""
1144
+ state = brainstate.State(jnp.array([1.0, 2.0, 3.0]))
1145
+
1146
+ def f(x):
1147
+ state.value += x
1148
+ return state.value
1149
+
1150
+ sf = brainstate.transform.StatefulFunction(f)
1151
+ x = jnp.array([1.0, 2.0, 3.0])
1152
+ sf.make_jaxpr(x)
1153
+
1154
+ cache_key = sf.get_arg_cache_key(x)
1155
+ out_shapes = sf.get_out_shapes_by_cache(cache_key)
1156
+
1157
+ # Should have output shapes
1158
+ self.assertIsNotNone(out_shapes)
1159
+
1160
+ def test_multiple_outputs(self):
1161
+ """Test tracking multiple output shapes."""
1162
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1163
+
1164
+ def f(x):
1165
+ state.value += x
1166
+ return state.value, state.value * 2, jnp.sum(state.value)
1167
+
1168
+ sf = brainstate.transform.StatefulFunction(f)
1169
+ x = jnp.array([1.0, 2.0])
1170
+ sf.make_jaxpr(x)
1171
+
1172
+ cache_key = sf.get_arg_cache_key(x)
1173
+ out_shapes = sf.get_out_shapes_by_cache(cache_key)
1174
+
1175
+ # Should track all output shapes
1176
+ self.assertIsNotNone(out_shapes)
1177
+
1178
+ def test_nested_output_structure(self):
1179
+ """Test tracking nested output structures."""
1180
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1181
+
1182
+ def f(x):
1183
+ state.value += x
1184
+ return {
1185
+ 'sum': jnp.sum(state.value),
1186
+ 'prod': jnp.prod(state.value),
1187
+ 'values': state.value
1188
+ }
1189
+
1190
+ sf = brainstate.transform.StatefulFunction(f)
1191
+ x = jnp.array([1.0, 2.0])
1192
+ sf.make_jaxpr(x)
1193
+
1194
+ cache_key = sf.get_arg_cache_key(x)
1195
+ out_treedef = sf.get_out_treedef_by_cache(cache_key)
1196
+
1197
+ # Should have tree definition
1198
+ self.assertIsNotNone(out_treedef)
1199
+
1200
+
1201
+ class TestStatefulFunctionJaxprCall(unittest.TestCase):
1202
+ """Test jaxpr_call and jaxpr_call_auto methods."""
1203
+
1204
+ def test_jaxpr_call_basic(self):
1205
+ """Test basic jaxpr_call usage."""
1206
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1207
+
1208
+ def f(x):
1209
+ state.value += x
1210
+ return state.value * 2
1211
+
1212
+ sf = brainstate.transform.StatefulFunction(f)
1213
+ x = jnp.array([0.5, 0.5])
1214
+ sf.make_jaxpr(x)
1215
+
1216
+ # Get current state values
1217
+ state_vals = [state.value]
1218
+
1219
+ # Call at jaxpr level
1220
+ new_state_vals, out = sf.jaxpr_call(state_vals, x)
1221
+
1222
+ self.assertEqual(len(new_state_vals), 1)
1223
+ self.assertEqual(out.shape, (2,))
1224
+
1225
+ def test_jaxpr_call_auto_basic(self):
1226
+ """Test basic jaxpr_call_auto usage."""
1227
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1228
+
1229
+ def f(x):
1230
+ state.value += x
1231
+ return state.value * 2
1232
+
1233
+ sf = brainstate.transform.StatefulFunction(f)
1234
+ x = jnp.array([0.5, 0.5])
1235
+ sf.make_jaxpr(x)
1236
+
1237
+ # Call with automatic state management
1238
+ result = sf.jaxpr_call_auto(x)
1239
+
1240
+ self.assertEqual(result.shape, (2,))
1241
+
1242
+ def test_jaxpr_call_preserves_state_order(self):
1243
+ """Test that jaxpr_call preserves state order."""
1244
+ state1 = brainstate.State(jnp.array([1.0]))
1245
+ state2 = brainstate.State(jnp.array([2.0]))
1246
+ state3 = brainstate.State(jnp.array([3.0]))
1247
+
1248
+ def f(x):
1249
+ state1.value += x
1250
+ state2.value += x * 2
1251
+ state3.value += x * 3
1252
+ return state1.value + state2.value + state3.value
1253
+
1254
+ sf = brainstate.transform.StatefulFunction(f)
1255
+ x = jnp.array([1.0])
1256
+ sf.make_jaxpr(x)
1257
+
1258
+ cache_key = sf.get_arg_cache_key(x)
1259
+ states = sf.get_states_by_cache(cache_key)
1260
+
1261
+ # Get initial state values
1262
+ state_vals = [s.value for s in states]
1263
+
1264
+ # Call at jaxpr level
1265
+ new_state_vals, _ = sf.jaxpr_call(state_vals, x)
1266
+
1267
+ # Should return same number of states
1268
+ self.assertEqual(len(new_state_vals), len(state_vals))
1269
+
1270
+
1271
+ class TestStatefulFunctionEdgeCases(unittest.TestCase):
1272
+ """Test edge cases and corner scenarios."""
1273
+
1274
+ def test_no_state_function(self):
1275
+ """Test function that doesn't use any states."""
1276
+
1277
+ def f(x):
1278
+ return x * 2 + 1
1279
+
1280
+ sf = brainstate.transform.StatefulFunction(f)
1281
+ x = jnp.array([1.0, 2.0])
1282
+ sf.make_jaxpr(x)
1283
+
1284
+ cache_key = sf.get_arg_cache_key(x)
1285
+ states = sf.get_states_by_cache(cache_key)
1286
+
1287
+ # Should have no states
1288
+ self.assertEqual(len(states), 0)
1289
+
1290
+ def test_read_only_state(self):
1291
+ """Test function that only reads from states."""
1292
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1293
+
1294
+ def f(x):
1295
+ # Only read from state, don't write
1296
+ return state.value + x
1297
+
1298
+ sf = brainstate.transform.StatefulFunction(f, return_only_write=True)
1299
+ x = jnp.array([1.0, 2.0])
1300
+ sf.make_jaxpr(x)
1301
+
1302
+ cache_key = sf.get_arg_cache_key(x)
1303
+ write_states = sf.get_write_states_by_cache(cache_key)
1304
+
1305
+ # Should have no write states
1306
+ self.assertEqual(len(write_states), 0)
1307
+
1308
+ def test_scalar_inputs_outputs(self):
1309
+ """Test with scalar inputs and outputs."""
1310
+ state = brainstate.State(jnp.array(1.0))
1311
+
1312
+ def f(x):
1313
+ state.value += x
1314
+ return state.value
1315
+
1316
+ sf = brainstate.transform.StatefulFunction(f)
1317
+ x = jnp.array(0.5)
1318
+ sf.make_jaxpr(x)
1319
+
1320
+ cache_key = sf.get_arg_cache_key(x)
1321
+ jaxpr = sf.get_jaxpr_by_cache(cache_key)
1322
+
1323
+ # Should compile successfully
1324
+ self.assertIsNotNone(jaxpr)
1325
+
1326
+ def test_empty_function(self):
1327
+ """Test function with no operations."""
1328
+
1329
+ def f(x):
1330
+ return x
1331
+
1332
+ sf = brainstate.transform.StatefulFunction(f)
1333
+ x = jnp.array([1.0, 2.0])
1334
+ sf.make_jaxpr(x)
1335
+
1336
+ cache_key = sf.get_arg_cache_key(x)
1337
+ jaxpr = sf.get_jaxpr_by_cache(cache_key)
1338
+
1339
+ # Should compile successfully
1340
+ self.assertIsNotNone(jaxpr)
1341
+
1342
+ def test_complex_dtype(self):
1343
+ """Test with complex dtype arrays."""
1344
+ state = brainstate.State(jnp.array([1.0 + 2.0j, 3.0 + 4.0j]))
1345
+
1346
+ def f(x):
1347
+ state.value += x
1348
+ return state.value
1349
+
1350
+ sf = brainstate.transform.StatefulFunction(f)
1351
+ x = jnp.array([0.5 + 0.5j, 0.5 + 0.5j])
1352
+ sf.make_jaxpr(x)
1353
+
1354
+ cache_key = sf.get_arg_cache_key(x)
1355
+ jaxpr = sf.get_jaxpr_by_cache(cache_key)
1356
+
1357
+ # Should compile successfully
1358
+ self.assertIsNotNone(jaxpr)
1359
+
1360
+
1361
+ class TestStatefulFunctionCacheKey(unittest.TestCase):
1362
+ """Test cache key generation and behavior."""
1363
+
1364
+ def test_cache_key_different_shapes(self):
1365
+ """Test that different input shapes produce different cache keys."""
1366
+
1367
+ def f(x):
1368
+ return x * 2
1369
+
1370
+ sf = brainstate.transform.StatefulFunction(f)
1371
+
1372
+ x1 = jnp.array([1.0, 2.0])
1373
+ x2 = jnp.array([1.0, 2.0, 3.0])
1374
+
1375
+ cache_key1 = sf.get_arg_cache_key(x1)
1376
+ cache_key2 = sf.get_arg_cache_key(x2)
1377
+
1378
+ # Should have different cache keys
1379
+ self.assertNotEqual(cache_key1, cache_key2)
1380
+
1381
+ def test_cache_key_different_dtypes(self):
1382
+ """Test that different dtypes produce different cache keys."""
1383
+
1384
+ def f(x):
1385
+ return x * 2
1386
+
1387
+ sf = brainstate.transform.StatefulFunction(f)
1388
+
1389
+ # Use int32 and float32 instead, which are always available in JAX
1390
+ x1 = jnp.array([1.0, 2.0], dtype=jnp.float32)
1391
+ x2 = jnp.array([1, 2], dtype=jnp.int32)
1392
+
1393
+ cache_key1 = sf.get_arg_cache_key(x1)
1394
+ cache_key2 = sf.get_arg_cache_key(x2)
1395
+
1396
+ # Should have different cache keys due to different dtypes
1397
+ self.assertNotEqual(cache_key1, cache_key2)
1398
+
1399
+ def test_cache_key_same_abstract_values(self):
1400
+ """Test that same abstract values produce same cache keys."""
1401
+
1402
+ def f(x):
1403
+ return x * 2
1404
+
1405
+ sf = brainstate.transform.StatefulFunction(f)
1406
+
1407
+ x1 = jnp.array([1.0, 2.0])
1408
+ x2 = jnp.array([3.0, 4.0]) # Different values, same shape/dtype
1409
+
1410
+ cache_key1 = sf.get_arg_cache_key(x1)
1411
+ cache_key2 = sf.get_arg_cache_key(x2)
1412
+
1413
+ # Should have same cache keys (abstract values are the same)
1414
+ self.assertEqual(cache_key1, cache_key2)
1415
+
1416
+ def test_cache_key_with_pytree_inputs(self):
1417
+ """Test cache key generation with pytree inputs."""
1418
+
1419
+ def f(inputs):
1420
+ x, y = inputs
1421
+ return x + y
1422
+
1423
+ sf = brainstate.transform.StatefulFunction(f)
1424
+
1425
+ inputs1 = (jnp.array([1.0]), jnp.array([2.0]))
1426
+ inputs2 = (jnp.array([3.0]), jnp.array([4.0]))
1427
+
1428
+ cache_key1 = sf.get_arg_cache_key(inputs1)
1429
+ cache_key2 = sf.get_arg_cache_key(inputs2)
1430
+
1431
+ # Should have same cache keys (same structure/shapes)
1432
+ self.assertEqual(cache_key1, cache_key2)
1433
+
1434
+
1435
+ class TestStatefulFunctionRecompilation(unittest.TestCase):
1436
+ """Test recompilation scenarios."""
1437
+
1438
+ def test_cache_reuse(self):
1439
+ """Test that cache is reused for same inputs."""
1440
+ state = brainstate.State(jnp.array([1.0]))
1441
+
1442
+ def f(x):
1443
+ state.value += x
1444
+ return state.value
1445
+
1446
+ sf = brainstate.transform.StatefulFunction(f)
1447
+
1448
+ x = jnp.array([1.0])
1449
+
1450
+ # First compilation
1451
+ sf.make_jaxpr(x)
1452
+ stats1 = sf.get_cache_stats()
1453
+
1454
+ # Second call with same shape should reuse cache
1455
+ sf.make_jaxpr(x)
1456
+ stats2 = sf.get_cache_stats()
1457
+
1458
+ # Cache size should remain the same
1459
+ self.assertEqual(
1460
+ stats1['jaxpr_cache']['size'],
1461
+ stats2['jaxpr_cache']['size']
1462
+ )
1463
+
1464
+ def test_multiple_compilations_different_shapes(self):
1465
+ """Test multiple compilations with different shapes."""
1466
+ state = brainstate.State(jnp.array([1.0]))
1467
+
1468
+ def f(x):
1469
+ return x * 2
1470
+
1471
+ sf = brainstate.transform.StatefulFunction(f)
1472
+
1473
+ # Compile for different shapes
1474
+ shapes = [
1475
+ jnp.array([1.0]),
1476
+ jnp.array([1.0, 2.0]),
1477
+ jnp.array([1.0, 2.0, 3.0]),
1478
+ ]
1479
+
1480
+ for x in shapes:
1481
+ sf.make_jaxpr(x)
1482
+
1483
+ stats = sf.get_cache_stats()
1484
+
1485
+ # Should have 3 different cache entries
1486
+ self.assertEqual(stats['jaxpr_cache']['size'], 3)
1487
+
1488
+ def test_clear_and_recompile(self):
1489
+ """Test clearing cache and recompiling."""
1490
+ state = brainstate.State(jnp.array([1.0]))
1491
+
1492
+ def f(x):
1493
+ state.value += x
1494
+ return state.value
1495
+
1496
+ sf = brainstate.transform.StatefulFunction(f)
1497
+ x = jnp.array([1.0])
1498
+
1499
+ # Compile
1500
+ sf.make_jaxpr(x)
1501
+ stats_before = sf.get_cache_stats()
1502
+ self.assertGreater(stats_before['jaxpr_cache']['size'], 0)
1503
+
1504
+ # Clear cache
1505
+ sf.clear_cache()
1506
+ stats_after_clear = sf.get_cache_stats()
1507
+ self.assertEqual(stats_after_clear['jaxpr_cache']['size'], 0)
1508
+
1509
+ # Recompile
1510
+ sf.make_jaxpr(x)
1511
+ stats_after_recompile = sf.get_cache_stats()
1512
+ self.assertGreater(stats_after_recompile['jaxpr_cache']['size'], 0)
1513
+
1514
+
1515
+ class TestStatefulMapping(unittest.TestCase):
1516
+ def test_state_filters_and_caching(self):
1517
+ counter = brainstate.ShortTermState(jnp.zeros(3))
1518
+
1519
+ def accumulate(x):
1520
+ counter.value = counter.value + x
1521
+ return counter.value
1522
+
1523
+ mapper = brainstate.transform.StatefulMapping(
1524
+ accumulate,
1525
+ in_axes=0,
1526
+ out_axes=0,
1527
+ state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1528
+ state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1529
+ )
1530
+
1531
+ xs = jnp.asarray([1.0, 2.0, 3.0])
1532
+ result = mapper(xs)
1533
+ self.assertTrue(jnp.allclose(result, xs))
1534
+ self.assertTrue(jnp.allclose(counter.value, xs))
1535
+
1536
+ def test_random_state_restoration(self):
1537
+ rng_state = brainstate.random.RandomState(0)
1538
+
1539
+ def draw(_):
1540
+ key = rng_state.split_key()
1541
+ return jr.normal(key, ())
1542
+
1543
+ mapper = brainstate.transform.StatefulMapping(
1544
+ draw,
1545
+ in_axes=0,
1546
+ out_axes=0,
1547
+ )
1548
+
1549
+ xs = jnp.ones((4,))
1550
+ before = rng_state.value
1551
+ samples = mapper(xs)
1552
+ self.assertEqual(samples.shape, xs.shape)
1553
+ self.assertFalse(jnp.allclose(samples, jnp.repeat(samples[0], xs.shape[0])))
1554
+ self.assertTrue(jnp.array_equal(rng_state.value.shape, before.shape))
1555
+
1556
+ def test_inconsistent_batch_sizes_raise(self):
1557
+ tracker = brainstate.ShortTermState(jnp.array(0.0))
1558
+
1559
+ def combine(x, y):
1560
+ tracker.value = tracker.value + x + y
1561
+ return tracker.value
1562
+
1563
+ mapper = brainstate.transform.StatefulMapping(
1564
+ combine,
1565
+ in_axes=(0, 0),
1566
+ out_axes=0,
1567
+ state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1568
+ state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
1569
+ )
1570
+
1571
+ with self.assertRaisesRegex(ValueError, "Inconsistent batch sizes"):
1572
+ mapper(jnp.ones((3,)), jnp.ones((4,)))
1573
+
1574
+ def test_unexpected_out_state_mapping_raise(self):
1575
+ leak = brainstate.ShortTermState(jnp.array(0.0))
1576
+
1577
+ def mutate(x):
1578
+ leak.value = leak.value + x
1579
+ return x
1580
+
1581
+ mapper = brainstate.transform.StatefulMapping(
1582
+ mutate,
1583
+ in_axes=0,
1584
+ out_axes=0,
1585
+ state_in_axes={},
1586
+ state_out_axes={},
1587
+ unexpected_out_state_mapping='raise',
1588
+ )
1589
+
1590
+ with self.assertRaises(BatchAxisError):
1591
+ mapper(jnp.ones((2,)))
1592
+
1593
+ def test_unexpected_out_state_mapping_warn(self):
1594
+ leak = brainstate.ShortTermState(jnp.array(0.0))
1595
+
1596
+ def mutate(x):
1597
+ leak.value = leak.value + x
1598
+ return x
1599
+
1600
+ mapper = brainstate.transform.StatefulMapping(
1601
+ mutate,
1602
+ in_axes=0,
1603
+ out_axes=0,
1604
+ state_in_axes={},
1605
+ state_out_axes={},
1606
+ unexpected_out_state_mapping='warn',
1607
+ )
1608
+
1609
+ with pytest.warns(UserWarning):
1610
+ mapper(jnp.ones((2,)))
1611
+ self.assertTrue(jnp.allclose(leak.value, 1.0))
1612
+
1613
+ def test_unexpected_out_state_mapping_ignore(self):
1614
+ leak = brainstate.ShortTermState(jnp.array(0.0))
1615
+
1616
+ def mutate(x):
1617
+ leak.value = leak.value + x
1618
+ return x
1619
+
1620
+ mapper = brainstate.transform.StatefulMapping(
1621
+ mutate,
1622
+ in_axes=0,
1623
+ out_axes=0,
1624
+ state_in_axes={},
1625
+ state_out_axes={},
1626
+ unexpected_out_state_mapping='ignore',
1627
+ )
1628
+
1629
+ with warnings.catch_warnings(record=True) as caught:
1630
+ warnings.simplefilter('always')
1631
+ mapper(jnp.ones((2,)))
1632
+ self.assertEqual(len(caught), 0)
1633
+ self.assertTrue(jnp.allclose(leak.value, 1.0))
1634
+