brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -1,1652 +1,2157 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import contextlib
|
19
|
-
import threading
|
20
|
-
from functools import partial
|
21
|
-
from typing import (
|
22
|
-
Any,
|
23
|
-
Union,
|
24
|
-
Callable,
|
25
|
-
Generic,
|
26
|
-
TypeVar,
|
27
|
-
Optional,
|
28
|
-
TYPE_CHECKING,
|
29
|
-
Tuple,
|
30
|
-
Dict,
|
31
|
-
List,
|
32
|
-
Sequence,
|
33
|
-
Generator,
|
34
|
-
)
|
35
|
-
|
36
|
-
import
|
37
|
-
import
|
38
|
-
|
39
|
-
from jax.
|
40
|
-
|
41
|
-
|
42
|
-
from brainstate.
|
43
|
-
from brainstate.util
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
'
|
48
|
-
'
|
49
|
-
'
|
50
|
-
'
|
51
|
-
'
|
52
|
-
'
|
53
|
-
'
|
54
|
-
|
55
|
-
'
|
56
|
-
'
|
57
|
-
|
58
|
-
'
|
59
|
-
'
|
60
|
-
'
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
self.
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
>>>
|
124
|
-
>>>
|
125
|
-
>>>
|
126
|
-
|
127
|
-
>>> # The
|
128
|
-
>>> state.value =
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
TRACE_CONTEXT.tree_check.
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
Any: The value
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
>>>
|
173
|
-
>>>
|
174
|
-
>>>
|
175
|
-
>>>
|
176
|
-
>>>
|
177
|
-
>>>
|
178
|
-
>>>
|
179
|
-
>>>
|
180
|
-
>>>
|
181
|
-
>>>
|
182
|
-
>>>
|
183
|
-
>>>
|
184
|
-
>>>
|
185
|
-
>>>
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
TRACE_CONTEXT.jax_tracer_check.
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
>>>
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
The
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
- :py:class:`
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
#
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
self.
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
with
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
else
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
#
|
497
|
-
|
498
|
-
|
499
|
-
return
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
attributes
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
if k == '
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
def
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
state's
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
state's
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
"""
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
"""
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
for
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
"""
|
1140
|
-
|
1141
|
-
for
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
"""
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
""
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
def
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
Returns:
|
1515
|
-
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
"""
|
1521
|
-
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1536
|
-
|
1537
|
-
|
1538
|
-
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
|
1546
|
-
"""
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1632
|
-
|
1633
|
-
|
1634
|
-
|
1635
|
-
|
1636
|
-
|
1637
|
-
|
1638
|
-
|
1639
|
-
|
1640
|
-
|
1641
|
-
|
1642
|
-
|
1643
|
-
|
1644
|
-
|
1645
|
-
|
1646
|
-
|
1647
|
-
|
1648
|
-
|
1649
|
-
|
1650
|
-
|
1651
|
-
|
1652
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import contextlib
|
19
|
+
import threading
|
20
|
+
from functools import partial
|
21
|
+
from typing import (
|
22
|
+
Any,
|
23
|
+
Union,
|
24
|
+
Callable,
|
25
|
+
Generic,
|
26
|
+
TypeVar,
|
27
|
+
Optional,
|
28
|
+
TYPE_CHECKING,
|
29
|
+
Tuple,
|
30
|
+
Dict,
|
31
|
+
List,
|
32
|
+
Sequence,
|
33
|
+
Generator,
|
34
|
+
)
|
35
|
+
|
36
|
+
import brainunit as u
|
37
|
+
import jax
|
38
|
+
import numpy as np
|
39
|
+
from jax.api_util import shaped_abstractify
|
40
|
+
from jax.extend import source_info_util
|
41
|
+
|
42
|
+
from brainstate.typing import ArrayLike, PyTree, Missing, Filter
|
43
|
+
from brainstate.util import DictManager, PrettyObject
|
44
|
+
from brainstate.util.filter import Nothing
|
45
|
+
|
46
|
+
__all__ = [
|
47
|
+
'State',
|
48
|
+
'ShortTermState',
|
49
|
+
'LongTermState',
|
50
|
+
'HiddenState',
|
51
|
+
'HiddenGroupState',
|
52
|
+
'HiddenTreeState',
|
53
|
+
'ParamState',
|
54
|
+
'BatchState',
|
55
|
+
'TreefyState',
|
56
|
+
'FakeState',
|
57
|
+
|
58
|
+
'StateDictManager',
|
59
|
+
'StateTraceStack',
|
60
|
+
'check_state_value_tree',
|
61
|
+
'check_state_jax_tracer',
|
62
|
+
'catch_new_states',
|
63
|
+
'maybe_state',
|
64
|
+
]
|
65
|
+
|
66
|
+
A = TypeVar('A')
|
67
|
+
B = TypeVar('B')
|
68
|
+
T = TypeVar('T')
|
69
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
70
|
+
|
71
|
+
max_int = np.iinfo(np.int32)
|
72
|
+
|
73
|
+
|
74
|
+
# The global state of the state stack is accessed by a thread-local object.
|
75
|
+
# This allows concurrent tracing in separate threads; passing traced objects
|
76
|
+
# between threads is forbidden.
|
77
|
+
class ThreadLocalStack(threading.local):
|
78
|
+
"""
|
79
|
+
A thread-local storage class for managing state-related information.
|
80
|
+
|
81
|
+
This class provides thread-local storage for various state management components,
|
82
|
+
ensuring that each thread has its own isolated set of state-related data structures.
|
83
|
+
|
84
|
+
Attributes:
|
85
|
+
state_stack (List[StateTraceStack]): A list to store StateTraceStack objects for the current thread.
|
86
|
+
tree_check (List[bool]): A list of boolean flags for tree structure checking, initialized with [False].
|
87
|
+
jax_tracer_check (List[bool]): A list of boolean flags for JAX tracer checking, initialized with [False].
|
88
|
+
new_state_catcher (List[StateCatcher]): A list to store Catcher objects for capturing new states in the current thread.
|
89
|
+
"""
|
90
|
+
|
91
|
+
def __init__(self):
|
92
|
+
"""
|
93
|
+
Initialize the ThreadLocalStack with empty data structures.
|
94
|
+
|
95
|
+
This constructor sets up the initial state for each thread-local instance,
|
96
|
+
creating empty lists for state stack, tree checking, JAX tracer checking,
|
97
|
+
and new state catching.
|
98
|
+
"""
|
99
|
+
self.state_stack: List[StateTraceStack] = []
|
100
|
+
self.tree_check: List[bool] = [False]
|
101
|
+
self.jax_tracer_check: List[bool] = [False]
|
102
|
+
self.new_state_catcher: List[StateCatcher] = []
|
103
|
+
|
104
|
+
|
105
|
+
TRACE_CONTEXT = ThreadLocalStack()
|
106
|
+
|
107
|
+
|
108
|
+
@contextlib.contextmanager
|
109
|
+
def check_state_value_tree(val: bool = True) -> Generator[None, None, None]:
|
110
|
+
"""
|
111
|
+
The contex manager to check weather the tree structure of the state value keeps consistently.
|
112
|
+
|
113
|
+
Once a :py:class:`~.State` is created, the tree structure of the value is fixed. In default,
|
114
|
+
the tree structure of the value is not checked to avoid off the repeated evaluation.
|
115
|
+
If you want to check the tree structure of the value once the new value is assigned,
|
116
|
+
you can use this context manager.
|
117
|
+
|
118
|
+
Examples
|
119
|
+
--------
|
120
|
+
|
121
|
+
.. code-block:: python
|
122
|
+
|
123
|
+
>>> import brainstate
|
124
|
+
>>> import jax.numpy as jnp
|
125
|
+
>>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
126
|
+
>>> with brainstate.check_state_value_tree():
|
127
|
+
>>> # The line below will not raise an error.
|
128
|
+
>>> state.value = jnp.zeros((2, 3))
|
129
|
+
...
|
130
|
+
>>> # The following code will raise an error, since it changes the tree structure.
|
131
|
+
>>> state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
|
132
|
+
|
133
|
+
"""
|
134
|
+
try:
|
135
|
+
TRACE_CONTEXT.tree_check.append(val)
|
136
|
+
yield
|
137
|
+
finally:
|
138
|
+
TRACE_CONTEXT.tree_check.pop()
|
139
|
+
|
140
|
+
|
141
|
+
def maybe_state(val: Any) -> Any:
|
142
|
+
"""
|
143
|
+
Extracts the value from a State object if given, otherwise returns the input value.
|
144
|
+
|
145
|
+
This function is useful for handling both State objects and raw values uniformly.
|
146
|
+
If the input is a State object, it returns the value stored in that State.
|
147
|
+
If the input is not a State object, it returns the input as is.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
val (Any): The input value, which can be either a State object or any other type.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Any: The value stored in the State if the input is a State object,
|
154
|
+
otherwise the input value itself.
|
155
|
+
"""
|
156
|
+
if isinstance(val, State):
|
157
|
+
return val.value
|
158
|
+
else:
|
159
|
+
return val
|
160
|
+
|
161
|
+
|
162
|
+
@contextlib.contextmanager
|
163
|
+
def check_state_jax_tracer(val: bool = True) -> Generator[None, None, None]:
|
164
|
+
"""
|
165
|
+
The context manager to check whether the state is valid to trace.
|
166
|
+
|
167
|
+
Example
|
168
|
+
-------
|
169
|
+
|
170
|
+
.. code-block:: python
|
171
|
+
|
172
|
+
>>> import jax
|
173
|
+
>>> import brainstate
|
174
|
+
>>> import jax.numpy as jnp
|
175
|
+
>>>
|
176
|
+
>>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
|
177
|
+
>>>
|
178
|
+
>>> @jax.jit
|
179
|
+
>>> def run_state(b):
|
180
|
+
>>> a.value = b
|
181
|
+
>>> return a.value
|
182
|
+
>>>
|
183
|
+
>>> # The following code will not raise an error, since the state is valid to trace.
|
184
|
+
>>> run_state(jnp.ones((2, 3)))
|
185
|
+
>>>
|
186
|
+
>>> with check_state_jax_tracer():
|
187
|
+
>>> # The line below will not raise an error.
|
188
|
+
>>> run_state(jnp.ones((2, 4)))
|
189
|
+
"""
|
190
|
+
try:
|
191
|
+
TRACE_CONTEXT.jax_tracer_check.append(val)
|
192
|
+
yield
|
193
|
+
finally:
|
194
|
+
TRACE_CONTEXT.jax_tracer_check.pop()
|
195
|
+
|
196
|
+
|
197
|
+
def _get_trace_stack_level() -> int:
|
198
|
+
return len(TRACE_CONTEXT.state_stack)
|
199
|
+
|
200
|
+
|
201
|
+
class State(Generic[A], PrettyObject):
|
202
|
+
"""
|
203
|
+
A generic class representing a dynamic data pointer in the BrainState framework.
|
204
|
+
|
205
|
+
The State class serves as a base for various types of state objects used to
|
206
|
+
manage and track dynamic data within a program. It provides mechanisms for
|
207
|
+
value storage, metadata management, and integration with the BrainState
|
208
|
+
tracing system.
|
209
|
+
|
210
|
+
Type Parameters:
|
211
|
+
A: The type of the value stored in the state.
|
212
|
+
|
213
|
+
Attributes:
|
214
|
+
name (Optional[str]): An optional name for the state.
|
215
|
+
value (PyTree): The actual value stored in the state.
|
216
|
+
tag (Optional[str]): An optional tag for categorizing or grouping states.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
220
|
+
The initial value for the state. Can be a PyTree of array-like objects
|
221
|
+
or a StateMetadata object.
|
222
|
+
name (Optional[str]): An optional name for the state.
|
223
|
+
**metadata: Additional metadata to be stored with the state.
|
224
|
+
|
225
|
+
Example
|
226
|
+
-------
|
227
|
+
|
228
|
+
.. code-block:: python
|
229
|
+
|
230
|
+
>>> class MyState(State):
|
231
|
+
... pass
|
232
|
+
>>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
|
233
|
+
>>> print(state.value)
|
234
|
+
[[0. 0. 0.]
|
235
|
+
[0. 0. 0.]
|
236
|
+
[0. 0. 0.]]
|
237
|
+
|
238
|
+
Note:
|
239
|
+
- Subclasses of :class:`State` (e.g., ShortTermState, LongTermState, ParamState,
|
240
|
+
RandomState) are typically used for specific purposes in a program.
|
241
|
+
- The class integrates with BrainState's tracing system to track state
|
242
|
+
creation and modifications.
|
243
|
+
|
244
|
+
The typical examples of :py:class:`~.State` subclass are:
|
245
|
+
|
246
|
+
- :py:class:`ShortTermState`: The short-term state, which is used to store the short-term data in the program.
|
247
|
+
- :py:class:`LongTermState`: The long-term state, which is used to store the long-term data in the program.
|
248
|
+
- :py:class:`ParamState`: The parameter state, which is used to store the parameters in the program.
|
249
|
+
- :py:class:`RandomState`: The random generator state, which is used to store the random key in the program.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
value: PyTree. It can be anything as a pyTree.
|
253
|
+
name: Optional[str]. The name of the state.
|
254
|
+
tag: Optional[str]. The tag of the state.
|
255
|
+
"""
|
256
|
+
__module__ = 'brainstate'
|
257
|
+
_level: int
|
258
|
+
_source_info: source_info_util.SourceInfo
|
259
|
+
_name: Optional[str]
|
260
|
+
_value: PyTree
|
261
|
+
_been_writen: bool # useful in `unflatten` and `flatten` graph processing
|
262
|
+
tag: Optional[str]
|
263
|
+
|
264
|
+
def __init__(
|
265
|
+
self,
|
266
|
+
value: PyTree[ArrayLike],
|
267
|
+
name: Optional[str] = None,
|
268
|
+
**metadata: Any
|
269
|
+
):
|
270
|
+
"""
|
271
|
+
Initialize a new HiddenState instance.
|
272
|
+
|
273
|
+
This constructor sets up the initial state for a hidden state in a dynamic model,
|
274
|
+
handling various input types and metadata.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
value (Union[PyTree[ArrayLike], StateMetadata[PyTree[ArrayLike]]]):
|
278
|
+
The initial value for the hidden state. Can be a PyTree of array-like objects
|
279
|
+
or a StateMetadata object containing both value and metadata.
|
280
|
+
name (Optional[str], optional): A name for the hidden state. Defaults to None.
|
281
|
+
**metadata: Additional metadata to be stored with the hidden state, including:
|
282
|
+
- tag (Optional[str]): A tag for categorizing or grouping states.
|
283
|
+
- Any other custom metadata fields.
|
284
|
+
|
285
|
+
Note:
|
286
|
+
This method initializes the hidden state, processes the input value and metadata,
|
287
|
+
sets up internal attributes, and records the state initialization.
|
288
|
+
"""
|
289
|
+
tag = metadata.pop('tag', None)
|
290
|
+
|
291
|
+
# set the value and metadata
|
292
|
+
if isinstance(value, State):
|
293
|
+
value = value.value
|
294
|
+
|
295
|
+
# update metadata
|
296
|
+
metadata.update(
|
297
|
+
_value=value,
|
298
|
+
_level=_get_trace_stack_level(),
|
299
|
+
_source_info=source_info_util.current(),
|
300
|
+
_name=name,
|
301
|
+
_been_writen=False,
|
302
|
+
tag=tag,
|
303
|
+
)
|
304
|
+
|
305
|
+
# avoid using self._setattr to avoid the check
|
306
|
+
vars(self).update(metadata)
|
307
|
+
|
308
|
+
# record the state initialization
|
309
|
+
record_state_init(self)
|
310
|
+
|
311
|
+
def decrease_stack_level(self):
|
312
|
+
"""
|
313
|
+
Decrease the stack level of the state by one, ensuring it doesn't go below zero.
|
314
|
+
|
315
|
+
This method is used to adjust the stack level of the state, typically when
|
316
|
+
exiting a nested context or scope. It ensures that the level never becomes
|
317
|
+
negative.
|
318
|
+
"""
|
319
|
+
self._level = max(self._level - 1, 0)
|
320
|
+
|
321
|
+
def increase_stack_level(self):
|
322
|
+
"""
|
323
|
+
Increase the stack level of the state by one.
|
324
|
+
|
325
|
+
This method is used to adjust the stack level of the state, typically when
|
326
|
+
entering a nested context or scope. It increments the internal level counter
|
327
|
+
by one.
|
328
|
+
"""
|
329
|
+
self._level = self._level + 1
|
330
|
+
|
331
|
+
@property
|
332
|
+
def name(self) -> Optional[str]:
|
333
|
+
"""
|
334
|
+
The name of the state.
|
335
|
+
"""
|
336
|
+
return self._name
|
337
|
+
|
338
|
+
@name.setter
|
339
|
+
def name(self, name: str) -> None:
|
340
|
+
"""
|
341
|
+
Set the name of the state.
|
342
|
+
"""
|
343
|
+
self._name = name
|
344
|
+
|
345
|
+
@property
|
346
|
+
def value(self) -> PyTree[ArrayLike]:
|
347
|
+
"""
|
348
|
+
The data and its value.
|
349
|
+
"""
|
350
|
+
record_state_value_read(self)
|
351
|
+
return self._read_value()
|
352
|
+
|
353
|
+
@value.setter
|
354
|
+
def value(self, v) -> None:
|
355
|
+
"""
|
356
|
+
Set the value of the state.
|
357
|
+
|
358
|
+
Args:
|
359
|
+
v: The value.
|
360
|
+
"""
|
361
|
+
# NOTE: the following order is important
|
362
|
+
|
363
|
+
if isinstance(v, State): # value checking
|
364
|
+
raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
|
365
|
+
self._check_value_tree(v) # check the tree structure
|
366
|
+
record_state_value_write(self) # record the value by the stack (>= level)
|
367
|
+
self._been_writen = True # set the flag
|
368
|
+
self._write_value(v) # write the value
|
369
|
+
|
370
|
+
@property
|
371
|
+
def stack_level(self):
|
372
|
+
"""
|
373
|
+
The stack level of the state.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
The stack level.
|
377
|
+
"""
|
378
|
+
return self._level
|
379
|
+
|
380
|
+
@stack_level.setter
|
381
|
+
def stack_level(self, level: int):
|
382
|
+
"""
|
383
|
+
Set the stack level of the state.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
level: The stack level.
|
387
|
+
"""
|
388
|
+
self._level = level
|
389
|
+
|
390
|
+
def _read_value(self) -> PyTree[ArrayLike]:
|
391
|
+
"""
|
392
|
+
The interface to customize the value reading.
|
393
|
+
"""
|
394
|
+
self.check_if_deleted()
|
395
|
+
return self._value
|
396
|
+
|
397
|
+
def _write_value(self, v) -> None:
|
398
|
+
"""
|
399
|
+
The interface to customize the value writing.
|
400
|
+
"""
|
401
|
+
self._value = v
|
402
|
+
|
403
|
+
def restore_value(self, v) -> None:
|
404
|
+
"""
|
405
|
+
Restore the value of the state.
|
406
|
+
|
407
|
+
Args:
|
408
|
+
v: The value.
|
409
|
+
"""
|
410
|
+
# value checking
|
411
|
+
if isinstance(v, State):
|
412
|
+
raise ValueError('Cannot set value to a State, ' 'use `copy_from` method instead')
|
413
|
+
with check_state_value_tree():
|
414
|
+
self._check_value_tree(v)
|
415
|
+
# record the value by the stack (>= level)
|
416
|
+
record_state_value_restore(self)
|
417
|
+
# set the value
|
418
|
+
self._value = v
|
419
|
+
|
420
|
+
def value_call(self, func: Callable[..., Any]) -> Any:
|
421
|
+
"""
|
422
|
+
Call the function with the value of the state.
|
423
|
+
"""
|
424
|
+
return jax.tree.map(func, self.value)
|
425
|
+
|
426
|
+
def _check_value_tree(self, v):
|
427
|
+
"""
|
428
|
+
Check if the value tree structure is consistent.
|
429
|
+
"""
|
430
|
+
if TRACE_CONTEXT.tree_check[-1]:
|
431
|
+
in_tree = jax.tree.structure(v)
|
432
|
+
self_tree = jax.tree.structure(self._value)
|
433
|
+
if in_tree != self_tree:
|
434
|
+
self.raise_error_with_source_info(
|
435
|
+
ValueError(f'The given value {in_tree} does not match with the origin tree structure {self_tree}.')
|
436
|
+
)
|
437
|
+
|
438
|
+
def raise_error_with_source_info(self, error: Exception):
|
439
|
+
"""
|
440
|
+
Raise an error with the source information for easy debugging.
|
441
|
+
"""
|
442
|
+
name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
|
443
|
+
with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
|
444
|
+
raise error
|
445
|
+
|
446
|
+
def check_if_deleted(self):
|
447
|
+
pass
|
448
|
+
|
449
|
+
@property
|
450
|
+
def source_info(self) -> source_info_util.SourceInfo:
|
451
|
+
"""
|
452
|
+
The source information of the state, can be useful to identify
|
453
|
+
the source code where the definition of the state.
|
454
|
+
|
455
|
+
Returns:
|
456
|
+
The source information.
|
457
|
+
"""
|
458
|
+
return self._source_info
|
459
|
+
|
460
|
+
def update_from_ref(self, state_ref: TreefyState[A]) -> None:
|
461
|
+
"""
|
462
|
+
Update the state from the state reference :py:class:`TreefyState`.
|
463
|
+
|
464
|
+
Args:
|
465
|
+
state_ref: The state reference.
|
466
|
+
"""
|
467
|
+
metadata = state_ref.get_metadata()
|
468
|
+
variable_vars = vars(self)
|
469
|
+
variable_vars.update(**metadata)
|
470
|
+
if metadata.pop('_been_writen', True):
|
471
|
+
self.value = state_ref.value
|
472
|
+
else:
|
473
|
+
self.restore_value(state_ref.value)
|
474
|
+
|
475
|
+
def replace(self, value: Any = Missing, **kwargs) -> State[Any]:
|
476
|
+
"""
|
477
|
+
Replace the attribute of the state.
|
478
|
+
"""
|
479
|
+
if value is not Missing:
|
480
|
+
kwargs['_value'] = value
|
481
|
+
|
482
|
+
# return `value` if it is a State
|
483
|
+
if '_value' in kwargs and isinstance(value := kwargs['_value'], State):
|
484
|
+
# remove value from kwargs
|
485
|
+
kwargs.pop('_value')
|
486
|
+
if type(self) is not type(value):
|
487
|
+
raise ValueError('Cannot replace value from incompatible container, '
|
488
|
+
f'expected {type(self).__name__}, got {type(value).__name__}')
|
489
|
+
# if kwargs aren't empty, recursively call replace
|
490
|
+
# else return variable value
|
491
|
+
if kwargs:
|
492
|
+
return value.replace(**kwargs)
|
493
|
+
else:
|
494
|
+
return value
|
495
|
+
|
496
|
+
# get and update attributes
|
497
|
+
attributes = vars(self).copy()
|
498
|
+
attributes.update(**kwargs)
|
499
|
+
# return new instance with updated attributes
|
500
|
+
obj = object.__new__(type(self))
|
501
|
+
vars(obj).update(attributes)
|
502
|
+
return obj
|
503
|
+
|
504
|
+
def copy(self: State[A]) -> State[A]:
|
505
|
+
"""
|
506
|
+
Copy the state.
|
507
|
+
"""
|
508
|
+
obj = object.__new__(type(self))
|
509
|
+
attributes = vars(self).copy()
|
510
|
+
# keep its own trace state and stack level
|
511
|
+
attributes['_level'] = _get_trace_stack_level()
|
512
|
+
attributes['_source_info'] = source_info_util.current()
|
513
|
+
attributes.pop('_been_writen', None)
|
514
|
+
# update the metadata
|
515
|
+
vars(obj).update(attributes)
|
516
|
+
return obj
|
517
|
+
|
518
|
+
def to_state_ref(self: State[A]) -> TreefyState[A]:
|
519
|
+
metadata = vars(self).copy()
|
520
|
+
del metadata['_value']
|
521
|
+
return TreefyState(type(self), self._value, **metadata)
|
522
|
+
|
523
|
+
def __pretty_repr_item__(self, k, v):
|
524
|
+
if k in ['_level', '_source_info', '_been_writen']:
|
525
|
+
return None
|
526
|
+
if k == '_value':
|
527
|
+
return 'value', jax.tree.map(shaped_abstractify, v)
|
528
|
+
|
529
|
+
if k == '_name':
|
530
|
+
if self.name is None:
|
531
|
+
return None
|
532
|
+
else:
|
533
|
+
return 'name', v
|
534
|
+
|
535
|
+
if k == 'tag':
|
536
|
+
if self.tag is None:
|
537
|
+
return None
|
538
|
+
else:
|
539
|
+
return 'tag', v
|
540
|
+
|
541
|
+
return k, v
|
542
|
+
|
543
|
+
# def __eq__(self, other: object) -> bool:
|
544
|
+
# return type(self) is type(other) and vars(other) == vars(self)
|
545
|
+
|
546
|
+
def __hash__(self):
|
547
|
+
"""
|
548
|
+
Make the state hashable.
|
549
|
+
"""
|
550
|
+
return hash(id(self))
|
551
|
+
|
552
|
+
def numel(self) -> int:
|
553
|
+
"""
|
554
|
+
Calculate the total number of elements in the state value.
|
555
|
+
|
556
|
+
This method traverses the state's value, which may be a nested structure (PyTree),
|
557
|
+
and computes the sum of sizes of all leaf nodes.
|
558
|
+
|
559
|
+
Returns:
|
560
|
+
int: The total number of elements across all arrays in the state value.
|
561
|
+
For scalar values, this will be 1. For arrays or nested structures,
|
562
|
+
it will be the sum of the sizes of all contained arrays.
|
563
|
+
|
564
|
+
Note:
|
565
|
+
This method uses jax.tree.leaves to flatten any nested structure in the state value,
|
566
|
+
and jax.numpy.size to compute the size of each leaf node.
|
567
|
+
"""
|
568
|
+
sizes = [jax.numpy.size(val) for val in jax.tree.leaves(self._value)]
|
569
|
+
return sum(sizes)
|
570
|
+
|
571
|
+
|
572
|
+
def record_state_init(st: State[A]):
|
573
|
+
"""
|
574
|
+
Record the initialization of a new :class:`State` object.
|
575
|
+
|
576
|
+
This function iterates through all registered state catchers in the current
|
577
|
+
trace context and appends the newly initialized state to each catcher.
|
578
|
+
|
579
|
+
Args:
|
580
|
+
st (State[A]): The newly initialized :class:`State` object to be recorded.
|
581
|
+
|
582
|
+
Note:
|
583
|
+
This function is typically called internally when a new :class:`State` object
|
584
|
+
is created to ensure proper tracking and management of states within
|
585
|
+
the current execution context.
|
586
|
+
"""
|
587
|
+
trace: StateCatcher
|
588
|
+
for trace in TRACE_CONTEXT.new_state_catcher:
|
589
|
+
trace.append(st)
|
590
|
+
|
591
|
+
|
592
|
+
def record_state_value_read(st: State[A]):
|
593
|
+
"""
|
594
|
+
Record that a state's value has been read in all relevant trace stacks.
|
595
|
+
|
596
|
+
This function iterates through all state trace stacks at or above the
|
597
|
+
state's stack level in the current trace context, and records that
|
598
|
+
the given state's value has been read.
|
599
|
+
|
600
|
+
Args:
|
601
|
+
st (State[A]): The state object whose value read is being recorded.
|
602
|
+
'A' is a generic type parameter representing the
|
603
|
+
type of the state's value.
|
604
|
+
|
605
|
+
Note:
|
606
|
+
This function modifies the state trace stacks in the current
|
607
|
+
trace context but does not return any value.
|
608
|
+
"""
|
609
|
+
trace: StateTraceStack
|
610
|
+
for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
|
611
|
+
trace.read_its_value(st)
|
612
|
+
|
613
|
+
|
614
|
+
def record_state_value_write(st: State[A]):
|
615
|
+
"""
|
616
|
+
Record that a state's value has been written in all relevant trace stacks.
|
617
|
+
|
618
|
+
This function iterates through all state trace stacks at or above the
|
619
|
+
state's stack level in the current trace context, and records that
|
620
|
+
the given state's value has been written.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
st (State[A]): The state object whose value write is being recorded.
|
624
|
+
'A' is a generic type parameter representing the
|
625
|
+
type of the state's value.
|
626
|
+
|
627
|
+
Note:
|
628
|
+
This function modifies the state trace stacks in the current
|
629
|
+
trace context but does not return any value.
|
630
|
+
"""
|
631
|
+
trace: StateTraceStack
|
632
|
+
for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
|
633
|
+
trace.write_its_value(st)
|
634
|
+
|
635
|
+
|
636
|
+
def record_state_value_restore(st: State[A]):
|
637
|
+
"""
|
638
|
+
Record that a state's value has been restored.
|
639
|
+
|
640
|
+
This function is used to indicate that a state's value has been restored
|
641
|
+
to a previous value. It internally calls the record_state_value_read
|
642
|
+
function to mark the state as having been accessed.
|
643
|
+
|
644
|
+
Args:
|
645
|
+
st (State[A]): The state object whose value restoration is being recorded.
|
646
|
+
'A' is a generic type parameter representing the
|
647
|
+
type of the state's value.
|
648
|
+
|
649
|
+
See Also:
|
650
|
+
record_state_value_read: Record that a state's value has been read.
|
651
|
+
|
652
|
+
Note:
|
653
|
+
This function does not actually restore the state's value; it only
|
654
|
+
records that a restoration has occurred.
|
655
|
+
"""
|
656
|
+
record_state_value_read(st)
|
657
|
+
|
658
|
+
|
659
|
+
class ShortTermState(State):
|
660
|
+
"""
|
661
|
+
A class representing short-term state in a program.
|
662
|
+
|
663
|
+
:class:`ShortTermState` is used to store temporary or transient data that is only relevant
|
664
|
+
for a short duration within the program's execution. This class extends the base
|
665
|
+
State class, inheriting its properties and methods while specifically denoting
|
666
|
+
the short-term nature of the stored data.
|
667
|
+
|
668
|
+
For example, in a machine learning training process, the gradients of the model
|
669
|
+
would typically be represented as :class:`ShortTermState`, as they are computed and used
|
670
|
+
within each iteration but not necessarily preserved across iterations.
|
671
|
+
|
672
|
+
Attributes:
|
673
|
+
Inherits all attributes from the base State class.
|
674
|
+
|
675
|
+
Note:
|
676
|
+
This class does not introduce new methods or attributes beyond those
|
677
|
+
inherited from the State class. Its primary purpose is to semantically
|
678
|
+
distinguish short-term states from other types of states in the program.
|
679
|
+
|
680
|
+
Example:
|
681
|
+
>>> gradient = ShortTermState(np.zeros(100), name="model_gradient")
|
682
|
+
>>> intermediate_result = ShortTermState({}, name="layer_activations")
|
683
|
+
"""
|
684
|
+
|
685
|
+
__module__ = 'brainstate'
|
686
|
+
|
687
|
+
|
688
|
+
class LongTermState(State):
|
689
|
+
"""
|
690
|
+
The long-term state, which is used to store the long-term data in the program.
|
691
|
+
|
692
|
+
This class extends the base :class:`State` class and is specifically designed to represent
|
693
|
+
and manage long-term data within a program. Long-term states are typically used
|
694
|
+
for data that persists across multiple iterations or epochs of a process.
|
695
|
+
|
696
|
+
For example, in a training process, the weights of the model are considered
|
697
|
+
long-term states as they are updated and maintained throughout the entire
|
698
|
+
training procedure.
|
699
|
+
|
700
|
+
Attributes:
|
701
|
+
Inherits all attributes from the base :class:`State` class.
|
702
|
+
|
703
|
+
Note:
|
704
|
+
This class does not introduce new methods or attributes beyond those
|
705
|
+
inherited from the :class:`State` class. Its primary purpose is to semantically
|
706
|
+
distinguish long-term states from other types of states in the program.
|
707
|
+
|
708
|
+
Example:
|
709
|
+
>>> model_weights = LongTermState(np.random.randn(100, 100), name="model_weights")
|
710
|
+
>>> optimizer_state = LongTermState({}, name="optimizer_state")
|
711
|
+
"""
|
712
|
+
|
713
|
+
__module__ = 'brainstate'
|
714
|
+
|
715
|
+
|
716
|
+
class BatchState(LongTermState):
|
717
|
+
"""
|
718
|
+
The batch state, which is used to store the batch data in the program.
|
719
|
+
|
720
|
+
This class extends :class:`LongTermState` and is specifically designed to represent
|
721
|
+
and manage batch data within a program. It provides a way to encapsulate
|
722
|
+
batch-related information and associated metadata, facilitating operations
|
723
|
+
like batch processing in machine learning or data analysis tasks.
|
724
|
+
|
725
|
+
Attributes:
|
726
|
+
Inherits all attributes from :class:`LongTermState`.
|
727
|
+
|
728
|
+
Note:
|
729
|
+
This class does not introduce new methods or attributes beyond those
|
730
|
+
inherited from :class:`LongTermState`. Its primary purpose is to semantically
|
731
|
+
distinguish batch states from other types of long-term states
|
732
|
+
in the program.
|
733
|
+
|
734
|
+
Example:
|
735
|
+
>>> batch_data = BatchState(np.array([1, 2, 3, 4, 5]), name="current_batch")
|
736
|
+
>>> batch_labels = BatchState(np.array([0, 1, 0, 1, 1]), name="batch_labels")
|
737
|
+
"""
|
738
|
+
|
739
|
+
__module__ = 'brainstate'
|
740
|
+
|
741
|
+
|
742
|
+
class HiddenState(ShortTermState):
|
743
|
+
"""
|
744
|
+
Represents hidden state variables in neurons or synapses.
|
745
|
+
|
746
|
+
This class extends :class:`ShortTermState` and is specifically designed to represent
|
747
|
+
and manage hidden states within dynamic models, such as recurrent neural networks.
|
748
|
+
It provides a way to encapsulate hidden state values and associated metadata,
|
749
|
+
facilitating operations like state updates during model execution.
|
750
|
+
|
751
|
+
Note:
|
752
|
+
:class:`HiddenState` and :class:`ParamState` are two most important state types
|
753
|
+
in brainstate. The former is used to store the hidden states in neurons, synapses,
|
754
|
+
or networks. The latter is used to store the trainable parameters in the model,
|
755
|
+
such as synaptic weights.
|
756
|
+
|
757
|
+
Note:
|
758
|
+
From version 0.2.2, :class:`HiddenState` only supports value of numpy.ndarray,
|
759
|
+
jax.Array or brainunit.Quantity. Moreover, it is equivalent to :class:`brainscale.ETraceState`.
|
760
|
+
Dynamics models defined with :class:`HiddenState` can be seamlessly integrated with
|
761
|
+
BrainScale online learning.
|
762
|
+
|
763
|
+
Example:
|
764
|
+
>>> lstm_hidden = HiddenState(np.zeros(128), name="lstm_hidden_state")
|
765
|
+
>>> gru_hidden = HiddenState(np.zeros(64), name="gru_hidden_state")
|
766
|
+
"""
|
767
|
+
|
768
|
+
__module__ = 'brainstate'
|
769
|
+
|
770
|
+
value: ArrayLike
|
771
|
+
|
772
|
+
def __init__(self, value: ArrayLike, name: Optional[str] = None):
|
773
|
+
self._check_value(value)
|
774
|
+
super().__init__(value, name=name)
|
775
|
+
|
776
|
+
@property
|
777
|
+
def varshape(self) -> Tuple[int, ...]:
|
778
|
+
"""
|
779
|
+
Get the shape of the hidden state variable.
|
780
|
+
|
781
|
+
This property returns the shape of the hidden state variable stored in the instance.
|
782
|
+
It provides the dimensions of the array representing the hidden state.
|
783
|
+
|
784
|
+
Returns:
|
785
|
+
Tuple[int, ...]: A tuple representing the shape of the hidden state variable.
|
786
|
+
"""
|
787
|
+
return self.value.shape
|
788
|
+
|
789
|
+
@property
|
790
|
+
def num_state(self) -> int:
|
791
|
+
"""
|
792
|
+
Get the number of hidden states.
|
793
|
+
|
794
|
+
This property returns the number of hidden states represented by the instance.
|
795
|
+
For the `ETraceState` class, this is always 1, as it represents a single hidden state.
|
796
|
+
|
797
|
+
Returns:
|
798
|
+
int: The number of hidden states, which is 1 for this class.
|
799
|
+
"""
|
800
|
+
return 1
|
801
|
+
|
802
|
+
def _check_value(self, value: ArrayLike):
|
803
|
+
if not isinstance(value, (np.ndarray, jax.Array, u.Quantity)):
|
804
|
+
raise TypeError(
|
805
|
+
f'Currently, {HiddenState.__name__} only supports '
|
806
|
+
f'numpy.ndarray, jax.Array or brainunit.Quantity. '
|
807
|
+
f'But we got {type(value)}.'
|
808
|
+
)
|
809
|
+
|
810
|
+
|
811
|
+
class HiddenGroupState(HiddenState):
|
812
|
+
"""
|
813
|
+
A group of multiple hidden states for eligibility trace-based learning.
|
814
|
+
|
815
|
+
This class is used to define multiple hidden states within a single instance
|
816
|
+
of :py:class:`ETraceState`. Normally, you should define multiple instances
|
817
|
+
of :py:class:`ETraceState` to represent multiple hidden states. But
|
818
|
+
:py:class:`HiddenGroupState` let your define multiple hidden states within
|
819
|
+
a single instance.
|
820
|
+
|
821
|
+
The following is the way to initialize the hidden states.
|
822
|
+
|
823
|
+
.. code-block:: python
|
824
|
+
|
825
|
+
import brainunit as u
|
826
|
+
value = np.random.randn(10, 10, 5) * u.mV
|
827
|
+
state = HiddenGroupState(value)
|
828
|
+
|
829
|
+
Then, you can retrieve the hidden state value with the following method.
|
830
|
+
|
831
|
+
.. code-block:: python
|
832
|
+
|
833
|
+
state.get_value(0) # get the first hidden state
|
834
|
+
# or
|
835
|
+
state.get_value('0') # get the hidden state with the name '0'
|
836
|
+
|
837
|
+
You can write the hidden state value with the following method.
|
838
|
+
|
839
|
+
.. code-block:: python
|
840
|
+
|
841
|
+
state.set_value({0: np.random.randn(10, 10) * u.mV}) # set the first hidden state
|
842
|
+
# or
|
843
|
+
state.set_value({'0': np.random.randn(10, 10) * u.mV}) # set the hidden state with the name '0'
|
844
|
+
# or
|
845
|
+
state.value = np.random.randn(10, 10, 5) * u.mV # set all hidden state value
|
846
|
+
|
847
|
+
Args:
|
848
|
+
value: The values of the hidden states. It can be a sequence of hidden states,
|
849
|
+
or a single hidden state with the last dimension as the number of hidden states,
|
850
|
+
or a dictionary of hidden states.
|
851
|
+
"""
|
852
|
+
|
853
|
+
__module__ = 'brainstate'
|
854
|
+
value: ArrayLike
|
855
|
+
name2index: Dict[str, int]
|
856
|
+
|
857
|
+
def __init__(self, value: ArrayLike):
|
858
|
+
value, name2index = self._check_value(value)
|
859
|
+
self.name2index = name2index
|
860
|
+
ShortTermState.__init__(self, value)
|
861
|
+
|
862
|
+
@property
|
863
|
+
def varshape(self) -> Tuple[int, ...]:
|
864
|
+
"""
|
865
|
+
Get the shape of each hidden state variable.
|
866
|
+
|
867
|
+
This property returns the shape of the hidden state variables, excluding
|
868
|
+
the last dimension which represents the number of hidden states.
|
869
|
+
|
870
|
+
Returns:
|
871
|
+
Tuple[int, ...]: A tuple representing the shape of each hidden state variable.
|
872
|
+
"""
|
873
|
+
return self.value.shape[:-1]
|
874
|
+
|
875
|
+
@property
|
876
|
+
def num_state(self) -> int:
|
877
|
+
"""
|
878
|
+
Get the number of hidden states.
|
879
|
+
|
880
|
+
This property returns the number of hidden states represented by the last dimension
|
881
|
+
of the value array.
|
882
|
+
|
883
|
+
Returns:
|
884
|
+
int: The number of hidden states.
|
885
|
+
"""
|
886
|
+
return self.value.shape[-1]
|
887
|
+
|
888
|
+
def _check_value(self, value) -> Tuple[ArrayLike, Dict[str, int]]:
|
889
|
+
"""
|
890
|
+
Validates the input value for hidden states and returns a tuple containing
|
891
|
+
the processed value and a dictionary mapping state names to indices.
|
892
|
+
|
893
|
+
This function ensures that the input value is of a supported type and has
|
894
|
+
the required dimensionality for hidden states. It also constructs a mapping
|
895
|
+
from string representations of indices to their integer counterparts.
|
896
|
+
|
897
|
+
Parameters
|
898
|
+
----------
|
899
|
+
value (ArrayLike): The input value representing hidden states.
|
900
|
+
It must be an instance of numpy.ndarray, jax.Array, or brainunit.Quantity
|
901
|
+
with at least two dimensions.
|
902
|
+
|
903
|
+
Returns
|
904
|
+
-------
|
905
|
+
Tuple[ArrayLike, Dict[str, int]]: A tuple containing:
|
906
|
+
- The validated and possibly modified input value.
|
907
|
+
- A dictionary mapping string representations of indices to integer indices.
|
908
|
+
|
909
|
+
Raises
|
910
|
+
------
|
911
|
+
TypeError: If the input value is not of a supported type.
|
912
|
+
ValueError: If the input value does not have the required number of dimensions.
|
913
|
+
"""
|
914
|
+
if not isinstance(value, (np.ndarray, jax.Array, u.Quantity)):
|
915
|
+
raise TypeError(
|
916
|
+
f'Currently, {self.__class__.__name__} only supports '
|
917
|
+
f'numpy.ndarray, jax.Array or brainunit.Quantity. '
|
918
|
+
f'But we got {type(value)}.'
|
919
|
+
)
|
920
|
+
if value.ndim < 2:
|
921
|
+
raise ValueError(
|
922
|
+
f'Currently, {self.__class__.__name__} only supports '
|
923
|
+
f'hidden states with more than 2 dimensions, where the last '
|
924
|
+
f'dimension is the number of state size and the other dimensions '
|
925
|
+
f'are the hidden shape. '
|
926
|
+
f'But we got {value.ndim} dimensions.'
|
927
|
+
)
|
928
|
+
name2index = {str(i): i for i in range(value.shape[-1])}
|
929
|
+
return value, name2index
|
930
|
+
|
931
|
+
def get_value(self, item: int | str) -> ArrayLike:
|
932
|
+
"""
|
933
|
+
Get the value of the hidden state with the item.
|
934
|
+
|
935
|
+
Args:
|
936
|
+
item: int or str. The index of the hidden state.
|
937
|
+
- If int, the index of the hidden state.
|
938
|
+
- If str, the name of the hidden state.
|
939
|
+
Returns:
|
940
|
+
The value of the hidden state.
|
941
|
+
"""
|
942
|
+
if isinstance(item, int):
|
943
|
+
assert item < self.value.shape[-1], (f'Index {item} out of range. '
|
944
|
+
f'The maximum index is {self.value.shape[-1] - 1}.')
|
945
|
+
return self.value[..., item]
|
946
|
+
elif isinstance(item, str):
|
947
|
+
assert item in self.name2index, (f'Hidden state name {item} not found. '
|
948
|
+
f'Please check the hidden state names.')
|
949
|
+
index = self.name2index[item]
|
950
|
+
return self.value[..., index]
|
951
|
+
else:
|
952
|
+
raise TypeError(
|
953
|
+
f'Currently, {self.__class__.__name__} only supports '
|
954
|
+
f'int or str for getting the hidden state. '
|
955
|
+
f'But we got {type(item)}.'
|
956
|
+
)
|
957
|
+
|
958
|
+
def set_value(
|
959
|
+
self,
|
960
|
+
val: Dict[int | str, ArrayLike] | Sequence[ArrayLike]
|
961
|
+
) -> None:
|
962
|
+
"""
|
963
|
+
Set the value of the hidden state with the specified item.
|
964
|
+
|
965
|
+
This method updates the hidden state values based on the provided dictionary or sequence.
|
966
|
+
The values are set according to the indices or names specified in the input.
|
967
|
+
|
968
|
+
Parameters
|
969
|
+
----------
|
970
|
+
val (Dict[int | str, ArrayLike] | Sequence[ArrayLike]):
|
971
|
+
A dictionary or sequence containing the new values for the hidden states.
|
972
|
+
- If a dictionary, keys can be integers (indices) or strings (names) of the hidden states.
|
973
|
+
- If a sequence, it is converted to a dictionary with indices as keys.
|
974
|
+
|
975
|
+
Returns
|
976
|
+
-------
|
977
|
+
None: This method does not return any value. It updates the hidden state values in place.
|
978
|
+
"""
|
979
|
+
if isinstance(val, (tuple, list)):
|
980
|
+
val = {i: v for i, v in enumerate(val)}
|
981
|
+
assert isinstance(val, dict), (
|
982
|
+
f'Currently, {self.__class__.__name__}.set_value() only supports '
|
983
|
+
f'dictionary of hidden states. But we got {type(val)}.'
|
984
|
+
)
|
985
|
+
indices = []
|
986
|
+
values = []
|
987
|
+
for k, v in val.items():
|
988
|
+
if isinstance(k, str):
|
989
|
+
k = self.name2index[k]
|
990
|
+
assert isinstance(k, int), (
|
991
|
+
f'Key {k} should be int or str. '
|
992
|
+
f'But we got {type(k)}.'
|
993
|
+
)
|
994
|
+
assert v.shape == self.varshape, (
|
995
|
+
f'The shape of the hidden state should be {self.varshape}. '
|
996
|
+
f'But we got {v.shape}.'
|
997
|
+
)
|
998
|
+
indices.append(k)
|
999
|
+
values.append(v)
|
1000
|
+
values = u.math.stack(values, axis=-1)
|
1001
|
+
self.value = self.value.at[..., indices].set(values)
|
1002
|
+
|
1003
|
+
|
1004
|
+
class HiddenTreeState(HiddenGroupState):
|
1005
|
+
"""
|
1006
|
+
A pytree of multiple hidden states for eligibility trace-based learning.
|
1007
|
+
|
1008
|
+
.. note::
|
1009
|
+
|
1010
|
+
The value in this state class behaves likes a dictionary/sequence of hidden states.
|
1011
|
+
However, the state is actually stored as a single dimensionless array.
|
1012
|
+
|
1013
|
+
There are two ways to define the hidden states.
|
1014
|
+
|
1015
|
+
1. The first is to define a sequence of hidden states.
|
1016
|
+
|
1017
|
+
.. code-block:: python
|
1018
|
+
|
1019
|
+
import brainunit as u
|
1020
|
+
value = [np.random.randn(10, 10) * u.mV,
|
1021
|
+
np.random.randn(10, 10) * u.mA,
|
1022
|
+
np.random.randn(10, 10) * u.mS]
|
1023
|
+
state = HiddenTreeState(value)
|
1024
|
+
|
1025
|
+
Then, you can retrieve the hidden state value with the following method.
|
1026
|
+
|
1027
|
+
.. code-block:: python
|
1028
|
+
|
1029
|
+
state.get_value(0) # get the first hidden state
|
1030
|
+
# or
|
1031
|
+
state.get_value('0') # get the hidden state with the name '0'
|
1032
|
+
|
1033
|
+
You can write the hidden state value with the following method.
|
1034
|
+
|
1035
|
+
.. code-block:: python
|
1036
|
+
|
1037
|
+
state.set_value({0: np.random.randn(10, 10) * u.mV}) # set the first hidden state
|
1038
|
+
# or
|
1039
|
+
state.set_value({'1': np.random.randn(10, 10) * u.mA}) # set the hidden state with the name '1'
|
1040
|
+
# or
|
1041
|
+
state.set_value([np.random.randn(10, 10) * u.mV,
|
1042
|
+
np.random.randn(10, 10) * u.mA,
|
1043
|
+
np.random.randn(10, 10) * u.mS]) # set all hidden state value
|
1044
|
+
# or
|
1045
|
+
state.set_value({
|
1046
|
+
0: np.random.randn(10, 10) * u.mV,
|
1047
|
+
1: np.random.randn(10, 10) * u.mA,
|
1048
|
+
2: np.random.randn(10, 10) * u.mS
|
1049
|
+
}) # set all hidden state value
|
1050
|
+
|
1051
|
+
2. The second is to define a dictionary of hidden states.
|
1052
|
+
|
1053
|
+
.. code-block:: python
|
1054
|
+
|
1055
|
+
import brainunit as u
|
1056
|
+
value = {'v': np.random.randn(10, 10) * u.mV,
|
1057
|
+
'i': np.random.randn(10, 10) * u.mA,
|
1058
|
+
'g': np.random.randn(10, 10) * u.mS}
|
1059
|
+
state = HiddenTreeState(value)
|
1060
|
+
|
1061
|
+
Then, you can retrieve the hidden state value with the following method.
|
1062
|
+
|
1063
|
+
.. code-block:: python
|
1064
|
+
|
1065
|
+
state.get_value('v') # get the hidden state with the name 'v'
|
1066
|
+
# or
|
1067
|
+
state.get_value('i') # get the hidden state with the name 'i'
|
1068
|
+
|
1069
|
+
You can write the hidden state value with the following method.
|
1070
|
+
|
1071
|
+
.. code-block:: python
|
1072
|
+
|
1073
|
+
state.set_value({'v': np.random.randn(10, 10) * u.mV}) # set the hidden state with the name 'v'
|
1074
|
+
# or
|
1075
|
+
state.set_value({'i': np.random.randn(10, 10) * u.mA}) # set the hidden state with the name 'i'
|
1076
|
+
# or
|
1077
|
+
state.set_value([np.random.randn(10, 10) * u.mV,
|
1078
|
+
np.random.randn(10, 10) * u.mA,
|
1079
|
+
np.random.randn(10, 10) * u.mS]) # set all hidden state value
|
1080
|
+
# or
|
1081
|
+
state.set_value({
|
1082
|
+
'v': np.random.randn(10, 10) * u.mV,
|
1083
|
+
'g': np.random.randn(10, 10) * u.mA,
|
1084
|
+
'i': np.random.randn(10, 10) * u.mS
|
1085
|
+
}) # set all hidden state value
|
1086
|
+
|
1087
|
+
.. note::
|
1088
|
+
|
1089
|
+
Avoid using ``HiddenTreeState.value`` to get the state value, or
|
1090
|
+
``HiddenTreeState.value =`` to assign the state value.
|
1091
|
+
|
1092
|
+
Instead, use ``HiddenTreeState.get_value()`` and ``HiddenTreeState.set_value()``.
|
1093
|
+
This is because ``.value`` loss hidden state units and other information,
|
1094
|
+
and it is only dimensionless data.
|
1095
|
+
|
1096
|
+
This design aims to ensure that any etrace hidden state has only one array.
|
1097
|
+
|
1098
|
+
|
1099
|
+
Args:
|
1100
|
+
value: The values of the hidden states.
|
1101
|
+
"""
|
1102
|
+
|
1103
|
+
__module__ = 'brainstate'
|
1104
|
+
value: ArrayLike
|
1105
|
+
|
1106
|
+
def __init__(
|
1107
|
+
self,
|
1108
|
+
value: Dict[str, ArrayLike] | Sequence[ArrayLike],
|
1109
|
+
):
|
1110
|
+
value, name2unit, name2index = self._check_value(value)
|
1111
|
+
self.name2unit: Dict[str, u.Unit] = name2unit
|
1112
|
+
self.name2index: Dict[str, int] = name2index
|
1113
|
+
self.index2unit: Dict[int, u.Unit] = {i: v for i, v in enumerate(name2unit.values())}
|
1114
|
+
self.index2name: Dict[int, str] = {v: k for k, v in name2index.items()}
|
1115
|
+
ShortTermState.__init__(self, value)
|
1116
|
+
|
1117
|
+
@property
|
1118
|
+
def varshape(self) -> Tuple[int, ...]:
|
1119
|
+
"""
|
1120
|
+
The shape of each hidden state variable.
|
1121
|
+
"""
|
1122
|
+
return self.value.shape[:-1]
|
1123
|
+
|
1124
|
+
@property
|
1125
|
+
def num_state(self) -> int:
|
1126
|
+
"""
|
1127
|
+
The number of hidden states.
|
1128
|
+
"""
|
1129
|
+
assert self.value.shape[-1] == len(self.name2index), (
|
1130
|
+
f'The number of hidden states '
|
1131
|
+
f'is not equal to the number of hidden state names.'
|
1132
|
+
)
|
1133
|
+
return self.value.shape[-1]
|
1134
|
+
|
1135
|
+
def _check_value(
|
1136
|
+
self,
|
1137
|
+
value: dict | Sequence
|
1138
|
+
) -> Tuple[ArrayLike, Dict[str, u.Unit], Dict[str, int]]:
|
1139
|
+
"""
|
1140
|
+
Validates and processes the input value to ensure it conforms to the expected format
|
1141
|
+
and structure for hidden states.
|
1142
|
+
|
1143
|
+
This function checks if the input value is a dictionary or sequence of hidden states,
|
1144
|
+
verifies that all hidden states have the same shape, and extracts units and indices
|
1145
|
+
for each hidden state.
|
1146
|
+
|
1147
|
+
Args:
|
1148
|
+
value (dict | Sequence): A dictionary or sequence representing hidden states.
|
1149
|
+
- If a sequence, it is converted to a dictionary with string indices as keys.
|
1150
|
+
- Each hidden state should be a numpy.ndarray, jax.Array, or brainunit.Quantity.
|
1151
|
+
|
1152
|
+
Returns:
|
1153
|
+
Tuple[ArrayLike, Dict[str, u.Unit], Dict[str, int]]:
|
1154
|
+
- A stacked array of hidden state magnitudes.
|
1155
|
+
- A dictionary mapping hidden state names to their units.
|
1156
|
+
- A dictionary mapping hidden state names to their indices.
|
1157
|
+
|
1158
|
+
Raises:
|
1159
|
+
TypeError: If any hidden state is not a numpy.ndarray, jax.Array, or brainunit.Quantity.
|
1160
|
+
ValueError: If hidden states do not have the same shape.
|
1161
|
+
"""
|
1162
|
+
if isinstance(value, (tuple, list)):
|
1163
|
+
value = {str(i): v for i, v in enumerate(value)}
|
1164
|
+
assert isinstance(value, dict), (
|
1165
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1166
|
+
f'dictionary/sequence of hidden states. But we got {type(value)}.'
|
1167
|
+
)
|
1168
|
+
shapes = []
|
1169
|
+
for k, v in value.items():
|
1170
|
+
if not isinstance(v, (np.ndarray, jax.Array, u.Quantity)):
|
1171
|
+
raise TypeError(
|
1172
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1173
|
+
f'numpy.ndarray, jax.Array or brainunit.Quantity. '
|
1174
|
+
f'But we got {type(v)} for key {k}.'
|
1175
|
+
)
|
1176
|
+
shapes.append(v.shape)
|
1177
|
+
if len(set(shapes)) > 1:
|
1178
|
+
info = {k: v.shape for k, v in value.items()}
|
1179
|
+
raise ValueError(
|
1180
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1181
|
+
f'hidden states with the same shape. '
|
1182
|
+
f'But we got {info}.'
|
1183
|
+
)
|
1184
|
+
name2unit = {k: u.get_unit(v) for k, v in value.items()}
|
1185
|
+
name2index = {k: i for i, k in enumerate(value.keys())}
|
1186
|
+
value = u.math.stack([u.get_magnitude(v) for v in value.values()], axis=-1)
|
1187
|
+
return value, name2unit, name2index
|
1188
|
+
|
1189
|
+
def get_value(self, item: str | int) -> ArrayLike:
|
1190
|
+
"""
|
1191
|
+
Get the value of the hidden state with the key.
|
1192
|
+
|
1193
|
+
Args:
|
1194
|
+
item: The key of the hidden state.
|
1195
|
+
- If int, the index of the hidden state.
|
1196
|
+
- If str, the name of the hidden state.
|
1197
|
+
"""
|
1198
|
+
if isinstance(item, int):
|
1199
|
+
assert item < self.value.shape[-1], (f'Index {item} out of range. '
|
1200
|
+
f'The maximum index is {self.value.shape[-1] - 1}.')
|
1201
|
+
val = self.value[..., item]
|
1202
|
+
elif isinstance(item, str):
|
1203
|
+
assert item in self.name2index, (f'Hidden state name {item} not found. '
|
1204
|
+
f'Please check the hidden state names.')
|
1205
|
+
item = self.name2index[item]
|
1206
|
+
val = self.value[..., item]
|
1207
|
+
else:
|
1208
|
+
raise TypeError(
|
1209
|
+
f'Currently, {self.__class__.__name__} only supports '
|
1210
|
+
f'int or str for getting the hidden state. '
|
1211
|
+
f'But we got {type(item)}.'
|
1212
|
+
)
|
1213
|
+
if self.index2unit[item].dim.is_dimensionless:
|
1214
|
+
return val
|
1215
|
+
else:
|
1216
|
+
return val * self.index2unit[item]
|
1217
|
+
|
1218
|
+
def set_value(
|
1219
|
+
self,
|
1220
|
+
val: Dict[int | str, ArrayLike] | Sequence[ArrayLike]
|
1221
|
+
) -> None:
|
1222
|
+
"""
|
1223
|
+
Set the value of the hidden state with the specified item.
|
1224
|
+
|
1225
|
+
This method updates the hidden state values based on the provided dictionary or sequence.
|
1226
|
+
The values are set according to the indices or names specified in the input.
|
1227
|
+
|
1228
|
+
Parameters
|
1229
|
+
----------
|
1230
|
+
val (Dict[int | str, ArrayLike] | Sequence[ArrayLike]):
|
1231
|
+
A dictionary or sequence containing the new values for the hidden states.
|
1232
|
+
- If a dictionary, keys can be integers (indices) or strings (names) of the hidden states.
|
1233
|
+
- If a sequence, it is converted to a dictionary with indices as keys.
|
1234
|
+
|
1235
|
+
Returns
|
1236
|
+
-------
|
1237
|
+
None: This method does not return any value. It updates the hidden state values in place.
|
1238
|
+
"""
|
1239
|
+
if isinstance(val, (tuple, list)):
|
1240
|
+
val = {i: v for i, v in enumerate(val)}
|
1241
|
+
assert isinstance(val, dict), (f'Currently, {self.__class__.__name__}.set_value() only supports '
|
1242
|
+
f'dictionary of hidden states. But we got {type(val)}.')
|
1243
|
+
indices = []
|
1244
|
+
values = []
|
1245
|
+
for index, v in val.items():
|
1246
|
+
if isinstance(index, str):
|
1247
|
+
index = self.name2index[index]
|
1248
|
+
assert isinstance(index, int), (f'Key {index} should be int or str. '
|
1249
|
+
f'But we got {type(index)}.')
|
1250
|
+
assert v.shape == self.varshape, (f'The shape of the hidden state should be {self.varshape}. '
|
1251
|
+
f'But we got {v.shape}.')
|
1252
|
+
indices.append(index)
|
1253
|
+
values.append(u.Quantity(v).to(self.index2unit[index]).mantissa)
|
1254
|
+
if len(indices) == 0:
|
1255
|
+
raise ValueError(
|
1256
|
+
f'No hidden state is set. Please check the hidden state names or indices.'
|
1257
|
+
)
|
1258
|
+
if len(indices) == 1:
|
1259
|
+
indices = indices[0]
|
1260
|
+
values = values[0]
|
1261
|
+
else:
|
1262
|
+
indices = np.asarray(indices)
|
1263
|
+
values = u.math.stack(values, axis=-1)
|
1264
|
+
self.value = self.value.at[..., indices].set(values)
|
1265
|
+
|
1266
|
+
|
1267
|
+
class ParamState(LongTermState):
|
1268
|
+
"""
|
1269
|
+
The parameter state, which is used to store the trainable parameters in the model.
|
1270
|
+
|
1271
|
+
This class extends :class:`LongTermState` and is specifically designed to represent
|
1272
|
+
and manage trainable parameters within a neural network or machine learning model.
|
1273
|
+
It provides a way to encapsulate parameter values and associated metadata,
|
1274
|
+
facilitating operations like parameter updates during training.
|
1275
|
+
|
1276
|
+
Note:
|
1277
|
+
:class:`HiddenState` and :class:`ParamState` are two most important state types
|
1278
|
+
in brainstate. The former is used to store the hidden states in neurons, synapses,
|
1279
|
+
or networks. The latter is used to store the trainable parameters in the model,
|
1280
|
+
such as synaptic weights.
|
1281
|
+
|
1282
|
+
Example:
|
1283
|
+
>>> weight = ParamState(np.random.randn(10, 10), name="layer1_weights")
|
1284
|
+
>>> bias = ParamState(np.zeros(10), name="layer1_bias")
|
1285
|
+
"""
|
1286
|
+
|
1287
|
+
__module__ = 'brainstate'
|
1288
|
+
|
1289
|
+
|
1290
|
+
class FakeState:
|
1291
|
+
"""
|
1292
|
+
The faked state, which is used to store the faked data in the program.
|
1293
|
+
"""
|
1294
|
+
|
1295
|
+
__module__ = 'brainstate'
|
1296
|
+
|
1297
|
+
def __init__(self, value: Any, name: Optional[str] = None):
|
1298
|
+
"""
|
1299
|
+
Initialize a FakeState instance.
|
1300
|
+
|
1301
|
+
Args:
|
1302
|
+
value (Any): The value to be stored in the fake state.
|
1303
|
+
name (Optional[str], optional): The name of the fake state. Defaults to None.
|
1304
|
+
"""
|
1305
|
+
self._value = value
|
1306
|
+
self._name = name
|
1307
|
+
|
1308
|
+
@property
|
1309
|
+
def value(self) -> Any:
|
1310
|
+
"""
|
1311
|
+
Get the value stored in the fake state.
|
1312
|
+
|
1313
|
+
Returns:
|
1314
|
+
Any: The value stored in the fake state.
|
1315
|
+
"""
|
1316
|
+
return self._value
|
1317
|
+
|
1318
|
+
@value.setter
|
1319
|
+
def value(self, v) -> None:
|
1320
|
+
"""
|
1321
|
+
Set the value of the fake state.
|
1322
|
+
|
1323
|
+
Args:
|
1324
|
+
v (Any): The new value to be stored in the fake state.
|
1325
|
+
"""
|
1326
|
+
self._value = v
|
1327
|
+
|
1328
|
+
def __repr__(self) -> str:
|
1329
|
+
"""
|
1330
|
+
Return a string representation of the FakeState instance.
|
1331
|
+
|
1332
|
+
Returns:
|
1333
|
+
str: A string representation of the FakeState instance.
|
1334
|
+
"""
|
1335
|
+
return f'FakedState(value={self._value})'
|
1336
|
+
|
1337
|
+
@property
|
1338
|
+
def name(self) -> Optional[str]:
|
1339
|
+
"""
|
1340
|
+
Get the name of the fake state.
|
1341
|
+
|
1342
|
+
Returns:
|
1343
|
+
Optional[str]: The name of the fake state, or None if not set.
|
1344
|
+
"""
|
1345
|
+
return self._name
|
1346
|
+
|
1347
|
+
@name.setter
|
1348
|
+
def name(self, name: str) -> None:
|
1349
|
+
"""
|
1350
|
+
Set the name of the fake state.
|
1351
|
+
|
1352
|
+
Args:
|
1353
|
+
name (str): The new name for the fake state.
|
1354
|
+
"""
|
1355
|
+
self._name = name
|
1356
|
+
|
1357
|
+
|
1358
|
+
class StateDictManager(DictManager):
|
1359
|
+
"""
|
1360
|
+
State stack, for collecting all :py:class:`~.State` used in the program.
|
1361
|
+
|
1362
|
+
:py:class:`~.StateDictManager` supports all features of python dict.
|
1363
|
+
"""
|
1364
|
+
|
1365
|
+
__module__ = 'brainstate'
|
1366
|
+
|
1367
|
+
def assign_values(self, *args: Dict) -> None:
|
1368
|
+
"""
|
1369
|
+
Assign the value for each element according to the given ``data``.
|
1370
|
+
"""
|
1371
|
+
for arg in args:
|
1372
|
+
assert isinstance(arg, dict), 'Must be an instance of dict.'
|
1373
|
+
for k, v in arg.items():
|
1374
|
+
self._set_elem(k, v)
|
1375
|
+
|
1376
|
+
def split_values(self, *filters: type) -> Tuple[Dict, ...]:
|
1377
|
+
"""
|
1378
|
+
Split the values into several subsets of stack by the given types.
|
1379
|
+
"""
|
1380
|
+
results = tuple(DictManager() for _ in range(len(filters) + 1))
|
1381
|
+
for k, v in self.items():
|
1382
|
+
for i, filt in enumerate(filters):
|
1383
|
+
if isinstance(v, filt):
|
1384
|
+
results[i][k] = v.value
|
1385
|
+
break
|
1386
|
+
else:
|
1387
|
+
results[-1][k] = v.value
|
1388
|
+
return results
|
1389
|
+
|
1390
|
+
def collect_values(self) -> Dict:
|
1391
|
+
"""
|
1392
|
+
Collect the values by the given types.
|
1393
|
+
"""
|
1394
|
+
results = DictManager()
|
1395
|
+
for k, v in self.items():
|
1396
|
+
results[k] = v.value
|
1397
|
+
return results
|
1398
|
+
|
1399
|
+
def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
|
1400
|
+
return super().split(first, *others)
|
1401
|
+
|
1402
|
+
def to_dict_values(self) -> Dict:
|
1403
|
+
"""
|
1404
|
+
Convert the values into a dict.
|
1405
|
+
"""
|
1406
|
+
return {k: v.value for k, v in self.items()}
|
1407
|
+
|
1408
|
+
def _check_elem(self, elem):
|
1409
|
+
assert isinstance(elem, State), f'must be instance of {State}'
|
1410
|
+
|
1411
|
+
def _set_elem(self, key: Any, value: Any) -> None:
|
1412
|
+
self[key].value = value
|
1413
|
+
|
1414
|
+
|
1415
|
+
class StateTraceStack(Generic[A]):
|
1416
|
+
"""
|
1417
|
+
A stack for tracing and managing states during program execution.
|
1418
|
+
|
1419
|
+
``StateTraceStack`` is used to automatically trace and manage State objects,
|
1420
|
+
keeping track of which states are read from or written to during the
|
1421
|
+
execution of a function or block of code. It provides methods for
|
1422
|
+
recording state accesses, retrieving state values, and managing the
|
1423
|
+
lifecycle of states within a tracing context.
|
1424
|
+
|
1425
|
+
The class is generic over type A, allowing for type-safe usage with
|
1426
|
+
different types of State objects.
|
1427
|
+
|
1428
|
+
The ``StateTraceStack`` is a crucial component in implementing state-based
|
1429
|
+
computations and is particularly useful in scenarios involving automatic
|
1430
|
+
differentiation or other forms of program transformation.
|
1431
|
+
"""
|
1432
|
+
|
1433
|
+
def __init__(
|
1434
|
+
self,
|
1435
|
+
new_arg: Callable = None,
|
1436
|
+
name: Optional[str] = None,
|
1437
|
+
):
|
1438
|
+
self.name = name
|
1439
|
+
self.states: List[State] = []
|
1440
|
+
self.been_writen: List[bool] = [] # False: read, True: write
|
1441
|
+
self._state_id_index = dict()
|
1442
|
+
self._original_state_values = []
|
1443
|
+
self._jax_trace_new_arg: Callable = new_arg
|
1444
|
+
self._stack_level = None
|
1445
|
+
|
1446
|
+
def __str__(self) -> str:
|
1447
|
+
_stack_level = self.name if self._stack_level is None else self._stack_level
|
1448
|
+
if _stack_level is None:
|
1449
|
+
_stack_level = ''
|
1450
|
+
return f"{self.__class__.__name__}({_stack_level})"
|
1451
|
+
|
1452
|
+
@property
|
1453
|
+
def original_state_values(self) -> Tuple[PyTree, ...]:
|
1454
|
+
"""
|
1455
|
+
Get the original values of all states in the StateTraceStack.
|
1456
|
+
|
1457
|
+
This property provides access to the initial values of all states
|
1458
|
+
that were captured when they were first added to the stack. It's
|
1459
|
+
useful for comparing current state values with their original values
|
1460
|
+
or for reverting states to their initial condition.
|
1461
|
+
|
1462
|
+
Returns:
|
1463
|
+
Tuple[PyTree, ...]: A tuple containing the original values of all
|
1464
|
+
states in the order they were added to the stack. Each element
|
1465
|
+
is a PyTree representing the structure and values of a state.
|
1466
|
+
"""
|
1467
|
+
return tuple(self._original_state_values)
|
1468
|
+
|
1469
|
+
def set_new_arg(self, new_arg: Callable) -> None:
|
1470
|
+
self._jax_trace_new_arg = new_arg
|
1471
|
+
|
1472
|
+
def new_arg(self, state: State) -> None:
|
1473
|
+
"""
|
1474
|
+
Apply a transformation to the value of a given state using a predefined function.
|
1475
|
+
|
1476
|
+
This method is used internally to transform the value of a state during tracing.
|
1477
|
+
If a transformation function (``_jax_trace_new_arg``) is defined, it applies this
|
1478
|
+
function to each element of the state's value using JAX's tree mapping.
|
1479
|
+
|
1480
|
+
Args:
|
1481
|
+
state (State): The State object whose value needs to be transformed.
|
1482
|
+
|
1483
|
+
Returns:
|
1484
|
+
None: This function modifies the state in-place and doesn't return anything.
|
1485
|
+
|
1486
|
+
Note:
|
1487
|
+
This method is intended for internal use and relies on the presence of
|
1488
|
+
a ``_jax_trace_new_arg`` function, which should be set separately.
|
1489
|
+
"""
|
1490
|
+
if self._jax_trace_new_arg is not None:
|
1491
|
+
# internal use
|
1492
|
+
state._value = self._jax_trace_new_arg(state)
|
1493
|
+
|
1494
|
+
def __enter__(self) -> 'StateTraceStack':
|
1495
|
+
TRACE_CONTEXT.state_stack.append(self)
|
1496
|
+
self._stack_level = ' / '.join([st.name for st in TRACE_CONTEXT.state_stack if st.name is not None])
|
1497
|
+
return self
|
1498
|
+
|
1499
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
1500
|
+
TRACE_CONTEXT.state_stack.pop()
|
1501
|
+
|
1502
|
+
def read_its_value(self, state: State) -> None:
|
1503
|
+
"""
|
1504
|
+
Record that a state's value has been read during tracing.
|
1505
|
+
|
1506
|
+
This method marks the given state as having been read in the current
|
1507
|
+
tracing context. If the state hasn't been encountered before, it adds
|
1508
|
+
it to the internal tracking structures and applies any necessary
|
1509
|
+
transformations via the new_arg method.
|
1510
|
+
|
1511
|
+
Args:
|
1512
|
+
state (State): The State object whose value is being read.
|
1513
|
+
|
1514
|
+
Returns:
|
1515
|
+
None
|
1516
|
+
|
1517
|
+
Note:
|
1518
|
+
This method updates the internal tracking of state accesses.
|
1519
|
+
It doesn't actually read or return the state's value.
|
1520
|
+
"""
|
1521
|
+
id_ = id(state)
|
1522
|
+
if id_ not in self._state_id_index:
|
1523
|
+
self._state_id_index[id_] = len(self.states)
|
1524
|
+
self.states.append(state)
|
1525
|
+
self.been_writen.append(False)
|
1526
|
+
self._original_state_values.append(state._value) # internal use
|
1527
|
+
self.new_arg(state)
|
1528
|
+
|
1529
|
+
def write_its_value(self, state: State) -> None:
|
1530
|
+
"""
|
1531
|
+
Record that a state's value has been written to during tracing.
|
1532
|
+
|
1533
|
+
This method marks the given state as having been written to in the current
|
1534
|
+
tracing context. If the state hasn't been encountered before, it first
|
1535
|
+
records it as being read before marking it as written.
|
1536
|
+
|
1537
|
+
Args:
|
1538
|
+
state (State): The State object whose value is being written to.
|
1539
|
+
|
1540
|
+
Returns:
|
1541
|
+
None
|
1542
|
+
|
1543
|
+
Note:
|
1544
|
+
This method updates the internal tracking of state modifications.
|
1545
|
+
It doesn't actually modify the state's value.
|
1546
|
+
"""
|
1547
|
+
id_ = id(state)
|
1548
|
+
if id_ not in self._state_id_index:
|
1549
|
+
self.read_its_value(state)
|
1550
|
+
index = self._state_id_index[id_]
|
1551
|
+
self.been_writen[index] = True
|
1552
|
+
|
1553
|
+
def get_state_values(
|
1554
|
+
self,
|
1555
|
+
separate: bool = False,
|
1556
|
+
replace: bool = False
|
1557
|
+
) -> Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
|
1558
|
+
"""
|
1559
|
+
Retrieve the values of all states in the StateTraceStack.
|
1560
|
+
|
1561
|
+
This method returns the values of all states, optionally separating them
|
1562
|
+
into written and read states, and optionally replacing values with None
|
1563
|
+
for states that weren't accessed in a particular way.
|
1564
|
+
|
1565
|
+
Args:
|
1566
|
+
separate (bool, optional): If True, separate the values into written
|
1567
|
+
and read states. If False, return all values in a single sequence.
|
1568
|
+
Defaults to False.
|
1569
|
+
replace (bool, optional): If True and separate is True, replace values
|
1570
|
+
with None for states that weren't written/read. If False, only
|
1571
|
+
include values for states that were written/read. Defaults to False.
|
1572
|
+
|
1573
|
+
Returns:
|
1574
|
+
Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]:
|
1575
|
+
If separate is False:
|
1576
|
+
A sequence of all state values.
|
1577
|
+
If separate is True:
|
1578
|
+
A tuple containing two sequences:
|
1579
|
+
- The first sequence contains values of written states.
|
1580
|
+
- The second sequence contains values of read states.
|
1581
|
+
If replace is True, these sequences will have None for
|
1582
|
+
states that weren't written/read respectively.
|
1583
|
+
|
1584
|
+
"""
|
1585
|
+
if separate:
|
1586
|
+
if replace:
|
1587
|
+
writes, reads = [], []
|
1588
|
+
for st, been_writen in zip(self.states, self.been_writen):
|
1589
|
+
if been_writen:
|
1590
|
+
writes.append(st.value)
|
1591
|
+
reads.append(None)
|
1592
|
+
else:
|
1593
|
+
reads.append(st.value)
|
1594
|
+
writes.append(None)
|
1595
|
+
return tuple(writes), tuple(reads)
|
1596
|
+
else:
|
1597
|
+
writes, reads = [], []
|
1598
|
+
for st, been_writen in zip(self.states, self.been_writen):
|
1599
|
+
if been_writen:
|
1600
|
+
writes.append(st.value)
|
1601
|
+
else:
|
1602
|
+
reads.append(st.value)
|
1603
|
+
return tuple(writes), tuple(reads)
|
1604
|
+
else:
|
1605
|
+
return tuple([st.value for st in self.states])
|
1606
|
+
|
1607
|
+
def recovery_original_values(self) -> None:
|
1608
|
+
"""
|
1609
|
+
Restore the original values of all states in the StateTraceStack.
|
1610
|
+
|
1611
|
+
This method iterates through all states in the stack and restores
|
1612
|
+
their values to the original ones that were captured when the states
|
1613
|
+
were first added to the stack. This is useful for reverting changes
|
1614
|
+
made during tracing or for resetting the states to their initial condition.
|
1615
|
+
|
1616
|
+
Note:
|
1617
|
+
This method modifies the states in-place.
|
1618
|
+
|
1619
|
+
Returns:
|
1620
|
+
None
|
1621
|
+
"""
|
1622
|
+
for st, val in zip(self.states, self._original_state_values):
|
1623
|
+
# internal use
|
1624
|
+
st.restore_value(val)
|
1625
|
+
|
1626
|
+
def merge(self, *traces) -> 'StateTraceStack':
|
1627
|
+
"""
|
1628
|
+
Merge other state traces into the current ``StateTraceStack``.
|
1629
|
+
|
1630
|
+
This method combines the states, their write status, and original values from
|
1631
|
+
other ``StateTraceStack`` instances into the current one. If a state from another
|
1632
|
+
trace is not present in the current trace, it is added. If a state is already
|
1633
|
+
present, its write status is updated if necessary.
|
1634
|
+
|
1635
|
+
Args:
|
1636
|
+
*traces: Variable number of ``StateTraceStack`` instances to be merged into
|
1637
|
+
the current instance.
|
1638
|
+
|
1639
|
+
Returns:
|
1640
|
+
StateTraceStack: The current ``StateTraceStack`` instance with merged traces.
|
1641
|
+
|
1642
|
+
Note:
|
1643
|
+
This method modifies the current ``StateTraceStack`` in-place and also returns it.
|
1644
|
+
"""
|
1645
|
+
trace: StateTraceStack
|
1646
|
+
for trace in traces:
|
1647
|
+
for st, been_writen, org_val in zip(trace.states, trace.been_writen, trace._original_state_values):
|
1648
|
+
if id(st) not in self._state_id_index: # read the value
|
1649
|
+
self._state_id_index[id(st)] = len(self.states)
|
1650
|
+
self._original_state_values.append(org_val) # add the original value
|
1651
|
+
self.states.append(st) # append the state
|
1652
|
+
self.been_writen.append(False)
|
1653
|
+
if been_writen:
|
1654
|
+
self.write_its_value(st)
|
1655
|
+
return self
|
1656
|
+
|
1657
|
+
def get_read_states(self, replace_writen: bool = False) -> Tuple[State, ...]:
|
1658
|
+
"""
|
1659
|
+
Retrieve the states that were read during the function execution.
|
1660
|
+
|
1661
|
+
This method returns the states that were accessed (read from) during
|
1662
|
+
the traced function's execution. It can optionally replace written
|
1663
|
+
states with None.
|
1664
|
+
|
1665
|
+
Args:
|
1666
|
+
replace_writen (bool, optional): If True, replace written states with None
|
1667
|
+
in the returned tuple. If False, exclude written states entirely from
|
1668
|
+
the result. Defaults to False.
|
1669
|
+
|
1670
|
+
Returns:
|
1671
|
+
Tuple[State, ...]: A tuple containing the read states.
|
1672
|
+
If replace_writen is True, the tuple will have the same length as the
|
1673
|
+
total number of states, with None for written states.
|
1674
|
+
If replace_writen is False, the tuple will only contain read-only states.
|
1675
|
+
"""
|
1676
|
+
if replace_writen:
|
1677
|
+
return tuple([st if not been_writen else None
|
1678
|
+
for st, been_writen in zip(self.states, self.been_writen)])
|
1679
|
+
else:
|
1680
|
+
return tuple([st for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
|
1681
|
+
|
1682
|
+
def get_read_state_values(self, replace_writen: bool = False) -> Tuple[PyTree, ...]:
|
1683
|
+
"""
|
1684
|
+
Retrieve the values of states that were read during the function execution.
|
1685
|
+
|
1686
|
+
This method returns the values of states that were accessed (read from) during
|
1687
|
+
the traced function's execution. It can optionally replace written states with None.
|
1688
|
+
|
1689
|
+
Args:
|
1690
|
+
replace_writen (bool, optional): If True, replace the values of written
|
1691
|
+
states with None in the returned tuple. If False, exclude written
|
1692
|
+
states entirely from the result. Defaults to False.
|
1693
|
+
|
1694
|
+
Returns:
|
1695
|
+
Tuple[PyTree, ...]: A tuple containing the values of read states.
|
1696
|
+
If replace_writen is True, the tuple will have the same length as the
|
1697
|
+
total number of states, with None for written states.
|
1698
|
+
If replace_writen is False, the tuple will only contain values of
|
1699
|
+
read-only states.
|
1700
|
+
"""
|
1701
|
+
if replace_writen:
|
1702
|
+
return tuple(
|
1703
|
+
[st.value if not been_writen else None
|
1704
|
+
for st, been_writen in zip(self.states, self.been_writen)]
|
1705
|
+
)
|
1706
|
+
else:
|
1707
|
+
return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if not been_writen])
|
1708
|
+
|
1709
|
+
def get_write_states(self, replace_read: bool = False) -> Tuple[State, ...]:
|
1710
|
+
"""
|
1711
|
+
Retrieve the states that were written during the function execution.
|
1712
|
+
|
1713
|
+
This method returns the states that were modified (written to) during
|
1714
|
+
the traced function's execution. It can optionally replace unwritten (read-only)
|
1715
|
+
states with None.
|
1716
|
+
|
1717
|
+
Args:
|
1718
|
+
replace_read (bool, optional): If True, replace read-only states with None
|
1719
|
+
in the returned tuple. If False, exclude read-only states entirely from
|
1720
|
+
the result. Defaults to False.
|
1721
|
+
|
1722
|
+
Returns:
|
1723
|
+
Tuple[State, ...]: A tuple containing the written states.
|
1724
|
+
If replace_read is True, the tuple will have the same length as the
|
1725
|
+
total number of states, with None for read-only states.
|
1726
|
+
If replace_read is False, the tuple will only contain written states.
|
1727
|
+
"""
|
1728
|
+
if replace_read:
|
1729
|
+
return tuple([st if been_writen else None
|
1730
|
+
for st, been_writen in zip(self.states, self.been_writen)])
|
1731
|
+
else:
|
1732
|
+
return tuple([st for st, been_writen in zip(self.states, self.been_writen) if been_writen])
|
1733
|
+
|
1734
|
+
def get_write_state_values(self, replace_read: bool = False) -> Tuple[PyTree, ...]:
|
1735
|
+
"""
|
1736
|
+
Retrieve the values of states that were written during the function execution.
|
1737
|
+
|
1738
|
+
This method returns the values of states that were modified (written to) during
|
1739
|
+
the traced function's execution. It can optionally replace unwritten (read-only)
|
1740
|
+
states with None.
|
1741
|
+
|
1742
|
+
Args:
|
1743
|
+
replace_read (bool, optional): If True, replace the values of read-only
|
1744
|
+
states with None in the returned tuple. If False, exclude read-only
|
1745
|
+
states entirely from the result. Defaults to False.
|
1746
|
+
|
1747
|
+
Returns:
|
1748
|
+
Tuple[PyTree, ...]: A tuple containing the values of written states.
|
1749
|
+
If replace_read is True, the tuple will have the same length as the
|
1750
|
+
total number of states, with None for read-only states.
|
1751
|
+
If replace_read is False, the tuple will only contain values of
|
1752
|
+
written states.
|
1753
|
+
|
1754
|
+
"""
|
1755
|
+
if replace_read:
|
1756
|
+
return tuple([st.value if been_writen else None for st, been_writen in zip(self.states, self.been_writen)])
|
1757
|
+
else:
|
1758
|
+
return tuple([st.value for st, been_writen in zip(self.states, self.been_writen) if been_writen])
|
1759
|
+
|
1760
|
+
def __add__(self, other: 'StateTraceStack') -> 'StateTraceStack':
|
1761
|
+
"""
|
1762
|
+
Support the syntax of `+` to merge the state traces.
|
1763
|
+
"""
|
1764
|
+
return StateTraceStack().merge(self, other)
|
1765
|
+
|
1766
|
+
def state_subset(self, state_type: type) -> List:
|
1767
|
+
"""
|
1768
|
+
Get a subset of states of a specific type from the ``StateTraceStack``.
|
1769
|
+
|
1770
|
+
This method filters the states in the ``StateTraceStack`` and returns only
|
1771
|
+
those that match the specified state type.
|
1772
|
+
|
1773
|
+
Args:
|
1774
|
+
state_type (type): The type of state to filter by. This should be
|
1775
|
+
a subclass of State or State itself.
|
1776
|
+
|
1777
|
+
Returns:
|
1778
|
+
List[State]: A list containing all states in the ``StateTraceStack``
|
1779
|
+
that are instances of the specified state_type.
|
1780
|
+
|
1781
|
+
Example:
|
1782
|
+
>>> stack = StateTraceStack()
|
1783
|
+
>>> # Assume stack has been populated with various state types
|
1784
|
+
>>> short_term_states = stack.state_subset(ShortTermState)
|
1785
|
+
"""
|
1786
|
+
return [st for st in self.states if isinstance(st, state_type)]
|
1787
|
+
|
1788
|
+
def assign_state_vals(self, state_vals: Sequence[PyTree]) -> None:
|
1789
|
+
"""
|
1790
|
+
Assign new values to the states tracked by this ``StateTraceStack``.
|
1791
|
+
|
1792
|
+
This method updates the values of the states based on whether they were
|
1793
|
+
written to or only read during the tracing process. For states that were
|
1794
|
+
written to, it directly assigns the new value. For states that were only
|
1795
|
+
read, it restores the value using the state's restore_value method.
|
1796
|
+
|
1797
|
+
Args:
|
1798
|
+
state_vals (Sequence[PyTree]): A sequence of new state values to be
|
1799
|
+
assigned. Each element in this sequence corresponds to a state
|
1800
|
+
in the ``StateTraceStack``'s states list.
|
1801
|
+
|
1802
|
+
Raises:
|
1803
|
+
ValueError: If the length of state_vals doesn't match the number of
|
1804
|
+
states in the ``StateTraceStack``.
|
1805
|
+
|
1806
|
+
Returns:
|
1807
|
+
None
|
1808
|
+
|
1809
|
+
Note:
|
1810
|
+
The order of state_vals should match the order of states in the
|
1811
|
+
``StateTraceStack``'s states list.
|
1812
|
+
"""
|
1813
|
+
if len(state_vals) != len(self.states):
|
1814
|
+
raise ValueError(
|
1815
|
+
'The length of the state values must be equal to the states. '
|
1816
|
+
f'Bug got {len(state_vals)} and {len(self.states)}'
|
1817
|
+
)
|
1818
|
+
for st, written, val in zip(self.states, self.been_writen, state_vals):
|
1819
|
+
if written:
|
1820
|
+
st.value = val
|
1821
|
+
else:
|
1822
|
+
st.restore_value(val)
|
1823
|
+
|
1824
|
+
def assign_state_vals_v2(
|
1825
|
+
self: StateTraceStack,
|
1826
|
+
read_state_vals: Sequence[PyTree],
|
1827
|
+
write_state_vals: Sequence[PyTree],
|
1828
|
+
):
|
1829
|
+
"""
|
1830
|
+
Write back state values to their corresponding states after computation.
|
1831
|
+
|
1832
|
+
This function updates the state values based on whether they were written to
|
1833
|
+
during the computation. If a state was written to, it gets the new written value.
|
1834
|
+
If not, it restores its original read value.
|
1835
|
+
|
1836
|
+
Parameters
|
1837
|
+
----------
|
1838
|
+
read_state_vals : sequence of PyTree
|
1839
|
+
The original state values that were read at the beginning.
|
1840
|
+
write_state_vals : sequence of PyTree
|
1841
|
+
The new state values that were written during computation.
|
1842
|
+
|
1843
|
+
Examples
|
1844
|
+
--------
|
1845
|
+
Basic usage in a compilation context:
|
1846
|
+
|
1847
|
+
.. code-block:: python
|
1848
|
+
|
1849
|
+
>>> import brainstate
|
1850
|
+
>>> import jax.numpy as jnp
|
1851
|
+
>>>
|
1852
|
+
>>> # Create states
|
1853
|
+
>>> state1 = brainstate.State(jnp.array([1.0, 2.0]))
|
1854
|
+
>>> state2 = brainstate.State(jnp.array([3.0, 4.0]))
|
1855
|
+
>>>
|
1856
|
+
>>> def f(x):
|
1857
|
+
... state1.value += x # This state will be written
|
1858
|
+
... return state1.value + state2.value # state2 is only read
|
1859
|
+
>>>
|
1860
|
+
>>> # During compilation, state values are collected and managed
|
1861
|
+
>>> # write_back_state_values ensures proper state management
|
1862
|
+
"""
|
1863
|
+
if len(self.states) != len(self.been_writen):
|
1864
|
+
raise ValueError('The length of the state values must be equal to the states. ')
|
1865
|
+
if len(read_state_vals) != len(self.states):
|
1866
|
+
raise ValueError('The length of the read state values must be equal to the states. ')
|
1867
|
+
if len(write_state_vals) != len(self.states):
|
1868
|
+
raise ValueError('The length of the write state values must be equal to the states. ')
|
1869
|
+
for st, write, val_r, val_w in zip(
|
1870
|
+
self.states, self.been_writen, read_state_vals, write_state_vals
|
1871
|
+
):
|
1872
|
+
if write:
|
1873
|
+
st.value = val_w
|
1874
|
+
else:
|
1875
|
+
st.restore_value(val_r)
|
1876
|
+
|
1877
|
+
|
1878
|
+
class TreefyState(Generic[A], PrettyObject):
|
1879
|
+
"""
|
1880
|
+
The state as a pytree.
|
1881
|
+
"""
|
1882
|
+
|
1883
|
+
def __init__(
|
1884
|
+
self,
|
1885
|
+
type: type[State[Any]],
|
1886
|
+
value: A,
|
1887
|
+
**metadata
|
1888
|
+
):
|
1889
|
+
self.type = type
|
1890
|
+
self.value = value
|
1891
|
+
vars(self).update(metadata)
|
1892
|
+
|
1893
|
+
if TYPE_CHECKING:
|
1894
|
+
def __getattr__(self, name: str) -> None: ...
|
1895
|
+
|
1896
|
+
def __setattr__(self, name: str, value: Any) -> None: ...
|
1897
|
+
|
1898
|
+
def __delattr__(self, name: str) -> None: ...
|
1899
|
+
|
1900
|
+
def __pretty_repr_item__(self, k, v):
|
1901
|
+
if k in ['_level', '_source_info', '_been_writen']:
|
1902
|
+
return None
|
1903
|
+
if k == '_value':
|
1904
|
+
return 'value', v
|
1905
|
+
|
1906
|
+
if k == '_name':
|
1907
|
+
return None if v is None else ('name', v)
|
1908
|
+
return k, v
|
1909
|
+
|
1910
|
+
@property
|
1911
|
+
def name(self) -> Optional[str]:
|
1912
|
+
"""
|
1913
|
+
The name of the state.
|
1914
|
+
"""
|
1915
|
+
return self._name
|
1916
|
+
|
1917
|
+
@name.setter
|
1918
|
+
def name(self, name: str) -> None:
|
1919
|
+
"""
|
1920
|
+
Set the name of the state.
|
1921
|
+
"""
|
1922
|
+
self._name = name
|
1923
|
+
|
1924
|
+
def replace(self, value: B) -> TreefyState[B]:
|
1925
|
+
"""
|
1926
|
+
Replace the value of the state reference.
|
1927
|
+
"""
|
1928
|
+
return TreefyState(self.type, value, **self.get_metadata())
|
1929
|
+
|
1930
|
+
def to_state(self) -> State[A]:
|
1931
|
+
"""
|
1932
|
+
Convert the state reference to the state.
|
1933
|
+
"""
|
1934
|
+
# we use object.__new__ to avoid calling __init__ and bypass the
|
1935
|
+
# __init__ logic which should not be called twice
|
1936
|
+
metadata = self.get_metadata()
|
1937
|
+
state = object.__new__(self.type)
|
1938
|
+
metadata.pop('_value', None)
|
1939
|
+
metadata.pop('_level', None)
|
1940
|
+
vars(state).update(**metadata, _value=self.value, _level=_get_trace_stack_level())
|
1941
|
+
return state
|
1942
|
+
|
1943
|
+
def copy(self: TreefyState[A]) -> TreefyState[A]:
|
1944
|
+
"""
|
1945
|
+
Copy the state reference.
|
1946
|
+
"""
|
1947
|
+
return jax.tree.map(lambda x: x, self)
|
1948
|
+
|
1949
|
+
def get_metadata(self) -> Dict[str, Any]:
|
1950
|
+
"""
|
1951
|
+
Get the metadata of the state reference
|
1952
|
+
"""
|
1953
|
+
metadata = vars(self).copy()
|
1954
|
+
del metadata['type']
|
1955
|
+
del metadata['value']
|
1956
|
+
return metadata
|
1957
|
+
|
1958
|
+
|
1959
|
+
def _state_ref_flatten(x: TreefyState[Any], *, with_keys: bool):
|
1960
|
+
metadata = tuple(x.get_metadata().items())
|
1961
|
+
if with_keys:
|
1962
|
+
node = (jax.tree_util.GetAttrKey('value'), x.value)
|
1963
|
+
else:
|
1964
|
+
node = x.value
|
1965
|
+
return (node,), (x.type, metadata)
|
1966
|
+
|
1967
|
+
|
1968
|
+
def _state_ref_unflatten(
|
1969
|
+
static: Tuple[type[State[A]], Tuple[Tuple[str, Any], ...]],
|
1970
|
+
children: Tuple[A],
|
1971
|
+
) -> TreefyState[A]:
|
1972
|
+
return TreefyState(type=static[0], value=children[0], **dict(static[1]))
|
1973
|
+
|
1974
|
+
|
1975
|
+
jax.tree_util.register_pytree_with_keys(
|
1976
|
+
TreefyState,
|
1977
|
+
partial(_state_ref_flatten, with_keys=True), # type: ignore
|
1978
|
+
_state_ref_unflatten, # type: ignore
|
1979
|
+
flatten_func=partial(_state_ref_flatten, with_keys=False), # type: ignore
|
1980
|
+
)
|
1981
|
+
|
1982
|
+
|
1983
|
+
class StateCatcher(PrettyObject):
|
1984
|
+
"""
|
1985
|
+
The catcher to catch and manage new states.
|
1986
|
+
|
1987
|
+
This class provides functionality to collect and tag new State objects.
|
1988
|
+
It ensures that each state is only added once and assigns a tag to each state.
|
1989
|
+
|
1990
|
+
Attributes:
|
1991
|
+
state_tag (str): A string identifier used to tag the caught states.
|
1992
|
+
state_ids (set): A set of state IDs to ensure uniqueness.
|
1993
|
+
states (list): A list to store the caught State objects.
|
1994
|
+
"""
|
1995
|
+
|
1996
|
+
def __init__(
|
1997
|
+
self,
|
1998
|
+
state_tag: str,
|
1999
|
+
state_to_exclude: Filter = Nothing()
|
2000
|
+
):
|
2001
|
+
"""
|
2002
|
+
Initialize a new Catcher instance.
|
2003
|
+
|
2004
|
+
Args:
|
2005
|
+
state_tag (str): The tag to be assigned to caught states.
|
2006
|
+
state_to_exclude (Filter, optional): A filter to exclude states from being caught.
|
2007
|
+
"""
|
2008
|
+
if state_to_exclude is None:
|
2009
|
+
state_to_exclude = Nothing()
|
2010
|
+
self.state_to_exclude = state_to_exclude
|
2011
|
+
self.state_tag = state_tag
|
2012
|
+
self.state_ids = set()
|
2013
|
+
self.states = []
|
2014
|
+
|
2015
|
+
def get_state_values(self) -> List[PyTree]:
|
2016
|
+
"""
|
2017
|
+
Get the values of the caught states.
|
2018
|
+
|
2019
|
+
Returns:
|
2020
|
+
list: A list of values of the caught states.
|
2021
|
+
"""
|
2022
|
+
return [state.value for state in self.states]
|
2023
|
+
|
2024
|
+
def get_states(self) -> List[State]:
|
2025
|
+
"""
|
2026
|
+
Get the caught states.
|
2027
|
+
|
2028
|
+
Returns:
|
2029
|
+
list: A list of the caught states.
|
2030
|
+
"""
|
2031
|
+
return self.states
|
2032
|
+
|
2033
|
+
def append(self, state: State):
|
2034
|
+
"""
|
2035
|
+
Add a new state to the catcher if it hasn't been added before.
|
2036
|
+
|
2037
|
+
This method adds the state to the internal list, records its ID,
|
2038
|
+
and assigns the catcher's tag to the state.
|
2039
|
+
|
2040
|
+
Args:
|
2041
|
+
state (State): The State object to be added.
|
2042
|
+
"""
|
2043
|
+
if self.state_to_exclude((), state):
|
2044
|
+
return
|
2045
|
+
if id(state) not in self.state_ids:
|
2046
|
+
self.state_ids.add(id(state))
|
2047
|
+
self.states.append(state)
|
2048
|
+
state.tag = self.state_tag
|
2049
|
+
|
2050
|
+
def __iter__(self):
|
2051
|
+
"""
|
2052
|
+
Allow iteration over the caught states.
|
2053
|
+
|
2054
|
+
Returns:
|
2055
|
+
iterator: An iterator over the list of caught states.
|
2056
|
+
"""
|
2057
|
+
return iter(self.states)
|
2058
|
+
|
2059
|
+
def __len__(self):
|
2060
|
+
"""
|
2061
|
+
Return the number of caught states.
|
2062
|
+
|
2063
|
+
Returns:
|
2064
|
+
int: The number of caught states.
|
2065
|
+
"""
|
2066
|
+
return len(self.states)
|
2067
|
+
|
2068
|
+
def __getitem__(self, index):
|
2069
|
+
"""
|
2070
|
+
Get a state by index.
|
2071
|
+
|
2072
|
+
Args:
|
2073
|
+
index (int): The index of the state to retrieve.
|
2074
|
+
|
2075
|
+
Returns:
|
2076
|
+
State: The state at the specified index.
|
2077
|
+
"""
|
2078
|
+
return self.states[index]
|
2079
|
+
|
2080
|
+
def clear(self):
|
2081
|
+
"""
|
2082
|
+
Clear all caught states.
|
2083
|
+
"""
|
2084
|
+
self.state_ids.clear()
|
2085
|
+
self.states.clear()
|
2086
|
+
|
2087
|
+
def get_by_tag(self, tag: str):
|
2088
|
+
"""
|
2089
|
+
Get all states with a specific tag.
|
2090
|
+
|
2091
|
+
Args:
|
2092
|
+
tag (str): The tag to filter by.
|
2093
|
+
|
2094
|
+
Returns:
|
2095
|
+
list: A list of states with the specified tag.
|
2096
|
+
"""
|
2097
|
+
return [state for state in self.states if state.tag == tag]
|
2098
|
+
|
2099
|
+
def remove(self, state: State):
|
2100
|
+
"""
|
2101
|
+
Remove a specific state from the catcher.
|
2102
|
+
|
2103
|
+
Args:
|
2104
|
+
state (State): The state to remove.
|
2105
|
+
"""
|
2106
|
+
if id(state) in self.state_ids:
|
2107
|
+
self.state_ids.remove(id(state))
|
2108
|
+
self.states.remove(state)
|
2109
|
+
|
2110
|
+
def __contains__(self, state: State):
|
2111
|
+
"""
|
2112
|
+
Check if a state is in the catcher.
|
2113
|
+
|
2114
|
+
Args:
|
2115
|
+
state (State): The state to check for.
|
2116
|
+
|
2117
|
+
Returns:
|
2118
|
+
bool: True if the state is in the catcher, False otherwise.
|
2119
|
+
"""
|
2120
|
+
return id(state) in self.state_ids
|
2121
|
+
|
2122
|
+
|
2123
|
+
@contextlib.contextmanager
|
2124
|
+
def catch_new_states(
|
2125
|
+
state_tag: str = None,
|
2126
|
+
state_to_exclude: Filter = Nothing()
|
2127
|
+
) -> Generator[StateCatcher, None, None]:
|
2128
|
+
"""
|
2129
|
+
A context manager that catches and tracks new states created within its scope.
|
2130
|
+
|
2131
|
+
This function creates a new Catcher object and adds it to the TRACE_CONTEXT's
|
2132
|
+
new_state_catcher list. It allows for tracking and managing new states created
|
2133
|
+
within the context.
|
2134
|
+
|
2135
|
+
Args:
|
2136
|
+
state_tag (str, optional): A string tag to associate with the caught states.
|
2137
|
+
Defaults to None.
|
2138
|
+
state_to_exclude (Filter, optional): A filter object to specify which states
|
2139
|
+
should be excluded from catching. Defaults to Nothing(), which excludes no states.
|
2140
|
+
|
2141
|
+
Yields:
|
2142
|
+
Catcher: A Catcher object that can be used to access and manage the
|
2143
|
+
newly created states within the context.
|
2144
|
+
|
2145
|
+
Example::
|
2146
|
+
|
2147
|
+
with catch_new_states("my_tag") as catcher:
|
2148
|
+
# Create new states here
|
2149
|
+
# They will be caught and tagged with "my_tag"
|
2150
|
+
# Access caught states through catcher object
|
2151
|
+
"""
|
2152
|
+
try:
|
2153
|
+
catcher = StateCatcher(state_tag=state_tag, state_to_exclude=state_to_exclude)
|
2154
|
+
TRACE_CONTEXT.new_state_catcher.append(catcher)
|
2155
|
+
yield catcher
|
2156
|
+
finally:
|
2157
|
+
TRACE_CONTEXT.new_state_catcher.pop()
|