brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1738 +1,1624 @@
|
|
1
|
-
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
-
# The credit should go to the Flax authors.
|
3
|
-
#
|
4
|
-
# Copyright 2024 The Flax Authors.
|
5
|
-
#
|
6
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
-
# you may not use this file except in compliance with the License.
|
8
|
-
# You may obtain a copy of the License at
|
9
|
-
#
|
10
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
-
#
|
12
|
-
# Unless required by applicable law or agreed to in writing, software
|
13
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
-
# See the License for the specific language governing permissions and
|
16
|
-
# limitations under the License.
|
17
|
-
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
|
-
import dataclasses
|
21
|
-
from typing import (
|
22
|
-
Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
23
|
-
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional
|
24
|
-
)
|
25
|
-
|
26
|
-
import jax
|
27
|
-
import numpy as np
|
28
|
-
from typing_extensions import TypeGuard, Unpack
|
29
|
-
|
30
|
-
from brainstate._state import State, TreefyState
|
31
|
-
from brainstate._utils import set_module_as
|
32
|
-
from brainstate.typing import PathParts, Filter, Predicate, Key
|
33
|
-
from brainstate.util.
|
34
|
-
from brainstate.util.
|
35
|
-
from brainstate.util.
|
36
|
-
from brainstate.util.struct import FrozenDict
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
'
|
44
|
-
|
45
|
-
|
46
|
-
'
|
47
|
-
|
48
|
-
#
|
49
|
-
'
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
def
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
def
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
def
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
#
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
)
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
#
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
states
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
"""
|
1103
|
-
|
1104
|
-
|
1105
|
-
return
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
)
|
1258
|
-
|
1259
|
-
|
1260
|
-
"""
|
1261
|
-
|
1262
|
-
if
|
1263
|
-
|
1264
|
-
else:
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1347
|
-
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
"""
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
def
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
|
1515
|
-
|
1516
|
-
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1536
|
-
|
1537
|
-
|
1538
|
-
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1568
|
-
node
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
) ->
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
... self.a = brainstate.nn.Linear(1, 2)
|
1626
|
-
... self.b = brainstate.nn.Linear(2, 3)
|
1627
|
-
... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
1628
|
-
... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
1629
|
-
... self.b.a = brainstate.nn.LIF(2)
|
1630
|
-
...
|
1631
|
-
>>> model = Model()
|
1632
|
-
...
|
1633
|
-
>>> for path, node in brainstate.graph.iter_node([model, model]):
|
1634
|
-
... print(path, node.__class__.__name__)
|
1635
|
-
...
|
1636
|
-
(0, 'a') Linear
|
1637
|
-
(0, 'b', 'a') LIF
|
1638
|
-
(0, 'b') Linear
|
1639
|
-
(0, 'c', 0) Linear
|
1640
|
-
(0, 'c', 1) Linear
|
1641
|
-
(0, 'd', 'x') Linear
|
1642
|
-
(0, 'd', 'y') Linear
|
1643
|
-
(0,) Model
|
1644
|
-
|
1645
|
-
Parameters
|
1646
|
-
----------
|
1647
|
-
node: Node
|
1648
|
-
The node to iterate over.
|
1649
|
-
allowed_hierarchy: tuple of int
|
1650
|
-
The allowed hierarchy.
|
1651
|
-
|
1652
|
-
"""
|
1653
|
-
|
1654
|
-
def _iter_graph_node(
|
1655
|
-
node_: Any,
|
1656
|
-
visited_: set[int],
|
1657
|
-
path_parts_: PathParts,
|
1658
|
-
level_: int,
|
1659
|
-
) -> Iterator[tuple[PathParts, Any]]:
|
1660
|
-
if level_ > allowed_hierarchy[1]:
|
1661
|
-
return
|
1662
|
-
|
1663
|
-
if _is_node(node_):
|
1664
|
-
if id(node_) in visited_:
|
1665
|
-
return
|
1666
|
-
|
1667
|
-
visited_.add(id(node_))
|
1668
|
-
node_dict = _get_node_impl(node_).node_dict(node_)
|
1669
|
-
for key, value in node_dict.items():
|
1670
|
-
yield from _iter_graph_node(value, visited_, (*path_parts_, key),
|
1671
|
-
level_ + 1 if _is_graph_node(value) else level_)
|
1672
|
-
|
1673
|
-
if _is_graph_node(node_) and level_ >= allowed_hierarchy[0]:
|
1674
|
-
yield path_parts_, node_
|
1675
|
-
|
1676
|
-
visited: set[int] = set()
|
1677
|
-
path_parts: PathParts = ()
|
1678
|
-
level: int = 0
|
1679
|
-
yield from _iter_graph_node(node, visited, path_parts, level)
|
1680
|
-
|
1681
|
-
|
1682
|
-
# --------------------------------------------------------
|
1683
|
-
# Graph operations: end
|
1684
|
-
# --------------------------------------------------------
|
1685
|
-
|
1686
|
-
|
1687
|
-
@dataclasses.dataclass(frozen=True)
|
1688
|
-
class Static(Generic[A]):
|
1689
|
-
"""An empty pytree node that treats its inner value as static.
|
1690
|
-
``value`` must define ``__eq__`` and ``__hash__``.
|
1691
|
-
"""
|
1692
|
-
|
1693
|
-
value: A
|
1694
|
-
|
1695
|
-
|
1696
|
-
jax.tree_util.register_static(Static)
|
1697
|
-
|
1698
|
-
|
1699
|
-
# ---------------------------------------------------------
|
1700
|
-
# Pytree
|
1701
|
-
# ---------------------------------------------------------
|
1702
|
-
|
1703
|
-
class PytreeType:
|
1704
|
-
...
|
1705
|
-
|
1706
|
-
|
1707
|
-
def _key_path_to_key(key: Any) -> Key:
|
1708
|
-
if isinstance(key, jax.tree_util.SequenceKey):
|
1709
|
-
return key.idx
|
1710
|
-
elif isinstance(
|
1711
|
-
key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
|
1712
|
-
):
|
1713
|
-
if not isinstance(key.key, Key):
|
1714
|
-
raise ValueError(
|
1715
|
-
f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
|
1716
|
-
)
|
1717
|
-
return key.key
|
1718
|
-
elif isinstance(key, jax.tree_util.GetAttrKey):
|
1719
|
-
return key.name
|
1720
|
-
else:
|
1721
|
-
return str(key)
|
1722
|
-
|
1723
|
-
|
1724
|
-
def _flatten_pytree(pytree: Any):
|
1725
|
-
leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
|
1726
|
-
nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
|
1727
|
-
return nodes, treedef
|
1728
|
-
|
1729
|
-
|
1730
|
-
def _unflatten_pytree(
|
1731
|
-
nodes: tuple[tuple[Key, Any], ...],
|
1732
|
-
treedef: jax.tree_util.PyTreeDef
|
1733
|
-
):
|
1734
|
-
pytree = treedef.unflatten(value for _, value in nodes)
|
1735
|
-
return pytree
|
1736
|
-
|
1737
|
-
|
1738
|
-
PYTREE_NODE_IMPL = PyTreeNodeImpl(type=PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree)
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import dataclasses
|
21
|
+
from typing import (
|
22
|
+
Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
23
|
+
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional
|
24
|
+
)
|
25
|
+
|
26
|
+
import jax
|
27
|
+
import numpy as np
|
28
|
+
from typing_extensions import TypeGuard, Unpack
|
29
|
+
|
30
|
+
from brainstate._state import State, TreefyState
|
31
|
+
from brainstate._utils import set_module_as
|
32
|
+
from brainstate.typing import PathParts, Filter, Predicate, Key
|
33
|
+
from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
|
34
|
+
from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
|
35
|
+
from brainstate.util.filter import to_predicate
|
36
|
+
from brainstate.util.struct import FrozenDict
|
37
|
+
|
38
|
+
__all__ = [
|
39
|
+
'register_graph_node_type',
|
40
|
+
|
41
|
+
# state management in the given graph or node
|
42
|
+
'pop_states',
|
43
|
+
'nodes',
|
44
|
+
'states',
|
45
|
+
'treefy_states',
|
46
|
+
'update_states',
|
47
|
+
|
48
|
+
# graph node operations
|
49
|
+
'flatten',
|
50
|
+
'unflatten',
|
51
|
+
'treefy_split',
|
52
|
+
'treefy_merge',
|
53
|
+
'iter_leaf',
|
54
|
+
'iter_node',
|
55
|
+
'clone',
|
56
|
+
'graphdef',
|
57
|
+
|
58
|
+
# others
|
59
|
+
'RefMap',
|
60
|
+
'GraphDef',
|
61
|
+
'NodeDef',
|
62
|
+
'NodeRef',
|
63
|
+
]
|
64
|
+
|
65
|
+
MAX_INT = np.iinfo(np.int32).max
|
66
|
+
|
67
|
+
A = TypeVar('A')
|
68
|
+
B = TypeVar('B')
|
69
|
+
C = TypeVar('C')
|
70
|
+
F = TypeVar('F', bound=Callable)
|
71
|
+
|
72
|
+
HA = TypeVar('HA', bound=Hashable)
|
73
|
+
HB = TypeVar('HB', bound=Hashable)
|
74
|
+
|
75
|
+
Index = int
|
76
|
+
Names = Sequence[int]
|
77
|
+
Node = TypeVar('Node')
|
78
|
+
Leaf = TypeVar('Leaf')
|
79
|
+
AuxData = TypeVar('AuxData')
|
80
|
+
|
81
|
+
StateLeaf = TreefyState[Any]
|
82
|
+
NodeLeaf = State[Any]
|
83
|
+
GraphStateMapping = NestedDict
|
84
|
+
|
85
|
+
|
86
|
+
# --------------------------------------------------------
|
87
|
+
|
88
|
+
def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
|
89
|
+
return isinstance(x, TreefyState)
|
90
|
+
|
91
|
+
|
92
|
+
def _is_node_leaf(x: Any) -> TypeGuard[NodeLeaf]:
|
93
|
+
return isinstance(x, State)
|
94
|
+
|
95
|
+
|
96
|
+
class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
|
97
|
+
"""
|
98
|
+
A mapping that uses object id as the hash for the keys.
|
99
|
+
|
100
|
+
This mapping is useful when we want to keep track of objects
|
101
|
+
that are being referenced by other objects.
|
102
|
+
|
103
|
+
Parameters
|
104
|
+
----------
|
105
|
+
mapping : Mapping[A, B] or Iterable[Tuple[A, B]], optional
|
106
|
+
A mapping or iterable of key-value pairs.
|
107
|
+
|
108
|
+
Examples
|
109
|
+
--------
|
110
|
+
.. code-block:: python
|
111
|
+
|
112
|
+
>>> import brainstate
|
113
|
+
>>> obj1 = object()
|
114
|
+
>>> obj2 = object()
|
115
|
+
>>> ref_map = brainstate.graph.RefMap()
|
116
|
+
>>> ref_map[obj1] = 'value1'
|
117
|
+
>>> ref_map[obj2] = 'value2'
|
118
|
+
>>> print(obj1 in ref_map)
|
119
|
+
True
|
120
|
+
>>> print(ref_map[obj1])
|
121
|
+
value1
|
122
|
+
|
123
|
+
"""
|
124
|
+
__module__ = 'brainstate.graph'
|
125
|
+
|
126
|
+
def __init__(self, mapping: Union[Mapping[A, B], Iterable[Tuple[A, B]]] = ()) -> None:
|
127
|
+
self._mapping: Dict[int, Tuple[A, B]] = {}
|
128
|
+
self.update(mapping)
|
129
|
+
|
130
|
+
def __getitem__(self, key: A) -> B:
|
131
|
+
return self._mapping[id(key)][1]
|
132
|
+
|
133
|
+
def __contains__(self, key: Any) -> bool:
|
134
|
+
return id(key) in self._mapping
|
135
|
+
|
136
|
+
def __setitem__(self, key: A, value: B) -> None:
|
137
|
+
self._mapping[id(key)] = (key, value)
|
138
|
+
|
139
|
+
def __delitem__(self, key: A) -> None:
|
140
|
+
del self._mapping[id(key)]
|
141
|
+
|
142
|
+
def __iter__(self) -> Iterator[A]:
|
143
|
+
return (key for key, _ in self._mapping.values())
|
144
|
+
|
145
|
+
def __len__(self) -> int:
|
146
|
+
return len(self._mapping)
|
147
|
+
|
148
|
+
def __str__(self) -> str:
|
149
|
+
return repr(self)
|
150
|
+
|
151
|
+
|
152
|
+
@dataclasses.dataclass(frozen=True)
|
153
|
+
class NodeImplBase(Generic[Node, Leaf, AuxData]):
|
154
|
+
type: type
|
155
|
+
flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
|
156
|
+
|
157
|
+
def node_dict(self, node: Node) -> dict[Key, Leaf]:
|
158
|
+
nodes, _ = self.flatten(node)
|
159
|
+
return dict(nodes)
|
160
|
+
|
161
|
+
|
162
|
+
@dataclasses.dataclass(frozen=True)
|
163
|
+
class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
|
164
|
+
set_key: Callable[[Node, Key, Leaf], None]
|
165
|
+
pop_key: Callable[[Node, Key], Leaf]
|
166
|
+
create_empty: Callable[[AuxData], Node]
|
167
|
+
clear: Callable[[Node], None]
|
168
|
+
|
169
|
+
def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]) -> None:
|
170
|
+
for key, value in items:
|
171
|
+
self.set_key(node, key, value)
|
172
|
+
|
173
|
+
|
174
|
+
@dataclasses.dataclass(frozen=True)
|
175
|
+
class PyTreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
|
176
|
+
unflatten: Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
|
177
|
+
|
178
|
+
|
179
|
+
NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf, AuxData]]
|
180
|
+
|
181
|
+
# --------------------------------------------------------
|
182
|
+
# Graph Node implementation: start
|
183
|
+
# --------------------------------------------------------
|
184
|
+
|
185
|
+
_node_impl_for_type: dict[type, NodeImpl] = {}
|
186
|
+
|
187
|
+
|
188
|
+
def register_graph_node_type(
|
189
|
+
type: type,
|
190
|
+
flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]],
|
191
|
+
set_key: Callable[[Node, Key, Leaf], None],
|
192
|
+
pop_key: Callable[[Node, Key], Leaf],
|
193
|
+
create_empty: Callable[[AuxData], Node],
|
194
|
+
clear: Callable[[Node], None],
|
195
|
+
):
|
196
|
+
"""
|
197
|
+
Register a graph node type.
|
198
|
+
|
199
|
+
Parameters
|
200
|
+
----------
|
201
|
+
type : type
|
202
|
+
The type of the node.
|
203
|
+
flatten : Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
|
204
|
+
A function that flattens the node into a sequence of key-value pairs.
|
205
|
+
set_key : Callable[[Node, Key, Leaf], None]
|
206
|
+
A function that sets a key in the node.
|
207
|
+
pop_key : Callable[[Node, Key], Leaf]
|
208
|
+
A function that pops a key from the node.
|
209
|
+
create_empty : Callable[[AuxData], Node]
|
210
|
+
A function that creates an empty node.
|
211
|
+
clear : Callable[[Node], None]
|
212
|
+
A function that clears the node.
|
213
|
+
|
214
|
+
Examples
|
215
|
+
--------
|
216
|
+
.. code-block:: python
|
217
|
+
|
218
|
+
>>> import brainstate
|
219
|
+
>>> # Custom node type implementation
|
220
|
+
>>> class CustomNode:
|
221
|
+
... def __init__(self):
|
222
|
+
... self.data = {}
|
223
|
+
...
|
224
|
+
>>> def flatten_custom(node):
|
225
|
+
... return list(node.data.items()), None
|
226
|
+
...
|
227
|
+
>>> def set_key_custom(node, key, value):
|
228
|
+
... node.data[key] = value
|
229
|
+
...
|
230
|
+
>>> def pop_key_custom(node, key):
|
231
|
+
... return node.data.pop(key)
|
232
|
+
...
|
233
|
+
>>> def create_empty_custom(metadata):
|
234
|
+
... return CustomNode()
|
235
|
+
...
|
236
|
+
>>> def clear_custom(node):
|
237
|
+
... node.data.clear()
|
238
|
+
...
|
239
|
+
>>> # Register the custom node type
|
240
|
+
>>> brainstate.graph.register_graph_node_type(
|
241
|
+
... CustomNode,
|
242
|
+
... flatten_custom,
|
243
|
+
... set_key_custom,
|
244
|
+
... pop_key_custom,
|
245
|
+
... create_empty_custom,
|
246
|
+
... clear_custom
|
247
|
+
... )
|
248
|
+
|
249
|
+
"""
|
250
|
+
_node_impl_for_type[type] = GraphNodeImpl(
|
251
|
+
type=type,
|
252
|
+
flatten=flatten,
|
253
|
+
set_key=set_key,
|
254
|
+
pop_key=pop_key,
|
255
|
+
create_empty=create_empty,
|
256
|
+
clear=clear,
|
257
|
+
)
|
258
|
+
|
259
|
+
|
260
|
+
# --------------------------------------------------------
|
261
|
+
# Graph node implementation: end
|
262
|
+
# --------------------------------------------------------
|
263
|
+
|
264
|
+
|
265
|
+
def _is_node(x: Any) -> bool:
|
266
|
+
return _is_graph_node(x) or _is_pytree_node(x)
|
267
|
+
|
268
|
+
|
269
|
+
def _is_pytree_node(x: Any) -> bool:
|
270
|
+
return not jax.tree_util.all_leaves((x,))
|
271
|
+
|
272
|
+
|
273
|
+
def _is_graph_node(x: Any) -> bool:
|
274
|
+
return type(x) in _node_impl_for_type
|
275
|
+
|
276
|
+
|
277
|
+
def _is_node_type(x: Type[Any]) -> bool:
|
278
|
+
return x in _node_impl_for_type or x is PytreeType
|
279
|
+
|
280
|
+
|
281
|
+
def _get_node_impl(x: Any) -> NodeImpl:
|
282
|
+
if isinstance(x, State):
|
283
|
+
raise ValueError(f'State is not a node: {x}')
|
284
|
+
|
285
|
+
node_type = type(x)
|
286
|
+
if node_type not in _node_impl_for_type:
|
287
|
+
if _is_pytree_node(x):
|
288
|
+
return PYTREE_NODE_IMPL
|
289
|
+
else:
|
290
|
+
raise ValueError(f'Unknown node type: {x}')
|
291
|
+
|
292
|
+
return _node_impl_for_type[node_type]
|
293
|
+
|
294
|
+
|
295
|
+
def get_node_impl_for_type(x: Type[Any]) -> NodeImpl:
|
296
|
+
if x is PytreeType:
|
297
|
+
return PYTREE_NODE_IMPL
|
298
|
+
return _node_impl_for_type[x]
|
299
|
+
|
300
|
+
|
301
|
+
class HashableMapping(Mapping[HA, HB], Hashable):
|
302
|
+
def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
|
303
|
+
self._mapping = dict(mapping)
|
304
|
+
|
305
|
+
def __contains__(self, key: object) -> bool:
|
306
|
+
return key in self._mapping
|
307
|
+
|
308
|
+
def __getitem__(self, key: HA) -> HB:
|
309
|
+
return self._mapping[key]
|
310
|
+
|
311
|
+
def __iter__(self) -> Iterator[HA]:
|
312
|
+
return iter(self._mapping)
|
313
|
+
|
314
|
+
def __len__(self) -> int:
|
315
|
+
return len(self._mapping)
|
316
|
+
|
317
|
+
def __hash__(self) -> int:
|
318
|
+
return hash(tuple(sorted(self._mapping.items())))
|
319
|
+
|
320
|
+
def __eq__(self, other: Any) -> bool:
|
321
|
+
return isinstance(other, HashableMapping) and self._mapping == other._mapping
|
322
|
+
|
323
|
+
def __repr__(self) -> str:
|
324
|
+
return repr(self._mapping)
|
325
|
+
|
326
|
+
|
327
|
+
class GraphDef(Generic[Node]):
|
328
|
+
"""
|
329
|
+
A base dataclass that denotes the graph structure of a :class:`Node`.
|
330
|
+
|
331
|
+
It contains two main components:
|
332
|
+
- type: The type of the node.
|
333
|
+
- index: The index of the node in the graph.
|
334
|
+
|
335
|
+
It has two concrete subclasses:
|
336
|
+
|
337
|
+
- :class:`NodeRef`: A reference to a node in the graph.
|
338
|
+
- :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
|
339
|
+
|
340
|
+
Attributes
|
341
|
+
----------
|
342
|
+
type : Type[Node]
|
343
|
+
The type of the node.
|
344
|
+
index : int
|
345
|
+
The index of the node in the graph.
|
346
|
+
|
347
|
+
"""
|
348
|
+
type: Type[Node]
|
349
|
+
index: int
|
350
|
+
|
351
|
+
|
352
|
+
@dataclasses.dataclass(frozen=True, repr=False)
|
353
|
+
class NodeDef(GraphDef[Node], PrettyRepr):
|
354
|
+
"""
|
355
|
+
A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
|
356
|
+
|
357
|
+
Attributes
|
358
|
+
----------
|
359
|
+
type : Type[Node]
|
360
|
+
Type of the node.
|
361
|
+
index : int
|
362
|
+
Index of the node in the graph.
|
363
|
+
attributes : Tuple[Key, ...]
|
364
|
+
Attributes for the node.
|
365
|
+
subgraphs : HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
|
366
|
+
Mapping of subgraph definitions.
|
367
|
+
static_fields : HashableMapping
|
368
|
+
Mapping of static fields.
|
369
|
+
leaves : HashableMapping[Key, NodeRef[Any] | None]
|
370
|
+
Mapping of leaf nodes.
|
371
|
+
metadata : Hashable
|
372
|
+
Metadata associated with the node.
|
373
|
+
index_mapping : FrozenDict[Index, Index] | None
|
374
|
+
Index mapping for node references.
|
375
|
+
|
376
|
+
"""
|
377
|
+
|
378
|
+
type: Type[Node] # type of the node
|
379
|
+
index: int # index of the node in the graph
|
380
|
+
attributes: Tuple[Key, ...] # attributes for the node
|
381
|
+
subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
|
382
|
+
static_fields: HashableMapping
|
383
|
+
leaves: HashableMapping[Key, NodeRef[Any] | None]
|
384
|
+
metadata: Hashable
|
385
|
+
index_mapping: FrozenDict[Index, Index] | None
|
386
|
+
|
387
|
+
@classmethod
|
388
|
+
def create(
|
389
|
+
cls,
|
390
|
+
type: Type[Node],
|
391
|
+
index: int,
|
392
|
+
attributes: tuple[Key, ...],
|
393
|
+
subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
|
394
|
+
static_fields: Iterable[tuple],
|
395
|
+
leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
|
396
|
+
metadata: Hashable,
|
397
|
+
index_mapping: Mapping[Index, Index] | None,
|
398
|
+
):
|
399
|
+
return cls(
|
400
|
+
type=type,
|
401
|
+
index=index,
|
402
|
+
attributes=attributes,
|
403
|
+
subgraphs=HashableMapping(subgraphs),
|
404
|
+
static_fields=HashableMapping(static_fields),
|
405
|
+
leaves=HashableMapping(leaves),
|
406
|
+
metadata=metadata,
|
407
|
+
index_mapping=FrozenDict(index_mapping) if index_mapping is not None else None,
|
408
|
+
)
|
409
|
+
|
410
|
+
def __pretty_repr__(self):
|
411
|
+
yield PrettyType(type=type(self))
|
412
|
+
|
413
|
+
yield PrettyAttr('type', self.type.__name__)
|
414
|
+
yield PrettyAttr('index', self.index)
|
415
|
+
yield PrettyAttr('attributes', self.attributes)
|
416
|
+
yield PrettyAttr('subgraphs', PrettyMapping(self.subgraphs))
|
417
|
+
yield PrettyAttr('static_fields', PrettyMapping(self.static_fields))
|
418
|
+
yield PrettyAttr('leaves', PrettyMapping(self.leaves))
|
419
|
+
yield PrettyAttr('metadata', self.metadata)
|
420
|
+
yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
|
421
|
+
|
422
|
+
|
423
|
+
jax.tree_util.register_static(NodeDef)
|
424
|
+
|
425
|
+
|
426
|
+
@dataclasses.dataclass(frozen=True, repr=False)
|
427
|
+
class NodeRef(GraphDef[Node], PrettyRepr):
|
428
|
+
"""
|
429
|
+
A reference to a node in the graph.
|
430
|
+
|
431
|
+
The node can be instances of :class:`Node` or :class:`State`.
|
432
|
+
|
433
|
+
Attributes
|
434
|
+
----------
|
435
|
+
type : Type[Node]
|
436
|
+
The type of the node being referenced.
|
437
|
+
index : int
|
438
|
+
The index of the node in the graph.
|
439
|
+
|
440
|
+
"""
|
441
|
+
type: Type[Node]
|
442
|
+
index: int
|
443
|
+
|
444
|
+
def __pretty_repr__(self):
|
445
|
+
yield PrettyType(type=type(self))
|
446
|
+
yield PrettyAttr('type', self.type.__name__)
|
447
|
+
yield PrettyAttr('index', self.index)
|
448
|
+
|
449
|
+
|
450
|
+
jax.tree_util.register_static(NodeRef)
|
451
|
+
|
452
|
+
|
453
|
+
# --------------------------------------------------------
|
454
|
+
# Graph operations: start
|
455
|
+
# --------------------------------------------------------
|
456
|
+
|
457
|
+
|
458
|
+
def _graph_flatten(
|
459
|
+
path: PathParts,
|
460
|
+
ref_index: RefMap[Any, Index],
|
461
|
+
flatted_state_mapping: Dict[PathParts, StateLeaf],
|
462
|
+
node: Any,
|
463
|
+
treefy_state: bool = False,
|
464
|
+
) -> Union[NodeDef[Any], NodeRef[Any]]:
|
465
|
+
"""
|
466
|
+
Recursive helper for graph flatten.
|
467
|
+
|
468
|
+
Parameters
|
469
|
+
----------
|
470
|
+
path : PathParts
|
471
|
+
The path to the node.
|
472
|
+
ref_index : RefMap[Any, Index]
|
473
|
+
A mapping from nodes to indexes.
|
474
|
+
flatted_state_mapping : Dict[PathParts, StateLeaf]
|
475
|
+
A mapping from paths to state leaves.
|
476
|
+
node : Node
|
477
|
+
The node to flatten.
|
478
|
+
treefy_state : bool, optional
|
479
|
+
Whether to convert states to TreefyState, by default False.
|
480
|
+
|
481
|
+
Returns
|
482
|
+
-------
|
483
|
+
NodeDef or NodeRef
|
484
|
+
A NodeDef or a NodeRef.
|
485
|
+
|
486
|
+
"""
|
487
|
+
if not _is_node(node):
|
488
|
+
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
|
489
|
+
|
490
|
+
# If the node is already in the cache, return a reference, otherwise
|
491
|
+
# add it to the cache and continue with the flattening process.
|
492
|
+
# This is done to avoid infinite recursion when there is a reference cycle.
|
493
|
+
if node in ref_index:
|
494
|
+
return NodeRef(type(node), ref_index[node])
|
495
|
+
|
496
|
+
# Get the node implementation for the node type.
|
497
|
+
# There are two types of node implementations: GraphNodeImpl and PyTreeNodeImpl.
|
498
|
+
# - ``GraphNodeImpl`` is used for nodes that have a graph structure.
|
499
|
+
# - ``PyTreeNodeImpl`` is used for nodes that have a tree structure.
|
500
|
+
node_impl = _get_node_impl(node)
|
501
|
+
|
502
|
+
# There are two types of nodes: Node and State.
|
503
|
+
# Here we handle the Node case.
|
504
|
+
if isinstance(node_impl, GraphNodeImpl):
|
505
|
+
# add the node to the cache
|
506
|
+
index = len(ref_index)
|
507
|
+
ref_index[node] = index
|
508
|
+
else:
|
509
|
+
index = -1
|
510
|
+
|
511
|
+
subgraphs: list[tuple[Key, Union[NodeDef[Any], NodeRef[Any]]]] = []
|
512
|
+
static_fields: list[tuple] = []
|
513
|
+
leaves: list[tuple[Key, Union[NodeRef[Any], None]]] = []
|
514
|
+
|
515
|
+
# Flatten the node into a sequence of key-value pairs.
|
516
|
+
values, metadata = node_impl.flatten(node)
|
517
|
+
for key, value in values:
|
518
|
+
if _is_node(value):
|
519
|
+
# Recursively flatten the subgraph.
|
520
|
+
nodedef = _graph_flatten((*path, key), ref_index, flatted_state_mapping, value, treefy_state)
|
521
|
+
subgraphs.append((key, nodedef))
|
522
|
+
elif isinstance(value, State):
|
523
|
+
# If the variable is in the cache, add a reference to it.
|
524
|
+
if value in ref_index:
|
525
|
+
leaves.append((key, NodeRef(type(value), ref_index[value])))
|
526
|
+
else:
|
527
|
+
# If the variable is not in the cache, add it to the cache.
|
528
|
+
# This is done to avoid multiple references to the same variable.
|
529
|
+
flatted_state_mapping[(*path, key)] = (value.to_state_ref() if treefy_state else value)
|
530
|
+
variable_index = ref_index[value] = len(ref_index)
|
531
|
+
leaves.append((key, NodeRef(type(value), variable_index)))
|
532
|
+
elif _is_state_leaf(value):
|
533
|
+
# The instance of ``TreefyState`` is a leaf.
|
534
|
+
flatted_state_mapping[(*path, key)] = value
|
535
|
+
leaves.append((key, None))
|
536
|
+
else:
|
537
|
+
# if isinstance(value, (jax.Array, np.ndarray)):
|
538
|
+
# path_str = '/'.join(map(str, (*path, key)))
|
539
|
+
# raise ValueError(f'Arrays leaves are not supported, at {path_str!r}: {value}')
|
540
|
+
|
541
|
+
# The value is a static field.
|
542
|
+
static_fields.append((key, value))
|
543
|
+
|
544
|
+
nodedef = NodeDef.create(
|
545
|
+
type=node_impl.type,
|
546
|
+
index=index,
|
547
|
+
attributes=tuple(key for key, _ in values),
|
548
|
+
subgraphs=subgraphs,
|
549
|
+
static_fields=static_fields,
|
550
|
+
leaves=leaves,
|
551
|
+
metadata=metadata,
|
552
|
+
index_mapping=None,
|
553
|
+
)
|
554
|
+
return nodedef
|
555
|
+
|
556
|
+
|
557
|
+
@set_module_as('brainstate.graph')
|
558
|
+
def flatten(
|
559
|
+
node: Any,
|
560
|
+
/,
|
561
|
+
ref_index: Optional[RefMap[Any, Index]] = None,
|
562
|
+
treefy_state: bool = True,
|
563
|
+
) -> Tuple[GraphDef[Any], NestedDict]:
|
564
|
+
"""
|
565
|
+
Flattens a graph node into a (graph_def, state_mapping) pair.
|
566
|
+
|
567
|
+
Parameters
|
568
|
+
----------
|
569
|
+
node : Node
|
570
|
+
A graph node.
|
571
|
+
ref_index : RefMap[Any, Index], optional
|
572
|
+
A mapping from nodes to indexes, defaults to None. If not provided, a new
|
573
|
+
empty dictionary is created. This argument can be used to flatten a sequence of graph
|
574
|
+
nodes that share references.
|
575
|
+
treefy_state : bool, optional
|
576
|
+
If True, the state mapping will be a NestedDict instead of a flat dictionary.
|
577
|
+
Default is True.
|
578
|
+
|
579
|
+
Returns
|
580
|
+
-------
|
581
|
+
tuple[GraphDef, NestedDict]
|
582
|
+
A tuple containing the graph definition and state mapping.
|
583
|
+
|
584
|
+
Examples
|
585
|
+
--------
|
586
|
+
.. code-block:: python
|
587
|
+
|
588
|
+
>>> import brainstate
|
589
|
+
>>> node = brainstate.graph.Node()
|
590
|
+
>>> graph_def, state_mapping = brainstate.graph.flatten(node)
|
591
|
+
>>> print(graph_def)
|
592
|
+
>>> print(state_mapping)
|
593
|
+
|
594
|
+
"""
|
595
|
+
ref_index = RefMap() if ref_index is None else ref_index
|
596
|
+
assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
|
597
|
+
flatted_state_mapping: dict[PathParts, StateLeaf] = {}
|
598
|
+
graph_def = _graph_flatten((), ref_index, flatted_state_mapping, node, treefy_state)
|
599
|
+
return graph_def, NestedDict.from_flat(flatted_state_mapping)
|
600
|
+
|
601
|
+
|
602
|
+
def _get_children(
|
603
|
+
graph_def: NodeDef[Any],
|
604
|
+
state_mapping: Mapping,
|
605
|
+
index_ref: dict[Index, Any],
|
606
|
+
index_ref_cache: Optional[dict[Index, Any]],
|
607
|
+
) -> dict[Key, Union[StateLeaf, Any]]:
|
608
|
+
children: dict[Key, Union[StateLeaf, Any]] = {}
|
609
|
+
|
610
|
+
# NOTE: we could allow adding new StateLeafs here
|
611
|
+
# All state keys must be present in the graph definition (the object attributes)
|
612
|
+
if unknown_keys := set(state_mapping) - set(graph_def.attributes):
|
613
|
+
raise ValueError(f'Unknown keys: {unknown_keys}')
|
614
|
+
|
615
|
+
# for every key in attributes there are 6 possible cases:
|
616
|
+
# - (2) the key can either be present in the state or not
|
617
|
+
# - (3) the key can be a subgraph, a leaf, or a static attribute
|
618
|
+
for key in graph_def.attributes:
|
619
|
+
if key not in state_mapping: # static field
|
620
|
+
# Support unflattening with missing keys for static fields and subgraphs
|
621
|
+
# This allows partial state restoration and flexible graph reconstruction
|
622
|
+
if key in graph_def.static_fields:
|
623
|
+
children[key] = graph_def.static_fields[key]
|
624
|
+
|
625
|
+
elif key in graph_def.subgraphs:
|
626
|
+
# if the key is a subgraph we create an empty node
|
627
|
+
subgraphdef = graph_def.subgraphs[key]
|
628
|
+
if isinstance(subgraphdef, NodeRef):
|
629
|
+
# subgraph exists, take it from the cache
|
630
|
+
children[key] = index_ref[subgraphdef.index]
|
631
|
+
|
632
|
+
else:
|
633
|
+
# create a node from an empty state, reasoning:
|
634
|
+
# * it is a node with no state
|
635
|
+
# * it is a node with state but only through references of already
|
636
|
+
# created nodes
|
637
|
+
substate = {}
|
638
|
+
children[key] = _graph_unflatten(subgraphdef, substate, index_ref, index_ref_cache)
|
639
|
+
|
640
|
+
elif key in graph_def.leaves:
|
641
|
+
noderef = graph_def.leaves[key]
|
642
|
+
if (noderef is not None) and (noderef.index in index_ref):
|
643
|
+
# variable exists, take it from the cache
|
644
|
+
children[key] = index_ref[noderef.index]
|
645
|
+
|
646
|
+
else:
|
647
|
+
# key for a variable is missing, raise an error
|
648
|
+
raise ValueError(
|
649
|
+
f'Expected key {key!r} in state while building node of type '
|
650
|
+
f'{graph_def.type.__name__}.'
|
651
|
+
)
|
652
|
+
|
653
|
+
else:
|
654
|
+
raise RuntimeError(f'Unknown static field: {key!r}')
|
655
|
+
|
656
|
+
else: # state field
|
657
|
+
value = state_mapping[key]
|
658
|
+
if isinstance(value, PrettyDict):
|
659
|
+
value = dict(value)
|
660
|
+
|
661
|
+
if key in graph_def.static_fields:
|
662
|
+
raise ValueError(f'Got state for static field {key!r}, this is not supported.')
|
663
|
+
|
664
|
+
if key in graph_def.subgraphs:
|
665
|
+
# if _is_state_leaf(value):
|
666
|
+
if isinstance(value, (TreefyState, State)):
|
667
|
+
raise ValueError(
|
668
|
+
f'Expected value of type {graph_def.subgraphs[key]} '
|
669
|
+
f'for {key!r}, but got {value!r}'
|
670
|
+
)
|
671
|
+
|
672
|
+
if not isinstance(value, dict):
|
673
|
+
raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
|
674
|
+
|
675
|
+
subgraphdef = graph_def.subgraphs[key]
|
676
|
+
if isinstance(subgraphdef, NodeRef):
|
677
|
+
children[key] = index_ref[subgraphdef.index]
|
678
|
+
else:
|
679
|
+
children[key] = _graph_unflatten(subgraphdef, value, index_ref, index_ref_cache)
|
680
|
+
|
681
|
+
elif key in graph_def.leaves:
|
682
|
+
# if not _is_state_leaf(value):
|
683
|
+
if not isinstance(value, (TreefyState, State)):
|
684
|
+
raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
|
685
|
+
|
686
|
+
noderef = graph_def.leaves[key]
|
687
|
+
if noderef is None:
|
688
|
+
# if the leaf is None, it means that the value was originally
|
689
|
+
# a non-TreefyState leaf, however we allow providing a
|
690
|
+
# TreefyState presumbly created by modifying the NestedDict
|
691
|
+
if isinstance(value, TreefyState):
|
692
|
+
value = value.to_state()
|
693
|
+
elif isinstance(value, State):
|
694
|
+
value = value
|
695
|
+
children[key] = value
|
696
|
+
|
697
|
+
elif noderef.index in index_ref:
|
698
|
+
# add an existing variable
|
699
|
+
children[key] = index_ref[noderef.index]
|
700
|
+
|
701
|
+
else:
|
702
|
+
# it is an unseen variable, create a new one
|
703
|
+
if not isinstance(value, (TreefyState, State)):
|
704
|
+
raise ValueError(
|
705
|
+
f'Expected a State type for {key!r}, but got {type(value)}.'
|
706
|
+
)
|
707
|
+
|
708
|
+
# when idxmap is present, check if the Varable exists there
|
709
|
+
# and update existing variables if it does
|
710
|
+
if index_ref_cache is not None and noderef.index in index_ref_cache:
|
711
|
+
variable = index_ref_cache[noderef.index]
|
712
|
+
if not isinstance(variable, State):
|
713
|
+
raise ValueError(f'Expected a State type for {key!r}, but got {type(variable)}.')
|
714
|
+
if isinstance(value, TreefyState):
|
715
|
+
variable.update_from_ref(value)
|
716
|
+
elif isinstance(value, State):
|
717
|
+
if value._been_writen:
|
718
|
+
variable.value = value.value
|
719
|
+
else:
|
720
|
+
variable.restore_value(value.value)
|
721
|
+
else:
|
722
|
+
raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
|
723
|
+
else: # if it doesn't, create a new variable
|
724
|
+
if isinstance(value, TreefyState):
|
725
|
+
variable = value.to_state()
|
726
|
+
elif isinstance(value, State):
|
727
|
+
variable = value
|
728
|
+
else:
|
729
|
+
raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
|
730
|
+
children[key] = variable
|
731
|
+
index_ref[noderef.index] = variable
|
732
|
+
|
733
|
+
else:
|
734
|
+
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
|
735
|
+
|
736
|
+
return children
|
737
|
+
|
738
|
+
|
739
|
+
def _graph_unflatten(
|
740
|
+
graph_def: Union[NodeDef[Any], NodeRef[Any]],
|
741
|
+
state_mapping: Mapping[Key, Union[StateLeaf, Mapping]],
|
742
|
+
index_ref: dict[Index, Any],
|
743
|
+
index_ref_cache: Optional[dict[Index, Any]],
|
744
|
+
) -> Any:
|
745
|
+
"""
|
746
|
+
Recursive helper for graph unflatten.
|
747
|
+
|
748
|
+
Args:
|
749
|
+
graph_def: A `GraphDef` instance or an index to a node in the cache.
|
750
|
+
state_mapping: A state mapping from attribute names to variables or subgraphs.
|
751
|
+
index_ref: A mapping from indexes to nodes that have been traversed.
|
752
|
+
If a node is already in the cache, it won't be traversed again.
|
753
|
+
index_ref_cache: A mapping from indexes to existing nodes that can be reused.
|
754
|
+
When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
|
755
|
+
object in an empty state and then filled by the unflatten process, as a result
|
756
|
+
existing graph nodes are mutated to have the new content/topology
|
757
|
+
specified by the nodedef.
|
758
|
+
|
759
|
+
Returns:
|
760
|
+
A node instance.
|
761
|
+
"""
|
762
|
+
|
763
|
+
# if the graph_def is a reference, this means that the node has already been created, so
|
764
|
+
# we return the node from the cache
|
765
|
+
if isinstance(graph_def, NodeRef):
|
766
|
+
return index_ref[graph_def.index]
|
767
|
+
else:
|
768
|
+
assert isinstance(graph_def, NodeDef), f"graph_def must be a NodeDef. But we got: {graph_def}"
|
769
|
+
|
770
|
+
# graph_def must be a registered node type
|
771
|
+
if not _is_node_type(graph_def.type):
|
772
|
+
raise RuntimeError(f'Unsupported type: {graph_def.type}, this is a bug.')
|
773
|
+
|
774
|
+
# check if the index is already in the cache
|
775
|
+
if graph_def.index in index_ref:
|
776
|
+
raise RuntimeError(f'GraphDef index {graph_def.index} already used.')
|
777
|
+
|
778
|
+
# get the node implementation for the node type
|
779
|
+
node_impl = get_node_impl_for_type(graph_def.type)
|
780
|
+
|
781
|
+
if isinstance(node_impl, GraphNodeImpl):
|
782
|
+
# we create an empty node first and add it to the index
|
783
|
+
# this avoids infinite recursion when there is a reference cycle
|
784
|
+
|
785
|
+
if (index_ref_cache is not None) and (graph_def.index in index_ref_cache):
|
786
|
+
# clear the node to leave it in an empty state
|
787
|
+
node = index_ref_cache[graph_def.index]
|
788
|
+
if type(node) != graph_def.type:
|
789
|
+
raise ValueError(f'Expected a node of type {graph_def.type} for index '
|
790
|
+
f'{graph_def.index}, but got a node of type {type(node)}.')
|
791
|
+
node_impl.clear(node)
|
792
|
+
else:
|
793
|
+
# create an empty node
|
794
|
+
node = node_impl.create_empty(graph_def.metadata)
|
795
|
+
|
796
|
+
# add the node to the cache
|
797
|
+
index_ref[graph_def.index] = node
|
798
|
+
|
799
|
+
# get the children (the attributes) of the node
|
800
|
+
children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
|
801
|
+
|
802
|
+
# initialize the node with the children
|
803
|
+
node_impl.init(node, tuple(children.items()))
|
804
|
+
|
805
|
+
else:
|
806
|
+
# if the node type does not support the creation of an empty object it means
|
807
|
+
# that it cannot reference itself, so we can create its children first
|
808
|
+
|
809
|
+
# first, we create the children (attributes)
|
810
|
+
children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
|
811
|
+
# then, we create the node
|
812
|
+
node = node_impl.unflatten(tuple(children.items()), graph_def.metadata)
|
813
|
+
|
814
|
+
return node
|
815
|
+
|
816
|
+
|
817
|
+
@set_module_as('brainstate.graph')
|
818
|
+
def unflatten(
|
819
|
+
graph_def: GraphDef[Any],
|
820
|
+
state_mapping: NestedDict,
|
821
|
+
/,
|
822
|
+
*,
|
823
|
+
index_ref: Optional[dict[Index, Any]] = None,
|
824
|
+
index_ref_cache: Optional[dict[Index, Any]] = None,
|
825
|
+
) -> Any:
|
826
|
+
"""
|
827
|
+
Unflattens a graphdef into a node with the given state tree mapping.
|
828
|
+
|
829
|
+
Parameters
|
830
|
+
----------
|
831
|
+
graph_def : GraphDef
|
832
|
+
A GraphDef instance.
|
833
|
+
state_mapping : NestedDict
|
834
|
+
A NestedDict instance containing the state mapping.
|
835
|
+
index_ref : dict[Index, Any], optional
|
836
|
+
A mapping from indexes to nodes references found during the graph
|
837
|
+
traversal. If not provided, a new empty dictionary is created. This argument
|
838
|
+
can be used to unflatten a sequence of (graphdef, state_mapping) pairs that
|
839
|
+
share the same index space.
|
840
|
+
index_ref_cache : dict[Index, Any], optional
|
841
|
+
A mapping from indexes to existing nodes that can be reused. When a reference
|
842
|
+
is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty
|
843
|
+
state and then filled by the unflatten process. As a result, existing graph
|
844
|
+
nodes are mutated to have the new content/topology specified by the graphdef.
|
845
|
+
|
846
|
+
Returns
|
847
|
+
-------
|
848
|
+
Node
|
849
|
+
The reconstructed node.
|
850
|
+
|
851
|
+
Examples
|
852
|
+
--------
|
853
|
+
.. code-block:: python
|
854
|
+
|
855
|
+
>>> import brainstate
|
856
|
+
>>> class MyNode(brainstate.graph.Node):
|
857
|
+
... def __init__(self):
|
858
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
859
|
+
... self.b = brainstate.nn.Linear(3, 4)
|
860
|
+
...
|
861
|
+
>>> # Flatten a node
|
862
|
+
>>> node = MyNode()
|
863
|
+
>>> graphdef, statetree = brainstate.graph.flatten(node)
|
864
|
+
>>>
|
865
|
+
>>> # Unflatten back to node
|
866
|
+
>>> reconstructed_node = brainstate.graph.unflatten(graphdef, statetree)
|
867
|
+
>>> assert isinstance(reconstructed_node, MyNode)
|
868
|
+
>>> assert isinstance(reconstructed_node.a, brainstate.nn.Linear)
|
869
|
+
>>> assert isinstance(reconstructed_node.b, brainstate.nn.Linear)
|
870
|
+
"""
|
871
|
+
index_ref = {} if index_ref is None else index_ref
|
872
|
+
assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
|
873
|
+
node = _graph_unflatten(graph_def, state_mapping.to_dict(), index_ref, index_ref_cache)
|
874
|
+
return node
|
875
|
+
|
876
|
+
|
877
|
+
def _graph_pop(
|
878
|
+
node: Any,
|
879
|
+
id_to_index: dict[int, Index],
|
880
|
+
path_parts: PathParts,
|
881
|
+
flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
|
882
|
+
predicates: tuple[Predicate, ...],
|
883
|
+
) -> None:
|
884
|
+
if not _is_node(node):
|
885
|
+
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
|
886
|
+
|
887
|
+
if id(node) in id_to_index:
|
888
|
+
return
|
889
|
+
|
890
|
+
id_to_index[id(node)] = len(id_to_index)
|
891
|
+
node_impl = _get_node_impl(node)
|
892
|
+
node_dict = node_impl.node_dict(node)
|
893
|
+
|
894
|
+
for name, value in node_dict.items():
|
895
|
+
if _is_node(value):
|
896
|
+
_graph_pop(
|
897
|
+
node=value,
|
898
|
+
id_to_index=id_to_index,
|
899
|
+
path_parts=(*path_parts, name),
|
900
|
+
flatted_state_dicts=flatted_state_dicts,
|
901
|
+
predicates=predicates,
|
902
|
+
)
|
903
|
+
continue
|
904
|
+
elif not _is_node_leaf(value):
|
905
|
+
continue
|
906
|
+
elif id(value) in id_to_index:
|
907
|
+
continue
|
908
|
+
|
909
|
+
node_path = (*path_parts, name)
|
910
|
+
node_impl = _get_node_impl(node)
|
911
|
+
for state_dicts, predicate in zip(flatted_state_dicts, predicates):
|
912
|
+
if predicate(node_path, value):
|
913
|
+
if isinstance(node_impl, PyTreeNodeImpl):
|
914
|
+
raise ValueError(f'Cannot pop key {name!r} from node of type {type(node).__name__}')
|
915
|
+
id_to_index[id(value)] = len(id_to_index)
|
916
|
+
node_impl.pop_key(node, name)
|
917
|
+
# if isinstance(value, State):
|
918
|
+
# value = value.to_state_ref()
|
919
|
+
state_dicts[node_path] = value # type: ignore[index] # mypy is wrong here?
|
920
|
+
break
|
921
|
+
else:
|
922
|
+
# NOTE: should we raise an error here?
|
923
|
+
pass
|
924
|
+
|
925
|
+
|
926
|
+
@set_module_as('brainstate.graph')
|
927
|
+
def pop_states(
|
928
|
+
node: Any, *filters: Any
|
929
|
+
) -> Union[NestedDict, Tuple[NestedDict, ...]]:
|
930
|
+
"""
|
931
|
+
Pop one or more :class:`State` types from the graph node.
|
932
|
+
|
933
|
+
Parameters
|
934
|
+
----------
|
935
|
+
node : Node
|
936
|
+
A graph node object.
|
937
|
+
*filters
|
938
|
+
One or more :class:`State` objects to filter by.
|
939
|
+
|
940
|
+
Returns
|
941
|
+
-------
|
942
|
+
NestedDict or tuple[NestedDict, ...]
|
943
|
+
The popped :class:`NestedDict` containing the :class:`State`
|
944
|
+
objects that were filtered for.
|
945
|
+
|
946
|
+
Examples
|
947
|
+
--------
|
948
|
+
.. code-block:: python
|
949
|
+
|
950
|
+
>>> import brainstate
|
951
|
+
>>> import jax.numpy as jnp
|
952
|
+
|
953
|
+
>>> class Model(brainstate.nn.Module):
|
954
|
+
... def __init__(self):
|
955
|
+
... super().__init__()
|
956
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
957
|
+
... self.b = brainstate.nn.LIF([10, 2])
|
958
|
+
|
959
|
+
>>> model = Model()
|
960
|
+
>>> with brainstate.catch_new_states('new'):
|
961
|
+
... brainstate.nn.init_all_states(model)
|
962
|
+
|
963
|
+
>>> assert len(model.states()) == 2
|
964
|
+
>>> model_states = brainstate.graph.pop_states(model, 'new')
|
965
|
+
>>> model_states # doctest: +SKIP
|
966
|
+
NestedDict({
|
967
|
+
'b': {
|
968
|
+
'V': {
|
969
|
+
'st': ShortTermState(
|
970
|
+
value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
971
|
+
0., 0., 0.], dtype=float32),
|
972
|
+
tag='new'
|
973
|
+
)
|
974
|
+
}
|
975
|
+
}
|
976
|
+
})
|
977
|
+
"""
|
978
|
+
if len(filters) == 0:
|
979
|
+
raise ValueError('Expected at least one filter')
|
980
|
+
|
981
|
+
id_to_index: dict[int, Index] = {}
|
982
|
+
path_parts: PathParts = ()
|
983
|
+
predicates = tuple(to_predicate(filter) for filter in filters)
|
984
|
+
flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
|
985
|
+
_graph_pop(
|
986
|
+
node=node,
|
987
|
+
id_to_index=id_to_index,
|
988
|
+
path_parts=path_parts,
|
989
|
+
flatted_state_dicts=flatted_state_dicts,
|
990
|
+
predicates=predicates,
|
991
|
+
)
|
992
|
+
states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
|
993
|
+
|
994
|
+
if len(states) == 1:
|
995
|
+
return states[0]
|
996
|
+
else:
|
997
|
+
return states
|
998
|
+
|
999
|
+
|
1000
|
+
def _split_state(
|
1001
|
+
state: GraphStateMapping,
|
1002
|
+
filters: tuple[Filter, ...],
|
1003
|
+
) -> tuple[GraphStateMapping, Unpack[tuple[GraphStateMapping, ...]]]:
|
1004
|
+
if not filters:
|
1005
|
+
return (state,)
|
1006
|
+
states = state.split(*filters)
|
1007
|
+
if isinstance(states, NestedDict):
|
1008
|
+
return (states,)
|
1009
|
+
assert len(states) > 0
|
1010
|
+
return states # type: ignore[return-value]
|
1011
|
+
|
1012
|
+
|
1013
|
+
@set_module_as('brainstate.graph')
|
1014
|
+
def treefy_split(
|
1015
|
+
node: A, *filters: Filter
|
1016
|
+
):
|
1017
|
+
"""
|
1018
|
+
Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s.
|
1019
|
+
|
1020
|
+
NestedDict is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States.
|
1021
|
+
GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is
|
1022
|
+
analogous to JAX's ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to
|
1023
|
+
switch seamlessly between stateful and stateless representations of the graph.
|
1024
|
+
|
1025
|
+
Parameters
|
1026
|
+
----------
|
1027
|
+
node : A
|
1028
|
+
Graph node to split.
|
1029
|
+
*filters
|
1030
|
+
Optional filters to group the state into mutually exclusive substates.
|
1031
|
+
|
1032
|
+
Returns
|
1033
|
+
-------
|
1034
|
+
tuple
|
1035
|
+
``GraphDef`` and one or more ``States`` equal to the number of filters passed.
|
1036
|
+
If no filters are passed, a single ``NestedDict`` is returned.
|
1037
|
+
|
1038
|
+
Examples
|
1039
|
+
--------
|
1040
|
+
.. code-block:: python
|
1041
|
+
|
1042
|
+
>>> import brainstate
|
1043
|
+
>>> import jax, jax.numpy as jnp
|
1044
|
+
|
1045
|
+
>>> class Foo(brainstate.graph.Node):
|
1046
|
+
... def __init__(self):
|
1047
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1048
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1049
|
+
...
|
1050
|
+
>>> node = Foo()
|
1051
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(
|
1052
|
+
... node, brainstate.ParamState, ...
|
1053
|
+
... )
|
1054
|
+
>>> # params contains ParamState variables
|
1055
|
+
>>> # others contains all other state variables
|
1056
|
+
"""
|
1057
|
+
graphdef, state_tree = flatten(node)
|
1058
|
+
states = tuple(_split_state(state_tree, filters))
|
1059
|
+
return graphdef, *states
|
1060
|
+
|
1061
|
+
|
1062
|
+
@set_module_as('brainstate.graph')
|
1063
|
+
def treefy_merge(graphdef: GraphDef[A], *state_mappings) -> A:
|
1064
|
+
"""
|
1065
|
+
The inverse of :func:`split`.
|
1066
|
+
|
1067
|
+
``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
|
1068
|
+
a new node with the same structure as the original node.
|
1069
|
+
|
1070
|
+
Parameters
|
1071
|
+
----------
|
1072
|
+
graphdef : GraphDef[A]
|
1073
|
+
A :class:`GraphDef` object.
|
1074
|
+
*state_mappings
|
1075
|
+
Additional :class:`NestedDict` objects.
|
1076
|
+
|
1077
|
+
Returns
|
1078
|
+
-------
|
1079
|
+
A
|
1080
|
+
The merged :class:`Module`.
|
1081
|
+
|
1082
|
+
Examples
|
1083
|
+
--------
|
1084
|
+
.. code-block:: python
|
1085
|
+
|
1086
|
+
>>> import brainstate
|
1087
|
+
>>> import jax, jax.numpy as jnp
|
1088
|
+
|
1089
|
+
>>> class Foo(brainstate.graph.Node):
|
1090
|
+
... def __init__(self):
|
1091
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1092
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1093
|
+
...
|
1094
|
+
>>> node = Foo()
|
1095
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(
|
1096
|
+
... node, brainstate.ParamState, ...
|
1097
|
+
... )
|
1098
|
+
>>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
1099
|
+
>>> assert isinstance(new_node, Foo)
|
1100
|
+
>>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
1101
|
+
>>> assert isinstance(new_node.a, brainstate.nn.Linear)
|
1102
|
+
"""
|
1103
|
+
state_mapping = GraphStateMapping.merge(*state_mappings)
|
1104
|
+
node = unflatten(graphdef, state_mapping)
|
1105
|
+
return node
|
1106
|
+
|
1107
|
+
|
1108
|
+
def _filters_to_predicates(filters: Tuple[Filter, ...]) -> Tuple[Predicate, ...]:
|
1109
|
+
for i, filter_ in enumerate(filters):
|
1110
|
+
if filter_ in (..., True) and i != len(filters) - 1:
|
1111
|
+
remaining_filters = filters[i + 1:]
|
1112
|
+
if not all(f in (..., True) for f in remaining_filters):
|
1113
|
+
raise ValueError('`...` or `True` can only be used as the last filters, '
|
1114
|
+
f'got {filter_} it at index {i}.')
|
1115
|
+
return tuple(map(to_predicate, filters))
|
1116
|
+
|
1117
|
+
|
1118
|
+
def _split_flatted(
|
1119
|
+
flatted: Iterable[tuple[PathParts, Any]],
|
1120
|
+
filters: tuple[Filter, ...],
|
1121
|
+
) -> tuple[list[tuple[PathParts, Any]], ...]:
|
1122
|
+
predicates = _filters_to_predicates(filters)
|
1123
|
+
|
1124
|
+
# we have n + 1 states, where n is the number of predicates
|
1125
|
+
# the last state is for values that don't match any predicate
|
1126
|
+
flat_states: tuple[list[tuple[PathParts, Any]], ...] = tuple([] for _ in predicates)
|
1127
|
+
|
1128
|
+
for path, value in flatted:
|
1129
|
+
for i, predicate in enumerate(predicates):
|
1130
|
+
if predicate(path, value):
|
1131
|
+
flat_states[i].append((path, value))
|
1132
|
+
break
|
1133
|
+
else:
|
1134
|
+
raise ValueError('Non-exhaustive filters, got a non-empty remainder: '
|
1135
|
+
f'{path} -> {value}.'
|
1136
|
+
'\nUse `...` to match all remaining elements.')
|
1137
|
+
|
1138
|
+
return flat_states
|
1139
|
+
|
1140
|
+
|
1141
|
+
@set_module_as('brainstate.graph')
|
1142
|
+
def nodes(
|
1143
|
+
node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1144
|
+
):
|
1145
|
+
"""
|
1146
|
+
Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
|
1147
|
+
|
1148
|
+
Parameters
|
1149
|
+
----------
|
1150
|
+
node : Node
|
1151
|
+
The node to get nodes from.
|
1152
|
+
*filters
|
1153
|
+
Filters to apply to the nodes.
|
1154
|
+
allowed_hierarchy : tuple[int, int], optional
|
1155
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1156
|
+
|
1157
|
+
Returns
|
1158
|
+
-------
|
1159
|
+
FlattedDict or tuple[FlattedDict, ...]
|
1160
|
+
The filtered nodes.
|
1161
|
+
|
1162
|
+
"""
|
1163
|
+
num_filters = len(filters)
|
1164
|
+
if num_filters == 0:
|
1165
|
+
filters = (..., ...)
|
1166
|
+
else:
|
1167
|
+
filters = (*filters, ...)
|
1168
|
+
|
1169
|
+
nodes_iterable = iter_node(node, allowed_hierarchy=allowed_hierarchy)
|
1170
|
+
flat_nodes = _split_flatted(nodes_iterable, (*filters, ...))
|
1171
|
+
node_maps = tuple(FlattedDict(flat_node) for flat_node in flat_nodes)
|
1172
|
+
if num_filters < 2:
|
1173
|
+
return node_maps[0]
|
1174
|
+
return node_maps[:num_filters]
|
1175
|
+
|
1176
|
+
|
1177
|
+
def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, State]]:
|
1178
|
+
for path, value in iter_leaf(node, allowed_hierarchy=allowed_hierarchy):
|
1179
|
+
if isinstance(value, State):
|
1180
|
+
yield path, value
|
1181
|
+
|
1182
|
+
|
1183
|
+
@set_module_as('brainstate.graph')
|
1184
|
+
def states(
|
1185
|
+
node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1186
|
+
) -> Union[FlattedDict, tuple[FlattedDict, ...]]:
|
1187
|
+
"""
|
1188
|
+
Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
|
1189
|
+
|
1190
|
+
Parameters
|
1191
|
+
----------
|
1192
|
+
node : Node
|
1193
|
+
The node to get states from.
|
1194
|
+
*filters
|
1195
|
+
Filters to apply to the states.
|
1196
|
+
allowed_hierarchy : tuple[int, int], optional
|
1197
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1198
|
+
|
1199
|
+
Returns
|
1200
|
+
-------
|
1201
|
+
FlattedDict or tuple[FlattedDict, ...]
|
1202
|
+
The filtered states.
|
1203
|
+
|
1204
|
+
"""
|
1205
|
+
num_filters = len(filters)
|
1206
|
+
if num_filters == 0:
|
1207
|
+
filters = (..., ...)
|
1208
|
+
else:
|
1209
|
+
filters = (*filters, ...)
|
1210
|
+
|
1211
|
+
states_iterable = _states_generator(node, allowed_hierarchy=allowed_hierarchy)
|
1212
|
+
flat_states = _split_flatted(states_iterable, (*filters, ...))
|
1213
|
+
state_maps = tuple(FlattedDict(flat_state) for flat_state in flat_states)
|
1214
|
+
if num_filters < 2:
|
1215
|
+
return state_maps[0]
|
1216
|
+
return state_maps[:num_filters]
|
1217
|
+
|
1218
|
+
|
1219
|
+
@set_module_as('brainstate.graph')
|
1220
|
+
def treefy_states(
|
1221
|
+
node, *filters,
|
1222
|
+
):
|
1223
|
+
"""
|
1224
|
+
Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
|
1225
|
+
|
1226
|
+
Parameters
|
1227
|
+
----------
|
1228
|
+
node : Node
|
1229
|
+
A graph node object.
|
1230
|
+
*filters
|
1231
|
+
One or more :class:`State` objects to filter by.
|
1232
|
+
|
1233
|
+
Returns
|
1234
|
+
-------
|
1235
|
+
NestedDict or tuple of NestedDict
|
1236
|
+
One or more :class:`NestedDict` mappings.
|
1237
|
+
|
1238
|
+
Examples
|
1239
|
+
--------
|
1240
|
+
.. code-block:: python
|
1241
|
+
|
1242
|
+
>>> import brainstate
|
1243
|
+
>>> class Model(brainstate.nn.Module):
|
1244
|
+
... def __init__(self):
|
1245
|
+
... super().__init__()
|
1246
|
+
... self.l1 = brainstate.nn.Linear(2, 3)
|
1247
|
+
... self.l2 = brainstate.nn.Linear(3, 4)
|
1248
|
+
... def __call__(self, x):
|
1249
|
+
... return self.l2(self.l1(x))
|
1250
|
+
|
1251
|
+
>>> model = Model()
|
1252
|
+
>>> # Get the learnable parameters
|
1253
|
+
>>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
|
1254
|
+
>>> # Get them separately
|
1255
|
+
>>> params, others = brainstate.graph.treefy_states(
|
1256
|
+
... model, brainstate.ParamState, brainstate.ShortTermState
|
1257
|
+
... )
|
1258
|
+
>>> # Get all states together
|
1259
|
+
>>> states = brainstate.graph.treefy_states(model)
|
1260
|
+
"""
|
1261
|
+
_, state_mapping = flatten(node)
|
1262
|
+
if len(filters) == 0:
|
1263
|
+
return state_mapping
|
1264
|
+
else:
|
1265
|
+
state_mappings = state_mapping.filter(*filters)
|
1266
|
+
if len(filters) == 1:
|
1267
|
+
return state_mappings[0]
|
1268
|
+
else:
|
1269
|
+
return state_mappings
|
1270
|
+
|
1271
|
+
|
1272
|
+
def _graph_update_dynamic(node: Any, state: Mapping) -> None:
|
1273
|
+
if not _is_node(node):
|
1274
|
+
raise RuntimeError(f'Unsupported type: {type(node)}')
|
1275
|
+
|
1276
|
+
node_impl = _get_node_impl(node)
|
1277
|
+
node_dict = node_impl.node_dict(node)
|
1278
|
+
for key, value in state.items():
|
1279
|
+
# case 1: new state is being added
|
1280
|
+
if key not in node_dict:
|
1281
|
+
if isinstance(node_impl, PyTreeNodeImpl):
|
1282
|
+
raise ValueError(f'Cannot set key {key!r} on immutable node of '
|
1283
|
+
f'type {type(node).__name__}')
|
1284
|
+
if isinstance(value, State):
|
1285
|
+
# TODO: here maybe error? we should check if the state already belongs to another node?
|
1286
|
+
value = value.to_state_ref() # Convert to state reference for proper state management
|
1287
|
+
node_impl.set_key(node, key, value)
|
1288
|
+
continue
|
1289
|
+
|
1290
|
+
# check values are of the same type
|
1291
|
+
current_value = node_dict[key]
|
1292
|
+
|
1293
|
+
# case 2: subgraph is being updated
|
1294
|
+
if _is_node(current_value):
|
1295
|
+
if _is_state_leaf(value):
|
1296
|
+
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
|
1297
|
+
_graph_update_dynamic(current_value, value)
|
1298
|
+
elif isinstance(value, TreefyState):
|
1299
|
+
# case 3: state leaf is being updated
|
1300
|
+
if not isinstance(current_value, State):
|
1301
|
+
raise ValueError(f'Trying to update a non-State attribute {key!r} with a State: '
|
1302
|
+
f'{value!r}')
|
1303
|
+
current_value.update_from_ref(value)
|
1304
|
+
elif _is_state_leaf(value):
|
1305
|
+
# case 4: state field is being updated
|
1306
|
+
if isinstance(node_impl, PyTreeNodeImpl):
|
1307
|
+
raise ValueError(f'Cannot set key {key!r} on immutable node of '
|
1308
|
+
f'type {type(node).__name__}')
|
1309
|
+
node_impl.set_key(node, key, value)
|
1310
|
+
else:
|
1311
|
+
raise ValueError(f'Unsupported update type: {type(value)} for key {key!r}')
|
1312
|
+
|
1313
|
+
|
1314
|
+
def update_states(
|
1315
|
+
node: Any,
|
1316
|
+
state_dict: Union[NestedDict, FlattedDict],
|
1317
|
+
/,
|
1318
|
+
*state_dicts: Union[NestedDict, FlattedDict]
|
1319
|
+
) -> None:
|
1320
|
+
"""
|
1321
|
+
Update the given graph node with a new :class:`NestedMapping` in-place.
|
1322
|
+
|
1323
|
+
Parameters
|
1324
|
+
----------
|
1325
|
+
node : Node
|
1326
|
+
A graph node to update.
|
1327
|
+
state_dict : NestedDict | FlattedDict
|
1328
|
+
A :class:`NestedMapping` object.
|
1329
|
+
*state_dicts : NestedDict | FlattedDict
|
1330
|
+
Additional :class:`NestedMapping` objects.
|
1331
|
+
|
1332
|
+
"""
|
1333
|
+
if state_dicts:
|
1334
|
+
state_dict = NestedDict.merge(state_dict, *state_dicts)
|
1335
|
+
_graph_update_dynamic(node, state_dict.to_dict())
|
1336
|
+
|
1337
|
+
|
1338
|
+
@set_module_as('brainstate.graph')
|
1339
|
+
def graphdef(node: Any) -> GraphDef[Any]:
|
1340
|
+
"""
|
1341
|
+
Get the :class:`GraphDef` of the given graph node.
|
1342
|
+
|
1343
|
+
Parameters
|
1344
|
+
----------
|
1345
|
+
node : Any
|
1346
|
+
A graph node object.
|
1347
|
+
|
1348
|
+
Returns
|
1349
|
+
-------
|
1350
|
+
GraphDef[Any]
|
1351
|
+
The :class:`GraphDef` of the :class:`Module` object.
|
1352
|
+
|
1353
|
+
Examples
|
1354
|
+
--------
|
1355
|
+
.. code-block:: python
|
1356
|
+
|
1357
|
+
>>> import brainstate
|
1358
|
+
|
1359
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1360
|
+
>>> graphdef, _ = brainstate.graph.treefy_split(model)
|
1361
|
+
>>> assert graphdef == brainstate.graph.graphdef(model)
|
1362
|
+
|
1363
|
+
"""
|
1364
|
+
graphdef, _ = flatten(node)
|
1365
|
+
return graphdef
|
1366
|
+
|
1367
|
+
|
1368
|
+
@set_module_as('brainstate.graph')
|
1369
|
+
def clone(node: A) -> A:
|
1370
|
+
"""
|
1371
|
+
Create a deep copy of the given graph node.
|
1372
|
+
|
1373
|
+
Parameters
|
1374
|
+
----------
|
1375
|
+
node : Node
|
1376
|
+
A graph node object.
|
1377
|
+
|
1378
|
+
Returns
|
1379
|
+
-------
|
1380
|
+
Node
|
1381
|
+
A deep copy of the :class:`Module` object.
|
1382
|
+
|
1383
|
+
Examples
|
1384
|
+
--------
|
1385
|
+
.. code-block:: python
|
1386
|
+
|
1387
|
+
>>> import brainstate
|
1388
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1389
|
+
>>> cloned_model = brainstate.graph.clone(model)
|
1390
|
+
>>> model.weight.value['bias'] += 1
|
1391
|
+
>>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
|
1392
|
+
|
1393
|
+
"""
|
1394
|
+
graphdef, state = treefy_split(node)
|
1395
|
+
return treefy_merge(graphdef, state)
|
1396
|
+
|
1397
|
+
|
1398
|
+
@set_module_as('brainstate.graph')
|
1399
|
+
def iter_leaf(
|
1400
|
+
node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1401
|
+
) -> Iterator[tuple[PathParts, Any]]:
|
1402
|
+
"""
|
1403
|
+
Iterates over all nested leaves in the given graph node, including the current node.
|
1404
|
+
|
1405
|
+
``iter_graph`` creates a generator that yields path and value pairs, where
|
1406
|
+
the path is a tuple of strings or integers representing the path to the value from the
|
1407
|
+
root. Repeated nodes are visited only once. Leaves include static values.
|
1408
|
+
|
1409
|
+
Parameters
|
1410
|
+
----------
|
1411
|
+
node : Any
|
1412
|
+
The node to iterate over.
|
1413
|
+
allowed_hierarchy : tuple[int, int], optional
|
1414
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1415
|
+
|
1416
|
+
Yields
|
1417
|
+
------
|
1418
|
+
Iterator[tuple[PathParts, Any]]
|
1419
|
+
Path and value pairs.
|
1420
|
+
|
1421
|
+
Examples
|
1422
|
+
--------
|
1423
|
+
.. code-block:: python
|
1424
|
+
|
1425
|
+
>>> import brainstate
|
1426
|
+
>>> import jax.numpy as jnp
|
1427
|
+
|
1428
|
+
>>> class Linear(brainstate.nn.Module):
|
1429
|
+
... def __init__(self, din, dout):
|
1430
|
+
... super().__init__()
|
1431
|
+
... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
1432
|
+
... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
1433
|
+
... self.a = 1
|
1434
|
+
...
|
1435
|
+
>>> module = Linear(3, 4)
|
1436
|
+
...
|
1437
|
+
>>> for path, value in brainstate.graph.iter_leaf([module, module]):
|
1438
|
+
... print(path, type(value).__name__)
|
1439
|
+
...
|
1440
|
+
(0, 'a') int
|
1441
|
+
(0, 'bias') ParamState
|
1442
|
+
(0, 'weight') ParamState
|
1443
|
+
|
1444
|
+
"""
|
1445
|
+
|
1446
|
+
def _iter_graph_leaf(
|
1447
|
+
node_: Any,
|
1448
|
+
visited_: set[int],
|
1449
|
+
path_parts_: PathParts,
|
1450
|
+
level_: int,
|
1451
|
+
) -> Iterator[tuple[PathParts, Any]]:
|
1452
|
+
if level_ > allowed_hierarchy[1]:
|
1453
|
+
return
|
1454
|
+
|
1455
|
+
if _is_node(node_):
|
1456
|
+
if id(node_) in visited_:
|
1457
|
+
return
|
1458
|
+
visited_.add(id(node_))
|
1459
|
+
node_dict = _get_node_impl(node_).node_dict(node_)
|
1460
|
+
for key, value in node_dict.items():
|
1461
|
+
yield from _iter_graph_leaf(
|
1462
|
+
value,
|
1463
|
+
visited_,
|
1464
|
+
(*path_parts_, key),
|
1465
|
+
level_ + 1 if _is_graph_node(value) else level_
|
1466
|
+
)
|
1467
|
+
else:
|
1468
|
+
if level_ >= allowed_hierarchy[0]:
|
1469
|
+
yield path_parts_, node_
|
1470
|
+
|
1471
|
+
visited: set[int] = set()
|
1472
|
+
path_parts: PathParts = ()
|
1473
|
+
level: int = 0
|
1474
|
+
yield from _iter_graph_leaf(node, visited, path_parts, level)
|
1475
|
+
|
1476
|
+
|
1477
|
+
@set_module_as('brainstate.graph')
|
1478
|
+
def iter_node(
|
1479
|
+
node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1480
|
+
) -> Iterator[Tuple[PathParts, Any]]:
|
1481
|
+
"""
|
1482
|
+
Iterates over all nested nodes of the given graph node, including the current node.
|
1483
|
+
|
1484
|
+
``iter_graph`` creates a generator that yields path and value pairs, where
|
1485
|
+
the path is a tuple of strings or integers representing the path to the value from the
|
1486
|
+
root. Repeated nodes are visited only once. Leaves include static values.
|
1487
|
+
|
1488
|
+
Parameters
|
1489
|
+
----------
|
1490
|
+
node : Any
|
1491
|
+
The node to iterate over.
|
1492
|
+
allowed_hierarchy : tuple[int, int], optional
|
1493
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1494
|
+
|
1495
|
+
Yields
|
1496
|
+
------
|
1497
|
+
Iterator[tuple[PathParts, Any]]
|
1498
|
+
Path and node pairs.
|
1499
|
+
|
1500
|
+
Examples
|
1501
|
+
--------
|
1502
|
+
.. code-block:: python
|
1503
|
+
|
1504
|
+
>>> import brainstate
|
1505
|
+
>>> import jax.numpy as jnp
|
1506
|
+
|
1507
|
+
>>> class Model(brainstate.nn.Module):
|
1508
|
+
... def __init__(self):
|
1509
|
+
... super().__init__()
|
1510
|
+
... self.a = brainstate.nn.Linear(1, 2)
|
1511
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1512
|
+
... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
1513
|
+
... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
1514
|
+
... self.b.a = brainstate.nn.LIF(2)
|
1515
|
+
...
|
1516
|
+
>>> model = Model()
|
1517
|
+
...
|
1518
|
+
>>> for path, node in brainstate.graph.iter_node([model, model]):
|
1519
|
+
... print(path, node.__class__.__name__)
|
1520
|
+
...
|
1521
|
+
(0, 'a') Linear
|
1522
|
+
(0, 'b', 'a') LIF
|
1523
|
+
(0, 'b') Linear
|
1524
|
+
(0, 'c', 0) Linear
|
1525
|
+
(0, 'c', 1) Linear
|
1526
|
+
(0, 'd', 'x') Linear
|
1527
|
+
(0, 'd', 'y') Linear
|
1528
|
+
(0,) Model
|
1529
|
+
|
1530
|
+
"""
|
1531
|
+
|
1532
|
+
def _iter_graph_node(
|
1533
|
+
node_: Any,
|
1534
|
+
visited_: set[int],
|
1535
|
+
path_parts_: PathParts,
|
1536
|
+
level_: int,
|
1537
|
+
) -> Iterator[tuple[PathParts, Any]]:
|
1538
|
+
if level_ > allowed_hierarchy[1]:
|
1539
|
+
return
|
1540
|
+
|
1541
|
+
if _is_node(node_):
|
1542
|
+
if id(node_) in visited_:
|
1543
|
+
return
|
1544
|
+
|
1545
|
+
visited_.add(id(node_))
|
1546
|
+
node_dict = _get_node_impl(node_).node_dict(node_)
|
1547
|
+
for key, value in node_dict.items():
|
1548
|
+
yield from _iter_graph_node(value, visited_, (*path_parts_, key),
|
1549
|
+
level_ + 1 if _is_graph_node(value) else level_)
|
1550
|
+
|
1551
|
+
if _is_graph_node(node_) and level_ >= allowed_hierarchy[0]:
|
1552
|
+
yield path_parts_, node_
|
1553
|
+
|
1554
|
+
visited: set[int] = set()
|
1555
|
+
path_parts: PathParts = ()
|
1556
|
+
level: int = 0
|
1557
|
+
yield from _iter_graph_node(node, visited, path_parts, level)
|
1558
|
+
|
1559
|
+
|
1560
|
+
# --------------------------------------------------------
|
1561
|
+
# Graph operations: end
|
1562
|
+
# --------------------------------------------------------
|
1563
|
+
|
1564
|
+
|
1565
|
+
@dataclasses.dataclass(frozen=True)
|
1566
|
+
class Static(Generic[A]):
|
1567
|
+
"""
|
1568
|
+
An empty pytree node that treats its inner value as static.
|
1569
|
+
|
1570
|
+
``value`` must define ``__eq__`` and ``__hash__``.
|
1571
|
+
|
1572
|
+
Attributes
|
1573
|
+
----------
|
1574
|
+
value : A
|
1575
|
+
The static value to wrap.
|
1576
|
+
|
1577
|
+
"""
|
1578
|
+
|
1579
|
+
value: A
|
1580
|
+
|
1581
|
+
|
1582
|
+
jax.tree_util.register_static(Static)
|
1583
|
+
|
1584
|
+
|
1585
|
+
# ---------------------------------------------------------
|
1586
|
+
# Pytree
|
1587
|
+
# ---------------------------------------------------------
|
1588
|
+
|
1589
|
+
class PytreeType:
|
1590
|
+
...
|
1591
|
+
|
1592
|
+
|
1593
|
+
def _key_path_to_key(key: Any) -> Key:
|
1594
|
+
if isinstance(key, jax.tree_util.SequenceKey):
|
1595
|
+
return key.idx
|
1596
|
+
elif isinstance(
|
1597
|
+
key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
|
1598
|
+
):
|
1599
|
+
if not isinstance(key.key, Key):
|
1600
|
+
raise ValueError(
|
1601
|
+
f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
|
1602
|
+
)
|
1603
|
+
return key.key
|
1604
|
+
elif isinstance(key, jax.tree_util.GetAttrKey):
|
1605
|
+
return key.name
|
1606
|
+
else:
|
1607
|
+
return str(key)
|
1608
|
+
|
1609
|
+
|
1610
|
+
def _flatten_pytree(pytree: Any) -> Tuple[Tuple[Tuple, ...], jax.tree_util.PyTreeDef]:
|
1611
|
+
leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
|
1612
|
+
nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
|
1613
|
+
return nodes, treedef
|
1614
|
+
|
1615
|
+
|
1616
|
+
def _unflatten_pytree(
|
1617
|
+
nodes: tuple[tuple, ...],
|
1618
|
+
treedef: jax.tree_util.PyTreeDef
|
1619
|
+
) -> Any:
|
1620
|
+
pytree = treedef.unflatten(value for _, value in nodes)
|
1621
|
+
return pytree
|
1622
|
+
|
1623
|
+
|
1624
|
+
PYTREE_NODE_IMPL = PyTreeNodeImpl(type=PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree)
|