ai-edge-torch-nightly 0.3.0.dev20240827__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (46) hide show
  1. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
  2. ai_edge_torch/_convert/test/test_convert.py +1 -1
  3. ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
  4. ai_edge_torch/_convert/test/test_convert_multisig.py +71 -31
  5. ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
  6. ai_edge_torch/debug/test/test_culprit.py +1 -1
  7. ai_edge_torch/debug/test/test_search_model.py +1 -1
  8. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +43 -59
  9. ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
  10. ai_edge_torch/generative/test/test_loader.py +1 -1
  11. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  12. ai_edge_torch/generative/test/test_quantize.py +1 -1
  13. ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
  14. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
  15. ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
  16. ai_edge_torch/lowertools/test_utils.py +1 -1
  17. ai_edge_torch/odml_torch/__init__.py +20 -0
  18. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  19. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  20. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  21. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  22. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  23. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  24. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  25. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  26. ai_edge_torch/odml_torch/export.py +320 -0
  27. ai_edge_torch/odml_torch/export_utils.py +168 -0
  28. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  29. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
  30. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  31. ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
  32. ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
  33. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  34. ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
  35. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
  36. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  37. ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
  38. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  39. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  40. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  41. ai_edge_torch/version.py +1 -1
  42. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +46 -22
  44. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,120 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import json
