brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1510 +1,1510 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import threading
18
- import unittest
19
-
20
- import jax
21
- import jax.numpy as jnp
22
- import pytest
23
-
24
- import brainstate
25
- from brainstate._error import BatchAxisError
26
- from brainstate._compatible_import import jaxpr_as_fun
27
- from brainstate.transform._make_jaxpr import _BoundedCache, make_hashable
28
-
29
-
30
- class TestMakeJaxpr(unittest.TestCase):
31
- def test_compar_jax_make_jaxpr(self):
32
- def func4(arg): # Arg is a pair
33
- temp = arg[0] + jnp.sin(arg[1]) * 3.
34
- c = brainstate.random.rand_like(arg[0])
35
- return jnp.sum(temp + c)
36
-
37
- key = brainstate.random.DEFAULT.value
38
- jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
39
- print(jaxpr)
40
- self.assertTrue(len(jaxpr.in_avals) == 2)
41
- self.assertTrue(len(jaxpr.consts) == 1)
42
- self.assertTrue(len(jaxpr.out_avals) == 1)
43
- self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
44
-
45
- brainstate.random.seed(1)
46
- print(brainstate.random.DEFAULT.value)
47
-
48
- jaxpr2, states = brainstate.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
49
- print(jaxpr2)
50
- self.assertTrue(len(jaxpr2.in_avals) == 3)
51
- self.assertTrue(len(jaxpr2.out_avals) == 2)
52
- self.assertTrue(len(jaxpr2.consts) == 0)
53
- print(brainstate.random.DEFAULT.value)
54
-
55
- def test_StatefulFunction_1(self):
56
- def func4(arg): # Arg is a pair
57
- temp = arg[0] + jnp.sin(arg[1]) * 3.
58
- c = brainstate.random.rand_like(arg[0])
59
- return jnp.sum(temp + c)
60
-
61
- fun = brainstate.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
62
- cache_key = fun.get_arg_cache_key((jnp.zeros(8), jnp.ones(8)))
63
- print(fun.get_states_by_cache(cache_key))
64
- print(fun.get_jaxpr_by_cache(cache_key))
65
-
66
- def test_StatefulFunction_2(self):
67
- st1 = brainstate.State(jnp.ones(10))
68
-
69
- def f1(x):
70
- st1.value = x + st1.value
71
-
72
- def f2(x):
73
- jaxpr = brainstate.transform.make_jaxpr(f1)(x)
74
- c = 1. + x
75
- return c
76
-
77
- def f3(x):
78
- jaxpr = brainstate.transform.make_jaxpr(f1)(x)
79
- c = 1.
80
- return c
81
-
82
- print()
83
- jaxpr = brainstate.transform.make_jaxpr(f1)(jnp.zeros(1))
84
- print(jaxpr)
85
- jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
86
- print(jaxpr)
87
- jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
88
- print(jaxpr)
89
- jaxpr, _ = brainstate.transform.make_jaxpr(f3)(jnp.zeros(1))
90
- print(jaxpr)
91
- self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
92
- f3(jnp.zeros(1))))
93
-
94
- def test_compare_jax_make_jaxpr2(self):
95
- st1 = brainstate.State(jnp.ones(10))
96
-
97
- def fa(x):
98
- st1.value = x + st1.value
99
-
100
- def ffa(x):
101
- jaxpr, states = brainstate.transform.make_jaxpr(fa)(x)
102
- c = 1. + x
103
- return c
104
-
105
- jaxpr, states = brainstate.transform.make_jaxpr(ffa)(jnp.zeros(1))
106
- print()
107
- print(jaxpr)
108
- print(states)
109
- print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
110
- jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
111
- print(jaxpr)
112
- print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
113
-
114
- def test_compare_jax_make_jaxpr3(self):
115
- def fa(x):
116
- return 1.
117
-
118
- jaxpr, states = brainstate.transform.make_jaxpr(fa)(jnp.zeros(1))
119
- print()
120
- print(jaxpr)
121
- print(states)
122
- # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
- jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
124
- print(jaxpr)
125
- # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
126
-
127
- def test_static_argnames(self):
128
- def func4(a, b): # Arg is a pair
129
- temp = a + jnp.sin(b) * 3.
130
- c = brainstate.random.rand_like(a)
131
- return jnp.sum(temp + c)
132
-
133
- jaxpr, states = brainstate.transform.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
134
- print()
135
- print(jaxpr)
136
- print(states)
137
-
138
- def test_state_in(self):
139
- def f(a):
140
- return a.value
141
-
142
- with pytest.raises(ValueError):
143
- brainstate.transform.StatefulFunction(f).make_jaxpr(brainstate.State(1.))
144
-
145
- def test_state_out(self):
146
- def f(a):
147
- return brainstate.State(a)
148
-
149
- with pytest.raises(ValueError):
150
- brainstate.transform.StatefulFunction(f).make_jaxpr(1.)
151
-
152
- def test_return_states(self):
153
- a = brainstate.State(jnp.ones(3))
154
-
155
- @brainstate.transform.jit
156
- def f():
157
- return a
158
-
159
- with pytest.raises(ValueError):
160
- f()
161
-
162
-
163
- class TestBoundedCache(unittest.TestCase):
164
- """Test the _BoundedCache class."""
165
-
166
- def test_cache_basic_operations(self):
167
- """Test basic get and set operations."""
168
- cache = _BoundedCache(maxsize=3)
169
-
170
- # Test set and get
171
- cache.set('key1', 'value1')
172
- self.assertEqual(cache.get('key1'), 'value1')
173
-
174
- # Test default value
175
- self.assertIsNone(cache.get('nonexistent'))
176
- self.assertEqual(cache.get('nonexistent', 'default'), 'default')
177
-
178
- # Test __contains__
179
- self.assertIn('key1', cache)
180
- self.assertNotIn('key2', cache)
181
-
182
- # Test __len__
183
- self.assertEqual(len(cache), 1)
184
-
185
- def test_cache_lru_eviction(self):
186
- """Test LRU eviction when cache is full."""
187
- cache = _BoundedCache(maxsize=3)
188
-
189
- # Fill cache
190
- cache.set('key1', 'value1')
191
- cache.set('key2', 'value2')
192
- cache.set('key3', 'value3')
193
- self.assertEqual(len(cache), 3)
194
-
195
- # Add one more, should evict key1 (least recently used)
196
- cache.set('key4', 'value4')
197
- self.assertEqual(len(cache), 3)
198
- self.assertNotIn('key1', cache)
199
- self.assertIn('key4', cache)
200
-
201
- # Access key2 to make it recently used
202
- cache.get('key2')
203
-
204
- # Add another key, should evict key3 (now least recently used)
205
- cache.set('key5', 'value5')
206
- self.assertNotIn('key3', cache)
207
- self.assertIn('key2', cache)
208
-
209
- def test_cache_update_existing(self):
210
- """Test updating an existing key."""
211
- cache = _BoundedCache(maxsize=2)
212
-
213
- cache.set('key1', 'value1')
214
- cache.set('key2', 'value2')
215
-
216
- # Update key1 (should move it to end)
217
- cache.replace('key1', 'updated_value1')
218
- self.assertEqual(cache.get('key1'), 'updated_value1')
219
-
220
- # Add new key, should evict key2 (now LRU)
221
- cache.set('key3', 'value3')
222
- self.assertNotIn('key2', cache)
223
- self.assertIn('key1', cache)
224
-
225
- def test_cache_statistics(self):
226
- """Test cache statistics tracking."""
227
- cache = _BoundedCache(maxsize=5)
228
-
229
- # Initial stats
230
- stats = cache.get_stats()
231
- self.assertEqual(stats['size'], 0)
232
- self.assertEqual(stats['maxsize'], 5)
233
- self.assertEqual(stats['hits'], 0)
234
- self.assertEqual(stats['misses'], 0)
235
- self.assertEqual(stats['hit_rate'], 0.0)
236
-
237
- # Add items and test hits/misses
238
- cache.set('key1', 'value1')
239
- cache.set('key2', 'value2')
240
-
241
- # Generate hits
242
- cache.get('key1') # hit
243
- cache.get('key1') # hit
244
- cache.get('key3') # miss
245
- cache.get('key2') # hit
246
-
247
- stats = cache.get_stats()
248
- self.assertEqual(stats['size'], 2)
249
- self.assertEqual(stats['hits'], 3)
250
- self.assertEqual(stats['misses'], 1)
251
- self.assertEqual(stats['hit_rate'], 75.0)
252
-
253
- def test_cache_clear(self):
254
- """Test clearing the cache."""
255
- cache = _BoundedCache(maxsize=5)
256
-
257
- # Add items
258
- cache.set('key1', 'value1')
259
- cache.set('key2', 'value2')
260
- cache.get('key1') # Generate a hit
261
-
262
- # Clear cache
263
- cache.clear()
264
-
265
- self.assertEqual(len(cache), 0)
266
- self.assertNotIn('key1', cache)
267
-
268
- # Check stats are reset
269
- stats = cache.get_stats()
270
- self.assertEqual(stats['hits'], 0)
271
- self.assertEqual(stats['misses'], 0)
272
-
273
- def test_cache_keys(self):
274
- """Test getting all cache keys."""
275
- cache = _BoundedCache(maxsize=5)
276
-
277
- cache.set('key1', 'value1')
278
- cache.set('key2', 'value2')
279
- cache.set('key3', 'value3')
280
-
281
- keys = cache.keys()
282
- self.assertEqual(set(keys), {'key1', 'key2', 'key3'})
283
-
284
- def test_cache_set_duplicate_raises(self):
285
- """Test that setting an existing key raises ValueError."""
286
- cache = _BoundedCache(maxsize=5)
287
-
288
- cache.set('key1', 'value1')
289
-
290
- # Attempting to set the same key should raise ValueError
291
- with pytest.raises(ValueError, match="Cache key already exists"):
292
- cache.set('key1', 'value2')
293
-
294
- def test_cache_pop(self):
295
- """Test pop method."""
296
- cache = _BoundedCache(maxsize=5)
297
-
298
- cache.set('key1', 'value1')
299
- cache.set('key2', 'value2')
300
-
301
- # Pop existing key
302
- value = cache.pop('key1')
303
- self.assertEqual(value, 'value1')
304
- self.assertNotIn('key1', cache)
305
- self.assertEqual(len(cache), 1)
306
-
307
- # Pop non-existent key with default
308
- value = cache.pop('nonexistent', 'default')
309
- self.assertEqual(value, 'default')
310
-
311
- # Pop non-existent key without default
312
- value = cache.pop('nonexistent')
313
- self.assertIsNone(value)
314
-
315
- def test_cache_replace(self):
316
- """Test replace method."""
317
- cache = _BoundedCache(maxsize=5)
318
-
319
- cache.set('key1', 'value1')
320
- cache.set('key2', 'value2')
321
-
322
- # Replace existing key
323
- cache.replace('key1', 'new_value1')
324
- self.assertEqual(cache.get('key1'), 'new_value1')
325
-
326
- # Replacing should move to end (most recently used)
327
- cache.set('key3', 'value3')
328
- cache.replace('key2', 'new_value2')
329
-
330
- # Add more items to test LRU behavior
331
- cache.set('key4', 'value4')
332
- cache.set('key5', 'value5')
333
-
334
- # Now when we add key6, key1 should be evicted (oldest after replace moved key2 to end)
335
- cache.set('key6', 'value6')
336
-
337
- # key2 should still be there because replace moved it to end
338
- self.assertIn('key2', cache)
339
-
340
- def test_cache_replace_nonexistent_raises(self):
341
- """Test that replacing a non-existent key raises KeyError."""
342
- cache = _BoundedCache(maxsize=5)
343
-
344
- with pytest.raises(KeyError, match="Cache key does not exist"):
345
- cache.replace('nonexistent', 'value')
346
-
347
- def test_cache_get_with_raise_on_miss(self):
348
- """Test get method with raise_on_miss parameter."""
349
- cache = _BoundedCache(maxsize=5)
350
-
351
- cache.set('key1', 'value1')
352
- cache.set('key2', 'value2')
353
-
354
- # Should work normally for existing key
355
- value = cache.get('key1', raise_on_miss=True)
356
- self.assertEqual(value, 'value1')
357
-
358
- # Should raise ValueError for missing key with raise_on_miss=True
359
- with pytest.raises(ValueError, match="not compiled for the requested cache key"):
360
- cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
361
-
362
- def test_cache_detailed_error_message(self):
363
- """Test that error message shows available keys."""
364
- cache = _BoundedCache(maxsize=5)
365
-
366
- cache.set('key1', 'value1')
367
- cache.set('key2', 'value2')
368
-
369
- # Error should include all available keys
370
- with pytest.raises(ValueError) as exc_info:
371
- cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
372
-
373
- error_msg = str(exc_info.value)
374
- # Should show requested key
375
- self.assertIn('nonexistent', error_msg)
376
- # Should show available keys
377
- self.assertIn('key1', error_msg)
378
- self.assertIn('key2', error_msg)
379
- # Should have helpful message
380
- self.assertIn('make_jaxpr()', error_msg)
381
-
382
- def test_cache_error_message_no_keys(self):
383
- """Test error message when cache is empty."""
384
- cache = _BoundedCache(maxsize=5)
385
-
386
- with pytest.raises(ValueError) as exc_info:
387
- cache.get('key', raise_on_miss=True, error_context="Empty cache")
388
-
389
- error_msg = str(exc_info.value)
390
- # Should indicate no keys available
391
- self.assertIn('none', error_msg.lower())
392
-
393
- def test_cache_thread_safety(self):
394
- """Test thread safety of cache operations."""
395
- cache = _BoundedCache(maxsize=100)
396
- errors = []
397
-
398
- def worker(thread_id):
399
- try:
400
- for i in range(50):
401
- key = f'key_{thread_id}_{i}'
402
- cache.set(key, f'value_{thread_id}_{i}')
403
- value = cache.get(key)
404
- if value != f'value_{thread_id}_{i}':
405
- errors.append(f'Mismatch in thread {thread_id}')
406
- except Exception as e:
407
- errors.append(f'Error in thread {thread_id}: {e}')
408
-
409
- # Create multiple threads
410
- threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
411
-
412
- # Start all threads
413
- for t in threads:
414
- t.start()
415
-
416
- # Wait for all threads to complete
417
- for t in threads:
418
- t.join()
419
-
420
- # Check no errors occurred
421
- self.assertEqual(len(errors), 0, f"Thread safety errors: {errors}")
422
-
423
-
424
- class TestStatefulFunctionEnhancements(unittest.TestCase):
425
- """Test enhancements to StatefulFunction class."""
426
-
427
- def test_cache_stats(self):
428
- """Test get_cache_stats method."""
429
- state = brainstate.State(jnp.array([1.0, 2.0]))
430
-
431
- def f(x):
432
- state.value += x
433
- return state.value * 2
434
-
435
- sf = brainstate.transform.StatefulFunction(f)
436
-
437
- # Compile for different inputs
438
- x1 = jnp.array([0.5, 0.5])
439
- x2 = jnp.array([1.0, 1.0])
440
-
441
- sf.make_jaxpr(x1)
442
- sf.make_jaxpr(x2)
443
-
444
- # Get cache stats
445
- stats = sf.get_cache_stats()
446
-
447
- # Verify all cache types are present
448
- self.assertIn('jaxpr_cache', stats)
449
- self.assertIn('out_shapes_cache', stats)
450
- self.assertIn('jaxpr_out_tree_cache', stats)
451
- self.assertIn('state_trace_cache', stats)
452
-
453
- # Verify each cache has proper stats
454
- for cache_name, cache_stats in stats.items():
455
- self.assertIn('size', cache_stats)
456
- self.assertIn('maxsize', cache_stats)
457
- self.assertIn('hits', cache_stats)
458
- self.assertIn('misses', cache_stats)
459
- self.assertIn('hit_rate', cache_stats)
460
-
461
- def test_validate_states(self):
462
- """Test validate_states method."""
463
- state = brainstate.State(jnp.array([1.0, 2.0]))
464
-
465
- def f(x):
466
- state.value += x
467
- return state.value
468
-
469
- sf = brainstate.transform.StatefulFunction(f)
470
- x = jnp.array([0.5, 0.5])
471
- sf.make_jaxpr(x)
472
-
473
- cache_key = sf.get_arg_cache_key(x)
474
-
475
- # Should validate successfully
476
- result = sf.validate_states(cache_key)
477
- self.assertTrue(result)
478
-
479
- def test_validate_all_states(self):
480
- """Test validate_all_states method."""
481
- state = brainstate.State(jnp.array([1.0, 2.0]))
482
-
483
- def f(x, n):
484
- state.value += x
485
- return state.value * n
486
-
487
- # Use static_argnums to create different cache keys
488
- sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
489
-
490
- # Compile for multiple inputs with different static args
491
- x = jnp.array([0.5, 0.5])
492
-
493
- sf.make_jaxpr(x, 1)
494
- sf.make_jaxpr(x, 2)
495
-
496
- # Validate all
497
- results = sf.validate_all_states()
498
-
499
- # Should have results for both cache keys
500
- self.assertEqual(len(results), 2)
501
-
502
- # All should be valid
503
- for result in results.values():
504
- self.assertTrue(result)
505
-
506
- def test_clear_cache(self):
507
- """Test clear_cache method."""
508
- state = brainstate.State(jnp.array([1.0, 2.0]))
509
-
510
- def f(x):
511
- state.value += x
512
- return state.value
513
-
514
- sf = brainstate.transform.StatefulFunction(f)
515
- x = jnp.array([0.5, 0.5])
516
- sf.make_jaxpr(x)
517
-
518
- # Verify cache has entries
519
- stats = sf.get_cache_stats()
520
- self.assertGreater(stats['jaxpr_cache']['size'], 0)
521
-
522
- # Clear cache
523
- sf.clear_cache()
524
-
525
- # Verify all caches are empty
526
- stats = sf.get_cache_stats()
527
- self.assertEqual(stats['jaxpr_cache']['size'], 0)
528
- self.assertEqual(stats['out_shapes_cache']['size'], 0)
529
- self.assertEqual(stats['jaxpr_out_tree_cache']['size'], 0)
530
- self.assertEqual(stats['state_trace_cache']['size'], 0)
531
-
532
- def test_return_only_write_parameter(self):
533
- """Test return_only_write parameter."""
534
- read_state = brainstate.State(jnp.array([1.0, 2.0]))
535
- write_state = brainstate.State(jnp.array([3.0, 4.0]))
536
-
537
- def f(x):
538
- # Read from read_state, write to write_state
539
- _ = read_state.value + x
540
- write_state.value += x
541
- return write_state.value
542
-
543
- # Test with return_only_write=False (default)
544
- sf_all = brainstate.transform.StatefulFunction(f, return_only_write=False)
545
- sf_all.make_jaxpr(jnp.array([0.5, 0.5]))
546
- cache_key = sf_all.get_arg_cache_key(jnp.array([0.5, 0.5]))
547
- states_all = sf_all.get_states_by_cache(cache_key)
548
-
549
- # Test with return_only_write=True
550
- sf_write_only = brainstate.transform.StatefulFunction(f, return_only_write=True)
551
- sf_write_only.make_jaxpr(jnp.array([0.5, 0.5]))
552
- cache_key_write = sf_write_only.get_arg_cache_key(jnp.array([0.5, 0.5]))
553
- states_write = sf_write_only.get_states_by_cache(cache_key_write)
554
-
555
- # With return_only_write=True, should have fewer or equal states
556
- self.assertLessEqual(len(states_write), len(states_all))
557
-
558
-
559
- class TestErrorHandling(unittest.TestCase):
560
- """Test error handling in StatefulFunction."""
561
-
562
- def test_jaxpr_call_state_mismatch(self):
563
- """Test error when state values length doesn't match."""
564
- state1 = brainstate.State(jnp.array([1.0, 2.0]))
565
- state2 = brainstate.State(jnp.array([3.0, 4.0]))
566
-
567
- def f(x):
568
- state1.value += x
569
- state2.value += x
570
- return state1.value + state2.value
571
-
572
- sf = brainstate.transform.StatefulFunction(f)
573
- x = jnp.array([0.5, 0.5])
574
- sf.make_jaxpr(x)
575
-
576
- # Try to call with wrong number of state values (only 1 instead of 2)
577
- with pytest.raises(ValueError, match="State length mismatch"):
578
- sf.jaxpr_call([jnp.array([1.0, 1.0])], x) # Only 1 state instead of 2
579
-
580
- def test_get_jaxpr_not_compiled_detailed_error(self):
581
- """Test detailed error message when getting jaxpr for uncompiled function."""
582
- state = brainstate.State(jnp.array([1.0, 2.0]))
583
-
584
- def f(x):
585
- return x * 2
586
-
587
- sf = brainstate.transform.StatefulFunction(f)
588
-
589
- # Compile for one input shape
590
- sf.make_jaxpr(jnp.array([1.0, 2.0]))
591
-
592
- # Try to get jaxpr with a different cache key
593
- from brainstate.transform._make_jaxpr import hashabledict
594
- fake_key = hashabledict(
595
- static_args=(),
596
- dyn_args=(),
597
- static_kwargs=(),
598
- dyn_kwargs=()
599
- )
600
-
601
- # Should raise detailed error
602
- with pytest.raises(ValueError) as exc_info:
603
- sf.get_jaxpr_by_cache(fake_key)
604
-
605
- error_msg = str(exc_info.value)
606
- # Should contain the requested key
607
- self.assertIn('Requested key:', error_msg)
608
- # Should show available keys
609
- self.assertIn('Available', error_msg)
610
- # Should have helpful message
611
- self.assertIn('make_jaxpr()', error_msg)
612
-
613
- def test_get_out_shapes_not_compiled_detailed_error(self):
614
- """Test detailed error message when getting output shapes for uncompiled function."""
615
-
616
- def f(x):
617
- return x * 2
618
-
619
- sf = brainstate.transform.StatefulFunction(f)
620
-
621
- from brainstate.transform._make_jaxpr import hashabledict
622
- fake_key = hashabledict(
623
- static_args=(),
624
- dyn_args=(),
625
- static_kwargs=(),
626
- dyn_kwargs=()
627
- )
628
-
629
- # Should raise detailed error with context "Output shapes"
630
- with pytest.raises(ValueError) as exc_info:
631
- sf.get_out_shapes_by_cache(fake_key)
632
-
633
- error_msg = str(exc_info.value)
634
- self.assertIn('Output shapes', error_msg)
635
- self.assertIn('Requested key:', error_msg)
636
-
637
- def test_get_out_treedef_not_compiled_detailed_error(self):
638
- """Test detailed error message when getting output tree for uncompiled function."""
639
-
640
- def f(x):
641
- return x * 2
642
-
643
- sf = brainstate.transform.StatefulFunction(f)
644
-
645
- from brainstate.transform._make_jaxpr import hashabledict
646
- fake_key = hashabledict(
647
- static_args=(),
648
- dyn_args=(),
649
- static_kwargs=(),
650
- dyn_kwargs=()
651
- )
652
-
653
- # Should raise detailed error with context "Output tree"
654
- with pytest.raises(ValueError) as exc_info:
655
- sf.get_out_treedef_by_cache(fake_key)
656
-
657
- error_msg = str(exc_info.value)
658
- self.assertIn('Output tree', error_msg)
659
- self.assertIn('Requested key:', error_msg)
660
-
661
- def test_get_state_trace_not_compiled_detailed_error(self):
662
- """Test detailed error message when getting state trace for uncompiled function."""
663
-
664
- def f(x):
665
- return x * 2
666
-
667
- sf = brainstate.transform.StatefulFunction(f)
668
-
669
- from brainstate.transform._make_jaxpr import hashabledict
670
- fake_key = hashabledict(
671
- static_args=(),
672
- dyn_args=(),
673
- static_kwargs=(),
674
- dyn_kwargs=()
675
- )
676
-
677
- # Should raise detailed error with context "State trace"
678
- with pytest.raises(ValueError) as exc_info:
679
- sf.get_state_trace_by_cache(fake_key)
680
-
681
- error_msg = str(exc_info.value)
682
- self.assertIn('State trace', error_msg)
683
- self.assertIn('Requested key:', error_msg)
684
-
685
-
686
- class TestCompileIfMiss(unittest.TestCase):
687
- """Test compile_if_miss parameter in *_by_call methods."""
688
-
689
- def test_get_jaxpr_by_call_with_compile_if_miss_true(self):
690
- """Test get_jaxpr_by_call with compile_if_miss=True (default)."""
691
-
692
- def f(x):
693
- return x * 2
694
-
695
- sf = brainstate.transform.StatefulFunction(f)
696
-
697
- # Should compile automatically
698
- jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=True)
699
- self.assertIsNotNone(jaxpr)
700
-
701
- def test_get_jaxpr_by_call_with_compile_if_miss_false(self):
702
- """Test get_jaxpr_by_call with compile_if_miss=False."""
703
-
704
- def f(x):
705
- return x * 2
706
-
707
- sf = brainstate.transform.StatefulFunction(f)
708
-
709
- # Should raise error because not compiled
710
- with pytest.raises(ValueError, match="not compiled"):
711
- sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=False)
712
-
713
- def test_get_out_shapes_by_call_compile_if_miss(self):
714
- """Test get_out_shapes_by_call with compile_if_miss parameter."""
715
- state = brainstate.State(jnp.array([1.0, 2.0]))
716
-
717
- def f(x):
718
- state.value += x
719
- return state.value * 2
720
-
721
- sf = brainstate.transform.StatefulFunction(f)
722
-
723
- # With compile_if_miss=True, should compile automatically
724
- shapes = sf.get_out_shapes(jnp.array([1.0, 2.0]), compile_if_miss=True)
725
- self.assertIsNotNone(shapes)
726
-
727
- # With compile_if_miss=False on different input, should fail
728
- with pytest.raises(ValueError):
729
- sf.get_out_shapes(jnp.array([1.0, 2.0, 3.0]), compile_if_miss=False)
730
-
731
- def test_get_out_treedef_by_call_compile_if_miss(self):
732
- """Test get_out_treedef_by_call with compile_if_miss parameter."""
733
-
734
- def f(x):
735
- return x * 2, x + 1
736
-
737
- sf = brainstate.transform.StatefulFunction(f)
738
-
739
- # Should compile automatically with default compile_if_miss=True
740
- treedef = sf.get_out_treedef(jnp.array([1.0, 2.0]))
741
- self.assertIsNotNone(treedef)
742
-
743
- def test_get_state_trace_by_call_compile_if_miss(self):
744
- """Test get_state_trace_by_call with compile_if_miss parameter."""
745
- state = brainstate.State(jnp.array([1.0, 2.0]))
746
-
747
- def f(x):
748
- state.value += x
749
- return state.value
750
-
751
- sf = brainstate.transform.StatefulFunction(f)
752
-
753
- # Should compile automatically
754
- trace = sf.get_state_trace(jnp.array([1.0, 2.0]), compile_if_miss=True)
755
- self.assertIsNotNone(trace)
756
-
757
- def test_get_states_by_call_compile_if_miss(self):
758
- """Test get_states_by_call with compile_if_miss parameter."""
759
- state1 = brainstate.State(jnp.array([1.0, 2.0]))
760
- state2 = brainstate.State(jnp.array([3.0, 4.0]))
761
-
762
- def f(x):
763
- state1.value += x
764
- state2.value += x
765
- return state1.value + state2.value
766
-
767
- sf = brainstate.transform.StatefulFunction(f)
768
-
769
- # Should compile automatically
770
- states = sf.get_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
771
- self.assertEqual(len(states), 2)
772
-
773
- def test_get_read_states_by_call_compile_if_miss(self):
774
- """Test get_read_states_by_call with compile_if_miss parameter."""
775
- read_state = brainstate.State(jnp.array([1.0, 2.0]))
776
- write_state = brainstate.State(jnp.array([3.0, 4.0]))
777
-
778
- def f(x):
779
- _ = read_state.value
780
- write_state.value += x
781
- return write_state.value
782
-
783
- sf = brainstate.transform.StatefulFunction(f)
784
-
785
- # Should compile automatically
786
- read_states = sf.get_read_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
787
- self.assertIsNotNone(read_states)
788
-
789
- def test_get_write_states_by_call_compile_if_miss(self):
790
- """Test get_write_states_by_call with compile_if_miss parameter."""
791
- read_state = brainstate.State(jnp.array([1.0, 2.0]))
792
- write_state = brainstate.State(jnp.array([3.0, 4.0]))
793
-
794
- def f(x):
795
- _ = read_state.value
796
- write_state.value += x
797
- return write_state.value
798
-
799
- sf = brainstate.transform.StatefulFunction(f)
800
-
801
- # Should compile automatically
802
- write_states = sf.get_write_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
803
- self.assertIsNotNone(write_states)
804
-
805
- def test_compile_if_miss_default_behavior(self):
806
- """Test that compile_if_miss defaults to True for all *_by_call methods."""
807
- state = brainstate.State(jnp.array([1.0, 2.0]))
808
-
809
- def f(x):
810
- state.value += x
811
- return state.value
812
-
813
- sf = brainstate.transform.StatefulFunction(f)
814
-
815
- # All these should work without explicit compile_if_miss=True
816
- jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]))
817
- self.assertIsNotNone(jaxpr)
818
-
819
- # Create new instance for fresh cache
820
- sf2 = brainstate.transform.StatefulFunction(f)
821
- shapes = sf2.get_out_shapes(jnp.array([1.0, 2.0]))
822
- self.assertIsNotNone(shapes)
823
-
824
- # Create new instance for fresh cache
825
- sf3 = brainstate.transform.StatefulFunction(f)
826
- states = sf3.get_states(jnp.array([1.0, 2.0]))
827
- self.assertIsNotNone(states)
828
-
829
-
830
- class TestMakeHashable(unittest.TestCase):
831
- """Test the make_hashable utility function."""
832
-
833
- def test_hashable_list(self):
834
- """Test converting list to hashable."""
835
- result = make_hashable([1, 2, 3])
836
- # Should return a tuple
837
- self.assertIsInstance(result, tuple)
838
- # Should be hashable
839
- hash(result)
840
-
841
- def test_hashable_dict(self):
842
- """Test converting dict to hashable."""
843
- result = make_hashable({'b': 2, 'a': 1})
844
- # Should return a tuple of sorted key-value pairs
845
- self.assertIsInstance(result, tuple)
846
- # Should be hashable
847
- hash(result)
848
- # Keys should be sorted
849
- keys = [item[0] for item in result]
850
- self.assertEqual(keys, ['a', 'b'])
851
-
852
- def test_hashable_set(self):
853
- """Test converting set to hashable."""
854
- result = make_hashable({1, 2, 3})
855
- # Should return a frozenset
856
- self.assertIsInstance(result, frozenset)
857
- # Should be hashable
858
- hash(result)
859
-
860
- def test_hashable_nested(self):
861
- """Test converting nested structures."""
862
- nested = {
863
- 'list': [1, 2, 3],
864
- 'dict': {'a': 1, 'b': 2},
865
- 'set': {4, 5}
866
- }
867
- result = make_hashable(nested)
868
- # Should be hashable
869
- hash(result) # Should not raise
870
-
871
- def test_hashable_tuple(self):
872
- """Test with tuples."""
873
- result = make_hashable((1, 2, 3))
874
- # Should return a tuple
875
- self.assertIsInstance(result, tuple)
876
- # Should be hashable
877
- hash(result)
878
-
879
- def test_hashable_idempotent(self):
880
- """Test that applying make_hashable twice gives consistent results."""
881
- original = {'a': [1, 2], 'b': {3, 4}}
882
- result1 = make_hashable(original)
883
- result2 = make_hashable(original)
884
- # Should be the same
885
- self.assertEqual(result1, result2)
886
-
887
-
888
- class TestCacheCleanupOnError(unittest.TestCase):
889
- """Test that cache is properly cleaned up when compilation fails."""
890
-
891
- def test_cache_cleanup_on_compilation_error(self):
892
- """Test that partial cache entries are cleaned up when make_jaxpr fails."""
893
-
894
- def f(x):
895
- # This will cause an error during JAX tracing
896
- if x > 0: # Control flow not allowed in JAX
897
- return x * 2
898
- else:
899
- return x + 1
900
-
901
- sf = brainstate.transform.StatefulFunction(f)
902
-
903
- # Try to compile, should fail
904
- try:
905
- sf.make_jaxpr(jnp.array([1.0]))
906
- except Exception:
907
- pass # Expected to fail
908
-
909
- # Cache should be empty after error
910
- stats = sf.get_cache_stats()
911
- # All caches should be empty since error cleanup should have removed partial entries
912
- # Note: The actual behavior depends on when the error occurs during compilation
913
- # If error happens early, no cache entries; if late, entries might exist
914
- # This test just verifies the cleanup mechanism exists
915
-
916
-
917
- class TestMakeJaxprReturnOnlyWrite(unittest.TestCase):
918
- """Test make_jaxpr with return_only_write parameter."""
919
-
920
- def test_make_jaxpr_return_only_write(self):
921
- """Test make_jaxpr function with return_only_write parameter."""
922
- read_state = brainstate.State(jnp.array([1.0]))
923
- write_state = brainstate.State(jnp.array([2.0]))
924
-
925
- def f(x):
926
- _ = read_state.value # Read only
927
- write_state.value += x # Write
928
- return x * 2
929
-
930
- # Test with return_only_write=True
931
- jaxpr_maker = brainstate.transform.make_jaxpr(f, return_only_write=True)
932
- jaxpr, states = jaxpr_maker(jnp.array([1.0]))
933
-
934
- # Should compile successfully
935
- self.assertIsNotNone(jaxpr)
936
- self.assertIsInstance(states, tuple)
937
-
938
-
939
- class TestStatefulFunctionCallable(unittest.TestCase):
940
- """Test __call__ method of StatefulFunction."""
941
-
942
- def test_stateful_function_call(self):
943
- """Test calling StatefulFunction directly."""
944
- state = brainstate.State(jnp.array([1.0, 2.0]))
945
-
946
- def f(x):
947
- state.value += x
948
- return state.value * 2
949
-
950
- sf = brainstate.transform.StatefulFunction(f)
951
- x = jnp.array([0.5, 0.5])
952
- sf.make_jaxpr(x)
953
-
954
- # Test direct call
955
- result = sf(x)
956
- self.assertEqual(result.shape, (2,))
957
-
958
- def test_stateful_function_call_auto_compile(self):
959
- """Test that __call__ automatically compiles if needed."""
960
- state = brainstate.State(jnp.array([1.0, 2.0]))
961
-
962
- def f(x):
963
- state.value += x
964
- return state.value * 2
965
-
966
- sf = brainstate.transform.StatefulFunction(f)
967
- x = jnp.array([0.5, 0.5])
968
-
969
- # Call without pre-compilation should work
970
- result = sf(x)
971
- self.assertEqual(result.shape, (2,))
972
-
973
- def test_stateful_function_multiple_calls(self):
974
- """Test multiple calls to StatefulFunction."""
975
- state = brainstate.State(jnp.array([0.0]))
976
-
977
- def f(x):
978
- state.value += x
979
- return state.value
980
-
981
- sf = brainstate.transform.StatefulFunction(f)
982
-
983
- # Multiple calls should accumulate state
984
- result1 = sf(jnp.array([1.0]))
985
- result2 = sf(jnp.array([2.0]))
986
- result3 = sf(jnp.array([3.0]))
987
-
988
- # Each call should update the state
989
- self.assertIsNotNone(result1)
990
- self.assertIsNotNone(result2)
991
- self.assertIsNotNone(result3)
992
-
993
-
994
- class TestStatefulFunctionStaticArgs(unittest.TestCase):
995
- """Test StatefulFunction with static arguments."""
996
-
997
- def test_static_argnums_basic(self):
998
- """Test basic usage of static_argnums."""
999
- state = brainstate.State(jnp.array([1.0, 2.0]))
1000
-
1001
- def f(x, multiplier):
1002
- state.value += x
1003
- return state.value * multiplier
1004
-
1005
- sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
1006
- x = jnp.array([0.5, 0.5])
1007
-
1008
- # Compile with multiplier=2
1009
- sf.make_jaxpr(x, 2)
1010
- cache_key1 = sf.get_arg_cache_key(x, 2)
1011
-
1012
- # Compile with multiplier=3
1013
- sf.make_jaxpr(x, 3)
1014
- cache_key2 = sf.get_arg_cache_key(x, 3)
1015
-
1016
- # Should have different cache keys
1017
- self.assertNotEqual(cache_key1, cache_key2)
1018
-
1019
- def test_static_argnames_basic(self):
1020
- """Test basic usage of static_argnames."""
1021
- state = brainstate.State(jnp.array([1.0, 2.0]))
1022
-
1023
- def f(x, multiplier=2):
1024
- state.value += x
1025
- return state.value * multiplier
1026
-
1027
- sf = brainstate.transform.StatefulFunction(f, static_argnames='multiplier')
1028
- x = jnp.array([0.5, 0.5])
1029
-
1030
- # Compile with different multiplier values
1031
- sf.make_jaxpr(x, multiplier=2)
1032
- cache_key1 = sf.get_arg_cache_key(x, multiplier=2)
1033
-
1034
- sf.make_jaxpr(x, multiplier=3)
1035
- cache_key2 = sf.get_arg_cache_key(x, multiplier=3)
1036
-
1037
- # Should have different cache keys
1038
- self.assertNotEqual(cache_key1, cache_key2)
1039
-
1040
- def test_static_args_combination(self):
1041
- """Test using both static_argnums and static_argnames."""
1042
- state = brainstate.State(jnp.array([1.0]))
1043
-
1044
- def f(x, multiplier, offset=0):
1045
- state.value += x
1046
- return state.value * multiplier + offset
1047
-
1048
- sf = brainstate.transform.StatefulFunction(
1049
- f, static_argnums=(1,), static_argnames='offset'
1050
- )
1051
- x = jnp.array([0.5])
1052
-
1053
- # Compile with different static args
1054
- sf.make_jaxpr(x, 2, offset=0)
1055
- cache_key1 = sf.get_arg_cache_key(x, 2, offset=0)
1056
-
1057
- sf.make_jaxpr(x, 3, offset=1)
1058
- cache_key2 = sf.get_arg_cache_key(x, 3, offset=1)
1059
-
1060
- # Should have different cache keys
1061
- self.assertNotEqual(cache_key1, cache_key2)
1062
-
1063
-
1064
- class TestStatefulFunctionComplexStates(unittest.TestCase):
1065
- """Test StatefulFunction with complex state scenarios."""
1066
-
1067
- def test_multiple_states(self):
1068
- """Test function with multiple states."""
1069
- state1 = brainstate.State(jnp.array([1.0]))
1070
- state2 = brainstate.State(jnp.array([2.0]))
1071
- state3 = brainstate.State(jnp.array([3.0]))
1072
-
1073
- def f(x):
1074
- state1.value += x
1075
- state2.value += x * 2
1076
- state3.value += x * 3
1077
- return state1.value + state2.value + state3.value
1078
-
1079
- sf = brainstate.transform.StatefulFunction(f)
1080
- x = jnp.array([1.0])
1081
- sf.make_jaxpr(x)
1082
-
1083
- cache_key = sf.get_arg_cache_key(x)
1084
- states = sf.get_states_by_cache(cache_key)
1085
-
1086
- # Should track all three states
1087
- self.assertEqual(len(states), 3)
1088
-
1089
- def test_nested_state_access(self):
1090
- """Test function with nested state access patterns."""
1091
- outer_state = brainstate.State(jnp.array([1.0]))
1092
- inner_state = brainstate.State(jnp.array([2.0]))
1093
-
1094
- def inner_fn(x):
1095
- inner_state.value += x
1096
- return inner_state.value
1097
-
1098
- def outer_fn(x):
1099
- outer_state.value += x
1100
- result = inner_fn(x)
1101
- return outer_state.value + result
1102
-
1103
- sf = brainstate.transform.StatefulFunction(outer_fn)
1104
- x = jnp.array([1.0])
1105
- sf.make_jaxpr(x)
1106
-
1107
- cache_key = sf.get_arg_cache_key(x)
1108
- states = sf.get_states_by_cache(cache_key)
1109
-
1110
- # Should track both states
1111
- self.assertGreaterEqual(len(states), 2)
1112
-
1113
- def test_conditional_state_write(self):
1114
- """Test function that conditionally writes to states."""
1115
- state1 = brainstate.State(jnp.array([1.0]))
1116
- state2 = brainstate.State(jnp.array([2.0]))
1117
-
1118
- def f(x, write_state1=True):
1119
- # Note: In JAX, actual control flow needs special handling
1120
- # This test is more about the framework's ability to track states
1121
- state1.value += x # Always write to state1
1122
- state2.value += x * 2 # Always write to state2
1123
- return state1.value + state2.value
1124
-
1125
- sf = brainstate.transform.StatefulFunction(f, static_argnames='write_state1')
1126
- x = jnp.array([1.0])
1127
- sf.make_jaxpr(x, write_state1=True)
1128
-
1129
- cache_key = sf.get_arg_cache_key(x, write_state1=True)
1130
- states = sf.get_states_by_cache(cache_key)
1131
-
1132
- # Should track states
1133
- self.assertGreaterEqual(len(states), 2)
1134
-
1135
-
1136
- class TestStatefulFunctionOutputShapes(unittest.TestCase):
1137
- """Test StatefulFunction output shape tracking."""
1138
-
1139
- def test_single_output(self):
1140
- """Test tracking single output shape."""
1141
- state = brainstate.State(jnp.array([1.0, 2.0, 3.0]))
1142
-
1143
- def f(x):
1144
- state.value += x
1145
- return state.value
1146
-
1147
- sf = brainstate.transform.StatefulFunction(f)
1148
- x = jnp.array([1.0, 2.0, 3.0])
1149
- sf.make_jaxpr(x)
1150
-
1151
- cache_key = sf.get_arg_cache_key(x)
1152
- out_shapes = sf.get_out_shapes_by_cache(cache_key)
1153
-
1154
- # Should have output shapes
1155
- self.assertIsNotNone(out_shapes)
1156
-
1157
- def test_multiple_outputs(self):
1158
- """Test tracking multiple output shapes."""
1159
- state = brainstate.State(jnp.array([1.0, 2.0]))
1160
-
1161
- def f(x):
1162
- state.value += x
1163
- return state.value, state.value * 2, jnp.sum(state.value)
1164
-
1165
- sf = brainstate.transform.StatefulFunction(f)
1166
- x = jnp.array([1.0, 2.0])
1167
- sf.make_jaxpr(x)
1168
-
1169
- cache_key = sf.get_arg_cache_key(x)
1170
- out_shapes = sf.get_out_shapes_by_cache(cache_key)
1171
-
1172
- # Should track all output shapes
1173
- self.assertIsNotNone(out_shapes)
1174
-
1175
- def test_nested_output_structure(self):
1176
- """Test tracking nested output structures."""
1177
- state = brainstate.State(jnp.array([1.0, 2.0]))
1178
-
1179
- def f(x):
1180
- state.value += x
1181
- return {
1182
- 'sum': jnp.sum(state.value),
1183
- 'prod': jnp.prod(state.value),
1184
- 'values': state.value
1185
- }
1186
-
1187
- sf = brainstate.transform.StatefulFunction(f)
1188
- x = jnp.array([1.0, 2.0])
1189
- sf.make_jaxpr(x)
1190
-
1191
- cache_key = sf.get_arg_cache_key(x)
1192
- out_treedef = sf.get_out_treedef_by_cache(cache_key)
1193
-
1194
- # Should have tree definition
1195
- self.assertIsNotNone(out_treedef)
1196
-
1197
-
1198
- class TestStatefulFunctionJaxprCall(unittest.TestCase):
1199
- """Test jaxpr_call and jaxpr_call_auto methods."""
1200
-
1201
- def test_jaxpr_call_basic(self):
1202
- """Test basic jaxpr_call usage."""
1203
- state = brainstate.State(jnp.array([1.0, 2.0]))
1204
-
1205
- def f(x):
1206
- state.value += x
1207
- return state.value * 2
1208
-
1209
- sf = brainstate.transform.StatefulFunction(f)
1210
- x = jnp.array([0.5, 0.5])
1211
- sf.make_jaxpr(x)
1212
-
1213
- # Get current state values
1214
- state_vals = [state.value]
1215
-
1216
- # Call at jaxpr level
1217
- new_state_vals, out = sf.jaxpr_call(state_vals, x)
1218
-
1219
- self.assertEqual(len(new_state_vals), 1)
1220
- self.assertEqual(out.shape, (2,))
1221
-
1222
- def test_jaxpr_call_auto_basic(self):
1223
- """Test basic jaxpr_call_auto usage."""
1224
- state = brainstate.State(jnp.array([1.0, 2.0]))
1225
-
1226
- def f(x):
1227
- state.value += x
1228
- return state.value * 2
1229
-
1230
- sf = brainstate.transform.StatefulFunction(f)
1231
- x = jnp.array([0.5, 0.5])
1232
- sf.make_jaxpr(x)
1233
-
1234
- # Call with automatic state management
1235
- result = sf.jaxpr_call_auto(x)
1236
-
1237
- self.assertEqual(result.shape, (2,))
1238
-
1239
- def test_jaxpr_call_preserves_state_order(self):
1240
- """Test that jaxpr_call preserves state order."""
1241
- state1 = brainstate.State(jnp.array([1.0]))
1242
- state2 = brainstate.State(jnp.array([2.0]))
1243
- state3 = brainstate.State(jnp.array([3.0]))
1244
-
1245
- def f(x):
1246
- state1.value += x
1247
- state2.value += x * 2
1248
- state3.value += x * 3
1249
- return state1.value + state2.value + state3.value
1250
-
1251
- sf = brainstate.transform.StatefulFunction(f)
1252
- x = jnp.array([1.0])
1253
- sf.make_jaxpr(x)
1254
-
1255
- cache_key = sf.get_arg_cache_key(x)
1256
- states = sf.get_states_by_cache(cache_key)
1257
-
1258
- # Get initial state values
1259
- state_vals = [s.value for s in states]
1260
-
1261
- # Call at jaxpr level
1262
- new_state_vals, _ = sf.jaxpr_call(state_vals, x)
1263
-
1264
- # Should return same number of states
1265
- self.assertEqual(len(new_state_vals), len(state_vals))
1266
-
1267
-
1268
- class TestStatefulFunctionEdgeCases(unittest.TestCase):
1269
- """Test edge cases and corner scenarios."""
1270
-
1271
- def test_no_state_function(self):
1272
- """Test function that doesn't use any states."""
1273
-
1274
- def f(x):
1275
- return x * 2 + 1
1276
-
1277
- sf = brainstate.transform.StatefulFunction(f)
1278
- x = jnp.array([1.0, 2.0])
1279
- sf.make_jaxpr(x)
1280
-
1281
- cache_key = sf.get_arg_cache_key(x)
1282
- states = sf.get_states_by_cache(cache_key)
1283
-
1284
- # Should have no states
1285
- self.assertEqual(len(states), 0)
1286
-
1287
- def test_read_only_state(self):
1288
- """Test function that only reads from states."""
1289
- state = brainstate.State(jnp.array([1.0, 2.0]))
1290
-
1291
- def f(x):
1292
- # Only read from state, don't write
1293
- return state.value + x
1294
-
1295
- sf = brainstate.transform.StatefulFunction(f, return_only_write=True)
1296
- x = jnp.array([1.0, 2.0])
1297
- sf.make_jaxpr(x)
1298
-
1299
- cache_key = sf.get_arg_cache_key(x)
1300
- write_states = sf.get_write_states_by_cache(cache_key)
1301
-
1302
- # Should have no write states
1303
- self.assertEqual(len(write_states), 0)
1304
-
1305
- def test_scalar_inputs_outputs(self):
1306
- """Test with scalar inputs and outputs."""
1307
- state = brainstate.State(jnp.array(1.0))
1308
-
1309
- def f(x):
1310
- state.value += x
1311
- return state.value
1312
-
1313
- sf = brainstate.transform.StatefulFunction(f)
1314
- x = jnp.array(0.5)
1315
- sf.make_jaxpr(x)
1316
-
1317
- cache_key = sf.get_arg_cache_key(x)
1318
- jaxpr = sf.get_jaxpr_by_cache(cache_key)
1319
-
1320
- # Should compile successfully
1321
- self.assertIsNotNone(jaxpr)
1322
-
1323
- def test_empty_function(self):
1324
- """Test function with no operations."""
1325
-
1326
- def f(x):
1327
- return x
1328
-
1329
- sf = brainstate.transform.StatefulFunction(f)
1330
- x = jnp.array([1.0, 2.0])
1331
- sf.make_jaxpr(x)
1332
-
1333
- cache_key = sf.get_arg_cache_key(x)
1334
- jaxpr = sf.get_jaxpr_by_cache(cache_key)
1335
-
1336
- # Should compile successfully
1337
- self.assertIsNotNone(jaxpr)
1338
-
1339
- def test_complex_dtype(self):
1340
- """Test with complex dtype arrays."""
1341
- state = brainstate.State(jnp.array([1.0 + 2.0j, 3.0 + 4.0j]))
1342
-
1343
- def f(x):
1344
- state.value += x
1345
- return state.value
1346
-
1347
- sf = brainstate.transform.StatefulFunction(f)
1348
- x = jnp.array([0.5 + 0.5j, 0.5 + 0.5j])
1349
- sf.make_jaxpr(x)
1350
-
1351
- cache_key = sf.get_arg_cache_key(x)
1352
- jaxpr = sf.get_jaxpr_by_cache(cache_key)
1353
-
1354
- # Should compile successfully
1355
- self.assertIsNotNone(jaxpr)
1356
-
1357
-
1358
- class TestStatefulFunctionCacheKey(unittest.TestCase):
1359
- """Test cache key generation and behavior."""
1360
-
1361
- def test_cache_key_different_shapes(self):
1362
- """Test that different input shapes produce different cache keys."""
1363
-
1364
- def f(x):
1365
- return x * 2
1366
-
1367
- sf = brainstate.transform.StatefulFunction(f)
1368
-
1369
- x1 = jnp.array([1.0, 2.0])
1370
- x2 = jnp.array([1.0, 2.0, 3.0])
1371
-
1372
- cache_key1 = sf.get_arg_cache_key(x1)
1373
- cache_key2 = sf.get_arg_cache_key(x2)
1374
-
1375
- # Should have different cache keys
1376
- self.assertNotEqual(cache_key1, cache_key2)
1377
-
1378
- def test_cache_key_different_dtypes(self):
1379
- """Test that different dtypes produce different cache keys."""
1380
-
1381
- def f(x):
1382
- return x * 2
1383
-
1384
- sf = brainstate.transform.StatefulFunction(f)
1385
-
1386
- # Use int32 and float32 instead, which are always available in JAX
1387
- x1 = jnp.array([1.0, 2.0], dtype=jnp.float32)
1388
- x2 = jnp.array([1, 2], dtype=jnp.int32)
1389
-
1390
- cache_key1 = sf.get_arg_cache_key(x1)
1391
- cache_key2 = sf.get_arg_cache_key(x2)
1392
-
1393
- # Should have different cache keys due to different dtypes
1394
- self.assertNotEqual(cache_key1, cache_key2)
1395
-
1396
- def test_cache_key_same_abstract_values(self):
1397
- """Test that same abstract values produce same cache keys."""
1398
-
1399
- def f(x):
1400
- return x * 2
1401
-
1402
- sf = brainstate.transform.StatefulFunction(f)
1403
-
1404
- x1 = jnp.array([1.0, 2.0])
1405
- x2 = jnp.array([3.0, 4.0]) # Different values, same shape/dtype
1406
-
1407
- cache_key1 = sf.get_arg_cache_key(x1)
1408
- cache_key2 = sf.get_arg_cache_key(x2)
1409
-
1410
- # Should have same cache keys (abstract values are the same)
1411
- self.assertEqual(cache_key1, cache_key2)
1412
-
1413
- def test_cache_key_with_pytree_inputs(self):
1414
- """Test cache key generation with pytree inputs."""
1415
-
1416
- def f(inputs):
1417
- x, y = inputs
1418
- return x + y
1419
-
1420
- sf = brainstate.transform.StatefulFunction(f)
1421
-
1422
- inputs1 = (jnp.array([1.0]), jnp.array([2.0]))
1423
- inputs2 = (jnp.array([3.0]), jnp.array([4.0]))
1424
-
1425
- cache_key1 = sf.get_arg_cache_key(inputs1)
1426
- cache_key2 = sf.get_arg_cache_key(inputs2)
1427
-
1428
- # Should have same cache keys (same structure/shapes)
1429
- self.assertEqual(cache_key1, cache_key2)
1430
-
1431
-
1432
- class TestStatefulFunctionRecompilation(unittest.TestCase):
1433
- """Test recompilation scenarios."""
1434
-
1435
- def test_cache_reuse(self):
1436
- """Test that cache is reused for same inputs."""
1437
- state = brainstate.State(jnp.array([1.0]))
1438
-
1439
- def f(x):
1440
- state.value += x
1441
- return state.value
1442
-
1443
- sf = brainstate.transform.StatefulFunction(f)
1444
-
1445
- x = jnp.array([1.0])
1446
-
1447
- # First compilation
1448
- sf.make_jaxpr(x)
1449
- stats1 = sf.get_cache_stats()
1450
-
1451
- # Second call with same shape should reuse cache
1452
- sf.make_jaxpr(x)
1453
- stats2 = sf.get_cache_stats()
1454
-
1455
- # Cache size should remain the same
1456
- self.assertEqual(
1457
- stats1['jaxpr_cache']['size'],
1458
- stats2['jaxpr_cache']['size']
1459
- )
1460
-
1461
- def test_multiple_compilations_different_shapes(self):
1462
- """Test multiple compilations with different shapes."""
1463
- state = brainstate.State(jnp.array([1.0]))
1464
-
1465
- def f(x):
1466
- return x * 2
1467
-
1468
- sf = brainstate.transform.StatefulFunction(f)
1469
-
1470
- # Compile for different shapes
1471
- shapes = [
1472
- jnp.array([1.0]),
1473
- jnp.array([1.0, 2.0]),
1474
- jnp.array([1.0, 2.0, 3.0]),
1475
- ]
1476
-
1477
- for x in shapes:
1478
- sf.make_jaxpr(x)
1479
-
1480
- stats = sf.get_cache_stats()
1481
-
1482
- # Should have 3 different cache entries
1483
- self.assertEqual(stats['jaxpr_cache']['size'], 3)
1484
-
1485
- def test_clear_and_recompile(self):
1486
- """Test clearing cache and recompiling."""
1487
- state = brainstate.State(jnp.array([1.0]))
1488
-
1489
- def f(x):
1490
- state.value += x
1491
- return state.value
1492
-
1493
- sf = brainstate.transform.StatefulFunction(f)
1494
- x = jnp.array([1.0])
1495
-
1496
- # Compile
1497
- sf.make_jaxpr(x)
1498
- stats_before = sf.get_cache_stats()
1499
- self.assertGreater(stats_before['jaxpr_cache']['size'], 0)
1500
-
1501
- # Clear cache
1502
- sf.clear_cache()
1503
- stats_after_clear = sf.get_cache_stats()
1504
- self.assertEqual(stats_after_clear['jaxpr_cache']['size'], 0)
1505
-
1506
- # Recompile
1507
- sf.make_jaxpr(x)
1508
- stats_after_recompile = sf.get_cache_stats()
1509
- self.assertGreater(stats_after_recompile['jaxpr_cache']['size'], 0)
1510
-
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import threading
18
+ import unittest
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import pytest
23
+
24
+ import brainstate
25
+ from brainstate._error import BatchAxisError
26
+ from brainstate._compatible_import import jaxpr_as_fun
27
+ from brainstate.transform._make_jaxpr import _BoundedCache, make_hashable
28
+
29
+
30
+ class TestMakeJaxpr(unittest.TestCase):
31
+ def test_compar_jax_make_jaxpr(self):
32
+ def func4(arg): # Arg is a pair
33
+ temp = arg[0] + jnp.sin(arg[1]) * 3.
34
+ c = brainstate.random.rand_like(arg[0])
35
+ return jnp.sum(temp + c)
36
+
37
+ key = brainstate.random.DEFAULT.value
38
+ jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
39
+ print(jaxpr)
40
+ self.assertTrue(len(jaxpr.in_avals) == 2)
41
+ self.assertTrue(len(jaxpr.consts) == 1)
42
+ self.assertTrue(len(jaxpr.out_avals) == 1)
43
+ self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
44
+
45
+ brainstate.random.seed(1)
46
+ print(brainstate.random.DEFAULT.value)
47
+
48
+ jaxpr2, states = brainstate.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
49
+ print(jaxpr2)
50
+ self.assertTrue(len(jaxpr2.in_avals) == 3)
51
+ self.assertTrue(len(jaxpr2.out_avals) == 2)
52
+ self.assertTrue(len(jaxpr2.consts) == 0)
53
+ print(brainstate.random.DEFAULT.value)
54
+
55
+ def test_StatefulFunction_1(self):
56
+ def func4(arg): # Arg is a pair
57
+ temp = arg[0] + jnp.sin(arg[1]) * 3.
58
+ c = brainstate.random.rand_like(arg[0])
59
+ return jnp.sum(temp + c)
60
+
61
+ fun = brainstate.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
62
+ cache_key = fun.get_arg_cache_key((jnp.zeros(8), jnp.ones(8)))
63
+ print(fun.get_states_by_cache(cache_key))
64
+ print(fun.get_jaxpr_by_cache(cache_key))
65
+
66
+ def test_StatefulFunction_2(self):
67
+ st1 = brainstate.State(jnp.ones(10))
68
+
69
+ def f1(x):
70
+ st1.value = x + st1.value
71
+
72
+ def f2(x):
73
+ jaxpr = brainstate.transform.make_jaxpr(f1)(x)
74
+ c = 1. + x
75
+ return c
76
+
77
+ def f3(x):
78
+ jaxpr = brainstate.transform.make_jaxpr(f1)(x)
79
+ c = 1.
80
+ return c
81
+
82
+ print()
83
+ jaxpr = brainstate.transform.make_jaxpr(f1)(jnp.zeros(1))
84
+ print(jaxpr)
85
+ jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
86
+ print(jaxpr)
87
+ jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
88
+ print(jaxpr)
89
+ jaxpr, _ = brainstate.transform.make_jaxpr(f3)(jnp.zeros(1))
90
+ print(jaxpr)
91
+ self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
92
+ f3(jnp.zeros(1))))
93
+
94
+ def test_compare_jax_make_jaxpr2(self):
95
+ st1 = brainstate.State(jnp.ones(10))
96
+
97
+ def fa(x):
98
+ st1.value = x + st1.value
99
+
100
+ def ffa(x):
101
+ jaxpr, states = brainstate.transform.make_jaxpr(fa)(x)
102
+ c = 1. + x
103
+ return c
104
+
105
+ jaxpr, states = brainstate.transform.make_jaxpr(ffa)(jnp.zeros(1))
106
+ print()
107
+ print(jaxpr)
108
+ print(states)
109
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
110
+ jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
111
+ print(jaxpr)
112
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
113
+
114
+ def test_compare_jax_make_jaxpr3(self):
115
+ def fa(x):
116
+ return 1.
117
+
118
+ jaxpr, states = brainstate.transform.make_jaxpr(fa)(jnp.zeros(1))
119
+ print()
120
+ print(jaxpr)
121
+ print(states)
122
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
+ jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
124
+ print(jaxpr)
125
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
126
+
127
+ def test_static_argnames(self):
128
+ def func4(a, b): # Arg is a pair
129
+ temp = a + jnp.sin(b) * 3.
130
+ c = brainstate.random.rand_like(a)
131
+ return jnp.sum(temp + c)
132
+
133
+ jaxpr, states = brainstate.transform.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
134
+ print()
135
+ print(jaxpr)
136
+ print(states)
137
+
138
+ def test_state_in(self):
139
+ def f(a):
140
+ return a.value
141
+
142
+ with pytest.raises(ValueError):
143
+ brainstate.transform.StatefulFunction(f).make_jaxpr(brainstate.State(1.))
144
+
145
+ def test_state_out(self):
146
+ def f(a):
147
+ return brainstate.State(a)
148
+
149
+ with pytest.raises(ValueError):
150
+ brainstate.transform.StatefulFunction(f).make_jaxpr(1.)
151
+
152
+ def test_return_states(self):
153
+ a = brainstate.State(jnp.ones(3))
154
+
155
+ @brainstate.transform.jit
156
+ def f():
157
+ return a
158
+
159
+ with pytest.raises(ValueError):
160
+ f()
161
+
162
+
163
+ class TestBoundedCache(unittest.TestCase):
164
+ """Test the _BoundedCache class."""
165
+
166
+ def test_cache_basic_operations(self):
167
+ """Test basic get and set operations."""
168
+ cache = _BoundedCache(maxsize=3)
169
+
170
+ # Test set and get
171
+ cache.set('key1', 'value1')
172
+ self.assertEqual(cache.get('key1'), 'value1')
173
+
174
+ # Test default value
175
+ self.assertIsNone(cache.get('nonexistent'))
176
+ self.assertEqual(cache.get('nonexistent', 'default'), 'default')
177
+
178
+ # Test __contains__
179
+ self.assertIn('key1', cache)
180
+ self.assertNotIn('key2', cache)
181
+
182
+ # Test __len__
183
+ self.assertEqual(len(cache), 1)
184
+
185
+ def test_cache_lru_eviction(self):
186
+ """Test LRU eviction when cache is full."""
187
+ cache = _BoundedCache(maxsize=3)
188
+
189
+ # Fill cache
190
+ cache.set('key1', 'value1')
191
+ cache.set('key2', 'value2')
192
+ cache.set('key3', 'value3')
193
+ self.assertEqual(len(cache), 3)
194
+
195
+ # Add one more, should evict key1 (least recently used)
196
+ cache.set('key4', 'value4')
197
+ self.assertEqual(len(cache), 3)
198
+ self.assertNotIn('key1', cache)
199
+ self.assertIn('key4', cache)
200
+
201
+ # Access key2 to make it recently used
202
+ cache.get('key2')
203
+
204
+ # Add another key, should evict key3 (now least recently used)
205
+ cache.set('key5', 'value5')
206
+ self.assertNotIn('key3', cache)
207
+ self.assertIn('key2', cache)
208
+
209
+ def test_cache_update_existing(self):
210
+ """Test updating an existing key."""
211
+ cache = _BoundedCache(maxsize=2)
212
+
213
+ cache.set('key1', 'value1')
214
+ cache.set('key2', 'value2')
215
+
216
+ # Update key1 (should move it to end)
217
+ cache.replace('key1', 'updated_value1')
218
+ self.assertEqual(cache.get('key1'), 'updated_value1')
219
+
220
+ # Add new key, should evict key2 (now LRU)
221
+ cache.set('key3', 'value3')
222
+ self.assertNotIn('key2', cache)
223
+ self.assertIn('key1', cache)
224
+
225
+ def test_cache_statistics(self):
226
+ """Test cache statistics tracking."""
227
+ cache = _BoundedCache(maxsize=5)
228
+
229
+ # Initial stats
230
+ stats = cache.get_stats()
231
+ self.assertEqual(stats['size'], 0)
232
+ self.assertEqual(stats['maxsize'], 5)
233
+ self.assertEqual(stats['hits'], 0)
234
+ self.assertEqual(stats['misses'], 0)
235
+ self.assertEqual(stats['hit_rate'], 0.0)
236
+
237
+ # Add items and test hits/misses
238
+ cache.set('key1', 'value1')
239
+ cache.set('key2', 'value2')
240
+
241
+ # Generate hits
242
+ cache.get('key1') # hit
243
+ cache.get('key1') # hit
244
+ cache.get('key3') # miss
245
+ cache.get('key2') # hit
246
+
247
+ stats = cache.get_stats()
248
+ self.assertEqual(stats['size'], 2)
249
+ self.assertEqual(stats['hits'], 3)
250
+ self.assertEqual(stats['misses'], 1)
251
+ self.assertEqual(stats['hit_rate'], 75.0)
252
+
253
+ def test_cache_clear(self):
254
+ """Test clearing the cache."""
255
+ cache = _BoundedCache(maxsize=5)
256
+
257
+ # Add items
258
+ cache.set('key1', 'value1')
259
+ cache.set('key2', 'value2')
260
+ cache.get('key1') # Generate a hit
261
+
262
+ # Clear cache
263
+ cache.clear()
264
+
265
+ self.assertEqual(len(cache), 0)
266
+ self.assertNotIn('key1', cache)
267
+
268
+ # Check stats are reset
269
+ stats = cache.get_stats()
270
+ self.assertEqual(stats['hits'], 0)
271
+ self.assertEqual(stats['misses'], 0)
272
+
273
+ def test_cache_keys(self):
274
+ """Test getting all cache keys."""
275
+ cache = _BoundedCache(maxsize=5)
276
+
277
+ cache.set('key1', 'value1')
278
+ cache.set('key2', 'value2')
279
+ cache.set('key3', 'value3')
280
+
281
+ keys = cache.keys()
282
+ self.assertEqual(set(keys), {'key1', 'key2', 'key3'})
283
+
284
+ def test_cache_set_duplicate_raises(self):
285
+ """Test that setting an existing key raises ValueError."""
286
+ cache = _BoundedCache(maxsize=5)
287
+
288
+ cache.set('key1', 'value1')
289
+
290
+ # Attempting to set the same key should raise ValueError
291
+ with pytest.raises(ValueError, match="Cache key already exists"):
292
+ cache.set('key1', 'value2')
293
+
294
+ def test_cache_pop(self):
295
+ """Test pop method."""
296
+ cache = _BoundedCache(maxsize=5)
297
+
298
+ cache.set('key1', 'value1')
299
+ cache.set('key2', 'value2')
300
+
301
+ # Pop existing key
302
+ value = cache.pop('key1')
303
+ self.assertEqual(value, 'value1')
304
+ self.assertNotIn('key1', cache)
305
+ self.assertEqual(len(cache), 1)
306
+
307
+ # Pop non-existent key with default
308
+ value = cache.pop('nonexistent', 'default')
309
+ self.assertEqual(value, 'default')
310
+
311
+ # Pop non-existent key without default
312
+ value = cache.pop('nonexistent')
313
+ self.assertIsNone(value)
314
+
315
+ def test_cache_replace(self):
316
+ """Test replace method."""
317
+ cache = _BoundedCache(maxsize=5)
318
+
319
+ cache.set('key1', 'value1')
320
+ cache.set('key2', 'value2')
321
+
322
+ # Replace existing key
323
+ cache.replace('key1', 'new_value1')
324
+ self.assertEqual(cache.get('key1'), 'new_value1')
325
+
326
+ # Replacing should move to end (most recently used)
327
+ cache.set('key3', 'value3')
328
+ cache.replace('key2', 'new_value2')
329
+
330
+ # Add more items to test LRU behavior
331
+ cache.set('key4', 'value4')
332
+ cache.set('key5', 'value5')
333
+
334
+ # Now when we add key6, key1 should be evicted (oldest after replace moved key2 to end)
335
+ cache.set('key6', 'value6')
336
+
337
+ # key2 should still be there because replace moved it to end
338
+ self.assertIn('key2', cache)
339
+
340
+ def test_cache_replace_nonexistent_raises(self):
341
+ """Test that replacing a non-existent key raises KeyError."""
342
+ cache = _BoundedCache(maxsize=5)
343
+
344
+ with pytest.raises(KeyError, match="Cache key does not exist"):
345
+ cache.replace('nonexistent', 'value')
346
+
347
+ def test_cache_get_with_raise_on_miss(self):
348
+ """Test get method with raise_on_miss parameter."""
349
+ cache = _BoundedCache(maxsize=5)
350
+
351
+ cache.set('key1', 'value1')
352
+ cache.set('key2', 'value2')
353
+
354
+ # Should work normally for existing key
355
+ value = cache.get('key1', raise_on_miss=True)
356
+ self.assertEqual(value, 'value1')
357
+
358
+ # Should raise ValueError for missing key with raise_on_miss=True
359
+ with pytest.raises(ValueError, match="not compiled for the requested cache key"):
360
+ cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
361
+
362
+ def test_cache_detailed_error_message(self):
363
+ """Test that error message shows available keys."""
364
+ cache = _BoundedCache(maxsize=5)
365
+
366
+ cache.set('key1', 'value1')
367
+ cache.set('key2', 'value2')
368
+
369
+ # Error should include all available keys
370
+ with pytest.raises(ValueError) as exc_info:
371
+ cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
372
+
373
+ error_msg = str(exc_info.value)
374
+ # Should show requested key
375
+ self.assertIn('nonexistent', error_msg)
376
+ # Should show available keys
377
+ self.assertIn('key1', error_msg)
378
+ self.assertIn('key2', error_msg)
379
+ # Should have helpful message
380
+ self.assertIn('make_jaxpr()', error_msg)
381
+
382
+ def test_cache_error_message_no_keys(self):
383
+ """Test error message when cache is empty."""
384
+ cache = _BoundedCache(maxsize=5)
385
+
386
+ with pytest.raises(ValueError) as exc_info:
387
+ cache.get('key', raise_on_miss=True, error_context="Empty cache")
388
+
389
+ error_msg = str(exc_info.value)
390
+ # Should indicate no keys available
391
+ self.assertIn('none', error_msg.lower())
392
+
393
+ def test_cache_thread_safety(self):
394
+ """Test thread safety of cache operations."""
395
+ cache = _BoundedCache(maxsize=100)
396
+ errors = []
397
+
398
+ def worker(thread_id):
399
+ try:
400
+ for i in range(50):
401
+ key = f'key_{thread_id}_{i}'
402
+ cache.set(key, f'value_{thread_id}_{i}')
403
+ value = cache.get(key)
404
+ if value != f'value_{thread_id}_{i}':
405
+ errors.append(f'Mismatch in thread {thread_id}')
406
+ except Exception as e:
407
+ errors.append(f'Error in thread {thread_id}: {e}')
408
+
409
+ # Create multiple threads
410
+ threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
411
+
412
+ # Start all threads
413
+ for t in threads:
414
+ t.start()
415
+
416
+ # Wait for all threads to complete
417
+ for t in threads:
418
+ t.join()
419
+
420
+ # Check no errors occurred
421
+ self.assertEqual(len(errors), 0, f"Thread safety errors: {errors}")
422
+
423
+
424
+ class TestStatefulFunctionEnhancements(unittest.TestCase):
425
+ """Test enhancements to StatefulFunction class."""
426
+
427
+ def test_cache_stats(self):
428
+ """Test get_cache_stats method."""
429
+ state = brainstate.State(jnp.array([1.0, 2.0]))
430
+
431
+ def f(x):
432
+ state.value += x
433
+ return state.value * 2
434
+
435
+ sf = brainstate.transform.StatefulFunction(f)
436
+
437
+ # Compile for different inputs
438
+ x1 = jnp.array([0.5, 0.5])
439
+ x2 = jnp.array([1.0, 1.0])
440
+
441
+ sf.make_jaxpr(x1)
442
+ sf.make_jaxpr(x2)
443
+
444
+ # Get cache stats
445
+ stats = sf.get_cache_stats()
446
+
447
+ # Verify all cache types are present
448
+ self.assertIn('jaxpr_cache', stats)
449
+ self.assertIn('out_shapes_cache', stats)
450
+ self.assertIn('jaxpr_out_tree_cache', stats)
451
+ self.assertIn('state_trace_cache', stats)
452
+
453
+ # Verify each cache has proper stats
454
+ for cache_name, cache_stats in stats.items():
455
+ self.assertIn('size', cache_stats)
456
+ self.assertIn('maxsize', cache_stats)
457
+ self.assertIn('hits', cache_stats)
458
+ self.assertIn('misses', cache_stats)
459
+ self.assertIn('hit_rate', cache_stats)
460
+
461
+ def test_validate_states(self):
462
+ """Test validate_states method."""
463
+ state = brainstate.State(jnp.array([1.0, 2.0]))
464
+
465
+ def f(x):
466
+ state.value += x
467
+ return state.value
468
+
469
+ sf = brainstate.transform.StatefulFunction(f)
470
+ x = jnp.array([0.5, 0.5])
471
+ sf.make_jaxpr(x)
472
+
473
+ cache_key = sf.get_arg_cache_key(x)
474
+
475
+ # Should validate successfully
476
+ result = sf.validate_states(cache_key)
477
+ self.assertTrue(result)
478
+
479
+ def test_validate_all_states(self):
480
+ """Test validate_all_states method."""
481
+ state = brainstate.State(jnp.array([1.0, 2.0]))
482
+
483
+ def f(x, n):
484
+ state.value += x
485
+ return state.value * n
486
+
487
+ # Use static_argnums to create different cache keys
488
+ sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
489
+
490
+ # Compile for multiple inputs with different static args
491
+ x = jnp.array([0.5, 0.5])
492
+
493
+ sf.make_jaxpr(x, 1)
494
+ sf.make_jaxpr(x, 2)
495
+
496
+ # Validate all
497
+ results = sf.validate_all_states()
498
+
499
+ # Should have results for both cache keys
500
+ self.assertEqual(len(results), 2)
501
+
502
+ # All should be valid
503
+ for result in results.values():
504
+ self.assertTrue(result)
505
+
506
+ def test_clear_cache(self):
507
+ """Test clear_cache method."""
508
+ state = brainstate.State(jnp.array([1.0, 2.0]))
509
+
510
+ def f(x):
511
+ state.value += x
512
+ return state.value
513
+
514
+ sf = brainstate.transform.StatefulFunction(f)
515
+ x = jnp.array([0.5, 0.5])
516
+ sf.make_jaxpr(x)
517
+
518
+ # Verify cache has entries
519
+ stats = sf.get_cache_stats()
520
+ self.assertGreater(stats['jaxpr_cache']['size'], 0)
521
+
522
+ # Clear cache
523
+ sf.clear_cache()
524
+
525
+ # Verify all caches are empty
526
+ stats = sf.get_cache_stats()
527
+ self.assertEqual(stats['jaxpr_cache']['size'], 0)
528
+ self.assertEqual(stats['out_shapes_cache']['size'], 0)
529
+ self.assertEqual(stats['jaxpr_out_tree_cache']['size'], 0)
530
+ self.assertEqual(stats['state_trace_cache']['size'], 0)
531
+
532
+ def test_return_only_write_parameter(self):
533
+ """Test return_only_write parameter."""
534
+ read_state = brainstate.State(jnp.array([1.0, 2.0]))
535
+ write_state = brainstate.State(jnp.array([3.0, 4.0]))
536
+
537
+ def f(x):
538
+ # Read from read_state, write to write_state
539
+ _ = read_state.value + x
540
+ write_state.value += x
541
+ return write_state.value
542
+
543
+ # Test with return_only_write=False (default)
544
+ sf_all = brainstate.transform.StatefulFunction(f, return_only_write=False)
545
+ sf_all.make_jaxpr(jnp.array([0.5, 0.5]))
546
+ cache_key = sf_all.get_arg_cache_key(jnp.array([0.5, 0.5]))
547
+ states_all = sf_all.get_states_by_cache(cache_key)
548
+
549
+ # Test with return_only_write=True
550
+ sf_write_only = brainstate.transform.StatefulFunction(f, return_only_write=True)
551
+ sf_write_only.make_jaxpr(jnp.array([0.5, 0.5]))
552
+ cache_key_write = sf_write_only.get_arg_cache_key(jnp.array([0.5, 0.5]))
553
+ states_write = sf_write_only.get_states_by_cache(cache_key_write)
554
+
555
+ # With return_only_write=True, should have fewer or equal states
556
+ self.assertLessEqual(len(states_write), len(states_all))
557
+
558
+
559
+ class TestErrorHandling(unittest.TestCase):
560
+ """Test error handling in StatefulFunction."""
561
+
562
+ def test_jaxpr_call_state_mismatch(self):
563
+ """Test error when state values length doesn't match."""
564
+ state1 = brainstate.State(jnp.array([1.0, 2.0]))
565
+ state2 = brainstate.State(jnp.array([3.0, 4.0]))
566
+
567
+ def f(x):
568
+ state1.value += x
569
+ state2.value += x
570
+ return state1.value + state2.value
571
+
572
+ sf = brainstate.transform.StatefulFunction(f)
573
+ x = jnp.array([0.5, 0.5])
574
+ sf.make_jaxpr(x)
575
+
576
+ # Try to call with wrong number of state values (only 1 instead of 2)
577
+ with pytest.raises(ValueError, match="State length mismatch"):
578
+ sf.jaxpr_call([jnp.array([1.0, 1.0])], x) # Only 1 state instead of 2
579
+
580
+ def test_get_jaxpr_not_compiled_detailed_error(self):
581
+ """Test detailed error message when getting jaxpr for uncompiled function."""
582
+ state = brainstate.State(jnp.array([1.0, 2.0]))
583
+
584
+ def f(x):
585
+ return x * 2
586
+
587
+ sf = brainstate.transform.StatefulFunction(f)
588
+
589
+ # Compile for one input shape
590
+ sf.make_jaxpr(jnp.array([1.0, 2.0]))
591
+
592
+ # Try to get jaxpr with a different cache key
593
+ from brainstate.transform._make_jaxpr import hashabledict
594
+ fake_key = hashabledict(
595
+ static_args=(),
596
+ dyn_args=(),
597
+ static_kwargs=(),
598
+ dyn_kwargs=()
599
+ )
600
+
601
+ # Should raise detailed error
602
+ with pytest.raises(ValueError) as exc_info:
603
+ sf.get_jaxpr_by_cache(fake_key)
604
+
605
+ error_msg = str(exc_info.value)
606
+ # Should contain the requested key
607
+ self.assertIn('Requested key:', error_msg)
608
+ # Should show available keys
609
+ self.assertIn('Available', error_msg)
610
+ # Should have helpful message
611
+ self.assertIn('make_jaxpr()', error_msg)
612
+
613
+ def test_get_out_shapes_not_compiled_detailed_error(self):
614
+ """Test detailed error message when getting output shapes for uncompiled function."""
615
+
616
+ def f(x):
617
+ return x * 2
618
+
619
+ sf = brainstate.transform.StatefulFunction(f)
620
+
621
+ from brainstate.transform._make_jaxpr import hashabledict
622
+ fake_key = hashabledict(
623
+ static_args=(),
624
+ dyn_args=(),
625
+ static_kwargs=(),
626
+ dyn_kwargs=()
627
+ )
628
+
629
+ # Should raise detailed error with context "Output shapes"
630
+ with pytest.raises(ValueError) as exc_info:
631
+ sf.get_out_shapes_by_cache(fake_key)
632
+
633
+ error_msg = str(exc_info.value)
634
+ self.assertIn('Output shapes', error_msg)
635
+ self.assertIn('Requested key:', error_msg)
636
+
637
+ def test_get_out_treedef_not_compiled_detailed_error(self):
638
+ """Test detailed error message when getting output tree for uncompiled function."""
639
+
640
+ def f(x):
641
+ return x * 2
642
+
643
+ sf = brainstate.transform.StatefulFunction(f)
644
+
645
+ from brainstate.transform._make_jaxpr import hashabledict
646
+ fake_key = hashabledict(
647
+ static_args=(),
648
+ dyn_args=(),
649
+ static_kwargs=(),
650
+ dyn_kwargs=()
651
+ )
652
+
653
+ # Should raise detailed error with context "Output tree"
654
+ with pytest.raises(ValueError) as exc_info:
655
+ sf.get_out_treedef_by_cache(fake_key)
656
+
657
+ error_msg = str(exc_info.value)
658
+ self.assertIn('Output tree', error_msg)
659
+ self.assertIn('Requested key:', error_msg)
660
+
661
+ def test_get_state_trace_not_compiled_detailed_error(self):
662
+ """Test detailed error message when getting state trace for uncompiled function."""
663
+
664
+ def f(x):
665
+ return x * 2
666
+
667
+ sf = brainstate.transform.StatefulFunction(f)
668
+
669
+ from brainstate.transform._make_jaxpr import hashabledict
670
+ fake_key = hashabledict(
671
+ static_args=(),
672
+ dyn_args=(),
673
+ static_kwargs=(),
674
+ dyn_kwargs=()
675
+ )
676
+
677
+ # Should raise detailed error with context "State trace"
678
+ with pytest.raises(ValueError) as exc_info:
679
+ sf.get_state_trace_by_cache(fake_key)
680
+
681
+ error_msg = str(exc_info.value)
682
+ self.assertIn('State trace', error_msg)
683
+ self.assertIn('Requested key:', error_msg)
684
+
685
+
686
+ class TestCompileIfMiss(unittest.TestCase):
687
+ """Test compile_if_miss parameter in *_by_call methods."""
688
+
689
+ def test_get_jaxpr_by_call_with_compile_if_miss_true(self):
690
+ """Test get_jaxpr_by_call with compile_if_miss=True (default)."""
691
+
692
+ def f(x):
693
+ return x * 2
694
+
695
+ sf = brainstate.transform.StatefulFunction(f)
696
+
697
+ # Should compile automatically
698
+ jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=True)
699
+ self.assertIsNotNone(jaxpr)
700
+
701
+ def test_get_jaxpr_by_call_with_compile_if_miss_false(self):
702
+ """Test get_jaxpr_by_call with compile_if_miss=False."""
703
+
704
+ def f(x):
705
+ return x * 2
706
+
707
+ sf = brainstate.transform.StatefulFunction(f)
708
+
709
+ # Should raise error because not compiled
710
+ with pytest.raises(ValueError, match="not compiled"):
711
+ sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=False)
712
+
713
+ def test_get_out_shapes_by_call_compile_if_miss(self):
714
+ """Test get_out_shapes_by_call with compile_if_miss parameter."""
715
+ state = brainstate.State(jnp.array([1.0, 2.0]))
716
+
717
+ def f(x):
718
+ state.value += x
719
+ return state.value * 2
720
+
721
+ sf = brainstate.transform.StatefulFunction(f)
722
+
723
+ # With compile_if_miss=True, should compile automatically
724
+ shapes = sf.get_out_shapes(jnp.array([1.0, 2.0]), compile_if_miss=True)
725
+ self.assertIsNotNone(shapes)
726
+
727
+ # With compile_if_miss=False on different input, should fail
728
+ with pytest.raises(ValueError):
729
+ sf.get_out_shapes(jnp.array([1.0, 2.0, 3.0]), compile_if_miss=False)
730
+
731
+ def test_get_out_treedef_by_call_compile_if_miss(self):
732
+ """Test get_out_treedef_by_call with compile_if_miss parameter."""
733
+
734
+ def f(x):
735
+ return x * 2, x + 1
736
+
737
+ sf = brainstate.transform.StatefulFunction(f)
738
+
739
+ # Should compile automatically with default compile_if_miss=True
740
+ treedef = sf.get_out_treedef(jnp.array([1.0, 2.0]))
741
+ self.assertIsNotNone(treedef)
742
+
743
+ def test_get_state_trace_by_call_compile_if_miss(self):
744
+ """Test get_state_trace_by_call with compile_if_miss parameter."""
745
+ state = brainstate.State(jnp.array([1.0, 2.0]))
746
+
747
+ def f(x):
748
+ state.value += x
749
+ return state.value
750
+
751
+ sf = brainstate.transform.StatefulFunction(f)
752
+
753
+ # Should compile automatically
754
+ trace = sf.get_state_trace(jnp.array([1.0, 2.0]), compile_if_miss=True)
755
+ self.assertIsNotNone(trace)
756
+
757
+ def test_get_states_by_call_compile_if_miss(self):
758
+ """Test get_states_by_call with compile_if_miss parameter."""
759
+ state1 = brainstate.State(jnp.array([1.0, 2.0]))
760
+ state2 = brainstate.State(jnp.array([3.0, 4.0]))
761
+
762
+ def f(x):
763
+ state1.value += x
764
+ state2.value += x
765
+ return state1.value + state2.value
766
+
767
+ sf = brainstate.transform.StatefulFunction(f)
768
+
769
+ # Should compile automatically
770
+ states = sf.get_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
771
+ self.assertEqual(len(states), 2)
772
+
773
+ def test_get_read_states_by_call_compile_if_miss(self):
774
+ """Test get_read_states_by_call with compile_if_miss parameter."""
775
+ read_state = brainstate.State(jnp.array([1.0, 2.0]))
776
+ write_state = brainstate.State(jnp.array([3.0, 4.0]))
777
+
778
+ def f(x):
779
+ _ = read_state.value
780
+ write_state.value += x
781
+ return write_state.value
782
+
783
+ sf = brainstate.transform.StatefulFunction(f)
784
+
785
+ # Should compile automatically
786
+ read_states = sf.get_read_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
787
+ self.assertIsNotNone(read_states)
788
+
789
+ def test_get_write_states_by_call_compile_if_miss(self):
790
+ """Test get_write_states_by_call with compile_if_miss parameter."""
791
+ read_state = brainstate.State(jnp.array([1.0, 2.0]))
792
+ write_state = brainstate.State(jnp.array([3.0, 4.0]))
793
+
794
+ def f(x):
795
+ _ = read_state.value
796
+ write_state.value += x
797
+ return write_state.value
798
+
799
+ sf = brainstate.transform.StatefulFunction(f)
800
+
801
+ # Should compile automatically
802
+ write_states = sf.get_write_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
803
+ self.assertIsNotNone(write_states)
804
+
805
+ def test_compile_if_miss_default_behavior(self):
806
+ """Test that compile_if_miss defaults to True for all *_by_call methods."""
807
+ state = brainstate.State(jnp.array([1.0, 2.0]))
808
+
809
+ def f(x):
810
+ state.value += x
811
+ return state.value
812
+
813
+ sf = brainstate.transform.StatefulFunction(f)
814
+
815
+ # All these should work without explicit compile_if_miss=True
816
+ jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]))
817
+ self.assertIsNotNone(jaxpr)
818
+
819
+ # Create new instance for fresh cache
820
+ sf2 = brainstate.transform.StatefulFunction(f)
821
+ shapes = sf2.get_out_shapes(jnp.array([1.0, 2.0]))
822
+ self.assertIsNotNone(shapes)
823
+
824
+ # Create new instance for fresh cache
825
+ sf3 = brainstate.transform.StatefulFunction(f)
826
+ states = sf3.get_states(jnp.array([1.0, 2.0]))
827
+ self.assertIsNotNone(states)
828
+
829
+
830
+ class TestMakeHashable(unittest.TestCase):
831
+ """Test the make_hashable utility function."""
832
+
833
+ def test_hashable_list(self):
834
+ """Test converting list to hashable."""
835
+ result = make_hashable([1, 2, 3])
836
+ # Should return a tuple
837
+ self.assertIsInstance(result, tuple)
838
+ # Should be hashable
839
+ hash(result)
840
+
841
+ def test_hashable_dict(self):
842
+ """Test converting dict to hashable."""
843
+ result = make_hashable({'b': 2, 'a': 1})
844
+ # Should return a tuple of sorted key-value pairs
845
+ self.assertIsInstance(result, tuple)
846
+ # Should be hashable
847
+ hash(result)
848
+ # Keys should be sorted
849
+ keys = [item[0] for item in result]
850
+ self.assertEqual(keys, ['a', 'b'])
851
+
852
+ def test_hashable_set(self):
853
+ """Test converting set to hashable."""
854
+ result = make_hashable({1, 2, 3})
855
+ # Should return a frozenset
856
+ self.assertIsInstance(result, frozenset)
857
+ # Should be hashable
858
+ hash(result)
859
+
860
+ def test_hashable_nested(self):
861
+ """Test converting nested structures."""
862
+ nested = {
863
+ 'list': [1, 2, 3],
864
+ 'dict': {'a': 1, 'b': 2},
865
+ 'set': {4, 5}
866
+ }
867
+ result = make_hashable(nested)
868
+ # Should be hashable
869
+ hash(result) # Should not raise
870
+
871
+ def test_hashable_tuple(self):
872
+ """Test with tuples."""
873
+ result = make_hashable((1, 2, 3))
874
+ # Should return a tuple
875
+ self.assertIsInstance(result, tuple)
876
+ # Should be hashable
877
+ hash(result)
878
+
879
+ def test_hashable_idempotent(self):
880
+ """Test that applying make_hashable twice gives consistent results."""
881
+ original = {'a': [1, 2], 'b': {3, 4}}
882
+ result1 = make_hashable(original)
883
+ result2 = make_hashable(original)
884
+ # Should be the same
885
+ self.assertEqual(result1, result2)
886
+
887
+
888
+ class TestCacheCleanupOnError(unittest.TestCase):
889
+ """Test that cache is properly cleaned up when compilation fails."""
890
+
891
+ def test_cache_cleanup_on_compilation_error(self):
892
+ """Test that partial cache entries are cleaned up when make_jaxpr fails."""
893
+
894
+ def f(x):
895
+ # This will cause an error during JAX tracing
896
+ if x > 0: # Control flow not allowed in JAX
897
+ return x * 2
898
+ else:
899
+ return x + 1
900
+
901
+ sf = brainstate.transform.StatefulFunction(f)
902
+
903
+ # Try to compile, should fail
904
+ try:
905
+ sf.make_jaxpr(jnp.array([1.0]))
906
+ except Exception:
907
+ pass # Expected to fail
908
+
909
+ # Cache should be empty after error
910
+ stats = sf.get_cache_stats()
911
+ # All caches should be empty since error cleanup should have removed partial entries
912
+ # Note: The actual behavior depends on when the error occurs during compilation
913
+ # If error happens early, no cache entries; if late, entries might exist
914
+ # This test just verifies the cleanup mechanism exists
915
+
916
+
917
+ class TestMakeJaxprReturnOnlyWrite(unittest.TestCase):
918
+ """Test make_jaxpr with return_only_write parameter."""
919
+
920
+ def test_make_jaxpr_return_only_write(self):
921
+ """Test make_jaxpr function with return_only_write parameter."""
922
+ read_state = brainstate.State(jnp.array([1.0]))
923
+ write_state = brainstate.State(jnp.array([2.0]))
924
+
925
+ def f(x):
926
+ _ = read_state.value # Read only
927
+ write_state.value += x # Write
928
+ return x * 2
929
+
930
+ # Test with return_only_write=True
931
+ jaxpr_maker = brainstate.transform.make_jaxpr(f, return_only_write=True)
932
+ jaxpr, states = jaxpr_maker(jnp.array([1.0]))
933
+
934
+ # Should compile successfully
935
+ self.assertIsNotNone(jaxpr)
936
+ self.assertIsInstance(states, tuple)
937
+
938
+
939
+ class TestStatefulFunctionCallable(unittest.TestCase):
940
+ """Test __call__ method of StatefulFunction."""
941
+
942
+ def test_stateful_function_call(self):
943
+ """Test calling StatefulFunction directly."""
944
+ state = brainstate.State(jnp.array([1.0, 2.0]))
945
+
946
+ def f(x):
947
+ state.value += x
948
+ return state.value * 2
949
+
950
+ sf = brainstate.transform.StatefulFunction(f)
951
+ x = jnp.array([0.5, 0.5])
952
+ sf.make_jaxpr(x)
953
+
954
+ # Test direct call
955
+ result = sf(x)
956
+ self.assertEqual(result.shape, (2,))
957
+
958
+ def test_stateful_function_call_auto_compile(self):
959
+ """Test that __call__ automatically compiles if needed."""
960
+ state = brainstate.State(jnp.array([1.0, 2.0]))
961
+
962
+ def f(x):
963
+ state.value += x
964
+ return state.value * 2
965
+
966
+ sf = brainstate.transform.StatefulFunction(f)
967
+ x = jnp.array([0.5, 0.5])
968
+
969
+ # Call without pre-compilation should work
970
+ result = sf(x)
971
+ self.assertEqual(result.shape, (2,))
972
+
973
+ def test_stateful_function_multiple_calls(self):
974
+ """Test multiple calls to StatefulFunction."""
975
+ state = brainstate.State(jnp.array([0.0]))
976
+
977
+ def f(x):
978
+ state.value += x
979
+ return state.value
980
+
981
+ sf = brainstate.transform.StatefulFunction(f)
982
+
983
+ # Multiple calls should accumulate state
984
+ result1 = sf(jnp.array([1.0]))
985
+ result2 = sf(jnp.array([2.0]))
986
+ result3 = sf(jnp.array([3.0]))
987
+
988
+ # Each call should update the state
989
+ self.assertIsNotNone(result1)
990
+ self.assertIsNotNone(result2)
991
+ self.assertIsNotNone(result3)
992
+
993
+
994
+ class TestStatefulFunctionStaticArgs(unittest.TestCase):
995
+ """Test StatefulFunction with static arguments."""
996
+
997
+ def test_static_argnums_basic(self):
998
+ """Test basic usage of static_argnums."""
999
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1000
+
1001
+ def f(x, multiplier):
1002
+ state.value += x
1003
+ return state.value * multiplier
1004
+
1005
+ sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
1006
+ x = jnp.array([0.5, 0.5])
1007
+
1008
+ # Compile with multiplier=2
1009
+ sf.make_jaxpr(x, 2)
1010
+ cache_key1 = sf.get_arg_cache_key(x, 2)
1011
+
1012
+ # Compile with multiplier=3
1013
+ sf.make_jaxpr(x, 3)
1014
+ cache_key2 = sf.get_arg_cache_key(x, 3)
1015
+
1016
+ # Should have different cache keys
1017
+ self.assertNotEqual(cache_key1, cache_key2)
1018
+
1019
+ def test_static_argnames_basic(self):
1020
+ """Test basic usage of static_argnames."""
1021
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1022
+
1023
+ def f(x, multiplier=2):
1024
+ state.value += x
1025
+ return state.value * multiplier
1026
+
1027
+ sf = brainstate.transform.StatefulFunction(f, static_argnames='multiplier')
1028
+ x = jnp.array([0.5, 0.5])
1029
+
1030
+ # Compile with different multiplier values
1031
+ sf.make_jaxpr(x, multiplier=2)
1032
+ cache_key1 = sf.get_arg_cache_key(x, multiplier=2)
1033
+
1034
+ sf.make_jaxpr(x, multiplier=3)
1035
+ cache_key2 = sf.get_arg_cache_key(x, multiplier=3)
1036
+
1037
+ # Should have different cache keys
1038
+ self.assertNotEqual(cache_key1, cache_key2)
1039
+
1040
+ def test_static_args_combination(self):
1041
+ """Test using both static_argnums and static_argnames."""
1042
+ state = brainstate.State(jnp.array([1.0]))
1043
+
1044
+ def f(x, multiplier, offset=0):
1045
+ state.value += x
1046
+ return state.value * multiplier + offset
1047
+
1048
+ sf = brainstate.transform.StatefulFunction(
1049
+ f, static_argnums=(1,), static_argnames='offset'
1050
+ )
1051
+ x = jnp.array([0.5])
1052
+
1053
+ # Compile with different static args
1054
+ sf.make_jaxpr(x, 2, offset=0)
1055
+ cache_key1 = sf.get_arg_cache_key(x, 2, offset=0)
1056
+
1057
+ sf.make_jaxpr(x, 3, offset=1)
1058
+ cache_key2 = sf.get_arg_cache_key(x, 3, offset=1)
1059
+
1060
+ # Should have different cache keys
1061
+ self.assertNotEqual(cache_key1, cache_key2)
1062
+
1063
+
1064
+ class TestStatefulFunctionComplexStates(unittest.TestCase):
1065
+ """Test StatefulFunction with complex state scenarios."""
1066
+
1067
+ def test_multiple_states(self):
1068
+ """Test function with multiple states."""
1069
+ state1 = brainstate.State(jnp.array([1.0]))
1070
+ state2 = brainstate.State(jnp.array([2.0]))
1071
+ state3 = brainstate.State(jnp.array([3.0]))
1072
+
1073
+ def f(x):
1074
+ state1.value += x
1075
+ state2.value += x * 2
1076
+ state3.value += x * 3
1077
+ return state1.value + state2.value + state3.value
1078
+
1079
+ sf = brainstate.transform.StatefulFunction(f)
1080
+ x = jnp.array([1.0])
1081
+ sf.make_jaxpr(x)
1082
+
1083
+ cache_key = sf.get_arg_cache_key(x)
1084
+ states = sf.get_states_by_cache(cache_key)
1085
+
1086
+ # Should track all three states
1087
+ self.assertEqual(len(states), 3)
1088
+
1089
+ def test_nested_state_access(self):
1090
+ """Test function with nested state access patterns."""
1091
+ outer_state = brainstate.State(jnp.array([1.0]))
1092
+ inner_state = brainstate.State(jnp.array([2.0]))
1093
+
1094
+ def inner_fn(x):
1095
+ inner_state.value += x
1096
+ return inner_state.value
1097
+
1098
+ def outer_fn(x):
1099
+ outer_state.value += x
1100
+ result = inner_fn(x)
1101
+ return outer_state.value + result
1102
+
1103
+ sf = brainstate.transform.StatefulFunction(outer_fn)
1104
+ x = jnp.array([1.0])
1105
+ sf.make_jaxpr(x)
1106
+
1107
+ cache_key = sf.get_arg_cache_key(x)
1108
+ states = sf.get_states_by_cache(cache_key)
1109
+
1110
+ # Should track both states
1111
+ self.assertGreaterEqual(len(states), 2)
1112
+
1113
+ def test_conditional_state_write(self):
1114
+ """Test function that conditionally writes to states."""
1115
+ state1 = brainstate.State(jnp.array([1.0]))
1116
+ state2 = brainstate.State(jnp.array([2.0]))
1117
+
1118
+ def f(x, write_state1=True):
1119
+ # Note: In JAX, actual control flow needs special handling
1120
+ # This test is more about the framework's ability to track states
1121
+ state1.value += x # Always write to state1
1122
+ state2.value += x * 2 # Always write to state2
1123
+ return state1.value + state2.value
1124
+
1125
+ sf = brainstate.transform.StatefulFunction(f, static_argnames='write_state1')
1126
+ x = jnp.array([1.0])
1127
+ sf.make_jaxpr(x, write_state1=True)
1128
+
1129
+ cache_key = sf.get_arg_cache_key(x, write_state1=True)
1130
+ states = sf.get_states_by_cache(cache_key)
1131
+
1132
+ # Should track states
1133
+ self.assertGreaterEqual(len(states), 2)
1134
+
1135
+
1136
+ class TestStatefulFunctionOutputShapes(unittest.TestCase):
1137
+ """Test StatefulFunction output shape tracking."""
1138
+
1139
+ def test_single_output(self):
1140
+ """Test tracking single output shape."""
1141
+ state = brainstate.State(jnp.array([1.0, 2.0, 3.0]))
1142
+
1143
+ def f(x):
1144
+ state.value += x
1145
+ return state.value
1146
+
1147
+ sf = brainstate.transform.StatefulFunction(f)
1148
+ x = jnp.array([1.0, 2.0, 3.0])
1149
+ sf.make_jaxpr(x)
1150
+
1151
+ cache_key = sf.get_arg_cache_key(x)
1152
+ out_shapes = sf.get_out_shapes_by_cache(cache_key)
1153
+
1154
+ # Should have output shapes
1155
+ self.assertIsNotNone(out_shapes)
1156
+
1157
+ def test_multiple_outputs(self):
1158
+ """Test tracking multiple output shapes."""
1159
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1160
+
1161
+ def f(x):
1162
+ state.value += x
1163
+ return state.value, state.value * 2, jnp.sum(state.value)
1164
+
1165
+ sf = brainstate.transform.StatefulFunction(f)
1166
+ x = jnp.array([1.0, 2.0])
1167
+ sf.make_jaxpr(x)
1168
+
1169
+ cache_key = sf.get_arg_cache_key(x)
1170
+ out_shapes = sf.get_out_shapes_by_cache(cache_key)
1171
+
1172
+ # Should track all output shapes
1173
+ self.assertIsNotNone(out_shapes)
1174
+
1175
+ def test_nested_output_structure(self):
1176
+ """Test tracking nested output structures."""
1177
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1178
+
1179
+ def f(x):
1180
+ state.value += x
1181
+ return {
1182
+ 'sum': jnp.sum(state.value),
1183
+ 'prod': jnp.prod(state.value),
1184
+ 'values': state.value
1185
+ }
1186
+
1187
+ sf = brainstate.transform.StatefulFunction(f)
1188
+ x = jnp.array([1.0, 2.0])
1189
+ sf.make_jaxpr(x)
1190
+
1191
+ cache_key = sf.get_arg_cache_key(x)
1192
+ out_treedef = sf.get_out_treedef_by_cache(cache_key)
1193
+
1194
+ # Should have tree definition
1195
+ self.assertIsNotNone(out_treedef)
1196
+
1197
+
1198
+ class TestStatefulFunctionJaxprCall(unittest.TestCase):
1199
+ """Test jaxpr_call and jaxpr_call_auto methods."""
1200
+
1201
+ def test_jaxpr_call_basic(self):
1202
+ """Test basic jaxpr_call usage."""
1203
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1204
+
1205
+ def f(x):
1206
+ state.value += x
1207
+ return state.value * 2
1208
+
1209
+ sf = brainstate.transform.StatefulFunction(f)
1210
+ x = jnp.array([0.5, 0.5])
1211
+ sf.make_jaxpr(x)
1212
+
1213
+ # Get current state values
1214
+ state_vals = [state.value]
1215
+
1216
+ # Call at jaxpr level
1217
+ new_state_vals, out = sf.jaxpr_call(state_vals, x)
1218
+
1219
+ self.assertEqual(len(new_state_vals), 1)
1220
+ self.assertEqual(out.shape, (2,))
1221
+
1222
+ def test_jaxpr_call_auto_basic(self):
1223
+ """Test basic jaxpr_call_auto usage."""
1224
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1225
+
1226
+ def f(x):
1227
+ state.value += x
1228
+ return state.value * 2
1229
+
1230
+ sf = brainstate.transform.StatefulFunction(f)
1231
+ x = jnp.array([0.5, 0.5])
1232
+ sf.make_jaxpr(x)
1233
+
1234
+ # Call with automatic state management
1235
+ result = sf.jaxpr_call_auto(x)
1236
+
1237
+ self.assertEqual(result.shape, (2,))
1238
+
1239
+ def test_jaxpr_call_preserves_state_order(self):
1240
+ """Test that jaxpr_call preserves state order."""
1241
+ state1 = brainstate.State(jnp.array([1.0]))
1242
+ state2 = brainstate.State(jnp.array([2.0]))
1243
+ state3 = brainstate.State(jnp.array([3.0]))
1244
+
1245
+ def f(x):
1246
+ state1.value += x
1247
+ state2.value += x * 2
1248
+ state3.value += x * 3
1249
+ return state1.value + state2.value + state3.value
1250
+
1251
+ sf = brainstate.transform.StatefulFunction(f)
1252
+ x = jnp.array([1.0])
1253
+ sf.make_jaxpr(x)
1254
+
1255
+ cache_key = sf.get_arg_cache_key(x)
1256
+ states = sf.get_states_by_cache(cache_key)
1257
+
1258
+ # Get initial state values
1259
+ state_vals = [s.value for s in states]
1260
+
1261
+ # Call at jaxpr level
1262
+ new_state_vals, _ = sf.jaxpr_call(state_vals, x)
1263
+
1264
+ # Should return same number of states
1265
+ self.assertEqual(len(new_state_vals), len(state_vals))
1266
+
1267
+
1268
+ class TestStatefulFunctionEdgeCases(unittest.TestCase):
1269
+ """Test edge cases and corner scenarios."""
1270
+
1271
+ def test_no_state_function(self):
1272
+ """Test function that doesn't use any states."""
1273
+
1274
+ def f(x):
1275
+ return x * 2 + 1
1276
+
1277
+ sf = brainstate.transform.StatefulFunction(f)
1278
+ x = jnp.array([1.0, 2.0])
1279
+ sf.make_jaxpr(x)
1280
+
1281
+ cache_key = sf.get_arg_cache_key(x)
1282
+ states = sf.get_states_by_cache(cache_key)
1283
+
1284
+ # Should have no states
1285
+ self.assertEqual(len(states), 0)
1286
+
1287
+ def test_read_only_state(self):
1288
+ """Test function that only reads from states."""
1289
+ state = brainstate.State(jnp.array([1.0, 2.0]))
1290
+
1291
+ def f(x):
1292
+ # Only read from state, don't write
1293
+ return state.value + x
1294
+
1295
+ sf = brainstate.transform.StatefulFunction(f, return_only_write=True)
1296
+ x = jnp.array([1.0, 2.0])
1297
+ sf.make_jaxpr(x)
1298
+
1299
+ cache_key = sf.get_arg_cache_key(x)
1300
+ write_states = sf.get_write_states_by_cache(cache_key)
1301
+
1302
+ # Should have no write states
1303
+ self.assertEqual(len(write_states), 0)
1304
+
1305
+ def test_scalar_inputs_outputs(self):
1306
+ """Test with scalar inputs and outputs."""
1307
+ state = brainstate.State(jnp.array(1.0))
1308
+
1309
+ def f(x):
1310
+ state.value += x
1311
+ return state.value
1312
+
1313
+ sf = brainstate.transform.StatefulFunction(f)
1314
+ x = jnp.array(0.5)
1315
+ sf.make_jaxpr(x)
1316
+
1317
+ cache_key = sf.get_arg_cache_key(x)
1318
+ jaxpr = sf.get_jaxpr_by_cache(cache_key)
1319
+
1320
+ # Should compile successfully
1321
+ self.assertIsNotNone(jaxpr)
1322
+
1323
+ def test_empty_function(self):
1324
+ """Test function with no operations."""
1325
+
1326
+ def f(x):
1327
+ return x
1328
+
1329
+ sf = brainstate.transform.StatefulFunction(f)
1330
+ x = jnp.array([1.0, 2.0])
1331
+ sf.make_jaxpr(x)
1332
+
1333
+ cache_key = sf.get_arg_cache_key(x)
1334
+ jaxpr = sf.get_jaxpr_by_cache(cache_key)
1335
+
1336
+ # Should compile successfully
1337
+ self.assertIsNotNone(jaxpr)
1338
+
1339
+ def test_complex_dtype(self):
1340
+ """Test with complex dtype arrays."""
1341
+ state = brainstate.State(jnp.array([1.0 + 2.0j, 3.0 + 4.0j]))
1342
+
1343
+ def f(x):
1344
+ state.value += x
1345
+ return state.value
1346
+
1347
+ sf = brainstate.transform.StatefulFunction(f)
1348
+ x = jnp.array([0.5 + 0.5j, 0.5 + 0.5j])
1349
+ sf.make_jaxpr(x)
1350
+
1351
+ cache_key = sf.get_arg_cache_key(x)
1352
+ jaxpr = sf.get_jaxpr_by_cache(cache_key)
1353
+
1354
+ # Should compile successfully
1355
+ self.assertIsNotNone(jaxpr)
1356
+
1357
+
1358
+ class TestStatefulFunctionCacheKey(unittest.TestCase):
1359
+ """Test cache key generation and behavior."""
1360
+
1361
+ def test_cache_key_different_shapes(self):
1362
+ """Test that different input shapes produce different cache keys."""
1363
+
1364
+ def f(x):
1365
+ return x * 2
1366
+
1367
+ sf = brainstate.transform.StatefulFunction(f)
1368
+
1369
+ x1 = jnp.array([1.0, 2.0])
1370
+ x2 = jnp.array([1.0, 2.0, 3.0])
1371
+
1372
+ cache_key1 = sf.get_arg_cache_key(x1)
1373
+ cache_key2 = sf.get_arg_cache_key(x2)
1374
+
1375
+ # Should have different cache keys
1376
+ self.assertNotEqual(cache_key1, cache_key2)
1377
+
1378
+ def test_cache_key_different_dtypes(self):
1379
+ """Test that different dtypes produce different cache keys."""
1380
+
1381
+ def f(x):
1382
+ return x * 2
1383
+
1384
+ sf = brainstate.transform.StatefulFunction(f)
1385
+
1386
+ # Use int32 and float32 instead, which are always available in JAX
1387
+ x1 = jnp.array([1.0, 2.0], dtype=jnp.float32)
1388
+ x2 = jnp.array([1, 2], dtype=jnp.int32)
1389
+
1390
+ cache_key1 = sf.get_arg_cache_key(x1)
1391
+ cache_key2 = sf.get_arg_cache_key(x2)
1392
+
1393
+ # Should have different cache keys due to different dtypes
1394
+ self.assertNotEqual(cache_key1, cache_key2)
1395
+
1396
+ def test_cache_key_same_abstract_values(self):
1397
+ """Test that same abstract values produce same cache keys."""
1398
+
1399
+ def f(x):
1400
+ return x * 2
1401
+
1402
+ sf = brainstate.transform.StatefulFunction(f)
1403
+
1404
+ x1 = jnp.array([1.0, 2.0])
1405
+ x2 = jnp.array([3.0, 4.0]) # Different values, same shape/dtype
1406
+
1407
+ cache_key1 = sf.get_arg_cache_key(x1)
1408
+ cache_key2 = sf.get_arg_cache_key(x2)
1409
+
1410
+ # Should have same cache keys (abstract values are the same)
1411
+ self.assertEqual(cache_key1, cache_key2)
1412
+
1413
+ def test_cache_key_with_pytree_inputs(self):
1414
+ """Test cache key generation with pytree inputs."""
1415
+
1416
+ def f(inputs):
1417
+ x, y = inputs
1418
+ return x + y
1419
+
1420
+ sf = brainstate.transform.StatefulFunction(f)
1421
+
1422
+ inputs1 = (jnp.array([1.0]), jnp.array([2.0]))
1423
+ inputs2 = (jnp.array([3.0]), jnp.array([4.0]))
1424
+
1425
+ cache_key1 = sf.get_arg_cache_key(inputs1)
1426
+ cache_key2 = sf.get_arg_cache_key(inputs2)
1427
+
1428
+ # Should have same cache keys (same structure/shapes)
1429
+ self.assertEqual(cache_key1, cache_key2)
1430
+
1431
+
1432
+ class TestStatefulFunctionRecompilation(unittest.TestCase):
1433
+ """Test recompilation scenarios."""
1434
+
1435
+ def test_cache_reuse(self):
1436
+ """Test that cache is reused for same inputs."""
1437
+ state = brainstate.State(jnp.array([1.0]))
1438
+
1439
+ def f(x):
1440
+ state.value += x
1441
+ return state.value
1442
+
1443
+ sf = brainstate.transform.StatefulFunction(f)
1444
+
1445
+ x = jnp.array([1.0])
1446
+
1447
+ # First compilation
1448
+ sf.make_jaxpr(x)
1449
+ stats1 = sf.get_cache_stats()
1450
+
1451
+ # Second call with same shape should reuse cache
1452
+ sf.make_jaxpr(x)
1453
+ stats2 = sf.get_cache_stats()
1454
+
1455
+ # Cache size should remain the same
1456
+ self.assertEqual(
1457
+ stats1['jaxpr_cache']['size'],
1458
+ stats2['jaxpr_cache']['size']
1459
+ )
1460
+
1461
+ def test_multiple_compilations_different_shapes(self):
1462
+ """Test multiple compilations with different shapes."""
1463
+ state = brainstate.State(jnp.array([1.0]))
1464
+
1465
+ def f(x):
1466
+ return x * 2
1467
+
1468
+ sf = brainstate.transform.StatefulFunction(f)
1469
+
1470
+ # Compile for different shapes
1471
+ shapes = [
1472
+ jnp.array([1.0]),
1473
+ jnp.array([1.0, 2.0]),
1474
+ jnp.array([1.0, 2.0, 3.0]),
1475
+ ]
1476
+
1477
+ for x in shapes:
1478
+ sf.make_jaxpr(x)
1479
+
1480
+ stats = sf.get_cache_stats()
1481
+
1482
+ # Should have 3 different cache entries
1483
+ self.assertEqual(stats['jaxpr_cache']['size'], 3)
1484
+
1485
+ def test_clear_and_recompile(self):
1486
+ """Test clearing cache and recompiling."""
1487
+ state = brainstate.State(jnp.array([1.0]))
1488
+
1489
+ def f(x):
1490
+ state.value += x
1491
+ return state.value
1492
+
1493
+ sf = brainstate.transform.StatefulFunction(f)
1494
+ x = jnp.array([1.0])
1495
+
1496
+ # Compile
1497
+ sf.make_jaxpr(x)
1498
+ stats_before = sf.get_cache_stats()
1499
+ self.assertGreater(stats_before['jaxpr_cache']['size'], 0)
1500
+
1501
+ # Clear cache
1502
+ sf.clear_cache()
1503
+ stats_after_clear = sf.get_cache_stats()
1504
+ self.assertEqual(stats_after_clear['jaxpr_cache']['size'], 0)
1505
+
1506
+ # Recompile
1507
+ sf.make_jaxpr(x)
1508
+ stats_after_recompile = sf.get_cache_stats()
1509
+ self.assertGreater(stats_after_recompile['jaxpr_cache']['size'], 0)
1510
+