ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """APIs to convert and lower a PyTorch ExportedProgram to MLIR."""
16
+
17
+ import dataclasses
18
+ import enum
19
+ import io
20
+ import operator
21
+ from typing import Any, Callable, Optional
22
+
23
+ from jax.lib import xla_extension
24
+ from jax._src.lib.mlir import ir
25
+ from jax._src.lib.mlir.dialects import func
26
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
27
+ import torch
28
+ import torch.utils._pytree as pytree
29
+
30
+ from . import _torch_future
31
+ from . import debuginfo
32
+ from . import export_utils
33
+ from . import lowerings
34
+
35
+ LoweringContext = lowerings.context.LoweringContext
36
+
37
+
38
+ def _build_flat_inputs(
39
+ ctx: ir.Context, exported_program: torch.export.ExportedProgram
40
+ ):
41
+ """Build flattened inputs and metadata from exported program's signature."""
42
+ placeholder_nodes = [
43
+ n for n in exported_program.graph.nodes if n.op == "placeholder"
44
+ ]
45
+ export_flat_args = _torch_future.graph_module_flat_inputs(
46
+ exported_program, *exported_program.example_inputs
47
+ )
48
+
49
+ ir_inputs = []
50
+ tensor_metas = []
51
+ for node, arg in zip(placeholder_nodes, export_flat_args):
52
+ tensor_meta = node.meta.get("tensor_meta")
53
+ if tensor_meta is None:
54
+ raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
55
+
56
+ tensor_metas.append(tensor_meta)
57
+ # Assume all dynamic dimensions are unbounded.
58
+ # TODO: Add checks for ep.range_constraints in MLIR.
59
+ shape = tuple(
60
+ export_utils.IR_DYNAMIC if export_utils.is_torch_dynamic(s) else s
61
+ for s in tensor_meta.shape
62
+ )
63
+ ir_inputs.append(
64
+ ir.RankedTensorType.get(
65
+ shape,
66
+ export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
67
+ )
68
+ )
69
+ return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
70
+
71
+
72
+ def _get_output_metas(exported_program: torch.export.ExportedProgram):
73
+ """Get the output node's tensor_meta from the exported program."""
74
+ outputs = [n for n in exported_program.graph.nodes if n.op == "output"]
75
+ assert len(outputs) == 1
76
+ outputs, _ = pytree.tree_flatten(outputs[0].args[0])
77
+ assert all(isinstance(output, torch.fx.Node) for output in outputs)
78
+ return tuple(output.meta["tensor_meta"] for output in outputs)
79
+
80
+
81
+ class LoweringInterpreter(torch.fx.Interpreter):
82
+ """The FX interpreter to iterate and invoke corresponding lowering for each PyTorch op in the graph."""
83
+
84
+ def __init__(self, module: torch.fx.GraphModule, lctx: LoweringContext):
85
+ super().__init__(module)
86
+ self.lctx = lctx
87
+ self.outputs = None
88
+
89
+ def _build_loc(self, node: torch.fx.Node):
90
+
91
+ info = debuginfo.build_mlir_debuginfo(node)
92
+ if info is None:
93
+ return ir.Location.unknown()
94
+
95
+ return ir.Location.name(name=info)
96
+
97
+ def run_node(self, node: torch.fx.Node):
98
+ loc = self._build_loc(node)
99
+ with loc:
100
+ self.lctx = self.lctx.replace(ir_location=loc, node=node)
101
+ res = super().run_node(node)
102
+ self.lctx = self.lctx.replace(ir_location=None, node=None)
103
+ return res
104
+
105
+ def call_function(self, target, args, kwargs):
106
+ if target is operator.getitem:
107
+ return super().call_function(target, args, kwargs)
108
+
109
+ if hasattr(target, "_schema"):
110
+ new_args = []
111
+ for arg, spec in zip(args, target._schema.arguments):
112
+ if isinstance(spec.type, torch.TensorType):
113
+ if isinstance(arg, int):
114
+ arg = lowerings.utils.splat(arg, ir.IntegerType.get_signless(32))
115
+ elif isinstance(arg, float):
116
+ arg = lowerings.utils.splat(arg, ir.F32Type.get())
117
+
118
+ new_args.append(arg)
119
+ args = tuple(new_args)
120
+
121
+ lowering = lowerings.lookup(target)
122
+ if lowering is None:
123
+ raise RuntimeError(f"Lowering not found: {target}")
124
+ return lowering(self.lctx, *args, **kwargs)
125
+
126
+ def output(self, target, args, kwargs):
127
+ flat_outputs = pytree.tree_flatten(args[0])[0]
128
+ self.outputs = flat_outputs
129
+
130
+
131
+ @dataclasses.dataclass
132
+ class InputSpec:
133
+
134
+ class VariableType(enum.Enum):
135
+ USER_INPUT = "user_input"
136
+ PARAMETER = "parameter"
137
+
138
+ type_: VariableType
139
+ i: int = -1
140
+ name: str = ""
141
+
142
+ @classmethod
143
+ def parameter(cls, name: str):
144
+ return cls(type_=cls.VariableType.PARAMETER, name=name)
145
+
146
+ @classmethod
147
+ def user_input(cls, i: int):
148
+ return cls(type_=cls.VariableType.USER_INPUT, i=i)
149
+
150
+ @property
151
+ def is_parameter(self):
152
+ return self.type_ == self.VariableType.PARAMETER
153
+
154
+ @property
155
+ def is_user_input(self):
156
+ return self.type_ == self.VariableType.USER_INPUT
157
+
158
+
159
+ @dataclasses.dataclass
160
+ class VariableSignature: # either argument or parameters
161
+ shape: list[int]
162
+ dtype: str
163
+ input_spec: InputSpec = None
164
+
165
+
166
+ @dataclasses.dataclass
167
+ class MlirLowered:
168
+ """The lowered MLIR module, metadata, and weight tensors bundle from exported program."""
169
+
170
+ ctx: ir.Context
171
+ module: ir.Module
172
+ state_dict: dict[str, torch.Tensor]
173
+ input_signature: list[VariableSignature]
174
+ output_signature: list[VariableSignature]
175
+
176
+ _tf_function: Optional[Callable[Any, Any]] = None
177
+
178
+ def __str__(self):
179
+ return str(self.get_text(enable_debug_info=False))
180
+
181
+ def __repr__(self):
182
+ return str(self.get_text(enable_debug_info=False))
183
+
184
+ def get_text(self, enable_debug_info=False):
185
+ return str(
186
+ self.module.operation.get_asm(enable_debug_info=enable_debug_info)
187
+ )
188
+
189
+ @property
190
+ def module_bytecode(self) -> bytes:
191
+ output = io.BytesIO()
192
+ self.module.operation.write_bytecode(file=output)
193
+ return output.getvalue()
194
+
195
+ @property
196
+ def module_bytecode_vhlo(self) -> bytes:
197
+ # HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
198
+ # build, which may not have the same StableHLO version as what used in
199
+ # TFLite converter. Therefore we always serialize MLIR module in VHLO.
200
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201
+ target_version = stablehlo.get_minimum_version()
202
+ module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203
+ self.module_bytecode, target_version
204
+ )
205
+ return module_bytecode
206
+
207
+ @property
208
+ def tf_function(self):
209
+ # Lazy import
210
+ from . import tf_integration
211
+
212
+ if self._tf_function is None:
213
+ self._tf_function = tf_integration.mlir_to_tf_function(self)
214
+ return self._tf_function
215
+
216
+ def __call__(self, *args):
217
+ # Lazy importing TF when execution is needed.
218
+ return self.tf_function(*args)
219
+
220
+ def to_flatbuffer(self):
221
+ from . import tf_integration
222
+
223
+ return tf_integration.mlir_to_flatbuffer(self)
224
+
225
+
226
+ # TODO(b/331481564) Make this a ai_edge_torch FX pass.
227
+ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
228
+ """Convert internal constant aten ops' output from int64 to int32.
229
+
230
+ Int32 generally has better performance and compatibility than int64 in
231
+ runtime. This pass converts aten op where the output(s) are int64 constant
232
+ tensors to return int32 constant tensors.
233
+
234
+ Args:
235
+ exported_program: The exported program to apply the pass.
236
+ """
237
+
238
+ def in_i32(x: int):
239
+ return -2147483648 <= x <= 2147483647
240
+
241
+ def rewrite_arange(node: torch.fx.Node):
242
+ tensor_meta = node.meta.get("tensor_meta", None)
243
+ if not tensor_meta:
244
+ return
245
+
246
+ start, end = node.args[:2]
247
+ if tensor_meta.dtype != torch.int64:
248
+ return
249
+ if not (in_i32(start) and in_i32(end)):
250
+ return
251
+ op = node.target
252
+ node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
253
+
254
+ graph_module = exported_program.graph_module
255
+ for node in graph_module.graph.nodes:
256
+
257
+ if node.target == torch.ops.aten.arange.start_step:
258
+ rewrite_arange(node)
259
+
260
+
261
+ def exported_program_to_mlir(
262
+ exported_program: torch.export.ExportedProgram,
263
+ ) -> MlirLowered:
264
+ """Lower the exported program to MLIR."""
265
+ exported_program = exported_program.run_decompositions(
266
+ lowerings.decompositions()
267
+ )
268
+
269
+ _convert_i64_to_i32(exported_program)
270
+ exported_program = exported_program.run_decompositions(
271
+ lowerings.decompositions()
272
+ )
273
+
274
+ with export_utils.create_ir_context() as context, ir.Location.unknown():
275
+
276
+ module = ir.Module.create()
277
+ lctx = LoweringContext(context, module)
278
+ interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
279
+ ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
280
+ context, exported_program
281
+ )
282
+
283
+ # HACK: OSS MLIR pybinding could mysteriously transform func.func under
284
+ # construction into a func.return op after calling ir.Module.parse(..)
285
+ # in the context, which happens in JAX bridge. This is a bug in MLIR
286
+ # pybinding.
287
+ # Workaround steps:
288
+ # 1. Create a temp func.func.
289
+ # 2. Create and insert ops to temp's entry block. During the process
290
+ # the temp func.func would be broken, but the ops in the block are fine.
291
+ # 3. Create the main func.func and copy all the ops in temp's entry block
292
+ # to main.
293
+ # 4. Erase the temp func.func.
294
+ temp_func = func.FuncOp(
295
+ "temp",
296
+ ir.FunctionType.get(ir_flat_inputs, []),
297
+ ip=ir.InsertionPoint.at_block_begin(module.body),
298
+ )
299
+ with ir.InsertionPoint(temp_func.add_entry_block()):
300
+ interpreter.run(*temp_func.arguments, enable_io_processing=False)
301
+ num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
302
+ outputs = interpreter.outputs[num_mutations:]
303
+ func.ReturnOp(interpreter.outputs[num_mutations:])
304
+
305
+ main_func = func.FuncOp(
306
+ "main",
307
+ ir.FunctionType.get(ir_flat_inputs, [o.type for o in outputs]),
308
+ ip=ir.InsertionPoint.at_block_begin(module.body),
309
+ )
310
+ with ir.InsertionPoint(main_func.add_entry_block()):
311
+ outputs = export_utils.clone_func_body_ops(temp_func, main_func.arguments)
312
+ func.ReturnOp(outputs)
313
+
314
+ main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
315
+ temp_func.erase()
316
+
317
+ module.operation.verify()
318
+
319
+ input_signature = []
320
+ state_dict = {}
321
+
322
+ user_inputs_cnt = 0
323
+ for arg, tensor_meta, input_spec in zip(
324
+ export_flat_args,
325
+ tensor_metas,
326
+ exported_program.graph_signature.input_specs,
327
+ ):
328
+ # Assumption:
329
+ # All states comes first in the list of args, and user provided inputs
330
+ # comes later. Also there is no kwargs.
331
+ if input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT:
332
+ input_signature.append(
333
+ VariableSignature(
334
+ tensor_meta.shape,
335
+ tensor_meta.dtype,
336
+ input_spec=InputSpec.user_input(user_inputs_cnt),
337
+ )
338
+ )
339
+ user_inputs_cnt += 1
340
+ else:
341
+ # Parameter or constant
342
+ state_dict[input_spec.target] = arg
343
+ input_signature.append(
344
+ VariableSignature(
345
+ tensor_meta.shape,
346
+ tensor_meta.dtype,
347
+ input_spec=InputSpec.parameter(input_spec.target),
348
+ )
349
+ )
350
+
351
+ output_signature = [
352
+ VariableSignature(tensor_meta.shape, tensor_meta.dtype)
353
+ for tensor_meta in _get_output_metas(exported_program)
354
+ ]
355
+ return MlirLowered(
356
+ context, module, state_dict, input_signature, output_signature
357
+ )
@@ -0,0 +1,168 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for ODML Torch export."""
16
+
17
+ import functools
18
+ import re
19
+ from typing import Sequence, cast
20
+ import jax._src.interpreters.mlir
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import func
23
+ import torch
24
+
25
+ # std::numeric_limits<int64_t>::min()
26
+ IR_DYNAMIC = -9223372036854775808
27
+
28
+
29
+ def is_ir_dynamic(v):
30
+ return v == IR_DYNAMIC
31
+
32
+
33
+ def is_torch_dynamic(v):
34
+ return isinstance(v, torch.SymInt)
35
+
36
+
37
+ def is_iterable(v):
38
+ try:
39
+ iter(v)
40
+ except TypeError:
41
+ return False
42
+ return True
43
+
44
+
45
+ def create_ir_context():
46
+ # HACK: Use ir context from JAX as base for better stability in OSS.
47
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
48
+ context = jax._src.interpreters.mlir.make_ir_context()
49
+ context.allow_unregistered_dialects = True
50
+
51
+ return context
52
+
53
+
54
+ def inline(
55
+ symbol_table: ir.SymbolTable,
56
+ block: ir.Block,
57
+ ):
58
+ """Recursively inlines all func.call ops in the block.
59
+
60
+ The symbol_table must include all func.func called by func.call ops.
61
+ This inliner in Python is implemented because MLIR inline pass from JAX's
62
+ MLIR pybinding build in OSS cannot properly inline func.call ops.
63
+ """
64
+ while True:
65
+ is_changed = False
66
+ for op in block.operations:
67
+ if op.OPERATION_NAME != func.CallOp.OPERATION_NAME:
68
+ continue
69
+
70
+ call_op = cast(func.CallOp, op)
71
+ func_op = cast(func.FuncOp, symbol_table[call_op.callee.value])
72
+ with ir.InsertionPoint(op):
73
+ new_results = clone_func_body_ops(func_op, call_op.operands)
74
+
75
+ for old_result, new_result in zip(call_op.results, new_results):
76
+ old_result = cast(ir.Value, old_result)
77
+ old_result.replace_all_uses_with(new_result)
78
+ call_op.erase()
79
+ is_changed = True
80
+
81
+ if not is_changed:
82
+ break
83
+
84
+ for op in block.operations:
85
+ for region in op.regions:
86
+ for block in region.blocks:
87
+ inline(symbol_table, block)
88
+
89
+
90
+ def clone_func_body_ops(func_op: func.FuncOp, ir_inputs: Sequence[ir.Value]):
91
+ """Clone operations in the func_op's body by one into the current context."""
92
+ func_args = list(func_op.arguments)
93
+ ir_inputs = list(ir_inputs)
94
+ assert len(func_args) == len(ir_inputs)
95
+
96
+ value_mapping = {arg: ir_input for arg, ir_input in zip(func_args, ir_inputs)}
97
+
98
+ for op in list(func_op.entry_block.operations):
99
+ cloned_operands = [value_mapping[val] for val in op.operands]
100
+ if op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME:
101
+ return cloned_operands
102
+
103
+ cloned = cast(ir.Operation, op.operation.clone())
104
+
105
+ for i in range(len(op.operands)):
106
+ cloned.operands[i] = cloned_operands[i]
107
+
108
+ for i in range(len(op.results)):
109
+ value_mapping[op.results[i]] = cloned.results[i]
110
+
111
+ return []
112
+
113
+
114
+ def sanitize_aten_op_name(op, chars=":."):
115
+ return re.sub("[{}]".format(chars), "_", str(op))
116
+
117
+
118
+ def build_ir_attr(val):
119
+ if val is None:
120
+ return ir.StringAttr.get("py_None")
121
+ if isinstance(val, bool):
122
+ return ir.BoolAttr.get(val)
123
+ if isinstance(val, int):
124
+ return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
125
+ if isinstance(val, float):
126
+ return ir.BoolAttr.get(val)
127
+ if isinstance(val, str):
128
+ return ir.StringAttr.get(val)
129
+ if isinstance(val, dict):
130
+ return ir.DictAttr.get({k: build_ir_attr(v) for k, v in val.items()})
131
+ if isinstance(val, (list, tuple)):
132
+ return ir.ArrayAttr.get([build_ir_attr(v) for v in val])
133
+
134
+ # Stringify the value to a StringAttr by default
135
+ return ir.StringAttr.get(str(val))
136
+
137
+
138
+ def torch_dtype_to_ir_element_type(ctx, dtype):
139
+ ty_get = {
140
+ torch.double: ir.F64Type.get,
141
+ torch.float32: ir.F32Type.get,
142
+ torch.half: ir.F16Type.get,
143
+ torch.long: functools.partial(ir.IntegerType.get_signless, 64),
144
+ torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
145
+ torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
146
+ torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
147
+ }.get(dtype)
148
+ return ty_get(ctx)
149
+
150
+
151
+ def ir_element_type_to_torch_dtype(ty):
152
+ if isinstance(ty, ir.F32Type):
153
+ return torch.float32
154
+ if isinstance(ty, ir.F64Type):
155
+ return torch.float64
156
+ if isinstance(ty, ir.F16Type):
157
+ return torch.half
158
+ if isinstance(ty, ir.IntegerType):
159
+ if ty.is_signless:
160
+ if ty.width == 64:
161
+ return torch.long
162
+ if ty.width == 32:
163
+ return torch.int32
164
+ if ty.width == 16:
165
+ return torch.int16
166
+ if ty.width == 1:
167
+ return torch.bool
168
+ raise RuntimeError(f"Unsupported ir element type: {ty}")
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """APIs to wrap JAX functions for using in ODML Torch lowerings."""
16
+
17
+ import functools
18
+ import inspect
19
+ from typing import Any, Callable, cast
20
+ import uuid
21
+ from ai_edge_torch.odml_torch import export_utils
22
+ from ai_edge_torch.odml_torch import passes
23
+ from ai_edge_torch.odml_torch.jax_bridge import utils
24
+ import jax
25
+ from jax._src.lib.mlir import ir
26
+ from jax._src.lib.mlir.dialects import func
27
+ import torch.utils._pytree as pytree
28
+
29
+ # Jax double (64bit) precision is required to generate StableHLO mlir with
30
+ # i64/f64 tensors from Jax bridged lowerings. If not set properly, all the
31
+ # 64bit tensors would be truncated to 32bit dtype and potentially break the
32
+ # lowering.
33
+ jax.config.update("jax_enable_x64", True)
34
+
35
+
36
+ def _lower_to_ir_text(
37
+ jaxfn, args, kwargs, ir_input_names: list[str] = None
38
+ ) -> tuple[str, list[ir.Value]]:
39
+ args = utils.tree_map_list_to_tuple(args)
40
+ kwargs = utils.tree_map_list_to_tuple(kwargs)
41
+
42
+ names_args = [
43
+ *zip(inspect.signature(jaxfn).parameters.keys(), args),
44
+ *kwargs.items(),
45
+ ]
46
+
47
+ static_argnames = []
48
+ jax_lower_static_kwargs = {}
49
+ jax_lower_args = []
50
+ jax_lower_argnames = []
51
+ ir_inputs = []
52
+
53
+ for i, (name, arg) in enumerate(names_args):
54
+ is_positional = i < len(args)
55
+ if not utils.is_ir_variable(arg):
56
+ static_argnames.append(name)
57
+ jax_lower_static_kwargs[name] = arg
58
+ else:
59
+ # Enforce the arg order in the mlir is the same as the lowering func
60
+ jax_lower_args.append(utils.ir_variable_to_jax(arg))
61
+
62
+ if is_positional and len(jax_lower_args) == i + 1:
63
+ # The first N continuous tensor args are passed to the lowering func
64
+ # as positional args, when they passed to the bridged func as
65
+ # positional args also.
66
+ jax_lower_argnames.append(None)
67
+ else:
68
+ # Otherwise pass the arg to the lowering func as keyword arg.
69
+ jax_lower_argnames.append(name)
70
+
71
+ if ir_input_names is None or name in ir_input_names:
72
+ # ir variable can be a nested tuple, while mlir args should be flat.
73
+ ir_inputs += [
74
+ x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
75
+ ]
76
+
77
+ def lower_wrapper(*args):
78
+ nonlocal jax_lower_static_kwargs
79
+
80
+ jaxfn_args = []
81
+ jaxfn_kwargs = jax_lower_static_kwargs.copy()
82
+ for name, arg in zip(jax_lower_argnames, args):
83
+ if name is None:
84
+ jaxfn_args.append(arg)
85
+ else:
86
+ jaxfn_kwargs[name] = arg
87
+
88
+ return jaxfn(*jaxfn_args, **jaxfn_kwargs)
89
+
90
+ return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
91
+
92
+
93
+ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
94
+ """Return the wrapped JAX function to be used in ODMLTorch lowerings.
95
+
96
+ If the given jaxfn has signature `jaxfn(*args, **kwargs) -> return`, the
97
+ wrapped function would:
98
+ - Have signature `wrapped(lctx: odml_torch.export.LoweringContext, *args,
99
+ **kwargs) -> return`.
100
+ - Accept mlir.ir.Value for all params expecting jax.Array as inputs.
101
+ - Return mlir.ir.Value for all jax.Array outputs from jaxfn.
102
+
103
+ Args:
104
+ jaxfn: The JAX function to be wrapped.
105
+ ir_input_names: The input (param) names of the JAX function to be used in
106
+ the MLIR lowering. This is useful when the JAX impl only depends on
107
+ specific inputs to the function. If not specified, all ir.Value passed to
108
+ the wrapped function are assumed to be used in the lowering.
109
+ """
110
+
111
+ @functools.wraps(jaxfn)
112
+ def wrapped(lctx, *args, **kwargs):
113
+
114
+ ir_text, ir_inputs = _lower_to_ir_text(
115
+ jaxfn,
116
+ args,
117
+ kwargs,
118
+ ir_input_names=ir_input_names,
119
+ )
120
+
121
+ module = ir.Module.parse(ir_text)
122
+ passes.strip_debuginfo(module)
123
+
124
+ symbol_table = ir.SymbolTable(module.operation)
125
+ main_func = symbol_table["main"]
126
+
127
+ with ir.InsertionPoint.at_block_begin(lctx.ir_module.body):
128
+ cloned_func = cast(func.FuncOp, main_func.clone())
129
+ cloned_func_name = f"{jaxfn.__name__}_{uuid.uuid4().hex[:8]}"
130
+ cloned_func.attributes["sym_name"] = ir.StringAttr.get(cloned_func_name)
131
+ cloned_func.attributes["sym_visibility"] = ir.StringAttr.get("private")
132
+
133
+ # HACK: Use the custom inliner implemented in Python because MLIR inline
134
+ # pass from JAX's MLIR pybinding build in OSS cannot properly inline
135
+ # func.call ops.
136
+ # This should be switched to `passes.inline(module)` when we have our own
137
+ # MLIR pybinding build.
138
+ export_utils.inline(symbol_table, cloned_func.entry_block)
139
+
140
+ if not cloned_func.arguments:
141
+ # Known edge case: when the lowering does not depend on input but
142
+ # just the meta of input like shape or dtype.
143
+ ir_inputs = []
144
+
145
+ results = func.CallOp(cloned_func, ir_inputs).results
146
+ if len(results) == 1:
147
+ return results[0]
148
+ return results
149
+
150
+ return wrapped