ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (169) hide show
  1. ai_edge_torch/__init__.py +5 -4
  2. ai_edge_torch/_convert/conversion.py +112 -0
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +94 -48
  5. ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
  8. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
  9. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
  10. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  19. ai_edge_torch/_convert/signature.py +66 -0
  20. ai_edge_torch/_convert/test/test_convert.py +495 -0
  21. ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
  22. ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
  23. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
  24. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
  25. ai_edge_torch/config.py +27 -0
  26. ai_edge_torch/conftest.py +20 -0
  27. ai_edge_torch/debug/culprit.py +72 -40
  28. ai_edge_torch/debug/test/test_culprit.py +7 -5
  29. ai_edge_torch/debug/test/test_search_model.py +8 -7
  30. ai_edge_torch/debug/utils.py +14 -3
  31. ai_edge_torch/fx_pass_base.py +101 -0
  32. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
  33. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
  34. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
  35. ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
  36. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
  37. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
  38. ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
  39. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
  40. ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
  41. ai_edge_torch/generative/examples/openelm/verify.py +64 -0
  42. ai_edge_torch/generative/examples/phi/__init__.py +14 -0
  43. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  44. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
  45. ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
  46. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  47. ai_edge_torch/generative/examples/phi/verify.py +65 -0
  48. ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
  49. ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
  50. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
  51. ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
  52. ai_edge_torch/generative/examples/smollm/verify.py +62 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  54. ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
  55. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
  56. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
  57. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
  58. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  59. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
  60. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  61. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  62. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  63. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  64. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  65. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  66. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
  67. ai_edge_torch/generative/examples/t5/t5.py +208 -159
  68. ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
  69. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  70. ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
  71. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
  72. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  73. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
  74. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
  75. ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
  76. ai_edge_torch/generative/fx_passes/__init__.py +4 -5
  77. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
  78. ai_edge_torch/generative/layers/attention.py +141 -102
  79. ai_edge_torch/generative/layers/attention_utils.py +53 -12
  80. ai_edge_torch/generative/layers/builder.py +37 -7
  81. ai_edge_torch/generative/layers/feed_forward.py +39 -14
  82. ai_edge_torch/generative/layers/kv_cache.py +162 -50
  83. ai_edge_torch/generative/layers/model_config.py +84 -30
  84. ai_edge_torch/generative/layers/normalization.py +185 -7
  85. ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
  86. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
  87. ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
  88. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  89. ai_edge_torch/generative/layers/unet/model_config.py +17 -15
  90. ai_edge_torch/generative/quantize/example.py +7 -8
  91. ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
  92. ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
  93. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  94. ai_edge_torch/generative/test/test_kv_cache.py +120 -0
  95. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
  96. ai_edge_torch/generative/test/test_model_conversion.py +124 -188
  97. ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
  98. ai_edge_torch/generative/test/test_quantize.py +76 -60
  99. ai_edge_torch/generative/test/utils.py +54 -0
  100. ai_edge_torch/generative/utilities/converter.py +82 -0
  101. ai_edge_torch/generative/utilities/loader.py +120 -57
  102. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
  103. ai_edge_torch/generative/utilities/t5_loader.py +110 -81
  104. ai_edge_torch/generative/utilities/verifier.py +247 -0
  105. ai_edge_torch/hlfb/__init__.py +1 -1
  106. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
  107. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  108. ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
  109. ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
  110. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
  111. ai_edge_torch/lowertools/__init__.py +18 -0
  112. ai_edge_torch/lowertools/_shim.py +80 -0
  113. ai_edge_torch/lowertools/common_utils.py +142 -0
  114. ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
  115. ai_edge_torch/lowertools/test_utils.py +60 -0
  116. ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
  117. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
  118. ai_edge_torch/model.py +53 -18
  119. ai_edge_torch/odml_torch/__init__.py +20 -0
  120. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  121. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  122. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  123. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  124. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  125. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  126. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  127. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  128. ai_edge_torch/odml_torch/export.py +357 -0
  129. ai_edge_torch/odml_torch/export_utils.py +168 -0
  130. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  131. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
  132. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  133. ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
  134. ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
  135. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  136. ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
  137. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
  138. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  139. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  140. ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
  141. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  142. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  143. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  144. ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
  145. ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
  146. ai_edge_torch/quantize/quant_config.py +13 -9
  147. ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
  148. ai_edge_torch/version.py +16 -0
  149. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
  150. ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
  151. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
  152. ai_edge_torch/convert/conversion.py +0 -117
  153. ai_edge_torch/convert/conversion_utils.py +0 -400
  154. ai_edge_torch/convert/fx_passes/__init__.py +0 -59
  155. ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
  156. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
  157. ai_edge_torch/convert/test/test_convert.py +0 -311
  158. ai_edge_torch/convert/test/test_convert_composites.py +0 -192
  159. ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
  160. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
  161. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
  162. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
  163. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  164. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
  165. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  166. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  167. /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
  168. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
  169. {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """APIs to convert lowered MLIR from PyTorch to TensorFlow and TFLite artifacts."""
16
+
17
+ import re
18
+ import tempfile
19
+
20
+ import tensorflow as tf
21
+ import torch
22
+
23
+ from tensorflow.compiler.tf2xla.python import xla as tfxla
24
+
25
+ from . import export
26
+ from . import export_utils
27
+
28
+
29
+ def torch_dtype_to_tf(dtype):
30
+ return {
31
+ torch.double: tf.float64,
32
+ torch.float32: tf.float32,
33
+ torch.half: tf.float16,
34
+ torch.long: tf.int64,
35
+ torch.int32: tf.int32,
36
+ torch.int16: tf.int16,
37
+ torch.bool: tf.bool,
38
+ }.get(dtype)
39
+
40
+
41
+ def _get_shape_with_dynamic(signature: export.VariableSignature):
42
+ return [
43
+ None if export_utils.is_torch_dynamic(s) else s for s in signature.shape
44
+ ]
45
+
46
+
47
+ def _mangle_tf_root_scope_name(name):
48
+ r"""Build the mangled name for tf.Variable.
49
+
50
+ TF has more restricted constrain on the variable names at root scope. Root
51
+ scope name constrain: [A-Za-z0-9.][A-Za-z0-9_.\\-/]* Non-root scope name
52
+ constrain: [A-Za-z0-9_.\\-/]*
53
+ https://github.com/tensorflow/tensorflow/blob/51b601fa6bb7e801c0b6ae73c25580e40a8b5745/tensorflow/python/framework/ops.py#L3301-L3302
54
+ The state_dict key doesn't have such constrain, the name need to be mangled
55
+ when a root-scoped TF variable is created.
56
+
57
+ FX Graph Node may contain characters other than [A-Za-z0-9_.\\-/], replace
58
+ offending characters with '_'.
59
+
60
+ Args:
61
+ name: the tensor name to be mangled.
62
+
63
+ Returns:
64
+ Mangled name in str.
65
+ """
66
+ if name[0] in "._\\-/":
67
+ name = "k" + name
68
+ name = re.sub(r"[^^\w\-/\\]+", "_", name)
69
+ return name
70
+
71
+
72
+ def _build_tf_state_dict(
73
+ lowered: export.MlirLowered,
74
+ ) -> dict[str, tf.Variable]:
75
+ """Build a dictionary of tf.Variable from the state_dict in lowered."""
76
+ tf_state_dict = {}
77
+ for sig in lowered.input_signature:
78
+ if sig.input_spec.is_parameter:
79
+ name = sig.input_spec.name
80
+ tf_state_dict[name] = tf.Variable(
81
+ lowered.state_dict[name].detach().numpy(),
82
+ trainable=False,
83
+ name=_mangle_tf_root_scope_name(name),
84
+ )
85
+ return tf_state_dict
86
+
87
+
88
+ def _extract_call_args(
89
+ lowered: export.MlirLowered,
90
+ args,
91
+ tf_state_dict: dict[str, tf.Variable],
92
+ ):
93
+ """Extract the flattened inputs to built tf.function."""
94
+ call_args = []
95
+ for sig in lowered.input_signature:
96
+ if sig.input_spec.is_user_input:
97
+ call_args.append(args[sig.input_spec.i])
98
+ elif sig.input_spec.is_parameter:
99
+ name = sig.input_spec.name
100
+ call_args.append(tf_state_dict[name])
101
+ return call_args
102
+
103
+
104
+ def _wrap_as_tf_func(lowered, tf_state_dict):
105
+ """Build tf.function from lowered and tf_state_dict."""
106
+
107
+ def inner(*args):
108
+ t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
109
+ s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
110
+ call_args = _extract_call_args(lowered, args, tf_state_dict)
111
+ return tfxla.call_module(
112
+ tuple(call_args),
113
+ version=5,
114
+ Tout=t_outs, # dtype information
115
+ Sout=s_outs, # Shape information
116
+ function_list=[],
117
+ module=lowered.module_bytecode,
118
+ )
119
+
120
+ return inner
121
+
122
+
123
+ def _make_input_signatures(
124
+ lowered: export.MlirLowered,
125
+ ) -> list[tf.TensorSpec]:
126
+ """Build the input signatures in tf.TensorSpec for building tf.function."""
127
+ user_input_signature = sorted(
128
+ [sig for sig in lowered.input_signature if sig.input_spec.is_user_input],
129
+ key=lambda sig: sig.input_spec.i,
130
+ )
131
+ tf_signatures = []
132
+
133
+ for sig in user_input_signature:
134
+ shape = _get_shape_with_dynamic(sig)
135
+ tf_signatures.append(
136
+ tf.TensorSpec(
137
+ shape=shape,
138
+ dtype=torch_dtype_to_tf(sig.dtype),
139
+ name=f"args_{sig.input_spec.i}",
140
+ )
141
+ )
142
+ return tf_signatures
143
+
144
+
145
+ def mlir_to_tf_function(lowered: export.MlirLowered):
146
+ """Convert the MLIR lowered to a executable tf.function."""
147
+ tf_state_dict = _build_tf_state_dict(lowered)
148
+ return tf.function(
149
+ _wrap_as_tf_func(lowered, tf_state_dict),
150
+ input_signature=_make_input_signatures(lowered),
151
+ )
152
+
153
+
154
+ def mlir_to_flatbuffer(lowered: export.MlirLowered):
155
+ """Convert the MLIR lowered to a TFLite flatbuffer binary."""
156
+ tf_state_dict = _build_tf_state_dict(lowered)
157
+ signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
158
+ tf_signatures = [_make_input_signatures(lowered)]
159
+ tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
160
+
161
+ tf_module = tf.Module()
162
+ tf_module.f = []
163
+
164
+ for tf_sig, func in zip(tf_signatures, tf_functions):
165
+ tf_module.f.append(
166
+ tf.function(
167
+ func,
168
+ input_signature=tf_sig,
169
+ )
170
+ )
171
+
172
+ tf_module._variables = list(tf_state_dict.values())
173
+
174
+ tf_concrete_funcs = [
175
+ func.get_concrete_function(*tf_sig)
176
+ for func, tf_sig in zip(tf_module.f, tf_signatures)
177
+ ]
178
+
179
+ # We need to temporarily save since TFLite's from_concrete_functions does not
180
+ # allow providing names for each of the concrete functions.
181
+ with tempfile.TemporaryDirectory() as temp_dir_path:
182
+ tf.saved_model.save(
183
+ tf_module,
184
+ temp_dir_path,
185
+ signatures={
186
+ sig_name: tf_concrete_funcs[idx]
187
+ for idx, sig_name in enumerate(signature_names)
188
+ },
189
+ )
190
+
191
+ converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
192
+ tflite_model = converter.convert()
193
+
194
+ return tflite_model
@@ -19,6 +19,12 @@ import copy
19
19
  import functools
20
20
  from typing import Any, Callable, Dict, List, Optional, Set
21
21
 
22
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
23
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
24
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
25
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
26
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
27
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
22
28
  import torch
23
29
  from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
24
30
  from torch.ao.quantization.observer import HistogramObserver
@@ -34,20 +40,15 @@ from torch.ao.quantization.quantizer import Quantizer
34
40
  from torch.fx import Node
35
41
  import torch.nn.functional as F
36
42
 
37
- from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
38
- from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
39
- from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
40
- from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
41
- from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
42
- from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
43
-
44
43
  __all__ = [
45
44
  "PT2EQuantizer",
46
45
  "get_symmetric_quantization_config",
47
46
  ]
48
47
 
49
48
 
50
- def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
49
+ def _supported_symmetric_quantized_operators() -> (
50
+ Dict[str, List[OperatorPatternType]]
51
+ ):
51
52
  supported_operators: Dict[str, List[OperatorPatternType]] = {
52
53
  # Both conv and linear should be able to handle relu + hardtanh fusion since
53
54
  # those are clamp ops
@@ -92,7 +93,9 @@ def get_symmetric_quantization_config(
92
93
  ):
93
94
  if is_qat:
94
95
  if is_dynamic:
95
- raise NotImplementedError("dynamic quantization for qat is not yet implemented.")
96
+ raise NotImplementedError(
97
+ "dynamic quantization for qat is not yet implemented."
98
+ )
96
99
  act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
97
100
  else:
98
101
  if is_dynamic:
@@ -106,12 +109,18 @@ def get_symmetric_quantization_config(
106
109
  quant_max=127,
107
110
  qscheme=torch.per_tensor_affine,
108
111
  is_dynamic=is_dynamic,
109
- observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
112
+ observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
113
+ eps=2**-12
114
+ ),
110
115
  )
111
116
  qscheme = (
112
- torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
117
+ torch.per_channel_symmetric
118
+ if is_per_channel
119
+ else torch.per_tensor_symmetric
120
+ )
121
+ weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
122
+ MinMaxObserver
113
123
  )
