ai-edge-torch-nightly 0.3.0.dev20240828__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (45) hide show
  1. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
  2. ai_edge_torch/_convert/test/test_convert.py +1 -1
  3. ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
  4. ai_edge_torch/_convert/test/test_convert_multisig.py +1 -1
  5. ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
  6. ai_edge_torch/debug/test/test_culprit.py +1 -1
  7. ai_edge_torch/debug/test/test_search_model.py +1 -1
  8. ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
  9. ai_edge_torch/generative/test/test_loader.py +1 -1
  10. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  11. ai_edge_torch/generative/test/test_quantize.py +1 -1
  12. ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
  13. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
  14. ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
  15. ai_edge_torch/lowertools/test_utils.py +1 -1
  16. ai_edge_torch/odml_torch/__init__.py +20 -0
  17. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  18. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  19. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  20. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  21. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  22. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  23. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  24. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  25. ai_edge_torch/odml_torch/export.py +320 -0
  26. ai_edge_torch/odml_torch/export_utils.py +168 -0
  27. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  28. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
  29. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
  31. ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
  32. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  33. ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
  34. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
  35. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  36. ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
  37. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  38. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  39. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  40. ai_edge_torch/version.py +1 -1
  41. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
  42. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +45 -21
  43. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
  44. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20240828.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,320 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """APIs to convert and lower a PyTorch ExportedProgram to MLIR."""
