ai-edge-torch-nightly 0.3.0.dev20250121__py3-none-any.whl → 0.3.0.dev20250123__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +1 -1
- ai_edge_torch/_convert/conversion.py +6 -10
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +6 -3
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +9 -11
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -3
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py +33 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -3
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +4 -4
- ai_edge_torch/_convert/test/test_convert.py +2 -2
- ai_edge_torch/fx_infra/__init__.py +32 -0
- ai_edge_torch/fx_infra/_canonicalize_pass.py +27 -0
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +57 -0
- ai_edge_torch/fx_infra/decomp.py +80 -0
- ai_edge_torch/fx_infra/graph_utils.py +42 -0
- ai_edge_torch/{fx_pass_base.py → fx_infra/pass_base.py} +0 -28
- ai_edge_torch/generative/fx_passes/__init__.py +3 -3
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -3
- ai_edge_torch/hlfb/mark_pattern/__init__.py +0 -1
- ai_edge_torch/hlfb/mark_pattern/fx_utils.py +5 -21
- ai_edge_torch/hlfb/mark_pattern/pattern.py +20 -11
- ai_edge_torch/hlfb/test/test_mark_pattern.py +18 -15
- ai_edge_torch/odml_torch/_torch_future.py +0 -27
- ai_edge_torch/odml_torch/export.py +6 -8
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +65 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/RECORD +34 -28
- ai_edge_torch/odml_torch/lowerings/decomp.py +0 -69
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
15
|
from ai_edge_torch._config import config
|
17
16
|
from ai_edge_torch._convert.converter import convert
|
18
17
|
from ai_edge_torch._convert.converter import signature
|
@@ -20,6 +19,7 @@ from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
|
20
19
|
from ai_edge_torch.model import Model
|
21
20
|
from ai_edge_torch.version import __version__
|
22
21
|
|
22
|
+
|
23
23
|
def load(path: str) -> Model:
|
24
24
|
"""Imports an ai_edge_torch model from disk.
|
25
25
|
|
@@ -17,7 +17,7 @@ import logging
|
|
17
17
|
from typing import Any, Literal, Optional, Union
|
18
18
|
|
19
19
|
import ai_edge_torch
|
20
|
-
from ai_edge_torch import
|
20
|
+
from ai_edge_torch import fx_infra
|
21
21
|
from ai_edge_torch import lowertools
|
22
22
|
from ai_edge_torch import model
|
23
23
|
from ai_edge_torch._convert import fx_passes
|
@@ -53,7 +53,7 @@ def _run_convert_passes(
|
|
53
53
|
fx_passes.CanonicalizePass(),
|
54
54
|
]
|
55
55
|
|
56
|
-
exported_program =
|
56
|
+
exported_program = fx_infra.run_passes(exported_program, passes)
|
57
57
|
return exported_program
|
58
58
|
|
59
59
|
|
@@ -125,14 +125,10 @@ def convert_signatures(
|
|
125
125
|
else:
|
126
126
|
exported_program = torch.export.export(**kwargs, strict=True)
|
127
127
|
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
# implementation of torch.export changes.
|
133
|
-
exported_program = exported_program.run_decompositions(
|
134
|
-
torch._decomp._decomp_table_to_post_autograd_aten()
|
135
|
-
)
|
128
|
+
exported_program = fx_infra.safe_run_decompositions(
|
129
|
+
exported_program,
|
130
|
+
fx_infra.decomp.pre_convert_decomp(),
|
131
|
+
)
|
136
132
|
return exported_program
|
137
133
|
|
138
134
|
exported_programs: torch.export.ExportedProgram = [
|
@@ -20,4 +20,4 @@ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import Bu
|
|
20
20
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
22
|
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
23
|
-
from ai_edge_torch.
|
23
|
+
from ai_edge_torch.fx_infra import CanonicalizePass
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from typing import Any, Callable
|
17
|
-
from ai_edge_torch import
|
17
|
+
from ai_edge_torch import fx_infra
|
18
18
|
from ai_edge_torch import lowertools
|
19
19
|
import torch
|
20
20
|
import torch.utils._pytree as pytree
|
@@ -25,6 +25,9 @@ _composite_builders: dict[
|
|
25
25
|
|
26
26
|
|
27
27
|
def _register_composite_builder(op):
|
28
|
+
# Remove op from pre_convert_decomp to keep this in the decomposed graph.
|
29
|
+
fx_infra.decomp.remove_pre_convert_decomp(op)
|
30
|
+
|
28
31
|
def inner(func):
|
29
32
|
if isinstance(op, torch._ops.OpOverloadPacket):
|
30
33
|
for overload in op.overloads():
|
@@ -276,7 +279,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
276
279
|
node.target = embedding
|
277
280
|
|
278
281
|
|
279
|
-
class BuildAtenCompositePass(
|
282
|
+
class BuildAtenCompositePass(fx_infra.PassBase):
|
280
283
|
|
281
284
|
def call(self, graph_module: torch.fx.GraphModule):
|
282
285
|
for node in graph_module.graph.nodes:
|
@@ -285,4 +288,4 @@ class BuildAtenCompositePass(fx_pass_base.PassBase):
|
|
285
288
|
|
286
289
|
graph_module.graph.lint()
|
287
290
|
graph_module.recompile()
|
288
|
-
return
|
291
|
+
return fx_infra.PassResult(graph_module, True)
|
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
import functools
|
18
18
|
|
19
|
-
from ai_edge_torch import
|
19
|
+
from ai_edge_torch import fx_infra
|
20
20
|
from ai_edge_torch.hlfb import mark_pattern
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
22
22
|
import torch
|
@@ -41,7 +41,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
41
41
|
x, scale_factor=2, mode="bilinear", align_corners=False
|
42
42
|
),
|
43
43
|
export_args=(torch.rand(1, 3, 100, 100),),
|
44
|
-
|
44
|
+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
45
45
|
)
|
46
46
|
|
47
47
|
@pattern.register_attr_builder
|
@@ -65,7 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
65
65
|
x, scale_factor=2, mode="bilinear", align_corners=True
|
66
66
|
),
|
67
67
|
export_args=(torch.rand(1, 3, 100, 100),),
|
68
|
-
|
68
|
+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
69
69
|
)
|
70
70
|
|
71
71
|
@pattern.register_attr_builder
|
@@ -89,7 +89,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
89
89
|
x, scale_factor=2, mode="nearest"
|
90
90
|
),
|
91
91
|
export_args=(torch.rand(1, 3, 100, 100),),
|
92
|
-
|
92
|
+
extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
|
93
93
|
)
|
94
94
|
|
95
95
|
@pattern.register_attr_builder
|
@@ -104,7 +104,7 @@ def _get_interpolate_nearest2d_pattern():
|
|
104
104
|
return pattern
|
105
105
|
|
106
106
|
|
107
|
-
class BuildInterpolateCompositePass(
|
107
|
+
class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
|
108
108
|
|
109
109
|
def __init__(self):
|
110
110
|
super().__init__()
|
@@ -115,11 +115,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
|
115
115
|
]
|
116
116
|
|
117
117
|
def call(self, exported_program: torch.export.ExportedProgram):
|
118
|
-
exported_program =
|
119
|
-
exported_program,
|
120
|
-
|
121
|
-
exported_program = exported_program.run_decompositions(
|
122
|
-
_INTERPOLATE_DECOMPOSITIONS
|
118
|
+
exported_program = fx_infra.safe_run_decompositions(
|
119
|
+
exported_program,
|
120
|
+
_INTERPOLATE_DECOMPOSITIONS,
|
123
121
|
)
|
124
122
|
|
125
123
|
graph_module = exported_program.graph_module
|
@@ -128,4 +126,4 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase):
|
|
128
126
|
|
129
127
|
graph_module.graph.lint()
|
130
128
|
graph_module.recompile()
|
131
|
-
return
|
129
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from ai_edge_torch import
|
16
|
+
from ai_edge_torch import fx_infra
|
17
17
|
from ai_edge_torch import lowertools
|
18
18
|
import torch
|
19
19
|
import torch.utils._pytree as pytree
|
@@ -61,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
61
61
|
node.target = debuginfo_writer
|
62
62
|
|
63
63
|
|
64
|
-
class InjectMlirDebuginfoPass(
|
64
|
+
class InjectMlirDebuginfoPass(fx_infra.PassBase):
|
65
65
|
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
|
66
66
|
|
67
67
|
def call(self, graph_module: torch.fx.GraphModule):
|
@@ -70,4 +70,4 @@ class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
|
70
70
|
|
71
71
|
graph_module.graph.lint()
|
72
72
|
graph_module.recompile()
|
73
|
-
return
|
73
|
+
return fx_infra.PassResult(graph_module, True)
|
@@ -13,4 +13,5 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import _decomp_registry
|
16
17
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
|
@@ -0,0 +1,33 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Remove decompositions for ops to keep in layout optimization."""
|
16
|
+
from ai_edge_torch import fx_infra
|
17
|
+
import torch
|
18
|
+
|
19
|
+
__all__ = []
|
20
|
+
|
21
|
+
aten = torch.ops.aten
|
22
|
+
|
23
|
+
_OPS_TO_KEEP = [
|
24
|
+
aten.conv2d,
|
25
|
+
aten.max_pool2d,
|
26
|
+
aten._softmax.default,
|
27
|
+
aten.group_norm.default,
|
28
|
+
aten.native_group_norm.default,
|
29
|
+
aten.reflection_pad2d.default,
|
30
|
+
]
|
31
|
+
|
32
|
+
for op in _OPS_TO_KEEP:
|
33
|
+
fx_infra.decomp.remove_pre_convert_decomp(op)
|
@@ -18,7 +18,7 @@ import operator
|
|
18
18
|
import os
|
19
19
|
from typing import Union
|
20
20
|
|
21
|
-
from ai_edge_torch import
|
21
|
+
from ai_edge_torch import fx_infra
|
22
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
23
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
24
24
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
|
@@ -30,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e
|
|
30
30
|
TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
|
31
31
|
|
32
32
|
|
33
|
-
class OptimizeLayoutTransposesPass(
|
33
|
+
class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
|
34
34
|
|
35
35
|
def get_source_meta(self, node: torch.fx.Node):
|
36
36
|
keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
|
@@ -300,4 +300,4 @@ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase):
|
|
300
300
|
# Mark const node again for debugging
|
301
301
|
self.mark_const_nodes(exported_program)
|
302
302
|
|
303
|
-
return
|
303
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -15,11 +15,11 @@
|
|
15
15
|
"""Pass to remove all non user outputs from exported program."""
|
16
16
|
|
17
17
|
|
18
|
-
from ai_edge_torch import
|
18
|
+
from ai_edge_torch import fx_infra
|
19
19
|
import torch
|
20
20
|
|
21
21
|
|
22
|
-
class RemoveNonUserOutputsPass(
|
22
|
+
class RemoveNonUserOutputsPass(fx_infra.ExportedProgramPassBase):
|
23
23
|
"""This pass removes all non user outputs from the exported program's output.
|
24
24
|
|
25
25
|
The FX graph may output more tensors/data than what user's original model
|
@@ -47,6 +47,6 @@ class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase):
|
|
47
47
|
node.args = (tuple(new_outputs),)
|
48
48
|
exported_program.graph_signature.output_specs = new_output_specs
|
49
49
|
|
50
|
-
exported_program.
|
50
|
+
exported_program.graph.eliminate_dead_code()
|
51
51
|
exported_program.graph_module.recompile()
|
52
|
-
return
|
52
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -511,7 +511,7 @@ class TestConvert(googletest.TestCase):
|
|
511
511
|
# Step 1: export resnet18
|
512
512
|
args = (torch.randn(1, 3, 224, 224),)
|
513
513
|
m = torchvision.models.resnet18().eval()
|
514
|
-
m = torch.
|
514
|
+
m = torch.export.export_for_training(m, args).module()
|
515
515
|
|
516
516
|
# Step 2: Insert observers or fake quantize modules
|
517
517
|
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
@@ -533,7 +533,7 @@ class TestConvert(googletest.TestCase):
|
|
533
533
|
# Step 1: export resnet18
|
534
534
|
args = (torch.randn(1, 3, 224, 224),)
|
535
535
|
m = torchvision.models.resnet18().eval()
|
536
|
-
m = torch.
|
536
|
+
m = torch.export.export_for_training(m, args).module()
|
537
537
|
|
538
538
|
# Step 2: Insert observers or fake quantize modules
|
539
539
|
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ai_edge_torch.fx_infra import _canonicalize_pass
|
17
|
+
from ai_edge_torch.fx_infra import _safe_run_decompositions
|
18
|
+
from ai_edge_torch.fx_infra import decomp
|
19
|
+
from ai_edge_torch.fx_infra import graph_utils
|
20
|
+
from ai_edge_torch.fx_infra import pass_base
|
21
|
+
|
22
|
+
|
23
|
+
PassBase = pass_base.PassBase
|
24
|
+
PassResult = pass_base.PassResult
|
25
|
+
FxPassBase = pass_base.FxPassBase
|
26
|
+
FxPassResult = pass_base.FxPassResult
|
27
|
+
ExportedProgramPassBase = pass_base.ExportedProgramPassBase
|
28
|
+
ExportedProgramPassResult = pass_base.ExportedProgramPassResult
|
29
|
+
run_passes = pass_base.run_passes
|
30
|
+
|
31
|
+
CanonicalizePass = _canonicalize_pass.CanonicalizePass
|
32
|
+
safe_run_decompositions = _safe_run_decompositions.safe_run_decompositions
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from ai_edge_torch.fx_infra import _safe_run_decompositions
|
16
|
+
from ai_edge_torch.fx_infra import pass_base
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
class CanonicalizePass(pass_base.ExportedProgramPassBase):
|
21
|
+
|
22
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
23
|
+
exported_program = _safe_run_decompositions.safe_run_decompositions(
|
24
|
+
exported_program, {}
|
25
|
+
)
|
26
|
+
|
27
|
+
return pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
|
16
|
+
import torch
|
17
|
+
|
18
|
+
|
19
|
+
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
20
|
+
# any op decompositions but just aot_export_module. Due to the check in
|
21
|
+
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
22
|
+
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
23
|
+
# dummy decomp table is needed.
|
24
|
+
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
25
|
+
_DUMMY_DECOMP_TABLE = {
|
26
|
+
torch._ops.OperatorBase(): lambda: None,
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
def safe_run_decompositions(exported_program, decomp_table=None):
|
31
|
+
"""Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
|
32
|
+
|
33
|
+
if decomp_table is not None and not decomp_table:
|
34
|
+
# Empty decomp table means no op decompositions. Use dummy decomp table
|
35
|
+
# instead for backward compatibility.
|
36
|
+
decomp_table = _DUMMY_DECOMP_TABLE
|
37
|
+
|
38
|
+
for node in exported_program.graph.nodes:
|
39
|
+
if node.target == torch.ops.aten.view.default:
|
40
|
+
# Passes or torch.export may generate aten.view nodes not respecting the
|
41
|
+
# tensor memory format. Changes all the aten.view to torch.reshape
|
42
|
+
# for retracing. If the input memory format is already contiguous,
|
43
|
+
# retracing in run_decomposition below would decompose torch.reshape
|
44
|
+
# back to one aten.view.
|
45
|
+
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
46
|
+
|
47
|
+
exported_program = exported_program.run_decompositions(decomp_table)
|
48
|
+
|
49
|
+
if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
|
50
|
+
for node in exported_program.graph.nodes:
|
51
|
+
if node.target == torch.ops.aten._assert_tensor_metadata.default:
|
52
|
+
exported_program.graph.erase_node(node)
|
53
|
+
|
54
|
+
exported_program.graph.eliminate_dead_code()
|
55
|
+
exported_program.graph_module.recompile()
|
56
|
+
|
57
|
+
return exported_program
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
|
16
|
+
import torch
|
17
|
+
|
18
|
+
# Decompositions to be run after torch.export before conversion and any
|
19
|
+
# passes. Remove ops from this decomposition table if they need to be preserved
|
20
|
+
# in passes.
|
21
|
+
_pre_convert_decomp = torch._decomp.core_aten_decompositions().copy()
|
22
|
+
|
23
|
+
|
24
|
+
# Decompositions to be run after conversion before odml_torch passes and
|
25
|
+
# lowerings.
|
26
|
+
_pre_lower_decomp = torch._decomp.core_aten_decompositions().copy()
|
27
|
+
|
28
|
+
|
29
|
+
def _get_ops(op):
|
30
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
31
|
+
return [getattr(op, overload) for overload in op.overloads()]
|
32
|
+
else:
|
33
|
+
return [op]
|
34
|
+
|
35
|
+
|
36
|
+
def pre_convert_decomp():
|
37
|
+
"""Decompositions to be run after torch.export before conversion and any passes."""
|
38
|
+
return _pre_convert_decomp.copy()
|
39
|
+
|
40
|
+
|
41
|
+
def pre_lower_decomp():
|
42
|
+
"""Decompositions to be run after conversion before odml_torch passes and lowerings."""
|
43
|
+
return _pre_lower_decomp.copy()
|
44
|
+
|
45
|
+
|
46
|
+
def remove_pre_lower_decomp(op):
|
47
|
+
# Also remove from pre_convert_decomp which always run before pre_lower_
|
48
|
+
# decomp.
|
49
|
+
remove_pre_convert_decomp(op)
|
50
|
+
|
51
|
+
for op_ in _get_ops(op):
|
52
|
+
_pre_lower_decomp.pop(op_, None)
|
53
|
+
|
54
|
+
|
55
|
+
def remove_pre_convert_decomp(op):
|
56
|
+
for op_ in _get_ops(op):
|
57
|
+
_pre_convert_decomp.pop(op_, None)
|
58
|
+
|
59
|
+
|
60
|
+
def add_pre_convert_decomp(op, decomp):
|
61
|
+
# Also add decomp to pre_lower_decomp which runs after pre_convert_decomp.
|
62
|
+
add_pre_lower_decomp(op, decomp)
|
63
|
+
|
64
|
+
for op_ in _get_ops(op):
|
65
|
+
_pre_convert_decomp[op_] = decomp
|
66
|
+
|
67
|
+
|
68
|
+
def add_pre_lower_decomp(op, decomp):
|
69
|
+
for op_ in _get_ops(op):
|
70
|
+
_pre_lower_decomp[op_] = decomp
|
71
|
+
|
72
|
+
|
73
|
+
def update_pre_convert_decomp(decomps):
|
74
|
+
for op, decomp in decomps.items():
|
75
|
+
add_pre_convert_decomp(op, decomp)
|
76
|
+
|
77
|
+
|
78
|
+
def update_pre_lower_decomp(decomps):
|
79
|
+
for op, decomp in decomps.items():
|
80
|
+
add_pre_lower_decomp(op, decomp)
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""FX graph utilities."""
|
16
|
+
import torch
|
17
|
+
|
18
|
+
|
19
|
+
def remove_dangling_args(graph_module: torch.fx.GraphModule):
|
20
|
+
"""Removes dangling args from the graph."""
|
21
|
+
for node in graph_module.graph.nodes:
|
22
|
+
if node.op == "placeholder" and not node.users:
|
23
|
+
graph_module.graph.erase_node(node)
|
24
|
+
|
25
|
+
graph_module.graph.lint()
|
26
|
+
graph_module.recompile()
|
27
|
+
return graph_module
|
28
|
+
|
29
|
+
|
30
|
+
def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
|
31
|
+
"""Removes aten._assert_tensor_metadata nodes from the graph.
|
32
|
+
|
33
|
+
This op is inserted by torch.export to check tensor metadata on custom ops. It
|
34
|
+
can break patten matching and lowering.
|
35
|
+
"""
|
36
|
+
for node in graph_module.graph.nodes:
|
37
|
+
if node.target == torch.ops.aten._assert_tensor_metadata.default:
|
38
|
+
graph_module.graph.erase_node(node)
|
39
|
+
|
40
|
+
graph_module.graph.lint()
|
41
|
+
graph_module.recompile()
|
42
|
+
return graph_module
|
@@ -80,31 +80,3 @@ def run_passes(
|
|
80
80
|
constants=exported_program.constants,
|
81
81
|
)
|
82
82
|
return exported_program
|
83
|
-
|
84
|
-
|
85
|
-
class CanonicalizePass(ExportedProgramPassBase):
|
86
|
-
|
87
|
-
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
88
|
-
# any op decompositions but just aot_export_module. Due to the check in
|
89
|
-
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
90
|
-
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
91
|
-
# dummy decomp table is needed.
|
92
|
-
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
93
|
-
_DUMMY_DECOMP_TABLE = {
|
94
|
-
torch._ops.OperatorBase(): lambda: None,
|
95
|
-
}
|
96
|
-
|
97
|
-
def call(self, exported_program: torch.export.ExportedProgram):
|
98
|
-
for node in exported_program.graph.nodes:
|
99
|
-
if node.target == torch.ops.aten.view.default:
|
100
|
-
# Passes or torch.export may generate aten.view nodes not respecting the
|
101
|
-
# tensor memory format. Changes all the aten.view to torch.reshape
|
102
|
-
# for retracing. If the input memory format is already contiguous,
|
103
|
-
# retracing in run_decomposition below would decompose torch.reshape
|
104
|
-
# back to one aten.view.
|
105
|
-
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
106
|
-
|
107
|
-
exported_program = exported_program.run_decompositions(
|
108
|
-
self._DUMMY_DECOMP_TABLE
|
109
|
-
)
|
110
|
-
return ExportedProgramPassResult(exported_program, True)
|
@@ -12,8 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ai_edge_torch import
|
16
|
-
from ai_edge_torch.
|
15
|
+
from ai_edge_torch import fx_infra
|
16
|
+
from ai_edge_torch.fx_infra import CanonicalizePass
|
17
17
|
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
|
18
18
|
import torch
|
19
19
|
|
@@ -21,7 +21,7 @@ import torch
|
|
21
21
|
def run_generative_passes(
|
22
22
|
exported_program: torch.export.ExportedProgram,
|
23
23
|
) -> torch.export.ExportedProgram:
|
24
|
-
return
|
24
|
+
return fx_infra.run_passes(
|
25
25
|
exported_program,
|
26
26
|
[
|
27
27
|
RemoveSDPACompositeZeroMaskPass(),
|
@@ -12,12 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from ai_edge_torch import
|
15
|
+
from ai_edge_torch import fx_infra
|
16
16
|
from ai_edge_torch import lowertools
|
17
17
|
import torch
|
18
18
|
|
19
19
|
|
20
|
-
class RemoveSDPACompositeZeroMaskPass(
|
20
|
+
class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
|
21
21
|
|
22
22
|
def is_zero_tensor_node(self, node: torch.fx.Node):
|
23
23
|
return node.target == torch.ops.aten.zeros.default
|
@@ -47,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase):
|
|
47
47
|
|
48
48
|
exported_program.graph_module.graph.lint()
|
49
49
|
exported_program.graph_module.recompile()
|
50
|
-
return
|
50
|
+
return fx_infra.ExportedProgramPassResult(exported_program, True)
|
@@ -14,8 +14,13 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""FX graph utilities for pattern matching clean ups."""
|
16
16
|
|
17
|
+
from ai_edge_torch import fx_infra
|
17
18
|
import torch
|
18
19
|
|
20
|
+
remove_dangling_args = fx_infra.graph_utils.remove_dangling_args
|
21
|
+
remove_assert_tensor_metadata_nodes = (
|
22
|
+
fx_infra.graph_utils.remove_assert_tensor_metadata_nodes
|
23
|
+
)
|
19
24
|
|
20
25
|
def is_clone_op(node: torch.fx.Node) -> bool:
|
21
26
|
"""Checks if the node is a clone op."""
|
@@ -46,24 +51,3 @@ def remove_clone_ops(gm: torch.fx.GraphModule):
|
|
46
51
|
gm.graph.lint()
|
47
52
|
gm.recompile()
|
48
53
|
return gm
|
49
|
-
|
50
|
-
|
51
|
-
def remove_dangling_args(gm: torch.fx.GraphModule):
|
52
|
-
"""Removes dangling args from the graph.
|
53
|
-
|
54
|
-
Args:
|
55
|
-
gm: The graph module to remove dangling args from.
|
56
|
-
|
57
|
-
Returns:
|
58
|
-
The graph module with dangling args removed.
|
59
|
-
"""
|
60
|
-
nodes_to_erase = []
|
61
|
-
for node in gm.graph.nodes:
|
62
|
-
if node.op == "placeholder" and len(node.users) == 0:
|
63
|
-
nodes_to_erase.append(node)
|
64
|
-
for node in nodes_to_erase:
|
65
|
-
gm.graph.erase_node(node)
|
66
|
-
|
67
|
-
gm.graph.lint()
|
68
|
-
gm.recompile()
|
69
|
-
return gm
|
@@ -17,7 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
from typing import Any, Callable, Optional, Union
|
19
19
|
|
20
|
-
from ai_edge_torch import
|
20
|
+
from ai_edge_torch import fx_infra
|
21
21
|
from ai_edge_torch.hlfb.mark_pattern import fx_utils
|
22
22
|
import torch
|
23
23
|
|
@@ -118,7 +118,7 @@ def _find_scalar_attr(
|
|
118
118
|
track_args[tracker.pattern_arg_pos] = source
|
119
119
|
ep = torch.export.export(pattern_module, tuple(track_args))
|
120
120
|
if decomp_table is not None:
|
121
|
-
ep =
|
121
|
+
ep = fx_infra.run_passes(ep, [fx_infra.CanonicalizePass()])
|
122
122
|
ep = ep.run_decompositions(decomp_table)
|
123
123
|
|
124
124
|
scalar_locs = set()
|
@@ -152,13 +152,15 @@ class Pattern:
|
|
152
152
|
self,
|
153
153
|
name: str,
|
154
154
|
module: Union[Callable, torch.nn.Module],
|
155
|
-
export_args: tuple[Any],
|
155
|
+
export_args: tuple[Any, ...],
|
156
156
|
*,
|
157
157
|
attr_builder: Callable[
|
158
158
|
["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
|
159
159
|
] = None,
|
160
160
|
scalar_attr_trackers: list[ScalarAttrTracker] = None,
|
161
|
-
|
161
|
+
extra_decomp_table: Optional[
|
162
|
+
dict[torch._ops.OperatorBase, Callable]
|
163
|
+
] = None,
|
162
164
|
):
|
163
165
|
"""The PyTorch computation pattern to match against a model.
|
164
166
|
|
@@ -177,8 +179,9 @@ class Pattern:
|
|
177
179
|
scalar_attr_trackers (list[ScalarAttrTracker]): the trackers for scalar
|
178
180
|
args in `export_args`, which are used to track the attr occurrence(s)
|
179
181
|
and retrieve their values from the matched subgraph.
|
180
|
-
|
181
|
-
decomposition
|
182
|
+
extra_decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
|
183
|
+
Extra decomposition to be run on the pattern's exported program (in
|
184
|
+
addition to the default pre_convert_decomp in fx_infra.decomp).
|
182
185
|
"""
|
183
186
|
if not isinstance(module, torch.nn.Module):
|
184
187
|
|
@@ -200,11 +203,14 @@ class Pattern:
|
|
200
203
|
)
|
201
204
|
|
202
205
|
exported_program = torch.export.export(module, export_args)
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
)
|
207
|
-
|
206
|
+
|
207
|
+
decomp_table = fx_infra.decomp.pre_convert_decomp()
|
208
|
+
if extra_decomp_table is not None:
|
209
|
+
decomp_table.update(extra_decomp_table)
|
210
|
+
|
211
|
+
exported_program = fx_infra.safe_run_decompositions(
|
212
|
+
exported_program, decomp_table
|
213
|
+
)
|
208
214
|
|
209
215
|
self.exported_program = exported_program
|
210
216
|
self.graph_module = self.exported_program.graph_module
|
@@ -222,6 +228,9 @@ class Pattern:
|
|
222
228
|
# sanitization.
|
223
229
|
self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
|
224
230
|
self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
|
231
|
+
self.graph_module = fx_utils.remove_assert_tensor_metadata_nodes(
|
232
|
+
self.graph_module
|
233
|
+
)
|
225
234
|
|
226
235
|
# Builds list of ordered input and output nodes.
|
227
236
|
self.graph_nodes_map = {}
|
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Tests for mark_pattern."""
|
16
16
|
|
17
|
+
from ai_edge_torch import fx_infra
|
17
18
|
from ai_edge_torch import lowertools
|
18
19
|
from ai_edge_torch.hlfb import mark_pattern
|
19
20
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
@@ -22,11 +23,13 @@ import torch
|
|
22
23
|
from absl.testing import absltest as googletest
|
23
24
|
|
24
25
|
|
25
|
-
def
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
26
|
+
def _export_and_decomp(mod, args):
|
27
|
+
ep = torch.export.export(mod, args)
|
28
|
+
ep = fx_infra.safe_run_decompositions(ep, fx_infra.decomp.pre_lower_decomp())
|
29
|
+
return ep
|
30
|
+
|
31
|
+
|
32
|
+
def _to_mlir(ep: torch.export.ExportedProgram):
|
30
33
|
return lowertools.exported_program_to_mlir_text(ep)
|
31
34
|
|
32
35
|
|
@@ -47,9 +50,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
47
50
|
|
48
51
|
model = TestModel().eval()
|
49
52
|
args = (torch.rand(20, 20),)
|
50
|
-
exported_program =
|
53
|
+
exported_program = _export_and_decomp(model, args)
|
51
54
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
52
|
-
mlir =
|
55
|
+
mlir = _to_mlir(exported_program)
|
53
56
|
|
54
57
|
lowertools.assert_string_count(
|
55
58
|
self,
|
@@ -73,9 +76,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
73
76
|
|
74
77
|
model = TestModel().eval()
|
75
78
|
args = (torch.rand(20, 20),)
|
76
|
-
exported_program =
|
79
|
+
exported_program = _export_and_decomp(model, args)
|
77
80
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
78
|
-
mlir =
|
81
|
+
mlir = _to_mlir(exported_program)
|
79
82
|
|
80
83
|
lowertools.assert_string_count(
|
81
84
|
self,
|
@@ -99,9 +102,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
99
102
|
|
100
103
|
model = TestModel().eval()
|
101
104
|
args = (torch.rand(20, 20),)
|
102
|
-
exported_program =
|
105
|
+
exported_program = _export_and_decomp(model, args)
|
103
106
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
104
|
-
mlir =
|
107
|
+
mlir = _to_mlir(exported_program)
|
105
108
|
|
106
109
|
lowertools.assert_string_count(
|
107
110
|
self,
|
@@ -137,9 +140,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
137
140
|
|
138
141
|
model = TestModel().eval()
|
139
142
|
args = (torch.rand(10, 10),)
|
140
|
-
exported_program =
|
143
|
+
exported_program = _export_and_decomp(model, args)
|
141
144
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
142
|
-
mlir =
|
145
|
+
mlir = _to_mlir(exported_program)
|
143
146
|
|
144
147
|
lowertools.assert_string_count(
|
145
148
|
self,
|
@@ -169,9 +172,9 @@ class TestMarkPattern(googletest.TestCase):
|
|
169
172
|
|
170
173
|
model = TestModel().eval()
|
171
174
|
args = (torch.rand(20, 20), torch.rand(20, 20))
|
172
|
-
exported_program =
|
175
|
+
exported_program = _export_and_decomp(model, args)
|
173
176
|
mark_pattern.mark_pattern(exported_program.graph_module, pattern)
|
174
|
-
mlir =
|
177
|
+
mlir = _to_mlir(exported_program)
|
175
178
|
|
176
179
|
lowertools.assert_string_count(
|
177
180
|
self,
|
@@ -59,30 +59,3 @@ def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
|
|
59
59
|
ordered_tensor_constants = tuple()
|
60
60
|
|
61
61
|
return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
|
62
|
-
|
63
|
-
|
64
|
-
# TODO(b/331481564): Replace this with CanonicalizePass + run_decomposition
|
65
|
-
def safe_run_decompositions(exported_program, decomp_table=None):
|
66
|
-
for node in exported_program.graph.nodes:
|
67
|
-
if node.target == torch.ops.aten.view.default:
|
68
|
-
# Passes or torch.export may generate aten.view nodes not respecting the
|
69
|
-
# tensor memory format. Changes all the aten.view to torch.reshape
|
70
|
-
# for retracing. If the input memory format is already contiguous,
|
71
|
-
# retracing in run_decomposition below would decompose torch.reshape
|
72
|
-
# back to one aten.view.
|
73
|
-
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
|
74
|
-
|
75
|
-
return exported_program.run_decompositions(decomp_table)
|
76
|
-
|
77
|
-
|
78
|
-
def dummy_decomp_table():
|
79
|
-
"""Build dummy decomp table for run_decompositions without any decompositions.
|
80
|
-
|
81
|
-
Compatible for torch<=2.5.
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
Decomp table for ExportedProgram.run_decompositions.
|
85
|
-
"""
|
86
|
-
return {
|
87
|
-
torch._ops.OperatorBase(): lambda: None,
|
88
|
-
}
|
@@ -20,6 +20,7 @@ import io
|
|
20
20
|
import operator
|
21
21
|
from typing import Any, Callable, Optional
|
22
22
|
|
23
|
+
from ai_edge_torch import fx_infra
|
23
24
|
from jax.lib import xla_extension
|
24
25
|
from jax._src.lib.mlir import ir
|
25
26
|
from jax._src.lib.mlir.dialects import func
|
@@ -302,16 +303,13 @@ def exported_program_to_mlir(
|
|
302
303
|
exported_program: torch.export.ExportedProgram,
|
303
304
|
) -> MlirLowered:
|
304
305
|
"""Lower the exported program to MLIR."""
|
305
|
-
exported_program =
|
306
|
-
exported_program,
|
306
|
+
exported_program = fx_infra.safe_run_decompositions(
|
307
|
+
exported_program,
|
308
|
+
fx_infra.decomp.pre_lower_decomp(),
|
307
309
|
)
|
308
|
-
|
309
310
|
_convert_i64_to_i32(exported_program)
|
310
|
-
|
311
|
-
|
312
|
-
exported_program = _torch_future.safe_run_decompositions(
|
313
|
-
exported_program, _torch_future.dummy_decomp_table()
|
314
|
-
)
|
311
|
+
# Run decompositions for retracing and cananicalization.
|
312
|
+
exported_program = fx_infra.safe_run_decompositions(exported_program, {})
|
315
313
|
|
316
314
|
# Passes below mutate the exported program to a state not executable by torch.
|
317
315
|
# Do not call run_decompositions after applying the passes.
|
@@ -15,6 +15,7 @@
|
|
15
15
|
from . import _basic
|
16
16
|
from . import _batch_norm
|
17
17
|
from . import _convolution
|
18
|
+
from . import _decomp_registry
|
18
19
|
from . import _jax_lowerings
|
19
20
|
from . import _layer_norm
|
20
21
|
from . import _quantized_decomposed
|
@@ -22,6 +23,5 @@ from . import _rand
|
|
22
23
|
from . import context
|
23
24
|
from . import registry
|
24
25
|
from . import utils
|
25
|
-
from .decomp import decompositions
|
26
26
|
from .registry import lookup
|
27
27
|
from .registry import lower
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Torch export decompositions to run before lowering."""
|
16
|
+
|
17
|
+
from ai_edge_torch import fx_infra
|
18
|
+
import torch
|
19
|
+
|
20
|
+
|
21
|
+
fx_infra.decomp.update_pre_lower_decomp(
|
22
|
+
torch._decomp.get_decompositions([
|
23
|
+
torch.ops.aten.upsample_nearest2d,
|
24
|
+
torch.ops.aten._native_batch_norm_legit.no_stats,
|
25
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
26
|
+
torch.ops.aten._adaptive_avg_pool2d,
|
27
|
+
torch.ops.aten._adaptive_avg_pool3d,
|
28
|
+
torch.ops.aten.grid_sampler_2d,
|
29
|
+
torch.ops.aten.native_group_norm,
|
30
|
+
torch.ops.aten.native_dropout,
|
31
|
+
torch.ops.aten.reflection_pad1d,
|
32
|
+
torch.ops.aten.reflection_pad2d,
|
33
|
+
torch.ops.aten.reflection_pad3d,
|
34
|
+
torch.ops.aten.replication_pad1d,
|
35
|
+
torch.ops.aten.replication_pad2d,
|
36
|
+
torch.ops.aten.replication_pad3d,
|
37
|
+
torch.ops.aten.addmm,
|
38
|
+
])
|
39
|
+
)
|
40
|
+
|
41
|
+
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
|
42
|
+
|
43
|
+
# Torch's default einsum impl/decompositions is less efficient and
|
44
|
+
# optimized through converter than JAX's impl. Disable einsum
|
45
|
+
# decomposition to use JAX bridge for a more efficient lowering.
|
46
|
+
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.einsum.default)
|
47
|
+
|
48
|
+
|
49
|
+
# Override noop aten op decompositions for faster run_decompositions.
|
50
|
+
fx_infra.decomp.add_pre_convert_decomp(
|
51
|
+
torch.ops.aten.alias.default, lambda x: x
|
52
|
+
)
|
53
|
+
fx_infra.decomp.add_pre_convert_decomp(
|
54
|
+
torch.ops.aten.detach.default, lambda x: x
|
55
|
+
)
|
56
|
+
|
57
|
+
# Override _safe_softmax decompositions with regular softmax.
|
58
|
+
# _safe_softmax introduces additional check-select ops to guard extreme
|
59
|
+
# input values to softmax, which could make the converted model inefficient
|
60
|
+
# on-device.
|
61
|
+
if hasattr(torch.ops.aten, "_safe_softmax"):
|
62
|
+
fx_infra.decomp.add_pre_convert_decomp(
|
63
|
+
torch.ops.aten._safe_softmax.default,
|
64
|
+
torch.softmax,
|
65
|
+
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250123
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,32 +1,32 @@
|
|
1
|
-
ai_edge_torch/__init__.py,sha256=
|
1
|
+
ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
|
2
2
|
ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
-
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=szrxg2aB7mcm59IL_QVIqapmbw9Nz8AQ28vc9684bqY,706
|
7
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
9
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
12
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
14
|
-
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
16
|
-
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=
|
17
|
-
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=
|
18
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=dG4WIICk0FqCH9euvbYHHsybRN7B1cYcuxN_OYxmjWo,1263
|
13
|
+
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
|
14
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
|
15
|
+
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
16
|
+
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
17
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
|
18
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
|
19
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
21
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
|
22
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=
|
23
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
22
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
|
23
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=lgoH32l6zAbWTCpa_4-RWkHjqbNaPsBnhSObLIX8dL4,10551
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
25
25
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=o6tuJkD-ESaQxLxJpN104qpchm3LCtPmHinzQxe6PSg,17226
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -37,6 +37,12 @@ ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
|
|
37
37
|
ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
|
38
38
|
ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
|
39
39
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
|
+
ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpxXtuiI,1347
|
41
|
+
ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
|
42
|
+
ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=ZbWheeZ8ydsxCk2aVGUgUynrkEkBOMjBCzPhS5uq4sU,2595
|
43
|
+
ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
|
44
|
+
ai_edge_torch/fx_infra/graph_utils.py,sha256=3UZAOHWOUh2LCj1E2_AKQn3gRDILi9JCdqSScjyOd4M,1535
|
45
|
+
ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
|
40
46
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
47
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
48
|
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -115,8 +121,8 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6
|
|
115
121
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
|
116
122
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
117
123
|
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
118
|
-
ai_edge_torch/generative/fx_passes/__init__.py,sha256=
|
119
|
-
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=
|
124
|
+
ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
|
125
|
+
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=n5TbXdhBZi8jQe4j7-rox_MugMVvW8ReOhkTA3pfQkw,1919
|
120
126
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
121
127
|
ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
|
122
128
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
@@ -159,11 +165,11 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yG
|
|
159
165
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
160
166
|
ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
|
161
167
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
162
|
-
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256
|
163
|
-
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=
|
164
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=
|
168
|
+
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
|
169
|
+
ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
|
170
|
+
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=Ui6BrehF3zJJN7uTxKwbO2yCY9mYjbewlQzAxzZv9Es,10274
|
165
171
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
166
|
-
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256
|
172
|
+
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrssTfaV1sMyUvdMbsg,5573
|
167
173
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
168
174
|
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
169
175
|
ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
|
@@ -172,9 +178,9 @@ ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUG
|
|
172
178
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
173
179
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
174
180
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
175
|
-
ai_edge_torch/odml_torch/_torch_future.py,sha256=
|
181
|
+
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
176
182
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
177
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
183
|
+
ai_edge_torch/odml_torch/export.py,sha256=YN7QPrQ8W6T3YVOdyIGadfSQuBroMjIqAMB9FeUa7Ho,13447
|
178
184
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
179
185
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
180
186
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -186,16 +192,16 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
|
|
186
192
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
187
193
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
188
194
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
189
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
195
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
190
196
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
|
191
197
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
192
198
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
199
|
+
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
193
200
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
|
194
201
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
195
202
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
196
203
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
197
204
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
198
|
-
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=0dIkVNN9ubgfwSxeaBFfTmWarWWG8Q1M0HX_1FShHik,2610
|
199
205
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
200
206
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
|
201
207
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
@@ -206,8 +212,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
206
212
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
207
213
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
208
214
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
211
|
-
ai_edge_torch_nightly-0.3.0.
|
212
|
-
ai_edge_torch_nightly-0.3.0.
|
213
|
-
ai_edge_torch_nightly-0.3.0.
|
215
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
216
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/METADATA,sha256=1IZCBOcKVCWbEfAQvEMgt39cuATDIzpK6AhW_gTnIY4,1966
|
217
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
218
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
219
|
+
ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/RECORD,,
|
@@ -1,69 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Torch export decompositions to run before lowering."""
|
16
|
-
|
17
|
-
import functools
|
18
|
-
|
19
|
-
import torch
|
20
|
-
|
21
|
-
|
22
|
-
@functools.cache
|
23
|
-
def decompositions():
|
24
|
-
# Base: Core ATen decompositions
|
25
|
-
decompositions = torch._decomp.core_aten_decompositions()
|
26
|
-
|
27
|
-
decompositions.update(
|
28
|
-
torch._decomp.get_decompositions([
|
29
|
-
torch.ops.aten.upsample_nearest2d,
|
30
|
-
torch.ops.aten._native_batch_norm_legit.no_stats,
|
31
|
-
torch.ops.aten._native_batch_norm_legit_functional,
|
32
|
-
torch.ops.aten._adaptive_avg_pool2d,
|
33
|
-
torch.ops.aten._adaptive_avg_pool3d,
|
34
|
-
torch.ops.aten.grid_sampler_2d,
|
35
|
-
torch.ops.aten.native_group_norm,
|
36
|
-
torch.ops.aten.native_dropout,
|
37
|
-
torch.ops.aten.reflection_pad1d,
|
38
|
-
torch.ops.aten.reflection_pad2d,
|
39
|
-
torch.ops.aten.reflection_pad3d,
|
40
|
-
torch.ops.aten.replication_pad1d,
|
41
|
-
torch.ops.aten.replication_pad2d,
|
42
|
-
torch.ops.aten.replication_pad3d,
|
43
|
-
torch.ops.aten.addmm,
|
44
|
-
])
|
45
|
-
)
|
46
|
-
|
47
|
-
torch._decomp.remove_decompositions(
|
48
|
-
decompositions,
|
49
|
-
[
|
50
|
-
torch.ops.aten.roll,
|
51
|
-
# Torch's default einsum impl/decompositions is less efficient and
|
52
|
-
# optimized through converter than JAX's impl. Disable einsum
|
53
|
-
# decomposition to use JAX bridge for a more efficient lowering.
|
54
|
-
torch.ops.aten.einsum.default,
|
55
|
-
],
|
56
|
-
)
|
57
|
-
|
58
|
-
# Override noop aten op decompositions for faster run_decompositions.
|
59
|
-
decompositions[torch.ops.aten.alias.default] = lambda x: x
|
60
|
-
decompositions[torch.ops.aten.detach.default] = lambda x: x
|
61
|
-
|
62
|
-
# Override _safe_softmax decompositions with regular softmax.
|
63
|
-
# _safe_softmax introduces additional check-select ops to guard extreme
|
64
|
-
# input values to softmax, which could make the converted model inefficient
|
65
|
-
# on-device.
|
66
|
-
if hasattr(torch.ops.aten, "_safe_softmax"):
|
67
|
-
decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
|
68
|
-
|
69
|
-
return decompositions
|
File without changes
|
File without changes
|