ai-edge-torch-nightly 0.3.0.dev20250121__py3-none-any.whl → 0.3.0.dev20250123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_convert/conversion.py +6 -10
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +6 -3
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +9 -11
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -3
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py +33 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -3
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +4 -4
- ai_edge_torch/_convert/test/test_convert.py +2 -2
- ai_edge_torch/fx_infra/__init__.py +32 -0
- ai_edge_torch/fx_infra/_canonicalize_pass.py +27 -0
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +57 -0
- ai_edge_torch/fx_infra/decomp.py +80 -0
- ai_edge_torch/fx_infra/graph_utils.py +42 -0
- ai_edge_torch/{fx_pass_base.py → fx_infra/pass_base.py} +0 -28
- ai_edge_torch/generative/fx_passes/__init__.py +3 -3
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -3
- ai_edge_torch/hlfb/mark_pattern/__init__.py +0 -1
- ai_edge_torch/hlfb/mark_pattern/fx_utils.py +5 -21
- ai_edge_torch/hlfb/mark_pattern/pattern.py +20 -11
- ai_edge_torch/hlfb/test/test_mark_pattern.py +18 -15
- ai_edge_torch/odml_torch/_torch_future.py +0 -27
- ai_edge_torch/odml_torch/export.py +6 -8
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +65 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/RECORD +34 -28
- ai_edge_torch/odml_torch/lowerings/decomp.py +0 -69
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
15
|
from ai_edge_torch._config import config
|
17
16
|
from ai_edge_torch._convert.converter import convert
|
18
17
|
from ai_edge_torch._convert.converter import signature
|
@@ -20,6 +19,7 @@ from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
|
20
19
|
from ai_edge_torch.model import Model
|
21
20
|
from ai_edge_torch.version import __version__
|
22
21
|
|
22
|
+
|
23
23
|
def load(path: str) -> Model:
|
24
24
|
"""Imports an ai_edge_torch model from disk.
|
25
25
|
|
@@ -17,7 +17,7 @@ import logging
|
|
17
17
|
from typing import Any, Literal, Optional, Union
|
18
18
|
|
19
19
|
import ai_edge_torch
|
20
|
-
from ai_edge_torch import
|
20
|
+
from ai_edge_torch import fx_infra
|
21
21
|
from ai_edge_torch import lowertools
|
22
22
|
from ai_edge_torch import model
|
23
23
|
from ai_edge_torch._convert import fx_passes
|
@@ -53,7 +53,7 @@ def _run_convert_passes(
|
|
53
53
|
fx_passes.CanonicalizePass(),
|
54
54
|
]
|
55
55
|
|
56
|
-
exported_program =
|
56
|
+
exported_program = fx_infra.run_passes(exported_program, passes)
|
57
57
|
return exported_program
|
58
58
|
|
59
59
|
|
@@ -125,14 +125,10 @@ def convert_signatures(
|
|
125
125
|
else:
|
126
126
|
exported_program = torch.export.export(**kwargs, strict=True)
|
127
127
|
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
# implementation of torch.export changes.
|
133
|
-
exported_program = exported_program.run_decompositions(
|
134
|
-
torch._decomp._decomp_table_to_post_autograd_aten()
|
135
|
-
)
|
128
|
+
exported_program = fx_infra.safe_run_decompositions(
|
129
|
+
exported_program,
|
130
|
+
fx_infra.decomp.pre_convert_decomp(),
|
131
|
+
)
|
136
132
|
return exported_program
|
137
133
|
|
138
134
|
exported_programs: torch.export.ExportedProgram = [
|
@@ -20,4 +20,4 @@ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import Bu
|
|
20
20
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
22
|
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
23
|
-
from ai_edge_torch.
|
23
|
+
from ai_edge_torch.fx_infra import CanonicalizePass
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from typing import Any, Callable
|
17
|
-
from ai_edge_torch import
|
17
|
+
from ai_edge_torch import fx_infra
|
18
18
|
from ai_edge_torch import lowertools
|
19
19
|
import torch
|
20
20
|
import torch.utils._pytree as pytree
|
@@ -25,6 +25,9 @@ _composite_builders: dict[
|
|
25
25
|
|
26
26
|
|
27
27
|
def _register_composite_builder(op):
|
28
|
+
# Remove op from pre_convert_decomp to keep this in the decomposed graph.
|
29
|
+
fx_infra.decomp.remove_pre_convert_decomp(op)
|
30
|
+
|
28
31
|
def inner(func):
|
29
32
|
if isinstance(op, torch._ops.OpOverloadPacket):
|
30
33
|
for overload in op.overloads():
|
@@ -276,7 +279,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
276
279
|
node.target = embedding
|
277
280
|
|
278
281
|
|
279
|
-
class BuildAtenCompositePass(
|
282
|
+
class BuildAtenCompositePass(fx_infra.PassBase):
|
280
283
|
|
281
284
|
def call(self, graph_module: torch.fx.GraphModule):
|
282
285
|
for node in graph_module.graph.nodes:
|
@@ -285,4 +288,4 @@ class BuildAtenCompositePass(fx_pass_base.PassBase):
|
|
285
288
|
|
286
289
|
graph_module.graph.lint()
|
287
290
|
graph_module.recompile()
|
288
|
-
return
|
291
|
+
return fx_infra.PassResult(graph_module, True)
|
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
import functools
|
18
18
|
|
19
|
-
from ai_edge_torch import
|
19
|
+
from ai_edge_torch import fx_infra
|
20
20
|
from ai_edge_torch.hlfb import mark_pattern
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
22
|
import torch
|
@@ -41,7 +41,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
41
41
|
x, scale_factor=2, mode="bilinear", align_corners=False
|
42
42
|
),
|
43
43
|
export_args=(torch.rand(1, 3, 100, 100),),
|
44
|
-
|
44
|
+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
45
45
|
)
|
46
46
|
|
47
47
|
@pattern.register_attr_builder
|
@@ -65,7 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
65
65
|
x, scale_factor=2, mode="bilinear", align_corners=True
|
66
66
|
),
|
67
67
|
export_args=(torch.rand(1, 3, 100, 100),),
|
68
|
-
|
68
|
+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
69
69
|
)
|
70
70
|
|
71
71
|
@pattern.register_attr_builder
|
@@ -89,7 +89,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
89
89
|
x, scale_factor=2, mode="nearest"
|
90
90
|
),
|
91
91
|
export_args=(torch.rand(1, 3, 100, 100),),
|
92
|
-
|
92
|
+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
93
93
|
)
|
94
94
|
|
95
95
|
@pattern.register_attr_builder
|
@@ -104,7 +104,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
104
104
|
return pattern
|
105
105
|
|
106
106
|
|
107
|
-
class BuildInterpolateCompositePass(
|
107
|
+
class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
|
108
108
|
|
109
109
|
def __init__(self):
|
110
110
|
super().__init__()
|
@@ -115,11 +115,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
|
115
115
|
]
|
116
116
|
|
117
117
|
def call(self, exported_program: torch.export.ExportedProgram):
|
118
|
-
exported_program =
|
119
|
-
exported_program,
|
120
|
-
|
121
|
-
exported_program = exported_program.run_decompositions(
|
122
|
-
_INTERPOLATE_DECOMPOSITIONS
|
118
|
+
exported_program = fx_infra.safe_run_decompositions(
|
119
|
+
exported_program,
|
120
|
+
_INTERPOLATE_DECOMPOSITIONS,
|
123
121
|
)
|
124
122
|
|
125
123
|
graph_module = exported_program.graph_module
|
@@ -128,4 +126,4 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
|
128
126
|
|
129
127
|
graph_module.graph.lint()
|
130
128
|
graph_module.recompile()
|
131
|
-
return
|
129
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from ai_edge_torch import
|
16
|
+
from ai_edge_torch import fx_infra
|
17
17
|
from ai_edge_torch import lowertools
|
18
18
|
import torch
|
19
19
|
import torch.utils._pytree as pytree
|
@@ -61,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
61
61
|
node.target = debuginfo_writer
|
62
62
|
|
63
63
|
|
64
|
-
class InjectMlirDebuginfoPass(
|
64
|
+
class InjectMlirDebuginfoPass(fx_infra.PassBase):
|
65
65
|
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
|
66
66
|
|
67
67
|
def call(self, graph_module: torch.fx.GraphModule):
|
@@ -70,4 +70,4 @@ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
|
70
70
|
|
71
71
|
graph_module.graph.lint()
|
72
72
|
graph_module.recompile()
|
73
|
-
return
|
73
|
+
return fx_infra.PassResult(graph_module, True)
|
@@ -13,4 +13,5 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import _decomp_registry
|
16
17
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
|
@@ -0,0 +1,33 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Remove decompositions for ops to keep in layout optimization."""
|
16
|
+
from ai_edge_torch import fx_infra
|
17
|
+
import torch
|
18
|
+
|
19
|
+
__all__ = []
|
20
|
+
|
21
|
+
aten = torch.ops.aten
|
22
|
+
|
23
|
+
_OPS_TO_KEEP = [
|
24
|
+
aten.conv2d,
|
25
|
+
aten.max_pool2d,
|
26
|
+
aten._softmax.default,
|
27
|
+
aten.group_norm.default,
|
28
|
+
aten.native_group_norm.default,
|
29
|
+
aten.reflection_pad2d.default,
|
30
|
+
]
|
31
|
+
|
32
|
+
for op in _OPS_TO_KEEP:
|
33
|
+
fx_infra.decomp.remove_pre_convert_decomp(op)
|
@@ -18,7 +18,7 @@ import operator
|
|
18
18
|
import os
|
19
19
|
from typing import Union
|
20
20
|
|
21
|
-
from ai_edge_torch import
|
21
|
+
from ai_edge_torch import fx_infra
|
22
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
23
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
24
24
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
|
@@ -30,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
|
|
30
30
|
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
|
31
31
|
|
32
32
|
|
33
|
-
class OptimizeLayoutTransposesPass(
|
33
|
+
class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
|
34
34
|
|
35
35
|
def get_source_meta(self, node: torch.fx.Node):
|
36
36
|
keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
|
@@ -300,4 +300,4 @@ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
|
|
300
300
|
# Mark const node again for debugging
|
301
301
|
self.mark_const_nodes(exported_program)
|
302
302
|
|
303
|
-
return
|
303
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -15,11 +15,11 @@
|
|
15
15
|
"""Pass to remove all non user outputs from exported program."""
|
16
16
|
|
17
17
|
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import fx_infra
|
19
19
|
import torch
|
20
20
|
|
21
21
|
|
22
|
-
class RemoveNonUserOutputsPass(
|
22
|
+
class RemoveNonUserOutputsPass(fx_infra.ExportedProgramPassBase):
|
23
23
|
"""This pass removes all non user outputs from the exported program's output.
|
24
24
|
|
25
25
|
The FX graph may output more tensors/data than what user's original model
|
@@ -47,6 +47,6 @@ class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
|
|
47
47
|
node.args = (tuple(new_outputs),)
|
48
48
|
exported_program.graph_signature.output_specs = new_output_specs
|
49
49
|
|
50
|
-
exported_program.
|
50
|
+
exported_program.graph.eliminate_dead_code()
|
51
51
|
exported_program.graph_module.recompile()
|
52
|
-
return
|
52
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -511,7 +511,7 @@ class TestConvert(googletest.TestCase):
|
|
511
511
|
# Step 1: export resnet18
|
512
512
|
args = (torch.randn(1, 3, 224, 224),)
|
513
513
|
m = torchvision.models.resnet18().eval()
|
514
|
-
m = torch.
|
514
|
+
m = torch.export.export_for_training(m, args).module()
|
515
515
|
|
516
516
|
# Step 2: Insert observers or fake quantize modules
|
517
517
|
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
@@ -533,7 +533,7 @@ class TestConvert(googletest.TestCase):
|
|
533
533
|
# Step 1: export resnet18
|
534
534
|
args = (torch.randn(1, 3, 224, 224),)
|
535
535
|
m = torchvision.models.resnet18().eval()
|
536
|
-
m = torch.
|
536
|
+
m = torch.export.export_for_training(m, args).module()
|
537
537
|
|
538
538
|
# Step 2: Insert observers or fake quantize modules
|
539
539
|
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ai_edge_torch.fx_infra import _canonicalize_pass
|
17
|
+
from ai_edge_torch.fx_infra import _safe_run_decompositions
|
18
|
+
from ai_edge_torch.fx_infra import decomp
|
19
|
+
from ai_edge_torch.fx_infra import graph_utils
|
20
|
+
from ai_edge_torch.fx_infra import pass_base
|
21
|
+
|
22
|
+
|
23
|
+
PassBase = pass_base.PassBase
|
24
|
+
PassResult = pass_base.PassResult
|
25
|
+
FxPassBase = pass_base.FxPassBase
|
26
|
+
FxPassResult = pass_base.FxPassResult
|
27
|
+
ExportedProgramPassBase = pass_base.ExportedProgramPassBase
|
28
|
+
ExportedProgramPassResult = pass_base.ExportedProgramPassResult
|
29
|
+
run_passes = pass_base.run_passes
|
30
|
+
|
31
|
+
CanonicalizePass = _canonicalize_pass.CanonicalizePass
|
32
|
+
safe_run_decompositions = _safe_run_decompositions.safe_run_decompositions
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from ai_edge_torch.fx_infra import _safe_run_decompositions
|
16
|
+
from ai_edge_torch.fx_infra import pass_base
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
class CanonicalizePass(pass_base.ExportedProgramPassBase):
|
21
|
+
|
22
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
23
|
+
exported_program = _safe_run_decompositions.safe_run_decompositions(
|
24
|
+
exported_program, {}
|
25
|
+
)
|
26
|
+
|
27
|
+
return pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
|
16
|
+
import torch
|
17
|
+
|
18
|
+
|
19
|
+
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
20
|
+
# any op decompositions but just aot_export_module. Due to the check in
|
21
|
+
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
22
|
+
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
23
|
+
# dummy decomp table is needed.
|
24
|
+
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
25
|
+
_DUMMY_DECOMP_TABLE = {
|
26
|
+
torch._ops.OperatorBase(): lambda: None,
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
def safe_run_decompositions(exported_program, decomp_table=None):
|
31
|
+
"""Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
|
32
|
+
|
33
|
+
if decomp_table is not None and not decomp_table:
|
34
|
+
# Empty decomp table means no op decompositions. Use dummy decomp table
|
35
|
+
# instead for backward compatibility.
|
36
|
+
decomp_table = _DUMMY_DECOMP_TABLE
|
37
|
+
|
38
|
+
for node in exported_program.graph.nodes:
|
39
|
+
if node.target == torch.ops.aten.view.default:
|
40
|
+
# Passes or torch.export may generate aten.view nodes not respecting the
|
41
|
+
# tensor memory format. Changes all the aten.view to torch.reshape
|
42
|
+
# for retracing. If the input memory format is already contiguous,
|
43
|
+
# retracing in run_decomposition below would decompose torch.reshape
|
44
|
+
# back to one aten.view.
|
45
|
+
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
46
|
+
|
47
|
+
exported_program = exported_program.run_decompositions(decomp_table)
|
48
|
+
|
49
|
+
if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
|
50
|
+
for node in exported_program.graph.nodes:
|
51
|
+
if node.target == torch.ops.aten._assert_tensor_metadata.default:
|
52
|
+
exported_program.graph.erase_node(node)
|
53
|
+
|
54
|
+
exported_program.graph.eliminate_dead_code()
|
55
|
+
exported_program.graph_module.recompile()
|
56
|
+
|
57
|
+
return exported_program
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
|
16
|
+
import torch
|
17
|
+
|
18
|
+
# Decompositions to be run after torch.export before conversion and any
|
19
|
+
# passes. Remove ops from this decomposition table if they need to be preserved
|
20
|
+
# in passes.
|
21
|
+
_pre_convert_decomp = torch._decomp.core_aten_decompositions().copy()
|
22
|
+
|
23
|
+
|
24
|
+
# Decompositions to be run after conversion before odml_torch passes and
|
25
|
+
# lowerings.
|
26
|
+
_pre_lower_decomp = torch._decomp.core_aten_decompositions().copy()
|
27
|
+
|
28
|
+
|
29
|
+
def _get_ops(op):
|
30
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
31
|
+
return [getattr(op, overload) for overload in op.overloads()]
|
32
|
+
else:
|
33
|
+
return [op]
|
34
|
+
|
35
|
+
|
36
|
+
def pre_convert_decomp():
|
37
|
+
"""Decompositions to be run after torch.export before conversion and any passes."""
|
38
|
+
return _pre_convert_decomp.copy()
|
39
|
+
|
40
|
+
|
41
|
+
def pre_lower_decomp():
|
42
|
+
"""Decompositions to be run after conversion before odml_torch passes and lowerings."""
|
43
|
+
return _pre_lower_decomp.copy()
|
44
|
+
|
45
|
+
|
46
|
+
def remove_pre_lower_decomp(op):
|
47
|
+
# Also remove from pre_convert_decomp which always run before pre_lower_
|
48
|
+
# decomp.
|
49
|
+
remove_pre_convert_decomp(op)
|
50
|
+
|
51
|
+
for op_ in _get_ops(op):
|
52
|
+
_pre_lower_decomp.pop(op_, None)
|
53
|
+
|
54
|
+
|
55
|
+
def remove_pre_convert_decomp(op):
|
56
|
+
for op_ in _get_ops(op):
|
57
|
+
_pre_convert_decomp.pop(op_, None)
|
58
|
+
|
59
|
+
|
60
|
+
def add_pre_convert_decomp(op, decomp):
|
61
|
+
# Also add decomp to pre_lower_decomp which runs after pre_convert_decomp.
|
62
|
+
add_pre_lower_decomp(op, decomp)
|
63
|
+
|
64
|
+
for op_ in _get_ops(op):
|
65
|
+
_pre_convert_decomp[op_] = decomp
|
66
|
+
|
67
|
+
|
68
|
+
def add_pre_lower_decomp(op, decomp):
|
69
|
+
for op_ in _get_ops(op):
|
70
|
+
_pre_lower_decomp[op_] = decomp
|
71
|
+
|
72
|
+
|
73
|
+
def update_pre_convert_decomp(decomps):
|
74
|
+
for op, decomp in decomps.items():
|
75
|
+
add_pre_convert_decomp(op, decomp)
|
76
|
+
|
77
|
+
|
78
|
+
def update_pre_lower_decomp(decomps):
|
79
|
+
for op, decomp in decomps.items():
|
80
|
+
add_pre_lower_decomp(op, decomp)
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""FX graph utilities."""
|
16
|
+
import torch
|
17
|
+
|
18
|
+
|
19
|
+
def remove_dangling_args(graph_module: torch.fx.GraphModule):
|
20
|
+
"""Removes dangling args from the graph."""
|
21
|
+
for node in graph_module.graph.nodes:
|
22
|
+
if node.op == "placeholder" and not node.users:
|
23
|
+
graph_module.graph.erase_node(node)
|
24
|
+
|
25
|
+
graph_module.graph.lint()
|
26
|
+
graph_module.recompile()
|
27
|
+
return graph_module
|
28
|
+
|
29
|
+
|
30
|
+
def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
|
31
|
+
"""Removes aten._assert_tensor_metadata nodes from the graph.
|
32
|
+
|
33
|
+
This op is inserted by torch.export to check tensor metadata on custom ops. It
|
34
|
+
can break patten matching and lowering.
|
35
|
+
"""
|
36
|
+
for node in graph_module.graph.nodes:
|
37
|
+
if node.target == torch.ops.aten._assert_tensor_metadata.default:
|
38
|
+
graph_module.graph.erase_node(node)
|
39
|
+
|
40
|
+
graph_module.graph.lint()
|
41
|
+
graph_module.recompile()
|
42
|
+
return graph_module
|
@@ -80,31 +80,3 @@ def run_passes(
|
|
80
80
|
constants=exported_program.constants,
|
81
81
|
)
|
82
82
|
return exported_program
|
83
|
-
|
84
|
-
|
85
|
-
class CanonicalizePass(ExportedProgramPassBase):
|
86
|
-
|
87
|
-
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
88
|
-
# any op decompositions but just aot_export_module. Due to the check in
|
89
|
-
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
90
|
-
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
91
|
-
# dummy decomp table is needed.
|
92
|
-
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
93
|
-
_DUMMY_DECOMP_TABLE = {
|
94
|
-
torch._ops.OperatorBase(): lambda: None,
|
95
|
-
}
|
96
|
-
|
97
|
-
def call(self, exported_program: torch.export.ExportedProgram):
|
98
|
-
for node in exported_program.graph.nodes:
|
99
|
-
if node.target == torch.ops.aten.view.default:
|
100
|
-
# Passes or torch.export may generate aten.view nodes not respecting the
|
101
|
-
# tensor memory format. Changes all the aten.view to torch.reshape
|
102
|
-
# for retracing. If the input memory format is already contiguous,
|
103
|
-
# retracing in run_decomposition below would decompose torch.reshape
|
104
|
-
# back to one aten.view.
|
105
|
-
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
106
|
-
|
107
|
-
exported_program = exported_program.run_decompositions(
|
108
|
-
self._DUMMY_DECOMP_TABLE
|
109
|
-
)
|
110
|
-
return ExportedProgramPassResult(exported_program, True)
|
@@ -12,8 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ai_edge_torch import
|
16
|
-
from ai_edge_torch.
|
15
|
+
from ai_edge_torch import fx_infra
|
16
|
+
from ai_edge_torch.fx_infra import CanonicalizePass
|
17
17
|
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
|
18
18
|
import torch
|
19
19
|
|
@@ -21,7 +21,7 @@ import torch
|
|
21
21
|
def run_generative_passes(
|
22
22
|
exported_program: torch.export.ExportedProgram,
|
23
23
|
) -> torch.export.ExportedProgram:
|
24
|
-
return
|
24
|
+
return fx_infra.run_passes(
|
25
25
|
exported_program,
|
26
26
|
[
|
27
27
|
RemoveSDPACompositeZeroMaskPass(),
|
@@ -12,12 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ai_edge_torch import
|
15
|
+
from ai_edge_torch import fx_infra
|
16
16
|
from ai_edge_torch import lowertools
|
17
17
|
import torch
|
18
18
|
|
19
19
|
|
20
|
-
class RemoveSDPACompositeZeroMaskPass(
|
20
|
+
class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
|
21
21
|
|
22
22
|
def is_zero_tensor_node(self, node: torch.fx.Node):
|
23
23
|
return node.target == torch.ops.aten.zeros.default
|
@@ -47,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
|
|
47
47
|
|
48
48
|
exported_program.graph_module.graph.lint()
|
49
49
|
exported_program.graph_module.recompile()
|
50
|
-
return
|
50
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -14,8 +14,13 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""FX graph utilities for pattern matching clean ups."""
|
16
16
|
|
17
|
+
from ai_edge_torch import fx_infra
|
17
18
|
import torch
|
18
19
|
|
20
|
+
remove_dangling_args = fx_infra.graph_utils.remove_dangling_args
|
21
|
+
remove_assert_tensor_metadata_nodes = (
|
22
|
+
fx_infra.graph_utils.remove_assert_tensor_metadata_nodes
|
23
|
+
)
|
19
24
|
|
20
25
|
def is_clone_op(node: torch.fx.Node) -> bool:
|
21
26
|
"""Checks if the node is a clone op."""
|
@@ -46,24 +51,3 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
46
51
|
gm.graph.lint()
|
47
52
|
gm.recompile()
|
48
53
|
return gm
|
49
|
-
|
50
|
-
|
51
|
-
def remove_dangling_args(gm: torch.fx.GraphModule):
|
52
|
-
"""Removes dangling args from the graph.
|
53
|
-
|
54
|
-
Args:
|
55
|
-
gm: The graph module to remove dangling args from.
|
56
|
-
|
57
|
-
Returns:
|
58
|
-
The graph module with dangling args removed.
|
59
|
-
"""
|
60
|
-
nodes_to_erase = []
|
61
|
-
for node in gm.graph.nodes:
|
62
|
-
if node.op == "placeholder" and len(node.users) == 0:
|
63
|
-
nodes_to_erase.append(node)
|
64
|
-
for node in nodes_to_erase:
|
65
|
-
gm.graph.erase_node(node)
|
66
|
-
|
67
|
-
gm.graph.lint()
|
68
|
-
gm.recompile()
|
69
|
-
return gm
|
@@ -17,7 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
|
-
from ai_edge_torch import
|
20
|
+
from ai_edge_torch import fx_infra
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
22
22
|
import torch
|
23
23
|
|
@@ -118,7 +118,7 @@ def _find_scalar_attr(
|
|
118
118
|
track_args[tracker.pattern_arg_pos] = source
|
119
119
|
ep = torch.export.export(pattern_module, tuple(track_args))
|
120
120
|
if decomp_table is not None:
|
121
|
-
ep =
|
121
|
+
ep = fx_infra.run_passes(ep, [fx_infra.CanonicalizePass()])
|
122
122
|
ep = ep.run_decompositions(decomp_table)
|
123
123
|
|
124
124
|
scalar_locs = set()
|
@@ -152,13 +152,15 @@ class Pattern:
|
|
152
152
|
self,
|
153
153
|
name: str,
|
154
154
|
module: Union[Callable, torch.nn.Module],
|
155
|
-
export_args: tuple[Any],
|
155
|
+
export_args: tuple[Any, ...],
|
156
156
|
*,
|
157
157
|
attr_builder: Callable[
|
158
158
|
["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
|
159
159
|
] = None,
|
160
160
|
scalar_attr_trackers: list[ScalarAttrTracker] = None,
|
161
|
-
|
161
|
+
extra_decomp_table: Optional[
|
162
|
+
dict[torch._ops.OperatorBase, Callable]
|
163
|
+
] = None,
|
162
164
|
):
|
163
165
|
"""The PyTorch computation pattern to match against a model.
|
164
166
|
|
@@ -177,8 +179,9 @@ class Pattern:
|
|
177
179
|
scalar_attr_trackers (list[ScalarAttrTracker]): the trackers for scalar
|
178
180
|
args in `export_args`, which are used to track the attr occurrence(s)
|
179
181
|
and retrieve their values from the matched subgraph.
|
180
|
-
|
181
|
-
decomposition
|
182
|
+
extra_decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
|
183
|
+
Extra decomposition to be run on the pattern's exported program (in
|
184
|
+
addition to the default pre_convert_decomp in fx_infra.decomp).
|
182
185
|
"""
|
183
186
|
if not isinstance(module, torch.nn.Module):
|
184
187
|
|
@@ -200,11 +203,14 @@ class Pattern:
|
|
200
203
|
)
|
201
204
|
|
202
205
|
exported_program = torch.export.export(module, export_args)
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
)
|
207
|
-
|
206
|
+
|
207
|
+
decomp_table = fx_infra.decomp.pre_convert_decomp()
|
208
|
+
if extra_decomp_table is not None:
|
209
|
+
decomp_table.update(extra_decomp_table)
|
210
|
+
|
211
|
+
exported_program = fx_infra.safe_run_decompositions(
|
212
|
+
exported_program, decomp_table
|
213
|
+
)
|
208
214
|
|
209
215
|
self.exported_program = exported_program
|
210
216
|
self.graph_module = self.exported_program.graph_module
|
@@ -222,6 +228,9 @@ class Pattern:
|
|
222
228
|
# sanitization.
|
223
229
|
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
|
224
230
|
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
|
231
|
+
self.graph_module = fx_utils.remove_assert_tensor_metadata_nodes(
|
232
|
+
self.graph_module
|
233
|
+
)
|
225
234
|
|
226
235
|
# Builds list of ordered input and output nodes.
|
227
236
|
self.graph_nodes_map = {}
|
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Tests for mark_pattern."""
|
16
16
|
|
17
|
+
from ai_edge_torch import fx_infra
|
17
18
|
from ai_edge_torch import lowertools
|
18
19
|
from ai_edge_torch.hlfb import mark_pattern
|
19
20
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
@@ -22,11 +23,13 @@ import torch
|
|
22
23
|
from absl.testing import absltest as googletest
|
23
24
|
|
24
25
|
|
25
|
-
def
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
26
|
+
def _export_and_decomp(mod, args):
|
27
|
+
ep = torch.export.export(mod, args)
|
28
|
+
ep = fx_infra.safe_run_decompositions(ep, fx_infra.decomp.pre_lower_decomp())
|
29
|
+
return ep
|
30
|
+
|
31
|
+
|
32
|
+
def _to_mlir(ep: torch.export.ExportedProgram):
|
30
33
|
return lowertools.exported_program_to_mlir_text(ep)
|
31
34
|
|
32
35
|
|
@@ -47,9 +50,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
47
50
|
|
48
51
|
model = TestModel().eval()
|
49
52
|
args = (torch.rand(20, 20),)
|
50
|
-
exported_program =
|
53
|
+
exported_program = _export_and_decomp(model, args)
|
51
54
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
52
|
-
mlir =
|
55
|
+
mlir = _to_mlir(exported_program)
|
53
56
|
|
54
57
|
lowertools.assert_string_count(
|
55
58
|
self,
|
@@ -73,9 +76,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
73
76
|
|
74
77
|
model = TestModel().eval()
|
75
78
|
args = (torch.rand(20, 20),)
|
76
|
-
exported_program =
|
79
|
+
exported_program = _export_and_decomp(model, args)
|
77
80
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
78
|
-
mlir =
|
81
|
+
mlir = _to_mlir(exported_program)
|
79
82
|
|
80
83
|
lowertools.assert_string_count(
|
81
84
|
self,
|
@@ -99,9 +102,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
99
102
|
|
100
103
|
model = TestModel().eval()
|
101
104
|
args = (torch.rand(20, 20),)
|
102
|
-
exported_program =
|
105
|
+
exported_program = _export_and_decomp(model, args)
|
103
106
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
104
|
-
mlir =
|
107
|
+
mlir = _to_mlir(exported_program)
|
105
108
|
|
106
109
|
lowertools.assert_string_count(
|
107
110
|
self,
|
@@ -137,9 +140,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
137
140
|
|
138
141
|
model = TestModel().eval()
|
139
142
|
args = (torch.rand(10, 10),)
|
140
|
-
exported_program =
|
143
|
+
exported_program = _export_and_decomp(model, args)
|
141
144
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
142
|
-
mlir =
|
145
|
+
mlir = _to_mlir(exported_program)
|
143
146
|
|
144
147
|
lowertools.assert_string_count(
|
145
148
|
self,
|
@@ -169,9 +172,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
169
172
|
|
170
173
|
model = TestModel().eval()
|
171
174
|
args = (torch.rand(20, 20), torch.rand(20, 20))
|
172
|
-
exported_program =
|
175
|
+
exported_program = _export_and_decomp(model, args)
|
173
176
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
174
|
-
mlir =
|
177
|
+
mlir = _to_mlir(exported_program)
|
175
178
|
|
176
179
|
lowertools.assert_string_count(
|
177
180
|
self,
|
@@ -59,30 +59,3 @@ def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
|
|
59
59
|
ordered_tensor_constants = tuple()
|
60
60
|
|
61
61
|
return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
|
62
|
-
|
63
|
-
|
64
|
-
# TODO(b/331481564): Replace this with CanonicalizePass + run_decomposition
|
65
|
-
def safe_run_decompositions(exported_program, decomp_table=None):
|
66
|
-
for node in exported_program.graph.nodes:
|
67
|
-
if node.target == torch.ops.aten.view.default:
|
68
|
-
# Passes or torch.export may generate aten.view nodes not respecting the
|
69
|
-
# tensor memory format. Changes all the aten.view to torch.reshape
|
70
|
-
# for retracing. If the input memory format is already contiguous,
|
71
|
-
# retracing in run_decomposition below would decompose torch.reshape
|
72
|
-
# back to one aten.view.
|
73
|
-
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
74
|
-
|
75
|
-
return exported_program.run_decompositions(decomp_table)
|
76
|
-
|
77
|
-
|
78
|
-
def dummy_decomp_table():
|
79
|
-
"""Build dummy decomp table for run_decompositions without any decompositions.
|
80
|
-
|
81
|
-
Compatible for torch<=2.5.
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
Decomp table for ExportedProgram.run_decompositions.
|
85
|
-
"""
|
86
|
-
return {
|
87
|
-
torch._ops.OperatorBase(): lambda: None,
|
88
|
-
}
|
@@ -20,6 +20,7 @@ import io
|
|
20
20
|
import operator
|
21
21
|
from typing import Any, Callable, Optional
|
22
22
|
|
23
|
+
from ai_edge_torch import fx_infra
|
23
24
|
from jax.lib import xla_extension
|
24
25
|
from jax._src.lib.mlir import ir
|
25
26
|
from jax._src.lib.mlir.dialects import func
|
@@ -302,16 +303,13 @@ def exported_program_to_mlir(
|
|
302
303
|
exported_program: torch.export.ExportedProgram,
|
303
304
|
) -> MlirLowered:
|
304
305
|
"""Lower the exported program to MLIR."""
|
305
|
-
exported_program =
|
306
|
-
exported_program,
|
306
|
+
exported_program = fx_infra.safe_run_decompositions(
|
307
|
+
exported_program,
|
308
|
+
fx_infra.decomp.pre_lower_decomp(),
|
307
309
|
)
|
308
|
-
|
309
310
|
_convert_i64_to_i32(exported_program)
|
310
|
-
|
311
|
-
|
312
|
-
exported_program = _torch_future.safe_run_decompositions(
|
313
|
-
exported_program, _torch_future.dummy_decomp_table()
|
314
|
-
)
|
311
|
+
# Run decompositions for retracing and cananicalization.
|
312
|
+
exported_program = fx_infra.safe_run_decompositions(exported_program, {})
|
315
313
|
|
316
314
|
# Passes below mutate the exported program to a state not executable by torch.
|
317
315
|
# Do not call run_decompositions after applying the passes.
|
@@ -15,6 +15,7 @@
|
|
15
15
|
from . import _basic
|
16
16
|
from . import _batch_norm
|
17
17
|
from . import _convolution
|
18
|
+
from . import _decomp_registry
|
18
19
|
from . import _jax_lowerings
|
19
20
|
from . import _layer_norm
|
20
21
|
from . import _quantized_decomposed
|
@@ -22,6 +23,5 @@ from . import _rand
|
|
22
23
|
from . import context
|
23
24
|
from . import registry
|
24
25
|
from . import utils
|
25
|
-
from .decomp import decompositions
|
26
26
|
from .registry import lookup
|
27
27
|
from .registry import lower
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Torch export decompositions to run before lowering."""
|
16
|
+
|
17
|
+
from ai_edge_torch import fx_infra
|
18
|
+
import torch
|
19
|
+
|
20
|
+
|
21
|
+
fx_infra.decomp.update_pre_lower_decomp(
|
22
|
+
torch._decomp.get_decompositions([
|
23
|
+
torch.ops.aten.upsample_nearest2d,
|
24
|
+
torch.ops.aten._native_batch_norm_legit.no_stats,
|
25
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
26
|
+
torch.ops.aten._adaptive_avg_pool2d,
|
27
|
+
torch.ops.aten._adaptive_avg_pool3d,
|
28
|
+
torch.ops.aten.grid_sampler_2d,
|
29
|
+
torch.ops.aten.native_group_norm,
|
30
|
+
torch.ops.aten.native_dropout,
|
31
|
+
torch.ops.aten.reflection_pad1d,
|
32
|
+
torch.ops.aten.reflection_pad2d,
|
33
|
+
torch.ops.aten.reflection_pad3d,
|
34
|
+
torch.ops.aten.replication_pad1d,
|
35
|
+
torch.ops.aten.replication_pad2d,
|
36
|
+
torch.ops.aten.replication_pad3d,
|
37
|
+
torch.ops.aten.addmm,
|
38
|
+
])
|
39
|
+
)
|
40
|
+
|
41
|
+
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
|
42
|
+
|
43
|
+
# Torch's default einsum impl/decompositions is less efficient and
|
44
|
+
# optimized through converter than JAX's impl. Disable einsum
|
45
|
+
# decomposition to use JAX bridge for a more efficient lowering.
|
46
|
+
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.einsum.default)
|
47
|
+
|
48
|
+
|
49
|
+
# Override noop aten op decompositions for faster run_decompositions.
|
50
|
+
fx_infra.decomp.add_pre_convert_decomp(
|
51
|
+
torch.ops.aten.alias.default, lambda x: x
|
52
|
+
)
|
53
|
+
fx_infra.decomp.add_pre_convert_decomp(
|
54
|
+
torch.ops.aten.detach.default, lambda x: x
|
55
|
+
)
|
56
|
+
|
57
|
+
# Override _safe_softmax decompositions with regular softmax.
|
58
|
+
# _safe_softmax introduces additional check-select ops to guard extreme
|
59
|
+
# input values to softmax, which could make the converted model inefficient
|
60
|
+
# on-device.
|
61
|
+
if hasattr(torch.ops.aten, "_safe_softmax"):
|
62
|
+
fx_infra.decomp.add_pre_convert_decomp(
|
63
|
+
torch.ops.aten._safe_softmax.default,
|
64
|
+
torch.softmax,
|
65
|
+
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250123
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,32 +1,32 @@
|
|
1
|
-
ai_edge_torch/__init__.py,sha256=
|
1
|
+
ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
|
2
2
|
ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=szrxg2aB7mcm59IL_QVIqapmbw9Nz8AQ28vc9684bqY,706
|
7
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
9
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
12
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
14
|
-
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
16
|
-
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=
|
17
|
-
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=
|
18
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=dG4WIICk0FqCH9euvbYHHsybRN7B1cYcuxN_OYxmjWo,1263
|
13
|
+
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
|
14
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
|
15
|
+
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
16
|
+
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
17
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
|
18
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
|
19
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
21
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
|
22
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=
|
23
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
22
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
|
23
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=lgoH32l6zAbWTCpa_4-RWkHjqbNaPsBnhSObLIX8dL4,10551
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
25
25
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=o6tuJkD-ESaQxLxJpN104qpchm3LCtPmHinzQxe6PSg,17226
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -37,6 +37,12 @@ ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
|
|
37
37
|
ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
|
38
38
|
ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
|
39
39
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
|
+
ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpxXtuiI,1347
|
41
|
+
ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
|
42
|
+
ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=ZbWheeZ8ydsxCk2aVGUgUynrkEkBOMjBCzPhS5uq4sU,2595
|
43
|
+
ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
|
44
|
+
ai_edge_torch/fx_infra/graph_utils.py,sha256=3UZAOHWOUh2LCj1E2_AKQn3gRDILi9JCdqSScjyOd4M,1535
|
45
|
+
ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
|
40
46
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
47
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
48
|
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -115,8 +121,8 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6
|
|
115
121
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
|
116
122
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
117
123
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
118
|
-
ai_edge_torch/generative/fx_passes/__init__.py,sha256=
|
119
|
-
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=
|
124
|
+
ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
|
125
|
+
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=n5TbXdhBZi8jQe4j7-rox_MugMVvW8ReOhkTA3pfQkw,1919
|
120
126
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
121
127
|
ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
|
122
128
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
@@ -159,11 +165,11 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yG
|
|
159
165
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
160
166
|
ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
|
161
167
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
162
|
-
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256
|
163
|
-
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=
|
164
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=
|
168
|
+
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
|
169
|
+
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
|
170
|
+
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=Ui6BrehF3zJJN7uTxKwbO2yCY9mYjbewlQzAxzZv9Es,10274
|
165
171
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
166
|
-
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256
|
172
|
+
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrssTfaV1sMyUvdMbsg,5573
|
167
173
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
168
174
|
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
169
175
|
ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
|
@@ -172,9 +178,9 @@ ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUG
|
|
172
178
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
173
179
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
174
180
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
175
|
-
ai_edge_torch/odml_torch/_torch_future.py,sha256=
|
181
|
+
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
176
182
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
177
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
183
|
+
ai_edge_torch/odml_torch/export.py,sha256=YN7QPrQ8W6T3YVOdyIGadfSQuBroMjIqAMB9FeUa7Ho,13447
|
178
184
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
179
185
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
180
186
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -186,16 +192,16 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
|
|
186
192
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
187
193
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
188
194
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
189
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
195
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
190
196
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
|
191
197
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
192
198
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
199
|
+
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
193
200
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
|
194
201
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
195
202
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
196
203
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
197
204
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
198
|
-
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=0dIkVNN9ubgfwSxeaBFfTmWarWWG8Q1M0HX_1FShHik,2610
|
199
205
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
200
206
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
|
201
207
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
@@ -206,8 +212,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
206
212
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
207
213
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
208
214
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
211
|
-
ai_edge_torch_nightly-0.3.0.
|
212
|
-
ai_edge_torch_nightly-0.3.0.
|
213
|
-
ai_edge_torch_nightly-0.3.0.
|
215
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
216
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/METADATA,sha256=1IZCBOcKVCWbEfAQvEMgt39cuATDIzpK6AhW_gTnIY4,1966
|
217
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
218
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
219
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/RECORD,,
|
@@ -1,69 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Torch export decompositions to run before lowering."""
|
16
|
-
|
17
|
-
import functools
|
18
|
-
|
19
|
-
import torch
|
20
|
-
|
21
|
-
|
22
|
-
@functools.cache
|
23
|
-
def decompositions():
|
24
|
-
# Base: Core ATen decompositions
|
25
|
-
decompositions = torch._decomp.core_aten_decompositions()
|
26
|
-
|
27
|
-
decompositions.update(
|
28
|
-
torch._decomp.get_decompositions([
|
29
|
-
torch.ops.aten.upsample_nearest2d,
|
30
|
-
torch.ops.aten._native_batch_norm_legit.no_stats,
|
31
|
-
torch.ops.aten._native_batch_norm_legit_functional,
|
32
|
-
torch.ops.aten._adaptive_avg_pool2d,
|
33
|
-
torch.ops.aten._adaptive_avg_pool3d,
|
34
|
-
torch.ops.aten.grid_sampler_2d,
|
35
|
-
torch.ops.aten.native_group_norm,
|
36
|
-
torch.ops.aten.native_dropout,
|
37
|
-
torch.ops.aten.reflection_pad1d,
|
38
|
-
torch.ops.aten.reflection_pad2d,
|
39
|
-
torch.ops.aten.reflection_pad3d,
|
40
|
-
torch.ops.aten.replication_pad1d,
|
41
|
-
torch.ops.aten.replication_pad2d,
|
42
|
-
torch.ops.aten.replication_pad3d,
|
43
|
-
torch.ops.aten.addmm,
|
44
|
-
])
|
45
|
-
)
|
46
|
-
|
47
|
-
torch._decomp.remove_decompositions(
|
48
|
-
decompositions,
|
49
|
-
[
|
50
|
-
torch.ops.aten.roll,
|
51
|
-
# Torch's default einsum impl/decompositions is less efficient and
|
52
|
-
# optimized through converter than JAX's impl. Disable einsum
|
53
|
-
# decomposition to use JAX bridge for a more efficient lowering.
|
54
|
-
torch.ops.aten.einsum.default,
|
55
|
-
],
|
56
|
-
)
|
57
|
-
|
58
|
-
# Override noop aten op decompositions for faster run_decompositions.
|
59
|
-
decompositions[torch.ops.aten.alias.default] = lambda x: x
|
60
|
-
decompositions[torch.ops.aten.detach.default] = lambda x: x
|
61
|
-
|
62
|
-
# Override _safe_softmax decompositions with regular softmax.
|
63
|
-
# _safe_softmax introduces additional check-select ops to guard extreme
|
64
|
-
# input values to softmax, which could make the converted model inefficient
|
65
|
-
# on-device.
|
66
|
-
if hasattr(torch.ops.aten, "_safe_softmax"):
|
67
|
-
decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
|
68
|
-
|
69
|
-
return decompositions
|
File without changes
|
File without changes
|