ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +30 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +330 -0
- ai_edge_torch/convert/converter.py +171 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
- ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +273 -0
- ai_edge_torch/convert/test/test_convert_composites.py +171 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/debug/__init__.py +16 -0
- ai_edge_torch/debug/culprit.py +423 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +288 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +103 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +135 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_model_conversion.py +201 -0
- ai_edge_torch/generative/test/test_quantize.py +109 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +290 -0
- ai_edge_torch/generative/utilities/t5_loader.py +467 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +134 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Sequence, Union
|
|
17
|
+
|
|
18
|
+
from torch.export import ExportedProgram
|
|
19
|
+
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
|
20
|
+
import torch.utils._pytree as pytree
|
|
21
|
+
|
|
22
|
+
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
23
|
+
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
24
|
+
from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
|
|
25
|
+
from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
|
|
26
|
+
from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
|
|
27
|
+
from ai_edge_torch.convert.fx_passes.build_upsample_bilinear2d_composite_pass import BuildUpsampleBilinear2DCompositePass # NOQA
|
|
28
|
+
from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
|
|
29
|
+
from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
|
|
30
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# TODO(cnchan): make a PassManager class.
|
|
34
|
+
def run_passes(
|
|
35
|
+
exported_program: ExportedProgram,
|
|
36
|
+
passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
|
|
37
|
+
) -> ExportedProgram:
|
|
38
|
+
passes, _ = pytree.tree_flatten(passes)
|
|
39
|
+
for pass_ in passes:
|
|
40
|
+
if not isinstance(pass_, ExportedProgramPassBase):
|
|
41
|
+
pass_ = pass_result_wrapper(pass_)
|
|
42
|
+
if isinstance(pass_, ExportedProgramPassBase):
|
|
43
|
+
exported_program = pass_(exported_program).exported_program
|
|
44
|
+
else:
|
|
45
|
+
gm = exported_program.graph_module
|
|
46
|
+
gm, modified = pass_(gm)
|
|
47
|
+
if modified and gm is not exported_program.graph_module:
|
|
48
|
+
exported_program = ExportedProgram(
|
|
49
|
+
root=gm,
|
|
50
|
+
graph=gm.graph,
|
|
51
|
+
graph_signature=exported_program.graph_signature,
|
|
52
|
+
state_dict=exported_program.state_dict,
|
|
53
|
+
range_constraints=exported_program.range_constraints,
|
|
54
|
+
module_call_graph=exported_program.module_call_graph,
|
|
55
|
+
example_inputs=exported_program.example_inputs,
|
|
56
|
+
verifier=exported_program.verifier,
|
|
57
|
+
constants=exported_program.constants,
|
|
58
|
+
)
|
|
59
|
+
return exported_program
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
from collections import namedtuple
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
from torch.fx.passes.infra.pass_base import PassBase as FxPassBase
|
|
22
|
+
from torch.fx.passes.infra.pass_base import PassResult as FxPassResult
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ExportedProgramPassResult(
|
|
26
|
+
namedtuple("ExportedProgramPassResult", ["exported_program", "modified"])
|
|
27
|
+
):
|
|
28
|
+
|
|
29
|
+
def __new__(cls, exported_program, modified):
|
|
30
|
+
return super().__new__(cls, exported_program, modified)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ExportedProgramPassBase(abc.ABC):
|
|
34
|
+
|
|
35
|
+
def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
|
|
36
|
+
self.requires(exported_program)
|
|
37
|
+
res = self.call(exported_program)
|
|
38
|
+
self.ensures(exported_program)
|
|
39
|
+
return res
|
|
40
|
+
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
def requires(self, exported_program: ExportedProgram) -> None:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def ensures(self, exported_program: ExportedProgram) -> None:
|
|
49
|
+
pass
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import copy
|
|
17
|
+
import functools
|
|
18
|
+
from typing import Any, Callable
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch.fx import GraphModule
|
|
22
|
+
from torch.fx import Node
|
|
23
|
+
from torch.fx.passes.infra.pass_base import PassBase
|
|
24
|
+
from torch.fx.passes.infra.pass_base import PassResult
|
|
25
|
+
import torch.utils._pytree as pytree
|
|
26
|
+
|
|
27
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
28
|
+
|
|
29
|
+
_composite_builders: dict[Callable, Callable[[GraphModule, Node], None]] = {}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _register_composite_builder(op):
|
|
33
|
+
def inner(func):
|
|
34
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
35
|
+
for overload in v.overloads():
|
|
36
|
+
_composite_builders[getattr(v, overload)] = func
|
|
37
|
+
else:
|
|
38
|
+
_composite_builders[op] = func
|
|
39
|
+
return func
|
|
40
|
+
|
|
41
|
+
return inner
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _tree_map_to_composite_attr_values(values, *, stringify_incompatible_values=True):
|
|
45
|
+
|
|
46
|
+
def convert(value):
|
|
47
|
+
nonlocal stringify_incompatible_values
|
|
48
|
+
if value is None:
|
|
49
|
+
return "py_None"
|
|
50
|
+
if isinstance(value, (str, int, float, bool)):
|
|
51
|
+
return value
|
|
52
|
+
|
|
53
|
+
if stringify_incompatible_values:
|
|
54
|
+
return str(value)
|
|
55
|
+
return value
|
|
56
|
+
|
|
57
|
+
return pytree.tree_map(convert, values)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TorchOpArgumentsMapper:
|
|
61
|
+
|
|
62
|
+
def __init__(self, op):
|
|
63
|
+
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
64
|
+
op = op.default
|
|
65
|
+
|
|
66
|
+
assert hasattr(op, "_schema")
|
|
67
|
+
self.op = op
|
|
68
|
+
self.arg_specs = [(spec.name, spec.default_value) for spec in op._schema.arguments]
|
|
69
|
+
|
|
70
|
+
def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
|
|
71
|
+
"""Inspect the op's schema and extract all its args and kwargs
|
|
72
|
+
into one single kwargs dict, with default values for those
|
|
73
|
+
unspecified args and kwargs.
|
|
74
|
+
"""
|
|
75
|
+
full_kwargs = {**(kwargs or {})}
|
|
76
|
+
|
|
77
|
+
for arg, (name, default_value) in zip(args, self.arg_specs):
|
|
78
|
+
full_kwargs[name] = arg
|
|
79
|
+
|
|
80
|
+
for name, default_value in self.arg_specs[len(args) :]:
|
|
81
|
+
if name not in full_kwargs:
|
|
82
|
+
full_kwargs[name] = default_value
|
|
83
|
+
|
|
84
|
+
return full_kwargs
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@_register_composite_builder(torch.ops.aten.hardswish.default)
|
|
88
|
+
def _aten_hardswish(gm: GraphModule, node: Node):
|
|
89
|
+
op = node.target
|
|
90
|
+
|
|
91
|
+
def hardswish(self: torch.Tensor):
|
|
92
|
+
nonlocal op
|
|
93
|
+
builder = StableHLOCompositeBuilder("aten.hardswish.default")
|
|
94
|
+
self = builder.mark_inputs(self)
|
|
95
|
+
output = op(self)
|
|
96
|
+
output = builder.mark_outputs(output)
|
|
97
|
+
return output
|
|
98
|
+
|
|
99
|
+
node.target = hardswish
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@_register_composite_builder(torch.ops.aten.avg_pool2d.default)
|
|
103
|
+
def _aten_avg_pool2d(gm: GraphModule, node: Node):
|
|
104
|
+
op = node.target
|
|
105
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
|
106
|
+
|
|
107
|
+
def avg_pool2d(*args, **kwargs):
|
|
108
|
+
nonlocal op, args_mapper
|
|
109
|
+
|
|
110
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
|
111
|
+
|
|
112
|
+
def is_same_padding(
|
|
113
|
+
input_shape: list[int],
|
|
114
|
+
kernel_size: list[int],
|
|
115
|
+
stride: list[int],
|
|
116
|
+
padding: list[int],
|
|
117
|
+
):
|
|
118
|
+
for dim_input_size, dim_kernel_size, dim_stride, dim_padding in zip(
|
|
119
|
+
input_shape, kernel_size, stride, padding
|
|
120
|
+
):
|
|
121
|
+
dim_output_size = int((dim_input_size + dim_stride - 1) / dim_stride)
|
|
122
|
+
padding_needed = max(
|
|
123
|
+
0, (dim_output_size - 1) * dim_stride + dim_kernel_size - dim_input_size
|
|
124
|
+
)
|
|
125
|
+
if padding_needed % 2 != 0:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
if padding_needed // 2 != dim_padding:
|
|
129
|
+
return False
|
|
130
|
+
return True
|
|
131
|
+
|
|
132
|
+
def is_valid_padding(padding: list[int]):
|
|
133
|
+
return not any(padding)
|
|
134
|
+
|
|
135
|
+
# We prefer to avoid passing empty arrays to composite attributes
|
|
136
|
+
# as they will be lowered to an ArrayAttr so canonicalizing according
|
|
137
|
+
# to the default behaviour here.
|
|
138
|
+
if not full_kwargs["stride"]:
|
|
139
|
+
full_kwargs["stride"] = full_kwargs["kernel_size"]
|
|
140
|
+
|
|
141
|
+
# Only wrap in a composite when the underlying converter can handle it.
|
|
142
|
+
# TODO We should be able to remove this if the converter can inline composites when it can not handle them.
|
|
143
|
+
|
|
144
|
+
# We don't cover any cases where ceil_mode is True or divisor_override is set.
|
|
145
|
+
if full_kwargs["ceil_mode"] or full_kwargs["divisor_override"] is not None:
|
|
146
|
+
return op(*args, **kwargs)
|
|
147
|
+
|
|
148
|
+
# We also can not cover a case where count_include_pad is False but the padding is custom.
|
|
149
|
+
if (
|
|
150
|
+
not full_kwargs["count_include_pad"]
|
|
151
|
+
and not is_valid_padding(full_kwargs["padding"])
|
|
152
|
+
and not is_same_padding(
|
|
153
|
+
list(full_kwargs["self"].shape)[2:],
|
|
154
|
+
full_kwargs["kernel_size"],
|
|
155
|
+
full_kwargs["stride"],
|
|
156
|
+
full_kwargs["padding"],
|
|
157
|
+
)
|
|
158
|
+
):
|
|
159
|
+
return op(*args, **kwargs)
|
|
160
|
+
|
|
161
|
+
builder = StableHLOCompositeBuilder(
|
|
162
|
+
"aten.avg_pool2d.default",
|
|
163
|
+
attr=_tree_map_to_composite_attr_values(
|
|
164
|
+
{
|
|
165
|
+
"kernel_size": full_kwargs["kernel_size"],
|
|
166
|
+
"stride": full_kwargs["stride"],
|
|
167
|
+
"padding": full_kwargs["padding"],
|
|
168
|
+
"ceil_mode": full_kwargs["ceil_mode"],
|
|
169
|
+
"count_include_pad": full_kwargs["count_include_pad"],
|
|
170
|
+
"divisor_override": full_kwargs["divisor_override"],
|
|
171
|
+
}
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
|
|
176
|
+
output = op(**full_kwargs)
|
|
177
|
+
output = builder.mark_outputs(output)
|
|
178
|
+
return output
|
|
179
|
+
|
|
180
|
+
node.target = avg_pool2d
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class BuildAtenCompositePass(PassBase):
|
|
184
|
+
|
|
185
|
+
def call(self, graph_module: GraphModule):
|
|
186
|
+
for node in graph_module.graph.nodes:
|
|
187
|
+
if node.target in _composite_builders:
|
|
188
|
+
_composite_builders[node.target](graph_module, node)
|
|
189
|
+
|
|
190
|
+
graph_module.graph.lint()
|
|
191
|
+
graph_module.recompile()
|
|
192
|
+
return PassResult(graph_module, True)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch.convert.fx_passes import FxPassBase
|
|
21
|
+
from ai_edge_torch.convert.fx_passes import FxPassResult
|
|
22
|
+
from ai_edge_torch.hlfb import mark_pattern
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@functools.cache
|
|
26
|
+
def _get_upsample_bilinear2d_pattern():
|
|
27
|
+
pattern = mark_pattern.Pattern(
|
|
28
|
+
"odml.upsample_bilinear2d",
|
|
29
|
+
lambda x: torch.nn.functional.interpolate(
|
|
30
|
+
x, scale_factor=2, mode="bilinear", align_corners=False
|
|
31
|
+
),
|
|
32
|
+
export_args=(torch.rand(1, 3, 100, 100),),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
@pattern.register_attr_builder
|
|
36
|
+
def attr_builder(pattern, graph_module, internal_match):
|
|
37
|
+
output = internal_match.returning_nodes[0]
|
|
38
|
+
output_h, output_w = output.meta["val"].shape[-2:]
|
|
39
|
+
return {
|
|
40
|
+
"output": (int(output_h), int(output_w)),
|
|
41
|
+
"align_corners": False,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
return pattern
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@functools.cache
|
|
48
|
+
def _get_upsample_bilinear2d_align_corners_pattern():
|
|
49
|
+
pattern = mark_pattern.Pattern(
|
|
50
|
+
"odml.upsample_bilinear2d",
|
|
51
|
+
lambda x: torch.nn.functional.interpolate(
|
|
52
|
+
x, scale_factor=2, mode="bilinear", align_corners=True
|
|
53
|
+
),
|
|
54
|
+
export_args=(torch.rand(1, 3, 100, 100),),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@pattern.register_attr_builder
|
|
58
|
+
def attr_builder(graph_module, pattern, internal_match):
|
|
59
|
+
output = internal_match.returning_nodes[0]
|
|
60
|
+
output_h, output_w = output.meta["val"].shape[-2:]
|
|
61
|
+
return {
|
|
62
|
+
"output": (int(output_h), int(output_w)),
|
|
63
|
+
"align_corners": True,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
return pattern
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class BuildUpsampleBilinear2DCompositePass(FxPassBase):
|
|
70
|
+
|
|
71
|
+
def __init__(self):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self._patterns = [
|
|
74
|
+
_get_upsample_bilinear2d_pattern(),
|
|
75
|
+
_get_upsample_bilinear2d_align_corners_pattern(),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
def call(self, graph_module: torch.fx.GraphModule):
|
|
79
|
+
for pattern in self._patterns:
|
|
80
|
+
graph_module = mark_pattern.mark_pattern(graph_module, pattern)
|
|
81
|
+
|
|
82
|
+
graph_module.graph.lint()
|
|
83
|
+
graph_module.recompile()
|
|
84
|
+
return FxPassResult(graph_module, True)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch.export import ExportedProgram
|
|
18
|
+
|
|
19
|
+
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
20
|
+
from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
21
|
+
|
|
22
|
+
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
|
23
|
+
# any op decompositions but just aot_export_module. Due to the check in
|
|
24
|
+
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
|
25
|
+
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
|
26
|
+
# dummy decomp table is needed.
|
|
27
|
+
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
|
28
|
+
_dummy_decomp_table = {
|
|
29
|
+
torch._ops.OperatorBase(): lambda: None,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CanonicalizePass(ExportedProgramPassBase):
|
|
34
|
+
|
|
35
|
+
def call(self, exported_program: ExportedProgram):
|
|
36
|
+
exported_program = exported_program.run_decompositions(_dummy_decomp_table)
|
|
37
|
+
return ExportedProgramPassResult(exported_program, True)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch.fx.passes.infra.pass_base import PassBase
|
|
18
|
+
from torch.fx.passes.infra.pass_base import PassResult
|
|
19
|
+
import torch.utils._pytree as pytree
|
|
20
|
+
import torch_xla.experimental.xla_mlir_debuginfo # Import required to register torch.ops.xla.write_mlir_debuginfo
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _get_mlir_debuginfo(node: torch.fx.Node):
|
|
24
|
+
def class_fullname(cls):
|
|
25
|
+
module = cls.__module__
|
|
26
|
+
if module == "builtins":
|
|
27
|
+
return cls.__qualname__
|
|
28
|
+
return module + "." + cls.__qualname__
|
|
29
|
+
|
|
30
|
+
def get_hierarchy(node: torch.fx.Node):
|
|
31
|
+
nn_module_stack = node.meta.get("nn_module_stack", {})
|
|
32
|
+
layers = []
|
|
33
|
+
for name, layer in nn_module_stack.values():
|
|
34
|
+
iid = ("_" + name.split(".")[-1]) if name else ""
|
|
35
|
+
layer_str = layer if isinstance(layer, str) else class_fullname(layer)
|
|
36
|
+
layers.append(layer_str + iid)
|
|
37
|
+
|
|
38
|
+
hierachy_str = "/".join(layers) + ";"
|
|
39
|
+
return hierachy_str
|
|
40
|
+
|
|
41
|
+
# TODO(yijieyang): Encode aten op and attrs.
|
|
42
|
+
return get_hierarchy(node)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
46
|
+
if not node.op.startswith("call_function"):
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
target = node.target
|
|
50
|
+
debuginfo = _get_mlir_debuginfo(node)
|
|
51
|
+
|
|
52
|
+
def debuginfo_writer(*args, **kwargs):
|
|
53
|
+
nonlocal target, debuginfo
|
|
54
|
+
outputs = target(*args, **kwargs)
|
|
55
|
+
outputs = pytree.tree_map_only(
|
|
56
|
+
torch.Tensor,
|
|
57
|
+
lambda x: torch.ops.xla.write_mlir_debuginfo(x, debuginfo),
|
|
58
|
+
outputs,
|
|
59
|
+
)
|
|
60
|
+
return outputs
|
|
61
|
+
|
|
62
|
+
node.target = debuginfo_writer
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class InjectMlirDebuginfoPass(PassBase):
|
|
66
|
+
|
|
67
|
+
def call(self, graph_module: torch.fx.GraphModule):
|
|
68
|
+
for node in graph_module.graph.nodes:
|
|
69
|
+
_wrap_call_function_node_with_debuginfo_writer(node)
|
|
70
|
+
|
|
71
|
+
graph_module.graph.lint()
|
|
72
|
+
graph_module.recompile()
|
|
73
|
+
return PassResult(graph_module, True)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass # NOQA
|