brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,43 +1,43 @@
|
|
1
|
-
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
import brainstate
|
20
|
-
|
21
|
-
|
22
|
-
class Test_vmap_init_all_states:
|
23
|
-
|
24
|
-
def test_vmap_init_all_states(self):
|
25
|
-
gru = brainstate.nn.GRUCell(1, 2)
|
26
|
-
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
27
|
-
print(gru)
|
28
|
-
|
29
|
-
def test_vmap_init_all_states_v2(self):
|
30
|
-
@brainstate.compile.jit
|
31
|
-
def init():
|
32
|
-
gru = brainstate.nn.GRUCell(1, 2)
|
33
|
-
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
34
|
-
print(gru)
|
35
|
-
|
36
|
-
init()
|
37
|
-
|
38
|
-
|
39
|
-
class Test_init_all_states:
|
40
|
-
def test_init_all_states(self):
|
41
|
-
gru = brainstate.nn.GRUCell(1, 2)
|
42
|
-
brainstate.nn.init_all_states(gru, batch_size=10)
|
43
|
-
print(gru)
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
import brainstate
|
20
|
+
|
21
|
+
|
22
|
+
class Test_vmap_init_all_states:
|
23
|
+
|
24
|
+
def test_vmap_init_all_states(self):
|
25
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
26
|
+
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
27
|
+
print(gru)
|
28
|
+
|
29
|
+
def test_vmap_init_all_states_v2(self):
|
30
|
+
@brainstate.compile.jit
|
31
|
+
def init():
|
32
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
33
|
+
brainstate.nn.vmap_init_all_states(gru, axis_size=10)
|
34
|
+
print(gru)
|
35
|
+
|
36
|
+
init()
|
37
|
+
|
38
|
+
|
39
|
+
class Test_init_all_states:
|
40
|
+
def test_init_all_states(self):
|
41
|
+
gru = brainstate.nn.GRUCell(1, 2)
|
42
|
+
brainstate.nn.init_all_states(gru, batch_size=10)
|
43
|
+
print(gru)
|
brainstate/nn/_common.py
CHANGED
@@ -1,178 +1,178 @@
|
|
1
|
-
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from collections import defaultdict
|
19
|
-
from typing import Any, Sequence, Hashable, Dict
|
20
|
-
|
21
|
-
from brainstate import environ
|
22
|
-
from brainstate.augment._mapping import vmap
|
23
|
-
from brainstate.typing import Filter
|
24
|
-
from ._module import Module
|
25
|
-
|
26
|
-
AxisName = Hashable
|
27
|
-
|
28
|
-
__all__ = [
|
29
|
-
'EnvironContext',
|
30
|
-
'Vmap',
|
31
|
-
]
|
32
|
-
|
33
|
-
|
34
|
-
class EnvironContext(Module):
|
35
|
-
"""
|
36
|
-
A wrapper class that provides an environment context for a given layer.
|
37
|
-
|
38
|
-
This class allows execution of a layer within a specific environment context,
|
39
|
-
which can be useful for controlling the execution environment of neural network layers.
|
40
|
-
|
41
|
-
This class is equivalent to the following code snippet:
|
42
|
-
|
43
|
-
```python
|
44
|
-
|
45
|
-
import brainstate
|
46
|
-
|
47
|
-
with brainstate.environ.context(**context):
|
48
|
-
result = layer(*args, **kwargs)
|
49
|
-
|
50
|
-
```
|
51
|
-
|
52
|
-
Attributes:
|
53
|
-
layer (Module): The layer to be executed within the environment context.
|
54
|
-
context (dict): The environment context parameters.
|
55
|
-
"""
|
56
|
-
|
57
|
-
def __init__(self, layer: Module, **context):
|
58
|
-
"""
|
59
|
-
Initialize the EnvironContext.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
layer (Module): The layer to be wrapped with the environment context.
|
63
|
-
**context: Arbitrary keyword arguments representing the environment context parameters.
|
64
|
-
"""
|
65
|
-
super().__init__()
|
66
|
-
|
67
|
-
assert isinstance(layer, Module), 'The layer must be an instance of Module.'
|
68
|
-
self.layer = layer
|
69
|
-
self.context = context
|
70
|
-
|
71
|
-
def update(self, *args, **kwargs):
|
72
|
-
"""
|
73
|
-
Execute the wrapped layer within the specified environment context.
|
74
|
-
|
75
|
-
Args:
|
76
|
-
*args: Variable length argument list to be passed to the wrapped layer.
|
77
|
-
**kwargs: Arbitrary keyword arguments to be passed to the wrapped layer.
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
The result of executing the wrapped layer within the environment context.
|
81
|
-
"""
|
82
|
-
with environ.context(**self.context):
|
83
|
-
return self.layer(*args, **kwargs)
|
84
|
-
|
85
|
-
def add_context(self, **context):
|
86
|
-
"""
|
87
|
-
Add additional environment context parameters to the existing context.
|
88
|
-
|
89
|
-
Args:
|
90
|
-
**context: Arbitrary keyword arguments representing the additional environment context parameters.
|
91
|
-
"""
|
92
|
-
self.context.update(context)
|
93
|
-
|
94
|
-
|
95
|
-
def _filter_states(
|
96
|
-
module: Module,
|
97
|
-
filters: Filter | Dict[Filter, int],
|
98
|
-
) -> Dict:
|
99
|
-
if filters is None:
|
100
|
-
filtered_states = None
|
101
|
-
elif isinstance(filters, dict):
|
102
|
-
in_states_filter = defaultdict(list)
|
103
|
-
for filter_, axis in filters:
|
104
|
-
assert isinstance(axis, int), 'The value of in_states must be the map axis, which should be an integer.'
|
105
|
-
in_states_filter[axis].append(filter_)
|
106
|
-
filtered_states = module.states(*in_states_filter.values())
|
107
|
-
in_states_axis = tuple(in_states_filter.keys())
|
108
|
-
filtered_states = {axis: states for axis, states in zip(in_states_axis, filtered_states)}
|
109
|
-
else:
|
110
|
-
filtered_states = module.states(filters)
|
111
|
-
return filtered_states
|
112
|
-
|
113
|
-
|
114
|
-
class Vmap(Module):
|
115
|
-
"""
|
116
|
-
A class that applies vectorized mapping (vmap) to a given module.
|
117
|
-
|
118
|
-
This class wraps a module and applies vectorized mapping to its execution,
|
119
|
-
allowing for efficient parallel processing across specified axes.
|
120
|
-
|
121
|
-
Args:
|
122
|
-
module (Module): The module to be vmapped.
|
123
|
-
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
124
|
-
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
125
|
-
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
126
|
-
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
127
|
-
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
128
|
-
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
129
|
-
"""
|
130
|
-
|
131
|
-
def __init__(
|
132
|
-
self,
|
133
|
-
module: Module,
|
134
|
-
in_axes: int | None | Sequence[Any] = 0,
|
135
|
-
out_axes: Any = 0,
|
136
|
-
vmap_states: Filter | Dict[Filter, int] = None,
|
137
|
-
vmap_out_states: Filter | Dict[Filter, int] = None,
|
138
|
-
axis_name: AxisName | None = None,
|
139
|
-
axis_size: int | None = None,
|
140
|
-
):
|
141
|
-
super().__init__()
|
142
|
-
|
143
|
-
# parameters
|
144
|
-
self.in_axes = in_axes
|
145
|
-
self.out_axes = out_axes
|
146
|
-
self.axis_name = axis_name
|
147
|
-
self.axis_size = axis_size
|
148
|
-
assert isinstance(module, Module), 'The module must be an instance of Module.'
|
149
|
-
self.module = module
|
150
|
-
vmap_states = _filter_states(module, vmap_states)
|
151
|
-
vmap_out_states = _filter_states(module, vmap_out_states)
|
152
|
-
|
153
|
-
@vmap(
|
154
|
-
in_axes=in_axes,
|
155
|
-
out_axes=out_axes,
|
156
|
-
in_states=vmap_states,
|
157
|
-
out_states=vmap_out_states,
|
158
|
-
axis_name=axis_name,
|
159
|
-
axis_size=axis_size,
|
160
|
-
)
|
161
|
-
def vmap_run(*args, **kwargs):
|
162
|
-
return module(*args, **kwargs)
|
163
|
-
|
164
|
-
# vmapped module
|
165
|
-
self.vmapped_fn = vmap_run
|
166
|
-
|
167
|
-
def update(self, *args, **kwargs):
|
168
|
-
"""
|
169
|
-
Execute the vmapped module with the given arguments.
|
170
|
-
|
171
|
-
Args:
|
172
|
-
*args: Variable length argument list to be passed to the vmapped module.
|
173
|
-
**kwargs: Arbitrary keyword arguments to be passed to the vmapped module.
|
174
|
-
|
175
|
-
Returns:
|
176
|
-
The result of executing the vmapped module.
|
177
|
-
"""
|
178
|
-
return self.vmapped_fn(*args, **kwargs)
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from collections import defaultdict
|
19
|
+
from typing import Any, Sequence, Hashable, Dict
|
20
|
+
|
21
|
+
from brainstate import environ
|
22
|
+
from brainstate.augment._mapping import vmap
|
23
|
+
from brainstate.typing import Filter
|
24
|
+
from ._module import Module
|
25
|
+
|
26
|
+
AxisName = Hashable
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
'EnvironContext',
|
30
|
+
'Vmap',
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
class EnvironContext(Module):
|
35
|
+
"""
|
36
|
+
A wrapper class that provides an environment context for a given layer.
|
37
|
+
|
38
|
+
This class allows execution of a layer within a specific environment context,
|
39
|
+
which can be useful for controlling the execution environment of neural network layers.
|
40
|
+
|
41
|
+
This class is equivalent to the following code snippet:
|
42
|
+
|
43
|
+
```python
|
44
|
+
|
45
|
+
import brainstate
|
46
|
+
|
47
|
+
with brainstate.environ.context(**context):
|
48
|
+
result = layer(*args, **kwargs)
|
49
|
+
|
50
|
+
```
|
51
|
+
|
52
|
+
Attributes:
|
53
|
+
layer (Module): The layer to be executed within the environment context.
|
54
|
+
context (dict): The environment context parameters.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(self, layer: Module, **context):
|
58
|
+
"""
|
59
|
+
Initialize the EnvironContext.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
layer (Module): The layer to be wrapped with the environment context.
|
63
|
+
**context: Arbitrary keyword arguments representing the environment context parameters.
|
64
|
+
"""
|
65
|
+
super().__init__()
|
66
|
+
|
67
|
+
assert isinstance(layer, Module), 'The layer must be an instance of Module.'
|
68
|
+
self.layer = layer
|
69
|
+
self.context = context
|
70
|
+
|
71
|
+
def update(self, *args, **kwargs):
|
72
|
+
"""
|
73
|
+
Execute the wrapped layer within the specified environment context.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
*args: Variable length argument list to be passed to the wrapped layer.
|
77
|
+
**kwargs: Arbitrary keyword arguments to be passed to the wrapped layer.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
The result of executing the wrapped layer within the environment context.
|
81
|
+
"""
|
82
|
+
with environ.context(**self.context):
|
83
|
+
return self.layer(*args, **kwargs)
|
84
|
+
|
85
|
+
def add_context(self, **context):
|
86
|
+
"""
|
87
|
+
Add additional environment context parameters to the existing context.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
**context: Arbitrary keyword arguments representing the additional environment context parameters.
|
91
|
+
"""
|
92
|
+
self.context.update(context)
|
93
|
+
|
94
|
+
|
95
|
+
def _filter_states(
|
96
|
+
module: Module,
|
97
|
+
filters: Filter | Dict[Filter, int],
|
98
|
+
) -> Dict:
|
99
|
+
if filters is None:
|
100
|
+
filtered_states = None
|
101
|
+
elif isinstance(filters, dict):
|
102
|
+
in_states_filter = defaultdict(list)
|
103
|
+
for filter_, axis in filters:
|
104
|
+
assert isinstance(axis, int), 'The value of in_states must be the map axis, which should be an integer.'
|
105
|
+
in_states_filter[axis].append(filter_)
|
106
|
+
filtered_states = module.states(*in_states_filter.values())
|
107
|
+
in_states_axis = tuple(in_states_filter.keys())
|
108
|
+
filtered_states = {axis: states for axis, states in zip(in_states_axis, filtered_states)}
|
109
|
+
else:
|
110
|
+
filtered_states = module.states(filters)
|
111
|
+
return filtered_states
|
112
|
+
|
113
|
+
|
114
|
+
class Vmap(Module):
|
115
|
+
"""
|
116
|
+
A class that applies vectorized mapping (vmap) to a given module.
|
117
|
+
|
118
|
+
This class wraps a module and applies vectorized mapping to its execution,
|
119
|
+
allowing for efficient parallel processing across specified axes.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
module (Module): The module to be vmapped.
|
123
|
+
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
124
|
+
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
125
|
+
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
126
|
+
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
127
|
+
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
128
|
+
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
129
|
+
"""
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
module: Module,
|
134
|
+
in_axes: int | None | Sequence[Any] = 0,
|
135
|
+
out_axes: Any = 0,
|
136
|
+
vmap_states: Filter | Dict[Filter, int] = None,
|
137
|
+
vmap_out_states: Filter | Dict[Filter, int] = None,
|
138
|
+
axis_name: AxisName | None = None,
|
139
|
+
axis_size: int | None = None,
|
140
|
+
):
|
141
|
+
super().__init__()
|
142
|
+
|
143
|
+
# parameters
|
144
|
+
self.in_axes = in_axes
|
145
|
+
self.out_axes = out_axes
|
146
|
+
self.axis_name = axis_name
|
147
|
+
self.axis_size = axis_size
|
148
|
+
assert isinstance(module, Module), 'The module must be an instance of Module.'
|
149
|
+
self.module = module
|
150
|
+
vmap_states = _filter_states(module, vmap_states)
|
151
|
+
vmap_out_states = _filter_states(module, vmap_out_states)
|
152
|
+
|
153
|
+
@vmap(
|
154
|
+
in_axes=in_axes,
|
155
|
+
out_axes=out_axes,
|
156
|
+
in_states=vmap_states,
|
157
|
+
out_states=vmap_out_states,
|
158
|
+
axis_name=axis_name,
|
159
|
+
axis_size=axis_size,
|
160
|
+
)
|
161
|
+
def vmap_run(*args, **kwargs):
|
162
|
+
return module(*args, **kwargs)
|
163
|
+
|
164
|
+
# vmapped module
|
165
|
+
self.vmapped_fn = vmap_run
|
166
|
+
|
167
|
+
def update(self, *args, **kwargs):
|
168
|
+
"""
|
169
|
+
Execute the vmapped module with the given arguments.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
*args: Variable length argument list to be passed to the vmapped module.
|
173
|
+
**kwargs: Arbitrary keyword arguments to be passed to the vmapped module.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
The result of executing the vmapped module.
|
177
|
+
"""
|
178
|
+
return self.vmapped_fn(*args, **kwargs)
|