brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/_normalizations.py
CHANGED
@@ -1,918 +1,975 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from typing import Callable, Union, Sequence, Optional, Any
|
19
|
-
|
20
|
-
import jax
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
|
24
|
-
from brainstate
|
25
|
-
from brainstate.
|
26
|
-
from .
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
'
|
31
|
-
'
|
32
|
-
'
|
33
|
-
'
|
34
|
-
'
|
35
|
-
'
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
#
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
#
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
The
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
)
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
)
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
self.
|
867
|
-
self.
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
is
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Callable, Union, Sequence, Optional, Any
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
import brainunit as u
|
24
|
+
from brainstate import environ, init
|
25
|
+
from brainstate._state import ParamState, BatchState
|
26
|
+
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
27
|
+
from ._module import Module
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'BatchNorm0d',
|
31
|
+
'BatchNorm1d',
|
32
|
+
'BatchNorm2d',
|
33
|
+
'BatchNorm3d',
|
34
|
+
'LayerNorm',
|
35
|
+
'RMSNorm',
|
36
|
+
'GroupNorm',
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
def weight_standardization(
|
41
|
+
w: ArrayLike,
|
42
|
+
eps: float = 1e-4,
|
43
|
+
gain: Optional[jax.Array] = None,
|
44
|
+
out_axis: int = -1,
|
45
|
+
) -> Union[jax.Array, u.Quantity]:
|
46
|
+
"""
|
47
|
+
Scaled Weight Standardization,
|
48
|
+
see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
w : ArrayLike
|
53
|
+
The weight tensor.
|
54
|
+
eps : float
|
55
|
+
A small value to avoid division by zero.
|
56
|
+
gain : Array
|
57
|
+
The gain function, by default None.
|
58
|
+
out_axis : int
|
59
|
+
The output axis, by default -1.
|
60
|
+
|
61
|
+
Returns
|
62
|
+
-------
|
63
|
+
ArrayLike
|
64
|
+
The scaled weight tensor.
|
65
|
+
"""
|
66
|
+
w = u.maybe_custom_array(w)
|
67
|
+
if out_axis < 0:
|
68
|
+
out_axis = w.ndim + out_axis
|
69
|
+
fan_in = 1 # get the fan-in of the weight tensor
|
70
|
+
axes = [] # get the axes of the weight tensor
|
71
|
+
for i in range(w.ndim):
|
72
|
+
if i != out_axis:
|
73
|
+
fan_in *= w.shape[i]
|
74
|
+
axes.append(i)
|
75
|
+
# normalize the weight
|
76
|
+
mean = u.math.mean(w, axis=axes, keepdims=True)
|
77
|
+
var = u.math.var(w, axis=axes, keepdims=True)
|
78
|
+
|
79
|
+
temp = u.math.maximum(var * fan_in, eps)
|
80
|
+
if isinstance(temp, u.Quantity):
|
81
|
+
unit = temp.unit
|
82
|
+
temp = temp.mantissa
|
83
|
+
if unit.is_unitless:
|
84
|
+
scale = jax.lax.rsqrt(temp)
|
85
|
+
else:
|
86
|
+
scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
|
87
|
+
else:
|
88
|
+
scale = jax.lax.rsqrt(temp)
|
89
|
+
if gain is not None:
|
90
|
+
scale = gain * scale
|
91
|
+
shift = mean * scale
|
92
|
+
return w * scale - shift
|
93
|
+
|
94
|
+
|
95
|
+
|
96
|
+
def canonicalize_dtype(
|
97
|
+
*args,
|
98
|
+
dtype: jax.typing.DTypeLike | None = None,
|
99
|
+
inexact: bool = True
|
100
|
+
) -> jax.typing.DTypeLike:
|
101
|
+
"""Canonicalize an optional dtype to the definitive dtype.
|
102
|
+
|
103
|
+
If the ``dtype`` is None this function will infer the dtype. If it is not
|
104
|
+
None it will be returned unmodified or an exceptions is raised if the dtype
|
105
|
+
is invalid.
|
106
|
+
from the input arguments using ``jnp.result_type``.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
*args: JAX array compatible values. None values
|
110
|
+
are ignored.
|
111
|
+
dtype: Optional dtype override. If specified the arguments are cast to
|
112
|
+
the specified dtype instead and dtype inference is disabled.
|
113
|
+
inexact: When True, the output dtype must be a subdtype
|
114
|
+
of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
|
115
|
+
is useful when you want to apply operations that don't work directly on
|
116
|
+
integers like taking a mean for example.
|
117
|
+
Returns:
|
118
|
+
The dtype that *args should be cast to.
|
119
|
+
"""
|
120
|
+
if dtype is None:
|
121
|
+
args_filtered = [jnp.asarray(x) for x in args if x is not None]
|
122
|
+
dtype = jnp.result_type(*args_filtered)
|
123
|
+
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
|
124
|
+
dtype = jnp.promote_types(jnp.float32, dtype)
|
125
|
+
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
|
126
|
+
raise ValueError(f'Dtype must be inexact: {dtype}')
|
127
|
+
return dtype
|
128
|
+
|
129
|
+
|
130
|
+
def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
|
131
|
+
axes = []
|
132
|
+
for axis in feature_axes:
|
133
|
+
if axis < 0:
|
134
|
+
axis += ndim
|
135
|
+
if axis < 0 or axis >= ndim:
|
136
|
+
raise ValueError(f'Invalid axis {axis} for {ndim}D input')
|
137
|
+
axes.append(axis)
|
138
|
+
return tuple(axes)
|
139
|
+
|
140
|
+
|
141
|
+
def _abs_sq(x):
|
142
|
+
"""Computes the elementwise square of the absolute value |x|^2."""
|
143
|
+
if jnp.iscomplexobj(x):
|
144
|
+
return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
|
145
|
+
else:
|
146
|
+
return jax.lax.square(x)
|
147
|
+
|
148
|
+
|
149
|
+
class NormalizationParamState(ParamState):
|
150
|
+
# This is a dummy class to be used as a compatibility
|
151
|
+
# usage of `ETraceParam` for the layers in "brainetrace"
|
152
|
+
def execute(self, x):
|
153
|
+
param = self.value
|
154
|
+
if 'scale' in param:
|
155
|
+
x = x * param['scale']
|
156
|
+
if 'bias' in param:
|
157
|
+
x = x + param['bias']
|
158
|
+
return x
|
159
|
+
|
160
|
+
|
161
|
+
def _compute_stats(
|
162
|
+
x: ArrayLike,
|
163
|
+
axes: Sequence[int],
|
164
|
+
dtype: DTypeLike,
|
165
|
+
axis_name: Optional[str] = None,
|
166
|
+
axis_index_groups: Optional[Sequence[int]] = None,
|
167
|
+
use_mean: bool = True,
|
168
|
+
use_fast_variance: bool = True,
|
169
|
+
mask: Optional[jax.Array] = None,
|
170
|
+
):
|
171
|
+
"""
|
172
|
+
Computes mean and variance statistics.
|
173
|
+
|
174
|
+
This implementation takes care of a few important details:
|
175
|
+
- Computes in float32 precision for stability in half precision training.
|
176
|
+
- If ``use_fast_variance`` is ``True``, mean and variance are computed using
|
177
|
+
Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion.
|
178
|
+
- Clips negative variances to zero which can happen due to
|
179
|
+
roundoff errors. This avoids downstream NaNs.
|
180
|
+
- Supports averaging across a parallel axis and subgroups of a parallel axis
|
181
|
+
with a single `lax.pmean` call to avoid latency.
|
182
|
+
|
183
|
+
Arguments:
|
184
|
+
x: Input array.
|
185
|
+
axes: The axes in ``x`` to compute mean and variance statistics for.
|
186
|
+
dtype: tp.Optional dtype specifying the minimal precision. Statistics
|
187
|
+
are always at least float32 for stability (default: dtype of x).
|
188
|
+
axis_name: Optional name for the pmapped axis to compute mean over. Note,
|
189
|
+
this is only used for pmap and shard map. For SPMD jit, you do not need to
|
190
|
+
manually synchronize. Just make sure that the axes are correctly annotated
|
191
|
+
and XLA:SPMD will insert the necessary collectives.
|
192
|
+
axis_index_groups: Optional axis indices.
|
193
|
+
use_mean: If true, calculate the mean from the input and use it when
|
194
|
+
computing the variance. If false, set the mean to zero and compute
|
195
|
+
the variance without subtracting the mean.
|
196
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
197
|
+
calculation for the variance.
|
198
|
+
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
|
199
|
+
the positions for which the mean and variance should be computed.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
A pair ``(mean, val)``.
|
203
|
+
"""
|
204
|
+
if dtype is None:
|
205
|
+
dtype = jax.numpy.result_type(x)
|
206
|
+
# promote x to at least float32, this avoids half precision computation
|
207
|
+
# but preserves double or complex floating points
|
208
|
+
dtype = jax.numpy.promote_types(dtype, jnp.float32)
|
209
|
+
x = jnp.asarray(x, dtype)
|
210
|
+
axes = _canonicalize_axes(x.ndim, axes)
|
211
|
+
|
212
|
+
def maybe_distributed_mean(*xs, mask=None):
|
213
|
+
mus = tuple(x.mean(axes, where=mask) for x in xs)
|
214
|
+
if axis_name is None:
|
215
|
+
return mus if len(xs) > 1 else mus[0]
|
216
|
+
else:
|
217
|
+
# In the distributed case we stack multiple arrays to speed comms.
|
218
|
+
if len(xs) > 1:
|
219
|
+
reduced_mus = jax.lax.pmean(
|
220
|
+
jnp.stack(mus, axis=0),
|
221
|
+
axis_name,
|
222
|
+
axis_index_groups=axis_index_groups,
|
223
|
+
)
|
224
|
+
return tuple(reduced_mus[i] for i in range(len(xs)))
|
225
|
+
else:
|
226
|
+
return jax.lax.pmean(
|
227
|
+
mus[0],
|
228
|
+
axis_name,
|
229
|
+
axis_index_groups=axis_index_groups
|
230
|
+
)
|
231
|
+
|
232
|
+
if use_mean:
|
233
|
+
if use_fast_variance:
|
234
|
+
mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
|
235
|
+
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
|
236
|
+
# to floating point round-off errors.
|
237
|
+
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
|
238
|
+
else:
|
239
|
+
mu = maybe_distributed_mean(x, mask=mask)
|
240
|
+
var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
|
241
|
+
else:
|
242
|
+
var = maybe_distributed_mean(_abs_sq(x), mask=mask)
|
243
|
+
mu = jnp.zeros_like(var)
|
244
|
+
return mu, var
|
245
|
+
|
246
|
+
|
247
|
+
def _normalize(
|
248
|
+
x: ArrayLike,
|
249
|
+
mean: Optional[ArrayLike],
|
250
|
+
var: Optional[ArrayLike],
|
251
|
+
weights: Optional[NormalizationParamState],
|
252
|
+
reduction_axes: Axes,
|
253
|
+
feature_axes: Axes,
|
254
|
+
dtype: DTypeLike,
|
255
|
+
epsilon: jax.typing.ArrayLike,
|
256
|
+
):
|
257
|
+
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
|
258
|
+
|
259
|
+
Arguments:
|
260
|
+
x: The input.
|
261
|
+
mean: Mean to use for normalization.
|
262
|
+
var: Variance to use for normalization.
|
263
|
+
weights: The scale and bias parameters.
|
264
|
+
reduction_axes: The axes in ``x`` to reduce.
|
265
|
+
feature_axes: The feature axes to apply the scale and bias.
|
266
|
+
dtype: The dtype of the result (default: infer from input and params).
|
267
|
+
epsilon: Normalization epsilon.
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
The normalized input.
|
271
|
+
"""
|
272
|
+
if mean is not None:
|
273
|
+
assert var is not None, 'mean and val must be both None or not None.'
|
274
|
+
reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
|
275
|
+
feature_axes = _canonicalize_axes(x.ndim, feature_axes)
|
276
|
+
stats_shape = list(x.shape)
|
277
|
+
for axis in reduction_axes:
|
278
|
+
stats_shape[axis] = 1
|
279
|
+
mean = mean.reshape(stats_shape)
|
280
|
+
var = var.reshape(stats_shape)
|
281
|
+
feature_shape = [1] * x.ndim
|
282
|
+
for ax in feature_axes:
|
283
|
+
feature_shape[ax] = x.shape[ax]
|
284
|
+
y = x - mean
|
285
|
+
mul = jax.lax.rsqrt(var + epsilon)
|
286
|
+
y = y * mul
|
287
|
+
if weights is not None:
|
288
|
+
y = weights.execute(y)
|
289
|
+
dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
|
290
|
+
else:
|
291
|
+
assert var is None, 'mean and val must be both None or not None.'
|
292
|
+
assert weights is None, 'scale and bias are not supported without mean and val'
|
293
|
+
y = x
|
294
|
+
return jnp.asarray(y, dtype)
|
295
|
+
|
296
|
+
|
297
|
+
class _BatchNorm(Module):
|
298
|
+
__module__ = 'brainstate.nn'
|
299
|
+
num_spatial_dims: int
|
300
|
+
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
in_size: Size,
|
304
|
+
feature_axis: Axes = -1,
|
305
|
+
*,
|
306
|
+
track_running_stats: bool = True,
|
307
|
+
epsilon: float = 1e-5,
|
308
|
+
momentum: float = 0.99,
|
309
|
+
affine: bool = True,
|
310
|
+
bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
|
311
|
+
scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
|
312
|
+
axis_name: Optional[Union[str, Sequence[str]]] = None,
|
313
|
+
axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
|
314
|
+
use_fast_variance: bool = True,
|
315
|
+
name: Optional[str] = None,
|
316
|
+
dtype: Any = None,
|
317
|
+
param_type: type = NormalizationParamState,
|
318
|
+
mean_type: type = BatchState,
|
319
|
+
):
|
320
|
+
super().__init__(name=name)
|
321
|
+
|
322
|
+
# parameters
|
323
|
+
self.in_size = in_size
|
324
|
+
self.out_size = in_size
|
325
|
+
self.affine = affine
|
326
|
+
self.bias_initializer = bias_initializer
|
327
|
+
self.scale_initializer = scale_initializer
|
328
|
+
self.dtype = dtype or environ.dftype()
|
329
|
+
self.track_running_stats = track_running_stats
|
330
|
+
self.momentum = jnp.asarray(momentum, dtype=self.dtype)
|
331
|
+
self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
|
332
|
+
self.use_fast_variance = use_fast_variance
|
333
|
+
|
334
|
+
# parameters about axis
|
335
|
+
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
336
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
|
337
|
+
self.axis_name = axis_name
|
338
|
+
self.axis_index_groups = axis_index_groups
|
339
|
+
|
340
|
+
# variables
|
341
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
342
|
+
for i, ax in enumerate(self.in_size)])
|
343
|
+
if self.track_running_stats:
|
344
|
+
self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
|
345
|
+
self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
|
346
|
+
else:
|
347
|
+
self.running_mean = None
|
348
|
+
self.running_var = None
|
349
|
+
|
350
|
+
# parameters
|
351
|
+
if self.affine:
|
352
|
+
assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
|
353
|
+
bias = init.param(self.bias_initializer, feature_shape)
|
354
|
+
scale = init.param(self.scale_initializer, feature_shape)
|
355
|
+
self.weight = param_type(dict(bias=bias, scale=scale))
|
356
|
+
else:
|
357
|
+
self.weight = None
|
358
|
+
|
359
|
+
def update(self, x, mask: Optional[jax.Array] = None):
|
360
|
+
# input shape and batch mode or not
|
361
|
+
if x.ndim == self.num_spatial_dims + 2:
|
362
|
+
x_shape = x.shape[1:]
|
363
|
+
batch = True
|
364
|
+
elif x.ndim == self.num_spatial_dims + 1:
|
365
|
+
x_shape = x.shape
|
366
|
+
batch = False
|
367
|
+
else:
|
368
|
+
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
369
|
+
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
370
|
+
if self.in_size != x_shape:
|
371
|
+
raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
|
372
|
+
|
373
|
+
# reduce the feature axis
|
374
|
+
if batch:
|
375
|
+
reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
|
376
|
+
else:
|
377
|
+
reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
|
378
|
+
|
379
|
+
# fitting phase
|
380
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
381
|
+
|
382
|
+
# compute the running mean and variance
|
383
|
+
if self.track_running_stats:
|
384
|
+
if fit_phase:
|
385
|
+
mean, var = _compute_stats(
|
386
|
+
x,
|
387
|
+
reduction_axes,
|
388
|
+
dtype=self.dtype,
|
389
|
+
axis_name=self.axis_name,
|
390
|
+
axis_index_groups=self.axis_index_groups,
|
391
|
+
use_fast_variance=self.use_fast_variance,
|
392
|
+
mask=mask,
|
393
|
+
)
|
394
|
+
self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
|
395
|
+
self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
|
396
|
+
else:
|
397
|
+
mean = self.running_mean.value
|
398
|
+
var = self.running_var.value
|
399
|
+
else:
|
400
|
+
mean, var = None, None
|
401
|
+
|
402
|
+
# normalize
|
403
|
+
return _normalize(
|
404
|
+
x,
|
405
|
+
mean=mean,
|
406
|
+
var=var,
|
407
|
+
weights=self.weight,
|
408
|
+
reduction_axes=reduction_axes,
|
409
|
+
feature_axes=self.feature_axes,
|
410
|
+
dtype=self.dtype,
|
411
|
+
epsilon=self.epsilon
|
412
|
+
)
|
413
|
+
|
414
|
+
|
415
|
+
class BatchNorm0d(_BatchNorm):
|
416
|
+
r"""0-D batch normalization [1]_.
|
417
|
+
|
418
|
+
The data should be of `(b, c)`, where `b` is the batch dimension, and `c` is the channel dimension.
|
419
|
+
|
420
|
+
%s
|
421
|
+
"""
|
422
|
+
__module__ = 'brainstate.nn'
|
423
|
+
num_spatial_dims: int = 0
|
424
|
+
|
425
|
+
|
426
|
+
class BatchNorm1d(_BatchNorm):
|
427
|
+
r"""1-D batch normalization [1]_.
|
428
|
+
|
429
|
+
The data should be of `(b, l, c)`, where `b` is the batch dimension,
|
430
|
+
`l` is the layer dimension, and `c` is the channel dimension.
|
431
|
+
|
432
|
+
%s
|
433
|
+
"""
|
434
|
+
__module__ = 'brainstate.nn'
|
435
|
+
num_spatial_dims: int = 1
|
436
|
+
|
437
|
+
|
438
|
+
class BatchNorm2d(_BatchNorm):
|
439
|
+
r"""2-D batch normalization [1]_.
|
440
|
+
|
441
|
+
The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
|
442
|
+
`h` is the height dimension, `w` is the width dimension, and `c` is the
|
443
|
+
channel dimension.
|
444
|
+
|
445
|
+
%s
|
446
|
+
"""
|
447
|
+
__module__ = 'brainstate.nn'
|
448
|
+
num_spatial_dims: int = 2
|
449
|
+
|
450
|
+
|
451
|
+
class BatchNorm3d(_BatchNorm):
|
452
|
+
r"""3-D batch normalization [1]_.
|
453
|
+
|
454
|
+
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
|
455
|
+
`h` is the height dimension, `w` is the width dimension, `d` is the depth
|
456
|
+
dimension, and `c` is the channel dimension.
|
457
|
+
|
458
|
+
%s
|
459
|
+
"""
|
460
|
+
__module__ = 'brainstate.nn'
|
461
|
+
num_spatial_dims: int = 3
|
462
|
+
|
463
|
+
|
464
|
+
_bn_doc = r'''
|
465
|
+
|
466
|
+
This layer aims to reduce the internal covariant shift of data. It
|
467
|
+
normalizes a batch of data by fixing the mean and variance of inputs
|
468
|
+
on each feature (channel). Most commonly, the first axis of the data
|
469
|
+
is the batch, and the last is the channel. However, users can specify
|
470
|
+
the axes to be normalized.
|
471
|
+
|
472
|
+
.. math::
|
473
|
+
y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
|
474
|
+
|
475
|
+
.. note::
|
476
|
+
This :attr:`momentum` argument is different from one used in optimizer
|
477
|
+
classes and the conventional notion of momentum. Mathematically, the
|
478
|
+
update rule for running statistics here is
|
479
|
+
:math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
|
480
|
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
481
|
+
new observed value.
|
482
|
+
|
483
|
+
Parameters
|
484
|
+
----------
|
485
|
+
in_size: sequence of int
|
486
|
+
The input shape, without batch size.
|
487
|
+
feature_axis: int, tuple, list
|
488
|
+
The feature or non-batch axis of the input.
|
489
|
+
track_running_stats: bool
|
490
|
+
A boolean value that when set to ``True``, this module tracks the running mean and variance,
|
491
|
+
and when set to ``False``, this module does not track such statistics, and initializes
|
492
|
+
statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
|
493
|
+
this module always uses batch statistics. in both training and eval modes. Default: ``True``.
|
494
|
+
momentum: float
|
495
|
+
The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
|
496
|
+
epsilon: float
|
497
|
+
A value added to the denominator for numerical stability. Default: 1e-5
|
498
|
+
affine: bool
|
499
|
+
A boolean value that when set to ``True``, this module has
|
500
|
+
learnable affine parameters. Default: ``True``
|
501
|
+
bias_initializer: ArrayLike, Callable
|
502
|
+
An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
|
503
|
+
Default: ``init.Constant(0.)``
|
504
|
+
scale_initializer: ArrayLike, Callable
|
505
|
+
An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
|
506
|
+
Default: ``init.Constant(1.)``
|
507
|
+
axis_name: optional, str, sequence of str
|
508
|
+
If not ``None``, it should be a string (or sequence of
|
509
|
+
strings) representing the axis name(s) over which this module is being
|
510
|
+
run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
|
511
|
+
argument means that batch statistics are calculated across all replicas
|
512
|
+
on the named axes.
|
513
|
+
axis_index_groups: optional, sequence
|
514
|
+
Specifies how devices are grouped. Valid
|
515
|
+
only within ``jax.pmap`` collectives.
|
516
|
+
Groups of axis indices within that named axis
|
517
|
+
representing subsets of devices to reduce over (default: None). For
|
518
|
+
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
519
|
+
the examples on the first two and last two devices. See `jax.lax.psum`
|
520
|
+
for more details.
|
521
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
522
|
+
calculation for the variance.
|
523
|
+
|
524
|
+
|
525
|
+
References
|
526
|
+
----------
|
527
|
+
.. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
|
528
|
+
by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
|
529
|
+
|
530
|
+
'''
|
531
|
+
|
532
|
+
BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
|
533
|
+
BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
|
534
|
+
BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
|
535
|
+
|
536
|
+
|
537
|
+
class LayerNorm(Module):
|
538
|
+
"""
|
539
|
+
Layer normalization (https://arxiv.org/abs/1607.06450).
|
540
|
+
|
541
|
+
LayerNorm normalizes the activations of the layer for each given example in a
|
542
|
+
batch independently, rather than across a batch like Batch Normalization.
|
543
|
+
i.e. applies a transformation that maintains the mean activation within
|
544
|
+
each example close to 0 and the activation standard deviation close to 1.
|
545
|
+
|
546
|
+
Example usage::
|
547
|
+
|
548
|
+
>>> import brainstate as brainstate
|
549
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
550
|
+
>>> layer = brainstate.nn.LayerNorm(x.shape)
|
551
|
+
>>> layer.states()
|
552
|
+
>>> y = layer(x)
|
553
|
+
|
554
|
+
Attributes:
|
555
|
+
in_size: The input shape, without batch size.
|
556
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
557
|
+
dtype: the dtype of the result (default: infer from input and params).
|
558
|
+
use_bias: If True, bias (beta) is added.
|
559
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
560
|
+
(also e.g. nnx.relu), this can be disabled since the scaling will be done
|
561
|
+
by the next layer.
|
562
|
+
bias_init: Initializer for bias, by default, zero.
|
563
|
+
scale_init: Initializer for scale, by default, one.
|
564
|
+
reduction_axes: Axes for computing normalization statistics. It is recommended
|
565
|
+
to use the negative integer, since when the batch dimension is used,
|
566
|
+
the reduction_axes may be wrong when using the positive integer.
|
567
|
+
feature_axes: Feature axes for learned bias and scaling.
|
568
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
569
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
570
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
571
|
+
array being normalized is sharded across devices within a pmap.
|
572
|
+
axis_index_groups: groups of axis indices within that named axis
|
573
|
+
representing subsets of devices to reduce over (default: None). For
|
574
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
|
575
|
+
the examples on the first two and last two devices. See ``jax.lax.psum``
|
576
|
+
for more details.
|
577
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
578
|
+
calculation for the variance.
|
579
|
+
"""
|
580
|
+
|
581
|
+
def __init__(
|
582
|
+
self,
|
583
|
+
in_size: Size,
|
584
|
+
reduction_axes: Axes = -1,
|
585
|
+
feature_axes: Axes = -1,
|
586
|
+
*,
|
587
|
+
epsilon: float = 1e-6,
|
588
|
+
use_bias: bool = True,
|
589
|
+
use_scale: bool = True,
|
590
|
+
bias_init: Callable = init.ZeroInit(),
|
591
|
+
scale_init: Callable = init.Constant(1.0),
|
592
|
+
axis_name: Optional[str] = None,
|
593
|
+
axis_index_groups: Any = None,
|
594
|
+
use_fast_variance: bool = True,
|
595
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
596
|
+
param_type: type = NormalizationParamState,
|
597
|
+
):
|
598
|
+
super().__init__()
|
599
|
+
|
600
|
+
self.in_size = in_size
|
601
|
+
self.out_size = in_size
|
602
|
+
|
603
|
+
# parameters about axis
|
604
|
+
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
|
605
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
|
606
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
607
|
+
self.axis_name = axis_name
|
608
|
+
self.axis_index_groups = axis_index_groups
|
609
|
+
|
610
|
+
# variables
|
611
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
612
|
+
for i, ax in enumerate(self.in_size)])
|
613
|
+
|
614
|
+
weights = dict()
|
615
|
+
if use_scale:
|
616
|
+
weights['scale'] = init.param(scale_init, feature_shape)
|
617
|
+
if use_bias:
|
618
|
+
weights['bias'] = init.param(bias_init, feature_shape)
|
619
|
+
if len(weights):
|
620
|
+
self.weight = param_type(weights)
|
621
|
+
else:
|
622
|
+
self.weight = None
|
623
|
+
|
624
|
+
# parameters
|
625
|
+
self.epsilon = epsilon
|
626
|
+
self.dtype = dtype or environ.dftype()
|
627
|
+
self.use_bias = use_bias
|
628
|
+
self.use_scale = use_scale
|
629
|
+
self.bias_init = bias_init
|
630
|
+
self.scale_init = scale_init
|
631
|
+
self.use_fast_variance = use_fast_variance
|
632
|
+
|
633
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
634
|
+
"""Applies layer normalization on the input.
|
635
|
+
|
636
|
+
Args:
|
637
|
+
x: the inputs
|
638
|
+
|
639
|
+
Returns:
|
640
|
+
Normalized inputs (the same shape as inputs).
|
641
|
+
"""
|
642
|
+
mean, var = _compute_stats(
|
643
|
+
x,
|
644
|
+
self.reduction_axes,
|
645
|
+
dtype=self.dtype,
|
646
|
+
axis_name=self.axis_name,
|
647
|
+
axis_index_groups=self.axis_index_groups,
|
648
|
+
use_fast_variance=self.use_fast_variance,
|
649
|
+
mask=mask,
|
650
|
+
)
|
651
|
+
|
652
|
+
return _normalize(
|
653
|
+
x,
|
654
|
+
mean=mean,
|
655
|
+
var=var,
|
656
|
+
weights=self.weight,
|
657
|
+
reduction_axes=self.reduction_axes,
|
658
|
+
feature_axes=self.feature_axes,
|
659
|
+
dtype=self.dtype,
|
660
|
+
epsilon=self.epsilon,
|
661
|
+
)
|
662
|
+
|
663
|
+
|
664
|
+
class RMSNorm(Module):
|
665
|
+
"""
|
666
|
+
RMS Layer normalization (https://arxiv.org/abs/1910.07467).
|
667
|
+
|
668
|
+
RMSNorm normalizes the activations of the layer for each given example in a
|
669
|
+
batch independently, rather than across a batch like Batch Normalization.
|
670
|
+
Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
|
671
|
+
standard deviation of the activations, RMSNorm does not re-center at all
|
672
|
+
and instead normalizes by the root mean square of the activations.
|
673
|
+
|
674
|
+
Example usage::
|
675
|
+
|
676
|
+
>>> import brainstate as brainstate
|
677
|
+
>>> x = brainstate.random.normal(size=(5, 6))
|
678
|
+
>>> layer = brainstate.nn.RMSNorm(num_features=6)
|
679
|
+
>>> layer.states()
|
680
|
+
>>> y = layer(x)
|
681
|
+
|
682
|
+
Attributes:
|
683
|
+
in_size: The input shape, without batch size.
|
684
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
685
|
+
dtype: the dtype of the result (default: infer from input and params).
|
686
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
687
|
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
688
|
+
by the next layer.
|
689
|
+
scale_init: Initializer for scale, by default, one.
|
690
|
+
reduction_axes: Axes for computing normalization statistics. It is recommended
|
691
|
+
to use the negative integer, since when the batch dimension is used,
|
692
|
+
the reduction_axes may be wrong when using the positive integer.
|
693
|
+
feature_axes: Feature axes for learned bias and scaling.
|
694
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
695
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
696
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
697
|
+
array being normalized is sharded across devices within a pmap.
|
698
|
+
axis_index_groups: groups of axis indices within that named axis
|
699
|
+
representing subsets of devices to reduce over (default: None). For
|
700
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
|
701
|
+
the examples on the first two and last two devices. See ``jax.lax.psum``
|
702
|
+
for more details.
|
703
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
704
|
+
calculation for the variance.
|
705
|
+
"""
|
706
|
+
|
707
|
+
def __init__(
|
708
|
+
self,
|
709
|
+
in_size: Size,
|
710
|
+
*,
|
711
|
+
epsilon: float = 1e-6,
|
712
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
713
|
+
use_scale: bool = True,
|
714
|
+
scale_init: Callable = init.Constant(1.0),
|
715
|
+
reduction_axes: Axes = -1,
|
716
|
+
feature_axes: Axes = -1,
|
717
|
+
axis_name: Optional[str] = None,
|
718
|
+
axis_index_groups: Any = None,
|
719
|
+
use_fast_variance: bool = True,
|
720
|
+
param_type: type = NormalizationParamState,
|
721
|
+
):
|
722
|
+
super().__init__()
|
723
|
+
|
724
|
+
self.in_size = in_size
|
725
|
+
self.out_size = in_size
|
726
|
+
|
727
|
+
# parameters about axis
|
728
|
+
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
|
729
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
|
730
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
731
|
+
self.axis_name = axis_name
|
732
|
+
self.axis_index_groups = axis_index_groups
|
733
|
+
|
734
|
+
# variables
|
735
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
736
|
+
for i, ax in enumerate(self.in_size)])
|
737
|
+
if use_scale:
|
738
|
+
self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
|
739
|
+
else:
|
740
|
+
self.scale = None
|
741
|
+
|
742
|
+
# parameters
|
743
|
+
self.epsilon = epsilon
|
744
|
+
self.dtype = dtype or environ.dftype()
|
745
|
+
self.use_scale = use_scale
|
746
|
+
self.scale_init = scale_init
|
747
|
+
self.use_fast_variance = use_fast_variance
|
748
|
+
|
749
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
750
|
+
"""Applies layer normalization on the input.
|
751
|
+
|
752
|
+
Args:
|
753
|
+
x: the inputs
|
754
|
+
mask: the mask
|
755
|
+
|
756
|
+
Returns:
|
757
|
+
Normalized inputs (the same shape as inputs).
|
758
|
+
"""
|
759
|
+
mean, var = _compute_stats(
|
760
|
+
x,
|
761
|
+
self.reduction_axes,
|
762
|
+
dtype=self.dtype,
|
763
|
+
axis_name=self.axis_name,
|
764
|
+
axis_index_groups=self.axis_index_groups,
|
765
|
+
use_mean=False,
|
766
|
+
use_fast_variance=self.use_fast_variance,
|
767
|
+
mask=mask,
|
768
|
+
)
|
769
|
+
|
770
|
+
return _normalize(
|
771
|
+
x,
|
772
|
+
mean=mean,
|
773
|
+
var=var,
|
774
|
+
weights=self.scale,
|
775
|
+
reduction_axes=self.reduction_axes,
|
776
|
+
feature_axes=self.feature_axes,
|
777
|
+
dtype=self.dtype,
|
778
|
+
epsilon=self.epsilon,
|
779
|
+
)
|
780
|
+
|
781
|
+
|
782
|
+
class GroupNorm(Module):
|
783
|
+
"""
|
784
|
+
Group normalization (arxiv.org/abs/1803.08494).
|
785
|
+
|
786
|
+
This op is similar to batch normalization, but statistics are shared across
|
787
|
+
equally-sized groups of channels and not shared across batch dimension.
|
788
|
+
Thus, group normalization does not depend on the batch composition and does
|
789
|
+
not require maintaining internal state for storing statistics.
|
790
|
+
The user should either specify the total number of channel groups or the
|
791
|
+
number of channels per group.
|
792
|
+
|
793
|
+
.. note::
|
794
|
+
LayerNorm is a special case of GroupNorm where ``num_groups=1``.
|
795
|
+
|
796
|
+
Example usage::
|
797
|
+
|
798
|
+
>>> import numpy as np
|
799
|
+
>>> import brainstate as brainstate
|
800
|
+
...
|
801
|
+
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
|
802
|
+
>>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
|
803
|
+
>>> layer.states()
|
804
|
+
>>> y = layer(x)
|
805
|
+
>>> y = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
|
806
|
+
>>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
|
807
|
+
>>> np.testing.assert_allclose(y, y2)
|
808
|
+
|
809
|
+
Attributes:
|
810
|
+
in_size: The input shape, without batch size.
|
811
|
+
num_groups: the total number of channel groups. The default value of 32 is
|
812
|
+
proposed by the original group normalization paper.
|
813
|
+
group_size: the number of channels in a group.
|
814
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
815
|
+
dtype: the dtype of the result (default: infer from input and params).
|
816
|
+
use_bias: If True, bias (beta) is added.
|
817
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
818
|
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
819
|
+
by the next layer.
|
820
|
+
bias_init: Initializer for bias, by default, zero.
|
821
|
+
scale_init: Initializer for scale, by default, one.
|
822
|
+
reduction_axes: List of axes used for computing normalization statistics.
|
823
|
+
This list must include the final dimension, which is assumed to be the
|
824
|
+
feature axis. Furthermore, if the input used at call time has additional
|
825
|
+
leading axes compared to the data used for initialisation, for example due
|
826
|
+
to batching, then the reduction axes need to be defined explicitly.
|
827
|
+
It is recommended to use the negative integer, since when the batch dimension is used,
|
828
|
+
the reduction_axes may be wrong when using the positive integer.
|
829
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
830
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
831
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
832
|
+
array being normalized is sharded across devices within a pmap or shard
|
833
|
+
map. For SPMD jit, you do not need to manually synchronize. Just make sure
|
834
|
+
that the axes are correctly annotated and XLA:SPMD will insert the
|
835
|
+
necessary collectives.
|
836
|
+
axis_index_groups: groups of axis indices within that named axis
|
837
|
+
representing subsets of devices to reduce over (default: None). For
|
838
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
|
839
|
+
examples on the first two and last two devices. See ``jax.lax.psum`` for
|
840
|
+
more details.
|
841
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
842
|
+
calculation for the variance.
|
843
|
+
"""
|
844
|
+
|
845
|
+
def __init__(
|
846
|
+
self,
|
847
|
+
in_size: Size,
|
848
|
+
feature_axis: Axes = -1,
|
849
|
+
num_groups: Optional[int] = 32,
|
850
|
+
group_size: Optional[int] = None,
|
851
|
+
*,
|
852
|
+
epsilon: float = 1e-6,
|
853
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
854
|
+
use_bias: bool = True,
|
855
|
+
use_scale: bool = True,
|
856
|
+
bias_init: Callable = init.ZeroInit(),
|
857
|
+
scale_init: Callable = init.Constant(1.),
|
858
|
+
reduction_axes: Optional[Axes] = None,
|
859
|
+
axis_name: Optional[str] = None,
|
860
|
+
axis_index_groups: Any = None,
|
861
|
+
use_fast_variance: bool = True,
|
862
|
+
param_type: type = NormalizationParamState,
|
863
|
+
):
|
864
|
+
super().__init__()
|
865
|
+
|
866
|
+
self.in_size = in_size
|
867
|
+
self.out_size = in_size
|
868
|
+
|
869
|
+
# parameters about axis
|
870
|
+
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
871
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
|
872
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
873
|
+
self.axis_name = axis_name
|
874
|
+
self.axis_index_groups = axis_index_groups
|
875
|
+
|
876
|
+
if (num_groups is None and group_size is None) or (
|
877
|
+
num_groups is not None and group_size is not None
|
878
|
+
):
|
879
|
+
raise ValueError(
|
880
|
+
'Either `num_groups` or `group_size` should be '
|
881
|
+
'specified. If `group_size` is to be specified, '
|
882
|
+
'pass `num_groups=None` as argument to override '
|
883
|
+
'the default `num_groups` value of 32.'
|
884
|
+
)
|
885
|
+
|
886
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
887
|
+
for i, ax in enumerate(self.in_size)])
|
888
|
+
assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
|
889
|
+
num_features = feature_shape[0]
|
890
|
+
if group_size is not None:
|
891
|
+
if num_features % group_size != 0:
|
892
|
+
raise ValueError(
|
893
|
+
'Number of features ({}) is not multiple of the '
|
894
|
+
'group size ({}).'.format(num_features, group_size)
|
895
|
+
)
|
896
|
+
self.num_groups = num_features // group_size
|
897
|
+
self.group_size = group_size
|
898
|
+
else:
|
899
|
+
if not isinstance(num_groups, int) or num_groups <= 0 or (
|
900
|
+
num_features % num_groups != 0
|
901
|
+
):
|
902
|
+
raise ValueError(
|
903
|
+
'Number of groups ({}) does not divide the number'
|
904
|
+
' of channels ({}).'.format(num_groups, num_features)
|
905
|
+
)
|
906
|
+
self.num_groups = num_groups
|
907
|
+
self.group_size = num_features // num_groups
|
908
|
+
|
909
|
+
# variables
|
910
|
+
weights = dict()
|
911
|
+
if use_scale:
|
912
|
+
weights['scale'] = init.param(scale_init, feature_shape)
|
913
|
+
if use_bias:
|
914
|
+
weights['bias'] = init.param(bias_init, feature_shape)
|
915
|
+
if len(weights):
|
916
|
+
self.weight = param_type(weights)
|
917
|
+
else:
|
918
|
+
self.weight = None
|
919
|
+
|
920
|
+
# parameters
|
921
|
+
self.epsilon = epsilon
|
922
|
+
self.dtype = dtype
|
923
|
+
self.use_bias = use_bias
|
924
|
+
self.use_scale = use_scale
|
925
|
+
self.bias_init = bias_init
|
926
|
+
self.scale_init = scale_init
|
927
|
+
self.use_fast_variance = use_fast_variance
|
928
|
+
|
929
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
930
|
+
"""Applies group normalization to the input (arxiv.org/abs/1803.08494).
|
931
|
+
|
932
|
+
Args:
|
933
|
+
x: the input of shape ``...self.num_features`` where ``self.num_features``
|
934
|
+
is a channels dimension and ``...`` represents an arbitrary number of
|
935
|
+
extra dimensions that can be used to accumulate statistics over. If no
|
936
|
+
reduction axes have been specified then all additional dimensions ``...``
|
937
|
+
will be used to accumulate statistics apart from the leading dimension
|
938
|
+
which is assumed to represent the batch.
|
939
|
+
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
|
940
|
+
the positions for which the mean and variance should be computed.
|
941
|
+
|
942
|
+
Returns:
|
943
|
+
Normalized inputs (the same shape as inputs).
|
944
|
+
"""
|
945
|
+
if self.reduction_axes is not None:
|
946
|
+
reduction_axes = self.reduction_axes
|
947
|
+
else:
|
948
|
+
reduction_axes = list(range(1, x.ndim - 1)) + [-1]
|
949
|
+
reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
|
950
|
+
|
951
|
+
group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
|
952
|
+
if mask is not None:
|
953
|
+
mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
|
954
|
+
|
955
|
+
mean, var = _compute_stats(
|
956
|
+
x.reshape(group_shape),
|
957
|
+
list(reduction_axes[:-1]) + [-1],
|
958
|
+
dtype=self.dtype,
|
959
|
+
axis_name=self.axis_name,
|
960
|
+
axis_index_groups=self.axis_index_groups,
|
961
|
+
use_fast_variance=self.use_fast_variance,
|
962
|
+
mask=mask,
|
963
|
+
)
|
964
|
+
mean = jnp.repeat(mean, self.group_size, axis=1)
|
965
|
+
var = jnp.repeat(var, self.group_size, axis=1)
|
966
|
+
return _normalize(
|
967
|
+
x,
|
968
|
+
mean=mean,
|
969
|
+
var=var,
|
970
|
+
weights=self.weight,
|
971
|
+
reduction_axes=reduction_axes[:-1],
|
972
|
+
feature_axes=self.feature_axes,
|
973
|
+
dtype=self.dtype,
|
974
|
+
epsilon=self.epsilon,
|
975
|
+
)
|