114
- weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
115
124
  if is_qat:
116
125
  weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
117
126
  elif is_per_channel:
@@ -179,15 +188,18 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]:
179
188
 
180
189
  def _get_module_name_filter(module_name: str):
181
190
  """Get the module_name_filter function for a given module name, the filter accepts
191
+
182
192
  a node and checks if the node comes from a module that has certain module name
183
193
 
184
194
  For example:
185
- node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
195
+ node: linear_op = call_function[...](...) # comes from a module with name
196
+ blocks.sub.linear1
186
197
 
187
198
 
188
199
  >> module_name_filter = _get_module_name_filter("blocks.sub")
189
200
  >> print(module_name_filter(node))
190
- True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
201
+ True # the node is from "blocks.sub" based on the fully qualified name
202
+ "blocks.sub.linear1"
191
203
  """
192
204
 
193
205
  def module_name_filter(n: Node) -> bool:
@@ -197,7 +209,9 @@ def _get_module_name_filter(module_name: str):
197
209
  # }
198
210
  # get_attr nodes doesn't have nn_module_stack?
199
211
  nn_module_stack = n.meta.get("nn_module_stack", {})
200
- names = [n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()]
212
+ names = [
213
+ n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
214
+ ]
201
215
  return module_name in names
202
216
 
203
217
  return module_name_filter
