ai-edge-torch-nightly 0.3.0.dev20240827__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
- ai_edge_torch/_convert/test/test_convert.py +1 -1
- ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
- ai_edge_torch/_convert/test/test_convert_multisig.py +71 -31
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
- ai_edge_torch/debug/test/test_culprit.py +1 -1
- ai_edge_torch/debug/test/test_search_model.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +43 -59
- ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_quantize.py +1 -1
- ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
- ai_edge_torch/lowertools/test_utils.py +1 -1
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +320 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +46 -22
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import json
|
|
16
|
+
from typing import Sequence, Union
|
|
17
|
+
|
|
18
|
+
from jax._src.lib.mlir import ir
|
|
19
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from .. import _torch_library
|
|
23
|
+
from .. import lowerings
|
|
24
|
+
|
|
25
|
+
CompositeAttrType = dict[
|
|
26
|
+
str,
|
|
27
|
+
Union[
|
|
28
|
+
int,
|
|
29
|
+
float,
|
|
30
|
+
bool,
|
|
31
|
+
str,
|
|
32
|
+
Sequence[int],
|
|
33
|
+
Sequence[float],
|
|
34
|
+
Sequence[bool],
|
|
35
|
+
],
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _assert_valid_composite_attr(attr: CompositeAttrType):
|
|
40
|
+
if attr is None:
|
|
41
|
+
return
|
|
42
|
+
if not isinstance(attr, dict):
|
|
43
|
+
raise ValueError("Composite attr must be a Python dictionary.")
|
|
44
|
+
|
|
45
|
+
for k, v in attr.items():
|
|
46
|
+
if not isinstance(k, str):
|
|
47
|
+
raise ValueError("Composite attr name must be a Python str.")
|
|
48
|
+
|
|
49
|
+
invalid_attr_value_error = ValueError(
|
|
50
|
+
"Composite attr value must be either Python str, float, int, bool,"
|
|
51
|
+
" list[int], list[float], list[bool]."
|
|
52
|
+
)
|
|
53
|
+
if isinstance(v, (list, tuple)):
|
|
54
|
+
eltys = {type(el) for el in v}
|
|
55
|
+
if len(eltys) > 1 or next(iter(eltys)) not in (int, float, bool):
|
|
56
|
+
raise invalid_attr_value_error
|
|
57
|
+
elif type(v) not in (str, float, int, bool):
|
|
58
|
+
raise invalid_attr_value_error
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@torch._dynamo.assume_constant_result
|
|
62
|
+
def serialize_composite_attr(attr: Union[CompositeAttrType, None]):
|
|
63
|
+
"""Serialize the composite attr into a dynamo-tracable value."""
|
|
64
|
+
if attr is None:
|
|
65
|
+
return None
|
|
66
|
+
_assert_valid_composite_attr(attr)
|
|
67
|
+
return tuple(attr.items())
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@torch._dynamo.assume_constant_result
|
|
71
|
+
def deserialize_composite_attr(serialized_attr) -> CompositeAttrType:
|
|
72
|
+
"""Deserialize dynamo-tracable composite attribute into its raw value."""
|
|
73
|
+
if serialized_attr is None:
|
|
74
|
+
return None
|
|
75
|
+
return dict(serialized_attr)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
_torch_library.ODML_TORCH_LIB.define(
|
|
79
|
+
"mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any?"
|
|
80
|
+
" attr=None) -> Tensor"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
mark_tensor_op = torch.ops.odml_torch.mark_tensor.default
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@torch.library.impl(
|
|
87
|
+
_torch_library.ODML_TORCH_LIB, "mark_tensor", "CompositeExplicitAutograd"
|
|
88
|
+
)
|
|
89
|
+
def mark_tensor(
|
|
90
|
+
x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
|
|
91
|
+
):
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@torch.library.impl(_torch_library.ODML_TORCH_LIB, "mark_tensor", "Meta")
|
|
96
|
+
def mark_tensor_meta(
|
|
97
|
+
x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
|
|
98
|
+
):
|
|
99
|
+
return torch.empty_like(x)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@lowerings.lower(torch.ops.odml_torch.mark_tensor)
|
|
103
|
+
def mark_tensor_lowering(
|
|
104
|
+
lctx, x: ir.Value, name: str, pos: int, id: str, is_input: bool, attr=None
|
|
105
|
+
):
|
|
106
|
+
attr = deserialize_composite_attr(attr)
|
|
107
|
+
return stablehlo.custom_call(
|
|
108
|
+
[x.type],
|
|
109
|
+
inputs=[x],
|
|
110
|
+
call_target_name="mark_tensor",
|
|
111
|
+
backend_config=ir.StringAttr.get(
|
|
112
|
+
json.dumps({
|
|
113
|
+
"name": name,
|
|
114
|
+
"pos": pos,
|
|
115
|
+
"id": id,
|
|
116
|
+
"is_input": is_input,
|
|
117
|
+
"attr": attr,
|
|
118
|
+
})
|
|
119
|
+
),
|
|
120
|
+
)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import uuid
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from . import mark_tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@torch._dynamo.assume_constant_result
|
|
23
|
+
def _get_uuid() -> str:
|
|
24
|
+
return uuid.uuid4().hex
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class StableHLOCompositeBuilder:
|
|
28
|
+
"""Builder class for building a StableHLO composite in the lowering."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, name: str, attr: mark_tensor.CompositeAttrType = None):
|
|
31
|
+
"""Helper for building a StableHLO Composite by marking input and output tensors.
|
|
32
|
+
|
|
33
|
+
It should be used with the StableHLO converters from `torch_xla.stablehlo`.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
name (str): The name of the built StableHLO Composite op.
|
|
37
|
+
attr (mark_tensor.CompositeAttrType): Attributes of the StableHLO
|
|
38
|
+
Composite op.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
self.attr = attr
|
|
42
|
+
self.name = name
|
|
43
|
+
self.id = _get_uuid()
|
|
44
|
+
self._inputs = []
|
|
45
|
+
self._outputs = []
|
|
46
|
+
|
|
47
|
+
def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
|
|
48
|
+
"""Mark the input/output tensors of the StableHLO Composite."""
|
|
49
|
+
marked_tensors = []
|
|
50
|
+
serialized_attr = (
|
|
51
|
+
mark_tensor.serialize_composite_attr(self.attr)
|
|
52
|
+
if not is_input
|
|
53
|
+
else None
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
for pos, tensor in enumerate(tensors):
|
|
57
|
+
if not isinstance(tensor, torch.Tensor):
|
|
58
|
+
raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
|
|
59
|
+
marked_tensors.append(
|
|
60
|
+
mark_tensor.mark_tensor_op(
|
|
61
|
+
tensor,
|
|
62
|
+
name=self.name,
|
|
63
|
+
pos=pos,
|
|
64
|
+
id=self.id,
|
|
65
|
+
is_input=is_input,
|
|
66
|
+
attr=serialized_attr,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if len(marked_tensors) == 1:
|
|
71
|
+
return marked_tensors[0]
|
|
72
|
+
return tuple(marked_tensors)
|
|
73
|
+
|
|
74
|
+
def mark_inputs(self, *tensors: torch.Tensor):
|
|
75
|
+
"""Mark the input tensors of the StableHLO Composite.
|
|
76
|
+
|
|
77
|
+
This method must only be called once per builder.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
*tensors (torch.Tensor): Torch tensors to mark.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
|
|
84
|
+
Torch tensors marked as composite inputs. The tensor inputs of this
|
|
85
|
+
method
|
|
86
|
+
should be replaced by the marked tensors in later usages.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
return self._mark_tensor(*tensors, is_input=True)
|
|
90
|
+
|
|
91
|
+
def mark_outputs(self, *tensors: torch.Tensor):
|
|
92
|
+
"""Mark the output tensors of the StableHLO Composite.
|
|
93
|
+
|
|
94
|
+
This method must only be called once per builder.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
*tensors (torch.Tensor): Torch tensors to mark.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
|
|
101
|
+
Torch tensors marked as composite outputs. The tensor inputs of this
|
|
102
|
+
method
|
|
103
|
+
should be replaced by the marked tensors in later usages.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
return self._mark_tensor(*tensors, is_input=False)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from ._build import build_mlir_debuginfo
|
|
16
|
+
from ._op_polyfill import write_mlir_debuginfo_op
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _class_fullname(cls):
|
|
19
|
+
module = cls.__module__
|
|
20
|
+
if module == "builtins":
|
|
21
|
+
return cls.__qualname__
|
|
22
|
+
return module + "." + cls.__qualname__
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_hierarchy(node: torch.fx.Node):
|
|
26
|
+
nn_module_stack = node.meta.get("nn_module_stack", {})
|
|
27
|
+
layers = []
|
|
28
|
+
for name, layer in nn_module_stack.values():
|
|
29
|
+
iid = ("_" + name.split(".")[-1]) if name else ""
|
|
30
|
+
layer_str = layer if isinstance(layer, str) else _class_fullname(layer)
|
|
31
|
+
layers.append(layer_str + iid)
|
|
32
|
+
|
|
33
|
+
hierachy_str = "/".join(layers) + ";"
|
|
34
|
+
return hierachy_str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_mlir_debuginfo(node: torch.fx.Node):
|
|
38
|
+
"""Build the debuginfo string for the given node's lowerings in MLIR."""
|
|
39
|
+
|
|
40
|
+
if not hasattr(node, "meta") or "nn_module_stack" not in node.meta:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
return _get_hierarchy(node)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Polyfill op for torch.ops.xla.write_mlir_debuginfo.
|
|
16
|
+
|
|
17
|
+
In odml-torch, MLIR debuginfo is generated in the lowering framework directly
|
|
18
|
+
without the need of an additional torch op to write. This file register a no-op
|
|
19
|
+
placeholder torch op to replace torch.ops.xla.write_mlir_debuginfo in
|
|
20
|
+
ai-edge-torch.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from jax._src.lib.mlir import ir
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
from .. import _torch_library
|
|
27
|
+
from .. import lowerings
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_torch_library.ODML_TORCH_LIB.define(
|
|
31
|
+
"write_mlir_debuginfo(Tensor x, str data) -> Tensor"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
write_mlir_debuginfo_op = torch.ops.odml_torch.write_mlir_debuginfo
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@torch.library.impl(
|
|
38
|
+
_torch_library.ODML_TORCH_LIB,
|
|
39
|
+
"write_mlir_debuginfo",
|
|
40
|
+
"CompositeExplicitAutograd",
|
|
41
|
+
)
|
|
42
|
+
def write_mlir_debuginfo(x: torch.Tensor, _: str):
|
|
43
|
+
return x
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@torch.library.impl(
|
|
47
|
+
_torch_library.ODML_TORCH_LIB, "write_mlir_debuginfo", "Meta"
|
|
48
|
+
)
|
|
49
|
+
def write_mlir_debuginfo_meta(x: torch.Tensor, _: str):
|
|
50
|
+
return torch.empty_like(x)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@lowerings.lower(torch.ops.odml_torch.write_mlir_debuginfo)
|
|
54
|
+
def write_mlir_debuginfo_lowering(lctx, x: ir.Value, _: str):
|
|
55
|
+
return x
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""APIs to convert and lower a PyTorch ExportedProgram to MLIR."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import enum
|
|
19
|
+
import io
|
|
20
|
+
import operator
|
|
21
|
+
from typing import Any, Callable, Optional
|
|
22
|
+
|
|
23
|
+
from jax.lib import xla_extension
|
|
24
|
+
from jax._src.lib.mlir import ir
|
|
25
|
+
from jax._src.lib.mlir.dialects import func
|
|
26
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
27
|
+
import torch
|
|
28
|
+
import torch.utils._pytree as pytree
|
|
29
|
+
|
|
30
|
+
from . import _torch_future
|
|
31
|
+
from . import debuginfo
|
|
32
|
+
from . import export_utils
|
|
33
|
+
from . import lowerings
|
|
34
|
+
|
|
35
|
+
LoweringContext = lowerings.context.LoweringContext
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _build_flat_inputs(
|
|
39
|
+
ctx: ir.Context, exported_program: torch.export.ExportedProgram
|
|
40
|
+
):
|
|
41
|
+
"""Build flattened inputs and metadata from exported program's signature."""
|
|
42
|
+
placeholder_nodes = [
|
|
43
|
+
n for n in exported_program.graph.nodes if n.op == "placeholder"
|
|
44
|
+
]
|
|
45
|
+
export_flat_args = _torch_future.graph_module_flat_inputs(
|
|
46
|
+
exported_program, *exported_program.example_inputs
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
ir_inputs = []
|
|
50
|
+
tensor_metas = []
|
|
51
|
+
for node, arg in zip(placeholder_nodes, export_flat_args):
|
|
52
|
+
tensor_meta = node.meta.get("tensor_meta")
|
|
53
|
+
if tensor_meta is None:
|
|
54
|
+
raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
|
|
55
|
+
|
|
56
|
+
tensor_metas.append(tensor_meta)
|
|
57
|
+
# Assume all dynamic dimensions are unbounded.
|
|
58
|
+
# TODO: Add checks for ep.range_constraints in MLIR.
|
|
59
|
+
shape = tuple(
|
|
60
|
+
export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(s) else s
|
|
61
|
+
for s in tensor_meta.shape
|
|
62
|
+
)
|
|
63
|
+
ir_inputs.append(
|
|
64
|
+
ir.RankedTensorType.get(
|
|
65
|
+
shape,
|
|
66
|
+
export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_output_metas(exported_program: torch.export.ExportedProgram):
|
|
73
|
+
"""Get the output node's tensor_meta from the exported program."""
|
|
74
|
+
outputs = [n for n in exported_program.graph.nodes if n.op == "output"]
|
|
75
|
+
assert len(outputs) == 1
|
|
76
|
+
outputs, _ = pytree.tree_flatten(outputs[0].args[0])
|
|
77
|
+
assert all(isinstance(output, torch.fx.Node) for output in outputs)
|
|
78
|
+
return tuple(output.meta["tensor_meta"] for output in outputs)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class LoweringInterpreter(torch.fx.Interpreter):
|
|
82
|
+
"""The FX interpreter to iterate and invoke corresponding lowering for each PyTorch op in the graph."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, module: torch.fx.GraphModule, lctx: LoweringContext):
|
|
85
|
+
super().__init__(module)
|
|
86
|
+
self.lctx = lctx
|
|
87
|
+
self.outputs = None
|
|
88
|
+
|
|
89
|
+
def _build_loc(self, node: torch.fx.Node):
|
|
90
|
+
|
|
91
|
+
info = debuginfo.build_mlir_debuginfo(node)
|
|
92
|
+
if info is None:
|
|
93
|
+
return ir.Location.unknown()
|
|
94
|
+
|
|
95
|
+
return ir.Location.name(name=info)
|
|
96
|
+
|
|
97
|
+
def run_node(self, node: torch.fx.Node):
|
|
98
|
+
loc = self._build_loc(node)
|
|
99
|
+
with loc:
|
|
100
|
+
self.lctx = self.lctx.replace(ir_location=loc, node=node)
|
|
101
|
+
res = super().run_node(node)
|
|
102
|
+
self.lctx = self.lctx.replace(ir_location=None, node=None)
|
|
103
|
+
return res
|
|
104
|
+
|
|
105
|
+
def call_function(self, target, args, kwargs):
|
|
106
|
+
if target is operator.getitem:
|
|
107
|
+
return super().call_function(target, args, kwargs)
|
|
108
|
+
|
|
109
|
+
if hasattr(target, "_schema"):
|
|
110
|
+
new_args = []
|
|
111
|
+
for arg, spec in zip(args, target._schema.arguments):
|
|
112
|
+
if isinstance(spec.type, torch.TensorType):
|
|
113
|
+
if isinstance(arg, int):
|
|
114
|
+
arg = lowerings.utils.splat(arg, ir.IntegerType.get_signless(32))
|
|
115
|
+
elif isinstance(arg, float):
|
|
116
|
+
arg = lowerings.utils.splat(arg, ir.F32Type.get())
|
|
117
|
+
|
|
118
|
+
new_args.append(arg)
|
|
119
|
+
args = tuple(new_args)
|
|
120
|
+
|
|
121
|
+
lowering = lowerings.lookup(target)
|
|
122
|
+
if lowering is None:
|
|
123
|
+
raise RuntimeError(f"Lowering not found: {target}")
|
|
124
|
+
return lowering(self.lctx, *args, **kwargs)
|
|
125
|
+
|
|
126
|
+
def output(self, target, args, kwargs):
|
|
127
|
+
flat_outputs = pytree.tree_flatten(args[0])[0]
|
|
128
|
+
self.outputs = flat_outputs
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclasses.dataclass
|
|
132
|
+
class InputSpec:
|
|
133
|
+
|
|
134
|
+
class VariableType(enum.Enum):
|
|
135
|
+
USER_INPUT = "user_input"
|
|
136
|
+
PARAMETER = "parameter"
|
|
137
|
+
|
|
138
|
+
type_: VariableType
|
|
139
|
+
i: int = -1
|
|
140
|
+
name: str = ""
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def parameter(cls, name: str):
|
|
144
|
+
return cls(type_=cls.VariableType.PARAMETER, name=name)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def user_input(cls, i: int):
|
|
148
|
+
return cls(type_=cls.VariableType.USER_INPUT, i=i)
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def is_parameter(self):
|
|
152
|
+
return self.type_ == self.VariableType.PARAMETER
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def is_user_input(self):
|
|
156
|
+
return self.type_ == self.VariableType.USER_INPUT
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclasses.dataclass
|
|
160
|
+
class VariableSignature: # either argument or parameters
|
|
161
|
+
shape: list[int]
|
|
162
|
+
dtype: str
|
|
163
|
+
input_spec: InputSpec = None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclasses.dataclass
|
|
167
|
+
class MlirLowered:
|
|
168
|
+
"""The lowered MLIR module, metadata, and weight tensors bundle from exported program."""
|
|
169
|
+
|
|
170
|
+
ctx: ir.Context
|
|
171
|
+
module: ir.Module
|
|
172
|
+
state_dict: dict[str, torch.Tensor]
|
|
173
|
+
input_signature: list[VariableSignature]
|
|
174
|
+
output_signature: list[VariableSignature]
|
|
175
|
+
|
|
176
|
+
_tf_function: Optional[Callable[Any, Any]] = None
|
|
177
|
+
|
|
178
|
+
def __str__(self):
|
|
179
|
+
return str(self.get_text(enable_debug_info=False))
|
|
180
|
+
|
|
181
|
+
def __repr__(self):
|
|
182
|
+
return str(self.get_text(enable_debug_info=False))
|
|
183
|
+
|
|
184
|
+
def get_text(self, enable_debug_info=False):
|
|
185
|
+
return str(
|
|
186
|
+
self.module.operation.get_asm(enable_debug_info=enable_debug_info)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def module_bytecode(self) -> bytes:
|
|
191
|
+
output = io.BytesIO()
|
|
192
|
+
self.module.operation.write_bytecode(file=output)
|
|
193
|
+
return output.getvalue()
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def module_bytecode_vhlo(self) -> bytes:
|
|
197
|
+
# HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
|
|
198
|
+
# build, which may not have the same StableHLO version as what used in
|
|
199
|
+
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
|
|
200
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
|
201
|
+
target_version = stablehlo.get_minimum_version()
|
|
202
|
+
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
|
|
203
|
+
self.module_bytecode, target_version
|
|
204
|
+
)
|
|
205
|
+
return module_bytecode
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def tf_function(self):
|
|
209
|
+
# Lazy import
|
|
210
|
+
from . import tf_integration
|
|
211
|
+
|
|
212
|
+
if self._tf_function is None:
|
|
213
|
+
self._tf_function = tf_integration.mlir_to_tf_function(self)
|
|
214
|
+
return self._tf_function
|
|
215
|
+
|
|
216
|
+
def __call__(self, *args):
|
|
217
|
+
# Lazy importing TF when execution is needed.
|
|
218
|
+
return self.tf_function(*args)
|
|
219
|
+
|
|
220
|
+
def to_flatbuffer(self):
|
|
221
|
+
from . import tf_integration
|
|
222
|
+
|
|
223
|
+
return tf_integration.mlir_to_flatbuffer(self)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def exported_program_to_mlir(
|
|
227
|
+
exported_program: torch.export.ExportedProgram,
|
|
228
|
+
) -> MlirLowered:
|
|
229
|
+
"""Lower the exported program to MLIR."""
|
|
230
|
+
if torch.__version__ >= "2.2":
|
|
231
|
+
# torch version 2.1 didn't expose this yet
|
|
232
|
+
exported_program = exported_program.run_decompositions()
|
|
233
|
+
exported_program = exported_program.run_decompositions(
|
|
234
|
+
lowerings.decompositions()
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
|
238
|
+
|
|
239
|
+
module = ir.Module.create()
|
|
240
|
+
lctx = LoweringContext(context, module)
|
|
241
|
+
interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
|
|
242
|
+
ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
|
|
243
|
+
context, exported_program
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# HACK: OSS MLIR pybinding could mysteriously transform func.func under
|
|
247
|
+
# construction into a func.return op after calling ir.Module.parse(..)
|
|
248
|
+
# in the context, which happens in JAX bridge. This is a bug in MLIR
|
|
249
|
+
# pybinding.
|
|
250
|
+
# Workaround steps:
|
|
251
|
+
# 1. Create a temp func.func.
|
|
252
|
+
# 2. Create and insert ops to temp's entry block. During the process
|
|
253
|
+
# the temp func.func would be broken, but the ops in the block are fine.
|
|
254
|
+
# 3. Create the main func.func and copy all the ops in temp's entry block
|
|
255
|
+
# to main.
|
|
256
|
+
# 4. Erase the temp func.func.
|
|
257
|
+
temp_func = func.FuncOp(
|
|
258
|
+
"temp",
|
|
259
|
+
ir.FunctionType.get(ir_flat_inputs, []),
|
|
260
|
+
ip=ir.InsertionPoint.at_block_begin(module.body),
|
|
261
|
+
)
|
|
262
|
+
with ir.InsertionPoint(temp_func.add_entry_block()):
|
|
263
|
+
interpreter.run(*temp_func.arguments, enable_io_processing=False)
|
|
264
|
+
num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
|
|
265
|
+
outputs = interpreter.outputs[num_mutations:]
|
|
266
|
+
func.ReturnOp(interpreter.outputs[num_mutations:])
|
|
267
|
+
|
|
268
|
+
main_func = func.FuncOp(
|
|
269
|
+
"main",
|
|
270
|
+
ir.FunctionType.get(ir_flat_inputs, [o.type for o in outputs]),
|
|
271
|
+
ip=ir.InsertionPoint.at_block_begin(module.body),
|
|
272
|
+
)
|
|
273
|
+
with ir.InsertionPoint(main_func.add_entry_block()):
|
|
274
|
+
outputs = export_utils.clone_func_body_ops(temp_func, main_func.arguments)
|
|
275
|
+
func.ReturnOp(outputs)
|
|
276
|
+
|
|
277
|
+
main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
|
278
|
+
temp_func.erase()
|
|
279
|
+
|
|
280
|
+
module.operation.verify()
|
|
281
|
+
|
|
282
|
+
input_signature = []
|
|
283
|
+
state_dict = {}
|
|
284
|
+
|
|
285
|
+
user_inputs_cnt = 0
|
|
286
|
+
for arg, tensor_meta, input_spec in zip(
|
|
287
|
+
export_flat_args,
|
|
288
|
+
tensor_metas,
|
|
289
|
+
exported_program.graph_signature.input_specs,
|
|
290
|
+
):
|
|
291
|
+
# Assumption:
|
|
292
|
+
# All states comes first in the list of args, and user provided inputs
|
|
293
|
+
# comes later. Also there is no kwargs.
|
|
294
|
+
if input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT:
|
|
295
|
+
input_signature.append(
|
|
296
|
+
VariableSignature(
|
|
297
|
+
tensor_meta.shape,
|
|
298
|
+
tensor_meta.dtype,
|
|
299
|
+
input_spec=InputSpec.user_input(user_inputs_cnt),
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
user_inputs_cnt += 1
|
|
303
|
+
else:
|
|
304
|
+
# Parameter or constant
|
|
305
|
+
state_dict[input_spec.target] = arg
|
|
306
|
+
input_signature.append(
|
|
307
|
+
VariableSignature(
|
|
308
|
+
tensor_meta.shape,
|
|
309
|
+
tensor_meta.dtype,
|
|
310
|
+
input_spec=InputSpec.parameter(input_spec.target),
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
output_signature = [
|
|
315
|
+
VariableSignature(tensor_meta.shape, tensor_meta.dtype)
|
|
316
|
+
for tensor_meta in _get_output_metas(exported_program)
|
|
317
|
+
]
|
|
318
|
+
return MlirLowered(
|
|
319
|
+
context, module, state_dict, input_signature, output_signature
|
|
320
|
+
)
|