ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (213) hide show
  1. ai_edge_torch/__init__.py +32 -0
  2. ai_edge_torch/_config.py +69 -0
  3. ai_edge_torch/_convert/__init__.py +14 -0
  4. ai_edge_torch/_convert/conversion.py +153 -0
  5. ai_edge_torch/_convert/conversion_utils.py +64 -0
  6. ai_edge_torch/_convert/converter.py +270 -0
  7. ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
  8. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
  9. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
  10. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  11. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  12. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
  13. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
  14. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
  15. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
  16. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
  17. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
  18. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  19. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
  20. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
  21. ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
  22. ai_edge_torch/_convert/signature.py +66 -0
  23. ai_edge_torch/_convert/test/__init__.py +14 -0
  24. ai_edge_torch/_convert/test/test_convert.py +558 -0
  25. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  26. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  27. ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
  28. ai_edge_torch/_convert/to_channel_last_io.py +92 -0
  29. ai_edge_torch/conftest.py +20 -0
  30. ai_edge_torch/debug/__init__.py +17 -0
  31. ai_edge_torch/debug/culprit.py +496 -0
  32. ai_edge_torch/debug/test/__init__.py +14 -0
  33. ai_edge_torch/debug/test/test_culprit.py +140 -0
  34. ai_edge_torch/debug/test/test_search_model.py +51 -0
  35. ai_edge_torch/debug/utils.py +59 -0
  36. ai_edge_torch/experimental/__init__.py +14 -0
  37. ai_edge_torch/fx_pass_base.py +110 -0
  38. ai_edge_torch/generative/__init__.py +14 -0
  39. ai_edge_torch/generative/examples/__init__.py +14 -0
  40. ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
  42. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
  43. ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
  44. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
  46. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
  47. ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
  48. ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
  49. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  50. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
  51. ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
  52. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
  54. ai_edge_torch/generative/examples/llama/llama.py +196 -0
  55. ai_edge_torch/generative/examples/llama/verify.py +88 -0
  56. ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
  57. ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
  58. ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
  59. ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
  60. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
  61. ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
  62. ai_edge_torch/generative/examples/openelm/verify.py +71 -0
  63. ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
  64. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
  65. ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
  66. ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
  67. ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
  68. ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
  69. ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
  70. ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
  71. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  72. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
  73. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  74. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
  75. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
  76. ai_edge_torch/generative/examples/phi/phi2.py +107 -0
  77. ai_edge_torch/generative/examples/phi/phi3.py +219 -0
  78. ai_edge_torch/generative/examples/phi/verify.py +64 -0
  79. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  80. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  81. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
  82. ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
  83. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  84. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  85. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
  86. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  87. ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
  88. ai_edge_torch/generative/examples/smollm/verify.py +86 -0
  89. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  90. ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
  91. ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
  92. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
  93. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
  94. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
  95. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
  96. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
  97. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  98. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
  99. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
  100. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
  101. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
  102. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
  103. ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
  104. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  105. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
  106. ai_edge_torch/generative/examples/t5/t5.py +655 -0
  107. ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
  108. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  109. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  110. ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
  111. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
  112. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  113. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
  114. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
  115. ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
  116. ai_edge_torch/generative/fx_passes/__init__.py +30 -0
  117. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
  118. ai_edge_torch/generative/layers/__init__.py +14 -0
  119. ai_edge_torch/generative/layers/attention.py +399 -0
  120. ai_edge_torch/generative/layers/attention_utils.py +210 -0
  121. ai_edge_torch/generative/layers/builder.py +160 -0
  122. ai_edge_torch/generative/layers/feed_forward.py +120 -0
  123. ai_edge_torch/generative/layers/kv_cache.py +204 -0
  124. ai_edge_torch/generative/layers/lora.py +557 -0
  125. ai_edge_torch/generative/layers/model_config.py +238 -0
  126. ai_edge_torch/generative/layers/normalization.py +222 -0
  127. ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
  128. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
  129. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  130. ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
  131. ai_edge_torch/generative/layers/unet/builder.py +50 -0
  132. ai_edge_torch/generative/layers/unet/model_config.py +282 -0
  133. ai_edge_torch/generative/quantize/__init__.py +14 -0
  134. ai_edge_torch/generative/quantize/example.py +47 -0
  135. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  136. ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
  137. ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
  138. ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
  139. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  140. ai_edge_torch/generative/test/__init__.py +14 -0
  141. ai_edge_torch/generative/test/test_custom_dus.py +107 -0
  142. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  143. ai_edge_torch/generative/test/test_loader.py +83 -0
  144. ai_edge_torch/generative/test/test_lora.py +147 -0
  145. ai_edge_torch/generative/test/test_model_conversion.py +191 -0
  146. ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
  147. ai_edge_torch/generative/test/test_quantize.py +183 -0
  148. ai_edge_torch/generative/test/utils.py +82 -0
  149. ai_edge_torch/generative/utilities/__init__.py +15 -0
  150. ai_edge_torch/generative/utilities/converter.py +215 -0
  151. ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
  152. ai_edge_torch/generative/utilities/loader.py +398 -0
  153. ai_edge_torch/generative/utilities/model_builder.py +180 -0
  154. ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
  155. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
  156. ai_edge_torch/generative/utilities/t5_loader.py +512 -0
  157. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  158. ai_edge_torch/generative/utilities/verifier.py +335 -0
  159. ai_edge_torch/hlfb/__init__.py +16 -0
  160. ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
  161. ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
  162. ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
  163. ai_edge_torch/hlfb/test/__init__.py +14 -0
  164. ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
  165. ai_edge_torch/lowertools/__init__.py +18 -0
  166. ai_edge_torch/lowertools/_shim.py +86 -0
  167. ai_edge_torch/lowertools/common_utils.py +142 -0
  168. ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
  169. ai_edge_torch/lowertools/test_utils.py +62 -0
  170. ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
  171. ai_edge_torch/lowertools/translate_recipe.py +163 -0
  172. ai_edge_torch/model.py +177 -0
  173. ai_edge_torch/odml_torch/__init__.py +20 -0
  174. ai_edge_torch/odml_torch/_torch_future.py +88 -0
  175. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  176. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  177. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  178. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  179. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  180. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  181. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  182. ai_edge_torch/odml_torch/export.py +403 -0
  183. ai_edge_torch/odml_torch/export_utils.py +157 -0
  184. ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
  185. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
  186. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  187. ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
  188. ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
  189. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  190. ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
  191. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
  192. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
  193. ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
  194. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  195. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  196. ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
  197. ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
  198. ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
  199. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  200. ai_edge_torch/odml_torch/tf_integration.py +156 -0
  201. ai_edge_torch/quantize/__init__.py +16 -0
  202. ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
  203. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
  204. ai_edge_torch/quantize/quant_config.py +85 -0
  205. ai_edge_torch/testing/__init__.py +14 -0
  206. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  207. ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
  208. ai_edge_torch/version.py +16 -0
  209. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
  210. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
  211. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
  212. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
  213. ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1061 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from dataclasses import dataclass
