brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1510 +1,1634 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import threading
|
18
|
-
import unittest
|
19
|
-
|
20
|
-
|
21
|
-
import jax
|
22
|
-
import
|
23
|
-
|
24
|
-
import
|
25
|
-
|
26
|
-
|
27
|
-
from brainstate.
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
self.assertTrue(
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
print(
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
def
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
jaxpr
|
88
|
-
|
89
|
-
jaxpr
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
def
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
print(
|
110
|
-
jaxpr
|
111
|
-
print(
|
112
|
-
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
jaxpr
|
124
|
-
print(
|
125
|
-
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
brainstate.
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
cache
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
self.
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
self.
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
self.
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
cache
|
191
|
-
|
192
|
-
cache
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
#
|
205
|
-
cache.
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
cache
|
215
|
-
|
216
|
-
|
217
|
-
cache.
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
cache.
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
self.assertEqual(stats['
|
235
|
-
self.assertEqual(stats['
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
cache.
|
243
|
-
|
244
|
-
|
245
|
-
cache.get('
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
self.assertEqual(stats['
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
cache
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
cache.
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
cache.
|
278
|
-
cache
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
cache
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
value
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
cache
|
321
|
-
|
322
|
-
|
323
|
-
cache.
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
cache.
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
cache.
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
cache.set('
|
336
|
-
|
337
|
-
#
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
cache
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
cache
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
self.assertIn('
|
379
|
-
# Should
|
380
|
-
self.assertIn('
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
#
|
413
|
-
for
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
#
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
self.assertIn('
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
self.assertIn('
|
459
|
-
self.assertIn('
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
#
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
#
|
500
|
-
|
501
|
-
|
502
|
-
#
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
#
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
self.assertEqual(stats['
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
sf.
|
591
|
-
|
592
|
-
#
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
#
|
913
|
-
|
914
|
-
#
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
sf.
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
sf.make_jaxpr(x, multiplier=
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
sf.make_jaxpr(x,
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
#
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
#
|
1262
|
-
|
1263
|
-
|
1264
|
-
#
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
sf.
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import threading
|
18
|
+
import unittest
|
19
|
+
import warnings
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import jax.random as jr
|
24
|
+
import pytest
|
25
|
+
|
26
|
+
import brainstate
|
27
|
+
from brainstate._compatible_import import jaxpr_as_fun
|
28
|
+
from brainstate._error import BatchAxisError
|
29
|
+
from brainstate.transform._make_jaxpr import _BoundedCache, make_hashable
|
30
|
+
from brainstate.util import filter as state_filter
|
31
|
+
|
32
|
+
|
33
|
+
class TestMakeJaxpr(unittest.TestCase):
|
34
|
+
def test_compar_jax_make_jaxpr(self):
|
35
|
+
def func4(arg): # Arg is a pair
|
36
|
+
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
37
|
+
c = brainstate.random.rand_like(arg[0])
|
38
|
+
return jnp.sum(temp + c)
|
39
|
+
|
40
|
+
key = brainstate.random.DEFAULT.value
|
41
|
+
jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
42
|
+
print(jaxpr)
|
43
|
+
self.assertTrue(len(jaxpr.in_avals) == 2)
|
44
|
+
self.assertTrue(len(jaxpr.consts) == 1)
|
45
|
+
self.assertTrue(len(jaxpr.out_avals) == 1)
|
46
|
+
self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
|
47
|
+
|
48
|
+
brainstate.random.seed(1)
|
49
|
+
print(brainstate.random.DEFAULT.value)
|
50
|
+
|
51
|
+
jaxpr2, states = brainstate.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
52
|
+
print(jaxpr2)
|
53
|
+
self.assertTrue(len(jaxpr2.in_avals) == 3)
|
54
|
+
self.assertTrue(len(jaxpr2.out_avals) == 2)
|
55
|
+
self.assertTrue(len(jaxpr2.consts) == 0)
|
56
|
+
print(brainstate.random.DEFAULT.value)
|
57
|
+
|
58
|
+
def test_StatefulFunction_1(self):
|
59
|
+
def func4(arg): # Arg is a pair
|
60
|
+
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
61
|
+
c = brainstate.random.rand_like(arg[0])
|
62
|
+
return jnp.sum(temp + c)
|
63
|
+
|
64
|
+
fun = brainstate.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
|
65
|
+
cache_key = fun.get_arg_cache_key((jnp.zeros(8), jnp.ones(8)))
|
66
|
+
print(fun.get_states_by_cache(cache_key))
|
67
|
+
print(fun.get_jaxpr_by_cache(cache_key))
|
68
|
+
|
69
|
+
def test_StatefulFunction_2(self):
|
70
|
+
st1 = brainstate.State(jnp.ones(10))
|
71
|
+
|
72
|
+
def f1(x):
|
73
|
+
st1.value = x + st1.value
|
74
|
+
|
75
|
+
def f2(x):
|
76
|
+
jaxpr = brainstate.transform.make_jaxpr(f1)(x)
|
77
|
+
c = 1. + x
|
78
|
+
return c
|
79
|
+
|
80
|
+
def f3(x):
|
81
|
+
jaxpr = brainstate.transform.make_jaxpr(f1)(x)
|
82
|
+
c = 1.
|
83
|
+
return c
|
84
|
+
|
85
|
+
print()
|
86
|
+
jaxpr = brainstate.transform.make_jaxpr(f1)(jnp.zeros(1))
|
87
|
+
print(jaxpr)
|
88
|
+
jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
|
89
|
+
print(jaxpr)
|
90
|
+
jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
|
91
|
+
print(jaxpr)
|
92
|
+
jaxpr, _ = brainstate.transform.make_jaxpr(f3)(jnp.zeros(1))
|
93
|
+
print(jaxpr)
|
94
|
+
self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
95
|
+
f3(jnp.zeros(1))))
|
96
|
+
|
97
|
+
def test_compare_jax_make_jaxpr2(self):
|
98
|
+
st1 = brainstate.State(jnp.ones(10))
|
99
|
+
|
100
|
+
def fa(x):
|
101
|
+
st1.value = x + st1.value
|
102
|
+
|
103
|
+
def ffa(x):
|
104
|
+
jaxpr, states = brainstate.transform.make_jaxpr(fa)(x)
|
105
|
+
c = 1. + x
|
106
|
+
return c
|
107
|
+
|
108
|
+
jaxpr, states = brainstate.transform.make_jaxpr(ffa)(jnp.zeros(1))
|
109
|
+
print()
|
110
|
+
print(jaxpr)
|
111
|
+
print(states)
|
112
|
+
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
|
113
|
+
jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
|
114
|
+
print(jaxpr)
|
115
|
+
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
116
|
+
|
117
|
+
def test_compare_jax_make_jaxpr3(self):
|
118
|
+
def fa(x):
|
119
|
+
return 1.
|
120
|
+
|
121
|
+
jaxpr, states = brainstate.transform.make_jaxpr(fa)(jnp.zeros(1))
|
122
|
+
print()
|
123
|
+
print(jaxpr)
|
124
|
+
print(states)
|
125
|
+
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
126
|
+
jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
|
127
|
+
print(jaxpr)
|
128
|
+
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
129
|
+
|
130
|
+
def test_static_argnames(self):
|
131
|
+
def func4(a, b): # Arg is a pair
|
132
|
+
temp = a + jnp.sin(b) * 3.
|
133
|
+
c = brainstate.random.rand_like(a)
|
134
|
+
return jnp.sum(temp + c)
|
135
|
+
|
136
|
+
jaxpr, states = brainstate.transform.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
|
137
|
+
print()
|
138
|
+
print(jaxpr)
|
139
|
+
print(states)
|
140
|
+
|
141
|
+
def test_state_in(self):
|
142
|
+
def f(a):
|
143
|
+
return a.value
|
144
|
+
|
145
|
+
with pytest.raises(ValueError):
|
146
|
+
brainstate.transform.StatefulFunction(f).make_jaxpr(brainstate.State(1.))
|
147
|
+
|
148
|
+
def test_state_out(self):
|
149
|
+
def f(a):
|
150
|
+
return brainstate.State(a)
|
151
|
+
|
152
|
+
with pytest.raises(ValueError):
|
153
|
+
brainstate.transform.StatefulFunction(f).make_jaxpr(1.)
|
154
|
+
|
155
|
+
def test_return_states(self):
|
156
|
+
a = brainstate.State(jnp.ones(3))
|
157
|
+
|
158
|
+
@brainstate.transform.jit
|
159
|
+
def f():
|
160
|
+
return a
|
161
|
+
|
162
|
+
with pytest.raises(ValueError):
|
163
|
+
f()
|
164
|
+
|
165
|
+
|
166
|
+
class TestBoundedCache(unittest.TestCase):
|
167
|
+
"""Test the _BoundedCache class."""
|
168
|
+
|
169
|
+
def test_cache_basic_operations(self):
|
170
|
+
"""Test basic get and set operations."""
|
171
|
+
cache = _BoundedCache(maxsize=3)
|
172
|
+
|
173
|
+
# Test set and get
|
174
|
+
cache.set('key1', 'value1')
|
175
|
+
self.assertEqual(cache.get('key1'), 'value1')
|
176
|
+
|
177
|
+
# Test default value
|
178
|
+
self.assertIsNone(cache.get('nonexistent'))
|
179
|
+
self.assertEqual(cache.get('nonexistent', 'default'), 'default')
|
180
|
+
|
181
|
+
# Test __contains__
|
182
|
+
self.assertIn('key1', cache)
|
183
|
+
self.assertNotIn('key2', cache)
|
184
|
+
|
185
|
+
# Test __len__
|
186
|
+
self.assertEqual(len(cache), 1)
|
187
|
+
|
188
|
+
def test_cache_lru_eviction(self):
|
189
|
+
"""Test LRU eviction when cache is full."""
|
190
|
+
cache = _BoundedCache(maxsize=3)
|
191
|
+
|
192
|
+
# Fill cache
|
193
|
+
cache.set('key1', 'value1')
|
194
|
+
cache.set('key2', 'value2')
|
195
|
+
cache.set('key3', 'value3')
|
196
|
+
self.assertEqual(len(cache), 3)
|
197
|
+
|
198
|
+
# Add one more, should evict key1 (least recently used)
|
199
|
+
cache.set('key4', 'value4')
|
200
|
+
self.assertEqual(len(cache), 3)
|
201
|
+
self.assertNotIn('key1', cache)
|
202
|
+
self.assertIn('key4', cache)
|
203
|
+
|
204
|
+
# Access key2 to make it recently used
|
205
|
+
cache.get('key2')
|
206
|
+
|
207
|
+
# Add another key, should evict key3 (now least recently used)
|
208
|
+
cache.set('key5', 'value5')
|
209
|
+
self.assertNotIn('key3', cache)
|
210
|
+
self.assertIn('key2', cache)
|
211
|
+
|
212
|
+
def test_cache_update_existing(self):
|
213
|
+
"""Test updating an existing key."""
|
214
|
+
cache = _BoundedCache(maxsize=2)
|
215
|
+
|
216
|
+
cache.set('key1', 'value1')
|
217
|
+
cache.set('key2', 'value2')
|
218
|
+
|
219
|
+
# Update key1 (should move it to end)
|
220
|
+
cache.replace('key1', 'updated_value1')
|
221
|
+
self.assertEqual(cache.get('key1'), 'updated_value1')
|
222
|
+
|
223
|
+
# Add new key, should evict key2 (now LRU)
|
224
|
+
cache.set('key3', 'value3')
|
225
|
+
self.assertNotIn('key2', cache)
|
226
|
+
self.assertIn('key1', cache)
|
227
|
+
|
228
|
+
def test_cache_statistics(self):
|
229
|
+
"""Test cache statistics tracking."""
|
230
|
+
cache = _BoundedCache(maxsize=5)
|
231
|
+
|
232
|
+
# Initial stats
|
233
|
+
stats = cache.get_stats()
|
234
|
+
self.assertEqual(stats['size'], 0)
|
235
|
+
self.assertEqual(stats['maxsize'], 5)
|
236
|
+
self.assertEqual(stats['hits'], 0)
|
237
|
+
self.assertEqual(stats['misses'], 0)
|
238
|
+
self.assertEqual(stats['hit_rate'], 0.0)
|
239
|
+
|
240
|
+
# Add items and test hits/misses
|
241
|
+
cache.set('key1', 'value1')
|
242
|
+
cache.set('key2', 'value2')
|
243
|
+
|
244
|
+
# Generate hits
|
245
|
+
cache.get('key1') # hit
|
246
|
+
cache.get('key1') # hit
|
247
|
+
cache.get('key3') # miss
|
248
|
+
cache.get('key2') # hit
|
249
|
+
|
250
|
+
stats = cache.get_stats()
|
251
|
+
self.assertEqual(stats['size'], 2)
|
252
|
+
self.assertEqual(stats['hits'], 3)
|
253
|
+
self.assertEqual(stats['misses'], 1)
|
254
|
+
self.assertEqual(stats['hit_rate'], 75.0)
|
255
|
+
|
256
|
+
def test_cache_clear(self):
|
257
|
+
"""Test clearing the cache."""
|
258
|
+
cache = _BoundedCache(maxsize=5)
|
259
|
+
|
260
|
+
# Add items
|
261
|
+
cache.set('key1', 'value1')
|
262
|
+
cache.set('key2', 'value2')
|
263
|
+
cache.get('key1') # Generate a hit
|
264
|
+
|
265
|
+
# Clear cache
|
266
|
+
cache.clear()
|
267
|
+
|
268
|
+
self.assertEqual(len(cache), 0)
|
269
|
+
self.assertNotIn('key1', cache)
|
270
|
+
|
271
|
+
# Check stats are reset
|
272
|
+
stats = cache.get_stats()
|
273
|
+
self.assertEqual(stats['hits'], 0)
|
274
|
+
self.assertEqual(stats['misses'], 0)
|
275
|
+
|
276
|
+
def test_cache_keys(self):
|
277
|
+
"""Test getting all cache keys."""
|
278
|
+
cache = _BoundedCache(maxsize=5)
|
279
|
+
|
280
|
+
cache.set('key1', 'value1')
|
281
|
+
cache.set('key2', 'value2')
|
282
|
+
cache.set('key3', 'value3')
|
283
|
+
|
284
|
+
keys = cache.keys()
|
285
|
+
self.assertEqual(set(keys), {'key1', 'key2', 'key3'})
|
286
|
+
|
287
|
+
def test_cache_set_duplicate_raises(self):
|
288
|
+
"""Test that setting an existing key raises ValueError."""
|
289
|
+
cache = _BoundedCache(maxsize=5)
|
290
|
+
|
291
|
+
cache.set('key1', 'value1')
|
292
|
+
|
293
|
+
# Attempting to set the same key should raise ValueError
|
294
|
+
with pytest.raises(ValueError, match="Cache key already exists"):
|
295
|
+
cache.set('key1', 'value2')
|
296
|
+
|
297
|
+
def test_cache_pop(self):
|
298
|
+
"""Test pop method."""
|
299
|
+
cache = _BoundedCache(maxsize=5)
|
300
|
+
|
301
|
+
cache.set('key1', 'value1')
|
302
|
+
cache.set('key2', 'value2')
|
303
|
+
|
304
|
+
# Pop existing key
|
305
|
+
value = cache.pop('key1')
|
306
|
+
self.assertEqual(value, 'value1')
|
307
|
+
self.assertNotIn('key1', cache)
|
308
|
+
self.assertEqual(len(cache), 1)
|
309
|
+
|
310
|
+
# Pop non-existent key with default
|
311
|
+
value = cache.pop('nonexistent', 'default')
|
312
|
+
self.assertEqual(value, 'default')
|
313
|
+
|
314
|
+
# Pop non-existent key without default
|
315
|
+
value = cache.pop('nonexistent')
|
316
|
+
self.assertIsNone(value)
|
317
|
+
|
318
|
+
def test_cache_replace(self):
|
319
|
+
"""Test replace method."""
|
320
|
+
cache = _BoundedCache(maxsize=5)
|
321
|
+
|
322
|
+
cache.set('key1', 'value1')
|
323
|
+
cache.set('key2', 'value2')
|
324
|
+
|
325
|
+
# Replace existing key
|
326
|
+
cache.replace('key1', 'new_value1')
|
327
|
+
self.assertEqual(cache.get('key1'), 'new_value1')
|
328
|
+
|
329
|
+
# Replacing should move to end (most recently used)
|
330
|
+
cache.set('key3', 'value3')
|
331
|
+
cache.replace('key2', 'new_value2')
|
332
|
+
|
333
|
+
# Add more items to test LRU behavior
|
334
|
+
cache.set('key4', 'value4')
|
335
|
+
cache.set('key5', 'value5')
|
336
|
+
|
337
|
+
# Now when we add key6, key1 should be evicted (oldest after replace moved key2 to end)
|
338
|
+
cache.set('key6', 'value6')
|
339
|
+
|
340
|
+
# key2 should still be there because replace moved it to end
|
341
|
+
self.assertIn('key2', cache)
|
342
|
+
|
343
|
+
def test_cache_replace_nonexistent_raises(self):
|
344
|
+
"""Test that replacing a non-existent key raises KeyError."""
|
345
|
+
cache = _BoundedCache(maxsize=5)
|
346
|
+
|
347
|
+
with pytest.raises(KeyError, match="Cache key does not exist"):
|
348
|
+
cache.replace('nonexistent', 'value')
|
349
|
+
|
350
|
+
def test_cache_get_with_raise_on_miss(self):
|
351
|
+
"""Test get method with raise_on_miss parameter."""
|
352
|
+
cache = _BoundedCache(maxsize=5)
|
353
|
+
|
354
|
+
cache.set('key1', 'value1')
|
355
|
+
cache.set('key2', 'value2')
|
356
|
+
|
357
|
+
# Should work normally for existing key
|
358
|
+
value = cache.get('key1', raise_on_miss=True)
|
359
|
+
self.assertEqual(value, 'value1')
|
360
|
+
|
361
|
+
# Should raise ValueError for missing key with raise_on_miss=True
|
362
|
+
with pytest.raises(ValueError, match="not compiled for the requested cache key"):
|
363
|
+
cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
|
364
|
+
|
365
|
+
def test_cache_detailed_error_message(self):
|
366
|
+
"""Test that error message shows available keys."""
|
367
|
+
cache = _BoundedCache(maxsize=5)
|
368
|
+
|
369
|
+
cache.set('key1', 'value1')
|
370
|
+
cache.set('key2', 'value2')
|
371
|
+
|
372
|
+
# Error should include all available keys
|
373
|
+
with pytest.raises(ValueError) as exc_info:
|
374
|
+
cache.get('nonexistent', raise_on_miss=True, error_context="Test item")
|
375
|
+
|
376
|
+
error_msg = str(exc_info.value)
|
377
|
+
# Should show requested key
|
378
|
+
self.assertIn('nonexistent', error_msg)
|
379
|
+
# Should show available keys
|
380
|
+
self.assertIn('key1', error_msg)
|
381
|
+
self.assertIn('key2', error_msg)
|
382
|
+
# Should have helpful message
|
383
|
+
self.assertIn('make_jaxpr()', error_msg)
|
384
|
+
|
385
|
+
def test_cache_error_message_no_keys(self):
|
386
|
+
"""Test error message when cache is empty."""
|
387
|
+
cache = _BoundedCache(maxsize=5)
|
388
|
+
|
389
|
+
with pytest.raises(ValueError) as exc_info:
|
390
|
+
cache.get('key', raise_on_miss=True, error_context="Empty cache")
|
391
|
+
|
392
|
+
error_msg = str(exc_info.value)
|
393
|
+
# Should indicate no keys available
|
394
|
+
self.assertIn('none', error_msg.lower())
|
395
|
+
|
396
|
+
def test_cache_thread_safety(self):
|
397
|
+
"""Test thread safety of cache operations."""
|
398
|
+
cache = _BoundedCache(maxsize=100)
|
399
|
+
errors = []
|
400
|
+
|
401
|
+
def worker(thread_id):
|
402
|
+
try:
|
403
|
+
for i in range(50):
|
404
|
+
key = f'key_{thread_id}_{i}'
|
405
|
+
cache.set(key, f'value_{thread_id}_{i}')
|
406
|
+
value = cache.get(key)
|
407
|
+
if value != f'value_{thread_id}_{i}':
|
408
|
+
errors.append(f'Mismatch in thread {thread_id}')
|
409
|
+
except Exception as e:
|
410
|
+
errors.append(f'Error in thread {thread_id}: {e}')
|
411
|
+
|
412
|
+
# Create multiple threads
|
413
|
+
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
|
414
|
+
|
415
|
+
# Start all threads
|
416
|
+
for t in threads:
|
417
|
+
t.start()
|
418
|
+
|
419
|
+
# Wait for all threads to complete
|
420
|
+
for t in threads:
|
421
|
+
t.join()
|
422
|
+
|
423
|
+
# Check no errors occurred
|
424
|
+
self.assertEqual(len(errors), 0, f"Thread safety errors: {errors}")
|
425
|
+
|
426
|
+
|
427
|
+
class TestStatefulFunctionEnhancements(unittest.TestCase):
|
428
|
+
"""Test enhancements to StatefulFunction class."""
|
429
|
+
|
430
|
+
def test_cache_stats(self):
|
431
|
+
"""Test get_cache_stats method."""
|
432
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
433
|
+
|
434
|
+
def f(x):
|
435
|
+
state.value += x
|
436
|
+
return state.value * 2
|
437
|
+
|
438
|
+
sf = brainstate.transform.StatefulFunction(f)
|
439
|
+
|
440
|
+
# Compile for different inputs
|
441
|
+
x1 = jnp.array([0.5, 0.5])
|
442
|
+
x2 = jnp.array([1.0, 1.0])
|
443
|
+
|
444
|
+
sf.make_jaxpr(x1)
|
445
|
+
sf.make_jaxpr(x2)
|
446
|
+
|
447
|
+
# Get cache stats
|
448
|
+
stats = sf.get_cache_stats()
|
449
|
+
|
450
|
+
# Verify all cache types are present
|
451
|
+
self.assertIn('jaxpr_cache', stats)
|
452
|
+
self.assertIn('out_shapes_cache', stats)
|
453
|
+
self.assertIn('jaxpr_out_tree_cache', stats)
|
454
|
+
self.assertIn('state_trace_cache', stats)
|
455
|
+
|
456
|
+
# Verify each cache has proper stats
|
457
|
+
for cache_name, cache_stats in stats.items():
|
458
|
+
self.assertIn('size', cache_stats)
|
459
|
+
self.assertIn('maxsize', cache_stats)
|
460
|
+
self.assertIn('hits', cache_stats)
|
461
|
+
self.assertIn('misses', cache_stats)
|
462
|
+
self.assertIn('hit_rate', cache_stats)
|
463
|
+
|
464
|
+
def test_validate_states(self):
|
465
|
+
"""Test validate_states method."""
|
466
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
467
|
+
|
468
|
+
def f(x):
|
469
|
+
state.value += x
|
470
|
+
return state.value
|
471
|
+
|
472
|
+
sf = brainstate.transform.StatefulFunction(f)
|
473
|
+
x = jnp.array([0.5, 0.5])
|
474
|
+
sf.make_jaxpr(x)
|
475
|
+
|
476
|
+
cache_key = sf.get_arg_cache_key(x)
|
477
|
+
|
478
|
+
# Should validate successfully
|
479
|
+
result = sf.validate_states(cache_key)
|
480
|
+
self.assertTrue(result)
|
481
|
+
|
482
|
+
def test_validate_all_states(self):
|
483
|
+
"""Test validate_all_states method."""
|
484
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
485
|
+
|
486
|
+
def f(x, n):
|
487
|
+
state.value += x
|
488
|
+
return state.value * n
|
489
|
+
|
490
|
+
# Use static_argnums to create different cache keys
|
491
|
+
sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
|
492
|
+
|
493
|
+
# Compile for multiple inputs with different static args
|
494
|
+
x = jnp.array([0.5, 0.5])
|
495
|
+
|
496
|
+
sf.make_jaxpr(x, 1)
|
497
|
+
sf.make_jaxpr(x, 2)
|
498
|
+
|
499
|
+
# Validate all
|
500
|
+
results = sf.validate_all_states()
|
501
|
+
|
502
|
+
# Should have results for both cache keys
|
503
|
+
self.assertEqual(len(results), 2)
|
504
|
+
|
505
|
+
# All should be valid
|
506
|
+
for result in results.values():
|
507
|
+
self.assertTrue(result)
|
508
|
+
|
509
|
+
def test_clear_cache(self):
|
510
|
+
"""Test clear_cache method."""
|
511
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
512
|
+
|
513
|
+
def f(x):
|
514
|
+
state.value += x
|
515
|
+
return state.value
|
516
|
+
|
517
|
+
sf = brainstate.transform.StatefulFunction(f)
|
518
|
+
x = jnp.array([0.5, 0.5])
|
519
|
+
sf.make_jaxpr(x)
|
520
|
+
|
521
|
+
# Verify cache has entries
|
522
|
+
stats = sf.get_cache_stats()
|
523
|
+
self.assertGreater(stats['jaxpr_cache']['size'], 0)
|
524
|
+
|
525
|
+
# Clear cache
|
526
|
+
sf.clear_cache()
|
527
|
+
|
528
|
+
# Verify all caches are empty
|
529
|
+
stats = sf.get_cache_stats()
|
530
|
+
self.assertEqual(stats['jaxpr_cache']['size'], 0)
|
531
|
+
self.assertEqual(stats['out_shapes_cache']['size'], 0)
|
532
|
+
self.assertEqual(stats['jaxpr_out_tree_cache']['size'], 0)
|
533
|
+
self.assertEqual(stats['state_trace_cache']['size'], 0)
|
534
|
+
|
535
|
+
def test_return_only_write_parameter(self):
|
536
|
+
"""Test return_only_write parameter."""
|
537
|
+
read_state = brainstate.State(jnp.array([1.0, 2.0]))
|
538
|
+
write_state = brainstate.State(jnp.array([3.0, 4.0]))
|
539
|
+
|
540
|
+
def f(x):
|
541
|
+
# Read from read_state, write to write_state
|
542
|
+
_ = read_state.value + x
|
543
|
+
write_state.value += x
|
544
|
+
return write_state.value
|
545
|
+
|
546
|
+
# Test with return_only_write=False (default)
|
547
|
+
sf_all = brainstate.transform.StatefulFunction(f, return_only_write=False)
|
548
|
+
sf_all.make_jaxpr(jnp.array([0.5, 0.5]))
|
549
|
+
cache_key = sf_all.get_arg_cache_key(jnp.array([0.5, 0.5]))
|
550
|
+
states_all = sf_all.get_states_by_cache(cache_key)
|
551
|
+
|
552
|
+
# Test with return_only_write=True
|
553
|
+
sf_write_only = brainstate.transform.StatefulFunction(f, return_only_write=True)
|
554
|
+
sf_write_only.make_jaxpr(jnp.array([0.5, 0.5]))
|
555
|
+
cache_key_write = sf_write_only.get_arg_cache_key(jnp.array([0.5, 0.5]))
|
556
|
+
states_write = sf_write_only.get_states_by_cache(cache_key_write)
|
557
|
+
|
558
|
+
# With return_only_write=True, should have fewer or equal states
|
559
|
+
self.assertLessEqual(len(states_write), len(states_all))
|
560
|
+
|
561
|
+
|
562
|
+
class TestErrorHandling(unittest.TestCase):
|
563
|
+
"""Test error handling in StatefulFunction."""
|
564
|
+
|
565
|
+
def test_jaxpr_call_state_mismatch(self):
|
566
|
+
"""Test error when state values length doesn't match."""
|
567
|
+
state1 = brainstate.State(jnp.array([1.0, 2.0]))
|
568
|
+
state2 = brainstate.State(jnp.array([3.0, 4.0]))
|
569
|
+
|
570
|
+
def f(x):
|
571
|
+
state1.value += x
|
572
|
+
state2.value += x
|
573
|
+
return state1.value + state2.value
|
574
|
+
|
575
|
+
sf = brainstate.transform.StatefulFunction(f)
|
576
|
+
x = jnp.array([0.5, 0.5])
|
577
|
+
sf.make_jaxpr(x)
|
578
|
+
|
579
|
+
# Try to call with wrong number of state values (only 1 instead of 2)
|
580
|
+
with pytest.raises(ValueError, match="State length mismatch"):
|
581
|
+
sf.jaxpr_call([jnp.array([1.0, 1.0])], x) # Only 1 state instead of 2
|
582
|
+
|
583
|
+
def test_get_jaxpr_not_compiled_detailed_error(self):
|
584
|
+
"""Test detailed error message when getting jaxpr for uncompiled function."""
|
585
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
586
|
+
|
587
|
+
def f(x):
|
588
|
+
return x * 2
|
589
|
+
|
590
|
+
sf = brainstate.transform.StatefulFunction(f)
|
591
|
+
|
592
|
+
# Compile for one input shape
|
593
|
+
sf.make_jaxpr(jnp.array([1.0, 2.0]))
|
594
|
+
|
595
|
+
# Try to get jaxpr with a different cache key
|
596
|
+
from brainstate.transform._make_jaxpr import hashabledict
|
597
|
+
fake_key = hashabledict(
|
598
|
+
static_args=(),
|
599
|
+
dyn_args=(),
|
600
|
+
static_kwargs=(),
|
601
|
+
dyn_kwargs=()
|
602
|
+
)
|
603
|
+
|
604
|
+
# Should raise detailed error
|
605
|
+
with pytest.raises(ValueError) as exc_info:
|
606
|
+
sf.get_jaxpr_by_cache(fake_key)
|
607
|
+
|
608
|
+
error_msg = str(exc_info.value)
|
609
|
+
# Should contain the requested key
|
610
|
+
self.assertIn('Requested key:', error_msg)
|
611
|
+
# Should show available keys
|
612
|
+
self.assertIn('Available', error_msg)
|
613
|
+
# Should have helpful message
|
614
|
+
self.assertIn('make_jaxpr()', error_msg)
|
615
|
+
|
616
|
+
def test_get_out_shapes_not_compiled_detailed_error(self):
|
617
|
+
"""Test detailed error message when getting output shapes for uncompiled function."""
|
618
|
+
|
619
|
+
def f(x):
|
620
|
+
return x * 2
|
621
|
+
|
622
|
+
sf = brainstate.transform.StatefulFunction(f)
|
623
|
+
|
624
|
+
from brainstate.transform._make_jaxpr import hashabledict
|
625
|
+
fake_key = hashabledict(
|
626
|
+
static_args=(),
|
627
|
+
dyn_args=(),
|
628
|
+
static_kwargs=(),
|
629
|
+
dyn_kwargs=()
|
630
|
+
)
|
631
|
+
|
632
|
+
# Should raise detailed error with context "Output shapes"
|
633
|
+
with pytest.raises(ValueError) as exc_info:
|
634
|
+
sf.get_out_shapes_by_cache(fake_key)
|
635
|
+
|
636
|
+
error_msg = str(exc_info.value)
|
637
|
+
self.assertIn('Output shapes', error_msg)
|
638
|
+
self.assertIn('Requested key:', error_msg)
|
639
|
+
|
640
|
+
def test_get_out_treedef_not_compiled_detailed_error(self):
|
641
|
+
"""Test detailed error message when getting output tree for uncompiled function."""
|
642
|
+
|
643
|
+
def f(x):
|
644
|
+
return x * 2
|
645
|
+
|
646
|
+
sf = brainstate.transform.StatefulFunction(f)
|
647
|
+
|
648
|
+
from brainstate.transform._make_jaxpr import hashabledict
|
649
|
+
fake_key = hashabledict(
|
650
|
+
static_args=(),
|
651
|
+
dyn_args=(),
|
652
|
+
static_kwargs=(),
|
653
|
+
dyn_kwargs=()
|
654
|
+
)
|
655
|
+
|
656
|
+
# Should raise detailed error with context "Output tree"
|
657
|
+
with pytest.raises(ValueError) as exc_info:
|
658
|
+
sf.get_out_treedef_by_cache(fake_key)
|
659
|
+
|
660
|
+
error_msg = str(exc_info.value)
|
661
|
+
self.assertIn('Output tree', error_msg)
|
662
|
+
self.assertIn('Requested key:', error_msg)
|
663
|
+
|
664
|
+
def test_get_state_trace_not_compiled_detailed_error(self):
|
665
|
+
"""Test detailed error message when getting state trace for uncompiled function."""
|
666
|
+
|
667
|
+
def f(x):
|
668
|
+
return x * 2
|
669
|
+
|
670
|
+
sf = brainstate.transform.StatefulFunction(f)
|
671
|
+
|
672
|
+
from brainstate.transform._make_jaxpr import hashabledict
|
673
|
+
fake_key = hashabledict(
|
674
|
+
static_args=(),
|
675
|
+
dyn_args=(),
|
676
|
+
static_kwargs=(),
|
677
|
+
dyn_kwargs=()
|
678
|
+
)
|
679
|
+
|
680
|
+
# Should raise detailed error with context "State trace"
|
681
|
+
with pytest.raises(ValueError) as exc_info:
|
682
|
+
sf.get_state_trace_by_cache(fake_key)
|
683
|
+
|
684
|
+
error_msg = str(exc_info.value)
|
685
|
+
self.assertIn('State trace', error_msg)
|
686
|
+
self.assertIn('Requested key:', error_msg)
|
687
|
+
|
688
|
+
|
689
|
+
class TestCompileIfMiss(unittest.TestCase):
|
690
|
+
"""Test compile_if_miss parameter in *_by_call methods."""
|
691
|
+
|
692
|
+
def test_get_jaxpr_by_call_with_compile_if_miss_true(self):
|
693
|
+
"""Test get_jaxpr_by_call with compile_if_miss=True (default)."""
|
694
|
+
|
695
|
+
def f(x):
|
696
|
+
return x * 2
|
697
|
+
|
698
|
+
sf = brainstate.transform.StatefulFunction(f)
|
699
|
+
|
700
|
+
# Should compile automatically
|
701
|
+
jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=True)
|
702
|
+
self.assertIsNotNone(jaxpr)
|
703
|
+
|
704
|
+
def test_get_jaxpr_by_call_with_compile_if_miss_false(self):
|
705
|
+
"""Test get_jaxpr_by_call with compile_if_miss=False."""
|
706
|
+
|
707
|
+
def f(x):
|
708
|
+
return x * 2
|
709
|
+
|
710
|
+
sf = brainstate.transform.StatefulFunction(f)
|
711
|
+
|
712
|
+
# Should raise error because not compiled
|
713
|
+
with pytest.raises(ValueError, match="not compiled"):
|
714
|
+
sf.get_jaxpr(jnp.array([1.0, 2.0]), compile_if_miss=False)
|
715
|
+
|
716
|
+
def test_get_out_shapes_by_call_compile_if_miss(self):
|
717
|
+
"""Test get_out_shapes_by_call with compile_if_miss parameter."""
|
718
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
719
|
+
|
720
|
+
def f(x):
|
721
|
+
state.value += x
|
722
|
+
return state.value * 2
|
723
|
+
|
724
|
+
sf = brainstate.transform.StatefulFunction(f)
|
725
|
+
|
726
|
+
# With compile_if_miss=True, should compile automatically
|
727
|
+
shapes = sf.get_out_shapes(jnp.array([1.0, 2.0]), compile_if_miss=True)
|
728
|
+
self.assertIsNotNone(shapes)
|
729
|
+
|
730
|
+
# With compile_if_miss=False on different input, should fail
|
731
|
+
with pytest.raises(ValueError):
|
732
|
+
sf.get_out_shapes(jnp.array([1.0, 2.0, 3.0]), compile_if_miss=False)
|
733
|
+
|
734
|
+
def test_get_out_treedef_by_call_compile_if_miss(self):
|
735
|
+
"""Test get_out_treedef_by_call with compile_if_miss parameter."""
|
736
|
+
|
737
|
+
def f(x):
|
738
|
+
return x * 2, x + 1
|
739
|
+
|
740
|
+
sf = brainstate.transform.StatefulFunction(f)
|
741
|
+
|
742
|
+
# Should compile automatically with default compile_if_miss=True
|
743
|
+
treedef = sf.get_out_treedef(jnp.array([1.0, 2.0]))
|
744
|
+
self.assertIsNotNone(treedef)
|
745
|
+
|
746
|
+
def test_get_state_trace_by_call_compile_if_miss(self):
|
747
|
+
"""Test get_state_trace_by_call with compile_if_miss parameter."""
|
748
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
749
|
+
|
750
|
+
def f(x):
|
751
|
+
state.value += x
|
752
|
+
return state.value
|
753
|
+
|
754
|
+
sf = brainstate.transform.StatefulFunction(f)
|
755
|
+
|
756
|
+
# Should compile automatically
|
757
|
+
trace = sf.get_state_trace(jnp.array([1.0, 2.0]), compile_if_miss=True)
|
758
|
+
self.assertIsNotNone(trace)
|
759
|
+
|
760
|
+
def test_get_states_by_call_compile_if_miss(self):
|
761
|
+
"""Test get_states_by_call with compile_if_miss parameter."""
|
762
|
+
state1 = brainstate.State(jnp.array([1.0, 2.0]))
|
763
|
+
state2 = brainstate.State(jnp.array([3.0, 4.0]))
|
764
|
+
|
765
|
+
def f(x):
|
766
|
+
state1.value += x
|
767
|
+
state2.value += x
|
768
|
+
return state1.value + state2.value
|
769
|
+
|
770
|
+
sf = brainstate.transform.StatefulFunction(f)
|
771
|
+
|
772
|
+
# Should compile automatically
|
773
|
+
states = sf.get_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
|
774
|
+
self.assertEqual(len(states), 2)
|
775
|
+
|
776
|
+
def test_get_read_states_by_call_compile_if_miss(self):
|
777
|
+
"""Test get_read_states_by_call with compile_if_miss parameter."""
|
778
|
+
read_state = brainstate.State(jnp.array([1.0, 2.0]))
|
779
|
+
write_state = brainstate.State(jnp.array([3.0, 4.0]))
|
780
|
+
|
781
|
+
def f(x):
|
782
|
+
_ = read_state.value
|
783
|
+
write_state.value += x
|
784
|
+
return write_state.value
|
785
|
+
|
786
|
+
sf = brainstate.transform.StatefulFunction(f)
|
787
|
+
|
788
|
+
# Should compile automatically
|
789
|
+
read_states = sf.get_read_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
|
790
|
+
self.assertIsNotNone(read_states)
|
791
|
+
|
792
|
+
def test_get_write_states_by_call_compile_if_miss(self):
|
793
|
+
"""Test get_write_states_by_call with compile_if_miss parameter."""
|
794
|
+
read_state = brainstate.State(jnp.array([1.0, 2.0]))
|
795
|
+
write_state = brainstate.State(jnp.array([3.0, 4.0]))
|
796
|
+
|
797
|
+
def f(x):
|
798
|
+
_ = read_state.value
|
799
|
+
write_state.value += x
|
800
|
+
return write_state.value
|
801
|
+
|
802
|
+
sf = brainstate.transform.StatefulFunction(f)
|
803
|
+
|
804
|
+
# Should compile automatically
|
805
|
+
write_states = sf.get_write_states(jnp.array([1.0, 2.0]), compile_if_miss=True)
|
806
|
+
self.assertIsNotNone(write_states)
|
807
|
+
|
808
|
+
def test_compile_if_miss_default_behavior(self):
|
809
|
+
"""Test that compile_if_miss defaults to True for all *_by_call methods."""
|
810
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
811
|
+
|
812
|
+
def f(x):
|
813
|
+
state.value += x
|
814
|
+
return state.value
|
815
|
+
|
816
|
+
sf = brainstate.transform.StatefulFunction(f)
|
817
|
+
|
818
|
+
# All these should work without explicit compile_if_miss=True
|
819
|
+
jaxpr = sf.get_jaxpr(jnp.array([1.0, 2.0]))
|
820
|
+
self.assertIsNotNone(jaxpr)
|
821
|
+
|
822
|
+
# Create new instance for fresh cache
|
823
|
+
sf2 = brainstate.transform.StatefulFunction(f)
|
824
|
+
shapes = sf2.get_out_shapes(jnp.array([1.0, 2.0]))
|
825
|
+
self.assertIsNotNone(shapes)
|
826
|
+
|
827
|
+
# Create new instance for fresh cache
|
828
|
+
sf3 = brainstate.transform.StatefulFunction(f)
|
829
|
+
states = sf3.get_states(jnp.array([1.0, 2.0]))
|
830
|
+
self.assertIsNotNone(states)
|
831
|
+
|
832
|
+
|
833
|
+
class TestMakeHashable(unittest.TestCase):
|
834
|
+
"""Test the make_hashable utility function."""
|
835
|
+
|
836
|
+
def test_hashable_list(self):
|
837
|
+
"""Test converting list to hashable."""
|
838
|
+
result = make_hashable([1, 2, 3])
|
839
|
+
# Should return a tuple
|
840
|
+
self.assertIsInstance(result, tuple)
|
841
|
+
# Should be hashable
|
842
|
+
hash(result)
|
843
|
+
|
844
|
+
def test_hashable_dict(self):
|
845
|
+
"""Test converting dict to hashable."""
|
846
|
+
result = make_hashable({'b': 2, 'a': 1})
|
847
|
+
# Should return a tuple of sorted key-value pairs
|
848
|
+
self.assertIsInstance(result, tuple)
|
849
|
+
# Should be hashable
|
850
|
+
hash(result)
|
851
|
+
# Keys should be sorted
|
852
|
+
keys = [item[0] for item in result]
|
853
|
+
self.assertEqual(keys, ['a', 'b'])
|
854
|
+
|
855
|
+
def test_hashable_set(self):
|
856
|
+
"""Test converting set to hashable."""
|
857
|
+
result = make_hashable({1, 2, 3})
|
858
|
+
# Should return a frozenset
|
859
|
+
self.assertIsInstance(result, frozenset)
|
860
|
+
# Should be hashable
|
861
|
+
hash(result)
|
862
|
+
|
863
|
+
def test_hashable_nested(self):
|
864
|
+
"""Test converting nested structures."""
|
865
|
+
nested = {
|
866
|
+
'list': [1, 2, 3],
|
867
|
+
'dict': {'a': 1, 'b': 2},
|
868
|
+
'set': {4, 5}
|
869
|
+
}
|
870
|
+
result = make_hashable(nested)
|
871
|
+
# Should be hashable
|
872
|
+
hash(result) # Should not raise
|
873
|
+
|
874
|
+
def test_hashable_tuple(self):
|
875
|
+
"""Test with tuples."""
|
876
|
+
result = make_hashable((1, 2, 3))
|
877
|
+
# Should return a tuple
|
878
|
+
self.assertIsInstance(result, tuple)
|
879
|
+
# Should be hashable
|
880
|
+
hash(result)
|
881
|
+
|
882
|
+
def test_hashable_idempotent(self):
|
883
|
+
"""Test that applying make_hashable twice gives consistent results."""
|
884
|
+
original = {'a': [1, 2], 'b': {3, 4}}
|
885
|
+
result1 = make_hashable(original)
|
886
|
+
result2 = make_hashable(original)
|
887
|
+
# Should be the same
|
888
|
+
self.assertEqual(result1, result2)
|
889
|
+
|
890
|
+
|
891
|
+
class TestCacheCleanupOnError(unittest.TestCase):
|
892
|
+
"""Test that cache is properly cleaned up when compilation fails."""
|
893
|
+
|
894
|
+
def test_cache_cleanup_on_compilation_error(self):
|
895
|
+
"""Test that partial cache entries are cleaned up when make_jaxpr fails."""
|
896
|
+
|
897
|
+
def f(x):
|
898
|
+
# This will cause an error during JAX tracing
|
899
|
+
if x > 0: # Control flow not allowed in JAX
|
900
|
+
return x * 2
|
901
|
+
else:
|
902
|
+
return x + 1
|
903
|
+
|
904
|
+
sf = brainstate.transform.StatefulFunction(f)
|
905
|
+
|
906
|
+
# Try to compile, should fail
|
907
|
+
try:
|
908
|
+
sf.make_jaxpr(jnp.array([1.0]))
|
909
|
+
except Exception:
|
910
|
+
pass # Expected to fail
|
911
|
+
|
912
|
+
# Cache should be empty after error
|
913
|
+
stats = sf.get_cache_stats()
|
914
|
+
# All caches should be empty since error cleanup should have removed partial entries
|
915
|
+
# Note: The actual behavior depends on when the error occurs during compilation
|
916
|
+
# If error happens early, no cache entries; if late, entries might exist
|
917
|
+
# This test just verifies the cleanup mechanism exists
|
918
|
+
|
919
|
+
|
920
|
+
class TestMakeJaxprReturnOnlyWrite(unittest.TestCase):
|
921
|
+
"""Test make_jaxpr with return_only_write parameter."""
|
922
|
+
|
923
|
+
def test_make_jaxpr_return_only_write(self):
|
924
|
+
"""Test make_jaxpr function with return_only_write parameter."""
|
925
|
+
read_state = brainstate.State(jnp.array([1.0]))
|
926
|
+
write_state = brainstate.State(jnp.array([2.0]))
|
927
|
+
|
928
|
+
def f(x):
|
929
|
+
_ = read_state.value # Read only
|
930
|
+
write_state.value += x # Write
|
931
|
+
return x * 2
|
932
|
+
|
933
|
+
# Test with return_only_write=True
|
934
|
+
jaxpr_maker = brainstate.transform.make_jaxpr(f, return_only_write=True)
|
935
|
+
jaxpr, states = jaxpr_maker(jnp.array([1.0]))
|
936
|
+
|
937
|
+
# Should compile successfully
|
938
|
+
self.assertIsNotNone(jaxpr)
|
939
|
+
self.assertIsInstance(states, tuple)
|
940
|
+
|
941
|
+
|
942
|
+
class TestStatefulFunctionCallable(unittest.TestCase):
|
943
|
+
"""Test __call__ method of StatefulFunction."""
|
944
|
+
|
945
|
+
def test_stateful_function_call(self):
|
946
|
+
"""Test calling StatefulFunction directly."""
|
947
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
948
|
+
|
949
|
+
def f(x):
|
950
|
+
state.value += x
|
951
|
+
return state.value * 2
|
952
|
+
|
953
|
+
sf = brainstate.transform.StatefulFunction(f)
|
954
|
+
x = jnp.array([0.5, 0.5])
|
955
|
+
sf.make_jaxpr(x)
|
956
|
+
|
957
|
+
# Test direct call
|
958
|
+
result = sf(x)
|
959
|
+
self.assertEqual(result.shape, (2,))
|
960
|
+
|
961
|
+
def test_stateful_function_call_auto_compile(self):
|
962
|
+
"""Test that __call__ automatically compiles if needed."""
|
963
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
964
|
+
|
965
|
+
def f(x):
|
966
|
+
state.value += x
|
967
|
+
return state.value * 2
|
968
|
+
|
969
|
+
sf = brainstate.transform.StatefulFunction(f)
|
970
|
+
x = jnp.array([0.5, 0.5])
|
971
|
+
|
972
|
+
# Call without pre-compilation should work
|
973
|
+
result = sf(x)
|
974
|
+
self.assertEqual(result.shape, (2,))
|
975
|
+
|
976
|
+
def test_stateful_function_multiple_calls(self):
|
977
|
+
"""Test multiple calls to StatefulFunction."""
|
978
|
+
state = brainstate.State(jnp.array([0.0]))
|
979
|
+
|
980
|
+
def f(x):
|
981
|
+
state.value += x
|
982
|
+
return state.value
|
983
|
+
|
984
|
+
sf = brainstate.transform.StatefulFunction(f)
|
985
|
+
|
986
|
+
# Multiple calls should accumulate state
|
987
|
+
result1 = sf(jnp.array([1.0]))
|
988
|
+
result2 = sf(jnp.array([2.0]))
|
989
|
+
result3 = sf(jnp.array([3.0]))
|
990
|
+
|
991
|
+
# Each call should update the state
|
992
|
+
self.assertIsNotNone(result1)
|
993
|
+
self.assertIsNotNone(result2)
|
994
|
+
self.assertIsNotNone(result3)
|
995
|
+
|
996
|
+
|
997
|
+
class TestStatefulFunctionStaticArgs(unittest.TestCase):
|
998
|
+
"""Test StatefulFunction with static arguments."""
|
999
|
+
|
1000
|
+
def test_static_argnums_basic(self):
|
1001
|
+
"""Test basic usage of static_argnums."""
|
1002
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
1003
|
+
|
1004
|
+
def f(x, multiplier):
|
1005
|
+
state.value += x
|
1006
|
+
return state.value * multiplier
|
1007
|
+
|
1008
|
+
sf = brainstate.transform.StatefulFunction(f, static_argnums=(1,))
|
1009
|
+
x = jnp.array([0.5, 0.5])
|
1010
|
+
|
1011
|
+
# Compile with multiplier=2
|
1012
|
+
sf.make_jaxpr(x, 2)
|
1013
|
+
cache_key1 = sf.get_arg_cache_key(x, 2)
|
1014
|
+
|
1015
|
+
# Compile with multiplier=3
|
1016
|
+
sf.make_jaxpr(x, 3)
|
1017
|
+
cache_key2 = sf.get_arg_cache_key(x, 3)
|
1018
|
+
|
1019
|
+
# Should have different cache keys
|
1020
|
+
self.assertNotEqual(cache_key1, cache_key2)
|
1021
|
+
|
1022
|
+
def test_static_argnames_basic(self):
|
1023
|
+
"""Test basic usage of static_argnames."""
|
1024
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
1025
|
+
|
1026
|
+
def f(x, multiplier=2):
|
1027
|
+
state.value += x
|
1028
|
+
return state.value * multiplier
|
1029
|
+
|
1030
|
+
sf = brainstate.transform.StatefulFunction(f, static_argnames='multiplier')
|
1031
|
+
x = jnp.array([0.5, 0.5])
|
1032
|
+
|
1033
|
+
# Compile with different multiplier values
|
1034
|
+
sf.make_jaxpr(x, multiplier=2)
|
1035
|
+
cache_key1 = sf.get_arg_cache_key(x, multiplier=2)
|
1036
|
+
|
1037
|
+
sf.make_jaxpr(x, multiplier=3)
|
1038
|
+
cache_key2 = sf.get_arg_cache_key(x, multiplier=3)
|
1039
|
+
|
1040
|
+
# Should have different cache keys
|
1041
|
+
self.assertNotEqual(cache_key1, cache_key2)
|
1042
|
+
|
1043
|
+
def test_static_args_combination(self):
|
1044
|
+
"""Test using both static_argnums and static_argnames."""
|
1045
|
+
state = brainstate.State(jnp.array([1.0]))
|
1046
|
+
|
1047
|
+
def f(x, multiplier, offset=0):
|
1048
|
+
state.value += x
|
1049
|
+
return state.value * multiplier + offset
|
1050
|
+
|
1051
|
+
sf = brainstate.transform.StatefulFunction(
|
1052
|
+
f, static_argnums=(1,), static_argnames='offset'
|
1053
|
+
)
|
1054
|
+
x = jnp.array([0.5])
|
1055
|
+
|
1056
|
+
# Compile with different static args
|
1057
|
+
sf.make_jaxpr(x, 2, offset=0)
|
1058
|
+
cache_key1 = sf.get_arg_cache_key(x, 2, offset=0)
|
1059
|
+
|
1060
|
+
sf.make_jaxpr(x, 3, offset=1)
|
1061
|
+
cache_key2 = sf.get_arg_cache_key(x, 3, offset=1)
|
1062
|
+
|
1063
|
+
# Should have different cache keys
|
1064
|
+
self.assertNotEqual(cache_key1, cache_key2)
|
1065
|
+
|
1066
|
+
|
1067
|
+
class TestStatefulFunctionComplexStates(unittest.TestCase):
|
1068
|
+
"""Test StatefulFunction with complex state scenarios."""
|
1069
|
+
|
1070
|
+
def test_multiple_states(self):
|
1071
|
+
"""Test function with multiple states."""
|
1072
|
+
state1 = brainstate.State(jnp.array([1.0]))
|
1073
|
+
state2 = brainstate.State(jnp.array([2.0]))
|
1074
|
+
state3 = brainstate.State(jnp.array([3.0]))
|
1075
|
+
|
1076
|
+
def f(x):
|
1077
|
+
state1.value += x
|
1078
|
+
state2.value += x * 2
|
1079
|
+
state3.value += x * 3
|
1080
|
+
return state1.value + state2.value + state3.value
|
1081
|
+
|
1082
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1083
|
+
x = jnp.array([1.0])
|
1084
|
+
sf.make_jaxpr(x)
|
1085
|
+
|
1086
|
+
cache_key = sf.get_arg_cache_key(x)
|
1087
|
+
states = sf.get_states_by_cache(cache_key)
|
1088
|
+
|
1089
|
+
# Should track all three states
|
1090
|
+
self.assertEqual(len(states), 3)
|
1091
|
+
|
1092
|
+
def test_nested_state_access(self):
|
1093
|
+
"""Test function with nested state access patterns."""
|
1094
|
+
outer_state = brainstate.State(jnp.array([1.0]))
|
1095
|
+
inner_state = brainstate.State(jnp.array([2.0]))
|
1096
|
+
|
1097
|
+
def inner_fn(x):
|
1098
|
+
inner_state.value += x
|
1099
|
+
return inner_state.value
|
1100
|
+
|
1101
|
+
def outer_fn(x):
|
1102
|
+
outer_state.value += x
|
1103
|
+
result = inner_fn(x)
|
1104
|
+
return outer_state.value + result
|
1105
|
+
|
1106
|
+
sf = brainstate.transform.StatefulFunction(outer_fn)
|
1107
|
+
x = jnp.array([1.0])
|
1108
|
+
sf.make_jaxpr(x)
|
1109
|
+
|
1110
|
+
cache_key = sf.get_arg_cache_key(x)
|
1111
|
+
states = sf.get_states_by_cache(cache_key)
|
1112
|
+
|
1113
|
+
# Should track both states
|
1114
|
+
self.assertGreaterEqual(len(states), 2)
|
1115
|
+
|
1116
|
+
def test_conditional_state_write(self):
|
1117
|
+
"""Test function that conditionally writes to states."""
|
1118
|
+
state1 = brainstate.State(jnp.array([1.0]))
|
1119
|
+
state2 = brainstate.State(jnp.array([2.0]))
|
1120
|
+
|
1121
|
+
def f(x, write_state1=True):
|
1122
|
+
# Note: In JAX, actual control flow needs special handling
|
1123
|
+
# This test is more about the framework's ability to track states
|
1124
|
+
state1.value += x # Always write to state1
|
1125
|
+
state2.value += x * 2 # Always write to state2
|
1126
|
+
return state1.value + state2.value
|
1127
|
+
|
1128
|
+
sf = brainstate.transform.StatefulFunction(f, static_argnames='write_state1')
|
1129
|
+
x = jnp.array([1.0])
|
1130
|
+
sf.make_jaxpr(x, write_state1=True)
|
1131
|
+
|
1132
|
+
cache_key = sf.get_arg_cache_key(x, write_state1=True)
|
1133
|
+
states = sf.get_states_by_cache(cache_key)
|
1134
|
+
|
1135
|
+
# Should track states
|
1136
|
+
self.assertGreaterEqual(len(states), 2)
|
1137
|
+
|
1138
|
+
|
1139
|
+
class TestStatefulFunctionOutputShapes(unittest.TestCase):
|
1140
|
+
"""Test StatefulFunction output shape tracking."""
|
1141
|
+
|
1142
|
+
def test_single_output(self):
|
1143
|
+
"""Test tracking single output shape."""
|
1144
|
+
state = brainstate.State(jnp.array([1.0, 2.0, 3.0]))
|
1145
|
+
|
1146
|
+
def f(x):
|
1147
|
+
state.value += x
|
1148
|
+
return state.value
|
1149
|
+
|
1150
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1151
|
+
x = jnp.array([1.0, 2.0, 3.0])
|
1152
|
+
sf.make_jaxpr(x)
|
1153
|
+
|
1154
|
+
cache_key = sf.get_arg_cache_key(x)
|
1155
|
+
out_shapes = sf.get_out_shapes_by_cache(cache_key)
|
1156
|
+
|
1157
|
+
# Should have output shapes
|
1158
|
+
self.assertIsNotNone(out_shapes)
|
1159
|
+
|
1160
|
+
def test_multiple_outputs(self):
|
1161
|
+
"""Test tracking multiple output shapes."""
|
1162
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
1163
|
+
|
1164
|
+
def f(x):
|
1165
|
+
state.value += x
|
1166
|
+
return state.value, state.value * 2, jnp.sum(state.value)
|
1167
|
+
|
1168
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1169
|
+
x = jnp.array([1.0, 2.0])
|
1170
|
+
sf.make_jaxpr(x)
|
1171
|
+
|
1172
|
+
cache_key = sf.get_arg_cache_key(x)
|
1173
|
+
out_shapes = sf.get_out_shapes_by_cache(cache_key)
|
1174
|
+
|
1175
|
+
# Should track all output shapes
|
1176
|
+
self.assertIsNotNone(out_shapes)
|
1177
|
+
|
1178
|
+
def test_nested_output_structure(self):
|
1179
|
+
"""Test tracking nested output structures."""
|
1180
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
1181
|
+
|
1182
|
+
def f(x):
|
1183
|
+
state.value += x
|
1184
|
+
return {
|
1185
|
+
'sum': jnp.sum(state.value),
|
1186
|
+
'prod': jnp.prod(state.value),
|
1187
|
+
'values': state.value
|
1188
|
+
}
|
1189
|
+
|
1190
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1191
|
+
x = jnp.array([1.0, 2.0])
|
1192
|
+
sf.make_jaxpr(x)
|
1193
|
+
|
1194
|
+
cache_key = sf.get_arg_cache_key(x)
|
1195
|
+
out_treedef = sf.get_out_treedef_by_cache(cache_key)
|
1196
|
+
|
1197
|
+
# Should have tree definition
|
1198
|
+
self.assertIsNotNone(out_treedef)
|
1199
|
+
|
1200
|
+
|
1201
|
+
class TestStatefulFunctionJaxprCall(unittest.TestCase):
|
1202
|
+
"""Test jaxpr_call and jaxpr_call_auto methods."""
|
1203
|
+
|
1204
|
+
def test_jaxpr_call_basic(self):
|
1205
|
+
"""Test basic jaxpr_call usage."""
|
1206
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
1207
|
+
|
1208
|
+
def f(x):
|
1209
|
+
state.value += x
|
1210
|
+
return state.value * 2
|
1211
|
+
|
1212
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1213
|
+
x = jnp.array([0.5, 0.5])
|
1214
|
+
sf.make_jaxpr(x)
|
1215
|
+
|
1216
|
+
# Get current state values
|
1217
|
+
state_vals = [state.value]
|
1218
|
+
|
1219
|
+
# Call at jaxpr level
|
1220
|
+
new_state_vals, out = sf.jaxpr_call(state_vals, x)
|
1221
|
+
|
1222
|
+
self.assertEqual(len(new_state_vals), 1)
|
1223
|
+
self.assertEqual(out.shape, (2,))
|
1224
|
+
|
1225
|
+
def test_jaxpr_call_auto_basic(self):
|
1226
|
+
"""Test basic jaxpr_call_auto usage."""
|
1227
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
1228
|
+
|
1229
|
+
def f(x):
|
1230
|
+
state.value += x
|
1231
|
+
return state.value * 2
|
1232
|
+
|
1233
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1234
|
+
x = jnp.array([0.5, 0.5])
|
1235
|
+
sf.make_jaxpr(x)
|
1236
|
+
|
1237
|
+
# Call with automatic state management
|
1238
|
+
result = sf.jaxpr_call_auto(x)
|
1239
|
+
|
1240
|
+
self.assertEqual(result.shape, (2,))
|
1241
|
+
|
1242
|
+
def test_jaxpr_call_preserves_state_order(self):
|
1243
|
+
"""Test that jaxpr_call preserves state order."""
|
1244
|
+
state1 = brainstate.State(jnp.array([1.0]))
|
1245
|
+
state2 = brainstate.State(jnp.array([2.0]))
|
1246
|
+
state3 = brainstate.State(jnp.array([3.0]))
|
1247
|
+
|
1248
|
+
def f(x):
|
1249
|
+
state1.value += x
|
1250
|
+
state2.value += x * 2
|
1251
|
+
state3.value += x * 3
|
1252
|
+
return state1.value + state2.value + state3.value
|
1253
|
+
|
1254
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1255
|
+
x = jnp.array([1.0])
|
1256
|
+
sf.make_jaxpr(x)
|
1257
|
+
|
1258
|
+
cache_key = sf.get_arg_cache_key(x)
|
1259
|
+
states = sf.get_states_by_cache(cache_key)
|
1260
|
+
|
1261
|
+
# Get initial state values
|
1262
|
+
state_vals = [s.value for s in states]
|
1263
|
+
|
1264
|
+
# Call at jaxpr level
|
1265
|
+
new_state_vals, _ = sf.jaxpr_call(state_vals, x)
|
1266
|
+
|
1267
|
+
# Should return same number of states
|
1268
|
+
self.assertEqual(len(new_state_vals), len(state_vals))
|
1269
|
+
|
1270
|
+
|
1271
|
+
class TestStatefulFunctionEdgeCases(unittest.TestCase):
|
1272
|
+
"""Test edge cases and corner scenarios."""
|
1273
|
+
|
1274
|
+
def test_no_state_function(self):
|
1275
|
+
"""Test function that doesn't use any states."""
|
1276
|
+
|
1277
|
+
def f(x):
|
1278
|
+
return x * 2 + 1
|
1279
|
+
|
1280
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1281
|
+
x = jnp.array([1.0, 2.0])
|
1282
|
+
sf.make_jaxpr(x)
|
1283
|
+
|
1284
|
+
cache_key = sf.get_arg_cache_key(x)
|
1285
|
+
states = sf.get_states_by_cache(cache_key)
|
1286
|
+
|
1287
|
+
# Should have no states
|
1288
|
+
self.assertEqual(len(states), 0)
|
1289
|
+
|
1290
|
+
def test_read_only_state(self):
|
1291
|
+
"""Test function that only reads from states."""
|
1292
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
1293
|
+
|
1294
|
+
def f(x):
|
1295
|
+
# Only read from state, don't write
|
1296
|
+
return state.value + x
|
1297
|
+
|
1298
|
+
sf = brainstate.transform.StatefulFunction(f, return_only_write=True)
|
1299
|
+
x = jnp.array([1.0, 2.0])
|
1300
|
+
sf.make_jaxpr(x)
|
1301
|
+
|
1302
|
+
cache_key = sf.get_arg_cache_key(x)
|
1303
|
+
write_states = sf.get_write_states_by_cache(cache_key)
|
1304
|
+
|
1305
|
+
# Should have no write states
|
1306
|
+
self.assertEqual(len(write_states), 0)
|
1307
|
+
|
1308
|
+
def test_scalar_inputs_outputs(self):
|
1309
|
+
"""Test with scalar inputs and outputs."""
|
1310
|
+
state = brainstate.State(jnp.array(1.0))
|
1311
|
+
|
1312
|
+
def f(x):
|
1313
|
+
state.value += x
|
1314
|
+
return state.value
|
1315
|
+
|
1316
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1317
|
+
x = jnp.array(0.5)
|
1318
|
+
sf.make_jaxpr(x)
|
1319
|
+
|
1320
|
+
cache_key = sf.get_arg_cache_key(x)
|
1321
|
+
jaxpr = sf.get_jaxpr_by_cache(cache_key)
|
1322
|
+
|
1323
|
+
# Should compile successfully
|
1324
|
+
self.assertIsNotNone(jaxpr)
|
1325
|
+
|
1326
|
+
def test_empty_function(self):
|
1327
|
+
"""Test function with no operations."""
|
1328
|
+
|
1329
|
+
def f(x):
|
1330
|
+
return x
|
1331
|
+
|
1332
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1333
|
+
x = jnp.array([1.0, 2.0])
|
1334
|
+
sf.make_jaxpr(x)
|
1335
|
+
|
1336
|
+
cache_key = sf.get_arg_cache_key(x)
|
1337
|
+
jaxpr = sf.get_jaxpr_by_cache(cache_key)
|
1338
|
+
|
1339
|
+
# Should compile successfully
|
1340
|
+
self.assertIsNotNone(jaxpr)
|
1341
|
+
|
1342
|
+
def test_complex_dtype(self):
|
1343
|
+
"""Test with complex dtype arrays."""
|
1344
|
+
state = brainstate.State(jnp.array([1.0 + 2.0j, 3.0 + 4.0j]))
|
1345
|
+
|
1346
|
+
def f(x):
|
1347
|
+
state.value += x
|
1348
|
+
return state.value
|
1349
|
+
|
1350
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1351
|
+
x = jnp.array([0.5 + 0.5j, 0.5 + 0.5j])
|
1352
|
+
sf.make_jaxpr(x)
|
1353
|
+
|
1354
|
+
cache_key = sf.get_arg_cache_key(x)
|
1355
|
+
jaxpr = sf.get_jaxpr_by_cache(cache_key)
|
1356
|
+
|
1357
|
+
# Should compile successfully
|
1358
|
+
self.assertIsNotNone(jaxpr)
|
1359
|
+
|
1360
|
+
|
1361
|
+
class TestStatefulFunctionCacheKey(unittest.TestCase):
|
1362
|
+
"""Test cache key generation and behavior."""
|
1363
|
+
|
1364
|
+
def test_cache_key_different_shapes(self):
|
1365
|
+
"""Test that different input shapes produce different cache keys."""
|
1366
|
+
|
1367
|
+
def f(x):
|
1368
|
+
return x * 2
|
1369
|
+
|
1370
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1371
|
+
|
1372
|
+
x1 = jnp.array([1.0, 2.0])
|
1373
|
+
x2 = jnp.array([1.0, 2.0, 3.0])
|
1374
|
+
|
1375
|
+
cache_key1 = sf.get_arg_cache_key(x1)
|
1376
|
+
cache_key2 = sf.get_arg_cache_key(x2)
|
1377
|
+
|
1378
|
+
# Should have different cache keys
|
1379
|
+
self.assertNotEqual(cache_key1, cache_key2)
|
1380
|
+
|
1381
|
+
def test_cache_key_different_dtypes(self):
|
1382
|
+
"""Test that different dtypes produce different cache keys."""
|
1383
|
+
|
1384
|
+
def f(x):
|
1385
|
+
return x * 2
|
1386
|
+
|
1387
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1388
|
+
|
1389
|
+
# Use int32 and float32 instead, which are always available in JAX
|
1390
|
+
x1 = jnp.array([1.0, 2.0], dtype=jnp.float32)
|
1391
|
+
x2 = jnp.array([1, 2], dtype=jnp.int32)
|
1392
|
+
|
1393
|
+
cache_key1 = sf.get_arg_cache_key(x1)
|
1394
|
+
cache_key2 = sf.get_arg_cache_key(x2)
|
1395
|
+
|
1396
|
+
# Should have different cache keys due to different dtypes
|
1397
|
+
self.assertNotEqual(cache_key1, cache_key2)
|
1398
|
+
|
1399
|
+
def test_cache_key_same_abstract_values(self):
|
1400
|
+
"""Test that same abstract values produce same cache keys."""
|
1401
|
+
|
1402
|
+
def f(x):
|
1403
|
+
return x * 2
|
1404
|
+
|
1405
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1406
|
+
|
1407
|
+
x1 = jnp.array([1.0, 2.0])
|
1408
|
+
x2 = jnp.array([3.0, 4.0]) # Different values, same shape/dtype
|
1409
|
+
|
1410
|
+
cache_key1 = sf.get_arg_cache_key(x1)
|
1411
|
+
cache_key2 = sf.get_arg_cache_key(x2)
|
1412
|
+
|
1413
|
+
# Should have same cache keys (abstract values are the same)
|
1414
|
+
self.assertEqual(cache_key1, cache_key2)
|
1415
|
+
|
1416
|
+
def test_cache_key_with_pytree_inputs(self):
|
1417
|
+
"""Test cache key generation with pytree inputs."""
|
1418
|
+
|
1419
|
+
def f(inputs):
|
1420
|
+
x, y = inputs
|
1421
|
+
return x + y
|
1422
|
+
|
1423
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1424
|
+
|
1425
|
+
inputs1 = (jnp.array([1.0]), jnp.array([2.0]))
|
1426
|
+
inputs2 = (jnp.array([3.0]), jnp.array([4.0]))
|
1427
|
+
|
1428
|
+
cache_key1 = sf.get_arg_cache_key(inputs1)
|
1429
|
+
cache_key2 = sf.get_arg_cache_key(inputs2)
|
1430
|
+
|
1431
|
+
# Should have same cache keys (same structure/shapes)
|
1432
|
+
self.assertEqual(cache_key1, cache_key2)
|
1433
|
+
|
1434
|
+
|
1435
|
+
class TestStatefulFunctionRecompilation(unittest.TestCase):
|
1436
|
+
"""Test recompilation scenarios."""
|
1437
|
+
|
1438
|
+
def test_cache_reuse(self):
|
1439
|
+
"""Test that cache is reused for same inputs."""
|
1440
|
+
state = brainstate.State(jnp.array([1.0]))
|
1441
|
+
|
1442
|
+
def f(x):
|
1443
|
+
state.value += x
|
1444
|
+
return state.value
|
1445
|
+
|
1446
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1447
|
+
|
1448
|
+
x = jnp.array([1.0])
|
1449
|
+
|
1450
|
+
# First compilation
|
1451
|
+
sf.make_jaxpr(x)
|
1452
|
+
stats1 = sf.get_cache_stats()
|
1453
|
+
|
1454
|
+
# Second call with same shape should reuse cache
|
1455
|
+
sf.make_jaxpr(x)
|
1456
|
+
stats2 = sf.get_cache_stats()
|
1457
|
+
|
1458
|
+
# Cache size should remain the same
|
1459
|
+
self.assertEqual(
|
1460
|
+
stats1['jaxpr_cache']['size'],
|
1461
|
+
stats2['jaxpr_cache']['size']
|
1462
|
+
)
|
1463
|
+
|
1464
|
+
def test_multiple_compilations_different_shapes(self):
|
1465
|
+
"""Test multiple compilations with different shapes."""
|
1466
|
+
state = brainstate.State(jnp.array([1.0]))
|
1467
|
+
|
1468
|
+
def f(x):
|
1469
|
+
return x * 2
|
1470
|
+
|
1471
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1472
|
+
|
1473
|
+
# Compile for different shapes
|
1474
|
+
shapes = [
|
1475
|
+
jnp.array([1.0]),
|
1476
|
+
jnp.array([1.0, 2.0]),
|
1477
|
+
jnp.array([1.0, 2.0, 3.0]),
|
1478
|
+
]
|
1479
|
+
|
1480
|
+
for x in shapes:
|
1481
|
+
sf.make_jaxpr(x)
|
1482
|
+
|
1483
|
+
stats = sf.get_cache_stats()
|
1484
|
+
|
1485
|
+
# Should have 3 different cache entries
|
1486
|
+
self.assertEqual(stats['jaxpr_cache']['size'], 3)
|
1487
|
+
|
1488
|
+
def test_clear_and_recompile(self):
|
1489
|
+
"""Test clearing cache and recompiling."""
|
1490
|
+
state = brainstate.State(jnp.array([1.0]))
|
1491
|
+
|
1492
|
+
def f(x):
|
1493
|
+
state.value += x
|
1494
|
+
return state.value
|
1495
|
+
|
1496
|
+
sf = brainstate.transform.StatefulFunction(f)
|
1497
|
+
x = jnp.array([1.0])
|
1498
|
+
|
1499
|
+
# Compile
|
1500
|
+
sf.make_jaxpr(x)
|
1501
|
+
stats_before = sf.get_cache_stats()
|
1502
|
+
self.assertGreater(stats_before['jaxpr_cache']['size'], 0)
|
1503
|
+
|
1504
|
+
# Clear cache
|
1505
|
+
sf.clear_cache()
|
1506
|
+
stats_after_clear = sf.get_cache_stats()
|
1507
|
+
self.assertEqual(stats_after_clear['jaxpr_cache']['size'], 0)
|
1508
|
+
|
1509
|
+
# Recompile
|
1510
|
+
sf.make_jaxpr(x)
|
1511
|
+
stats_after_recompile = sf.get_cache_stats()
|
1512
|
+
self.assertGreater(stats_after_recompile['jaxpr_cache']['size'], 0)
|
1513
|
+
|
1514
|
+
|
1515
|
+
class TestStatefulMapping(unittest.TestCase):
|
1516
|
+
def test_state_filters_and_caching(self):
|
1517
|
+
counter = brainstate.ShortTermState(jnp.zeros(3))
|
1518
|
+
|
1519
|
+
def accumulate(x):
|
1520
|
+
counter.value = counter.value + x
|
1521
|
+
return counter.value
|
1522
|
+
|
1523
|
+
mapper = brainstate.transform.StatefulMapping(
|
1524
|
+
accumulate,
|
1525
|
+
in_axes=0,
|
1526
|
+
out_axes=0,
|
1527
|
+
state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1528
|
+
state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1529
|
+
)
|
1530
|
+
|
1531
|
+
xs = jnp.asarray([1.0, 2.0, 3.0])
|
1532
|
+
result = mapper(xs)
|
1533
|
+
self.assertTrue(jnp.allclose(result, xs))
|
1534
|
+
self.assertTrue(jnp.allclose(counter.value, xs))
|
1535
|
+
|
1536
|
+
def test_random_state_restoration(self):
|
1537
|
+
rng_state = brainstate.random.RandomState(0)
|
1538
|
+
|
1539
|
+
def draw(_):
|
1540
|
+
key = rng_state.split_key()
|
1541
|
+
return jr.normal(key, ())
|
1542
|
+
|
1543
|
+
mapper = brainstate.transform.StatefulMapping(
|
1544
|
+
draw,
|
1545
|
+
in_axes=0,
|
1546
|
+
out_axes=0,
|
1547
|
+
)
|
1548
|
+
|
1549
|
+
xs = jnp.ones((4,))
|
1550
|
+
before = rng_state.value
|
1551
|
+
samples = mapper(xs)
|
1552
|
+
self.assertEqual(samples.shape, xs.shape)
|
1553
|
+
self.assertFalse(jnp.allclose(samples, jnp.repeat(samples[0], xs.shape[0])))
|
1554
|
+
self.assertTrue(jnp.array_equal(rng_state.value.shape, before.shape))
|
1555
|
+
|
1556
|
+
def test_inconsistent_batch_sizes_raise(self):
|
1557
|
+
tracker = brainstate.ShortTermState(jnp.array(0.0))
|
1558
|
+
|
1559
|
+
def combine(x, y):
|
1560
|
+
tracker.value = tracker.value + x + y
|
1561
|
+
return tracker.value
|
1562
|
+
|
1563
|
+
mapper = brainstate.transform.StatefulMapping(
|
1564
|
+
combine,
|
1565
|
+
in_axes=(0, 0),
|
1566
|
+
out_axes=0,
|
1567
|
+
state_in_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1568
|
+
state_out_axes={0: state_filter.OfType(brainstate.ShortTermState)},
|
1569
|
+
)
|
1570
|
+
|
1571
|
+
with self.assertRaisesRegex(ValueError, "Inconsistent batch sizes"):
|
1572
|
+
mapper(jnp.ones((3,)), jnp.ones((4,)))
|
1573
|
+
|
1574
|
+
def test_unexpected_out_state_mapping_raise(self):
|
1575
|
+
leak = brainstate.ShortTermState(jnp.array(0.0))
|
1576
|
+
|
1577
|
+
def mutate(x):
|
1578
|
+
leak.value = leak.value + x
|
1579
|
+
return x
|
1580
|
+
|
1581
|
+
mapper = brainstate.transform.StatefulMapping(
|
1582
|
+
mutate,
|
1583
|
+
in_axes=0,
|
1584
|
+
out_axes=0,
|
1585
|
+
state_in_axes={},
|
1586
|
+
state_out_axes={},
|
1587
|
+
unexpected_out_state_mapping='raise',
|
1588
|
+
)
|
1589
|
+
|
1590
|
+
with self.assertRaises(BatchAxisError):
|
1591
|
+
mapper(jnp.ones((2,)))
|
1592
|
+
|
1593
|
+
def test_unexpected_out_state_mapping_warn(self):
|
1594
|
+
leak = brainstate.ShortTermState(jnp.array(0.0))
|
1595
|
+
|
1596
|
+
def mutate(x):
|
1597
|
+
leak.value = leak.value + x
|
1598
|
+
return x
|
1599
|
+
|
1600
|
+
mapper = brainstate.transform.StatefulMapping(
|
1601
|
+
mutate,
|
1602
|
+
in_axes=0,
|
1603
|
+
out_axes=0,
|
1604
|
+
state_in_axes={},
|
1605
|
+
state_out_axes={},
|
1606
|
+
unexpected_out_state_mapping='warn',
|
1607
|
+
)
|
1608
|
+
|
1609
|
+
with pytest.warns(UserWarning):
|
1610
|
+
mapper(jnp.ones((2,)))
|
1611
|
+
self.assertTrue(jnp.allclose(leak.value, 1.0))
|
1612
|
+
|
1613
|
+
def test_unexpected_out_state_mapping_ignore(self):
|
1614
|
+
leak = brainstate.ShortTermState(jnp.array(0.0))
|
1615
|
+
|
1616
|
+
def mutate(x):
|
1617
|
+
leak.value = leak.value + x
|
1618
|
+
return x
|
1619
|
+
|
1620
|
+
mapper = brainstate.transform.StatefulMapping(
|
1621
|
+
mutate,
|
1622
|
+
in_axes=0,
|
1623
|
+
out_axes=0,
|
1624
|
+
state_in_axes={},
|
1625
|
+
state_out_axes={},
|
1626
|
+
unexpected_out_state_mapping='ignore',
|
1627
|
+
)
|
1628
|
+
|
1629
|
+
with warnings.catch_warnings(record=True) as caught:
|
1630
|
+
warnings.simplefilter('always')
|
1631
|
+
mapper(jnp.ones((2,)))
|
1632
|
+
self.assertEqual(len(caught), 0)
|
1633
|
+
self.assertTrue(jnp.allclose(leak.value, 1.0))
|
1634
|
+
|