ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
ai_edge_torch/model.py ADDED
@@ -0,0 +1,134 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Represents an ai_edge_torch model.
17
+
18
+ PyTorch models can be converted to this representation through `ai_edge_torch.convert`.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import abc
23
+
24
+ import numpy as np
25
+ import numpy.typing as npt
26
+ import tensorflow as tf
27
+
28
+ from ai_edge_torch.convert import conversion_utils as cutils
29
+
30
+
31
+ class Model(abc.ABC):
32
+ """Represents and edge model."""
33
+
34
+ @abc.abstractmethod
35
+ def __call__(
36
+ self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
37
+ ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
38
+ raise NotImplementedError()
39
+
40
+ @abc.abstractmethod
41
+ def export(self, path: str):
42
+ raise NotImplementedError()
43
+
44
+ @staticmethod
45
+ def load(path: str) -> TfLiteModel:
46
+ tflite_model = TfLiteModel.load(path)
47
+ if tflite_model:
48
+ return tflite_model
49
+
50
+ raise ValueError(f'File format in {path} cannot be deserialized.')
51
+
52
+
53
+ class TfLiteModel(Model):
54
+ """An edge model which uses tflite under-the-hood."""
55
+
56
+ def __init__(self, tflite_model):
57
+ """Initializes the TfLiteModel instance using a TFLite serialized object.
58
+
59
+ Args:
60
+ tflite_model: A TFlite serialized object.
61
+ """
62
+ self._tflite_model = tflite_model
63
+
64
+ def __call__(
65
+ self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
66
+ ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
67
+ """Runs inference on the edge model using the provided arguments.
68
+
69
+ Args:
70
+ *args: The arguments to be passed to the model for inference.
71
+ signature_name: The name of the signature to be used for inference.
72
+ The default signature is used if not provided.
73
+ """
74
+ interpreter = tf.lite.Interpreter(model_content=self._tflite_model)
75
+ interpreter.allocate_tensors()
76
+
77
+ signature_list = interpreter.get_signature_list()
78
+ if signature_name not in signature_list:
79
+ raise ValueError(
80
+ f"Invalid signature name provided. Available signatures: {', '.join(signature_list.keys())}"
81
+ )
82
+
83
+ try:
84
+ runner = interpreter.get_signature_runner(signature_name)
85
+ except ValueError as exception:
86
+ if 'Invalid signature_key provided.' in str(exception):
87
+ raise ValueError(
88
+ f'Invalid signature key provided. Available signatures: {list(signature_list.keys())}'
89
+ )
90
+ else:
91
+ raise exception
92
+
93
+ if len(signature_list[signature_name]['inputs']) != len(args):
94
+ raise ValueError(
95
+ f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided."
96
+ )
97
+
98
+ # Gather the input dictionary based on the signature.
99
+ inputs = {f'args_{idx}': args[idx] for idx in range(len(args))}
100
+ outputs = runner(**inputs)
101
+
102
+ return (
103
+ outputs['output_0']
104
+ if len(outputs) == 1
105
+ else [outputs[f'output_{idx}'] for idx in range(len(outputs))]
106
+ )
107
+
108
+ def export(self, path: str) -> None:
109
+ """Serializes the edge model to disk.
110
+
111
+ Args:
112
+ path: The path to file to which the model is serialized.
113
+ """
114
+ with open(path, 'wb') as file_handle:
115
+ file_handle.write(self._tflite_model)
116
+
117
+ @staticmethod
118
+ def load(path: str) -> TfLiteModel | None:
119
+ """Returns an edge (tflite) model by reading it from the disk.
120
+
121
+ Args:
122
+ str: The path to the model.
123
+ """
124
+ with open(path, 'rb') as file_handle:
125
+ model_content = file_handle.read()
126
+
127
+ # Check if this is indeed a tflite model:
128
+ try:
129
+ interpreter = tf.lite.Interpreter(model_content=model_content)
130
+ interpreter.get_signature_list()
131
+ except:
132
+ return None
133
+
134
+ return TfLiteModel(model_content)
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from .pt2e_quantizer import PT2EQuantizer
@@ -0,0 +1,438 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import copy
19
+ import functools
20
+ from typing import Any, Callable, Dict, List, Optional, Set
21
+
22
+ import torch
23
+ from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
24
+ from torch.ao.quantization.observer import HistogramObserver
25
+ from torch.ao.quantization.observer import MinMaxObserver
26
+ from torch.ao.quantization.observer import MovingAverageMinMaxObserver
27
+ from torch.ao.quantization.observer import MovingAveragePerChannelMinMaxObserver # NOQA
28
+ from torch.ao.quantization.observer import PerChannelMinMaxObserver
29
+ from torch.ao.quantization.observer import PlaceholderObserver
30
+ from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
31
+ from torch.ao.quantization.quantizer import FixedQParamsQuantizationSpec
32
+ from torch.ao.quantization.quantizer import QuantizationSpec
33
+ from torch.ao.quantization.quantizer import Quantizer
34
+ from torch.fx import Node
35
+ import torch.nn.functional as F
36
+
37
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
38
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
39
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
40
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
41
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
42
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
43
+
44
+ __all__ = [
45
+ "PT2EQuantizer",
46
+ "get_symmetric_quantization_config",
47
+ ]
48
+
49
+
50
+ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
51
+ supported_operators: Dict[str, List[OperatorPatternType]] = {
52
+ # Both conv and linear should be able to handle relu + hardtanh fusion since
53
+ # those are clamp ops
54
+ "conv2d": [
55
+ [torch.nn.Conv2d, torch.nn.ReLU],
56
+ [torch.nn.Conv2d, F.relu],
57
+ [F.conv2d, torch.nn.ReLU],
58
+ [F.conv2d, F.relu],
59
+ ],
60
+ "linear": [[torch.nn.Linear], [F.linear]],
61
+ "add": [[torch.add]],
62
+ "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
63
+ "adaptive_avg_pool2d": [
64
+ [torch.nn.AdaptiveAvgPool2d],
65
+ [F.adaptive_avg_pool2d],
66
+ ],
67
+ }
68
+ return copy.deepcopy(supported_operators)
69
+
70
+
71
+ def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
72
+ supported_config_and_operators: List[OperatorConfig] = []
73
+ for quantization_config in [
74
+ get_symmetric_quantization_config(),
75
+ get_symmetric_quantization_config(is_qat=True),
76
+ get_symmetric_quantization_config(is_per_channel=True),
77
+ get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
78
+ ]:
79
+ ops = _supported_symmetric_quantized_operators()
80
+ for pattern_list in ops.values():
81
+ supported_config_and_operators.append(
82
+ OperatorConfig(quantization_config, pattern_list)
83
+ )
84
+ return copy.deepcopy(supported_config_and_operators)
85
+
86
+
87
+ @functools.lru_cache
88
+ def get_symmetric_quantization_config(
89
+ is_per_channel: bool = False,
90
+ is_qat: bool = False,
91
+ is_dynamic: bool = False,
92
+ ):
93
+ if is_qat:
94
+ if is_dynamic:
95
+ raise NotImplementedError("dynamic quantization for qat is not yet implemented.")
96
+ act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
97
+ else:
98
+ if is_dynamic:
99
+ act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
100
+ else:
101
+ act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
102
+
103
+ act_quantization_spec = QuantizationSpec(
104
+ dtype=torch.int8,
105
+ quant_min=-128,
106
+ quant_max=127,
107
+ qscheme=torch.per_tensor_affine,
108
+ is_dynamic=is_dynamic,
109
+ observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
110
+ )
111
+ qscheme = (
112
+ torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
113
+ )
114
+ weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
115
+ if is_qat:
116
+ weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
117
+ elif is_per_channel:
118
+ weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
119
+
120
+ extra_args: Dict[str, Any] = {"eps": 2**-12}
121
+ if is_qat:
122
+ if qscheme == torch.per_tensor_symmetric:
123
+ extra_args["observer"] = MovingAverageMinMaxObserver
124
+ else:
125
+ extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
126
+ weight_quantization_spec = QuantizationSpec(
127
+ dtype=torch.int8,
128
+ quant_min=-127,
129
+ quant_max=127,
130
+ qscheme=qscheme,
131
+ ch_axis=0,
132
+ is_dynamic=False,
133
+ observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
134
+ **extra_args
135
+ ),
136
+ )
137
+
138
+ bias_quantization_spec = None
139
+
140
+ # Some TFLite ops (e.g. Logistic, Softmax) have fixed qparams requirements
141
+ fixed_qparams_spec = FixedQParamsQuantizationSpec(
142
+ dtype=torch.int8,
143
+ scale=1 / 256,
144
+ zero_point=-128,
145
+ quant_min=-128,
146
+ quant_max=127,
147
+ qscheme=torch.per_tensor_affine,
148
+ )
149
+
150
+ if is_dynamic:
151
+ # Only valid for TFLite downstream to have no input activation quantization
152
+ # because dynamic quantization should be legalized to TFLite DRQ kernels
153
+ # which calculate quantization parameters during runtime inside the kernels
154
+ quantization_config = QuantizationConfig(
155
+ None,
156
+ None,
157
+ weight_quantization_spec,
158
+ bias_quantization_spec,
159
+ None,
160
+ is_qat,
161
+ True,
162
+ )
163
+ else:
164
+ quantization_config = QuantizationConfig(
165
+ act_quantization_spec,
166
+ act_quantization_spec,
167
+ weight_quantization_spec,
168
+ bias_quantization_spec,
169
+ fixed_qparams_spec,
170
+ is_qat,
171
+ False,
172
+ )
173
+ return quantization_config
174
+
175
+
176
+ def _get_supported_config_and_operators() -> List[OperatorConfig]:
177
+ return _get_supported_symmetric_config_and_operators()
178
+
179
+
180
+ def _get_module_name_filter(module_name: str):
181
+ """Get the module_name_filter function for a given module name, the filter accepts
182
+ a node and checks if the node comes from a module that has certain module name
183
+
184
+ For example:
185
+ node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
186
+
187
+
188
+ >> module_name_filter = _get_module_name_filter("blocks.sub")
189
+ >> print(module_name_filter(node))
190
+ True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
191
+ """
192
+
193
+ def module_name_filter(n: Node) -> bool:
194
+ # example: {
195
+ # 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
196
+ # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
197
+ # }
198
+ # get_attr nodes doesn't have nn_module_stack?
199
+ nn_module_stack = n.meta.get("nn_module_stack", {})
200
+ names = [n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()]
201
+ return module_name in names
202
+
203
+ return module_name_filter
204
+
205
+
206
+ def _get_module_type_filter(tp: Callable):
207
+ """Get the module_type_filter function for a given module type, the filter accepts
208
+ a node and checks if the node comes from a module that has certain module type
209
+
210
+ For example:
211
+ node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
212
+
213
+
214
+ >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
215
+ >> print(module_type_filter(node))
216
+ True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
217
+ """
218
+
219
+ def module_type_filter(n: Node) -> bool:
220
+ # example: {
221
+ # 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
222
+ # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
223
+ # }
224
+ nn_module_stack = n.meta.get("nn_module_stack", {})
225
+ types = [t for _, t in nn_module_stack.values()]
226
+ return tp in types
227
+
228
+ return module_type_filter
229
+
230
+
231
+ def _get_not_module_type_or_name_filter(
232
+ tp_list: List[Callable], module_name_list: List[str]
233
+ ) -> Callable[[Node], bool]:
234
+ module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
235
+ module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
236
+
237
+ def not_module_type_or_name_filter(n: Node) -> bool:
238
+ return not any(f(n) for f in module_type_filters + module_name_list_filters)
239
+
240
+ return not_module_type_or_name_filter
241
+
242
+
243
+ class PT2EQuantizer(Quantizer):
244
+ supported_config_and_operators = _get_supported_config_and_operators()
245
+ STATIC_QAT_ONLY_OPS = [
246
+ "conv_bn_relu",
247
+ "conv_bn",
248
+ ]
249
+
250
+ # static quantization ops (both PTQ and QAT)
251
+ STATIC_OPS = [
252
+ "linear",
253
+ "addmm",
254
+ "conv_relu",
255
+ "conv",
256
+ "adaptive_avg_pool2d",
257
+ "gru_io_only",
258
+ "max_pool2d",
259
+ "add_relu",
260
+ "add",
261
+ "mul_relu",
262
+ "mul",
263
+ "cat",
264
+ "fixed_qparams",
265
+ ]
266
+
267
+ DYNAMIC_OPS = [
268
+ "linear",
269
+ "addmm",
270
+ "conv",
271
+ "conv_relu",
272
+ ]
273
+
274
+ def __init__(self):
275
+ super().__init__()
276
+ self.global_config: Optional[QuantizationConfig] = None
277
+ self.operator_type_config: Dict[
278
+ torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
279
+ ] = {}
280
+ self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
281
+ self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
282
+
283
+ @classmethod
284
+ def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
285
+ op_configs: Set[QuantizationConfig] = set({})
286
+ for spec, _ in cls.supported_config_and_operators:
287
+ op_configs.add(spec)
288
+ return list(op_configs)
289
+
290
+ @classmethod
291
+ def get_supported_operator_for_quantization_config(
292
+ cls, quantization_config: Optional[QuantizationConfig]
293
+ ) -> List[OperatorPatternType]:
294
+ if quantization_config is None:
295
+ all_ops = []
296
+ for _, ops in cls.supported_config_and_operators:
297
+ all_ops.extend(ops)
298
+ return all_ops
299
+
300
+ for config, ops in cls.supported_config_and_operators:
301
+ # note: this assumes each entry in cls.supported_spec_and_operators
302
+ # corresponds to one spec, e.g. we don't have
303
+ # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
304
+ # where the first and second entry have the same spec but did not
305
+ # merge the op list
306
+ if config == quantization_config:
307
+ return ops
308
+ return []
309
+
310
+ def set_global(self, quantization_config: QuantizationConfig) -> PT2EQuantizer:
311
+ self.global_config = quantization_config
312
+ return self
313
+
314
+ def set_operator_type(
315
+ self,
316
+ operator_type: torch._ops.OpOverloadPacket,
317
+ quantization_config: QuantizationConfig,
318
+ ) -> PT2EQuantizer:
319
+ self.operator_type_config[operator_type] = quantization_config
320
+ return self
321
+
322
+ def set_module_type(
323
+ self, module_type: Callable, quantization_config: QuantizationConfig
324
+ ):
325
+ """Set quantization_config for a submodule with type: `module_type`, for example:
326
+ quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
327
+ patterns in the submodule with this module type with the given `quantization_config`
328
+ """
329
+ self.module_type_config[module_type] = quantization_config
330
+ return self
331
+
332
+ def set_module_name(
333
+ self, module_name: str, quantization_config: Optional[QuantizationConfig]
334
+ ):
335
+ """Set quantization_config for a submodule with name: `module_name`, for example:
336
+ quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
337
+ patterns in the submodule with this module name with the given `quantization_config`
338
+ """
339
+ assert (
340
+ quantization_config is not None
341
+ ), " quantization_config == None is not supported yet"
342
+ self.module_name_config[module_name] = quantization_config
343
+ return self
344
+
345
+ def transform_for_annotation(
346
+ self, model: torch.fx.GraphModule
347
+ ) -> torch.fx.GraphModule:
348
+ """Transforms scalar values to tensor attributes"""
349
+ return _convert_scalars_to_attrs(model)
350
+
351
+ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
352
+ """just handling global spec for now"""
353
+ if self.global_config and not self.global_config.input_activation: # type: ignore[union-attr]
354
+ model = self._annotate_for_dynamic_quantization_config(model)
355
+ else:
356
+ model = self._annotate_for_static_quantization_config(model)
357
+ propagate_annotation(model)
358
+ return model
359
+
360
+ def _annotate_all_static_patterns(
361
+ self,
362
+ model: torch.fx.GraphModule,
363
+ quantization_config: Optional[QuantizationConfig],
364
+ filter_fn: Optional[Callable[[Node], bool]] = None,
365
+ ) -> torch.fx.GraphModule:
366
+ if quantization_config is None:
367
+ return model
368
+
369
+ if quantization_config.is_qat:
370
+ for op in self.STATIC_QAT_ONLY_OPS:
371
+ OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
372
+ for op in self.STATIC_OPS:
373
+ OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
374
+ return model
375
+
376
+ def _annotate_all_dynamic_patterns(
377
+ self,
378
+ model: torch.fx.GraphModule,
379
+ quantization_config: Optional[QuantizationConfig],
380
+ filter_fn: Optional[Callable[[Node], bool]] = None,
381
+ ) -> torch.fx.GraphModule:
382
+ if quantization_config is None:
383
+ return model
384
+
385
+ for op in self.DYNAMIC_OPS:
386
+ OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
387
+ return model
388
+
389
+ def _annotate_for_static_quantization_config(
390
+ self, model: torch.fx.GraphModule
391
+ ) -> torch.fx.GraphModule:
392
+ module_name_list = list(self.module_name_config.keys())
393
+ for module_name, config in self.module_name_config.items():
394
+ self._annotate_all_static_patterns(
395
+ model, config, _get_module_name_filter(module_name)
396
+ )
397
+
398
+ tp_list = list(self.module_type_config.keys())
399
+ for module_type, config in self.module_type_config.items():
400
+ self._annotate_all_static_patterns(
401
+ model, config, _get_module_type_filter(module_type)
402
+ )
403
+
404
+ self._annotate_all_static_patterns(
405
+ model,
406
+ self.global_config,
407
+ _get_not_module_type_or_name_filter(tp_list, module_name_list),
408
+ )
409
+ return model
410
+
411
+ def _annotate_for_dynamic_quantization_config(
412
+ self, model: torch.fx.GraphModule
413
+ ) -> torch.fx.GraphModule:
414
+ module_name_list = list(self.module_name_config.keys())
415
+ for module_name, config in self.module_name_config.items():
416
+ self._annotate_all_dynamic_patterns(
417
+ model, config, _get_module_name_filter(module_name)
418
+ )
419
+
420
+ tp_list = list(self.module_type_config.keys())
421
+ for module_type, config in self.module_type_config.items():
422
+ self._annotate_all_dynamic_patterns(
423
+ model, config, _get_module_type_filter(module_type)
424
+ )
425
+
426
+ self._annotate_all_dynamic_patterns(
427
+ model,
428
+ self.global_config,
429
+ _get_not_module_type_or_name_filter(tp_list, module_name_list),
430
+ )
431
+ return model
432
+
433
+ def validate(self, model: torch.fx.GraphModule) -> None:
434
+ pass
435
+
436
+ @classmethod
437
+ def get_supported_operators(cls) -> List[OperatorConfig]:
438
+ return cls.supported_config_and_operators