ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,202 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+
20
+ import torch
21
+
22
+ from ai_edge_torch import model
23
+ from ai_edge_torch.convert import conversion
24
+ from ai_edge_torch.convert import conversion_utils as cutils
25
+ from ai_edge_torch.quantize import quant_config as qcfg
26
+
27
+
28
+ class Converter:
29
+
30
+ def __init__(self):
31
+ self._signatures: list[cutils.Signature] = []
32
+
33
+ def signature(
34
+ self,
35
+ name: str,
36
+ module: torch.nn.Module,
37
+ sample_args=None,
38
+ sample_kwargs=None,
39
+ *,
40
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
41
+ ) -> Converter:
42
+ """Alias to `add_signature`"""
43
+ return self.add_signature(
44
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
45
+ )
46
+
47
+ def add_signature(
48
+ self,
49
+ name: str,
50
+ module: torch.nn.Module,
51
+ sample_args=None,
52
+ sample_kwargs=None,
53
+ *,
54
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
55
+ ) -> Converter:
56
+ """Allows adding a new named torch model along with sample args to the conversion.
57
+
58
+ Args:
59
+ name: The name of the signature included in the converted edge model.
60
+ module: The torch module to be converted.
61
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
62
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
63
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
64
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
65
+
66
+ Raises:
67
+ ValueError: If a signature with the provided name already exists.
68
+ """
69
+
70
+ if name in [sig.name for sig in self._signatures]:
71
+ raise ValueError(f"A signature with the provided name ({name}) is already added.")
72
+
73
+ if sample_args is None and sample_kwargs is None:
74
+ raise ValueError("sample_args or sample_kwargs must be provided.")
75
+
76
+ self._signatures.append(
77
+ cutils.Signature(
78
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
79
+ )
80
+ )
81
+ return self
82
+
83
+ def convert(
84
+ self,
85
+ module: torch.nn.Module = None,
86
+ sample_args=None,
87
+ sample_kwargs=None,
88
+ *,
89
+ quant_config: Optional[qcfg.QuantConfig] = None,
90
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
91
+ _ai_edge_converter_flags: dict = {},
92
+ ) -> model.TfLiteModel:
93
+ """Finalizes the conversion and produces an edge model.
94
+
95
+ This could be called with no arguments as follows:
96
+
97
+ edge_model = Converter().signature(name, module, args).convert()
98
+
99
+ Or it could be used to set the default signature for the converted edge model:
100
+
101
+ edge_model = Converter().convert(module, args)
102
+
103
+ Args:
104
+ name: The name of the signature included in the converted edge model.
105
+ module: The torch module to be converted.
106
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
107
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
108
+ quant_config: User-defined quantization method and scheme of the model.
109
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
110
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
111
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
112
+ This gives access to an implementation detail of this function and so needs to be treated as such.
113
+ Please do not rely on this parameter except for local debugging as this can be removed in a future release.
114
+
115
+ Raises:
116
+ ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
117
+ """
118
+ if module is not None:
119
+ if (
120
+ sample_args is not None or sample_kwargs is not None
121
+ ): # both module and args provided
122
+ self.add_signature(
123
+ cutils.DEFAULT_SIGNATURE_NAME,
124
+ module,
125
+ sample_args,
126
+ sample_kwargs,
127
+ dynamic_shapes=dynamic_shapes,
128
+ )
129
+ else: # module is provided but not args
130
+ raise ValueError(
131
+ "sample_args or sample_kwargs must be provided if a module is specified."
132
+ )
133
+
134
+ return conversion.convert_signatures(
135
+ self._signatures,
136
+ quant_config=quant_config,
137
+ _tfl_converter_flags=_ai_edge_converter_flags,
138
+ )
139
+
140
+
141
+ def signature(
142
+ name: str,
143
+ module: torch.nn.Module,
144
+ sample_args=None,
145
+ sample_kwargs=None,
146
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
147
+ ) -> Converter:
148
+ """Initiates a Converter object with the provided signature.
149
+
150
+ Args:
151
+ name: The name of the signature included in the converted edge model.
152
+ module: The torch module to be converted.
153
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
154
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
155
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
156
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
157
+
158
+ Example:
159
+ converter = ai_edge_torch.signature(name, module, args)
160
+ edge_model = converter.convert()
161
+
162
+ """
163
+ return Converter().signature(
164
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
165
+ )
166
+
167
+
168
+ def convert(
169
+ module: torch.nn.Module = None,
170
+ sample_args=None,
171
+ sample_kwargs=None,
172
+ *,
173
+ quant_config: Optional[qcfg.QuantConfig] = None,
174
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
175
+ _ai_edge_converter_flags: dict = {},
176
+ ) -> model.TfLiteModel:
177
+ """Allows converting a PyTorch model to an edge model with one default signature in one step.
178
+
179
+ Args:
180
+ module: The torch module to be converted.
181
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
182
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
183
+ quant_config: User-defined quantization method and scheme of the model.
184
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
185
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
186
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
187
+ This gives access to an implementation detail of this function and so needs to be treated as such.
188
+ Please do not rely on this parameter except for local debugging as this can be removed in a future release.
189
+
190
+ Example:
191
+ edge_model = ai_edge_torch.convert(module, args)
192
+
193
+ """
194
+
195
+ return Converter().convert(
196
+ module,
197
+ sample_args,
198
+ sample_kwargs,
199
+ quant_config=quant_config,
200
+ dynamic_shapes=dynamic_shapes,
201
+ _ai_edge_converter_flags=_ai_edge_converter_flags,
202
+ )
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Sequence, Union
17
+
18
+ from torch.export import ExportedProgram
19
+ from torch.fx.passes.infra.pass_manager import pass_result_wrapper
20
+ import torch.utils._pytree as pytree
21
+
22
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
23
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
24
+ from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
25
+ from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
26
+ from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
27
+ from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
28
+ from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
29
+ from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
30
+ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
31
+
32
+
33
+ # TODO(cnchan): make a PassManager class.
34
+ def run_passes(
35
+ exported_program: ExportedProgram,
36
+ passes: Sequence[Union[ExportedProgramPassBase, FxPassBase]],
37
+ ) -> ExportedProgram:
38
+ passes, _ = pytree.tree_flatten(passes)
39
+ for pass_ in passes:
40
+ if not isinstance(pass_, ExportedProgramPassBase):
41
+ pass_ = pass_result_wrapper(pass_)
42
+ if isinstance(pass_, ExportedProgramPassBase):
43
+ exported_program = pass_(exported_program).exported_program
44
+ else:
45
+ gm = exported_program.graph_module
46
+ gm, modified = pass_(gm)
47
+ if modified and gm is not exported_program.graph_module:
48
+ exported_program = ExportedProgram(
49
+ root=gm,
50
+ graph=gm.graph,
51
+ graph_signature=exported_program.graph_signature,
52
+ state_dict=exported_program.state_dict,
53
+ range_constraints=exported_program.range_constraints,
54
+ module_call_graph=exported_program.module_call_graph,
55
+ example_inputs=exported_program.example_inputs,
56
+ verifier=exported_program.verifier,
57
+ constants=exported_program.constants,
58
+ )
59
+ return exported_program
@@ -0,0 +1,49 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import abc
17
+ from collections import namedtuple
18
+
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+ from torch.fx.passes.infra.pass_base import PassBase as FxPassBase
22
+ from torch.fx.passes.infra.pass_base import PassResult as FxPassResult
23
+
24
+
25
+ class ExportedProgramPassResult(
26
+ namedtuple("ExportedProgramPassResult", ["exported_program", "modified"])
27
+ ):
28
+
29
+ def __new__(cls, exported_program, modified):
30
+ return super().__new__(cls, exported_program, modified)
31
+
32
+
33
+ class ExportedProgramPassBase(abc.ABC):
34
+
35
+ def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
36
+ self.requires(exported_program)
37
+ res = self.call(exported_program)
38
+ self.ensures(exported_program)
39
+ return res
40
+
41
+ @abc.abstractmethod
42
+ def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
43
+ pass
44
+
45
+ def requires(self, exported_program: ExportedProgram) -> None:
46
+ pass
47
+
48
+ def ensures(self, exported_program: ExportedProgram) -> None:
49
+ pass
@@ -0,0 +1,225 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import copy
17
+ import functools
18
+ from typing import Any, Callable
19
+
20
+ import torch
21
+ from torch.fx import GraphModule
22
+ from torch.fx import Node
23
+ from torch.fx.passes.infra.pass_base import PassBase
24
+ from torch.fx.passes.infra.pass_base import PassResult
25
+ import torch.utils._pytree as pytree
26
+
27
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
28
+
29
+ _composite_builders: dict[Callable, Callable[[GraphModule, Node], None]] = {}
30
+
31
+
32
+ def _register_composite_builder(op):
33
+ def inner(func):
34
+ if isinstance(op, torch._ops.OpOverloadPacket):
35
+ for overload in v.overloads():
36
+ _composite_builders[getattr(v, overload)] = func
37
+ else:
38
+ _composite_builders[op] = func
39
+ return func
40
+
41
+ return inner
42
+
43
+
44
+ def _tree_map_to_composite_attr_values(values, *, stringify_incompatible_values=True):
45
+
46
+ def convert(value):
47
+ nonlocal stringify_incompatible_values
48
+ if value is None:
49
+ return "py_None"
50
+ if isinstance(value, (str, int, float, bool)):
51
+ return value
52
+
53
+ if stringify_incompatible_values:
54
+ return str(value)
55
+ return value
56
+
57
+ return pytree.tree_map(convert, values)
58
+
59
+
60
+ class TorchOpArgumentsMapper:
61
+
62
+ def __init__(self, op):
63
+ if isinstance(op, torch._ops.OpOverloadPacket):
64
+ op = op.default
65
+
66
+ assert hasattr(op, "_schema")
67
+ self.op = op
68
+ self.arg_specs = [(spec.name, spec.default_value) for spec in op._schema.arguments]
69
+
70
+ def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
71
+ """Inspect the op's schema and extract all its args and kwargs
72
+ into one single kwargs dict, with default values for those
73
+ unspecified args and kwargs.
74
+ """
75
+ full_kwargs = {**(kwargs or {})}
76
+
77
+ for arg, (name, default_value) in zip(args, self.arg_specs):
78
+ full_kwargs[name] = arg
79
+
80
+ for name, default_value in self.arg_specs[len(args) :]:
81
+ if name not in full_kwargs:
82
+ full_kwargs[name] = default_value
83
+
84
+ return full_kwargs
85
+
86
+
87
+ @_register_composite_builder(torch.ops.aten.hardswish.default)
88
+ def _aten_hardswish(gm: GraphModule, node: Node):
89
+ op = node.target
90
+
91
+ def hardswish(self: torch.Tensor):
92
+ nonlocal op
93
+ builder = StableHLOCompositeBuilder("aten.hardswish.default")
94
+ self = builder.mark_inputs(self)
95
+ output = op(self)
96
+ output = builder.mark_outputs(output)
97
+ return output
98
+
99
+ node.target = hardswish
100
+
101
+
102
+ @_register_composite_builder(torch.ops.aten.gelu.default)
103
+ def _aten_gelu(gm: GraphModule, node: Node):
104
+ op = node.target
105
+ args_mapper = TorchOpArgumentsMapper(op)
106
+
107
+ def gelu(*args, **kwargs):
108
+ nonlocal op, args_mapper
109
+
110
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
111
+
112
+ # TFLite supports exact and tanh approximate.
113
+ if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
114
+ return op(*args, **kwargs)
115
+
116
+ builder = StableHLOCompositeBuilder(
117
+ "aten.gelu.default",
118
+ attr=_tree_map_to_composite_attr_values(
119
+ {
120
+ "approximate": full_kwargs["approximate"],
121
+ }
122
+ ),
123
+ )
124
+ full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
125
+ output = op(full_kwargs["self"])
126
+ output = builder.mark_outputs(output)
127
+ return output
128
+
129
+ node.target = gelu
130
+
131
+
132
+ @_register_composite_builder(torch.ops.aten.avg_pool2d.default)
133
+ def _aten_avg_pool2d(gm: GraphModule, node: Node):
134
+ op = node.target
135
+ args_mapper = TorchOpArgumentsMapper(op)
136
+
137
+ def avg_pool2d(*args, **kwargs):
138
+ nonlocal op, args_mapper
139
+
140
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
141
+
142
+ def is_same_padding(
143
+ input_shape: list[int],
144
+ kernel_size: list[int],
145
+ stride: list[int],
146
+ padding: list[int],
147
+ ):
148
+ for dim_input_size, dim_kernel_size, dim_stride, dim_padding in zip(
149
+ input_shape, kernel_size, stride, padding
150
+ ):
151
+ dim_output_size = int((dim_input_size + dim_stride - 1) / dim_stride)
152
+ padding_needed = max(
153
+ 0, (dim_output_size - 1) * dim_stride + dim_kernel_size - dim_input_size
154
+ )
155
+ if padding_needed % 2 != 0:
156
+ return False
157
+
158
+ if padding_needed // 2 != dim_padding:
159
+ return False
160
+ return True
161
+
162
+ def is_valid_padding(padding: list[int]):
163
+ return not any(padding)
164
+
165
+ # We prefer to avoid passing empty arrays to composite attributes
166
+ # as they will be lowered to an ArrayAttr so canonicalizing according
167
+ # to the default behaviour here.
168
+ if not full_kwargs["stride"]:
169
+ full_kwargs["stride"] = full_kwargs["kernel_size"]
170
+
171
+ # Only wrap in a composite when the underlying converter can handle it.
172
+ # TODO We should be able to remove this if the converter can inline composites when it can not handle them.
173
+
174
+ # We don't cover any cases where the divisor_override is set.
175
+ if full_kwargs["divisor_override"] is not None:
176
+ return op(*args, **kwargs)
177
+
178
+ if full_kwargs["ceil_mode"] and not full_kwargs["count_include_pad"]:
179
+ return op(*args, **kwargs)
180
+
181
+ # We also can not cover a case where count_include_pad is False but the padding is custom.
182
+ if (
183
+ not full_kwargs["count_include_pad"]
184
+ and not is_valid_padding(full_kwargs["padding"])
185
+ and not is_same_padding(
186
+ list(full_kwargs["self"].shape)[2:],
187
+ full_kwargs["kernel_size"],
188
+ full_kwargs["stride"],
189
+ full_kwargs["padding"],
190
+ )
191
+ ):
192
+ return op(*args, **kwargs)
193
+
194
+ builder = StableHLOCompositeBuilder(
195
+ "aten.avg_pool2d.default",
196
+ attr=_tree_map_to_composite_attr_values(
197
+ {
198
+ "kernel_size": full_kwargs["kernel_size"],
199
+ "stride": full_kwargs["stride"],
200
+ "padding": full_kwargs["padding"],
201
+ "ceil_mode": full_kwargs["ceil_mode"],
202
+ "count_include_pad": full_kwargs["count_include_pad"],
203
+ "divisor_override": full_kwargs["divisor_override"],
204
+ }
205
+ ),
206
+ )
207
+
208
+ full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
209
+ output = op(**full_kwargs)
210
+ output = builder.mark_outputs(output)
211
+ return output
212
+
213
+ node.target = avg_pool2d
214
+
215
+
216
+ class BuildAtenCompositePass(PassBase):
217
+
218
+ def call(self, graph_module: GraphModule):
219
+ for node in graph_module.graph.nodes:
220
+ if node.target in _composite_builders:
221
+ _composite_builders[node.target](graph_module, node)
222
+
223
+ graph_module.graph.lint()
224
+ graph_module.recompile()
225
+ return PassResult(graph_module, True)
@@ -0,0 +1,123 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+
18
+ import torch
19
+
20
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
21
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
22
+ from ai_edge_torch.hlfb import mark_pattern
23
+
24
+ # For torch nightly released after mid June 2024,
25
+ # torch.nn.functional.interpolate no longer gets exported into decomposed graph
26
+ # but single aten op torch.ops.aten.upsample_nearest2d.vec/torch.ops.aten.upsample_bilinear2d.vec.
27
+ # This behavior would our pattern matching based composite builder.
28
+ # It requires the pattern and model graph to get decomposed first for backward compatibility.
29
+ _INTERPOLATE_DECOMPOSITIONS = torch._decomp.get_decompositions(
30
+ [
31
+ torch.ops.aten.upsample_bilinear2d.vec,
32
+ torch.ops.aten.upsample_nearest2d.vec,
33
+ ]
34
+ )
35
+
36
+
37
+ @functools.cache
38
+ def _get_upsample_bilinear2d_pattern():
39
+ pattern = mark_pattern.Pattern(
40
+ "odml.upsample_bilinear2d",
41
+ lambda x: torch.nn.functional.interpolate(
42
+ x, scale_factor=2, mode="bilinear", align_corners=False
43
+ ),
44
+ export_args=(torch.rand(1, 3, 100, 100),),
45
+ decomp_table=_INTERPOLATE_DECOMPOSITIONS,
46
+ )
47
+
48
+ @pattern.register_attr_builder
49
+ def attr_builder(pattern, graph_module, internal_match):
50
+ output = internal_match.returning_nodes[0]
51
+ output_h, output_w = output.meta["val"].shape[-2:]
52
+ return {
53
+ "output": (int(output_h), int(output_w)),
54
+ "align_corners": False,
55
+ }
56
+
57
+ return pattern
58
+
59
+
60
+ @functools.cache
61
+ def _get_upsample_bilinear2d_align_corners_pattern():
62
+ pattern = mark_pattern.Pattern(
63
+ "odml.upsample_bilinear2d",
64
+ lambda x: torch.nn.functional.interpolate(
65
+ x, scale_factor=2, mode="bilinear", align_corners=True
66
+ ),
67
+ export_args=(torch.rand(1, 3, 100, 100),),
68
+ decomp_table=_INTERPOLATE_DECOMPOSITIONS,
69
+ )
70
+
71
+ @pattern.register_attr_builder
72
+ def attr_builder(graph_module, pattern, internal_match):
73
+ output = internal_match.returning_nodes[0]
74
+ output_h, output_w = output.meta["val"].shape[-2:]
75
+ return {
76
+ "output": (int(output_h), int(output_w)),
77
+ "align_corners": True,
78
+ }
79
+
80
+ return pattern
81
+
82
+
83
+ @functools.cache
84
+ def _get_interpolate_nearest2d_pattern():
85
+ pattern = mark_pattern.Pattern(
86
+ "tfl.resize_nearest_neighbor",
87
+ lambda x: torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest"),
88
+ export_args=(torch.rand(1, 3, 100, 100),),
89
+ decomp_table=_INTERPOLATE_DECOMPOSITIONS,
90
+ )
91
+
92
+ @pattern.register_attr_builder
93
+ def attr_builder(pattern, graph_module, internal_match):
94
+ output = internal_match.returning_nodes[0]
95
+ output_h, output_w = output.meta["val"].shape[-2:]
96
+ return {
97
+ "size": (int(output_h), int(output_w)),
98
+ "is_nchw_op": True,
99
+ }
100
+
101
+ return pattern
102
+
103
+
104
+ class BuildInterpolateCompositePass(ExportedProgramPassBase):
105
+
106
+ def __init__(self):
107
+ super().__init__()
108
+ self._patterns = [
109
+ _get_upsample_bilinear2d_pattern(),
110
+ _get_upsample_bilinear2d_align_corners_pattern(),
111
+ _get_interpolate_nearest2d_pattern(),
112
+ ]
113
+
114
+ def call(self, exported_program: torch.export.ExportedProgram):
115
+ exported_program = exported_program.run_decompositions(_INTERPOLATE_DECOMPOSITIONS)
116
+
117
+ graph_module = exported_program.graph_module
118
+ for pattern in self._patterns:
119
+ graph_module = mark_pattern.mark_pattern(graph_module, pattern)
120
+
121
+ graph_module.graph.lint()
122
+ graph_module.recompile()
123
+ return ExportedProgramPassResult(exported_program, True)
@@ -0,0 +1,37 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch
17
+ from torch.export import ExportedProgram
18
+
19
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
20
+ from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
21
+
22
+ # A dummy decomp table for running ExportedProgram.run_decompositions without
23
+ # any op decompositions but just aot_export_module. Due to the check in
24
+ # run_decompositions, if None or an empty dict is passed as decomp_table,
25
+ # it will run the default aten-coreaten decompositions. Therefore a non-empty
26
+ # dummy decomp table is needed.
27
+ # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
28
+ _dummy_decomp_table = {
29
+ torch._ops.OperatorBase(): lambda: None,
30
+ }
31
+
32
+
33
+ class CanonicalizePass(ExportedProgramPassBase):
34
+
35
+ def call(self, exported_program: ExportedProgram):
36
+ exported_program = exported_program.run_decompositions(_dummy_decomp_table)
37
+ return ExportedProgramPassResult(exported_program, True)