16
+ from typing import Sequence, Union
17
+
18
+ from jax._src.lib.mlir import ir
19
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
20
+ import torch
21
+
22
+ from .. import _torch_library
23
+ from .. import lowerings
24
+
25
+ CompositeAttrType = dict[
26
+ str,
27
+ Union[
28
+ int,
29
+ float,
30
+ bool,
31
+ str,
32
+ Sequence[int],
33
+ Sequence[float],
34
+ Sequence[bool],
35
+ ],
36
+ ]
37
+
38
+
39
+ def _assert_valid_composite_attr(attr: CompositeAttrType):
40
+ if attr is None:
41
+ return
42
+ if not isinstance(attr, dict):
43
+ raise ValueError("Composite attr must be a Python dictionary.")
44
+
45
+ for k, v in attr.items():
46
+ if not isinstance(k, str):
47
+ raise ValueError("Composite attr name must be a Python str.")
48
+
49
+ invalid_attr_value_error = ValueError(
50
+ "Composite attr value must be either Python str, float, int, bool,"
51
+ " list[int], list[float], list[bool]."
52
+ )
53
+ if isinstance(v, (list, tuple)):
54
+ eltys = {type(el) for el in v}
55
+ if len(eltys) > 1 or next(iter(eltys)) not in (int, float, bool):
56
+ raise invalid_attr_value_error
57
+ elif type(v) not in (str, float, int, bool):
58
+ raise invalid_attr_value_error
59
+
60
+
61
+ @torch._dynamo.assume_constant_result
62
+ def serialize_composite_attr(attr: Union[CompositeAttrType, None]):
63
+ """Serialize the composite attr into a dynamo-tracable value."""
64
+ if attr is None:
65
+ return None
66
+ _assert_valid_composite_attr(attr)
67
+ return tuple(attr.items())
68
+
69
+
70
+ @torch._dynamo.assume_constant_result
71
+ def deserialize_composite_attr(serialized_attr) -> CompositeAttrType:
72
+ """Deserialize dynamo-tracable composite attribute into its raw value."""
73
+ if serialized_attr is None:
74
+ return None
75
+ return dict(serialized_attr)
76
+
77
+
78
+ _torch_library.ODML_TORCH_LIB.define(
79
+ "mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any?"
80
+ " attr=None) -> Tensor"
81
+ )
82
+
83
+ mark_tensor_op = torch.ops.odml_torch.mark_tensor.default
84
+
85
+
86
+ @torch.library.impl(
87
+ _torch_library.ODML_TORCH_LIB, "mark_tensor", "CompositeExplicitAutograd"
88
+ )
89
+ def mark_tensor(
90
+ x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
91
+ ):
92
+ return x
93
+
94
+
95
+ @torch.library.impl(_torch_library.ODML_TORCH_LIB, "mark_tensor", "Meta")
96
+ def mark_tensor_meta(
97
+ x: torch.Tensor, name: str, pos: int, id: str, is_input: bool, attr=None
98
+ ):
99
+ return torch.empty_like(x)
100
+
101
+
102
+ @lowerings.lower(torch.ops.odml_torch.mark_tensor)
103
+ def mark_tensor_lowering(
104
+ lctx, x: ir.Value, name: str, pos: int, id: str, is_input: bool, attr=None
105
+ ):
106
+ attr = deserialize_composite_attr(attr)
107
+ return stablehlo.custom_call(
108
+ [x.type],
109
+ inputs=[x],
110
+ call_target_name="mark_tensor",
111
+ backend_config=ir.StringAttr.get(
112
+ json.dumps({
113
+ "name": name,
114
+ "pos": pos,
115
+ "id": id,
116
+ "is_input": is_input,
117
+ "attr": attr,
118
+ })
119
+ ),
120
+ )
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import uuid
16
+
17
+ import torch
18
+
19
+ from . import mark_tensor
20
+
21
+
22
+ @torch._dynamo.assume_constant_result
23
+ def _get_uuid() -> str:
24
+ return uuid.uuid4().hex
25
+
26
+
27
+ class StableHLOCompositeBuilder:
28
+ """Builder class for building a StableHLO composite in the lowering."""
29
+
30
+ def __init__(self, name: str, attr: mark_tensor.CompositeAttrType = None):
31
+ """Helper for building a StableHLO Composite by marking input and output tensors.
32
+
33
+ It should be used with the StableHLO converters from `torch_xla.stablehlo`.
34
+
35
+ Args:
36
+ name (str): The name of the built StableHLO Composite op.
37
+ attr (mark_tensor.CompositeAttrType): Attributes of the StableHLO
38
+ Composite op.
39
+ """
40
+
41
+ self.attr = attr
42
+ self.name = name
43
+ self.id = _get_uuid()
44
+ self._inputs = []
45
+ self._outputs = []
46
+
47
+ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
48
+ """Mark the input/output tensors of the StableHLO Composite."""
49
+ marked_tensors = []
50
+ serialized_attr = (
51
+ mark_tensor.serialize_composite_attr(self.attr)
52
+ if not is_input
53
+ else None
54
+ )
55
+
56
+ for pos, tensor in enumerate(tensors):
57
+ if not isinstance(tensor, torch.Tensor):
58
+ raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
59
+ marked_tensors.append(
60
+ mark_tensor.mark_tensor_op(
61
+ tensor,
62
+ name=self.name,
63
+ pos=pos,
64
+ id=self.id,
65
+ is_input=is_input,
66
+ attr=serialized_attr,
67
+ )
68
+ )
69
+
70
+ if len(marked_tensors) == 1:
71
+ return marked_tensors[0]
72
+ return tuple(marked_tensors)
73
+
74
+ def mark_inputs(self, *tensors: torch.Tensor):
75
+ """Mark the input tensors of the StableHLO Composite.
76
+
77
+ This method must only be called once per builder.
78
+
79
+ Args:
80
+ *tensors (torch.Tensor): Torch tensors to mark.
81
+
82
+ Returns:
83
+ marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
84
+ Torch tensors marked as composite inputs. The tensor inputs of this
85
+ method
86
+ should be replaced by the marked tensors in later usages.
87
+ """
88
+
89
+ return self._mark_tensor(*tensors, is_input=True)
90
+
91
+ def mark_outputs(self, *tensors: torch.Tensor):
92
+ """Mark the output tensors of the StableHLO Composite.
93
+
94
+ This method must only be called once per builder.
95
+
96
+ Args:
97
+ *tensors (torch.Tensor): Torch tensors to mark.
98
+
99
+ Returns:
100
+ marked_tensors (torch.Tensor or Tuple[torch.Tensor]):
101
+ Torch tensors marked as composite outputs. The tensor inputs of this
102
+ method
103
+ should be replaced by the marked tensors in later usages.
104
+ """
105
+
106
+ return self._mark_tensor(*tensors, is_input=False)
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from ._build import build_mlir_debuginfo
16
+ from ._op_polyfill import write_mlir_debuginfo_op
@@ -0,0 +1,43 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+
17
+
18
+ def _class_fullname(cls):
19
+ module = cls.__module__
20
+ if module == "builtins":
21
+ return cls.__qualname__
22
+ return module + "." + cls.__qualname__
23
+
24
+
25
+ def _get_hierarchy(node: torch.fx.Node):
26
+ nn_module_stack = node.meta.get("nn_module_stack", {})
27
+ layers = []
28
+ for name, layer in nn_module_stack.values():
29
+ iid = ("_" + name.split(".")[-1]) if name else ""
30
+ layer_str = layer if isinstance(layer, str) else _class_fullname(layer)
31
+ layers.append(layer_str + iid)
32
+
33
+ hierachy_str = "/".join(layers) + ";"
34
+ return hierachy_str
35
+
36
+
37
+ def build_mlir_debuginfo(node: torch.fx.Node):
38
+ """Build the debuginfo string for the given node's lowerings in MLIR."""
39
+
40
+ if not hasattr(node, "meta") or "nn_module_stack" not in node.meta:
41
+ return None
42
+
43
+ return _get_hierarchy(node)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Polyfill op for torch.ops.xla.write_mlir_debuginfo.
16
+
17
+ In odml-torch, MLIR debuginfo is generated in the lowering framework directly
18
+ without the need of an additional torch op to write. This file register a no-op
19
+ placeholder torch op to replace torch.ops.xla.write_mlir_debuginfo in
20
+ ai-edge-torch.
21
+ """
22
+
23
+ from jax._src.lib.mlir import ir
24
+ import torch
25
+
26
+ from .. import _torch_library
27
+ from .. import lowerings
28
+
29
+
30
+ _torch_library.ODML_TORCH_LIB.define(
31
+ "write_mlir_debuginfo(Tensor x, str data) -> Tensor"
32
+ )
33
+
34
+ write_mlir_debuginfo_op = torch.ops.odml_torch.write_mlir_debuginfo
35
+
36
+
37
+ @torch.library.impl(
38
+ _torch_library.ODML_TORCH_LIB,
39
+ "write_mlir_debuginfo",
40
+ "CompositeExplicitAutograd",
41
+ )
42
+ def write_mlir_debuginfo(x: torch.Tensor, _: str):
43
+ return x
44
+
45
+
46
+ @torch.library.impl(
47
+ _torch_library.ODML_TORCH_LIB, "write_mlir_debuginfo", "Meta"
48
+ )
49
+ def write_mlir_debuginfo_meta(x: torch.Tensor, _: str):
50
+ return torch.empty_like(x)
51
+
52
+
53
+ @lowerings.lower(torch.ops.odml_torch.write_mlir_debuginfo)
54
+ def write_mlir_debuginfo_lowering(lctx, x: ir.Value, _: str):
55
+ return x
@@ -0,0 +1,320 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """APIs to convert and lower a PyTorch ExportedProgram to MLIR."""
16
+
17
+ import dataclasses
18
+ import enum
19
+ import io
20
+ import operator
21
+ from typing import Any, Callable, Optional
22
+
23
+ from jax.lib import xla_extension
24
+ from jax._src.lib.mlir import ir
25
+ from jax._src.lib.mlir.dialects import func
26
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
27
+ import torch
28
+ import torch.utils._pytree as pytree
29
+
30
+ from . import _torch_future
31
+ from . import debuginfo
32
+ from . import export_utils
33
+ from . import lowerings
34
+
35
+ LoweringContext = lowerings.context.LoweringContext
36
+
37
+
38
+ def _build_flat_inputs(
39
+ ctx: ir.Context, exported_program: torch.export.ExportedProgram
40
+ ):
41
+ """Build flattened inputs and metadata from exported program's signature."""
42
+ placeholder_nodes = [
43
+ n for n in exported_program.graph.nodes if n.op == "placeholder"
44
+ ]
45
+ export_flat_args = _torch_future.graph_module_flat_inputs(
46
+ exported_program, *exported_program.example_inputs
47
+ )
48
+
49
+ ir_inputs = []
50
+ tensor_metas = []
51
+ for node, arg in zip(placeholder_nodes, export_flat_args):
52
+ tensor_meta = node.meta.get("tensor_meta")
53
+ if tensor_meta is None:
54
+ raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
55
+
56
+ tensor_metas.append(tensor_meta)
57
+ # Assume all dynamic dimensions are unbounded.
58
+ # TODO: Add checks for ep.range_constraints in MLIR.
59
+ shape = tuple(
60
+ export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(s) else s
61
+ for s in tensor_meta.shape
62
+ )
63
+ ir_inputs.append(
64
+ ir.RankedTensorType.get(
65
+ shape,
66
+ export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
67
+ )
68
+ )
69
+ return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
70
+
71
+
72
+ def _get_output_metas(exported_program: torch.export.ExportedProgram):
73
+ """Get the output node's tensor_meta from the exported program."""
74
+ outputs = [n for n in exported_program.graph.nodes if n.op == "output"]
75
+ assert len(outputs) == 1
76
+ outputs, _ = pytree.tree_flatten(outputs[0].args[0])
77
+ assert all(isinstance(output, torch.fx.Node) for output in outputs)
78
+ return tuple(output.meta["tensor_meta"] for output in outputs)
79
+
80
+
81
+ class LoweringInterpreter(torch.fx.Interpreter):
82
+ """The FX interpreter to iterate and invoke corresponding lowering for each PyTorch op in the graph."""
83
+
84
+ def __init__(self, module: torch.fx.GraphModule, lctx: LoweringContext):
85
+ super().__init__(module)
86
+ self.lctx = lctx
87
+ self.outputs = None
88
+
89
+ def _build_loc(self, node: torch.fx.Node):
90
+
91
+ info = debuginfo.build_mlir_debuginfo(node)
92
+ if info is None:
93
+ return ir.Location.unknown()
94
+
95
+ return ir.Location.name(name=info)
96
+
97
+ def run_node(self, node: torch.fx.Node):
98
+ loc = self._build_loc(node)
99
+ with loc:
100
+ self.lctx = self.lctx.replace(ir_location=loc, node=node)
101
+ res = super().run_node(node)
102
+ self.lctx = self.lctx.replace(ir_location=None, node=None)
103
+ return res
104
+
105
+ def call_function(self, target, args, kwargs):
106
+ if target is operator.getitem:
107
+ return super().call_function(target, args, kwargs)
108
+
109
+ if hasattr(target, "_schema"):
110
+ new_args = []
111
+ for arg, spec in zip(args, target._schema.arguments):
112
+ if isinstance(spec.type, torch.TensorType):
113
+ if isinstance(arg, int):
114
+ arg = lowerings.utils.splat(arg, ir.IntegerType.get_signless(32))
115
+ elif isinstance(arg, float):
116
+ arg = lowerings.utils.splat(arg, ir.F32Type.get())
117
+
118
+ new_args.append(arg)
119
+ args = tuple(new_args)
120
+
121
+ lowering = lowerings.lookup(target)
122
+ if lowering is None:
123
+ raise RuntimeError(f"Lowering not found: {target}")
124
+ return lowering(self.lctx, *args, **kwargs)
125
+
126
+ def output(self, target, args, kwargs):
127
+ flat_outputs = pytree.tree_flatten(args[0])[0]
128
+ self.outputs = flat_outputs
129
+
130
+
131
+ @dataclasses.dataclass
132
+ class InputSpec:
133
+
134
+ class VariableType(enum.Enum):
135
+ USER_INPUT = "user_input"
136
+ PARAMETER = "parameter"
137
+
138
+ type_: VariableType
139
+ i: int = -1
140
+ name: str = ""
141
+
142
+ @classmethod
143
+ def parameter(cls, name: str):
144
+ return cls(type_=cls.VariableType.PARAMETER, name=name)
145
+
146
+ @classmethod
147
+ def user_input(cls, i: int):
148
+ return cls(type_=cls.VariableType.USER_INPUT, i=i)
149
+
150
+ @property
151
+ def is_parameter(self):
152
+ return self.type_ == self.VariableType.PARAMETER
153
+
154
+ @property
155
+ def is_user_input(self):
156
+ return self.type_ == self.VariableType.USER_INPUT
157
+
158
+
159
+ @dataclasses.dataclass
160
+ class VariableSignature: # either argument or parameters
161
+ shape: list[int]
162
+ dtype: str
163
+ input_spec: InputSpec = None
164
+
165
+
166
+ @dataclasses.dataclass
167
+ class MlirLowered:
168
+ """The lowered MLIR module, metadata, and weight tensors bundle from exported program."""
169
+
170
+ ctx: ir.Context
171
+ module: ir.Module
172
+ state_dict: dict[str, torch.Tensor]
173
+ input_signature: list[VariableSignature]
174
+ output_signature: list[VariableSignature]
175
+
176
+ _tf_function: Optional[Callable[Any, Any]] = None
177
+
178
+ def __str__(self):
179
+ return str(self.get_text(enable_debug_info=False))
180
+
181
+ def __repr__(self):
182
+ return str(self.get_text(enable_debug_info=False))
183
+
184
+ def get_text(self, enable_debug_info=False):
185
+ return str(
186
+ self.module.operation.get_asm(enable_debug_info=enable_debug_info)
187
+ )
188
+
189
+ @property
190
+ def module_bytecode(self) -> bytes:
191
+ output = io.BytesIO()
192
+ self.module.operation.write_bytecode(file=output)
193
+ return output.getvalue()
194
+
195
+ @property
196
+ def module_bytecode_vhlo(self) -> bytes:
197
+ # HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
198
+ # build, which may not have the same StableHLO version as what used in
199
+ # TFLite converter. Therefore we always serialize MLIR module in VHLO.
200
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201
+ target_version = stablehlo.get_minimum_version()
202
+ module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203
+ self.module_bytecode, target_version
204
+ )
205
+ return module_bytecode
206
+
207
+ @property
208
+ def tf_function(self):
209
+ # Lazy import
210
+ from . import tf_integration
211
+
212
+ if self._tf_function is None:
213
+ self._tf_function = tf_integration.mlir_to_tf_function(self)
214
+ return self._tf_function
215
+
216
+ def __call__(self, *args):
217
+ # Lazy importing TF when execution is needed.
218
+ return self.tf_function(*args)
219
+
220
+ def to_flatbuffer(self):
221
+ from . import tf_integration
222
+
223
+ return tf_integration.mlir_to_flatbuffer(self)
224
+
225
+
226
+ def exported_program_to_mlir(
227
+ exported_program: torch.export.ExportedProgram,
228
+ ) -> MlirLowered:
229
+ """Lower the exported program to MLIR."""
230
+ if torch.__version__ >= "2.2":
231
+ # torch version 2.1 didn't expose this yet
232
+ exported_program = exported_program.run_decompositions()
233
+ exported_program = exported_program.run_decompositions(
234
+ lowerings.decompositions()
235
+ )
236
+
237
+ with export_utils.create_ir_context() as context, ir.Location.unknown():
238
+
239
+ module = ir.Module.create()
240
+ lctx = LoweringContext(context, module)
241
+ interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
242
+ ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
243
+ context, exported_program
244
+ )
245
+
246
+ # HACK: OSS MLIR pybinding could mysteriously transform func.func under
247
+ # construction into a func.return op after calling ir.Module.parse(..)
248
+ # in the context, which happens in JAX bridge. This is a bug in MLIR
249
+ # pybinding.
250
+ # Workaround steps:
251
+ # 1. Create a temp func.func.
252
+ # 2. Create and insert ops to temp's entry block. During the process
253
+ # the temp func.func would be broken, but the ops in the block are fine.
254
+ # 3. Create the main func.func and copy all the ops in temp's entry block
255
+ # to main.
256
+ # 4. Erase the temp func.func.
257
+ temp_func = func.FuncOp(
258
+ "temp",
259
+ ir.FunctionType.get(ir_flat_inputs, []),
260
+ ip=ir.InsertionPoint.at_block_begin(module.body),
261
+ )
262
+ with ir.InsertionPoint(temp_func.add_entry_block()):
263
+ interpreter.run(*temp_func.arguments, enable_io_processing=False)
264
+ num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
265
+ outputs = interpreter.outputs[num_mutations:]
266
+ func.ReturnOp(interpreter.outputs[num_mutations:])
267
+
268
+ main_func = func.FuncOp(
269
+ "main",
270
+ ir.FunctionType.get(ir_flat_inputs, [o.type for o in outputs]),
271
+ ip=ir.InsertionPoint.at_block_begin(module.body),
272
+ )
273
+ with ir.InsertionPoint(main_func.add_entry_block()):
274
+ outputs = export_utils.clone_func_body_ops(temp_func, main_func.arguments)
275
+ func.ReturnOp(outputs)
276
+
277
+ main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
278
+ temp_func.erase()
279
+
280
+ module.operation.verify()
281
+
282
+ input_signature = []
283
+ state_dict = {}
284
+
285
+ user_inputs_cnt = 0
286
+ for arg, tensor_meta, input_spec in zip(
287
+ export_flat_args,
288
+ tensor_metas,
289
+ exported_program.graph_signature.input_specs,
290
+ ):
291
+ # Assumption:
292
+ # All states comes first in the list of args, and user provided inputs
293
+ # comes later. Also there is no kwargs.
294
+ if input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT:
295
+ input_signature.append(
296
+ VariableSignature(
297
+ tensor_meta.shape,
298
+ tensor_meta.dtype,
299
+ input_spec=InputSpec.user_input(user_inputs_cnt),
300
+ )
301
+ )
302
+ user_inputs_cnt += 1
303
+ else:
304
+ # Parameter or constant
305
+ state_dict[input_spec.target] = arg
306
+ input_signature.append(
307
+ VariableSignature(
308
+ tensor_meta.shape,
309
+ tensor_meta.dtype,
310
+ input_spec=InputSpec.parameter(input_spec.target),
311
+ )
312
+ )
313
+
314
+ output_signature = [
315
+ VariableSignature(tensor_meta.shape, tensor_meta.dtype)
316
+ for tensor_meta in _get_output_metas(exported_program)
317
+ ]
318
+ return MlirLowered(
319
+ context, module, state_dict, input_signature, output_signature
320
+ )