brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_dynamics.py
CHANGED
@@ -1,1267 +1,870 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
"""
|
20
|
-
All the basic dynamics class for the ``brainstate``.
|
21
|
-
|
22
|
-
For handling dynamical systems:
|
23
|
-
|
24
|
-
- ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
|
25
|
-
then ``Dynamics``, finally others.
|
26
|
-
- ``Projection``: The class for the synaptic projection.
|
27
|
-
- ``Dynamics``: The class for the dynamical system.
|
28
|
-
|
29
|
-
For handling the delays:
|
30
|
-
|
31
|
-
- ``Delay``: The class for all delays.
|
32
|
-
- ``DelayAccess``: The class for the delay access.
|
33
|
-
|
34
|
-
"""
|
35
|
-
|
36
|
-
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar,
|
37
|
-
|
38
|
-
import jax
|
39
|
-
import numpy as np
|
40
|
-
|
41
|
-
from brainstate import environ
|
42
|
-
from brainstate._state import State
|
43
|
-
from brainstate.graph import Node
|
44
|
-
from brainstate.
|
45
|
-
from
|
46
|
-
from .
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
'Dynamics',
|
52
|
-
|
53
|
-
'
|
54
|
-
'
|
55
|
-
'
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
- ``
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
#
|
131
|
-
|
132
|
-
|
133
|
-
#
|
134
|
-
|
135
|
-
|
136
|
-
#
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
"""
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
"""
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
self
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
if
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
"""
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
The
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
"""
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
-----
|
872
|
-
When called, this class retrieves the current value of the referenced item.
|
873
|
-
Use the `.delay` property to access delayed versions of the state.
|
874
|
-
|
875
|
-
"""
|
876
|
-
|
877
|
-
def __init__(self, module: Dynamics, item: str):
|
878
|
-
"""
|
879
|
-
Initialize a Prefetch object.
|
880
|
-
|
881
|
-
Parameters
|
882
|
-
----------
|
883
|
-
module : Module
|
884
|
-
The module that contains or will contain the referenced item.
|
885
|
-
item : str
|
886
|
-
The attribute name of the state or variable to prefetch.
|
887
|
-
"""
|
888
|
-
super().__init__()
|
889
|
-
self.module = module
|
890
|
-
self.item = item
|
891
|
-
|
892
|
-
@property
|
893
|
-
def delay(self):
|
894
|
-
"""
|
895
|
-
Access delayed versions of the prefetched item.
|
896
|
-
|
897
|
-
Returns
|
898
|
-
-------
|
899
|
-
PrefetchDelay
|
900
|
-
An object that provides access to delayed versions of the prefetched item.
|
901
|
-
"""
|
902
|
-
return PrefetchDelay(self.module, self.item)
|
903
|
-
# return PrefetchDelayAt(self.module, self.item, time)
|
904
|
-
|
905
|
-
def __call__(self, *args, **kwargs):
|
906
|
-
"""
|
907
|
-
Get the current value of the prefetched item.
|
908
|
-
|
909
|
-
Returns
|
910
|
-
-------
|
911
|
-
Any
|
912
|
-
The current value of the referenced item. If the item is a State object,
|
913
|
-
returns its value attribute, otherwise returns the item itself.
|
914
|
-
"""
|
915
|
-
item = _get_prefetch_item(self)
|
916
|
-
return item.value if isinstance(item, State) else item
|
917
|
-
|
918
|
-
def get_item_value(self):
|
919
|
-
"""
|
920
|
-
Get the current value of the prefetched item.
|
921
|
-
|
922
|
-
Similar to __call__, but explicitly named for clarity.
|
923
|
-
|
924
|
-
Returns
|
925
|
-
-------
|
926
|
-
Any
|
927
|
-
The current value of the referenced item. If the item is a State object,
|
928
|
-
returns its value attribute, otherwise returns the item itself.
|
929
|
-
"""
|
930
|
-
item = _get_prefetch_item(self)
|
931
|
-
return item.value if isinstance(item, State) else item
|
932
|
-
|
933
|
-
def get_item(self):
|
934
|
-
"""
|
935
|
-
Get the referenced item object itself, not its value.
|
936
|
-
|
937
|
-
Returns
|
938
|
-
-------
|
939
|
-
Any
|
940
|
-
The actual referenced item from the module, which could be a State
|
941
|
-
object or any other attribute.
|
942
|
-
"""
|
943
|
-
return _get_prefetch_item(self)
|
944
|
-
|
945
|
-
|
946
|
-
class PrefetchDelay(Node):
|
947
|
-
"""
|
948
|
-
Provides access to delayed versions of a prefetched state or variable.
|
949
|
-
|
950
|
-
This class acts as an intermediary for accessing delayed values of module variables.
|
951
|
-
It doesn't retrieve values directly but provides methods to specify the delay time
|
952
|
-
via the `at()` method.
|
953
|
-
|
954
|
-
Parameters
|
955
|
-
----------
|
956
|
-
module : Dynamics
|
957
|
-
The dynamics module that contains the referenced state or variable.
|
958
|
-
item : str
|
959
|
-
The name of the state or variable to access with delay.
|
960
|
-
|
961
|
-
Examples
|
962
|
-
--------
|
963
|
-
>>> import brainstate
|
964
|
-
>>> import brainunit as u
|
965
|
-
>>> neuron = brainstate.nn.LIF(10)
|
966
|
-
>>> # Access voltage delayed by 5ms
|
967
|
-
>>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
|
968
|
-
>>> delayed_value = delayed_v() # Get the delayed value
|
969
|
-
"""
|
970
|
-
|
971
|
-
def __init__(self, module: Dynamics, item: str):
|
972
|
-
self.module = module
|
973
|
-
self.item = item
|
974
|
-
|
975
|
-
def at(self, *delay_time):
|
976
|
-
"""
|
977
|
-
Specifies the delay time for accessing the variable.
|
978
|
-
|
979
|
-
Parameters
|
980
|
-
----------
|
981
|
-
time : ArrayLike
|
982
|
-
The amount of time to delay the variable access, typically in time units
|
983
|
-
(e.g., milliseconds).
|
984
|
-
|
985
|
-
Returns
|
986
|
-
-------
|
987
|
-
PrefetchDelayAt
|
988
|
-
An object that provides access to the variable at the specified delay time.
|
989
|
-
"""
|
990
|
-
return PrefetchDelayAt(self.module, self.item, delay_time)
|
991
|
-
|
992
|
-
|
993
|
-
class PrefetchDelayAt(Node):
|
994
|
-
"""
|
995
|
-
Provides access to a specific delayed state or variable value at the specific time.
|
996
|
-
|
997
|
-
This class represents the final step in the prefetch delay chain, providing
|
998
|
-
actual access to state values at a specific delay time. It converts the
|
999
|
-
specified time delay into steps and registers the delay with the appropriate
|
1000
|
-
StateWithDelay handler.
|
1001
|
-
|
1002
|
-
Parameters
|
1003
|
-
----------
|
1004
|
-
module : Dynamics
|
1005
|
-
The dynamics module that contains the referenced state or variable.
|
1006
|
-
item : str
|
1007
|
-
The name of the state or variable to access with delay.
|
1008
|
-
time : ArrayLike
|
1009
|
-
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1010
|
-
|
1011
|
-
Examples
|
1012
|
-
--------
|
1013
|
-
>>> import brainstate
|
1014
|
-
>>> import brainunit as u
|
1015
|
-
>>> neuron = brainstate.nn.LIF(10)
|
1016
|
-
>>> # Create a reference to voltage delayed by 5ms
|
1017
|
-
>>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
|
1018
|
-
>>> # Get the delayed value
|
1019
|
-
>>> v_value = delayed_v()
|
1020
|
-
"""
|
1021
|
-
|
1022
|
-
def __init__(
|
1023
|
-
self,
|
1024
|
-
module: Dynamics,
|
1025
|
-
item: str,
|
1026
|
-
delay_time: Tuple,
|
1027
|
-
init: Callable = None
|
1028
|
-
):
|
1029
|
-
"""
|
1030
|
-
Initialize a PrefetchDelayAt object.
|
1031
|
-
|
1032
|
-
Parameters
|
1033
|
-
----------
|
1034
|
-
module : Dynamics
|
1035
|
-
The dynamics module that contains the referenced state or variable.
|
1036
|
-
item : str
|
1037
|
-
The name of the state or variable to access with delay.
|
1038
|
-
delay_time : Tuple
|
1039
|
-
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1040
|
-
"""
|
1041
|
-
super().__init__()
|
1042
|
-
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1043
|
-
self.module = module
|
1044
|
-
self.item = item
|
1045
|
-
if not isinstance(delay_time, (tuple, list)):
|
1046
|
-
delay_time = (delay_time,)
|
1047
|
-
self.delay_time = delay_time
|
1048
|
-
if len(delay_time) > 0:
|
1049
|
-
key = _get_prefetch_delay_key(item)
|
1050
|
-
if not module._has_after_update(key):
|
1051
|
-
module._add_after_update(
|
1052
|
-
key,
|
1053
|
-
not_receive_update_output(
|
1054
|
-
StateWithDelay(module, item, init=init)
|
1055
|
-
)
|
1056
|
-
)
|
1057
|
-
self.state_delay: StateWithDelay = module._get_after_update(key)
|
1058
|
-
self.delay_info = self.state_delay.register_delay(*delay_time)
|
1059
|
-
|
1060
|
-
def __call__(self, *args, **kwargs):
|
1061
|
-
"""
|
1062
|
-
Retrieve the value of the state at the specified delay time.
|
1063
|
-
|
1064
|
-
Returns
|
1065
|
-
-------
|
1066
|
-
Any
|
1067
|
-
The value of the state or variable at the specified delay time.
|
1068
|
-
"""
|
1069
|
-
if len(self.delay_time) == 0:
|
1070
|
-
return _get_prefetch_item(self).value
|
1071
|
-
else:
|
1072
|
-
return self.state_delay.retrieve_at_step(*self.delay_info)
|
1073
|
-
|
1074
|
-
|
1075
|
-
class OutputDelayAt(Node):
|
1076
|
-
"""
|
1077
|
-
Provides access to a specific delayed state or variable value at the specific time.
|
1078
|
-
|
1079
|
-
This class represents the final step in the prefetch delay chain, providing
|
1080
|
-
actual access to state values at a specific delay time. It converts the
|
1081
|
-
specified time delay into steps and registers the delay with the appropriate
|
1082
|
-
StateWithDelay handler.
|
1083
|
-
|
1084
|
-
Parameters
|
1085
|
-
----------
|
1086
|
-
module : Dynamics
|
1087
|
-
The dynamics module that contains the referenced state or variable.
|
1088
|
-
time : ArrayLike
|
1089
|
-
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1090
|
-
|
1091
|
-
Examples
|
1092
|
-
--------
|
1093
|
-
>>> import brainstate
|
1094
|
-
>>> import brainunit as u
|
1095
|
-
>>> neuron = brainstate.nn.LIF(10)
|
1096
|
-
>>> # Create a reference to voltage delayed by 5ms
|
1097
|
-
>>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
|
1098
|
-
>>> # Get the delayed value
|
1099
|
-
>>> v_value = delayed_spike()
|
1100
|
-
"""
|
1101
|
-
|
1102
|
-
def __init__(
|
1103
|
-
self,
|
1104
|
-
module: Dynamics,
|
1105
|
-
delay_time: Tuple,
|
1106
|
-
):
|
1107
|
-
super().__init__()
|
1108
|
-
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1109
|
-
self.module = module
|
1110
|
-
key = _get_output_delay_key()
|
1111
|
-
if not module._has_after_update(key):
|
1112
|
-
delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
|
1113
|
-
module._add_after_update(key, receive_update_output(delay))
|
1114
|
-
self.out_delay: Delay = module._get_after_update(key)
|
1115
|
-
self.delay_info = self.out_delay.register_delay(*delay_time)
|
1116
|
-
|
1117
|
-
def __call__(self, *args, **kwargs):
|
1118
|
-
return self.out_delay.retrieve_at_step(*self.delay_info)
|
1119
|
-
|
1120
|
-
|
1121
|
-
def _get_prefetch_delay_key(item) -> str:
|
1122
|
-
return f'{item}-prefetch-delay'
|
1123
|
-
|
1124
|
-
|
1125
|
-
def _get_output_delay_key() -> str:
|
1126
|
-
return f'output-delay'
|
1127
|
-
|
1128
|
-
|
1129
|
-
def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
|
1130
|
-
item = getattr(target.module, target.item, None)
|
1131
|
-
if item is None:
|
1132
|
-
raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
|
1133
|
-
return item
|
1134
|
-
|
1135
|
-
|
1136
|
-
def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
|
1137
|
-
assert isinstance(target.module, Dynamics), (
|
1138
|
-
f'The target module should be an instance '
|
1139
|
-
f'of Dynamics. But got {target.module}.'
|
1140
|
-
)
|
1141
|
-
delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
|
1142
|
-
if not isinstance(delay, StateWithDelay):
|
1143
|
-
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
1144
|
-
f'its delay. But got {delay}.')
|
1145
|
-
return delay
|
1146
|
-
|
1147
|
-
|
1148
|
-
def maybe_init_prefetch(target, *args, **kwargs):
|
1149
|
-
"""
|
1150
|
-
Initialize a prefetch target if needed, based on its type.
|
1151
|
-
|
1152
|
-
This function ensures that prefetch references are properly initialized
|
1153
|
-
and ready to use. It handles different types of prefetch objects by
|
1154
|
-
performing the appropriate initialization action:
|
1155
|
-
- For :py:class:`Prefetch` objects: retrieves the referenced item
|
1156
|
-
- For :py:class:`PrefetchDelay` objects: retrieves the delay handler
|
1157
|
-
- For :py:class:`PrefetchDelayAt` objects: registers the specified delay
|
1158
|
-
|
1159
|
-
Parameters
|
1160
|
-
----------
|
1161
|
-
target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
|
1162
|
-
The prefetch target to initialize.
|
1163
|
-
*args : Any
|
1164
|
-
Additional positional arguments (unused).
|
1165
|
-
**kwargs : Any
|
1166
|
-
Additional keyword arguments (unused).
|
1167
|
-
|
1168
|
-
Returns
|
1169
|
-
-------
|
1170
|
-
None
|
1171
|
-
This function performs initialization side effects only.
|
1172
|
-
|
1173
|
-
Notes
|
1174
|
-
-----
|
1175
|
-
This function is typically called internally when prefetched references
|
1176
|
-
are used to ensure they are properly set up before access.
|
1177
|
-
"""
|
1178
|
-
if isinstance(target, Prefetch):
|
1179
|
-
_get_prefetch_item(target)
|
1180
|
-
|
1181
|
-
elif isinstance(target, PrefetchDelay):
|
1182
|
-
_get_prefetch_item_delay(target)
|
1183
|
-
|
1184
|
-
elif isinstance(target, PrefetchDelayAt):
|
1185
|
-
pass
|
1186
|
-
# delay = _get_prefetch_item_delay(target)
|
1187
|
-
# delay.register_delay(*target.delay_time)
|
1188
|
-
|
1189
|
-
|
1190
|
-
DynamicsGroup = Module
|
1191
|
-
|
1192
|
-
|
1193
|
-
def receive_update_output(cls: object):
|
1194
|
-
"""
|
1195
|
-
The decorator to mark the object (as the after updates) to receive the output of the update function.
|
1196
|
-
|
1197
|
-
That is, the `aft_update` will receive the return of the update function::
|
1198
|
-
|
1199
|
-
ret = model.update(*args, **kwargs)
|
1200
|
-
for fun in model.aft_updates:
|
1201
|
-
fun(ret)
|
1202
|
-
|
1203
|
-
"""
|
1204
|
-
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
1205
|
-
if hasattr(cls, '_not_receive_update_output'):
|
1206
|
-
delattr(cls, '_not_receive_update_output')
|
1207
|
-
return cls
|
1208
|
-
|
1209
|
-
|
1210
|
-
def not_receive_update_output(cls: T) -> T:
|
1211
|
-
"""
|
1212
|
-
The decorator to mark the object (as the after updates) to not receive the output of the update function.
|
1213
|
-
|
1214
|
-
That is, the `aft_update` will not receive the return of the update function::
|
1215
|
-
|
1216
|
-
ret = model.update(*args, **kwargs)
|
1217
|
-
for fun in model.aft_updates:
|
1218
|
-
fun()
|
1219
|
-
|
1220
|
-
"""
|
1221
|
-
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
1222
|
-
cls._not_receive_update_output = True
|
1223
|
-
return cls
|
1224
|
-
|
1225
|
-
|
1226
|
-
def receive_update_input(cls: object):
|
1227
|
-
"""
|
1228
|
-
The decorator to mark the object (as the before updates) to receive the input of the update function.
|
1229
|
-
|
1230
|
-
That is, the `bef_update` will receive the input of the update function::
|
1231
|
-
|
1232
|
-
|
1233
|
-
for fun in model.bef_updates:
|
1234
|
-
fun(*args, **kwargs)
|
1235
|
-
model.update(*args, **kwargs)
|
1236
|
-
|
1237
|
-
"""
|
1238
|
-
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
1239
|
-
cls._receive_update_input = True
|
1240
|
-
return cls
|
1241
|
-
|
1242
|
-
|
1243
|
-
def not_receive_update_input(cls: object):
|
1244
|
-
"""
|
1245
|
-
The decorator to mark the object (as the before updates) to not receive the input of the update function.
|
1246
|
-
|
1247
|
-
That is, the `bef_update` will not receive the input of the update function::
|
1248
|
-
|
1249
|
-
for fun in model.bef_updates:
|
1250
|
-
fun()
|
1251
|
-
model.update()
|
1252
|
-
|
1253
|
-
"""
|
1254
|
-
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
1255
|
-
if hasattr(cls, '_receive_update_input'):
|
1256
|
-
delattr(cls, '_receive_update_input')
|
1257
|
-
return cls
|
1258
|
-
|
1259
|
-
|
1260
|
-
def _input_label_start(label: str):
|
1261
|
-
# unify the input label repr.
|
1262
|
-
return f'{label} // '
|
1263
|
-
|
1264
|
-
|
1265
|
-
def _input_label_repr(name: str, label: Optional[str] = None):
|
1266
|
-
# unify the input label repr.
|
1267
|
-
return name if label is None else (_input_label_start(label) + str(name))
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
"""
|
20
|
+
All the basic dynamics class for the ``brainstate``.
|
21
|
+
|
22
|
+
For handling dynamical systems:
|
23
|
+
|
24
|
+
- ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
|
25
|
+
then ``Dynamics``, finally others.
|
26
|
+
- ``Projection``: The class for the synaptic projection.
|
27
|
+
- ``Dynamics``: The class for the dynamical system.
|
28
|
+
|
29
|
+
For handling the delays:
|
30
|
+
|
31
|
+
- ``Delay``: The class for all delays.
|
32
|
+
- ``DelayAccess``: The class for the delay access.
|
33
|
+
|
34
|
+
"""
|
35
|
+
|
36
|
+
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, Tuple
|
37
|
+
|
38
|
+
import jax
|
39
|
+
import numpy as np
|
40
|
+
|
41
|
+
from brainstate import environ
|
42
|
+
from brainstate._state import State
|
43
|
+
from brainstate.graph import Node
|
44
|
+
from brainstate.typing import Size, ArrayLike
|
45
|
+
from ._delay import StateWithDelay, Delay
|
46
|
+
from ._module import Module
|
47
|
+
|
48
|
+
T = TypeVar('T')
|
49
|
+
|
50
|
+
__all__ = [
|
51
|
+
'Dynamics',
|
52
|
+
|
53
|
+
'receive_update_output',
|
54
|
+
'not_receive_update_output',
|
55
|
+
'receive_update_input',
|
56
|
+
'not_receive_update_input',
|
57
|
+
|
58
|
+
'Prefetch',
|
59
|
+
'PrefetchDelay',
|
60
|
+
'PrefetchDelayAt',
|
61
|
+
'OutputDelayAt',
|
62
|
+
]
|
63
|
+
|
64
|
+
|
65
|
+
class Dynamics(Module):
|
66
|
+
"""
|
67
|
+
Base class for implementing neural dynamics models in BrainState.
|
68
|
+
|
69
|
+
Dynamics classes represent the core computational units in neural simulations,
|
70
|
+
implementing the differential equations or update rules that govern neural activity.
|
71
|
+
This class provides infrastructure for managing neural populations, handling inputs,
|
72
|
+
and coordinating updates within the simulation framework.
|
73
|
+
|
74
|
+
The Dynamics class serves several key purposes:
|
75
|
+
1. Managing neuron population geometry and size information
|
76
|
+
2. Handling current and delta (instantaneous change) inputs to neurons
|
77
|
+
3. Supporting before/after update hooks for computational dependencies
|
78
|
+
4. Providing access to delayed state variables through the prefetch mechanism
|
79
|
+
5. Establishing the execution order in neural network simulations
|
80
|
+
|
81
|
+
Parameters
|
82
|
+
----------
|
83
|
+
in_size : Size
|
84
|
+
The geometry of the neuron population. Can be an integer (e.g., 10) for
|
85
|
+
1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
|
86
|
+
name : Optional[str], default=None
|
87
|
+
Optional name identifier for this dynamics module.
|
88
|
+
|
89
|
+
Attributes
|
90
|
+
----------
|
91
|
+
in_size : tuple
|
92
|
+
The shape/geometry of the neuron population.
|
93
|
+
out_size : tuple
|
94
|
+
The output shape, typically matches in_size.
|
95
|
+
current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
|
96
|
+
Dictionary of registered current input functions or arrays.
|
97
|
+
delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
|
98
|
+
Dictionary of registered delta input functions or arrays.
|
99
|
+
before_updates : Optional[Dict[Hashable, Callable]]
|
100
|
+
Dictionary of functions to call before the main update.
|
101
|
+
after_updates : Optional[Dict[Hashable, Callable]]
|
102
|
+
Dictionary of functions to call after the main update.
|
103
|
+
|
104
|
+
Notes
|
105
|
+
-----
|
106
|
+
In the BrainState execution sequence, Dynamics modules are updated after
|
107
|
+
Projection modules and before other module types, reflecting the natural
|
108
|
+
flow of information in neural systems.
|
109
|
+
|
110
|
+
There are several essential attributes:
|
111
|
+
|
112
|
+
- ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
|
113
|
+
neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
|
114
|
+
a 3-dimensional neuron group.
|
115
|
+
- ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
|
116
|
+
`num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
|
117
|
+
|
118
|
+
|
119
|
+
See Also
|
120
|
+
--------
|
121
|
+
Module : Parent class providing base module functionality
|
122
|
+
Projection : Class for handling synaptic projections between neural populations
|
123
|
+
DynamicsGroup : Container for organizing multiple dynamics modules
|
124
|
+
"""
|
125
|
+
|
126
|
+
__module__ = 'brainstate.nn'
|
127
|
+
|
128
|
+
graph_invisible_attrs = ()
|
129
|
+
|
130
|
+
# before updates
|
131
|
+
_before_updates: Optional[Dict[Hashable, Callable]]
|
132
|
+
|
133
|
+
# after updates
|
134
|
+
_after_updates: Optional[Dict[Hashable, Callable]]
|
135
|
+
|
136
|
+
# current inputs
|
137
|
+
_current_inputs: Optional[Dict[str, ArrayLike | Callable]]
|
138
|
+
|
139
|
+
# delta inputs
|
140
|
+
_delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
|
141
|
+
|
142
|
+
def __init__(self, in_size: Size, name: Optional[str] = None):
|
143
|
+
# initialize
|
144
|
+
super().__init__(name=name)
|
145
|
+
|
146
|
+
# geometry size of neuron population
|
147
|
+
if isinstance(in_size, (list, tuple)):
|
148
|
+
if len(in_size) <= 0:
|
149
|
+
raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
|
150
|
+
if not isinstance(in_size[0], (int, np.integer)):
|
151
|
+
raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
|
152
|
+
in_size = tuple(in_size)
|
153
|
+
elif isinstance(in_size, (int, np.integer)):
|
154
|
+
in_size = (in_size,)
|
155
|
+
else:
|
156
|
+
raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
|
157
|
+
self.in_size = in_size
|
158
|
+
|
159
|
+
# before updates
|
160
|
+
self._before_updates = None
|
161
|
+
|
162
|
+
# after updates
|
163
|
+
self._after_updates = None
|
164
|
+
|
165
|
+
# in-/out- size of neuron population
|
166
|
+
self.out_size = self.in_size
|
167
|
+
|
168
|
+
@property
|
169
|
+
def varshape(self):
|
170
|
+
"""
|
171
|
+
Get the shape of variables in the neuron group.
|
172
|
+
|
173
|
+
This property provides access to the geometry (shape) of the neuron population,
|
174
|
+
which determines how variables and states are structured.
|
175
|
+
|
176
|
+
Returns
|
177
|
+
-------
|
178
|
+
tuple
|
179
|
+
A tuple representing the dimensional shape of the neuron group,
|
180
|
+
matching the in_size parameter provided during initialization.
|
181
|
+
|
182
|
+
See Also
|
183
|
+
--------
|
184
|
+
in_size : The input geometry specification for the neuron group
|
185
|
+
"""
|
186
|
+
return self.in_size
|
187
|
+
|
188
|
+
def prefetch(self, item: str) -> 'Prefetch':
|
189
|
+
"""
|
190
|
+
Create a reference to a state or variable that may not be initialized yet.
|
191
|
+
|
192
|
+
This method allows accessing module attributes or states before they are
|
193
|
+
fully defined, acting as a placeholder that will be resolved when called.
|
194
|
+
Particularly useful for creating references to variables that will be defined
|
195
|
+
during initialization or runtime.
|
196
|
+
|
197
|
+
Parameters
|
198
|
+
----------
|
199
|
+
item : str
|
200
|
+
The name of the attribute or state to reference.
|
201
|
+
|
202
|
+
Returns
|
203
|
+
-------
|
204
|
+
Prefetch
|
205
|
+
A Prefetch object that provides access to the referenced item.
|
206
|
+
|
207
|
+
Examples
|
208
|
+
--------
|
209
|
+
>>> import brainstate
|
210
|
+
>>> import brainunit as u
|
211
|
+
>>> neuron = brainstate.nn.LIF(...)
|
212
|
+
>>> v_ref = neuron.prefetch('V') # Reference to voltage
|
213
|
+
>>> v_value = v_ref() # Get current value
|
214
|
+
>>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
|
215
|
+
"""
|
216
|
+
return Prefetch(self, item)
|
217
|
+
|
218
|
+
def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
|
219
|
+
"""
|
220
|
+
Create a reference to a delayed state or variable in the module.
|
221
|
+
|
222
|
+
This method simplifies the process of accessing a delayed version of a state or variable
|
223
|
+
within the module. It first creates a prefetch reference to the specified state,
|
224
|
+
then specifies the delay time for accessing this state.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
state (str): The name of the state or variable to reference.
|
228
|
+
delay_time (ArrayLike): The amount of time to delay the variable access,
|
229
|
+
typically in time units (e.g., milliseconds).
|
230
|
+
init (Callable, optional): An optional initialization function to provide
|
231
|
+
a default value if the delayed state is not yet available.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
|
235
|
+
"""
|
236
|
+
return PrefetchDelayAt(self, state, delay_time, init=init)
|
237
|
+
|
238
|
+
def output_delay(self, *delay_time) -> 'OutputDelayAt':
|
239
|
+
"""
|
240
|
+
Create a reference to the delayed output of the module.
|
241
|
+
|
242
|
+
This method simplifies the process of accessing a delayed version of the module's output.
|
243
|
+
It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
|
244
|
+
at the specified delay time.
|
245
|
+
|
246
|
+
Args:
|
247
|
+
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
248
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
252
|
+
"""
|
253
|
+
return OutputDelayAt(self, delay_time)
|
254
|
+
|
255
|
+
@property
|
256
|
+
def before_updates(self):
|
257
|
+
"""
|
258
|
+
Get the dictionary of functions to execute before the module's update.
|
259
|
+
|
260
|
+
Returns
|
261
|
+
-------
|
262
|
+
dict or None
|
263
|
+
Dictionary mapping keys to callable functions that will be executed
|
264
|
+
before the main update, or None if no before updates are registered.
|
265
|
+
|
266
|
+
Notes
|
267
|
+
-----
|
268
|
+
Before updates are executed in the order they were registered whenever
|
269
|
+
the module is called via __call__.
|
270
|
+
"""
|
271
|
+
return self._before_updates
|
272
|
+
|
273
|
+
@property
|
274
|
+
def after_updates(self):
|
275
|
+
"""
|
276
|
+
Get the dictionary of functions to execute after the module's update.
|
277
|
+
|
278
|
+
Returns
|
279
|
+
-------
|
280
|
+
dict or None
|
281
|
+
Dictionary mapping keys to callable functions that will be executed
|
282
|
+
after the main update, or None if no after updates are registered.
|
283
|
+
|
284
|
+
Notes
|
285
|
+
-----
|
286
|
+
After updates are executed in the order they were registered whenever
|
287
|
+
the module is called via __call__, and may optionally receive the return
|
288
|
+
value from the update method.
|
289
|
+
"""
|
290
|
+
return self._after_updates
|
291
|
+
|
292
|
+
def add_before_update(self, key: Any, fun: Callable):
|
293
|
+
"""
|
294
|
+
Register a function to be executed before the module's update.
|
295
|
+
|
296
|
+
Parameters
|
297
|
+
----------
|
298
|
+
key : Any
|
299
|
+
A unique identifier for the update function.
|
300
|
+
fun : Callable
|
301
|
+
The function to execute before the module's update.
|
302
|
+
|
303
|
+
Raises
|
304
|
+
------
|
305
|
+
KeyError
|
306
|
+
If the key is already registered in before_updates.
|
307
|
+
|
308
|
+
Notes
|
309
|
+
-----
|
310
|
+
Internal method used by the module system to register dependencies.
|
311
|
+
"""
|
312
|
+
if self._before_updates is None:
|
313
|
+
self._before_updates = dict()
|
314
|
+
if key in self.before_updates:
|
315
|
+
raise KeyError(f'{key} has been registered in before_updates of {self}')
|
316
|
+
self.before_updates[key] = fun
|
317
|
+
|
318
|
+
def add_after_update(self, key: Any, fun: Callable):
|
319
|
+
"""
|
320
|
+
Register a function to be executed after the module's update.
|
321
|
+
|
322
|
+
Parameters
|
323
|
+
----------
|
324
|
+
key : Any
|
325
|
+
A unique identifier for the update function.
|
326
|
+
fun : Callable
|
327
|
+
The function to execute after the module's update.
|
328
|
+
|
329
|
+
Raises
|
330
|
+
------
|
331
|
+
KeyError
|
332
|
+
If the key is already registered in after_updates.
|
333
|
+
|
334
|
+
Notes
|
335
|
+
-----
|
336
|
+
Internal method used by the module system to register dependencies.
|
337
|
+
"""
|
338
|
+
if self._after_updates is None:
|
339
|
+
self._after_updates = dict()
|
340
|
+
if key in self.after_updates:
|
341
|
+
raise KeyError(f'{key} has been registered in after_updates of {self}')
|
342
|
+
self.after_updates[key] = fun
|
343
|
+
|
344
|
+
def get_before_update(self, key: Any):
|
345
|
+
"""
|
346
|
+
Retrieve a registered before-update function by its key.
|
347
|
+
|
348
|
+
Parameters
|
349
|
+
----------
|
350
|
+
key : Any
|
351
|
+
The identifier of the before-update function to retrieve.
|
352
|
+
|
353
|
+
Returns
|
354
|
+
-------
|
355
|
+
Callable
|
356
|
+
The registered before-update function.
|
357
|
+
|
358
|
+
Raises
|
359
|
+
------
|
360
|
+
KeyError
|
361
|
+
If the key is not registered in before_updates or if before_updates is None.
|
362
|
+
"""
|
363
|
+
if self._before_updates is None:
|
364
|
+
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
365
|
+
if key not in self.before_updates:
|
366
|
+
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
367
|
+
return self.before_updates.get(key)
|
368
|
+
|
369
|
+
def get_after_update(self, key: Any):
|
370
|
+
"""
|
371
|
+
Retrieve a registered after-update function by its key.
|
372
|
+
|
373
|
+
Parameters
|
374
|
+
----------
|
375
|
+
key : Any
|
376
|
+
The identifier of the after-update function to retrieve.
|
377
|
+
|
378
|
+
Returns
|
379
|
+
-------
|
380
|
+
Callable
|
381
|
+
The registered after-update function.
|
382
|
+
|
383
|
+
Raises
|
384
|
+
------
|
385
|
+
KeyError
|
386
|
+
If the key is not registered in after_updates or if after_updates is None.
|
387
|
+
"""
|
388
|
+
if self._after_updates is None:
|
389
|
+
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
390
|
+
if key not in self.after_updates:
|
391
|
+
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
392
|
+
return self.after_updates.get(key)
|
393
|
+
|
394
|
+
def has_before_update(self, key: Any):
|
395
|
+
"""
|
396
|
+
Check if a before-update function is registered with the given key.
|
397
|
+
|
398
|
+
Parameters
|
399
|
+
----------
|
400
|
+
key : Any
|
401
|
+
The identifier to check for in the before_updates dictionary.
|
402
|
+
|
403
|
+
Returns
|
404
|
+
-------
|
405
|
+
bool
|
406
|
+
True if the key is registered in before_updates, False otherwise.
|
407
|
+
"""
|
408
|
+
if self._before_updates is None:
|
409
|
+
return False
|
410
|
+
return key in self.before_updates
|
411
|
+
|
412
|
+
def has_after_update(self, key: Any):
|
413
|
+
"""
|
414
|
+
Check if an after-update function is registered with the given key.
|
415
|
+
|
416
|
+
Parameters
|
417
|
+
----------
|
418
|
+
key : Any
|
419
|
+
The identifier to check for in the after_updates dictionary.
|
420
|
+
|
421
|
+
Returns
|
422
|
+
-------
|
423
|
+
bool
|
424
|
+
True if the key is registered in after_updates, False otherwise.
|
425
|
+
"""
|
426
|
+
if self._after_updates is None:
|
427
|
+
return False
|
428
|
+
return key in self.after_updates
|
429
|
+
|
430
|
+
def __call__(self, *args, **kwargs):
|
431
|
+
"""
|
432
|
+
The shortcut to call ``update`` methods.
|
433
|
+
"""
|
434
|
+
|
435
|
+
# ``before_updates``
|
436
|
+
if self.before_updates is not None:
|
437
|
+
for model in self.before_updates.values():
|
438
|
+
if hasattr(model, '_receive_update_input'):
|
439
|
+
model(*args, **kwargs)
|
440
|
+
else:
|
441
|
+
model()
|
442
|
+
|
443
|
+
# update the model self
|
444
|
+
ret = self.update(*args, **kwargs)
|
445
|
+
|
446
|
+
# ``after_updates``
|
447
|
+
if self.after_updates is not None:
|
448
|
+
for model in self.after_updates.values():
|
449
|
+
if hasattr(model, '_not_receive_update_output'):
|
450
|
+
model()
|
451
|
+
else:
|
452
|
+
model(ret)
|
453
|
+
return ret
|
454
|
+
|
455
|
+
|
456
|
+
class Prefetch(Node):
|
457
|
+
"""
|
458
|
+
Prefetch a state or variable in a module before it is initialized.
|
459
|
+
|
460
|
+
|
461
|
+
This class provides a mechanism to reference a module's state or attribute
|
462
|
+
that may not have been initialized yet. It acts as a placeholder or reference
|
463
|
+
that will be resolved when called.
|
464
|
+
|
465
|
+
Use cases:
|
466
|
+
- Access variables within dynamics modules that will be defined later
|
467
|
+
- Create references to states across module boundaries
|
468
|
+
- Enable access to delayed states through the `.delay` property
|
469
|
+
|
470
|
+
Parameters
|
471
|
+
----------
|
472
|
+
module : Module
|
473
|
+
The module that contains or will contain the referenced item.
|
474
|
+
item : str
|
475
|
+
The attribute name of the state or variable to prefetch.
|
476
|
+
|
477
|
+
Examples
|
478
|
+
--------
|
479
|
+
>>> import brainstate
|
480
|
+
>>> import brainunit as u
|
481
|
+
>>> neuron = brainstate.nn.LIF(...)
|
482
|
+
>>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
|
483
|
+
>>> v_value = v_reference() # Get the current value
|
484
|
+
>>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
|
485
|
+
|
486
|
+
Notes
|
487
|
+
-----
|
488
|
+
When called, this class retrieves the current value of the referenced item.
|
489
|
+
Use the `.delay` property to access delayed versions of the state.
|
490
|
+
|
491
|
+
"""
|
492
|
+
|
493
|
+
def __init__(self, module: Dynamics, item: str):
|
494
|
+
"""
|
495
|
+
Initialize a Prefetch object.
|
496
|
+
|
497
|
+
Parameters
|
498
|
+
----------
|
499
|
+
module : Module
|
500
|
+
The module that contains or will contain the referenced item.
|
501
|
+
item : str
|
502
|
+
The attribute name of the state or variable to prefetch.
|
503
|
+
"""
|
504
|
+
super().__init__()
|
505
|
+
self.module = module
|
506
|
+
self.item = item
|
507
|
+
|
508
|
+
@property
|
509
|
+
def delay(self):
|
510
|
+
"""
|
511
|
+
Access delayed versions of the prefetched item.
|
512
|
+
|
513
|
+
Returns
|
514
|
+
-------
|
515
|
+
PrefetchDelay
|
516
|
+
An object that provides access to delayed versions of the prefetched item.
|
517
|
+
"""
|
518
|
+
return PrefetchDelay(self.module, self.item)
|
519
|
+
# return PrefetchDelayAt(self.module, self.item, time)
|
520
|
+
|
521
|
+
def __call__(self, *args, **kwargs):
|
522
|
+
"""
|
523
|
+
Get the current value of the prefetched item.
|
524
|
+
|
525
|
+
Returns
|
526
|
+
-------
|
527
|
+
Any
|
528
|
+
The current value of the referenced item. If the item is a State object,
|
529
|
+
returns its value attribute, otherwise returns the item itself.
|
530
|
+
"""
|
531
|
+
item = _get_prefetch_item(self)
|
532
|
+
return item.value if isinstance(item, State) else item
|
533
|
+
|
534
|
+
def get_item_value(self):
|
535
|
+
"""
|
536
|
+
Get the current value of the prefetched item.
|
537
|
+
|
538
|
+
Similar to __call__, but explicitly named for clarity.
|
539
|
+
|
540
|
+
Returns
|
541
|
+
-------
|
542
|
+
Any
|
543
|
+
The current value of the referenced item. If the item is a State object,
|
544
|
+
returns its value attribute, otherwise returns the item itself.
|
545
|
+
"""
|
546
|
+
item = _get_prefetch_item(self)
|
547
|
+
return item.value if isinstance(item, State) else item
|
548
|
+
|
549
|
+
def get_item(self):
|
550
|
+
"""
|
551
|
+
Get the referenced item object itself, not its value.
|
552
|
+
|
553
|
+
Returns
|
554
|
+
-------
|
555
|
+
Any
|
556
|
+
The actual referenced item from the module, which could be a State
|
557
|
+
object or any other attribute.
|
558
|
+
"""
|
559
|
+
return _get_prefetch_item(self)
|
560
|
+
|
561
|
+
|
562
|
+
class PrefetchDelay(Node):
|
563
|
+
"""
|
564
|
+
Provides access to delayed versions of a prefetched state or variable.
|
565
|
+
|
566
|
+
This class acts as an intermediary for accessing delayed values of module variables.
|
567
|
+
It doesn't retrieve values directly but provides methods to specify the delay time
|
568
|
+
via the `at()` method.
|
569
|
+
|
570
|
+
Parameters
|
571
|
+
----------
|
572
|
+
module : Dynamics
|
573
|
+
The dynamics module that contains the referenced state or variable.
|
574
|
+
item : str
|
575
|
+
The name of the state or variable to access with delay.
|
576
|
+
|
577
|
+
Examples
|
578
|
+
--------
|
579
|
+
>>> import brainstate
|
580
|
+
>>> import brainunit as u
|
581
|
+
>>> neuron = brainstate.nn.LIF(10)
|
582
|
+
>>> # Access voltage delayed by 5ms
|
583
|
+
>>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
|
584
|
+
>>> delayed_value = delayed_v() # Get the delayed value
|
585
|
+
"""
|
586
|
+
|
587
|
+
def __init__(self, module: Dynamics, item: str):
|
588
|
+
self.module = module
|
589
|
+
self.item = item
|
590
|
+
|
591
|
+
def at(self, *delay_time):
|
592
|
+
"""
|
593
|
+
Specifies the delay time for accessing the variable.
|
594
|
+
|
595
|
+
Parameters
|
596
|
+
----------
|
597
|
+
time : ArrayLike
|
598
|
+
The amount of time to delay the variable access, typically in time units
|
599
|
+
(e.g., milliseconds).
|
600
|
+
|
601
|
+
Returns
|
602
|
+
-------
|
603
|
+
PrefetchDelayAt
|
604
|
+
An object that provides access to the variable at the specified delay time.
|
605
|
+
"""
|
606
|
+
return PrefetchDelayAt(self.module, self.item, delay_time)
|
607
|
+
|
608
|
+
|
609
|
+
class PrefetchDelayAt(Node):
|
610
|
+
"""
|
611
|
+
Provides access to a specific delayed state or variable value at the specific time.
|
612
|
+
|
613
|
+
This class represents the final step in the prefetch delay chain, providing
|
614
|
+
actual access to state values at a specific delay time. It converts the
|
615
|
+
specified time delay into steps and registers the delay with the appropriate
|
616
|
+
StateWithDelay handler.
|
617
|
+
|
618
|
+
Parameters
|
619
|
+
----------
|
620
|
+
module : Dynamics
|
621
|
+
The dynamics module that contains the referenced state or variable.
|
622
|
+
item : str
|
623
|
+
The name of the state or variable to access with delay.
|
624
|
+
time : ArrayLike
|
625
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
626
|
+
|
627
|
+
Examples
|
628
|
+
--------
|
629
|
+
>>> import brainstate
|
630
|
+
>>> import brainunit as u
|
631
|
+
>>> neuron = brainstate.nn.LIF(10)
|
632
|
+
>>> # Create a reference to voltage delayed by 5ms
|
633
|
+
>>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
|
634
|
+
>>> # Get the delayed value
|
635
|
+
>>> v_value = delayed_v()
|
636
|
+
"""
|
637
|
+
|
638
|
+
def __init__(
|
639
|
+
self,
|
640
|
+
module: Dynamics,
|
641
|
+
item: str,
|
642
|
+
delay_time: Tuple,
|
643
|
+
init: Callable = None
|
644
|
+
):
|
645
|
+
"""
|
646
|
+
Initialize a PrefetchDelayAt object.
|
647
|
+
|
648
|
+
Parameters
|
649
|
+
----------
|
650
|
+
module : Dynamics
|
651
|
+
The dynamics module that contains the referenced state or variable.
|
652
|
+
item : str
|
653
|
+
The name of the state or variable to access with delay.
|
654
|
+
delay_time : Tuple
|
655
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
656
|
+
"""
|
657
|
+
super().__init__()
|
658
|
+
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
659
|
+
self.module = module
|
660
|
+
self.item = item
|
661
|
+
if not isinstance(delay_time, (tuple, list)):
|
662
|
+
delay_time = (delay_time,)
|
663
|
+
self.delay_time = delay_time
|
664
|
+
if len(delay_time) > 0:
|
665
|
+
key = _get_prefetch_delay_key(item)
|
666
|
+
if not module.has_after_update(key):
|
667
|
+
module.add_after_update(
|
668
|
+
key,
|
669
|
+
not_receive_update_output(
|
670
|
+
StateWithDelay(module, item, init=init)
|
671
|
+
)
|
672
|
+
)
|
673
|
+
self.state_delay: StateWithDelay = module.get_after_update(key)
|
674
|
+
self.delay_info = self.state_delay.register_delay(*delay_time)
|
675
|
+
|
676
|
+
def __call__(self, *args, **kwargs):
|
677
|
+
"""
|
678
|
+
Retrieve the value of the state at the specified delay time.
|
679
|
+
|
680
|
+
Returns
|
681
|
+
-------
|
682
|
+
Any
|
683
|
+
The value of the state or variable at the specified delay time.
|
684
|
+
"""
|
685
|
+
if len(self.delay_time) == 0:
|
686
|
+
return _get_prefetch_item(self).value
|
687
|
+
else:
|
688
|
+
return self.state_delay.retrieve_at_step(*self.delay_info)
|
689
|
+
|
690
|
+
|
691
|
+
class OutputDelayAt(Node):
|
692
|
+
"""
|
693
|
+
Provides access to a specific delayed state or variable value at the specific time.
|
694
|
+
|
695
|
+
This class represents the final step in the prefetch delay chain, providing
|
696
|
+
actual access to state values at a specific delay time. It converts the
|
697
|
+
specified time delay into steps and registers the delay with the appropriate
|
698
|
+
StateWithDelay handler.
|
699
|
+
|
700
|
+
Parameters
|
701
|
+
----------
|
702
|
+
module : Dynamics
|
703
|
+
The dynamics module that contains the referenced state or variable.
|
704
|
+
time : ArrayLike
|
705
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
706
|
+
|
707
|
+
Examples
|
708
|
+
--------
|
709
|
+
>>> import brainstate
|
710
|
+
>>> import brainunit as u
|
711
|
+
>>> neuron = brainstate.nn.LIF(10)
|
712
|
+
>>> # Create a reference to voltage delayed by 5ms
|
713
|
+
>>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
|
714
|
+
>>> # Get the delayed value
|
715
|
+
>>> v_value = delayed_spike()
|
716
|
+
"""
|
717
|
+
|
718
|
+
def __init__(
|
719
|
+
self,
|
720
|
+
module: Dynamics,
|
721
|
+
delay_time: Tuple,
|
722
|
+
):
|
723
|
+
super().__init__()
|
724
|
+
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
725
|
+
self.module = module
|
726
|
+
key = _get_output_delay_key()
|
727
|
+
if not module.has_after_update(key):
|
728
|
+
delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
|
729
|
+
module.add_after_update(key, receive_update_output(delay))
|
730
|
+
self.out_delay: Delay = module.get_after_update(key)
|
731
|
+
self.delay_info = self.out_delay.register_delay(*delay_time)
|
732
|
+
|
733
|
+
def __call__(self, *args, **kwargs):
|
734
|
+
return self.out_delay.retrieve_at_step(*self.delay_info)
|
735
|
+
|
736
|
+
|
737
|
+
def _get_prefetch_delay_key(item) -> str:
|
738
|
+
return f'{item}-prefetch-delay'
|
739
|
+
|
740
|
+
|
741
|
+
def _get_output_delay_key() -> str:
|
742
|
+
return f'output-delay'
|
743
|
+
|
744
|
+
|
745
|
+
def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
|
746
|
+
item = getattr(target.module, target.item, None)
|
747
|
+
if item is None:
|
748
|
+
raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
|
749
|
+
return item
|
750
|
+
|
751
|
+
|
752
|
+
def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
|
753
|
+
assert isinstance(target.module, Dynamics), (
|
754
|
+
f'The target module should be an instance '
|
755
|
+
f'of Dynamics. But got {target.module}.'
|
756
|
+
)
|
757
|
+
delay = target.module.get_after_update(_get_prefetch_delay_key(target.item))
|
758
|
+
if not isinstance(delay, StateWithDelay):
|
759
|
+
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
760
|
+
f'its delay. But got {delay}.')
|
761
|
+
return delay
|
762
|
+
|
763
|
+
|
764
|
+
def maybe_init_prefetch(target, *args, **kwargs):
|
765
|
+
"""
|
766
|
+
Initialize a prefetch target if needed, based on its type.
|
767
|
+
|
768
|
+
This function ensures that prefetch references are properly initialized
|
769
|
+
and ready to use. It handles different types of prefetch objects by
|
770
|
+
performing the appropriate initialization action:
|
771
|
+
- For :py:class:`Prefetch` objects: retrieves the referenced item
|
772
|
+
- For :py:class:`PrefetchDelay` objects: retrieves the delay handler
|
773
|
+
- For :py:class:`PrefetchDelayAt` objects: registers the specified delay
|
774
|
+
|
775
|
+
Parameters
|
776
|
+
----------
|
777
|
+
target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
|
778
|
+
The prefetch target to initialize.
|
779
|
+
*args : Any
|
780
|
+
Additional positional arguments (unused).
|
781
|
+
**kwargs : Any
|
782
|
+
Additional keyword arguments (unused).
|
783
|
+
|
784
|
+
Returns
|
785
|
+
-------
|
786
|
+
None
|
787
|
+
This function performs initialization side effects only.
|
788
|
+
|
789
|
+
Notes
|
790
|
+
-----
|
791
|
+
This function is typically called internally when prefetched references
|
792
|
+
are used to ensure they are properly set up before access.
|
793
|
+
"""
|
794
|
+
if isinstance(target, Prefetch):
|
795
|
+
_get_prefetch_item(target)
|
796
|
+
|
797
|
+
elif isinstance(target, PrefetchDelay):
|
798
|
+
_get_prefetch_item_delay(target)
|
799
|
+
|
800
|
+
elif isinstance(target, PrefetchDelayAt):
|
801
|
+
pass
|
802
|
+
# delay = _get_prefetch_item_delay(target)
|
803
|
+
# delay.register_delay(*target.delay_time)
|
804
|
+
|
805
|
+
|
806
|
+
def receive_update_output(cls: object):
|
807
|
+
"""
|
808
|
+
The decorator to mark the object (as the after updates) to receive the output of the update function.
|
809
|
+
|
810
|
+
That is, the `aft_update` will receive the return of the update function::
|
811
|
+
|
812
|
+
ret = model.update(*args, **kwargs)
|
813
|
+
for fun in model.aft_updates:
|
814
|
+
fun(ret)
|
815
|
+
|
816
|
+
"""
|
817
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
818
|
+
if hasattr(cls, '_not_receive_update_output'):
|
819
|
+
delattr(cls, '_not_receive_update_output')
|
820
|
+
return cls
|
821
|
+
|
822
|
+
|
823
|
+
def not_receive_update_output(cls: T) -> T:
|
824
|
+
"""
|
825
|
+
The decorator to mark the object (as the after updates) to not receive the output of the update function.
|
826
|
+
|
827
|
+
That is, the `aft_update` will not receive the return of the update function::
|
828
|
+
|
829
|
+
ret = model.update(*args, **kwargs)
|
830
|
+
for fun in model.aft_updates:
|
831
|
+
fun()
|
832
|
+
|
833
|
+
"""
|
834
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
835
|
+
cls._not_receive_update_output = True
|
836
|
+
return cls
|
837
|
+
|
838
|
+
|
839
|
+
def receive_update_input(cls: object):
|
840
|
+
"""
|
841
|
+
The decorator to mark the object (as the before updates) to receive the input of the update function.
|
842
|
+
|
843
|
+
That is, the `bef_update` will receive the input of the update function::
|
844
|
+
|
845
|
+
|
846
|
+
for fun in model.bef_updates:
|
847
|
+
fun(*args, **kwargs)
|
848
|
+
model.update(*args, **kwargs)
|
849
|
+
|
850
|
+
"""
|
851
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
852
|
+
cls._receive_update_input = True
|
853
|
+
return cls
|
854
|
+
|
855
|
+
|
856
|
+
def not_receive_update_input(cls: object):
|
857
|
+
"""
|
858
|
+
The decorator to mark the object (as the before updates) to not receive the input of the update function.
|
859
|
+
|
860
|
+
That is, the `bef_update` will not receive the input of the update function::
|
861
|
+
|
862
|
+
for fun in model.bef_updates:
|
863
|
+
fun()
|
864
|
+
model.update()
|
865
|
+
|
866
|
+
"""
|
867
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
868
|
+
if hasattr(cls, '_receive_update_input'):
|
869
|
+
delattr(cls, '_receive_update_input')
|
870
|
+
return cls
|