@@ -205,15 +219,19 @@ def _get_module_name_filter(module_name: str):
205
219
 
206
220
  def _get_module_type_filter(tp: Callable):
207
221
  """Get the module_type_filter function for a given module type, the filter accepts
222
+
208
223
  a node and checks if the node comes from a module that has certain module type
209
224
 
210
225
  For example:
211
- node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
226
+ node: linear_op = call_function[...](...) # comes from a module with type
227
+ Block -> Sub -> Linear
212
228
 
213
229
 
214
- >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
230
+ >> module_type_filter = _get_module_type_filter(Sub) # submodule with type
231
+ `Sub`, under the `Block` submodule
215
232
  >> print(module_type_filter(node))
216
- True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
233
+ True # the node is from the submodule `Sub` (same for `Block` and `Linear` as
234
+ well)
217
235
  """
218
236
 
219
237
  def module_type_filter(n: Node) -> bool:
@@ -232,7 +250,9 @@ def _get_not_module_type_or_name_filter(
232
250
  tp_list: List[Callable], module_name_list: List[str]
233
251
  ) -> Callable[[Node], bool]:
234
252
  module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
235
- module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
253
+ module_name_list_filters = [
254
+ _get_module_name_filter(m) for m in module_name_list
255
+ ]
236
256
 