17
+ import itertools
18
+ import operator
19
+ from typing import Callable, Dict, List, NamedTuple, Optional
20
+
21
+ import torch
22
+ from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
23
+ from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
24
+ from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
25
+ from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
26
+ from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
27
+ from torch.ao.quantization.quantizer import QuantizationAnnotation
28
+ from torch.ao.quantization.quantizer import QuantizationSpec
29
+ from torch.ao.quantization.quantizer import QuantizationSpecBase
30
+ from torch.ao.quantization.quantizer import SharedQuantizationSpec
31
+ from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map
32
+ from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
33
+ from torch.fx import Node
34
+ from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap
35
+ from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
36
+ import torch.nn.functional as F
37
+
38
+ __all__ = [
39
+ "OperatorConfig",
40
+ "OperatorPatternType",
41
+ "QuantizationConfig",
42
+ "get_input_act_qspec",
43
+ "get_output_act_qspec",
44
+ "get_weight_qspec",
45
+ "get_bias_qspec",
46
+ "OP_TO_ANNOTATOR",
47
+ "propagate_annotation",
48
+ ]
49
+
50
+
51
+ @dataclass(eq=True, frozen=True)
52
+ class QuantizationConfig:
53
+ input_activation: Optional[QuantizationSpec]
54
+ output_activation: Optional[QuantizationSpec]
55
+ weight: Optional[QuantizationSpec]
56
+ bias: Optional[QuantizationSpec]
57
+ fixed_qparams: Optional[QuantizationSpec]
58
+ # TODO: remove, since we can use observer_or_fake_quant_ctr to express this
59
+ is_qat: bool = False
60
+ is_dynamic: bool = False
61
+
62
+
63
+ OperatorPatternType = List[Callable]
64
+ OperatorPatternType.__module__ = "ai_edge_torch.quantize.pt2e_quantizer_utils"
65
+
66
+ AnnotatorType = Callable[
67
+ [
68
+ torch.fx.GraphModule,
69
+ Optional[QuantizationConfig],
70
+ Optional[Callable[[Node], bool]],
71
+ ],
72
+ Optional[List[List[Node]]],
73
+ ]
74
+ OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {}
75
+
76
+
77
+ def register_annotator(op: str):
78
+ def decorator(annotator: AnnotatorType):
79
+ OP_TO_ANNOTATOR[op] = annotator
80
+
81
+ return decorator
82
+
83
+
84
+ class OperatorConfig(NamedTuple):
85
+ # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
86
+ # Basically we are mapping a quantization config to some list of patterns.
87
+ # a pattern is defined as a list of nn module, function or builtin function names
88
+ # e.g. [nn.Conv2d, torch.relu, torch.add]
89
+ # We have not resolved whether fusion can be considered internal details of the
90
+ # quantizer hence it does not need communication to user.
91
+ # Note this pattern is not really informative since it does not really
92
+ # tell us the graph structure resulting from the list of ops.
93
+ config: QuantizationConfig
94
+ operators: List[OperatorPatternType]
95
+
96
+
97
+ def _is_annotated(nodes: List[Node]):
98
+ """Checks if a list of nodes is annotated.
99
+
100
+ Given a list of nodes (that represents an operator pattern), check if any of
101
+ the node is annotated, return True if any of the node
102
+ is annotated, otherwise return False
103
+ """
104
+ annotated = False
105
+ for node in nodes:
106
+ annotated = annotated or (
107
+ "quantization_annotation" in node.meta
108
+ and node.meta["quantization_annotation"]._annotated
109
+ )
110
+ return annotated
111
+
112
+
113
+ def _mark_nodes_as_annotated(nodes: List[Node]):
114
+ for node in nodes:
115
+ if node is not None:
116
+ if "quantization_annotation" not in node.meta:
117
+ node.meta["quantization_annotation"] = QuantizationAnnotation()
118
+ node.meta["quantization_annotation"]._annotated = True
119
+
120
+
121
+ def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
122
+ if quantization_config is None:
123
+ return None
124
+ if quantization_config.input_activation is None:
125
+ return None
126
+ quantization_spec: QuantizationSpec = quantization_config.input_activation
127
+ assert quantization_spec.qscheme in [
128
+ torch.per_tensor_affine,
129
+ torch.per_tensor_symmetric,
130
+ ]
131
+ return quantization_spec
132
+
133
+
134
+ def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
135
+ if quantization_config is None:
136
+ return None
137
+ if quantization_config.output_activation is None:
138
+ return None
139
+ quantization_spec: QuantizationSpec = quantization_config.output_activation
140
+ assert quantization_spec.qscheme in [
141
+ torch.per_tensor_affine,
142
+ torch.per_tensor_symmetric,
143
+ ]
144
+ return quantization_spec
145
+
146
+
147
+ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
148
+ if quantization_config is None:
149
+ return None
150
+ assert quantization_config is not None
151
+ if quantization_config.weight is None:
152
+ return None
153
+ quantization_spec: QuantizationSpec = quantization_config.weight
154
+ if quantization_spec.qscheme not in [
155
+ torch.per_tensor_symmetric,
156
+ torch.per_channel_symmetric,
157
+ ]:
158
+ raise ValueError(
159
+ f"Unsupported quantization_spec {quantization_spec} for weight"
160
+ )
161
+ return quantization_spec
162
+
163
+
164
+ def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
165
+ if quantization_config is None:
166
+ return None
167
+ assert quantization_config is not None
168
+ if quantization_config.bias is None:
169
+ return None
170
+ quantization_spec: QuantizationSpec = quantization_config.bias
171
+ assert (
172
+ quantization_spec.dtype == torch.float
173
+ ), "Only float dtype for bias is supported for bias right now"
174
+ return quantization_spec
175
+
176
+
177
+ def get_fixed_qparams_qspec(quantization_config: Optional[QuantizationConfig]):
178
+ if quantization_config is None:
179
+ return None
180
+ assert quantization_config is not None
181
+ if quantization_config.fixed_qparams is None:
182
+ return None
183
+ quantization_spec: QuantizationSpec = quantization_config.fixed_qparams
184
+ return quantization_spec
185
+
186
+
187
+ @register_annotator("linear")
188
+ def _annotate_linear(
189
+ gm: torch.fx.GraphModule,
190
+ quantization_config: Optional[QuantizationConfig],
191
+ filter_fn: Optional[Callable[[Node], bool]] = None,
192
+ ) -> Optional[List[List[Node]]]:
193
+ annotated_partitions = []
194
+ input_act_qspec = get_input_act_qspec(quantization_config)
195
+ output_act_qspec = get_output_act_qspec(quantization_config)
196
+ weight_qspec = get_weight_qspec(quantization_config)
197
+ bias_qspec = get_bias_qspec(quantization_config)
198
+ for node in gm.graph.nodes:
199
+ if (
200
+ node.op != "call_function"
201
+ or node.target != torch.ops.aten.linear.default
202
+ ):
203
+ continue
204
+ if filter_fn and not filter_fn(node):
205
+ continue
206
+ act_node = node.args[0]
207
+ weight_node = node.args[1]
208
+ bias_node = None
209
+ if len(node.args) > 2:
210
+ bias_node = node.args[2]
211
+
212
+ if _is_annotated([node]) is False: # type: ignore[list-item]
213
+ _annotate_input_qspec_map(
214
+ node,
215
+ act_node,
216
+ input_act_qspec,
217
+ )
218
+ _annotate_input_qspec_map(
219
+ node,
220
+ weight_node,
221
+ weight_qspec,
222
+ )
223
+ nodes_to_mark_annotated = [node, weight_node]
224
+ if bias_node:
225
+ _annotate_input_qspec_map(
226
+ node,
227
+ bias_node,
228
+ bias_qspec,
229
+ )
230
+ nodes_to_mark_annotated.append(bias_node)
231
+ _annotate_output_qspec(node, output_act_qspec)
232
+ _mark_nodes_as_annotated(nodes_to_mark_annotated)
233
+ annotated_partitions.append(nodes_to_mark_annotated)
234
+
235
+ return annotated_partitions
236
+
237
+
238
+ @register_annotator("addmm")
239
+ def _annotate_addmm(
240
+ gm: torch.fx.GraphModule,
241
+ quantization_config: Optional[QuantizationConfig],
242
+ filter_fn: Optional[Callable[[Node], bool]] = None,
243
+ ) -> Optional[List[List[Node]]]:
244
+ annotated_partitions = []
245
+ for n in gm.graph.nodes:
246
+ if n.op != "call_function" or n.target not in [
247
+ torch.ops.aten.addmm.default,
248
+ ]:
249
+ continue
250
+ addm_node = n
251
+
252
+ input_qspec_map = {}
253
+ input_act = addm_node.args[0]
254
+ assert isinstance(input_act, Node)
255
+ is_bias = (
256
+ len(list(input_act.meta["val"].size())) < 2
257
+ and input_act.op == "get_attr"
258
+ and "_param_constant" in input_act.target
259
+ )
260
+ input_qspec_map[input_act] = (
261
+ get_bias_qspec(quantization_config)
262
+ if is_bias
263
+ else get_input_act_qspec(quantization_config)
264
+ )
265
+
266
+ mat1_act = addm_node.args[1]
267
+ assert isinstance(mat1_act, Node)
268
+ input_qspec_map[mat1_act] = get_input_act_qspec(quantization_config)
269
+
270
+ mat2_act = addm_node.args[2]
271
+ assert isinstance(mat2_act, Node)
272
+ is_weight = False
273
+ if mat2_act.op == "get_attr" and "_param_constant" in mat2_act.target:
274
+ is_weight = True
275
+ elif mat2_act.target == torch.ops.aten.t.default:
276
+ t_in = mat2_act.args[0]
277
+ if t_in.op == "get_attr" and "_param_constant" in t_in.target:
278
+ is_weight = True
279
+ input_qspec_map[mat2_act] = (
280
+ get_weight_qspec(quantization_config)
281
+ if is_weight
282
+ else get_input_act_qspec(quantization_config)
283
+ )
284
+
285
+ partition = [addm_node, addm_node.args[1], addm_node.args[2]]
286
+
287
+ if _is_annotated(partition):
288
+ continue
289
+
290
+ if filter_fn and any(not filter_fn(n) for n in partition):
291
+ continue
292
+
293
+ addm_node.meta["quantization_annotation"] = QuantizationAnnotation(
294
+ input_qspec_map=input_qspec_map,
295
+ output_qspec=get_output_act_qspec(quantization_config),
296
+ _annotated=True,
297
+ )
298
+ _mark_nodes_as_annotated(partition)
299
+ annotated_partitions.append(partition)
300
+ return annotated_partitions
301
+
302
+
303
+ @register_annotator("conv")
304
+ def _annotate_conv(
305
+ gm: torch.fx.GraphModule,
306
+ quantization_config: Optional[QuantizationConfig],
307
+ filter_fn: Optional[Callable[[Node], bool]] = None,
308
+ ) -> Optional[List[List[Node]]]:
309
+ annotated_partitions = []
310
+ for n in gm.graph.nodes:
311
+ if n.op != "call_function" or n.target not in [
312
+ torch.ops.aten.conv1d.default,
313
+ torch.ops.aten.conv2d.default,
314
+ torch.ops.aten.convolution.default,
315
+ ]:
316
+ continue
317
+ conv_node = n
318
+
319
+ input_qspec_map = {}
320
+ input_act = conv_node.args[0]
321
+ assert isinstance(input_act, Node)
322
+ input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
323
+
324
+ weight = conv_node.args[1]
325
+ assert isinstance(weight, Node)
326
+ input_qspec_map[weight] = get_weight_qspec(quantization_config)
327
+
328
+ # adding weight node to the partition as well
329
+ partition = [conv_node, conv_node.args[1]]
330
+
331
+ bias = conv_node.args[2] if len(conv_node.args) > 2 else None
332
+ if isinstance(bias, Node):
333
+ input_qspec_map[bias] = get_bias_qspec(quantization_config)
334
+ partition.append(bias)
335
+
336
+ if _is_annotated(partition):
337
+ continue
338
+
339
+ if filter_fn and any(not filter_fn(n) for n in partition):
340
+ continue
341
+
342
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
343
+ input_qspec_map=input_qspec_map,
344
+ output_qspec=get_output_act_qspec(quantization_config),
345
+ _annotated=True,
346
+ )
347
+ _mark_nodes_as_annotated(partition)
348
+ annotated_partitions.append(partition)
349
+ return annotated_partitions
350
+
351
+
352
+ @register_annotator("conv_relu")
353
+ def _annotate_conv_relu(
354
+ gm: torch.fx.GraphModule,
355
+ quantization_config: Optional[QuantizationConfig],
356
+ filter_fn: Optional[Callable[[Node], bool]] = None,
357
+ ) -> Optional[List[List[Node]]]:
358
+ annotated_partitions = []
359
+ for n in gm.graph.nodes:
360
+ if n.op != "call_function" or n.target not in [
361
+ torch.ops.aten.relu.default,
362
+ torch.ops.aten.relu_.default,
363
+ ]:
364
+ continue
365
+ relu_node = n
366
+ maybe_conv_node = n.args[0]
367
+ if (
368
+ not isinstance(maybe_conv_node, Node)
369
+ or maybe_conv_node.op != "call_function"
370
+ or maybe_conv_node.target
371
+ not in [
372
+ torch.ops.aten.conv1d.default,
373
+ torch.ops.aten.conv2d.default,
374
+ torch.ops.aten.convolution.default,
375
+ ]
376
+ ):
377
+ continue
378
+ conv_node = maybe_conv_node
379
+
380
+ input_qspec_map = {}
381
+ input_act = conv_node.args[0]
382
+ assert isinstance(input_act, Node)
383
+ input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
384
+
385
+ weight = conv_node.args[1]
386
+ assert isinstance(weight, Node)
387
+ input_qspec_map[weight] = get_weight_qspec(quantization_config)
388
+
389
+ # adding weight node to the partition as well
390
+ partition = [relu_node, conv_node, conv_node.args[1]]
391
+ bias = conv_node.args[2] if len(conv_node.args) > 2 else None
392
+ if isinstance(bias, Node):
393
+ input_qspec_map[bias] = get_bias_qspec(quantization_config)
394
+ partition.append(bias)
395
+
396
+ if _is_annotated(partition):
397
+ continue
398
+
399
+ if filter_fn and any(not filter_fn(n) for n in partition):
400
+ continue
401
+
402
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
403
+ input_qspec_map=input_qspec_map, _annotated=True
404
+ )
405
+ relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
406
+ output_qspec=get_output_act_qspec(
407
+ quantization_config
408
+ ), # type: ignore[arg-type]
409
+ _annotated=True,
410
+ )
411
+ _mark_nodes_as_annotated(partition)
412
+ annotated_partitions.append(partition)
413
+ return annotated_partitions
414
+
415
+
416
+ @register_annotator("conv_bn")
417
+ def _annotate_conv_bn(
418
+ gm: torch.fx.GraphModule,
419
+ quantization_config: Optional[QuantizationConfig],
420
+ filter_fn: Optional[Callable[[Node], bool]] = None,
421
+ ) -> Optional[List[List[Node]]]:
422
+ """Find conv + batchnorm parititions Note: This is only used for QAT.
423
+
424
+ In PTQ, batchnorm should already be fused into the conv.
425
+ """
426
+ return _do_annotate_conv_bn(
427
+ gm, quantization_config, filter_fn, has_relu=False
428
+ )
429
+
430
+
431
+ @register_annotator("conv_bn_relu")
432
+ def _annotate_conv_bn_relu(
433
+ gm: torch.fx.GraphModule,
434
+ quantization_config: Optional[QuantizationConfig],
435
+ filter_fn: Optional[Callable[[Node], bool]] = None,
436
+ ) -> Optional[List[List[Node]]]:
437
+ """Find conv + batchnorm + relu parititions Note: This is only used for QAT.
438
+
439
+ In PTQ, batchnorm should already be fused into the conv.
440
+ """
441
+ return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
442
+
443
+
444
+ def _do_annotate_conv_bn(
445
+ gm: torch.fx.GraphModule,
446
+ quantization_config: Optional[QuantizationConfig],
447
+ filter_fn: Optional[Callable[[Node], bool]],
448
+ has_relu: bool,
449
+ ) -> List[List[Node]]:
450
+ """Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
451
+
452
+ return a list of annotated partitions.
453
+
454
+ The output of the pattern must include a dictionary from string name to node
455
+ for the following names: "input", "conv", "weight", "bias", and "output".
456
+ """
457
+
458
+ def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
459
+ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
460
+ conv = conv_fn(x, conv_weight, conv_bias)
461
+ bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
462
+ if has_relu:
463
+ output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
464
+ else:
465
+ output = bn
466
+ return output, {
467
+ "input": x,
468
+ "conv": conv,
469
+ "weight": conv_weight,
470
+ "bias": conv_bias,
471
+ "output": output,
472
+ }
473
+
474
+ return _conv_bn
475
+
476
+ # Needed for matching, otherwise the matches gets filtered out due to unused
477
+ # nodes returned by batch norm
478
+ gm.graph.eliminate_dead_code()
479
+ gm.recompile()
480
+
481
+ matches = []
482
+ combinations = [
483
+ (F.conv1d, _conv1d_bn_example_inputs),
484
+ (F.conv2d, _conv2d_bn_example_inputs),
485
+ ]
486
+
487
+ # Add `is_cuda` and `relu_is_inplace` dimensions
488
+ combinations = itertools.product(
489
+ combinations,
490
+ [True, False] if torch.cuda.is_available() else [False], # is_cuda
491
+ [True, False] if has_relu else [False], # relu_is_inplace
492
+ )
493
+
494
+ # Match against all conv dimensions and cuda variants
495
+ for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
496
+ pattern = get_pattern(conv_fn, relu_is_inplace)
497
+ pattern = _get_aten_graph_module_for_pattern(
498
+ pattern, example_inputs, is_cuda
499
+ )
500
+ pattern.graph.eliminate_dead_code()
501
+ pattern.recompile()
502
+ matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
503
+ matches.extend(matcher.match(gm.graph))
504
+
505
+ # Annotate nodes returned in the matches
506
+ annotated_partitions = []
507
+ for match in matches:
508
+ name_node_map = match.name_node_map
509
+ input_node = name_node_map["input"]
510
+ conv_node = name_node_map["conv"]
511
+ weight_node = name_node_map["weight"]
512
+ bias_node = name_node_map["bias"]
513
+ output_node = name_node_map["output"]
514
+
515
+ # TODO: annotate the uses of input, weight, and bias separately instead
516
+ # of assuming they come from a single conv node. This is not possible today
517
+ # because input may have multiple users, and we can't rely on the conv node
518
+ # always being the first user. This was the case in models with skip
519
+ # connections like resnet18
520
+
521
+ # Validate conv args
522
+ if conv_node.args[0] is not input_node:
523
+ raise ValueError("Conv arg did not contain input node ", input_node)
524
+ if conv_node.args[1] is not weight_node:
525
+ raise ValueError("Conv arg did not contain weight node ", weight_node)
526
+ if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
527
+ raise ValueError("Conv arg did not contain bias node ", bias_node)
528
+
529
+ # Skip if the partition is already annotated or is filtered out by the user
530
+ partition = [conv_node, weight_node]
531
+ if bias_node is not None:
532
+ partition.append(bias_node)
533
+ if _is_annotated(partition):
534
+ continue
535
+ if filter_fn and any(not filter_fn(n) for n in partition):
536
+ continue
537
+
538
+ # Annotate conv inputs and pattern output
539
+ input_qspec_map = {}
540
+ input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
541
+ input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
542
+ if bias_node is not None:
543
+ input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
544
+ conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
545
+ input_qspec_map=input_qspec_map,
546
+ _annotated=True,
547
+ )
548
+ output_node.meta["quantization_annotation"] = QuantizationAnnotation(
549
+ output_qspec=get_output_act_qspec(
550
+ quantization_config
551
+ ), # type: ignore[arg-type]
552
+ _annotated=True,
553
+ )
554
+ _mark_nodes_as_annotated(partition)
555
+ annotated_partitions.append(partition)
556
+ return annotated_partitions
557
+
558
+
559
+ @register_annotator("gru_io_only")
560
+ def _annotate_gru_io_only(
561
+ gm: torch.fx.GraphModule,
562
+ quantization_config: Optional[QuantizationConfig],
563
+ filter_fn: Optional[Callable[[Node], bool]] = None,
564
+ ) -> Optional[List[List[Node]]]:
565
+ gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
566
+ gru_partitions = list(itertools.chain(*gru_partitions.values()))
567
+ annotated_partitions = []
568
+ for gru_partition in gru_partitions:
569
+ annotated_partitions.append(gru_partition.nodes)
570
+ output_nodes = gru_partition.output_nodes
571
+ input_nodes = gru_partition.input_nodes
572
+ # skip annotation if it is already annotated
573
+ if _is_annotated(input_nodes + output_nodes):
574
+ continue
575
+ # inside each GRU partition, we should be able to annotate each linear
576
+ # subgraph
577
+ input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
578
+ input_act = input_nodes[0]
579
+ input_act_user = next(iter(input_act.users.keys()))
580
+ assert isinstance(input_act, Node)
581
+ assert isinstance(input_act_user, Node)
582
+ input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
583
+ input_qspec_map={
584
+ input_act: get_input_act_qspec(quantization_config),
585
+ },
586
+ _annotated=True,
587
+ )
588
+
589
+ hidden_state = input_nodes[1]
590
+ hidden_state_user = next(iter(hidden_state.users.keys()))
591
+ assert isinstance(hidden_state, Node)
592
+ assert isinstance(hidden_state_user, Node)
593
+ hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
594
+ input_qspec_map={
595
+ hidden_state: get_input_act_qspec(quantization_config),
596
+ },
597
+ _annotated=True,
598
+ )
599
+
600
+ assert len(output_nodes) == 2, "expecting GRU to have two outputs"
601
+ for output in output_nodes:
602
+ output.meta["quantization_annotation"] = QuantizationAnnotation(
603
+ output_qspec=get_output_act_qspec(quantization_config),
604
+ _annotated=True,
605
+ )
606
+ nodes_to_mark_annotated = list(gru_partition.nodes)
607
+ _mark_nodes_as_annotated(nodes_to_mark_annotated)
608
+ return annotated_partitions
609
+
610
+
611
+ @register_annotator("max_pool2d")
612
+ def _annotate_max_pool2d(
613
+ gm: torch.fx.GraphModule,
614
+ quantization_config: Optional[QuantizationConfig],
615
+ filter_fn: Optional[Callable[[Node], bool]] = None,
616
+ ) -> Optional[List[List[Node]]]:
617
+ module_partitions = get_source_partitions(
618
+ gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn
619
+ )
620
+ maxpool_partitions = list(itertools.chain(*module_partitions.values()))
621
+ annotated_partitions = []
622
+ for maxpool_partition in maxpool_partitions:
623
+ annotated_partitions.append(maxpool_partition.nodes)
624
+ output_node = maxpool_partition.output_nodes[0]
625
+ maxpool_node = None
626
+ for n in maxpool_partition.nodes:
627
+ if (
628
+ n.target == torch.ops.aten.max_pool2d.default
629
+ or n.target == torch.ops.aten.max_pool2d_with_indices.default
630
+ ):
631
+ maxpool_node = n
632
+
633
+ assert (
634
+ maxpool_node is not None
635
+ ), "PT2EQuantizer only works with torch.ops.aten.max_pool2d.default, "
636
+ "please make sure you are exporting the model correctly"
637
+ if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
638
+ continue
639
+
640
+ input_act = maxpool_node.args[0] # type: ignore[union-attr]
641
+ assert isinstance(input_act, Node)
642
+
643
+ # only annotate maxpool when the output of the input node is annotated
644
+ if (
645
+ "quantization_annotation" not in input_act.meta
646
+ or not input_act.meta["quantization_annotation"]._annotated
647
+ or input_act.meta["quantization_annotation"].output_qspec is None
648
+ ):
649
+ continue
650
+ # input and output of maxpool will share quantization parameter with input of maxpool
651
+ act_qspec = SharedQuantizationSpec(input_act)
652
+ # act_qspec = get_act_qspec(quantization_config)
653
+ maxpool_node.meta[
654
+ # type: ignore[union-attr]
655
+ "quantization_annotation"
656
+ ] = QuantizationAnnotation(
657
+ input_qspec_map={
658
+ input_act: act_qspec,
659
+ },
660
+ _annotated=True,
661
+ )
662
+ output_node.meta["quantization_annotation"] = QuantizationAnnotation(
663
+ output_qspec=act_qspec,
664
+ _annotated=True,
665
+ )
666
+ return annotated_partitions
667
+
668
+
669
+ @register_annotator("adaptive_avg_pool2d")
670
+ def _annotate_adaptive_avg_pool2d(
671
+ gm: torch.fx.GraphModule,
672
+ quantization_config: Optional[QuantizationConfig],
673
+ filter_fn: Optional[Callable[[Node], bool]] = None,
674
+ ) -> Optional[List[List[Node]]]:
675
+ """Always annotate adaptive_avg_pool2d op"""
676
+ module_partitions = get_source_partitions(
677
+ gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
678
+ )
679
+ partitions = list(itertools.chain(*module_partitions.values()))
680
+ annotated_partitions = []
681
+ for partition in partitions:
682
+ pool_node = partition.output_nodes[0]
683
+ if pool_node.op != "call_function" or (
684
+ pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
685
+ and pool_node.target != torch.ops.aten._adaptive_avg_pool2d.default
686
+ and pool_node.target != torch.ops.aten.mean.dim
687
+ and pool_node.target != torch.ops.aten.as_strided_.default
688
+ ):
689
+ raise ValueError(
690
+ f"{pool_node} is not an aten adaptive_avg_pool2d operator"
691
+ )
692
+
693
+ if _is_annotated([pool_node]):
694
+ continue
695
+
696
+ annotated_partitions.append(partition.nodes)
697
+ input_act = pool_node.args[0]
698
+ assert isinstance(input_act, Node)
699
+
700
+ # only annotate input output sharing operator
701
+ # when the output of the input node is annotated
702
+ if (
703
+ "quantization_annotation" not in input_act.meta
704
+ or not input_act.meta["quantization_annotation"]._annotated
705
+ or input_act.meta["quantization_annotation"].output_qspec is None
706
+ ):
707
+ input_act_qspec = get_input_act_qspec(quantization_config)
708
+ else:
709
+ input_act_qspec = SharedQuantizationSpec(input_act)
710
+
711
+ # output sharing with input
712
+ output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
713
+ pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
714
+ input_qspec_map={
715
+ input_act: input_act_qspec,
716
+ },
717
+ output_qspec=output_act_qspec,
718
+ _annotated=True,
719
+ )
720
+ return annotated_partitions
721
+
722
+
723
+ @register_annotator("fixed_qparams")
724
+ def _annotate_fixed_qparams(
725
+ gm: torch.fx.GraphModule,
726
+ quantization_config: Optional[QuantizationConfig],
727
+ filter_fn: Optional[Callable[[Node], bool]] = None,
728
+ ) -> Optional[List[List[Node]]]:
729
+ annotated_partitions = []
730
+ for node in gm.graph.nodes:
731
+ if node.op != "call_function" or (
732
+ node.target != torch.ops.aten.sigmoid.default
733
+ and node.target != torch.ops.aten._softmax.default
734
+ ):
735
+ continue
736
+
737
+ input_act = node.args[0] # type: ignore[union-attr]
738
+ assert isinstance(input_act, Node)
739
+
740
+ # only annotate when the output of the input node is annotated
741
+ if (
742
+ "quantization_annotation" not in input_act.meta
743
+ or not input_act.meta["quantization_annotation"]._annotated
744
+ or input_act.meta["quantization_annotation"].output_qspec is None
745
+ ):
746
+ continue
747
+ partition = [node]
748
+
749
+ if _is_annotated(partition):
750
+ continue
751
+
752
+ if filter_fn and any(not filter_fn(n) for n in partition):
753
+ continue
754
+
755
+ node.meta["quantization_annotation"] = QuantizationAnnotation(
756
+ output_qspec=get_fixed_qparams_qspec(quantization_config),
757
+ _annotated=True,
758
+ )
759
+ _mark_nodes_as_annotated(partition)
760
+ annotated_partitions.append(partition)
761
+
762
+ return annotated_partitions
763
+
764
+
765
+ @register_annotator("add_relu")
766
+ def _annotate_add_relu(
767
+ gm: torch.fx.GraphModule,
768
+ quantization_config: Optional[QuantizationConfig],
769
+ filter_fn: Optional[Callable[[Node], bool]] = None,
770
+ ) -> Optional[List[List[Node]]]:
771
+ fused_partitions = find_sequential_partitions(
772
+ gm, [torch.add, torch.nn.ReLU], filter_fn
773
+ )
774
+ annotated_partitions = []
775
+ for fused_partition in fused_partitions:
776
+ add_partition, relu_partition = fused_partition
777
+ annotated_partitions.append(add_partition.nodes + relu_partition.nodes)
778
+ if len(relu_partition.output_nodes) > 1:
779
+ raise ValueError("Relu partition has more than one output node")
780
+ relu_node = relu_partition.output_nodes[0]
781
+ if len(add_partition.output_nodes) > 1:
782
+ raise ValueError("add partition has more than one output node")
783
+ add_node = add_partition.output_nodes[0]
784
+
785
+ if _is_annotated([relu_node, add_node]):
786
+ continue
787
+
788
+ input_act_qspec = get_input_act_qspec(quantization_config)
789
+ output_act_qspec = get_output_act_qspec(quantization_config)
790
+
791
+ input_qspec_map = {}
792
+ input_act0 = add_node.args[0]
793
+ if isinstance(input_act0, Node):
794
+ input_qspec_map[input_act0] = input_act_qspec
795
+
796
+ input_act1 = add_node.args[1]
797
+ if isinstance(input_act1, Node):
798
+ input_qspec_map[input_act1] = input_act_qspec
799
+
800
+ add_node.meta["quantization_annotation"] = QuantizationAnnotation(
801
+ input_qspec_map=input_qspec_map,
802
+ _annotated=True,
803
+ )
804
+ relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
805
+ output_qspec=output_act_qspec,
806
+ _annotated=True,
807
+ )
808
+ return annotated_partitions
809
+
810
+
811
+ @register_annotator("add")
812
+ def _annotate_add(
813
+ gm: torch.fx.GraphModule,
814
+ quantization_config: Optional[QuantizationConfig],
815
+ filter_fn: Optional[Callable[[Node], bool]] = None,
816
+ ) -> Optional[List[List[Node]]]:
817
+ add_partitions = get_source_partitions(
818
+ gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
819
+ )
820
+ add_partitions = list(itertools.chain(*add_partitions.values()))
821
+ annotated_partitions = []
822
+ for add_partition in add_partitions:
823
+ annotated_partitions.append(add_partition.nodes)
824
+ add_node = add_partition.output_nodes[0]
825
+ if _is_annotated([add_node]):
826
+ continue
827
+
828
+ input_act_qspec = get_input_act_qspec(quantization_config)
829
+ output_act_qspec = get_output_act_qspec(quantization_config)
830
+
831
+ input_qspec_map = {}
832
+ input_act0 = add_node.args[0]
833
+ if isinstance(input_act0, Node):
834
+ input_qspec_map[input_act0] = input_act_qspec
835
+
836
+ input_act1 = add_node.args[1]
837
+ if isinstance(input_act1, Node):
838
+ input_qspec_map[input_act1] = input_act_qspec
839
+
840
+ add_node.meta["quantization_annotation"] = QuantizationAnnotation(
841
+ input_qspec_map=input_qspec_map,
842
+ output_qspec=output_act_qspec,
843
+ _annotated=True,
844
+ )
845
+ return annotated_partitions
846
+
847
+
848
+ @register_annotator("mul_relu")
849
+ def _annotate_mul_relu(
850
+ gm: torch.fx.GraphModule,
851
+ quantization_config: Optional[QuantizationConfig],
852
+ filter_fn: Optional[Callable[[Node], bool]] = None,
853
+ ) -> Optional[List[List[Node]]]:
854
+ fused_partitions = find_sequential_partitions(
855
+ gm, [torch.mul, torch.nn.ReLU], filter_fn
856
+ )
857
+ annotated_partitions = []
858
+ for fused_partition in fused_partitions:
859
+ mul_partition, relu_partition = fused_partition
860
+ annotated_partitions.append(mul_partition.nodes + relu_partition.nodes)
861
+ if len(relu_partition.output_nodes) > 1:
862
+ raise ValueError("Relu partition has more than one output node")
863
+ relu_node = relu_partition.output_nodes[0]
864
+ if len(mul_partition.output_nodes) > 1:
865
+ raise ValueError("mul partition has more than one output node")
866
+ mul_node = mul_partition.output_nodes[0]
867
+
868
+ if _is_annotated([relu_node, mul_node]):
869
+ continue
870
+
871
+ input_act_qspec = get_input_act_qspec(quantization_config)
872
+ output_act_qspec = get_output_act_qspec(quantization_config)
873
+
874
+ input_qspec_map = {}
875
+ input_act0 = mul_node.args[0]
876
+ if isinstance(input_act0, Node):
877
+ input_qspec_map[input_act0] = input_act_qspec
878
+
879
+ input_act1 = mul_node.args[1]
880
+ if isinstance(input_act1, Node):
881
+ input_qspec_map[input_act1] = input_act_qspec
882
+
883
+ mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
884
+ input_qspec_map=input_qspec_map,
885
+ _annotated=True,
886
+ )
887
+ relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
888
+ output_qspec=output_act_qspec,
889
+ _annotated=True,
890
+ )
891
+ return annotated_partitions
892
+
893
+
894
+ @register_annotator("mul")
895
+ def _annotate_mul(
896
+ gm: torch.fx.GraphModule,
897
+ quantization_config: Optional[QuantizationConfig],
898
+ filter_fn: Optional[Callable[[Node], bool]] = None,
899
+ ) -> Optional[List[List[Node]]]:
900
+ mul_partitions = get_source_partitions(
901
+ gm.graph,
902
+ ["mul", "mul_", operator.mul, torch.mul, operator.imul],
903
+ filter_fn,
904
+ )
905
+ mul_partitions = list(itertools.chain(*mul_partitions.values()))
906
+ annotated_partitions = []
907
+ for mul_partition in mul_partitions:
908
+ annotated_partitions.append(mul_partition.nodes)
909
+ mul_node = mul_partition.output_nodes[0]
910
+ if _is_annotated([mul_node]):
911
+ continue
912
+
913
+ input_act_qspec = get_input_act_qspec(quantization_config)
914
+ output_act_qspec = get_output_act_qspec(quantization_config)
915
+
916
+ input_qspec_map = {}
917
+ input_act0 = mul_node.args[0]
918
+ if isinstance(input_act0, Node):
919
+ input_qspec_map[input_act0] = input_act_qspec
920
+
921
+ input_act1 = mul_node.args[1]
922
+ if isinstance(input_act1, Node):
923
+ input_qspec_map[input_act1] = input_act_qspec
924
+
925
+ mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
926
+ input_qspec_map=input_qspec_map,
927
+ output_qspec=output_act_qspec,
928
+ _annotated=True,
929
+ )
930
+ return annotated_partitions
931
+
932
+
933
+ # TODO: remove Optional in return type, fix annotated_partitions logic
934
+ @register_annotator("cat")
935
+ def _annotate_cat(
936
+ gm: torch.fx.GraphModule,
937
+ quantization_config: Optional[QuantizationConfig],
938
+ filter_fn: Optional[Callable[[Node], bool]] = None,
939
+ ) -> Optional[List[List[Node]]]:
940
+ cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
941
+ cat_partitions = list(itertools.chain(*cat_partitions.values()))
942
+ annotated_partitions = []
943
+ for cat_partition in cat_partitions:
944
+ cat_node = cat_partition.output_nodes[0]
945
+ if _is_annotated([cat_node]):
946
+ continue
947
+
948
+ if cat_node.target != torch.ops.aten.cat.default:
949
+ raise Exception(
950
+ "Expected cat node: torch.ops.aten.cat.default, but found"
951
+ f" {cat_node.target} please check if you are calling the correct"
952
+ " capture API"
953
+ )
954
+
955
+ annotated_partitions.append(cat_partition.nodes)
956
+
957
+ input_act_qspec = get_input_act_qspec(quantization_config)
958
+ inputs = cat_node.args[0]
959
+
960
+ input_qspec_map = {}
961
+ input_act0 = inputs[0]
962
+ if isinstance(input_act0, Node):
963
+ input_qspec_map[input_act0] = input_act_qspec
964
+
965
+ shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
966
+ for input_act in inputs[1:]:
967
+ input_qspec_map[input_act] = shared_with_input0_qspec
968
+
969
+ output_act_qspec = shared_with_input0_qspec
970
+
971
+ cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
972
+ input_qspec_map=input_qspec_map,
973
+ output_qspec=output_act_qspec,
974
+ _annotated=True,
975
+ )
976
+ return annotated_partitions
977
+
978
+
979
+ def _is_share_obs_or_fq_op(op: Callable) -> bool:
980
+ return op in [
981
+ torch.ops.aten.hardtanh.default,
982
+ torch.ops.aten.hardtanh_.default,
983
+ torch.ops.aten.mean.default,
984
+ torch.ops.aten.mean.dim,
985
+ torch.ops.aten.permute.default,
986
+ torch.ops.aten.permute_copy.default,
987
+ torch.ops.aten.squeeze.dim,
988
+ torch.ops.aten.squeeze_copy.dim,
989
+ torch.ops.aten.adaptive_avg_pool2d.default,
990
+ torch.ops.aten.view_copy.default,
991
+ torch.ops.aten.view.default,
992
+ torch.ops.aten.slice_copy.Tensor,
993
+ torch.ops.aten.flatten.using_ints,
994
+ ]
995
+
996
+
997
+ def propagate_annotation(model: torch.fx.GraphModule) -> None:
998
+ for n in model.graph.nodes:
999
+ if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
1000
+ continue
1001
+
1002
+ prev_node = n.args[0]
1003
+ if not isinstance(prev_node, Node):
1004
+ continue
1005
+
1006
+ quantization_annotation = prev_node.meta.get(
1007
+ "quantization_annotation", None
1008
+ )
1009
+ if not quantization_annotation:
1010
+ continue
1011
+
1012
+ output_qspec = quantization_annotation.output_qspec
1013
+ if not output_qspec:
1014
+ continue
1015
+
1016
+ # make sure current node is not annotated
1017
+ if (
1018
+ "quantization_annotation" in n.meta
1019
+ and n.meta["quantization_annotation"]._annotated
1020
+ ):
1021
+ continue
1022
+
1023
+ shared_qspec = SharedQuantizationSpec(prev_node)
1024
+ # propagate the previous output_qspec to the current node
1025
+ n.meta["quantization_annotation"] = QuantizationAnnotation(
1026
+ input_qspec_map={
1027
+ prev_node: shared_qspec,
1028
+ },
1029
+ output_qspec=shared_qspec,
1030
+ _annotated=True,
1031
+ )
1032
+
1033
+
1034
+ # TODO: make the list of ops customizable
1035
+ def _convert_scalars_to_attrs(
1036
+ model: torch.fx.GraphModule,
1037
+ ) -> torch.fx.GraphModule:
1038
+ for n in model.graph.nodes:
1039
+ if n.op != "call_function" or n.target not in [
1040
+ torch.ops.aten.add.Tensor,
1041
+ torch.ops.aten.mul.Tensor,
1042
+ ]:
1043
+ continue
1044
+ args = list(n.args)
1045
+ new_args = []
1046
+ for i in range(len(args)):
1047
+ if isinstance(args[i], torch.fx.Node):
1048
+ new_args.append(args[i])
1049
+ continue
1050
+ prefix = "_tensor_constant_"
1051
+ get_new_attr_name = get_new_attr_name_with_prefix(prefix)
1052
+ tensor_constant_name = get_new_attr_name(model)
1053
+ model.register_buffer(tensor_constant_name, torch.tensor(args[i]))
1054
+ with model.graph.inserting_before(n):
1055
+ get_attr_node = model.graph.create_node(
1056
+ "get_attr", tensor_constant_name, (), {}
1057
+ )
1058
+ new_args.append(get_attr_node)
1059
+ n.args = tuple(new_args)
1060
+ model.recompile()
1061
+ return model