ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,243 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to stablehlo for Convolution."""
16
+
17
+ import math
18
+ from typing import Optional
19
+
20
+ from ai_edge_torch.odml_torch.lowerings import registry
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ def make_padding(padding):
27
+ """Change the padding from pytorch to stablehlo style.
28
+
29
+ Stablehlo allows start and end padding for each dimension while aten only
30
+ allows symmetric padding and so only has one number per dimension.
31
+
32
+ Args:
33
+ padding: The padding of the convolution
34
+
35
+ Returns:
36
+ The padding in stablehlo style
37
+ """
38
+ return tuple((p, p) for p in padding)
39
+
40
+
41
+ def create_conv_dimension_numbers(lhs, transposed: bool = False):
42
+ """Create the dimension numbers for the convolution.
43
+
44
+ Args:
45
+ lhs: The input tensor
46
+ transposed: Whether the convolution is transposed
47
+
48
+ Returns:
49
+ The dimension numbers for the convolution
50
+ """
51
+ num_spatial_dims = len(lhs.type.shape) - 2
52
+ spatial_dimensions = []
53
+ for i in range(0, num_spatial_dims):
54
+ spatial_dimensions.append(i + 2)
55
+
56
+ # Regular kernels are OIHW
57
+ # TransposedConv kernels are IOHW
58
+ dimension_numbers = stablehlo.ConvDimensionNumbers.get(
59
+ input_batch_dimension=0,
60
+ input_feature_dimension=1,
61
+ input_spatial_dimensions=spatial_dimensions,
62
+ kernel_input_feature_dimension=0 if transposed else 1,
63
+ kernel_output_feature_dimension=1 if transposed else 0,
64
+ kernel_spatial_dimensions=spatial_dimensions,
65
+ output_batch_dimension=0,
66
+ output_feature_dimension=1,
67
+ output_spatial_dimensions=spatial_dimensions,
68
+ )
69
+ return dimension_numbers
70
+
71
+
72
+ def infer_output_shape(
73
+ lhs,
74
+ rhs,
75
+ stride,
76
+ dilation,
77
+ padding,
78
+ transposed: bool = False,
79
+ output_padding: list[int] = 0,
80
+ ):
81
+ """Infer the output shape of the convolution.
82
+
83
+ Args:
84
+ lhs: The input tensor
85
+ rhs: The kernel tensor
86
+ stride: The stride of the convolution (dilation of input in transposed conv)
87
+ dilation: The kernel dilation of the convolution
88
+ padding: The padding of the convolution
89
+ transposed: Whether the convolution is transposed
90
+ output_padding: The output padding of the convolution
91
+
92
+ Returns:
93
+ The output shape of the convolution
94
+ """
95
+ lhs_type: ir.RankedTensorType = lhs.type
96
+ lhs_shape: list[int] = lhs_type.shape
97
+ rhs_shape: list[int] = rhs.type.shape
98
+
99
+ # Input layout is: (N)CHW and Kernel layout is: (O)IHW for regular conv
100
+ # Input layout is: (N)CHW and Kernel layout is: I(O)HW for transposed conv
101
+ output_shape = (
102
+ [lhs_shape[0], rhs_shape[1]]
103
+ if transposed
104
+ else [lhs_shape[0], rhs_shape[0]]
105
+ )
106
+ num_spatial_dims = len(lhs.type.shape) - 2
107
+
108
+ # looping over the spatial dims (skipping the first 2 dims which are
109
+ # batch and features)
110
+ for spatial_dim in range(0, num_spatial_dims):
111
+ dim = spatial_dim + 2
112
+ dim_size = lhs_shape[dim]
113
+ kernel_dim_size = rhs_shape[dim]
114
+
115
+ if transposed:
116
+ output_dim_size = (
117
+ (dim_size - 1) * stride[spatial_dim]
118
+ - 2 * padding[spatial_dim]
119
+ + dilation[spatial_dim] * (kernel_dim_size - 1)
120
+ + output_padding[spatial_dim]
121
+ + 1
122
+ )
123
+ else:
124
+ output_dim_size = math.floor(
125
+ (
126
+ (
127
+ dim_size
128
+ + 2 * padding[spatial_dim]
129
+ - dilation[spatial_dim] * (kernel_dim_size - 1)
130
+ - 1
131
+ )
132
+ / stride[spatial_dim]
133
+ )
134
+ + 1
135
+ )
136
+
137
+ output_shape.append(output_dim_size)
138
+
139
+ return output_shape
140
+
141
+
142
+ def build_transpose_conv(
143
+ lctx,
144
+ output_type: ir.RankedTensorType,
145
+ lhs: ir.Value,
146
+ rhs: ir.Value,
147
+ stride: list[int],
148
+ padding: list[int],
149
+ dilation: list[int],
150
+ output_padding: list[int],
151
+ groups: int,
152
+ ):
153
+ lhs_type: ir.RankedTensorType = lhs.type
154
+ num_spatial_dims = len(lhs_type.shape) - 2
155
+ rhs = stablehlo.reverse(rhs, list(range(2, 2 + num_spatial_dims)))
156
+
157
+ kernel_size = rhs.type.shape
158
+ # We need to additional padding on the input to get the right output size.
159
+ adjusted_padding = [
160
+ dilation[dim] * (kernel_size[dim + 2] - 1) - padding[dim]
161
+ for dim in range(num_spatial_dims)
162
+ ]
163
+ return stablehlo.convolution(
164
+ result=output_type,
165
+ lhs=lhs,
166
+ rhs=rhs,
167
+ dimension_numbers=create_conv_dimension_numbers(lhs, True),
168
+ feature_group_count=groups,
169
+ batch_group_count=1,
170
+ padding=make_padding(adjusted_padding),
171
+ lhs_dilation=stride,
172
+ rhs_dilation=dilation,
173
+ )
174
+
175
+
176
+ # convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride,
177
+ # SymInt[] padding, SymInt[] dilation, bool transposed,
178
+ # SymInt[] output_padding, SymInt groups) -> Tensor
179
+ @registry.lower(torch.ops.aten.convolution)
180
+ def _aten_convolution(
181
+ lctx,
182
+ lhs: ir.Value,
183
+ rhs: ir.Value,
184
+ bias: Optional[ir.Value],
185
+ stride: list[int],
186
+ padding: list[int],
187
+ dilation: list[int],
188
+ transposed: bool,
189
+ output_padding: list[int],
190
+ groups: int,
191
+ ):
192
+
193
+ # TODO(b/365559296) Add support for output_padding
194
+ if any(output_padding):
195
+ raise NotImplementedError(
196
+ "Output padding on convolution is not implemented."
197
+ )
198
+
199
+ lhs_type: ir.RankedTensorType = lhs.type
200
+ output_shape = infer_output_shape(
201
+ lhs, rhs, stride, dilation, padding, transposed, output_padding
202
+ )
203
+ output_type = ir.RankedTensorType.get(
204
+ output_shape,
205
+ lhs_type.element_type,
206
+ )
207
+
208
+ if transposed:
209
+ res = build_transpose_conv(
210
+ lctx,
211
+ output_type,
212
+ lhs,
213
+ rhs,
214
+ stride,
215
+ padding,
216
+ dilation,
217
+ output_padding,
218
+ groups,
219
+ )
220
+ else:
221
+ res = stablehlo.convolution(
222
+ result=output_type,
223
+ lhs=lhs,
224
+ rhs=rhs,
225
+ dimension_numbers=create_conv_dimension_numbers(lhs),
226
+ feature_group_count=groups,
227
+ batch_group_count=1,
228
+ window_strides=stride,
229
+ padding=make_padding(padding),
230
+ rhs_dilation=dilation,
231
+ )
232
+
233
+ if bias is not None:
234
+ # broadcast [C] to [NCHW]
235
+ broadcasted_bias = stablehlo.broadcast_in_dim(
236
+ output_type, bias, ir.DenseI64ArrayAttr.get([1])
237
+ )
238
+ res = stablehlo.add(
239
+ lhs=res,
240
+ rhs=broadcasted_bias,
241
+ )
242
+
243
+ return res
@@ -0,0 +1,285 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import functools
16
+ import logging
17
+
18
+ from ai_edge_torch.odml_torch import jax_bridge
19
+ from ai_edge_torch.odml_torch.lowerings import context
20
+ from ai_edge_torch.odml_torch.lowerings import registry
21
+ import jax.numpy as jnp
22
+ from jax._src.lib.mlir import ir
23
+ import torch
24
+ import torch_xla2.ops.jaten # Import to load torch_xla2 ops
25
+ import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
26
+
27
+ LoweringContext = context.LoweringContext
28
+
29
+
30
+ @functools.cache
31
+ def _log_usage(op):
32
+ logging.warning("Use jax lowering: %s", str(op))
33
+
34
+
35
+ def lower_by_jax(op, ir_input_names=None):
36
+ def inner(lowering):
37
+ bridged = jax_bridge.wrap(lowering, ir_input_names)
38
+
39
+ @registry.lower(op)
40
+ def _jax_lowering(lctx, *args, **kwargs):
41
+ _log_usage(op)
42
+ return bridged(lctx, *args, **kwargs)
43
+
44
+ return lowering
45
+
46
+ return inner
47
+
48
+
49
+ _TORCH_XLA2_IMPLS = {}
50
+
51
+ for op, torch_xla2_op in torch_xla2.ops.ops_registry.all_aten_ops.items():
52
+ if not torch_xla2_op.is_jax_function:
53
+ continue
54
+ if isinstance(op, torch._ops.OpOverloadPacket):
55
+ ops = [getattr(op, overload) for overload in op.overloads()] + [op]
56
+ else:
57
+ ops = [op]
58
+
59
+ for op in ops:
60
+ _TORCH_XLA2_IMPLS[op] = torch_xla2_op.func
61
+
62
+
63
+ def lower_by_torch_xla2(op):
64
+ return lower_by_jax(op)(_TORCH_XLA2_IMPLS[op])
65
+
66
+
67
+ lower_by_torch_xla2(torch.ops.aten._adaptive_avg_pool2d)
68
+ lower_by_torch_xla2(torch.ops.aten._adaptive_avg_pool3d)
69
+ lower_by_torch_xla2(torch.ops.aten._cdist_forward)
70
+ lower_by_torch_xla2(torch.ops.aten._local_scalar_dense)
71
+ lower_by_torch_xla2(torch.ops.aten._local_scalar_dense)
72
+ lower_by_torch_xla2(torch.ops.aten._log_softmax)
73
+ lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit)
74
+ lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training)
75
+ lower_by_torch_xla2(torch.ops.aten._pdist_forward)
76
+ lower_by_torch_xla2(torch.ops.aten._softmax)
77
+ lower_by_torch_xla2(torch.ops.aten._to_copy)
78
+ lower_by_torch_xla2(torch.ops.aten._unsafe_index)
79
+ lower_by_torch_xla2(torch.ops.aten._unsafe_view)
80
+ lower_by_torch_xla2(torch.ops.aten.abs)
81
+ lower_by_torch_xla2(torch.ops.aten.acos)
82
+ lower_by_torch_xla2(torch.ops.aten.acosh)
83
+ lower_by_torch_xla2(torch.ops.aten.add.Scalar)
84
+ lower_by_torch_xla2(torch.ops.aten.add.Tensor)
85
+ lower_by_torch_xla2(torch.ops.aten.addbmm.default)
86
+ lower_by_torch_xla2(torch.ops.aten.addmm)
87
+ lower_by_torch_xla2(torch.ops.aten.addmv)
88
+ lower_by_torch_xla2(torch.ops.aten.alias)
89
+ lower_by_torch_xla2(torch.ops.aten.allclose)
90
+ lower_by_torch_xla2(torch.ops.aten.amax)
91
+ lower_by_torch_xla2(torch.ops.aten.amin)
92
+ lower_by_torch_xla2(torch.ops.aten.any)
93
+ lower_by_torch_xla2(torch.ops.aten.arange.default)
94
+ lower_by_torch_xla2(torch.ops.aten.arange.start)
95
+ lower_by_torch_xla2(torch.ops.aten.arange.start_step)
96
+ lower_by_torch_xla2(torch.ops.aten.argmax)
97
+ lower_by_torch_xla2(torch.ops.aten.argmin)
98
+ lower_by_torch_xla2(torch.ops.aten.as_strided)
99
+ lower_by_torch_xla2(torch.ops.aten.as_strided_copy)
100
+ lower_by_torch_xla2(torch.ops.aten.asin)
101
+ lower_by_torch_xla2(torch.ops.aten.asinh)
102
+ lower_by_torch_xla2(torch.ops.aten.atan)
103
+ lower_by_torch_xla2(torch.ops.aten.atan2)
104
+ lower_by_torch_xla2(torch.ops.aten.atanh)
105
+ lower_by_torch_xla2(torch.ops.aten.avg_pool2d)
106
+ lower_by_torch_xla2(torch.ops.aten.avg_pool3d)
107
+ lower_by_torch_xla2(torch.ops.aten.bitwise_and)
108
+ lower_by_torch_xla2(torch.ops.aten.bitwise_not)
109
+ lower_by_torch_xla2(torch.ops.aten.bitwise_or)
110
+ lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
111
+ lower_by_torch_xla2(torch.ops.aten.bmm)
112
+ lower_by_torch_xla2(torch.ops.aten.ceil)
113
+ lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
114
+ lower_by_torch_xla2(torch.ops.aten.clamp.default)
115
+ lower_by_torch_xla2(torch.ops.aten.clone)
116
+ lower_by_torch_xla2(torch.ops.aten.clone.default)
117
+ lower_by_torch_xla2(torch.ops.aten.constant_pad_nd)
118
+ lower_by_torch_xla2(torch.ops.aten.cos)
119
+ lower_by_torch_xla2(torch.ops.aten.cosh)
120
+ lower_by_torch_xla2(torch.ops.aten.cumsum)
121
+ lower_by_torch_xla2(torch.ops.aten.detach)
122
+ lower_by_torch_xla2(torch.ops.aten.diagonal)
123
+ lower_by_torch_xla2(torch.ops.aten.div)
124
+ lower_by_torch_xla2(torch.ops.aten.dot)
125
+ lower_by_torch_xla2(torch.ops.aten.embedding)
126
+ lower_by_torch_xla2(torch.ops.aten.empty)
127
+ lower_by_torch_xla2(torch.ops.aten.eq)
128
+ lower_by_torch_xla2(torch.ops.aten.erf)
129
+ lower_by_torch_xla2(torch.ops.aten.exp)
130
+ lower_by_torch_xla2(torch.ops.aten.expand)
131
+ lower_by_torch_xla2(torch.ops.aten.expand_copy)
132
+ lower_by_torch_xla2(torch.ops.aten.expm1)
133
+ lower_by_torch_xla2(torch.ops.aten.fill)
134
+ lower_by_torch_xla2(torch.ops.aten.flip)
135
+ lower_by_torch_xla2(torch.ops.aten.fmod)
136
+ lower_by_torch_xla2(torch.ops.aten.full)
137
+ lower_by_torch_xla2(torch.ops.aten.full_like)
138
+ lower_by_torch_xla2(torch.ops.aten.gather)
139
+ lower_by_torch_xla2(torch.ops.aten.ge)
140
+ lower_by_torch_xla2(torch.ops.aten.gelu)
141
+ lower_by_torch_xla2(torch.ops.aten.glu)
142
+ lower_by_torch_xla2(torch.ops.aten.glu.default)
143
+ lower_by_torch_xla2(torch.ops.aten.gt)
144
+ lower_by_torch_xla2(torch.ops.aten.hardtanh)
145
+ lower_by_torch_xla2(torch.ops.aten.index)
146
+ lower_by_torch_xla2(torch.ops.aten.index.Tensor)
147
+ lower_by_torch_xla2(torch.ops.aten.index_copy)
148
+ lower_by_torch_xla2(torch.ops.aten.index_put)
149
+ lower_by_torch_xla2(torch.ops.aten.index_select)
150
+ lower_by_torch_xla2(torch.ops.aten.isinf)
151
+ lower_by_torch_xla2(torch.ops.aten.isnan)
152
+ lower_by_torch_xla2(torch.ops.aten.le)
153
+ lower_by_torch_xla2(torch.ops.aten.leaky_relu)
154
+ lower_by_torch_xla2(torch.ops.aten.lift_fresh_copy)
155
+ lower_by_torch_xla2(torch.ops.aten.linalg_vector_norm)
156
+ lower_by_torch_xla2(torch.ops.aten.log)
157
+ lower_by_torch_xla2(torch.ops.aten.log10)
158
+ lower_by_torch_xla2(torch.ops.aten.log1p)
159
+ lower_by_torch_xla2(torch.ops.aten.log2)
160
+ lower_by_torch_xla2(torch.ops.aten.logical_and)
161
+ lower_by_torch_xla2(torch.ops.aten.logical_not)
162
+ lower_by_torch_xla2(torch.ops.aten.logical_or)
163
+ lower_by_torch_xla2(torch.ops.aten.logical_xor)
164
+ lower_by_torch_xla2(torch.ops.aten.lt)
165
+ lower_by_torch_xla2(torch.ops.aten.max)
166
+ lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices)
167
+ lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
168
+ lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
169
+ lower_by_torch_xla2(torch.ops.aten.max_pool3d_with_indices)
170
+ lower_by_torch_xla2(torch.ops.aten.maximum)
171
+ lower_by_torch_xla2(torch.ops.aten.mean)
172
+ lower_by_torch_xla2(torch.ops.aten.min)
173
+ lower_by_torch_xla2(torch.ops.aten.minimum)
174
+ lower_by_torch_xla2(torch.ops.aten.mm)
175
+ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
176
+ lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
177
+ lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
178
+ lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
179
+ lower_by_torch_xla2(torch.ops.aten.ne)
180
+ lower_by_torch_xla2(torch.ops.aten.neg)
181
+ lower_by_torch_xla2(torch.ops.aten.nonzero)
182
+ lower_by_torch_xla2(torch.ops.aten.outer)
183
+ lower_by_torch_xla2(torch.ops.aten.permute)
184
+ lower_by_torch_xla2(torch.ops.aten.permute_copy)
185
+ lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
186
+ lower_by_torch_xla2(torch.ops.aten.pow)
187
+ lower_by_torch_xla2(torch.ops.aten.prod)
188
+ lower_by_torch_xla2(torch.ops.aten.reciprocal)
189
+ lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
190
+ lower_by_torch_xla2(torch.ops.aten.relu)
191
+ lower_by_torch_xla2(torch.ops.aten.remainder)
192
+ lower_by_torch_xla2(torch.ops.aten.repeat)
193
+ lower_by_torch_xla2(torch.ops.aten.reshape)
194
+ lower_by_torch_xla2(torch.ops.aten.roll)
195
+ lower_by_torch_xla2(torch.ops.aten.round)
196
+ lower_by_torch_xla2(torch.ops.aten.rsqrt)
197
+ lower_by_torch_xla2(torch.ops.aten.scalar_tensor)
198
+ lower_by_torch_xla2(torch.ops.aten.scatter.src)
199
+ lower_by_torch_xla2(torch.ops.aten.scatter.value)
200
+ lower_by_torch_xla2(torch.ops.aten.scatter_add)
201
+ lower_by_torch_xla2(torch.ops.aten.scatter_reduce)
202
+ lower_by_torch_xla2(torch.ops.aten.select)
203
+ lower_by_torch_xla2(torch.ops.aten.select_copy)
204
+ lower_by_torch_xla2(torch.ops.aten.select_scatter)
205
+ lower_by_torch_xla2(torch.ops.aten.sigmoid)
206
+ lower_by_torch_xla2(torch.ops.aten.sign)
207
+ lower_by_torch_xla2(torch.ops.aten.silu)
208
+ lower_by_torch_xla2(torch.ops.aten.sin)
209
+ lower_by_torch_xla2(torch.ops.aten.sinh)
210
+ lower_by_torch_xla2(torch.ops.aten.slice)
211
+ lower_by_torch_xla2(torch.ops.aten.slice_copy)
212
+ lower_by_torch_xla2(torch.ops.aten.sort)
213
+ lower_by_torch_xla2(torch.ops.aten.split)
214
+ lower_by_torch_xla2(torch.ops.aten.split_copy)
215
+ lower_by_torch_xla2(torch.ops.aten.split_with_sizes)
216
+ lower_by_torch_xla2(torch.ops.aten.sqrt)
217
+ lower_by_torch_xla2(torch.ops.aten.squeeze)
218
+ lower_by_torch_xla2(torch.ops.aten.squeeze_copy)
219
+ lower_by_torch_xla2(torch.ops.aten.stack)
220
+ lower_by_torch_xla2(torch.ops.aten.sub.Scalar)
221
+ lower_by_torch_xla2(torch.ops.aten.sub.Tensor)
222
+ lower_by_torch_xla2(torch.ops.aten.sum)
223
+ lower_by_torch_xla2(torch.ops.aten.sym_size)
224
+ lower_by_torch_xla2(torch.ops.aten.t)
225
+ lower_by_torch_xla2(torch.ops.aten.tan)
226
+ lower_by_torch_xla2(torch.ops.aten.tanh)
227
+ lower_by_torch_xla2(torch.ops.aten.tensor_split.sections)
228
+ lower_by_torch_xla2(torch.ops.aten.tensor_split.sections)
229
+ lower_by_torch_xla2(torch.ops.aten.to.device)
230
+ lower_by_torch_xla2(torch.ops.aten.to.device)
231
+ lower_by_torch_xla2(torch.ops.aten.to.dtype)
232
+ lower_by_torch_xla2(torch.ops.aten.topk)
233
+ lower_by_torch_xla2(torch.ops.aten.transpose)
234
+ lower_by_torch_xla2(torch.ops.aten.transpose_copy)
235
+ lower_by_torch_xla2(torch.ops.aten.triu)
236
+ lower_by_torch_xla2(torch.ops.aten.true_divide)
237
+ lower_by_torch_xla2(torch.ops.aten.trunc)
238
+ lower_by_torch_xla2(torch.ops.aten.unbind_copy)
239
+ lower_by_torch_xla2(torch.ops.aten.unsqueeze)
240
+ lower_by_torch_xla2(torch.ops.aten.unsqueeze.default)
241
+ lower_by_torch_xla2(torch.ops.aten.unsqueeze_copy)
242
+ lower_by_torch_xla2(torch.ops.aten.var.correction)
243
+ lower_by_torch_xla2(torch.ops.aten.var_mean.correction)
244
+ lower_by_torch_xla2(torch.ops.aten.view)
245
+ lower_by_torch_xla2(torch.ops.aten.view_as_complex)
246
+ lower_by_torch_xla2(torch.ops.aten.view_as_real)
247
+ lower_by_torch_xla2(torch.ops.aten.view_copy)
248
+ lower_by_torch_xla2(torch.ops.aten.where.ScalarOther)
249
+ lower_by_torch_xla2(torch.ops.aten.where.ScalarSelf)
250
+ lower_by_torch_xla2(torch.ops.aten.where.self)
251
+ lower_by_torch_xla2(torch.ops.prims.broadcast_in_dim)
252
+ lower_by_torch_xla2(torch.ops.prims.var)
253
+
254
+
255
+ @lower_by_jax(torch.ops.aten.unbind)
256
+ def _aten_copy(self, *args, **kwargs):
257
+ return _TORCH_XLA2_IMPLS[torch.ops.aten.unbind_copy](self, *args, **kwargs)
258
+
259
+
260
+ @lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
261
+ def _aten_copy(self, src, **kwargs):
262
+ return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
263
+
264
+
265
+ # Schema:
266
+ # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
267
+ # -> Tensor
268
+ # Torch Reference:
269
+ # - https://pytorch.org/docs/stable/generated/torch.einsum.html
270
+ # - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
271
+ @registry.lower(torch.ops.aten.einsum.default)
272
+ def _aten_einsum_default(
273
+ lctx: LoweringContext,
274
+ equation: str,
275
+ tensors: list[ir.Value],
276
+ path=None,
277
+ ):
278
+ _log_usage(torch.ops.aten.einsum.default)
279
+
280
+ @jax_bridge.wrap
281
+ def jax_lowering(operands):
282
+ # Ignore the input path and let JAX determine the path.
283
+ return jnp.einsum(equation, *operands, optimize="optimal")
284
+
285
+ return jax_lowering(lctx, tuple(tensors))
@@ -0,0 +1,87 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to stablehlo for LayerNorm."""
16
+
17
+ import math
18
+ from typing import Optional
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from ai_edge_torch.odml_torch.lowerings import utils
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
24
+ import torch
25
+
26
+
27
+ # native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
28
+ # Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
29
+ @registry.lower(torch.ops.aten.native_layer_norm)
30
+ def _aten_native_layer_norm(
31
+ lctx,
32
+ data: ir.Value,
33
+ normalized_shape: list[int],
34
+ weight: Optional[ir.Value],
35
+ bias: Optional[ir.Value],
36
+ eps: float,
37
+ ):
38
+ data_type: ir.RankedTensorType = data.type
39
+ unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
40
+ dest_shape = [
41
+ 1,
42
+ unnormalized_count,
43
+ math.prod(normalized_shape),
44
+ ]
45
+ dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
46
+
47
+ reshaped_data = stablehlo.reshape(dest_type, data)
48
+
49
+ one = utils.splat(1, data_type.element_type, [unnormalized_count])
50
+ zero = utils.splat(0, data_type.element_type, [unnormalized_count])
51
+ output, mean, var = stablehlo.batch_norm_training(
52
+ reshaped_data, one, zero, eps, 1
53
+ )
54
+ eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
55
+ rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
56
+
57
+ stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
58
+ normalized_shape
59
+ )
60
+ stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
61
+ mean = stablehlo.reshape(stats_type, mean)
62
+ rstd = stablehlo.reshape(stats_type, rstd)
63
+
64
+ output = stablehlo.reshape(data_type, output)
65
+
66
+ data_rank = len(data_type.shape)
67
+ normalized_rank = len(normalized_shape)
68
+ if weight is not None:
69
+ weight = stablehlo.broadcast_in_dim(
70
+ data_type,
71
+ weight,
72
+ ir.DenseI64ArrayAttr.get(
73
+ list(range(data_rank - normalized_rank, data_rank))
74
+ ),
75
+ )
76
+ output = stablehlo.multiply(weight, output)
77
+ if bias is not None:
78
+ bias = stablehlo.broadcast_in_dim(
79
+ data_type,
80
+ bias,
81
+ ir.DenseI64ArrayAttr.get(
82
+ list(range(data_rank - normalized_rank, data_rank))
83
+ ),
84
+ )
85
+ output = stablehlo.add(bias, output)
86
+
87
+ return output, mean, rstd