237
257
  def not_module_type_or_name_filter(n: Node) -> bool:
238
258
  return not any(f(n) for f in module_type_filters + module_name_list_filters)
@@ -307,7 +327,9 @@ class PT2EQuantizer(Quantizer):
307
327
  return ops
308
328
  return []
309
329
 
310
- def set_global(self, quantization_config: QuantizationConfig) -> PT2EQuantizer:
330
+ def set_global(
331
+ self, quantization_config: QuantizationConfig
332
+ ) -> PT2EQuantizer:
311
333
  self.global_config = quantization_config
312
334
  return self
313
335
 
@@ -323,8 +345,11 @@ class PT2EQuantizer(Quantizer):
323
345
  self, module_type: Callable, quantization_config: QuantizationConfig
324
346
  ):
325
347
  """Set quantization_config for a submodule with type: `module_type`, for example:
326
- quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
327
- patterns in the submodule with this module type with the given `quantization_config`
348
+
349
+ quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it
350
+ will quantize all supported operator/operator
351
+ patterns in the submodule with this module type with the given
352
+ `quantization_config`
328
353
  """
329
354
  self.module_type_config[module_type] = quantization_config
330
355
  return self
@@ -333,8 +358,11 @@ class PT2EQuantizer(Quantizer):
333
358
  self, module_name: str, quantization_config: Optional[QuantizationConfig]
