ai-edge-torch-nightly 0.3.0.dev20240828__py3-none-any.whl → 0.3.0.dev20240830__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
- ai_edge_torch/_convert/test/test_convert.py +1 -1
- ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
- ai_edge_torch/_convert/test/test_convert_multisig.py +1 -1
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
- ai_edge_torch/debug/test/test_culprit.py +1 -1
- ai_edge_torch/debug/test/test_search_model.py +1 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_quantize.py +1 -1
- ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
- ai_edge_torch/lowertools/test_utils.py +1 -1
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +320 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/RECORD +45 -21
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240830.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""APIs to convert and lower a PyTorch ExportedProgram to MLIR."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import enum
|
|
19
|
+
import io
|
|
20
|
+
import operator
|
|
21
|
+
from typing import Any, Callable, Optional
|
|
22
|
+
|
|
23
|
+
from jax.lib import xla_extension
|
|
24
|
+
from jax._src.lib.mlir import ir
|
|
25
|
+
from jax._src.lib.mlir.dialects import func
|
|
26
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
|
27
|
+
import torch
|
|
28
|
+
import torch.utils._pytree as pytree
|
|
29
|
+
|
|
30
|
+
from . import _torch_future
|
|
31
|
+
from . import debuginfo
|
|
32
|
+
from . import export_utils
|
|
33
|
+
from . import lowerings
|
|
34
|
+
|
|
35
|
+
LoweringContext = lowerings.context.LoweringContext
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _build_flat_inputs(
|
|
39
|
+
ctx: ir.Context, exported_program: torch.export.ExportedProgram
|
|
40
|
+
):
|
|
41
|
+
"""Build flattened inputs and metadata from exported program's signature."""
|
|
42
|
+
placeholder_nodes = [
|
|
43
|
+
n for n in exported_program.graph.nodes if n.op == "placeholder"
|
|
44
|
+
]
|
|
45
|
+
export_flat_args = _torch_future.graph_module_flat_inputs(
|
|
46
|
+
exported_program, *exported_program.example_inputs
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
ir_inputs = []
|
|
50
|
+
tensor_metas = []
|
|
51
|
+
for node, arg in zip(placeholder_nodes, export_flat_args):
|
|
52
|
+
tensor_meta = node.meta.get("tensor_meta")
|
|
53
|
+
if tensor_meta is None:
|
|
54
|
+
raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
|
|
55
|
+
|
|
56
|
+
tensor_metas.append(tensor_meta)
|
|
57
|
+
# Assume all dynamic dimensions are unbounded.
|
|
58
|
+
# TODO: Add checks for ep.range_constraints in MLIR.
|
|
59
|
+
shape = tuple(
|
|
60
|
+
export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(s) else s
|
|
61
|
+
for s in tensor_meta.shape
|
|
62
|
+
)
|
|
63
|
+
ir_inputs.append(
|
|
64
|
+
ir.RankedTensorType.get(
|
|
65
|
+
shape,
|
|
66
|
+
export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_output_metas(exported_program: torch.export.ExportedProgram):
|
|
73
|
+
"""Get the output node's tensor_meta from the exported program."""
|
|
74
|
+
outputs = [n for n in exported_program.graph.nodes if n.op == "output"]
|
|
75
|
+
assert len(outputs) == 1
|
|
76
|
+
outputs, _ = pytree.tree_flatten(outputs[0].args[0])
|
|
77
|
+
assert all(isinstance(output, torch.fx.Node) for output in outputs)
|
|
78
|
+
return tuple(output.meta["tensor_meta"] for output in outputs)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class LoweringInterpreter(torch.fx.Interpreter):
|
|
82
|
+
"""The FX interpreter to iterate and invoke corresponding lowering for each PyTorch op in the graph."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, module: torch.fx.GraphModule, lctx: LoweringContext):
|
|
85
|
+
super().__init__(module)
|
|
86
|
+
self.lctx = lctx
|
|
87
|
+
self.outputs = None
|
|
88
|
+
|
|
89
|
+
def _build_loc(self, node: torch.fx.Node):
|
|
90
|
+
|
|
91
|
+
info = debuginfo.build_mlir_debuginfo(node)
|
|
92
|
+
if info is None:
|
|
93
|
+
return ir.Location.unknown()
|
|
94
|
+
|
|
95
|
+
return ir.Location.name(name=info)
|
|
96
|
+
|
|
97
|
+
def run_node(self, node: torch.fx.Node):
|
|
98
|
+
loc = self._build_loc(node)
|
|
99
|
+
with loc:
|
|
100
|
+
self.lctx = self.lctx.replace(ir_location=loc, node=node)
|
|
101
|
+
res = super().run_node(node)
|
|
102
|
+
self.lctx = self.lctx.replace(ir_location=None, node=None)
|
|
103
|
+
return res
|
|
104
|
+
|
|
105
|
+
def call_function(self, target, args, kwargs):
|
|
106
|
+
if target is operator.getitem:
|
|
107
|
+
return super().call_function(target, args, kwargs)
|
|
108
|
+
|
|
109
|
+
if hasattr(target, "_schema"):
|
|
110
|
+
new_args = []
|
|
111
|
+
for arg, spec in zip(args, target._schema.arguments):
|
|
112
|
+
if isinstance(spec.type, torch.TensorType):
|
|
113
|
+
if isinstance(arg, int):
|
|
114
|
+
arg = lowerings.utils.splat(arg, ir.IntegerType.get_signless(32))
|
|
115
|
+
elif isinstance(arg, float):
|
|
116
|
+
arg = lowerings.utils.splat(arg, ir.F32Type.get())
|
|
117
|
+
|
|
118
|
+
new_args.append(arg)
|
|
119
|
+
args = tuple(new_args)
|
|
120
|
+
|
|
121
|
+
lowering = lowerings.lookup(target)
|
|
122
|
+
if lowering is None:
|
|
123
|
+
raise RuntimeError(f"Lowering not found: {target}")
|
|
124
|
+
return lowering(self.lctx, *args, **kwargs)
|
|
125
|
+
|
|
126
|
+
def output(self, target, args, kwargs):
|
|
127
|
+
flat_outputs = pytree.tree_flatten(args[0])[0]
|
|
128
|
+
self.outputs = flat_outputs
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclasses.dataclass
|
|
132
|
+
class InputSpec:
|
|
133
|
+
|
|
134
|
+
class VariableType(enum.Enum):
|
|
135
|
+
USER_INPUT = "user_input"
|
|
136
|
+
PARAMETER = "parameter"
|
|
137
|
+
|
|
138
|
+
type_: VariableType
|
|
139
|
+
i: int = -1
|
|
140
|
+
name: str = ""
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def parameter(cls, name: str):
|
|
144
|
+
return cls(type_=cls.VariableType.PARAMETER, name=name)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def user_input(cls, i: int):
|
|
148
|
+
return cls(type_=cls.VariableType.USER_INPUT, i=i)
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def is_parameter(self):
|
|
152
|
+
return self.type_ == self.VariableType.PARAMETER
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def is_user_input(self):
|
|
156
|
+
return self.type_ == self.VariableType.USER_INPUT
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclasses.dataclass
|
|
160
|
+
class VariableSignature: # either argument or parameters
|
|
161
|
+
shape: list[int]
|
|
162
|
+
dtype: str
|
|
163
|
+
input_spec: InputSpec = None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclasses.dataclass
|
|
167
|
+
class MlirLowered:
|
|
168
|
+
"""The lowered MLIR module, metadata, and weight tensors bundle from exported program."""
|
|
169
|
+
|
|
170
|
+
ctx: ir.Context
|
|
171
|
+
module: ir.Module
|
|
172
|
+
state_dict: dict[str, torch.Tensor]
|
|
173
|
+
input_signature: list[VariableSignature]
|
|
174
|
+
output_signature: list[VariableSignature]
|
|
175
|
+
|
|
176
|
+
_tf_function: Optional[Callable[Any, Any]] = None
|
|
177
|
+
|
|
178
|
+
def __str__(self):
|
|
179
|
+
return str(self.get_text(enable_debug_info=False))
|
|
180
|
+
|
|
181
|
+
def __repr__(self):
|
|
182
|
+
return str(self.get_text(enable_debug_info=False))
|
|
183
|
+
|
|
184
|
+
def get_text(self, enable_debug_info=False):
|
|
185
|
+
return str(
|
|
186
|
+
self.module.operation.get_asm(enable_debug_info=enable_debug_info)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def module_bytecode(self) -> bytes:
|
|
191
|
+
output = io.BytesIO()
|
|
192
|
+
self.module.operation.write_bytecode(file=output)
|
|
193
|
+
return output.getvalue()
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def module_bytecode_vhlo(self) -> bytes:
|
|
197
|
+
# HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
|
|
198
|
+
# build, which may not have the same StableHLO version as what used in
|
|
199
|
+
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
|
|
200
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
|
201
|
+
target_version = stablehlo.get_minimum_version()
|
|
202
|
+
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
|
|
203
|
+
self.module_bytecode, target_version
|
|
204
|
+
)
|
|
205
|
+
return module_bytecode
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def tf_function(self):
|
|
209
|
+
# Lazy import
|
|
210
|
+
from . import tf_integration
|
|
211
|
+
|
|
212
|
+
if self._tf_function is None:
|
|
213
|
+
self._tf_function = tf_integration.mlir_to_tf_function(self)
|
|
214
|
+
return self._tf_function
|
|
215
|
+
|
|
216
|
+
def __call__(self, *args):
|
|
217
|
+
# Lazy importing TF when execution is needed.
|
|
218
|
+
return self.tf_function(*args)
|
|
219
|
+
|
|
220
|
+
def to_flatbuffer(self):
|
|
221
|
+
from . import tf_integration
|
|
222
|
+
|
|
223
|
+
return tf_integration.mlir_to_flatbuffer(self)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def exported_program_to_mlir(
|
|
227
|
+
exported_program: torch.export.ExportedProgram,
|
|
228
|
+
) -> MlirLowered:
|
|
229
|
+
"""Lower the exported program to MLIR."""
|
|
230
|
+
if torch.__version__ >= "2.2":
|
|
231
|
+
# torch version 2.1 didn't expose this yet
|
|
232
|
+
exported_program = exported_program.run_decompositions()
|
|
233
|
+
exported_program = exported_program.run_decompositions(
|
|
234
|
+
lowerings.decompositions()
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
|
238
|
+
|
|
239
|
+
module = ir.Module.create()
|
|
240
|
+
lctx = LoweringContext(context, module)
|
|
241
|
+
interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
|
|
242
|
+
ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
|
|
243
|
+
context, exported_program
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# HACK: OSS MLIR pybinding could mysteriously transform func.func under
|
|
247
|
+
# construction into a func.return op after calling ir.Module.parse(..)
|
|
248
|
+
# in the context, which happens in JAX bridge. This is a bug in MLIR
|
|
249
|
+
# pybinding.
|
|
250
|
+
# Workaround steps:
|
|
251
|
+
# 1. Create a temp func.func.
|
|
252
|
+
# 2. Create and insert ops to temp's entry block. During the process
|
|
253
|
+
# the temp func.func would be broken, but the ops in the block are fine.
|
|
254
|
+
# 3. Create the main func.func and copy all the ops in temp's entry block
|
|
255
|
+
# to main.
|
|
256
|
+
# 4. Erase the temp func.func.
|
|
257
|
+
temp_func = func.FuncOp(
|
|
258
|
+
"temp",
|
|
259
|
+
ir.FunctionType.get(ir_flat_inputs, []),
|
|
260
|
+
ip=ir.InsertionPoint.at_block_begin(module.body),
|
|
261
|
+
)
|
|
262
|
+
with ir.InsertionPoint(temp_func.add_entry_block()):
|
|
263
|
+
interpreter.run(*temp_func.arguments, enable_io_processing=False)
|
|
264
|
+
num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
|
|
265
|
+
outputs = interpreter.outputs[num_mutations:]
|
|
266
|
+
func.ReturnOp(interpreter.outputs[num_mutations:])
|
|
267
|
+
|
|
268
|
+
main_func = func.FuncOp(
|
|
269
|
+
"main",
|
|
270
|
+
ir.FunctionType.get(ir_flat_inputs, [o.type for o in outputs]),
|
|
271
|
+
ip=ir.InsertionPoint.at_block_begin(module.body),
|
|
272
|
+
)
|
|
273
|
+
with ir.InsertionPoint(main_func.add_entry_block()):
|
|
274
|
+
outputs = export_utils.clone_func_body_ops(temp_func, main_func.arguments)
|
|
275
|
+
func.ReturnOp(outputs)
|
|
276
|
+
|
|
277
|
+
main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
|
|
278
|
+
temp_func.erase()
|
|
279
|
+
|
|
280
|
+
module.operation.verify()
|
|
281
|
+
|
|
282
|
+
input_signature = []
|
|
283
|
+
state_dict = {}
|
|
284
|
+
|
|
285
|
+
user_inputs_cnt = 0
|
|
286
|
+
for arg, tensor_meta, input_spec in zip(
|
|
287
|
+
export_flat_args,
|
|
288
|
+
tensor_metas,
|
|
289
|
+
exported_program.graph_signature.input_specs,
|
|
290
|
+
):
|
|
291
|
+
# Assumption:
|
|
292
|
+
# All states comes first in the list of args, and user provided inputs
|
|
293
|
+
# comes later. Also there is no kwargs.
|
|
294
|
+
if input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT:
|
|
295
|
+
input_signature.append(
|
|
296
|
+
VariableSignature(
|
|
297
|
+
tensor_meta.shape,
|
|
298
|
+
tensor_meta.dtype,
|
|
299
|
+
input_spec=InputSpec.user_input(user_inputs_cnt),
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
user_inputs_cnt += 1
|
|
303
|
+
else:
|
|
304
|
+
# Parameter or constant
|
|
305
|
+
state_dict[input_spec.target] = arg
|
|
306
|
+
input_signature.append(
|
|
307
|
+
VariableSignature(
|
|
308
|
+
tensor_meta.shape,
|
|
309
|
+
tensor_meta.dtype,
|
|
310
|
+
input_spec=InputSpec.parameter(input_spec.target),
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
output_signature = [
|
|
315
|
+
VariableSignature(tensor_meta.shape, tensor_meta.dtype)
|
|
316
|
+
for tensor_meta in _get_output_metas(exported_program)
|
|
317
|
+
]
|
|
318
|
+
return MlirLowered(
|
|
319
|
+
context, module, state_dict, input_signature, output_signature
|
|
320
|
+
)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utilities for ODML Torch export."""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
import re
|
|
19
|
+
from typing import Sequence, cast
|
|
20
|
+
import jax._src.interpreters.mlir
|
|
21
|
+
from jax._src.lib.mlir import ir
|
|
22
|
+
from jax._src.lib.mlir.dialects import func
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# std::numeric_limits<int64_t>::min()
|
|
26
|
+
IR_DYNAMIC = -9223372036854775808
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_ir_dynamic(v):
|
|
30
|
+
return v == IR_DYNAMIC
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_torch_dynamic(v):
|
|
34
|
+
return isinstance(v, torch.SymInt)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def is_iterable(v):
|
|
38
|
+
try:
|
|
39
|
+
iter(v)
|
|
40
|
+
except TypeError:
|
|
41
|
+
return False
|
|
42
|
+
return True
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_ir_context():
|
|
46
|
+
# HACK: Use ir context from JAX as base for better stability in OSS.
|
|
47
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
|
48
|
+
context = jax._src.interpreters.mlir.make_ir_context()
|
|
49
|
+
context.allow_unregistered_dialects = True
|
|
50
|
+
|
|
51
|
+
return context
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def inline(
|
|
55
|
+
symbol_table: ir.SymbolTable,
|
|
56
|
+
block: ir.Block,
|
|
57
|
+
):
|
|
58
|
+
"""Recursively inlines all func.call ops in the block.
|
|
59
|
+
|
|
60
|
+
The symbol_table must include all func.func called by func.call ops.
|
|
61
|
+
This inliner in Python is implemented because MLIR inline pass from JAX's
|
|
62
|
+
MLIR pybinding build in OSS cannot properly inline func.call ops.
|
|
63
|
+
"""
|
|
64
|
+
while True:
|
|
65
|
+
is_changed = False
|
|
66
|
+
for op in block.operations:
|
|
67
|
+
if op.OPERATION_NAME != func.CallOp.OPERATION_NAME:
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
call_op = cast(func.CallOp, op)
|
|
71
|
+
func_op = cast(func.FuncOp, symbol_table[call_op.callee.value])
|
|
72
|
+
with ir.InsertionPoint(op):
|
|
73
|
+
new_results = clone_func_body_ops(func_op, call_op.operands)
|
|
74
|
+
|
|
75
|
+
for old_result, new_result in zip(call_op.results, new_results):
|
|
76
|
+
old_result = cast(ir.Value, old_result)
|
|
77
|
+
old_result.replace_all_uses_with(new_result)
|
|
78
|
+
call_op.erase()
|
|
79
|
+
is_changed = True
|
|
80
|
+
|
|
81
|
+
if not is_changed:
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
for op in block.operations:
|
|
85
|
+
for region in op.regions:
|
|
86
|
+
for block in region.blocks:
|
|
87
|
+
inline(symbol_table, block)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def clone_func_body_ops(func_op: func.FuncOp, ir_inputs: Sequence[ir.Value]):
|
|
91
|
+
"""Clone operations in the func_op's body by one into the current context."""
|
|
92
|
+
func_args = list(func_op.arguments)
|
|
93
|
+
ir_inputs = list(ir_inputs)
|
|
94
|
+
assert len(func_args) == len(ir_inputs)
|
|
95
|
+
|
|
96
|
+
value_mapping = {arg: ir_input for arg, ir_input in zip(func_args, ir_inputs)}
|
|
97
|
+
|
|
98
|
+
for op in list(func_op.entry_block.operations):
|
|
99
|
+
cloned_operands = [value_mapping[val] for val in op.operands]
|
|
100
|
+
if op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME:
|
|
101
|
+
return cloned_operands
|
|
102
|
+
|
|
103
|
+
cloned = cast(ir.Operation, op.operation.clone())
|
|
104
|
+
|
|
105
|
+
for i in range(len(op.operands)):
|
|
106
|
+
cloned.operands[i] = cloned_operands[i]
|
|
107
|
+
|
|
108
|
+
for i in range(len(op.results)):
|
|
109
|
+
value_mapping[op.results[i]] = cloned.results[i]
|
|
110
|
+
|
|
111
|
+
return []
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def sanitize_aten_op_name(op, chars=":."):
|
|
115
|
+
return re.sub("[{}]".format(chars), "_", str(op))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def build_ir_attr(val):
|
|
119
|
+
if val is None:
|
|
120
|
+
return ir.StringAttr.get("py_None")
|
|
121
|
+
if isinstance(val, bool):
|
|
122
|
+
return ir.BoolAttr.get(val)
|
|
123
|
+
if isinstance(val, int):
|
|
124
|
+
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
|
|
125
|
+
if isinstance(val, float):
|
|
126
|
+
return ir.BoolAttr.get(val)
|
|
127
|
+
if isinstance(val, str):
|
|
128
|
+
return ir.StringAttr.get(val)
|
|
129
|
+
if isinstance(val, dict):
|
|
130
|
+
return ir.DictAttr.get({k: build_ir_attr(v) for k, v in val.items()})
|
|
131
|
+
if isinstance(val, (list, tuple)):
|
|
132
|
+
return ir.ArrayAttr.get([build_ir_attr(v) for v in val])
|
|
133
|
+
|
|
134
|
+
# Stringify the value to a StringAttr by default
|
|
135
|
+
return ir.StringAttr.get(str(val))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def torch_dtype_to_ir_element_type(ctx, dtype):
|
|
139
|
+
ty_get = {
|
|
140
|
+
torch.double: ir.F64Type.get,
|
|
141
|
+
torch.float32: ir.F32Type.get,
|
|
142
|
+
torch.half: ir.F16Type.get,
|
|
143
|
+
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
|
|
144
|
+
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
|
|
145
|
+
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
|
146
|
+
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
|
147
|
+
}.get(dtype)
|
|
148
|
+
return ty_get(ctx)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def ir_element_type_to_torch_dtype(ty):
|
|
152
|
+
if isinstance(ty, ir.F32Type):
|
|
153
|
+
return torch.float32
|
|
154
|
+
if isinstance(ty, ir.F64Type):
|
|
155
|
+
return torch.float64
|
|
156
|
+
if isinstance(ty, ir.F16Type):
|
|
157
|
+
return torch.half
|
|
158
|
+
if isinstance(ty, ir.IntegerType):
|
|
159
|
+
if ty.is_signless:
|
|
160
|
+
if ty.width == 64:
|
|
161
|
+
return torch.long
|
|
162
|
+
if ty.width == 32:
|
|
163
|
+
return torch.int32
|
|
164
|
+
if ty.width == 16:
|
|
165
|
+
return torch.int16
|
|
166
|
+
if ty.width == 1:
|
|
167
|
+
return torch.bool
|
|
168
|
+
raise RuntimeError(f"Unsupported ir element type: {ty}")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""APIs to wrap JAX functions for using in ODML Torch lowerings."""
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
import inspect
|
|
19
|
+
from typing import Any, Callable, cast
|
|
20
|
+
import uuid
|
|
21
|
+
from ai_edge_torch.odml_torch import export_utils
|
|
22
|
+
from ai_edge_torch.odml_torch import passes
|
|
23
|
+
from ai_edge_torch.odml_torch.jax_bridge import utils
|
|
24
|
+
import jax
|
|
25
|
+
from jax._src.lib.mlir import ir
|
|
26
|
+
from jax._src.lib.mlir.dialects import func
|
|
27
|
+
import torch.utils._pytree as pytree
|
|
28
|
+
|
|
29
|
+
# Jax double (64bit) precision is required to generate StableHLO mlir with
|
|
30
|
+
# i64/f64 tensors from Jax bridged lowerings. If not set properly, all the
|
|
31
|
+
# 64bit tensors would be truncated to 32bit dtype and potentially break the
|
|
32
|
+
# lowering.
|
|
33
|
+
jax.config.update("jax_enable_x64", True)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _lower_to_ir_text(
|
|
37
|
+
jaxfn, args, kwargs, ir_input_names: list[str] = None
|
|
38
|
+
) -> str:
|
|
39
|
+
args = utils.tree_map_list_to_tuple(args)
|
|
40
|
+
kwargs = utils.tree_map_list_to_tuple(kwargs)
|
|
41
|
+
|
|
42
|
+
names_args = [
|
|
43
|
+
*zip(inspect.signature(jaxfn).parameters.keys(), args),
|
|
44
|
+
*kwargs.items(),
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
static_argnames = []
|
|
48
|
+
jax_lower_static_kwargs = {}
|
|
49
|
+
jax_lower_args = []
|
|
50
|
+
jax_lower_argnames = []
|
|
51
|
+
ir_inputs = []
|
|
52
|
+
|
|
53
|
+
for i, (name, arg) in enumerate(names_args):
|
|
54
|
+
is_positional = i < len(args)
|
|
55
|
+
if not utils.is_ir_variable(arg):
|
|
56
|
+
static_argnames.append(name)
|
|
57
|
+
jax_lower_static_kwargs[name] = arg
|
|
58
|
+
else:
|
|
59
|
+
# Enforce the arg order in the mlir is the same as the lowering func
|
|
60
|
+
jax_lower_args.append(utils.ir_variable_to_jax(arg))
|
|
61
|
+
|
|
62
|
+
if is_positional and len(jax_lower_args) == i + 1:
|
|
63
|
+
# The first N continuous tensor args are passed to the lowering func
|
|
64
|
+
# as positional args, when they passed to the bridged func as
|
|
65
|
+
# positional args also.
|
|
66
|
+
jax_lower_argnames.append(None)
|
|
67
|
+
else:
|
|
68
|
+
# Otherwise pass the arg to the lowering func as keyword arg.
|
|
69
|
+
jax_lower_argnames.append(name)
|
|
70
|
+
|
|
71
|
+
if ir_input_names is None or name in ir_input_names:
|
|
72
|
+
# ir variable can be a nested tuple, while mlir args should be flat.
|
|
73
|
+
ir_inputs += [
|
|
74
|
+
x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
def new_lowering(*args, **jax_lower_static_kwargs):
|
|
78
|
+
jaxfn_args = []
|
|
79
|
+
jaxfn_kwargs = jax_lower_static_kwargs.copy()
|
|
80
|
+
for name, arg in zip(jax_lower_argnames, args):
|
|
81
|
+
if name is None:
|
|
82
|
+
jaxfn_args.append(arg)
|
|
83
|
+
else:
|
|
84
|
+
jaxfn_kwargs[name] = arg
|
|
85
|
+
|
|
86
|
+
return jaxfn(*jaxfn_args, **jaxfn_kwargs)
|
|
87
|
+
|
|
88
|
+
return (
|
|
89
|
+
jax.jit(new_lowering, static_argnames=static_argnames)
|
|
90
|
+
.lower(*jax_lower_args, **jax_lower_static_kwargs)
|
|
91
|
+
.as_text()
|
|
92
|
+
), ir_inputs
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
|
96
|
+
"""Return the wrapped JAX function to be used in ODMLTorch lowerings.
|
|
97
|
+
|
|
98
|
+
If the given jaxfn has signature `jaxfn(*args, **kwargs) -> return`, the
|
|
99
|
+
wrapped function would:
|
|
100
|
+
- Have signature `wrapped(lctx: odml_torch.export.LoweringContext, *args,
|
|
101
|
+
**kwargs) -> return`.
|
|
102
|
+
- Accept mlir.ir.Value for all params expecting jax.Array as inputs.
|
|
103
|
+
- Return mlir.ir.Value for all jax.Array outputs from jaxfn.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
jaxfn: The JAX function to be wrapped.
|
|
107
|
+
ir_input_names: The input (param) names of the JAX function to be used in
|
|
108
|
+
the MLIR lowering. This is useful when the JAX impl only depends on
|
|
109
|
+
specific inputs to the function. If not specified, all ir.Value passed to
|
|
110
|
+
the wrapped function are assumed to be used in the lowering.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
@functools.wraps(jaxfn)
|
|
114
|
+
def wrapped(lctx, *args, **kwargs):
|
|
115
|
+
|
|
116
|
+
ir_text, ir_inputs = _lower_to_ir_text(
|
|
117
|
+
jaxfn,
|
|
118
|
+
args,
|
|
119
|
+
kwargs,
|
|
120
|
+
ir_input_names=ir_input_names,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
module = ir.Module.parse(ir_text)
|
|
124
|
+
passes.strip_debuginfo(module)
|
|
125
|
+
|
|
126
|
+
symbol_table = ir.SymbolTable(module.operation)
|
|
127
|
+
main_func = symbol_table["main"]
|
|
128
|
+
|
|
129
|
+
with ir.InsertionPoint.at_block_begin(lctx.ir_module.body):
|
|
130
|
+
cloned_func = cast(func.FuncOp, main_func.clone())
|
|
131
|
+
cloned_func_name = f"{jaxfn.__name__}_{uuid.uuid4().hex[:8]}"
|
|
132
|
+
cloned_func.attributes["sym_name"] = ir.StringAttr.get(cloned_func_name)
|
|
133
|
+
cloned_func.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
|
134
|
+
|
|
135
|
+
# HACK: Use the custom inliner implemented in Python because MLIR inline
|
|
136
|
+
# pass from JAX's MLIR pybinding build in OSS cannot properly inline
|
|
137
|
+
# func.call ops.
|
|
138
|
+
# This should be switched to `passes.inline(module)` when we have our own
|
|
139
|
+
# MLIR pybinding build.
|
|
140
|
+
export_utils.inline(symbol_table, cloned_func.entry_block)
|
|
141
|
+
|
|
142
|
+
if not cloned_func.arguments:
|
|
143
|
+
# Known edge case: when the lowering does not depend on input but
|
|
144
|
+
# just the meta of input like shape or dtype.
|
|
145
|
+
ir_inputs = []
|
|
146
|
+
|
|
147
|
+
results = func.CallOp(cloned_func, ir_inputs).results
|
|
148
|
+
if len(results) == 1:
|
|
149
|
+
return results[0]
|
|
150
|
+
return results
|
|
151
|
+
|
|
152
|
+
return wrapped
|