brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/transform/_mapping.py
CHANGED
@@ -1,529 +1,607 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from typing import (
|
18
|
-
Any,
|
19
|
-
TypeVar,
|
20
|
-
Callable,
|
21
|
-
Hashable,
|
22
|
-
Sequence,
|
23
|
-
Iterable,
|
24
|
-
Tuple,
|
25
|
-
Union,
|
26
|
-
Optional,
|
27
|
-
Dict
|
28
|
-
)
|
29
|
-
|
30
|
-
import jax
|
31
|
-
|
32
|
-
from brainstate._compatible_import import Device
|
33
|
-
from brainstate._state import catch_new_states
|
34
|
-
from brainstate._utils import set_module_as
|
35
|
-
from brainstate.typing import Missing, Filter
|
36
|
-
from brainstate.util import NestedDict
|
37
|
-
from ._loop_collect_return import scan
|
38
|
-
from ._make_jaxpr import StatefulMapping
|
39
|
-
|
40
|
-
__all__ = [
|
41
|
-
'vmap',
|
42
|
-
'pmap',
|
43
|
-
'map',
|
44
|
-
'vmap_new_states',
|
45
|
-
]
|
46
|
-
|
47
|
-
F = TypeVar("F", bound=Callable)
|
48
|
-
AxisName = Hashable
|
49
|
-
|
50
|
-
|
51
|
-
@set_module_as('brainstate.transform')
|
52
|
-
def vmap(
|
53
|
-
fn: F | Missing = Missing(),
|
54
|
-
*,
|
55
|
-
# --- normal jax.vmap arguments --- #
|
56
|
-
in_axes: int | None | Sequence[Any] = 0,
|
57
|
-
out_axes: Any = 0,
|
58
|
-
axis_name: AxisName | None = None,
|
59
|
-
axis_size: int | None = None,
|
60
|
-
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
61
|
-
# --- brainstate specific arguments --- #
|
62
|
-
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
63
|
-
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
)
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
donate_argnums=donate_argnums,
|
300
|
-
global_arg_shapes=global_arg_shapes,
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
)
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
)
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
)
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from typing import (
|
18
|
+
Any,
|
19
|
+
TypeVar,
|
20
|
+
Callable,
|
21
|
+
Hashable,
|
22
|
+
Sequence,
|
23
|
+
Iterable,
|
24
|
+
Tuple,
|
25
|
+
Union,
|
26
|
+
Optional,
|
27
|
+
Dict
|
28
|
+
)
|
29
|
+
|
30
|
+
import jax
|
31
|
+
|
32
|
+
from brainstate._compatible_import import Device
|
33
|
+
from brainstate._state import catch_new_states
|
34
|
+
from brainstate._utils import set_module_as
|
35
|
+
from brainstate.typing import Missing, Filter
|
36
|
+
from brainstate.util import NestedDict
|
37
|
+
from ._loop_collect_return import scan
|
38
|
+
from ._make_jaxpr import StatefulMapping
|
39
|
+
|
40
|
+
__all__ = [
|
41
|
+
'vmap',
|
42
|
+
'pmap',
|
43
|
+
'map',
|
44
|
+
'vmap_new_states',
|
45
|
+
]
|
46
|
+
|
47
|
+
F = TypeVar("F", bound=Callable)
|
48
|
+
AxisName = Hashable
|
49
|
+
|
50
|
+
|
51
|
+
@set_module_as('brainstate.transform')
|
52
|
+
def vmap(
|
53
|
+
fn: F | Missing = Missing(),
|
54
|
+
*,
|
55
|
+
# --- normal jax.vmap arguments --- #
|
56
|
+
in_axes: int | None | Sequence[Any] = 0,
|
57
|
+
out_axes: Any = 0,
|
58
|
+
axis_name: AxisName | None = None,
|
59
|
+
axis_size: int | None = None,
|
60
|
+
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
61
|
+
# --- brainstate specific arguments --- #
|
62
|
+
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
63
|
+
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
64
|
+
unexpected_out_state_mapping: str = 'raise',
|
65
|
+
) -> StatefulMapping | Callable[[F], StatefulMapping]:
|
66
|
+
"""
|
67
|
+
Vectorize a callable while preserving BrainState state semantics.
|
68
|
+
|
69
|
+
This helper mirrors :func:`jax.vmap` but routes execution through
|
70
|
+
:class:`~brainstate.transform.StatefulMapping` so that reads and writes to
|
71
|
+
:class:`~brainstate.State` instances (including newly created random states)
|
72
|
+
are tracked correctly across the mapped axis. The returned object can be used
|
73
|
+
directly or as a decorator when ``fn`` is omitted.
|
74
|
+
|
75
|
+
Parameters
|
76
|
+
----------
|
77
|
+
fn : callable, optional
|
78
|
+
Function to be vectorised. If omitted, the function acts as a decorator.
|
79
|
+
in_axes : int | None | sequence, default 0
|
80
|
+
Mapping specification for positional arguments, following the semantics
|
81
|
+
of :func:`jax.vmap`.
|
82
|
+
out_axes : any, default 0
|
83
|
+
Placement of the mapped axis in the result. Must broadcast with the
|
84
|
+
structure of the outputs.
|
85
|
+
axis_name : hashable, optional
|
86
|
+
Name for the mapped axis so that collective primitives (e.g. ``lax.psum``)
|
87
|
+
can target it.
|
88
|
+
axis_size : int, optional
|
89
|
+
Explicit size of the mapped axis. If omitted, the size is inferred from
|
90
|
+
the arguments.
|
91
|
+
spmd_axis_name : hashable or tuple[hashable], optional
|
92
|
+
Axis labels used when the transformed function is itself executed inside
|
93
|
+
another SPMD transform (e.g. nested :func:`vmap` or :func:`pmap`).
|
94
|
+
state_in_axes : dict[AxisName, Filter] or Filter, optional
|
95
|
+
Filters identifying which :class:`State` objects should be batched on
|
96
|
+
input. Passing a single filter is shorthand for ``{0: filter}``. Filters
|
97
|
+
are converted with :func:`brainstate.util.filter.to_predicate`.
|
98
|
+
state_out_axes : dict[AxisName, Filter] or Filter, optional
|
99
|
+
Filters describing how written states are scattered back across the
|
100
|
+
mapped axis. Semantics mirror ``state_in_axes``.
|
101
|
+
unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
|
102
|
+
Policy when a state is written during the mapped call but not matched by
|
103
|
+
``state_out_axes``. ``'raise'`` propagates a :class:`BatchAxisError`,
|
104
|
+
``'warn'`` emits a warning, and ``'ignore'`` silently accepts the state.
|
105
|
+
|
106
|
+
Returns
|
107
|
+
-------
|
108
|
+
StatefulMapping or callable
|
109
|
+
If ``fn`` is supplied, returns a :class:`StatefulMapping` instance that
|
110
|
+
behaves like ``fn`` but with batch semantics. Otherwise a decorator is
|
111
|
+
returned.
|
112
|
+
|
113
|
+
Raises
|
114
|
+
------
|
115
|
+
ValueError
|
116
|
+
If axis sizes are inconsistent or cannot be inferred.
|
117
|
+
BatchAxisError
|
118
|
+
If a state write violates ``state_out_axes`` and the policy is ``'raise'``.
|
119
|
+
|
120
|
+
Examples
|
121
|
+
--------
|
122
|
+
.. code-block:: python
|
123
|
+
|
124
|
+
>>> import brainstate as bst
|
125
|
+
>>> import jax.numpy as jnp
|
126
|
+
>>> from brainstate.util.filter import OfType
|
127
|
+
>>>
|
128
|
+
>>> counter = bst.ShortTermState(jnp.array(0.0))
|
129
|
+
>>>
|
130
|
+
>>> @bst.transform.vmap(
|
131
|
+
... in_axes=0,
|
132
|
+
... out_axes=0,
|
133
|
+
... state_in_axes={0: OfType(bst.ShortTermState)},
|
134
|
+
... state_out_axes={0: OfType(bst.ShortTermState)},
|
135
|
+
... )
|
136
|
+
... def accumulate(x):
|
137
|
+
... counter.value = counter.value + x
|
138
|
+
... return counter.value
|
139
|
+
>>>
|
140
|
+
>>> xs = jnp.arange(3.0)
|
141
|
+
>>> accumulate(xs)
|
142
|
+
Array([0., 1., 3.], dtype=float32)
|
143
|
+
>>> counter.value
|
144
|
+
Array(3., dtype=float32)
|
145
|
+
|
146
|
+
See Also
|
147
|
+
--------
|
148
|
+
brainstate.transform.StatefulMapping : Underlying state-aware mapping helper.
|
149
|
+
pmap : Parallel mapping variant for multiple devices.
|
150
|
+
vmap_new_states : Vectorize newly created states within ``fn``.
|
151
|
+
"""
|
152
|
+
|
153
|
+
if isinstance(fn, Missing):
|
154
|
+
return functools.partial(
|
155
|
+
vmap,
|
156
|
+
in_axes=in_axes,
|
157
|
+
out_axes=out_axes,
|
158
|
+
state_in_axes=state_in_axes,
|
159
|
+
state_out_axes=state_out_axes,
|
160
|
+
axis_name=axis_name,
|
161
|
+
axis_size=axis_size,
|
162
|
+
spmd_axis_name=spmd_axis_name,
|
163
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
164
|
+
) # type: ignore[return-value]
|
165
|
+
|
166
|
+
return StatefulMapping(
|
167
|
+
fn,
|
168
|
+
in_axes=in_axes,
|
169
|
+
out_axes=out_axes,
|
170
|
+
state_in_axes=state_in_axes,
|
171
|
+
state_out_axes=state_out_axes,
|
172
|
+
axis_name=axis_name,
|
173
|
+
axis_size=axis_size,
|
174
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
175
|
+
mapping_fn=functools.partial(jax.vmap, spmd_axis_name=spmd_axis_name),
|
176
|
+
name='vmap'
|
177
|
+
)
|
178
|
+
|
179
|
+
|
180
|
+
@set_module_as('brainstate.transform')
|
181
|
+
def pmap(
|
182
|
+
fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
|
183
|
+
axis_name: Optional[AxisName] = None,
|
184
|
+
*,
|
185
|
+
in_axes: Any = 0,
|
186
|
+
out_axes: Any = 0,
|
187
|
+
static_broadcasted_argnums: int | Iterable[int] = (),
|
188
|
+
devices: Optional[Sequence[Device]] = None, # noqa: F811
|
189
|
+
backend: Optional[str] = None,
|
190
|
+
axis_size: Optional[int] = None,
|
191
|
+
donate_argnums: int | Iterable[int] = (),
|
192
|
+
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
193
|
+
# --- brainstate specific arguments --- #
|
194
|
+
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
195
|
+
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
196
|
+
unexpected_out_state_mapping: str = 'raise',
|
197
|
+
) -> Callable[[F], F] | F:
|
198
|
+
"""
|
199
|
+
Parallel mapping with state-aware semantics across devices.
|
200
|
+
|
201
|
+
This function mirrors :func:`jax.pmap` but integrates with
|
202
|
+
:class:`~brainstate.transform.StatefulMapping` so that
|
203
|
+
:class:`~brainstate.State` objects (including random states) are replicated
|
204
|
+
and restored correctly on every device. When ``fn`` is omitted the function
|
205
|
+
can be used as a decorator.
|
206
|
+
|
207
|
+
Parameters
|
208
|
+
----------
|
209
|
+
fn : callable, optional
|
210
|
+
Function to execute in SPMD style. If omitted, a decorator is returned.
|
211
|
+
axis_name : hashable, optional
|
212
|
+
Name for the mapped axis used by collective primitives.
|
213
|
+
in_axes : any, default 0
|
214
|
+
Axis mapping for positional arguments, identical to :func:`jax.pmap`.
|
215
|
+
out_axes : any, default 0
|
216
|
+
Placement of the mapped axis in the outputs.
|
217
|
+
static_broadcasted_argnums : int or iterable[int], default ()
|
218
|
+
Indices of positional arguments to treat as compile-time constants.
|
219
|
+
devices : sequence[Device], optional
|
220
|
+
Explicit device list to map over. Must be identical on every host in
|
221
|
+
multi-host setups.
|
222
|
+
backend : str, optional
|
223
|
+
Backend identifier (``'cpu'``, ``'gpu'``, or ``'tpu'``).
|
224
|
+
axis_size : int, optional
|
225
|
+
Size of the mapped axis. Defaults to ``len(devices)`` or the local device
|
226
|
+
count when ``devices`` is ``None``.
|
227
|
+
donate_argnums : int or iterable[int], default ()
|
228
|
+
Positional arguments whose buffers may be donated to the computation.
|
229
|
+
global_arg_shapes : tuple[tuple[int, ...], ...], optional
|
230
|
+
Shapes for globally distributed arguments (i.e. arguments not replicated
|
231
|
+
across devices).
|
232
|
+
state_in_axes : dict[AxisName, Filter] or Filter, optional
|
233
|
+
Filters indicating which states should be treated as device-mapped inputs.
|
234
|
+
state_out_axes : dict[AxisName, Filter] or Filter, optional
|
235
|
+
Filters describing how state writes are scattered back to devices.
|
236
|
+
unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
|
237
|
+
Policy applied when a state write is not covered by ``state_out_axes``.
|
238
|
+
rngs : Any, optional
|
239
|
+
Optional RNG seeds passed through to ``fn``. They are restored to their
|
240
|
+
original values after execution.
|
241
|
+
|
242
|
+
Returns
|
243
|
+
-------
|
244
|
+
StatefulMapping or callable
|
245
|
+
If ``fn`` is provided, returns a :class:`StatefulMapping` executing ``fn``
|
246
|
+
over devices. Otherwise returns a decorator that produces such an object.
|
247
|
+
|
248
|
+
Raises
|
249
|
+
------
|
250
|
+
ValueError
|
251
|
+
If ``axis_size`` or argument shapes are inconsistent.
|
252
|
+
BatchAxisError
|
253
|
+
If an unexpected state write occurs and the policy is ``'raise'``.
|
254
|
+
|
255
|
+
Examples
|
256
|
+
--------
|
257
|
+
.. code-block:: python
|
258
|
+
|
259
|
+
>>> import brainstate as bst
|
260
|
+
>>> import jax.numpy as jnp
|
261
|
+
>>> from brainstate.util.filter import OfType
|
262
|
+
>>>
|
263
|
+
>>> weights = bst.ParamState(jnp.ones((4,)))
|
264
|
+
>>>
|
265
|
+
>>> @bst.transform.pmap(
|
266
|
+
... axis_name='devices',
|
267
|
+
... in_axes=0,
|
268
|
+
... out_axes=0,
|
269
|
+
... state_in_axes={0: OfType(bst.ParamState)},
|
270
|
+
... state_out_axes={0: OfType(bst.ParamState)},
|
271
|
+
... )
|
272
|
+
... def update(delta):
|
273
|
+
... weights.value = weights.value + delta
|
274
|
+
... return weights.value
|
275
|
+
>>>
|
276
|
+
>>> deltas = jnp.arange(jax.local_device_count() * 4.).reshape(
|
277
|
+
... jax.local_device_count(), 4
|
278
|
+
... )
|
279
|
+
>>> updated = update(deltas)
|
280
|
+
>>> updated.shape
|
281
|
+
(jax.local_device_count(), 4)
|
282
|
+
|
283
|
+
See Also
|
284
|
+
--------
|
285
|
+
jax.pmap : Underlying JAX primitive.
|
286
|
+
vmap : Single-host vectorisation with the same state semantics.
|
287
|
+
"""
|
288
|
+
|
289
|
+
if isinstance(fn, Missing):
|
290
|
+
return functools.partial(
|
291
|
+
pmap,
|
292
|
+
axis_name=axis_name,
|
293
|
+
in_axes=in_axes,
|
294
|
+
out_axes=out_axes,
|
295
|
+
static_broadcasted_argnums=static_broadcasted_argnums,
|
296
|
+
devices=devices,
|
297
|
+
backend=backend,
|
298
|
+
axis_size=axis_size,
|
299
|
+
donate_argnums=donate_argnums,
|
300
|
+
global_arg_shapes=global_arg_shapes,
|
301
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
302
|
+
) # type: ignore[return-value]
|
303
|
+
|
304
|
+
return StatefulMapping(
|
305
|
+
fn,
|
306
|
+
in_axes=in_axes,
|
307
|
+
out_axes=out_axes,
|
308
|
+
state_in_axes=state_in_axes,
|
309
|
+
state_out_axes=state_out_axes,
|
310
|
+
axis_name=axis_name,
|
311
|
+
axis_size=axis_size,
|
312
|
+
mapping_fn=functools.partial(
|
313
|
+
jax.pmap,
|
314
|
+
static_broadcasted_argnums=static_broadcasted_argnums,
|
315
|
+
devices=devices,
|
316
|
+
backend=backend,
|
317
|
+
donate_argnums=donate_argnums,
|
318
|
+
global_arg_shapes=global_arg_shapes,
|
319
|
+
),
|
320
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
321
|
+
name='pmap'
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
def _batch_and_remainder(x, batch_size: int):
|
326
|
+
leaves, tree_def = jax.tree.flatten(x)
|
327
|
+
|
328
|
+
scan_leaves = []
|
329
|
+
remainder_leaves = []
|
330
|
+
|
331
|
+
length = None
|
332
|
+
for leaf in leaves:
|
333
|
+
if length is None:
|
334
|
+
length = leaf.shape[0]
|
335
|
+
if length != leaf.shape[0]:
|
336
|
+
raise ValueError(f"All inputs must have the same length. Got {length} and {leaf.shape[0]}.")
|
337
|
+
|
338
|
+
num_batches, num_remainder = divmod(length, batch_size)
|
339
|
+
for leaf in leaves:
|
340
|
+
total_batch_elems = num_batches * batch_size
|
341
|
+
scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
|
342
|
+
if num_remainder:
|
343
|
+
remainder_leaves.append(leaf[total_batch_elems:])
|
344
|
+
|
345
|
+
scan_tree = tree_def.unflatten(scan_leaves)
|
346
|
+
if num_remainder:
|
347
|
+
remainder_tree = tree_def.unflatten(remainder_leaves)
|
348
|
+
return scan_tree, remainder_tree
|
349
|
+
else:
|
350
|
+
return scan_tree, None
|
351
|
+
|
352
|
+
|
353
|
+
@set_module_as('brainstate.transform')
|
354
|
+
def map(
|
355
|
+
f,
|
356
|
+
*xs,
|
357
|
+
batch_size: int | None = None,
|
358
|
+
):
|
359
|
+
"""
|
360
|
+
Apply a Python function over the leading axis of one or more pytrees.
|
361
|
+
|
362
|
+
Compared with :func:`jax.vmap`, this helper executes sequentially by default
|
363
|
+
(via :func:`jax.lax.scan`), making it useful when auto-vectorisation is
|
364
|
+
impractical or when memory usage must be reduced. Providing ``batch_size``
|
365
|
+
enables chunked evaluation that internally leverages :func:`vmap` to improve
|
366
|
+
throughput while keeping peak memory bounded.
|
367
|
+
|
368
|
+
Parameters
|
369
|
+
----------
|
370
|
+
f : callable
|
371
|
+
Function applied element-wise across the leading dimension. Its return
|
372
|
+
value must be a pytree whose leaves can be stacked along axis ``0``.
|
373
|
+
*xs : Any
|
374
|
+
Positional pytrees sharing the same length along their leading axis.
|
375
|
+
batch_size : int, optional
|
376
|
+
Size of vectorised blocks. When given, ``map`` first processes full
|
377
|
+
batches using :func:`vmap` then handles any remainder sequentially.
|
378
|
+
|
379
|
+
Returns
|
380
|
+
-------
|
381
|
+
Any
|
382
|
+
PyTree matching the structure of ``f``'s outputs with results stacked
|
383
|
+
along the leading dimension.
|
384
|
+
|
385
|
+
Raises
|
386
|
+
------
|
387
|
+
ValueError
|
388
|
+
If the inputs do not share the same leading length.
|
389
|
+
|
390
|
+
Examples
|
391
|
+
--------
|
392
|
+
.. code-block:: python
|
393
|
+
|
394
|
+
>>> import jax.numpy as jnp
|
395
|
+
>>> from brainstate.transform import map
|
396
|
+
>>>
|
397
|
+
>>> xs = jnp.arange(6).reshape(6, 1)
|
398
|
+
>>>
|
399
|
+
>>> def normalize(row):
|
400
|
+
... return row / (1.0 + jnp.linalg.norm(row))
|
401
|
+
>>>
|
402
|
+
>>> stacked = map(normalize, xs, batch_size=2)
|
403
|
+
>>> stacked.shape
|
404
|
+
(6, 1)
|
405
|
+
|
406
|
+
See Also
|
407
|
+
--------
|
408
|
+
vmap : Vectorised mapping with automatic batching.
|
409
|
+
jax.lax.scan : Primitive used for the sequential fallback.
|
410
|
+
"""
|
411
|
+
if batch_size is not None:
|
412
|
+
scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
|
413
|
+
g = lambda _, x: ((), vmap(f)(*x))
|
414
|
+
_, scan_ys = scan(g, (), scan_xs)
|
415
|
+
if remainder_xs is None:
|
416
|
+
ys = jax.tree.map(lambda x: _flatten(x), scan_ys)
|
417
|
+
else:
|
418
|
+
remainder_ys = vmap(f)(*remainder_xs)
|
419
|
+
ys = jax.tree.map(
|
420
|
+
lambda x, y: jax.lax.concatenate([_flatten(x), y], dimension=0),
|
421
|
+
scan_ys,
|
422
|
+
remainder_ys,
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
g = lambda _, x: ((), f(*x))
|
426
|
+
_, ys = scan(g, (), xs)
|
427
|
+
return ys
|
428
|
+
|
429
|
+
|
430
|
+
def _flatten(x):
|
431
|
+
return x.reshape(-1, *x.shape[2:])
|
432
|
+
|
433
|
+
|
434
|
+
def _vmap_new_states_transform(
|
435
|
+
fun: Callable[..., Any],
|
436
|
+
*,
|
437
|
+
# -- normal jax.vmap arguments -- #
|
438
|
+
in_axes: int | None | Sequence[Any] = 0,
|
439
|
+
out_axes: Any = 0,
|
440
|
+
axis_name: AxisName | None = None,
|
441
|
+
axis_size: int | None = None,
|
442
|
+
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
443
|
+
# -- brainstate specific arguments -- #
|
444
|
+
state_tag: str | None = None,
|
445
|
+
state_to_exclude: Filter | None = None,
|
446
|
+
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
447
|
+
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
448
|
+
unexpected_out_state_mapping: str = 'raise',
|
449
|
+
):
|
450
|
+
# TODO: How about nested call ``vmap_new_states``?
|
451
|
+
if isinstance(axis_size, int) and axis_size <= 0:
|
452
|
+
raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
|
453
|
+
|
454
|
+
@vmap(
|
455
|
+
in_axes=in_axes,
|
456
|
+
out_axes=out_axes,
|
457
|
+
axis_name=axis_name,
|
458
|
+
axis_size=axis_size,
|
459
|
+
spmd_axis_name=spmd_axis_name,
|
460
|
+
state_in_axes=state_in_axes,
|
461
|
+
state_out_axes=state_out_axes,
|
462
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
463
|
+
)
|
464
|
+
def new_fun(args):
|
465
|
+
# call the function
|
466
|
+
with catch_new_states(state_tag=state_tag, state_to_exclude=state_to_exclude) as catcher:
|
467
|
+
out = fun(*args)
|
468
|
+
|
469
|
+
# get vmap state values
|
470
|
+
vmap_state_vals = catcher.get_state_values()
|
471
|
+
|
472
|
+
return out, vmap_state_vals
|
473
|
+
|
474
|
+
@functools.wraps(fun)
|
475
|
+
def vmapped_fn(*args):
|
476
|
+
# vmapping
|
477
|
+
with catch_new_states(state_to_exclude=state_to_exclude) as catcher:
|
478
|
+
outs, vmap_state_vals = new_fun(args)
|
479
|
+
vmap_states = catcher.get_states()
|
480
|
+
|
481
|
+
# restore vmapped state values
|
482
|
+
for st_val, st in zip(vmap_state_vals, vmap_states):
|
483
|
+
st.restore_value(st_val)
|
484
|
+
# ------------------------------------------------
|
485
|
+
# --- this is CRUCIAL to avoid jax tracing leakage
|
486
|
+
# ------------------------------------------------
|
487
|
+
st.decrease_stack_level()
|
488
|
+
return outs
|
489
|
+
|
490
|
+
return vmapped_fn
|
491
|
+
|
492
|
+
|
493
|
+
@set_module_as('brainstate.transform')
|
494
|
+
def vmap_new_states(
|
495
|
+
fun: Callable = Missing(),
|
496
|
+
*,
|
497
|
+
# -- normal jax.vmap arguments -- #
|
498
|
+
in_axes: int | None | Sequence[Any] = 0,
|
499
|
+
out_axes: Any = 0,
|
500
|
+
axis_name: AxisName | None = None,
|
501
|
+
axis_size: int | None = None,
|
502
|
+
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
503
|
+
# -- brainstate specific arguments -- #
|
504
|
+
state_tag: str | None = None,
|
505
|
+
state_to_exclude: Filter = None,
|
506
|
+
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
507
|
+
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
508
|
+
unexpected_out_state_mapping: str = 'raise',
|
509
|
+
):
|
510
|
+
"""
|
511
|
+
Vectorise a function that creates new BrainState states on the fly.
|
512
|
+
|
513
|
+
The helper wraps :func:`vmap` but also captures states instantiated inside
|
514
|
+
``fun`` via :func:`brainstate._state.catch_new_states`. Newly created states
|
515
|
+
are materialised for each batch element and restored after execution so that
|
516
|
+
their side effects persist exactly once. When ``fun`` is omitted the helper
|
517
|
+
can be used as a decorator.
|
518
|
+
|
519
|
+
Parameters
|
520
|
+
----------
|
521
|
+
fun : callable, optional
|
522
|
+
Function to transform. If omitted, :func:`vmap_new_states` returns a
|
523
|
+
decorator expecting ``fun``.
|
524
|
+
in_axes : int | None | sequence, default 0
|
525
|
+
Mapping specification for positional arguments, following
|
526
|
+
:func:`jax.vmap` semantics.
|
527
|
+
out_axes : any, default 0
|
528
|
+
Placement of the mapped axis in the outputs.
|
529
|
+
axis_name : hashable, optional
|
530
|
+
Name of the mapped axis for collective primitives.
|
531
|
+
axis_size : int, optional
|
532
|
+
Explicit size of the mapped axis. Must be positive when provided.
|
533
|
+
spmd_axis_name : hashable or tuple[hashable], optional
|
534
|
+
Axis labels used when nesting inside other SPMD transforms.
|
535
|
+
state_tag : str, optional
|
536
|
+
Tag used to limit which newly created states are tracked.
|
537
|
+
state_to_exclude : Filter, optional
|
538
|
+
Filter describing states that should *not* participate in the mapping.
|
539
|
+
state_in_axes : dict[AxisName, Filter] or Filter, optional
|
540
|
+
Filters indicating which existing states are batched on input.
|
541
|
+
state_out_axes : dict[AxisName, Filter] or Filter, optional
|
542
|
+
Filters describing how written states are scattered over the mapped axis.
|
543
|
+
unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
|
544
|
+
Behaviour when a state write is not covered by ``state_out_axes``.
|
545
|
+
|
546
|
+
Returns
|
547
|
+
-------
|
548
|
+
callable
|
549
|
+
A function with vectorised semantics that also mirrors new state
|
550
|
+
creation across the mapped axis.
|
551
|
+
|
552
|
+
Raises
|
553
|
+
------
|
554
|
+
ValueError
|
555
|
+
If ``axis_size`` is provided and is not strictly positive.
|
556
|
+
BatchAxisError
|
557
|
+
If unexpected state writes occur and the policy is ``'raise'``.
|
558
|
+
|
559
|
+
Examples
|
560
|
+
--------
|
561
|
+
.. code-block:: python
|
562
|
+
|
563
|
+
>>> import brainstate as bst
|
564
|
+
>>> import jax.numpy as jnp
|
565
|
+
>>> from brainstate.transform import vmap_new_states
|
566
|
+
>>>
|
567
|
+
>>> @vmap_new_states(in_axes=0, out_axes=0)
|
568
|
+
... def forward(x):
|
569
|
+
... scratch = bst.ShortTermState(jnp.array(0.0), tag='scratch')
|
570
|
+
... scratch.value = scratch.value + x
|
571
|
+
... return scratch.value
|
572
|
+
>>>
|
573
|
+
>>> forward(jnp.arange(3.0))
|
574
|
+
Array([0., 1., 2.], dtype=float32)
|
575
|
+
|
576
|
+
See Also
|
577
|
+
--------
|
578
|
+
vmap : State-aware vectorisation for existing states.
|
579
|
+
catch_new_states : Context manager used internally to intercept state creation.
|
580
|
+
"""
|
581
|
+
if isinstance(fun, Missing):
|
582
|
+
return functools.partial(
|
583
|
+
_vmap_new_states_transform,
|
584
|
+
in_axes=in_axes,
|
585
|
+
out_axes=out_axes,
|
586
|
+
axis_name=axis_name,
|
587
|
+
axis_size=axis_size,
|
588
|
+
spmd_axis_name=spmd_axis_name,
|
589
|
+
state_tag=state_tag,
|
590
|
+
state_to_exclude=state_to_exclude,
|
591
|
+
state_in_axes=state_in_axes,
|
592
|
+
state_out_axes=state_out_axes,
|
593
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
594
|
+
)
|
595
|
+
else:
|
596
|
+
return _vmap_new_states_transform(
|
597
|
+
fun,
|
598
|
+
in_axes=in_axes,
|
599
|
+
out_axes=out_axes,
|
600
|
+
axis_name=axis_name,
|
601
|
+
axis_size=axis_size,
|
602
|
+
spmd_axis_name=spmd_axis_name,
|
603
|
+
state_tag=state_tag,
|
604
|
+
state_to_exclude=state_to_exclude,
|
605
|
+
state_in_axes=state_in_axes,
|
606
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
607
|
+
)
|