334
359
  ):
335
360
  """Set quantization_config for a submodule with name: `module_name`, for example:
336
- quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
337
- patterns in the submodule with this module name with the given `quantization_config`
361
+
362
+ quantizer.set_module_name("blocks.sub"), it will quantize all supported
363
+ operator/operator
364
+ patterns in the submodule with this module name with the given
365
+ `quantization_config`
338
366
  """
339
367
  assert (
340
368
  quantization_config is not None
@@ -31,7 +31,7 @@ from torch.ao.quantization.quantizer import SharedQuantizationSpec
31
31
  from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map
32
32
  from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
33
33
  from torch.fx import Node
34
- from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap # NOQA
34
+ from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap
35
35
  from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
36
36
  import torch.nn.functional as F
37
37
 
@@ -95,9 +95,10 @@ class OperatorConfig(NamedTuple):
95
95
 
96
96
 
97
97
  def _is_annotated(nodes: List[Node]):
98
- """
99
- Given a list of nodes (that represents an operator pattern),
100
- check if any of the node is annotated, return True if any of the node
98
+ """Checks if a list of nodes is annotated.
99
+
100
+ Given a list of nodes (that represents an operator pattern), check if any of
101
+ the node is annotated, return True if any of the node
101
102
  is annotated, otherwise return False
102
103
  """
103
104
  annotated = False
@@ -154,7 +155,9 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
154
155
  torch.per_tensor_symmetric,
155
156
  torch.per_channel_symmetric,
156
157
  ]:
157
- raise ValueError(f"Unsupported quantization_spec {quantization_spec} for weight")
158
+ raise ValueError(
159
+ f"Unsupported quantization_spec {quantization_spec} for weight"
160
+ )
158
161
  return quantization_spec
159
162
 
160
163
 
@@ -193,7 +196,10 @@ def _annotate_linear(
193
196
  weight_qspec = get_weight_qspec(quantization_config)
194
197
  bias_qspec = get_bias_qspec(quantization_config)
195
198
  for node in gm.graph.nodes:
196
- if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
199
+ if (
200
+ node.op != "call_function"
201
+ or node.target != torch.ops.aten.linear.default
202
+ ):
197
203
  continue
198
204
  if filter_fn and not filter_fn(node):
199
205
  continue
@@ -413,11 +419,13 @@ def _annotate_conv_bn(
413
419
  quantization_config: Optional[QuantizationConfig],
414
420
  filter_fn: Optional[Callable[[Node], bool]] = None,
415
421
  ) -> Optional[List[List[Node]]]:
422
+ """Find conv + batchnorm parititions Note: This is only used for QAT.
423
+
424
+ In PTQ, batchnorm should already be fused into the conv.
416
425
  """
417
- Find conv + batchnorm parititions
418
- Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
419
- """
420
- return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
426
+ return _do_annotate_conv_bn(
427
+ gm, quantization_config, filter_fn, has_relu=False
428
+ )
421
429
 
422
430
 
423
431
  @register_annotator("conv_bn_relu")
@@ -426,9 +434,9 @@ def _annotate_conv_bn_relu(
426
434
  quantization_config: Optional[QuantizationConfig],
427
435
  filter_fn: Optional[Callable[[Node], bool]] = None,
428
436
  ) -> Optional[List[List[Node]]]:
429
- """
430
- Find conv + batchnorm + relu parititions
431
- Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
437
+ """Find conv + batchnorm + relu parititions Note: This is only used for QAT.
438
+
439
+ In PTQ, batchnorm should already be fused into the conv.
432
440
  """
433
441
  return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
434
442
 
@@ -439,8 +447,8 @@ def _do_annotate_conv_bn(
439
447
  filter_fn: Optional[Callable[[Node], bool]],
440
448
  has_relu: bool,
441
449
  ) -> List[List[Node]]:
442
- """
443
- Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
450
+ """Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
451
+
444
452
  return a list of annotated partitions.
445
453
 
446
454
  The output of the pattern must include a dictionary from string name to node
@@ -486,7 +494,9 @@ def _do_annotate_conv_bn(
486
494
  # Match against all conv dimensions and cuda variants
487
495
  for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
488
496
  pattern = get_pattern(conv_fn, relu_is_inplace)
489
- pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
497
+ pattern = _get_aten_graph_module_for_pattern(
498
+ pattern, example_inputs, is_cuda
499
+ )
490
500
  pattern.graph.eliminate_dead_code()
491
501
  pattern.recompile()
492
502
  matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
@@ -676,7 +686,9 @@ def _annotate_adaptive_avg_pool2d(
676
686
  and pool_node.target != torch.ops.aten.mean.dim
677
687
  and pool_node.target != torch.ops.aten.as_strided_.default
678
688
  ):
679
- raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
689
+ raise ValueError(
690
+ f"{pool_node} is not an aten adaptive_avg_pool2d operator"
691
+ )
680
692
 
681
693
  if _is_annotated([pool_node]):
682
694
  continue
@@ -741,7 +753,8 @@ def _annotate_fixed_qparams(
741
753
  continue
742
754
 
743
755
  node.meta["quantization_annotation"] = QuantizationAnnotation(
744
- output_qspec=get_fixed_qparams_qspec(quantization_config), _annotated=True
756
+ output_qspec=get_fixed_qparams_qspec(quantization_config),
757
+ _annotated=True,
745
758
  )
746
759
  _mark_nodes_as_annotated(partition)
747
760
  annotated_partitions.append(partition)
@@ -885,7 +898,9 @@ def _annotate_mul(
885
898
  filter_fn: Optional[Callable[[Node], bool]] = None,
886
899
  ) -> Optional[List[List[Node]]]:
887
900
  mul_partitions = get_source_partitions(
888
- gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
901
+ gm.graph,
902
+ ["mul", "mul_", operator.mul, torch.mul, operator.imul],
903
+ filter_fn,
889
904
  )
890
905
  mul_partitions = list(itertools.chain(*mul_partitions.values()))
891
906
  annotated_partitions = []
@@ -932,8 +947,9 @@ def _annotate_cat(
932
947
 
933
948
  if cat_node.target != torch.ops.aten.cat.default:
934
949
  raise Exception(
935
- f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
936
- " please check if you are calling the correct capture API"
950
+ "Expected cat node: torch.ops.aten.cat.default, but found"
951
+ f" {cat_node.target} please check if you are calling the correct"
952
+ " capture API"
937
953
  )
938
954
 
939
955
  annotated_partitions.append(cat_partition.nodes)
@@ -987,7 +1003,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
987
1003
  if not isinstance(prev_node, Node):
988
1004
  continue
989
1005
 
990
- quantization_annotation = prev_node.meta.get("quantization_annotation", None)
1006
+ quantization_annotation = prev_node.meta.get(
1007
+ "quantization_annotation", None
1008
+ )
991
1009
  if not quantization_annotation:
992
1010
  continue
993
1011
 
@@ -1014,7 +1032,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
1014
1032
 
1015
1033
 
1016
1034
  # TODO: make the list of ops customizable
1017
- def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1035
+ def _convert_scalars_to_attrs(
1036
+ model: torch.fx.GraphModule,
1037
+ ) -> torch.fx.GraphModule:
1018
1038
  for n in model.graph.nodes:
1019
1039
  if n.op != "call_function" or n.target not in [
1020
1040
  torch.ops.aten.add.Tensor,
@@ -13,27 +13,27 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from dataclasses import dataclass
16
+ import dataclasses
17
17
  import enum
18
18
  from typing import Optional
19
19
 
20
- from ai_edge_torch.generative.quantize import quant_attrs
21
20
  from ai_edge_torch.generative.quantize import quant_recipe
22
21
  from ai_edge_torch.quantize import pt2e_quantizer as pt2eq
23
22
 
24
23
 
25
- @dataclass(frozen=True)
24
+ @dataclasses.dataclass(frozen=True)
26
25
  class QuantConfig:
27
- """
26
+ """Encapsulates a quantization configuration.
27
+
28
28
  Encapsulates all different quantization methods and schemes available for
29
29
  models converted with ai_edge_torch.
30
30
 
31
- Args:
31
+ Attributes:
32
32
  pt2e_quantizer: The instance of PT2EQuantizer used to quantize the model
33
33
  with PT2E quantization. This method of quantization is not applicable to
34
34
  models created with the Edge Generative API.
35
- generative_recipe: Quantization recipe to be applied on a model created
36
- with the Edge Generative API.
35
+ generative_recipe: Quantization recipe to be applied on a model created with
36
+ the Edge Generative API.
37
37
  """
38
38
 
39
39
  pt2e_quantizer: pt2eq.PT2EQuantizer = None
@@ -76,6 +76,10 @@ class QuantConfig:
76
76
  elif generative_recipe is not None:
77
77
  generative_recipe.verify()
78
78
  object.__setattr__(self, 'generative_recipe', generative_recipe)
79
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER)
79
+ object.__setattr__(
80
+ self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER
81
+ )
80
82
  else:
81
- raise ValueError('Either pt2e_quantizer or generative_recipe must be set.')
83
+ raise ValueError(
84
+ 'Either pt2e_quantizer or generative_recipe must be set.'
85
+ )
@@ -13,26 +13,33 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Utility Functions to test TFLite models exported from PyTorch"""
16
+ """Contains utility functions to test TFLite models exported from PyTorch."""
17
17
 
18
18
  from collections.abc import Callable
19
19
 
20
+ from ai_edge_torch import model
20
21
  import numpy as np
21
22
  import torch
22
23
  from torch.utils import _pytree as pytree
23
24
 
24
- from ai_edge_torch.model import Model
25
-
26
25
 
27
26
  # Utility to flatten the order to make it deterministic.
28
27
  # Ordering is done in left-to-right depth-first tree traversal.
29
28
  def _flatten(data):
30
- out, spec = pytree.tree_flatten(data)
29
+ out, _ = pytree.tree_flatten(data)
31
30
  return out
32
31
 
33
32
 
34
33
  # Convert a Torch Tensor to a numpy array
35
34
  def _torch_tensors_to_np(*argv):
35
+ """Converts a Torch Tensor to a numpy array.
36
+
37
+ Args:
38
+ *argv: A list of torch.tensor or a single torch.tensor.
39
+
40
+ Returns:
41
+ A list of numpy array or a single numpy array.
42
+ """
36
43
  if len(argv) > 1:
37
44
  data = list(argv)
38
45
  else:
@@ -58,7 +65,7 @@ def _torch_tensors_to_np(*argv):
58
65
 
59
66
 
60
67
  def compare_tflite_torch(
61
- edge_model: Model,
68
+ edge_model: model.Model,
62
69
  torch_eval_func: Callable,
63
70
  args=None,
64
71
  kwargs=None,
@@ -69,15 +76,17 @@ def compare_tflite_torch(
69
76
  rtol: float = 1e-5
70
77
  ):
71
78
  """Compares torch models and TFLite models.