16
+
17
+ import dataclasses
18
+ import enum
19
+ import io
20
+ import operator
21
+ from typing import Any, Callable, Optional
22
+
23
+ from jax.lib import xla_extension
24
+ from jax._src.lib.mlir import ir
25
+ from jax._src.lib.mlir.dialects import func
26
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
27
+ import torch
28
+ import torch.utils._pytree as pytree
29
+
30
+ from . import _torch_future
31
+ from . import debuginfo
32
+ from . import export_utils
33
+ from . import lowerings
34
+
35
+ LoweringContext = lowerings.context.LoweringContext
36
+
37
+
38
+ def _build_flat_inputs(
39
+ ctx: ir.Context, exported_program: torch.export.ExportedProgram
40
+ ):
41
+ """Build flattened inputs and metadata from exported program's signature."""
42
+ placeholder_nodes = [
43
+ n for n in exported_program.graph.nodes if n.op == "placeholder"
44
+ ]
45
+ export_flat_args = _torch_future.graph_module_flat_inputs(
46
+ exported_program, *exported_program.example_inputs
47
+ )
48
+
49
+ ir_inputs = []
50
+ tensor_metas = []
51
+ for node, arg in zip(placeholder_nodes, export_flat_args):
52
+ tensor_meta = node.meta.get("tensor_meta")
53
+ if tensor_meta is None:
54
+ raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
55
+
56
+ tensor_metas.append(tensor_meta)
57
+ # Assume all dynamic dimensions are unbounded.
58
+ # TODO: Add checks for ep.range_constraints in MLIR.
59
+ shape = tuple(
60
+ export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(s) else s
61
+ for s in tensor_meta.shape
62
+ )
63
+ ir_inputs.append(
64
+ ir.RankedTensorType.get(
65
+ shape,
66
+ export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
67
+ )
68
+ )
69
+ return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
70
+
71
+
72
+ def _get_output_metas(exported_program: torch.export.ExportedProgram):
73
+ """Get the output node's tensor_meta from the exported program."""
74
+ outputs = [n for n in exported_program.graph.nodes if n.op == "output"]
75
+ assert len(outputs) == 1
76
+ outputs, _ = pytree.tree_flatten(outputs[0].args[0])
77
+ assert all(isinstance(output, torch.fx.Node) for output in outputs)
78
+ return tuple(output.meta["tensor_meta"] for output in outputs)
79
+
80
+
81
+ class LoweringInterpreter(torch.fx.Interpreter):
82
+ """The FX interpreter to iterate and invoke corresponding lowering for each PyTorch op in the graph."""
83
+
84
+ def __init__(self, module: torch.fx.GraphModule, lctx: LoweringContext):
85
+ super().__init__(module)
86
+ self.lctx = lctx
87
+ self.outputs = None
88
+
89
+ def _build_loc(self, node: torch.fx.Node):
90
+
91
+ info = debuginfo.build_mlir_debuginfo(node)
92
+ if info is None:
93
+ return ir.Location.unknown()
94
+
95
+ return ir.Location.name(name=info)
96
+
97
+ def run_node(self, node: torch.fx.Node):
98
+ loc = self._build_loc(node)
99
+ with loc:
100
+ self.lctx = self.lctx.replace(ir_location=loc, node=node)
101
+ res = super().run_node(node)
102
+ self.lctx = self.lctx.replace(ir_location=None, node=None)
103
+ return res
104
+
105
+ def call_function(self, target, args, kwargs):
106
+ if target is operator.getitem:
107
+ return super().call_function(target, args, kwargs)
108
+
109
+ if hasattr(target, "_schema"):
110
+ new_args = []
111
+ for arg, spec in zip(args, target._schema.arguments):
112
+ if isinstance(spec.type, torch.TensorType):
113
+ if isinstance(arg, int):
114
+ arg = lowerings.utils.splat(arg, ir.IntegerType.get_signless(32))
115
+ elif isinstance(arg, float):
116
+ arg = lowerings.utils.splat(arg, ir.F32Type.get())
117
+
118
+ new_args.append(arg)
119
+ args = tuple(new_args)
120
+
121
+ lowering = lowerings.lookup(target)
122
+ if lowering is None:
123
+ raise RuntimeError(f"Lowering not found: {target}")
124
+ return lowering(self.lctx, *args, **kwargs)
125
+
126
+ def output(self, target, args, kwargs):
127
+ flat_outputs = pytree.tree_flatten(args[0])[0]
128
+ self.outputs = flat_outputs
129
+
130
+
131
+ @dataclasses.dataclass
132
+ class InputSpec:
133
+
134
+ class VariableType(enum.Enum):
135
+ USER_INPUT = "user_input"
136
+ PARAMETER = "parameter"
137
+
138
+ type_: VariableType
139
+ i: int = -1
140
+ name: str = ""
141
+
142
+ @classmethod
143
+ def parameter(cls, name: str):
144
+ return cls(type_=cls.VariableType.PARAMETER, name=name)
145
+
146
+ @classmethod
147
+ def user_input(cls, i: int):
148
+ return cls(type_=cls.VariableType.USER_INPUT, i=i)
149
+
150
+ @property
151
+ def is_parameter(self):
152
+ return self.type_ == self.VariableType.PARAMETER
153
+
154
+ @property
155
+ def is_user_input(self):
156
+ return self.type_ == self.VariableType.USER_INPUT
157
+
158
+
159
+ @dataclasses.dataclass
160
+ class VariableSignature: # either argument or parameters
161
+ shape: list[int]
162
+ dtype: str
163
+ input_spec: InputSpec = None
164
+
165
+
166
+ @dataclasses.dataclass
167
+ class MlirLowered:
168
+ """The lowered MLIR module, metadata, and weight tensors bundle from exported program."""
169
+
170
+ ctx: ir.Context
171
+ module: ir.Module
172
+ state_dict: dict[str, torch.Tensor]
173
+ input_signature: list[VariableSignature]
174
+ output_signature: list[VariableSignature]
175
+
176
+ _tf_function: Optional[Callable[Any, Any]] = None
177
+
178
+ def __str__(self):
179
+ return str(self.get_text(enable_debug_info=False))
180
+
181
+ def __repr__(self):
182
+ return str(self.get_text(enable_debug_info=False))
183
+
184
+ def get_text(self, enable_debug_info=False):
185
+ return str(
186
+ self.module.operation.get_asm(enable_debug_info=enable_debug_info)
187
+ )
188
+
189
+ @property
190
+ def module_bytecode(self) -> bytes:
191
+ output = io.BytesIO()
192
+ self.module.operation.write_bytecode(file=output)
193
+ return output.getvalue()
194
+
195
+ @property
196
+ def module_bytecode_vhlo(self) -> bytes:
197
+ # HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
198
+ # build, which may not have the same StableHLO version as what used in
199
+ # TFLite converter. Therefore we always serialize MLIR module in VHLO.
200
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201
+ target_version = stablehlo.get_minimum_version()
202
+ module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203
+ self.module_bytecode, target_version
204
+ )
205
+ return module_bytecode
206
+
207
+ @property
208
+ def tf_function(self):
209
+ # Lazy import
210
+ from . import tf_integration
211
+
212
+ if self._tf_function is None:
213
+ self._tf_function = tf_integration.mlir_to_tf_function(self)
214
+ return self._tf_function
215
+
216
+ def __call__(self, *args):
217
+ # Lazy importing TF when execution is needed.
218
+ return self.tf_function(*args)
219
+
220
+ def to_flatbuffer(self):
221
+ from . import tf_integration
222
+
223
+ return tf_integration.mlir_to_flatbuffer(self)
224
+
225
+
226
+ def exported_program_to_mlir(
227
+ exported_program: torch.export.ExportedProgram,
228
+ ) -> MlirLowered:
229
+ """Lower the exported program to MLIR."""
230
+ if torch.__version__ >= "2.2":
231
+ # torch version 2.1 didn't expose this yet
232
+ exported_program = exported_program.run_decompositions()
233
+ exported_program = exported_program.run_decompositions(
234
+ lowerings.decompositions()
235
+ )
236
+
237
+ with export_utils.create_ir_context() as context, ir.Location.unknown():
238
+
239
+ module = ir.Module.create()
240
+ lctx = LoweringContext(context, module)
241
+ interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
242
+ ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
243
+ context, exported_program
244
+ )
245
+
246
+ # HACK: OSS MLIR pybinding could mysteriously transform func.func under
247
+ # construction into a func.return op after calling ir.Module.parse(..)
248
+ # in the context, which happens in JAX bridge. This is a bug in MLIR
249
+ # pybinding.
250
+ # Workaround steps:
251
+ # 1. Create a temp func.func.
252
+ # 2. Create and insert ops to temp's entry block. During the process
253
+ # the temp func.func would be broken, but the ops in the block are fine.
254
+ # 3. Create the main func.func and copy all the ops in temp's entry block
255
+ # to main.
256
+ # 4. Erase the temp func.func.
257
+ temp_func = func.FuncOp(
258
+ "temp",
259
+ ir.FunctionType.get(ir_flat_inputs, []),
260
+ ip=ir.InsertionPoint.at_block_begin(module.body),
261
+ )
262
+ with ir.InsertionPoint(temp_func.add_entry_block()):
263
+ interpreter.run(*temp_func.arguments, enable_io_processing=False)
264
+ num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
265
+ outputs = interpreter.outputs[num_mutations:]
266
+ func.ReturnOp(interpreter.outputs[num_mutations:])
267
+
268
+ main_func = func.FuncOp(
269
+ "main",
270
+ ir.FunctionType.get(ir_flat_inputs, [o.type for o in outputs]),
271
+ ip=ir.InsertionPoint.at_block_begin(module.body),
272
+ )
273
+ with ir.InsertionPoint(main_func.add_entry_block()):
274
+ outputs = export_utils.clone_func_body_ops(temp_func, main_func.arguments)
275
+ func.ReturnOp(outputs)
276
+
277
+ main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
278
+ temp_func.erase()
279
+
280
+ module.operation.verify()
281
+
282
+ input_signature = []
283
+ state_dict = {}
284
+
285
+ user_inputs_cnt = 0
286
+ for arg, tensor_meta, input_spec in zip(
287
+ export_flat_args,
288
+ tensor_metas,
289
+ exported_program.graph_signature.input_specs,
290
+ ):
291
+ # Assumption:
292
+ # All states comes first in the list of args, and user provided inputs
293
+ # comes later. Also there is no kwargs.
294
+ if input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT:
295
+ input_signature.append(
296
+ VariableSignature(
297
+ tensor_meta.shape,
298
+ tensor_meta.dtype,
299
+ input_spec=InputSpec.user_input(user_inputs_cnt),
300
+ )
301
+ )
302
+ user_inputs_cnt += 1
303
+ else:
304
+ # Parameter or constant
305
+ state_dict[input_spec.target] = arg
306
+ input_signature.append(
307
+ VariableSignature(
308
+ tensor_meta.shape,
309
+ tensor_meta.dtype,
310
+ input_spec=InputSpec.parameter(input_spec.target),
311
+ )
312
+ )
313
+
314
+ output_signature = [
315
+ VariableSignature(tensor_meta.shape, tensor_meta.dtype)
316
+ for tensor_meta in _get_output_metas(exported_program)
317
+ ]
318
+ return MlirLowered(
319
+ context, module, state_dict, input_signature, output_signature
320
+ )
@@ -0,0 +1,168 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for ODML Torch export."""
16
+
17
+ import functools
18
+ import re
19
+ from typing import Sequence, cast
20
+ import jax._src.interpreters.mlir
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import func
23
+ import torch
24
+
25
+ # std::numeric_limits<int64_t>::min()
26
+ IR_DYNAMIC = -9223372036854775808
27
+
28
+
29
+ def is_ir_dynamic(v):
30
+ return v == IR_DYNAMIC
31
+
32
+
33
+ def is_torch_dynamic(v):
34
+ return isinstance(v, torch.SymInt)
35
+
36
+
37
+ def is_iterable(v):
38
+ try:
39
+ iter(v)
40
+ except TypeError:
41
+ return False
42
+ return True
43
+
44
+
45
+ def create_ir_context():
46
+ # HACK: Use ir context from JAX as base for better stability in OSS.
47
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
48
+ context = jax._src.interpreters.mlir.make_ir_context()
49
+ context.allow_unregistered_dialects = True
50
+
51
+ return context
52
+
53
+
54
+ def inline(
55
+ symbol_table: ir.SymbolTable,
56
+ block: ir.Block,
57
+ ):
58
+ """Recursively inlines all func.call ops in the block.
59
+
60
+ The symbol_table must include all func.func called by func.call ops.
61
+ This inliner in Python is implemented because MLIR inline pass from JAX's
62
+ MLIR pybinding build in OSS cannot properly inline func.call ops.
63
+ """
64
+ while True:
65
+ is_changed = False
66
+ for op in block.operations:
67
+ if op.OPERATION_NAME != func.CallOp.OPERATION_NAME:
68
+ continue
69
+
70
+ call_op = cast(func.CallOp, op)
71
+ func_op = cast(func.FuncOp, symbol_table[call_op.callee.value])
72
+ with ir.InsertionPoint(op):
73
+ new_results = clone_func_body_ops(func_op, call_op.operands)
74
+
75
+ for old_result, new_result in zip(call_op.results, new_results):
76
+ old_result = cast(ir.Value, old_result)
77
+ old_result.replace_all_uses_with(new_result)
78
+ call_op.erase()
79
+ is_changed = True
80
+
81
+ if not is_changed:
82
+ break
83
+
84
+ for op in block.operations:
85
+ for region in op.regions:
86
+ for block in region.blocks:
87
+ inline(symbol_table, block)
88
+
89
+
90
+ def clone_func_body_ops(func_op: func.FuncOp, ir_inputs: Sequence[ir.Value]):
91
+ """Clone operations in the func_op's body by one into the current context."""
92
+ func_args = list(func_op.arguments)
93
+ ir_inputs = list(ir_inputs)
94
+ assert len(func_args) == len(ir_inputs)
95
+
96
+ value_mapping = {arg: ir_input for arg, ir_input in zip(func_args, ir_inputs)}
97
+
98
+ for op in list(func_op.entry_block.operations):
99
+ cloned_operands = [value_mapping[val] for val in op.operands]
100
+ if op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME:
101
+ return cloned_operands
102
+
103
+ cloned = cast(ir.Operation, op.operation.clone())
104
+
105
+ for i in range(len(op.operands)):
106
+ cloned.operands[i] = cloned_operands[i]
107
+
108
+ for i in range(len(op.results)):
109
+ value_mapping[op.results[i]] = cloned.results[i]
110
+
111
+ return []
112
+
113
+
114
+ def sanitize_aten_op_name(op, chars=":."):
115
+ return re.sub("[{}]".format(chars), "_", str(op))
116
+
117
+
118
+ def build_ir_attr(val):
119
+ if val is None:
120
+ return ir.StringAttr.get("py_None")
121
+ if isinstance(val, bool):
122
+ return ir.BoolAttr.get(val)
123
+ if isinstance(val, int):
124
+ return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
125
+ if isinstance(val, float):
126
+ return ir.BoolAttr.get(val)
127
+ if isinstance(val, str):
128
+ return ir.StringAttr.get(val)
129
+ if isinstance(val, dict):
130
+ return ir.DictAttr.get({k: build_ir_attr(v) for k, v in val.items()})
131
+ if isinstance(val, (list, tuple)):
132
+ return ir.ArrayAttr.get([build_ir_attr(v) for v in val])
133
+
134
+ # Stringify the value to a StringAttr by default
135
+ return ir.StringAttr.get(str(val))
136
+
137
+
138
+ def torch_dtype_to_ir_element_type(ctx, dtype):
139
+ ty_get = {
140
+ torch.double: ir.F64Type.get,
141
+ torch.float32: ir.F32Type.get,
142
+ torch.half: ir.F16Type.get,
143
+ torch.long: functools.partial(ir.IntegerType.get_signless, 64),
144
+ torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
145
+ torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
146
+ torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
147
+ }.get(dtype)
148
+ return ty_get(ctx)
149
+
150
+
151
+ def ir_element_type_to_torch_dtype(ty):
152
+ if isinstance(ty, ir.F32Type):
153
+ return torch.float32
154
+ if isinstance(ty, ir.F64Type):
155
+ return torch.float64
156
+ if isinstance(ty, ir.F16Type):
157
+ return torch.half
158
+ if isinstance(ty, ir.IntegerType):
159
+ if ty.is_signless:
160
+ if ty.width == 64:
161
+ return torch.long
162
+ if ty.width == 32:
163
+ return torch.int32
164
+ if ty.width == 16:
165
+ return torch.int16
166
+ if ty.width == 1:
167
+ return torch.bool
168
+ raise RuntimeError(f"Unsupported ir element type: {ty}")
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
@@ -0,0 +1,152 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """APIs to wrap JAX functions for using in ODML Torch lowerings."""
16
+
17
+ import functools
18
+ import inspect
19
+ from typing import Any, Callable, cast
20
+ import uuid
21
+ from ai_edge_torch.odml_torch import export_utils
22
+ from ai_edge_torch.odml_torch import passes
23
+ from ai_edge_torch.odml_torch.jax_bridge import utils
24
+ import jax
25
+ from jax._src.lib.mlir import ir
26
+ from jax._src.lib.mlir.dialects import func
27
+ import torch.utils._pytree as pytree
28
+
29
+ # Jax double (64bit) precision is required to generate StableHLO mlir with
30
+ # i64/f64 tensors from Jax bridged lowerings. If not set properly, all the
31
+ # 64bit tensors would be truncated to 32bit dtype and potentially break the
32
+ # lowering.
33
+ jax.config.update("jax_enable_x64", True)
34
+
35
+
36
+ def _lower_to_ir_text(
37
+ jaxfn, args, kwargs, ir_input_names: list[str] = None
38
+ ) -> str:
39
+ args = utils.tree_map_list_to_tuple(args)
40
+ kwargs = utils.tree_map_list_to_tuple(kwargs)
41
+
42
+ names_args = [
43
+ *zip(inspect.signature(jaxfn).parameters.keys(), args),
44
+ *kwargs.items(),
45
+ ]
46
+
47
+ static_argnames = []
48
+ jax_lower_static_kwargs = {}
49
+ jax_lower_args = []
50
+ jax_lower_argnames = []
51
+ ir_inputs = []
52
+
53
+ for i, (name, arg) in enumerate(names_args):
54
+ is_positional = i < len(args)
55
+ if not utils.is_ir_variable(arg):
56
+ static_argnames.append(name)
57
+ jax_lower_static_kwargs[name] = arg
58
+ else:
59
+ # Enforce the arg order in the mlir is the same as the lowering func
60
+ jax_lower_args.append(utils.ir_variable_to_jax(arg))
61
+
62
+ if is_positional and len(jax_lower_args) == i + 1:
63
+ # The first N continuous tensor args are passed to the lowering func
64
+ # as positional args, when they passed to the bridged func as
65
+ # positional args also.
66
+ jax_lower_argnames.append(None)
67
+ else:
68
+ # Otherwise pass the arg to the lowering func as keyword arg.
69
+ jax_lower_argnames.append(name)
70
+
71
+ if ir_input_names is None or name in ir_input_names:
72
+ # ir variable can be a nested tuple, while mlir args should be flat.
73
+ ir_inputs += [
74
+ x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
75
+ ]
76
+
77
+ def new_lowering(*args, **jax_lower_static_kwargs):
78
+ jaxfn_args = []
79
+ jaxfn_kwargs = jax_lower_static_kwargs.copy()
80
+ for name, arg in zip(jax_lower_argnames, args):
81
+ if name is None:
82
+ jaxfn_args.append(arg)
83
+ else:
84
+ jaxfn_kwargs[name] = arg
85
+
86
+ return jaxfn(*jaxfn_args, **jaxfn_kwargs)
87
+
88
+ return (
89
+ jax.jit(new_lowering, static_argnames=static_argnames)
90
+ .lower(*jax_lower_args, **jax_lower_static_kwargs)
91
+ .as_text()
92
+ ), ir_inputs
93
+
94
+
95
+ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
96
+ """Return the wrapped JAX function to be used in ODMLTorch lowerings.
97
+
98
+ If the given jaxfn has signature `jaxfn(*args, **kwargs) -> return`, the
99
+ wrapped function would:
100
+ - Have signature `wrapped(lctx: odml_torch.export.LoweringContext, *args,
101
+ **kwargs) -> return`.
102
+ - Accept mlir.ir.Value for all params expecting jax.Array as inputs.
103
+ - Return mlir.ir.Value for all jax.Array outputs from jaxfn.
104
+
105
+ Args:
106
+ jaxfn: The JAX function to be wrapped.
107
+ ir_input_names: The input (param) names of the JAX function to be used in
108
+ the MLIR lowering. This is useful when the JAX impl only depends on
109
+ specific inputs to the function. If not specified, all ir.Value passed to
110
+ the wrapped function are assumed to be used in the lowering.
111
+ """
112
+
113
+ @functools.wraps(jaxfn)
114
+ def wrapped(lctx, *args, **kwargs):
115
+
116
+ ir_text, ir_inputs = _lower_to_ir_text(
117
+ jaxfn,
118
+ args,
119
+ kwargs,
120
+ ir_input_names=ir_input_names,
121
+ )
122
+
123
+ module = ir.Module.parse(ir_text)
124
+ passes.strip_debuginfo(module)
125
+
126
+ symbol_table = ir.SymbolTable(module.operation)
127
+ main_func = symbol_table["main"]
128
+
129
+ with ir.InsertionPoint.at_block_begin(lctx.ir_module.body):
130
+ cloned_func = cast(func.FuncOp, main_func.clone())
131
+ cloned_func_name = f"{jaxfn.__name__}_{uuid.uuid4().hex[:8]}"
132
+ cloned_func.attributes["sym_name"] = ir.StringAttr.get(cloned_func_name)
133
+ cloned_func.attributes["sym_visibility"] = ir.StringAttr.get("private")
134
+
135
+ # HACK: Use the custom inliner implemented in Python because MLIR inline
136
+ # pass from JAX's MLIR pybinding build in OSS cannot properly inline
137
+ # func.call ops.
138
+ # This should be switched to `passes.inline(module)` when we have our own
139
+ # MLIR pybinding build.
140
+ export_utils.inline(symbol_table, cloned_func.entry_block)
141
+
142
+ if not cloned_func.arguments:
143
+ # Known edge case: when the lowering does not depend on input but
144
+ # just the meta of input like shape or dtype.
145
+ ir_inputs = []
146
+
147
+ results = func.CallOp(cloned_func, ir_inputs).results
148
+ if len(results) == 1:
149
+ return results[0]
150
+ return results
151
+
152
+ return wrapped