79
+
72
80
  Args:
73
81
  edge_model: Serialized ai_edge_torch.model.Model object.
74
82
  torch_eval_func: Callable function to evaluate torch model.
75
- args: torch.tensor array or a callable to generate a torch.tensor array
76
- with random data, to pass into models during inference. (default None).
83
+ args: torch.tensor array or a callable to generate a torch.tensor array with
84
+ random data, to pass into models during inference. (default None).
77
85
  kwargs: dict of str to torch.tensor, or a callable to generate such.
78
- num_valid_inputs: Defines the number of times the random inputs will be generated (if a callable is provided for input_data).
79
- signature_name: If provided, specifies the name for the signature of the edge_model to run.
80
- Calls the default signature if not provided.
86
+ num_valid_inputs: Defines the number of times the random inputs will be
87
+ generated (if a callable is provided for input_data).
88
+ signature_name: If provided, specifies the name for the signature of the
89
+ edge_model to run. Calls the default signature if not provided.
81
90
  atol: Absolute tolerance (see `numpy.allclose`)
82
91
  rtol: Relative tolerance (see `numpy.allclose`)
83
92
  """
@@ -94,7 +103,9 @@ def compare_tflite_torch(
94
103
  )
95
104
  for _ in range(num_valid_inputs)
96
105
  ]
97
- torch_outputs = [torch_eval_func(*args, **kwargs) for args, kwargs in torch_inputs]
106
+ torch_outputs = [
107
+ torch_eval_func(*args, **kwargs) for args, kwargs in torch_inputs
108
+ ]
98
109
  np_inputs = [
99
110
  (_torch_tensors_to_np(args), _torch_tensors_to_np(kwargs))
100
111
  for args, kwargs in torch_inputs
@@ -110,12 +121,13 @@ def compare_tflite_torch(
110
121
  if signature_name is None:
111
122
  return _flatten(edge_model(*args, **kwargs))
112
123
  else:
113
- return _flatten(edge_model(*args, **kwargs, signature_name=signature_name))
124
+ return _flatten(
125
+ edge_model(*args, **kwargs, signature_name=signature_name)
126
+ )
114
127
 
115
128
  for idx, np_input in enumerate(np_inputs):
116
129
  output = get_edge_output(np_input)
117
130
  golden_output = np_outputs[idx]
118
-
119
131
  is_output_len_eq = len(golden_output) == len(output)
120
132
 
121
133
  output = [v.astype(np.float32) for v in output]
@@ -123,9 +135,10 @@ def compare_tflite_torch(
123
135
 
124
136
  # Append the results of each invoke to a function-global variable
125
137
  # used to store the comparison final results
126
- is_equal = is_output_len_eq and all(
127
- [equal_fn(out, golden_out) for out, golden_out in zip(output, golden_output)]
128
- )
138
+ is_equal = is_output_len_eq and all([
139
+ equal_fn(out, golden_out)
140
+ for out, golden_out in zip(output, golden_output)
141
+ ])
129
142
  if not is_equal:
130
143
  return False
131
144