ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1041 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
import itertools
|
|
18
|
+
import operator
|
|
19
|
+
from typing import Callable, Dict, List, NamedTuple, Optional
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
|
23
|
+
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
|
24
|
+
from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
|
|
25
|
+
from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
|
|
26
|
+
from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
|
|
27
|
+
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
|
28
|
+
from torch.ao.quantization.quantizer import QuantizationSpec
|
|
29
|
+
from torch.ao.quantization.quantizer import QuantizationSpecBase
|
|
30
|
+
from torch.ao.quantization.quantizer import SharedQuantizationSpec
|
|
31
|
+
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map
|
|
32
|
+
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
|
|
33
|
+
from torch.fx import Node
|
|
34
|
+
from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap # NOQA
|
|
35
|
+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
|
|
36
|
+
import torch.nn.functional as F
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"OperatorConfig",
|
|
40
|
+
"OperatorPatternType",
|
|
41
|
+
"QuantizationConfig",
|
|
42
|
+
"get_input_act_qspec",
|
|
43
|
+
"get_output_act_qspec",
|
|
44
|
+
"get_weight_qspec",
|
|
45
|
+
"get_bias_qspec",
|
|
46
|
+
"OP_TO_ANNOTATOR",
|
|
47
|
+
"propagate_annotation",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(eq=True, frozen=True)
|
|
52
|
+
class QuantizationConfig:
|
|
53
|
+
input_activation: Optional[QuantizationSpec]
|
|
54
|
+
output_activation: Optional[QuantizationSpec]
|
|
55
|
+
weight: Optional[QuantizationSpec]
|
|
56
|
+
bias: Optional[QuantizationSpec]
|
|
57
|
+
fixed_qparams: Optional[QuantizationSpec]
|
|
58
|
+
# TODO: remove, since we can use observer_or_fake_quant_ctr to express this
|
|
59
|
+
is_qat: bool = False
|
|
60
|
+
is_dynamic: bool = False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
OperatorPatternType = List[Callable]
|
|
64
|
+
OperatorPatternType.__module__ = "ai_edge_torch.quantize.pt2e_quantizer_utils"
|
|
65
|
+
|
|
66
|
+
AnnotatorType = Callable[
|
|
67
|
+
[
|
|
68
|
+
torch.fx.GraphModule,
|
|
69
|
+
Optional[QuantizationConfig],
|
|
70
|
+
Optional[Callable[[Node], bool]],
|
|
71
|
+
],
|
|
72
|
+
Optional[List[List[Node]]],
|
|
73
|
+
]
|
|
74
|
+
OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def register_annotator(op: str):
|
|
78
|
+
def decorator(annotator: AnnotatorType):
|
|
79
|
+
OP_TO_ANNOTATOR[op] = annotator
|
|
80
|
+
|
|
81
|
+
return decorator
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class OperatorConfig(NamedTuple):
|
|
85
|
+
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
|
|
86
|
+
# Basically we are mapping a quantization config to some list of patterns.
|
|
87
|
+
# a pattern is defined as a list of nn module, function or builtin function names
|
|
88
|
+
# e.g. [nn.Conv2d, torch.relu, torch.add]
|
|
89
|
+
# We have not resolved whether fusion can be considered internal details of the
|
|
90
|
+
# quantizer hence it does not need communication to user.
|
|
91
|
+
# Note this pattern is not really informative since it does not really
|
|
92
|
+
# tell us the graph structure resulting from the list of ops.
|
|
93
|
+
config: QuantizationConfig
|
|
94
|
+
operators: List[OperatorPatternType]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _is_annotated(nodes: List[Node]):
|
|
98
|
+
"""
|
|
99
|
+
Given a list of nodes (that represents an operator pattern),
|
|
100
|
+
check if any of the node is annotated, return True if any of the node
|
|
101
|
+
is annotated, otherwise return False
|
|
102
|
+
"""
|
|
103
|
+
annotated = False
|
|
104
|
+
for node in nodes:
|
|
105
|
+
annotated = annotated or (
|
|
106
|
+
"quantization_annotation" in node.meta
|
|
107
|
+
and node.meta["quantization_annotation"]._annotated
|
|
108
|
+
)
|
|
109
|
+
return annotated
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _mark_nodes_as_annotated(nodes: List[Node]):
|
|
113
|
+
for node in nodes:
|
|
114
|
+
if node is not None:
|
|
115
|
+
if "quantization_annotation" not in node.meta:
|
|
116
|
+
node.meta["quantization_annotation"] = QuantizationAnnotation()
|
|
117
|
+
node.meta["quantization_annotation"]._annotated = True
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
|
|
121
|
+
if quantization_config is None:
|
|
122
|
+
return None
|
|
123
|
+
if quantization_config.input_activation is None:
|
|
124
|
+
return None
|
|
125
|
+
quantization_spec: QuantizationSpec = quantization_config.input_activation
|
|
126
|
+
assert quantization_spec.qscheme in [
|
|
127
|
+
torch.per_tensor_affine,
|
|
128
|
+
torch.per_tensor_symmetric,
|
|
129
|
+
]
|
|
130
|
+
return quantization_spec
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
|
|
134
|
+
if quantization_config is None:
|
|
135
|
+
return None
|
|
136
|
+
if quantization_config.output_activation is None:
|
|
137
|
+
return None
|
|
138
|
+
quantization_spec: QuantizationSpec = quantization_config.output_activation
|
|
139
|
+
assert quantization_spec.qscheme in [
|
|
140
|
+
torch.per_tensor_affine,
|
|
141
|
+
torch.per_tensor_symmetric,
|
|
142
|
+
]
|
|
143
|
+
return quantization_spec
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
|
|
147
|
+
if quantization_config is None:
|
|
148
|
+
return None
|
|
149
|
+
assert quantization_config is not None
|
|
150
|
+
if quantization_config.weight is None:
|
|
151
|
+
return None
|
|
152
|
+
quantization_spec: QuantizationSpec = quantization_config.weight
|
|
153
|
+
if quantization_spec.qscheme not in [
|
|
154
|
+
torch.per_tensor_symmetric,
|
|
155
|
+
torch.per_channel_symmetric,
|
|
156
|
+
]:
|
|
157
|
+
raise ValueError(f"Unsupported quantization_spec {quantization_spec} for weight")
|
|
158
|
+
return quantization_spec
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
|
|
162
|
+
if quantization_config is None:
|
|
163
|
+
return None
|
|
164
|
+
assert quantization_config is not None
|
|
165
|
+
if quantization_config.bias is None:
|
|
166
|
+
return None
|
|
167
|
+
quantization_spec: QuantizationSpec = quantization_config.bias
|
|
168
|
+
assert (
|
|
169
|
+
quantization_spec.dtype == torch.float
|
|
170
|
+
), "Only float dtype for bias is supported for bias right now"
|
|
171
|
+
return quantization_spec
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_fixed_qparams_qspec(quantization_config: Optional[QuantizationConfig]):
|
|
175
|
+
if quantization_config is None:
|
|
176
|
+
return None
|
|
177
|
+
assert quantization_config is not None
|
|
178
|
+
if quantization_config.fixed_qparams is None:
|
|
179
|
+
return None
|
|
180
|
+
quantization_spec: QuantizationSpec = quantization_config.fixed_qparams
|
|
181
|
+
return quantization_spec
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@register_annotator("linear")
|
|
185
|
+
def _annotate_linear(
|
|
186
|
+
gm: torch.fx.GraphModule,
|
|
187
|
+
quantization_config: Optional[QuantizationConfig],
|
|
188
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
189
|
+
) -> Optional[List[List[Node]]]:
|
|
190
|
+
annotated_partitions = []
|
|
191
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
|
192
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
|
193
|
+
weight_qspec = get_weight_qspec(quantization_config)
|
|
194
|
+
bias_qspec = get_bias_qspec(quantization_config)
|
|
195
|
+
for node in gm.graph.nodes:
|
|
196
|
+
if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
|
|
197
|
+
continue
|
|
198
|
+
if filter_fn and not filter_fn(node):
|
|
199
|
+
continue
|
|
200
|
+
act_node = node.args[0]
|
|
201
|
+
weight_node = node.args[1]
|
|
202
|
+
bias_node = None
|
|
203
|
+
if len(node.args) > 2:
|
|
204
|
+
bias_node = node.args[2]
|
|
205
|
+
|
|
206
|
+
if _is_annotated([node]) is False: # type: ignore[list-item]
|
|
207
|
+
_annotate_input_qspec_map(
|
|
208
|
+
node,
|
|
209
|
+
act_node,
|
|
210
|
+
input_act_qspec,
|
|
211
|
+
)
|
|
212
|
+
_annotate_input_qspec_map(
|
|
213
|
+
node,
|
|
214
|
+
weight_node,
|
|
215
|
+
weight_qspec,
|
|
216
|
+
)
|
|
217
|
+
nodes_to_mark_annotated = [node, weight_node]
|
|
218
|
+
if bias_node:
|
|
219
|
+
_annotate_input_qspec_map(
|
|
220
|
+
node,
|
|
221
|
+
bias_node,
|
|
222
|
+
bias_qspec,
|
|
223
|
+
)
|
|
224
|
+
nodes_to_mark_annotated.append(bias_node)
|
|
225
|
+
_annotate_output_qspec(node, output_act_qspec)
|
|
226
|
+
_mark_nodes_as_annotated(nodes_to_mark_annotated)
|
|
227
|
+
annotated_partitions.append(nodes_to_mark_annotated)
|
|
228
|
+
|
|
229
|
+
return annotated_partitions
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@register_annotator("addmm")
|
|
233
|
+
def _annotate_addmm(
|
|
234
|
+
gm: torch.fx.GraphModule,
|
|
235
|
+
quantization_config: Optional[QuantizationConfig],
|
|
236
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
237
|
+
) -> Optional[List[List[Node]]]:
|
|
238
|
+
annotated_partitions = []
|
|
239
|
+
for n in gm.graph.nodes:
|
|
240
|
+
if n.op != "call_function" or n.target not in [
|
|
241
|
+
torch.ops.aten.addmm.default,
|
|
242
|
+
]:
|
|
243
|
+
continue
|
|
244
|
+
addm_node = n
|
|
245
|
+
|
|
246
|
+
input_qspec_map = {}
|
|
247
|
+
input_act = addm_node.args[0]
|
|
248
|
+
assert isinstance(input_act, Node)
|
|
249
|
+
is_bias = (
|
|
250
|
+
len(list(input_act.meta["val"].size())) < 2
|
|
251
|
+
and input_act.op == "get_attr"
|
|
252
|
+
and "_param_constant" in input_act.target
|
|
253
|
+
)
|
|
254
|
+
input_qspec_map[input_act] = (
|
|
255
|
+
get_bias_qspec(quantization_config)
|
|
256
|
+
if is_bias
|
|
257
|
+
else get_input_act_qspec(quantization_config)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
mat1_act = addm_node.args[1]
|
|
261
|
+
assert isinstance(mat1_act, Node)
|
|
262
|
+
input_qspec_map[mat1_act] = get_input_act_qspec(quantization_config)
|
|
263
|
+
|
|
264
|
+
mat2_act = addm_node.args[2]
|
|
265
|
+
assert isinstance(mat2_act, Node)
|
|
266
|
+
is_weight = False
|
|
267
|
+
if mat2_act.op == "get_attr" and "_param_constant" in mat2_act.target:
|
|
268
|
+
is_weight = True
|
|
269
|
+
elif mat2_act.target == torch.ops.aten.t.default:
|
|
270
|
+
t_in = mat2_act.args[0]
|
|
271
|
+
if t_in.op == "get_attr" and "_param_constant" in t_in.target:
|
|
272
|
+
is_weight = True
|
|
273
|
+
input_qspec_map[mat2_act] = (
|
|
274
|
+
get_weight_qspec(quantization_config)
|
|
275
|
+
if is_weight
|
|
276
|
+
else get_input_act_qspec(quantization_config)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
partition = [addm_node, addm_node.args[1], addm_node.args[2]]
|
|
280
|
+
|
|
281
|
+
if _is_annotated(partition):
|
|
282
|
+
continue
|
|
283
|
+
|
|
284
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
|
285
|
+
continue
|
|
286
|
+
|
|
287
|
+
addm_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
288
|
+
input_qspec_map=input_qspec_map,
|
|
289
|
+
output_qspec=get_output_act_qspec(quantization_config),
|
|
290
|
+
_annotated=True,
|
|
291
|
+
)
|
|
292
|
+
_mark_nodes_as_annotated(partition)
|
|
293
|
+
annotated_partitions.append(partition)
|
|
294
|
+
return annotated_partitions
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@register_annotator("conv")
|
|
298
|
+
def _annotate_conv(
|
|
299
|
+
gm: torch.fx.GraphModule,
|
|
300
|
+
quantization_config: Optional[QuantizationConfig],
|
|
301
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
302
|
+
) -> Optional[List[List[Node]]]:
|
|
303
|
+
annotated_partitions = []
|
|
304
|
+
for n in gm.graph.nodes:
|
|
305
|
+
if n.op != "call_function" or n.target not in [
|
|
306
|
+
torch.ops.aten.conv1d.default,
|
|
307
|
+
torch.ops.aten.conv2d.default,
|
|
308
|
+
torch.ops.aten.convolution.default,
|
|
309
|
+
]:
|
|
310
|
+
continue
|
|
311
|
+
conv_node = n
|
|
312
|
+
|
|
313
|
+
input_qspec_map = {}
|
|
314
|
+
input_act = conv_node.args[0]
|
|
315
|
+
assert isinstance(input_act, Node)
|
|
316
|
+
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
|
|
317
|
+
|
|
318
|
+
weight = conv_node.args[1]
|
|
319
|
+
assert isinstance(weight, Node)
|
|
320
|
+
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
|
321
|
+
|
|
322
|
+
# adding weight node to the partition as well
|
|
323
|
+
partition = [conv_node, conv_node.args[1]]
|
|
324
|
+
|
|
325
|
+
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
|
|
326
|
+
if isinstance(bias, Node):
|
|
327
|
+
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
|
328
|
+
partition.append(bias)
|
|
329
|
+
|
|
330
|
+
if _is_annotated(partition):
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
337
|
+
input_qspec_map=input_qspec_map,
|
|
338
|
+
output_qspec=get_output_act_qspec(quantization_config),
|
|
339
|
+
_annotated=True,
|
|
340
|
+
)
|
|
341
|
+
_mark_nodes_as_annotated(partition)
|
|
342
|
+
annotated_partitions.append(partition)
|
|
343
|
+
return annotated_partitions
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@register_annotator("conv_relu")
|
|
347
|
+
def _annotate_conv_relu(
|
|
348
|
+
gm: torch.fx.GraphModule,
|
|
349
|
+
quantization_config: Optional[QuantizationConfig],
|
|
350
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
351
|
+
) -> Optional[List[List[Node]]]:
|
|
352
|
+
annotated_partitions = []
|
|
353
|
+
for n in gm.graph.nodes:
|
|
354
|
+
if n.op != "call_function" or n.target not in [
|
|
355
|
+
torch.ops.aten.relu.default,
|
|
356
|
+
torch.ops.aten.relu_.default,
|
|
357
|
+
]:
|
|
358
|
+
continue
|
|
359
|
+
relu_node = n
|
|
360
|
+
maybe_conv_node = n.args[0]
|
|
361
|
+
if (
|
|
362
|
+
not isinstance(maybe_conv_node, Node)
|
|
363
|
+
or maybe_conv_node.op != "call_function"
|
|
364
|
+
or maybe_conv_node.target
|
|
365
|
+
not in [
|
|
366
|
+
torch.ops.aten.conv1d.default,
|
|
367
|
+
torch.ops.aten.conv2d.default,
|
|
368
|
+
torch.ops.aten.convolution.default,
|
|
369
|
+
]
|
|
370
|
+
):
|
|
371
|
+
continue
|
|
372
|
+
conv_node = maybe_conv_node
|
|
373
|
+
|
|
374
|
+
input_qspec_map = {}
|
|
375
|
+
input_act = conv_node.args[0]
|
|
376
|
+
assert isinstance(input_act, Node)
|
|
377
|
+
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
|
|
378
|
+
|
|
379
|
+
weight = conv_node.args[1]
|
|
380
|
+
assert isinstance(weight, Node)
|
|
381
|
+
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
|
382
|
+
|
|
383
|
+
# adding weight node to the partition as well
|
|
384
|
+
partition = [relu_node, conv_node, conv_node.args[1]]
|
|
385
|
+
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
|
|
386
|
+
if isinstance(bias, Node):
|
|
387
|
+
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
|
388
|
+
partition.append(bias)
|
|
389
|
+
|
|
390
|
+
if _is_annotated(partition):
|
|
391
|
+
continue
|
|
392
|
+
|
|
393
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
397
|
+
input_qspec_map=input_qspec_map, _annotated=True
|
|
398
|
+
)
|
|
399
|
+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
400
|
+
output_qspec=get_output_act_qspec(
|
|
401
|
+
quantization_config
|
|
402
|
+
), # type: ignore[arg-type]
|
|
403
|
+
_annotated=True,
|
|
404
|
+
)
|
|
405
|
+
_mark_nodes_as_annotated(partition)
|
|
406
|
+
annotated_partitions.append(partition)
|
|
407
|
+
return annotated_partitions
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
@register_annotator("conv_bn")
|
|
411
|
+
def _annotate_conv_bn(
|
|
412
|
+
gm: torch.fx.GraphModule,
|
|
413
|
+
quantization_config: Optional[QuantizationConfig],
|
|
414
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
415
|
+
) -> Optional[List[List[Node]]]:
|
|
416
|
+
"""
|
|
417
|
+
Find conv + batchnorm parititions
|
|
418
|
+
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
|
419
|
+
"""
|
|
420
|
+
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@register_annotator("conv_bn_relu")
|
|
424
|
+
def _annotate_conv_bn_relu(
|
|
425
|
+
gm: torch.fx.GraphModule,
|
|
426
|
+
quantization_config: Optional[QuantizationConfig],
|
|
427
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
428
|
+
) -> Optional[List[List[Node]]]:
|
|
429
|
+
"""
|
|
430
|
+
Find conv + batchnorm + relu parititions
|
|
431
|
+
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
|
432
|
+
"""
|
|
433
|
+
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _do_annotate_conv_bn(
|
|
437
|
+
gm: torch.fx.GraphModule,
|
|
438
|
+
quantization_config: Optional[QuantizationConfig],
|
|
439
|
+
filter_fn: Optional[Callable[[Node], bool]],
|
|
440
|
+
has_relu: bool,
|
|
441
|
+
) -> List[List[Node]]:
|
|
442
|
+
"""
|
|
443
|
+
Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
|
|
444
|
+
return a list of annotated partitions.
|
|
445
|
+
|
|
446
|
+
The output of the pattern must include a dictionary from string name to node
|
|
447
|
+
for the following names: "input", "conv", "weight", "bias", and "output".
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
|
|
451
|
+
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
|
|
452
|
+
conv = conv_fn(x, conv_weight, conv_bias)
|
|
453
|
+
bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
|
|
454
|
+
if has_relu:
|
|
455
|
+
output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
|
|
456
|
+
else:
|
|
457
|
+
output = bn
|
|
458
|
+
return output, {
|
|
459
|
+
"input": x,
|
|
460
|
+
"conv": conv,
|
|
461
|
+
"weight": conv_weight,
|
|
462
|
+
"bias": conv_bias,
|
|
463
|
+
"output": output,
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
return _conv_bn
|
|
467
|
+
|
|
468
|
+
# Needed for matching, otherwise the matches gets filtered out due to unused
|
|
469
|
+
# nodes returned by batch norm
|
|
470
|
+
gm.graph.eliminate_dead_code()
|
|
471
|
+
gm.recompile()
|
|
472
|
+
|
|
473
|
+
matches = []
|
|
474
|
+
combinations = [
|
|
475
|
+
(F.conv1d, _conv1d_bn_example_inputs),
|
|
476
|
+
(F.conv2d, _conv2d_bn_example_inputs),
|
|
477
|
+
]
|
|
478
|
+
|
|
479
|
+
# Add `is_cuda` and `relu_is_inplace` dimensions
|
|
480
|
+
combinations = itertools.product(
|
|
481
|
+
combinations,
|
|
482
|
+
[True, False] if torch.cuda.is_available() else [False], # is_cuda
|
|
483
|
+
[True, False] if has_relu else [False], # relu_is_inplace
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Match against all conv dimensions and cuda variants
|
|
487
|
+
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
|
|
488
|
+
pattern = get_pattern(conv_fn, relu_is_inplace)
|
|
489
|
+
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
|
|
490
|
+
pattern.graph.eliminate_dead_code()
|
|
491
|
+
pattern.recompile()
|
|
492
|
+
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
|
493
|
+
matches.extend(matcher.match(gm.graph))
|
|
494
|
+
|
|
495
|
+
# Annotate nodes returned in the matches
|
|
496
|
+
annotated_partitions = []
|
|
497
|
+
for match in matches:
|
|
498
|
+
name_node_map = match.name_node_map
|
|
499
|
+
input_node = name_node_map["input"]
|
|
500
|
+
conv_node = name_node_map["conv"]
|
|
501
|
+
weight_node = name_node_map["weight"]
|
|
502
|
+
bias_node = name_node_map["bias"]
|
|
503
|
+
output_node = name_node_map["output"]
|
|
504
|
+
|
|
505
|
+
# TODO: annotate the uses of input, weight, and bias separately instead
|
|
506
|
+
# of assuming they come from a single conv node. This is not possible today
|
|
507
|
+
# because input may have multiple users, and we can't rely on the conv node
|
|
508
|
+
# always being the first user. This was the case in models with skip
|
|
509
|
+
# connections like resnet18
|
|
510
|
+
|
|
511
|
+
# Validate conv args
|
|
512
|
+
if conv_node.args[0] is not input_node:
|
|
513
|
+
raise ValueError("Conv arg did not contain input node ", input_node)
|
|
514
|
+
if conv_node.args[1] is not weight_node:
|
|
515
|
+
raise ValueError("Conv arg did not contain weight node ", weight_node)
|
|
516
|
+
if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
|
|
517
|
+
raise ValueError("Conv arg did not contain bias node ", bias_node)
|
|
518
|
+
|
|
519
|
+
# Skip if the partition is already annotated or is filtered out by the user
|
|
520
|
+
partition = [conv_node, weight_node]
|
|
521
|
+
if bias_node is not None:
|
|
522
|
+
partition.append(bias_node)
|
|
523
|
+
if _is_annotated(partition):
|
|
524
|
+
continue
|
|
525
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
|
526
|
+
continue
|
|
527
|
+
|
|
528
|
+
# Annotate conv inputs and pattern output
|
|
529
|
+
input_qspec_map = {}
|
|
530
|
+
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
|
531
|
+
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
|
|
532
|
+
if bias_node is not None:
|
|
533
|
+
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
|
|
534
|
+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
535
|
+
input_qspec_map=input_qspec_map,
|
|
536
|
+
_annotated=True,
|
|
537
|
+
)
|
|
538
|
+
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
539
|
+
output_qspec=get_output_act_qspec(
|
|
540
|
+
quantization_config
|
|
541
|
+
), # type: ignore[arg-type]
|
|
542
|
+
_annotated=True,
|
|
543
|
+
)
|
|
544
|
+
_mark_nodes_as_annotated(partition)
|
|
545
|
+
annotated_partitions.append(partition)
|
|
546
|
+
return annotated_partitions
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@register_annotator("gru_io_only")
|
|
550
|
+
def _annotate_gru_io_only(
|
|
551
|
+
gm: torch.fx.GraphModule,
|
|
552
|
+
quantization_config: Optional[QuantizationConfig],
|
|
553
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
554
|
+
) -> Optional[List[List[Node]]]:
|
|
555
|
+
gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
|
|
556
|
+
gru_partitions = list(itertools.chain(*gru_partitions.values()))
|
|
557
|
+
annotated_partitions = []
|
|
558
|
+
for gru_partition in gru_partitions:
|
|
559
|
+
annotated_partitions.append(gru_partition.nodes)
|
|
560
|
+
output_nodes = gru_partition.output_nodes
|
|
561
|
+
input_nodes = gru_partition.input_nodes
|
|
562
|
+
# skip annotation if it is already annotated
|
|
563
|
+
if _is_annotated(input_nodes + output_nodes):
|
|
564
|
+
continue
|
|
565
|
+
# inside each GRU partition, we should be able to annotate each linear
|
|
566
|
+
# subgraph
|
|
567
|
+
input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
|
|
568
|
+
input_act = input_nodes[0]
|
|
569
|
+
input_act_user = next(iter(input_act.users.keys()))
|
|
570
|
+
assert isinstance(input_act, Node)
|
|
571
|
+
assert isinstance(input_act_user, Node)
|
|
572
|
+
input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
573
|
+
input_qspec_map={
|
|
574
|
+
input_act: get_input_act_qspec(quantization_config),
|
|
575
|
+
},
|
|
576
|
+
_annotated=True,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
hidden_state = input_nodes[1]
|
|
580
|
+
hidden_state_user = next(iter(hidden_state.users.keys()))
|
|
581
|
+
assert isinstance(hidden_state, Node)
|
|
582
|
+
assert isinstance(hidden_state_user, Node)
|
|
583
|
+
hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
584
|
+
input_qspec_map={
|
|
585
|
+
hidden_state: get_input_act_qspec(quantization_config),
|
|
586
|
+
},
|
|
587
|
+
_annotated=True,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
assert len(output_nodes) == 2, "expecting GRU to have two outputs"
|
|
591
|
+
for output in output_nodes:
|
|
592
|
+
output.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
593
|
+
output_qspec=get_output_act_qspec(quantization_config),
|
|
594
|
+
_annotated=True,
|
|
595
|
+
)
|
|
596
|
+
nodes_to_mark_annotated = list(gru_partition.nodes)
|
|
597
|
+
_mark_nodes_as_annotated(nodes_to_mark_annotated)
|
|
598
|
+
return annotated_partitions
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
@register_annotator("max_pool2d")
|
|
602
|
+
def _annotate_max_pool2d(
|
|
603
|
+
gm: torch.fx.GraphModule,
|
|
604
|
+
quantization_config: Optional[QuantizationConfig],
|
|
605
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
606
|
+
) -> Optional[List[List[Node]]]:
|
|
607
|
+
module_partitions = get_source_partitions(
|
|
608
|
+
gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn
|
|
609
|
+
)
|
|
610
|
+
maxpool_partitions = list(itertools.chain(*module_partitions.values()))
|
|
611
|
+
annotated_partitions = []
|
|
612
|
+
for maxpool_partition in maxpool_partitions:
|
|
613
|
+
annotated_partitions.append(maxpool_partition.nodes)
|
|
614
|
+
output_node = maxpool_partition.output_nodes[0]
|
|
615
|
+
maxpool_node = None
|
|
616
|
+
for n in maxpool_partition.nodes:
|
|
617
|
+
if (
|
|
618
|
+
n.target == torch.ops.aten.max_pool2d.default
|
|
619
|
+
or n.target == torch.ops.aten.max_pool2d_with_indices.default
|
|
620
|
+
):
|
|
621
|
+
maxpool_node = n
|
|
622
|
+
|
|
623
|
+
assert (
|
|
624
|
+
maxpool_node is not None
|
|
625
|
+
), "PT2EQuantizer only works with torch.ops.aten.max_pool2d.default, "
|
|
626
|
+
"please make sure you are exporting the model correctly"
|
|
627
|
+
if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
|
|
628
|
+
continue
|
|
629
|
+
|
|
630
|
+
input_act = maxpool_node.args[0] # type: ignore[union-attr]
|
|
631
|
+
assert isinstance(input_act, Node)
|
|
632
|
+
|
|
633
|
+
# only annotate maxpool when the output of the input node is annotated
|
|
634
|
+
if (
|
|
635
|
+
"quantization_annotation" not in input_act.meta
|
|
636
|
+
or not input_act.meta["quantization_annotation"]._annotated
|
|
637
|
+
or input_act.meta["quantization_annotation"].output_qspec is None
|
|
638
|
+
):
|
|
639
|
+
continue
|
|
640
|
+
# input and output of maxpool will share quantization parameter with input of maxpool
|
|
641
|
+
act_qspec = SharedQuantizationSpec(input_act)
|
|
642
|
+
# act_qspec = get_act_qspec(quantization_config)
|
|
643
|
+
maxpool_node.meta[
|
|
644
|
+
# type: ignore[union-attr]
|
|
645
|
+
"quantization_annotation"
|
|
646
|
+
] = QuantizationAnnotation(
|
|
647
|
+
input_qspec_map={
|
|
648
|
+
input_act: act_qspec,
|
|
649
|
+
},
|
|
650
|
+
_annotated=True,
|
|
651
|
+
)
|
|
652
|
+
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
653
|
+
output_qspec=act_qspec,
|
|
654
|
+
_annotated=True,
|
|
655
|
+
)
|
|
656
|
+
return annotated_partitions
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
@register_annotator("adaptive_avg_pool2d")
|
|
660
|
+
def _annotate_adaptive_avg_pool2d(
|
|
661
|
+
gm: torch.fx.GraphModule,
|
|
662
|
+
quantization_config: Optional[QuantizationConfig],
|
|
663
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
664
|
+
) -> Optional[List[List[Node]]]:
|
|
665
|
+
"""Always annotate adaptive_avg_pool2d op"""
|
|
666
|
+
module_partitions = get_source_partitions(
|
|
667
|
+
gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
|
|
668
|
+
)
|
|
669
|
+
partitions = list(itertools.chain(*module_partitions.values()))
|
|
670
|
+
annotated_partitions = []
|
|
671
|
+
for partition in partitions:
|
|
672
|
+
pool_node = partition.output_nodes[0]
|
|
673
|
+
if pool_node.op != "call_function" or (
|
|
674
|
+
pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
|
|
675
|
+
and pool_node.target != torch.ops.aten._adaptive_avg_pool2d.default
|
|
676
|
+
and pool_node.target != torch.ops.aten.mean.dim
|
|
677
|
+
and pool_node.target != torch.ops.aten.as_strided_.default
|
|
678
|
+
):
|
|
679
|
+
raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
|
|
680
|
+
|
|
681
|
+
if _is_annotated([pool_node]):
|
|
682
|
+
continue
|
|
683
|
+
|
|
684
|
+
annotated_partitions.append(partition.nodes)
|
|
685
|
+
input_act = pool_node.args[0]
|
|
686
|
+
assert isinstance(input_act, Node)
|
|
687
|
+
|
|
688
|
+
# only annotate input output sharing operator
|
|
689
|
+
# when the output of the input node is annotated
|
|
690
|
+
if (
|
|
691
|
+
"quantization_annotation" not in input_act.meta
|
|
692
|
+
or not input_act.meta["quantization_annotation"]._annotated
|
|
693
|
+
or input_act.meta["quantization_annotation"].output_qspec is None
|
|
694
|
+
):
|
|
695
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
|
696
|
+
else:
|
|
697
|
+
input_act_qspec = SharedQuantizationSpec(input_act)
|
|
698
|
+
|
|
699
|
+
# output sharing with input
|
|
700
|
+
output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
|
|
701
|
+
pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
702
|
+
input_qspec_map={
|
|
703
|
+
input_act: input_act_qspec,
|
|
704
|
+
},
|
|
705
|
+
output_qspec=output_act_qspec,
|
|
706
|
+
_annotated=True,
|
|
707
|
+
)
|
|
708
|
+
return annotated_partitions
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
@register_annotator("fixed_qparams")
|
|
712
|
+
def _annotate_fixed_qparams(
|
|
713
|
+
gm: torch.fx.GraphModule,
|
|
714
|
+
quantization_config: Optional[QuantizationConfig],
|
|
715
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
716
|
+
) -> Optional[List[List[Node]]]:
|
|
717
|
+
annotated_partitions = []
|
|
718
|
+
for node in gm.graph.nodes:
|
|
719
|
+
if node.op != "call_function" or (
|
|
720
|
+
node.target != torch.ops.aten.sigmoid.default
|
|
721
|
+
and node.target != torch.ops.aten._softmax.default
|
|
722
|
+
):
|
|
723
|
+
continue
|
|
724
|
+
|
|
725
|
+
input_act = node.args[0] # type: ignore[union-attr]
|
|
726
|
+
assert isinstance(input_act, Node)
|
|
727
|
+
|
|
728
|
+
# only annotate when the output of the input node is annotated
|
|
729
|
+
if (
|
|
730
|
+
"quantization_annotation" not in input_act.meta
|
|
731
|
+
or not input_act.meta["quantization_annotation"]._annotated
|
|
732
|
+
or input_act.meta["quantization_annotation"].output_qspec is None
|
|
733
|
+
):
|
|
734
|
+
continue
|
|
735
|
+
partition = [node]
|
|
736
|
+
|
|
737
|
+
if _is_annotated(partition):
|
|
738
|
+
continue
|
|
739
|
+
|
|
740
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
|
741
|
+
continue
|
|
742
|
+
|
|
743
|
+
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
744
|
+
output_qspec=get_fixed_qparams_qspec(quantization_config), _annotated=True
|
|
745
|
+
)
|
|
746
|
+
_mark_nodes_as_annotated(partition)
|
|
747
|
+
annotated_partitions.append(partition)
|
|
748
|
+
|
|
749
|
+
return annotated_partitions
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
@register_annotator("add_relu")
|
|
753
|
+
def _annotate_add_relu(
|
|
754
|
+
gm: torch.fx.GraphModule,
|
|
755
|
+
quantization_config: Optional[QuantizationConfig],
|
|
756
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
757
|
+
) -> Optional[List[List[Node]]]:
|
|
758
|
+
fused_partitions = find_sequential_partitions(
|
|
759
|
+
gm, [torch.add, torch.nn.ReLU], filter_fn
|
|
760
|
+
)
|
|
761
|
+
annotated_partitions = []
|
|
762
|
+
for fused_partition in fused_partitions:
|
|
763
|
+
add_partition, relu_partition = fused_partition
|
|
764
|
+
annotated_partitions.append(add_partition.nodes + relu_partition.nodes)
|
|
765
|
+
if len(relu_partition.output_nodes) > 1:
|
|
766
|
+
raise ValueError("Relu partition has more than one output node")
|
|
767
|
+
relu_node = relu_partition.output_nodes[0]
|
|
768
|
+
if len(add_partition.output_nodes) > 1:
|
|
769
|
+
raise ValueError("add partition has more than one output node")
|
|
770
|
+
add_node = add_partition.output_nodes[0]
|
|
771
|
+
|
|
772
|
+
if _is_annotated([relu_node, add_node]):
|
|
773
|
+
continue
|
|
774
|
+
|
|
775
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
|
776
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
|
777
|
+
|
|
778
|
+
input_qspec_map = {}
|
|
779
|
+
input_act0 = add_node.args[0]
|
|
780
|
+
if isinstance(input_act0, Node):
|
|
781
|
+
input_qspec_map[input_act0] = input_act_qspec
|
|
782
|
+
|
|
783
|
+
input_act1 = add_node.args[1]
|
|
784
|
+
if isinstance(input_act1, Node):
|
|
785
|
+
input_qspec_map[input_act1] = input_act_qspec
|
|
786
|
+
|
|
787
|
+
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
788
|
+
input_qspec_map=input_qspec_map,
|
|
789
|
+
_annotated=True,
|
|
790
|
+
)
|
|
791
|
+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
792
|
+
output_qspec=output_act_qspec,
|
|
793
|
+
_annotated=True,
|
|
794
|
+
)
|
|
795
|
+
return annotated_partitions
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
@register_annotator("add")
|
|
799
|
+
def _annotate_add(
|
|
800
|
+
gm: torch.fx.GraphModule,
|
|
801
|
+
quantization_config: Optional[QuantizationConfig],
|
|
802
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
803
|
+
) -> Optional[List[List[Node]]]:
|
|
804
|
+
add_partitions = get_source_partitions(
|
|
805
|
+
gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
|
|
806
|
+
)
|
|
807
|
+
add_partitions = list(itertools.chain(*add_partitions.values()))
|
|
808
|
+
annotated_partitions = []
|
|
809
|
+
for add_partition in add_partitions:
|
|
810
|
+
annotated_partitions.append(add_partition.nodes)
|
|
811
|
+
add_node = add_partition.output_nodes[0]
|
|
812
|
+
if _is_annotated([add_node]):
|
|
813
|
+
continue
|
|
814
|
+
|
|
815
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
|
816
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
|
817
|
+
|
|
818
|
+
input_qspec_map = {}
|
|
819
|
+
input_act0 = add_node.args[0]
|
|
820
|
+
if isinstance(input_act0, Node):
|
|
821
|
+
input_qspec_map[input_act0] = input_act_qspec
|
|
822
|
+
|
|
823
|
+
input_act1 = add_node.args[1]
|
|
824
|
+
if isinstance(input_act1, Node):
|
|
825
|
+
input_qspec_map[input_act1] = input_act_qspec
|
|
826
|
+
|
|
827
|
+
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
828
|
+
input_qspec_map=input_qspec_map,
|
|
829
|
+
output_qspec=output_act_qspec,
|
|
830
|
+
_annotated=True,
|
|
831
|
+
)
|
|
832
|
+
return annotated_partitions
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
@register_annotator("mul_relu")
|
|
836
|
+
def _annotate_mul_relu(
|
|
837
|
+
gm: torch.fx.GraphModule,
|
|
838
|
+
quantization_config: Optional[QuantizationConfig],
|
|
839
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
840
|
+
) -> Optional[List[List[Node]]]:
|
|
841
|
+
fused_partitions = find_sequential_partitions(
|
|
842
|
+
gm, [torch.mul, torch.nn.ReLU], filter_fn
|
|
843
|
+
)
|
|
844
|
+
annotated_partitions = []
|
|
845
|
+
for fused_partition in fused_partitions:
|
|
846
|
+
mul_partition, relu_partition = fused_partition
|
|
847
|
+
annotated_partitions.append(mul_partition.nodes + relu_partition.nodes)
|
|
848
|
+
if len(relu_partition.output_nodes) > 1:
|
|
849
|
+
raise ValueError("Relu partition has more than one output node")
|
|
850
|
+
relu_node = relu_partition.output_nodes[0]
|
|
851
|
+
if len(mul_partition.output_nodes) > 1:
|
|
852
|
+
raise ValueError("mul partition has more than one output node")
|
|
853
|
+
mul_node = mul_partition.output_nodes[0]
|
|
854
|
+
|
|
855
|
+
if _is_annotated([relu_node, mul_node]):
|
|
856
|
+
continue
|
|
857
|
+
|
|
858
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
|
859
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
|
860
|
+
|
|
861
|
+
input_qspec_map = {}
|
|
862
|
+
input_act0 = mul_node.args[0]
|
|
863
|
+
if isinstance(input_act0, Node):
|
|
864
|
+
input_qspec_map[input_act0] = input_act_qspec
|
|
865
|
+
|
|
866
|
+
input_act1 = mul_node.args[1]
|
|
867
|
+
if isinstance(input_act1, Node):
|
|
868
|
+
input_qspec_map[input_act1] = input_act_qspec
|
|
869
|
+
|
|
870
|
+
mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
871
|
+
input_qspec_map=input_qspec_map,
|
|
872
|
+
_annotated=True,
|
|
873
|
+
)
|
|
874
|
+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
875
|
+
output_qspec=output_act_qspec,
|
|
876
|
+
_annotated=True,
|
|
877
|
+
)
|
|
878
|
+
return annotated_partitions
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
@register_annotator("mul")
|
|
882
|
+
def _annotate_mul(
|
|
883
|
+
gm: torch.fx.GraphModule,
|
|
884
|
+
quantization_config: Optional[QuantizationConfig],
|
|
885
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
886
|
+
) -> Optional[List[List[Node]]]:
|
|
887
|
+
mul_partitions = get_source_partitions(
|
|
888
|
+
gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
|
|
889
|
+
)
|
|
890
|
+
mul_partitions = list(itertools.chain(*mul_partitions.values()))
|
|
891
|
+
annotated_partitions = []
|
|
892
|
+
for mul_partition in mul_partitions:
|
|
893
|
+
annotated_partitions.append(mul_partition.nodes)
|
|
894
|
+
mul_node = mul_partition.output_nodes[0]
|
|
895
|
+
if _is_annotated([mul_node]):
|
|
896
|
+
continue
|
|
897
|
+
|
|
898
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
|
899
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
|
900
|
+
|
|
901
|
+
input_qspec_map = {}
|
|
902
|
+
input_act0 = mul_node.args[0]
|
|
903
|
+
if isinstance(input_act0, Node):
|
|
904
|
+
input_qspec_map[input_act0] = input_act_qspec
|
|
905
|
+
|
|
906
|
+
input_act1 = mul_node.args[1]
|
|
907
|
+
if isinstance(input_act1, Node):
|
|
908
|
+
input_qspec_map[input_act1] = input_act_qspec
|
|
909
|
+
|
|
910
|
+
mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
911
|
+
input_qspec_map=input_qspec_map,
|
|
912
|
+
output_qspec=output_act_qspec,
|
|
913
|
+
_annotated=True,
|
|
914
|
+
)
|
|
915
|
+
return annotated_partitions
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
# TODO: remove Optional in return type, fix annotated_partitions logic
|
|
919
|
+
@register_annotator("cat")
|
|
920
|
+
def _annotate_cat(
|
|
921
|
+
gm: torch.fx.GraphModule,
|
|
922
|
+
quantization_config: Optional[QuantizationConfig],
|
|
923
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
924
|
+
) -> Optional[List[List[Node]]]:
|
|
925
|
+
cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
|
|
926
|
+
cat_partitions = list(itertools.chain(*cat_partitions.values()))
|
|
927
|
+
annotated_partitions = []
|
|
928
|
+
for cat_partition in cat_partitions:
|
|
929
|
+
cat_node = cat_partition.output_nodes[0]
|
|
930
|
+
if _is_annotated([cat_node]):
|
|
931
|
+
continue
|
|
932
|
+
|
|
933
|
+
if cat_node.target != torch.ops.aten.cat.default:
|
|
934
|
+
raise Exception(
|
|
935
|
+
f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
|
|
936
|
+
" please check if you are calling the correct capture API"
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
annotated_partitions.append(cat_partition.nodes)
|
|
940
|
+
|
|
941
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
|
942
|
+
inputs = cat_node.args[0]
|
|
943
|
+
|
|
944
|
+
input_qspec_map = {}
|
|
945
|
+
input_act0 = inputs[0]
|
|
946
|
+
if isinstance(input_act0, Node):
|
|
947
|
+
input_qspec_map[input_act0] = input_act_qspec
|
|
948
|
+
|
|
949
|
+
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
|
|
950
|
+
for input_act in inputs[1:]:
|
|
951
|
+
input_qspec_map[input_act] = shared_with_input0_qspec
|
|
952
|
+
|
|
953
|
+
output_act_qspec = shared_with_input0_qspec
|
|
954
|
+
|
|
955
|
+
cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
956
|
+
input_qspec_map=input_qspec_map,
|
|
957
|
+
output_qspec=output_act_qspec,
|
|
958
|
+
_annotated=True,
|
|
959
|
+
)
|
|
960
|
+
return annotated_partitions
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def _is_share_obs_or_fq_op(op: Callable) -> bool:
|
|
964
|
+
return op in [
|
|
965
|
+
torch.ops.aten.hardtanh.default,
|
|
966
|
+
torch.ops.aten.hardtanh_.default,
|
|
967
|
+
torch.ops.aten.mean.default,
|
|
968
|
+
torch.ops.aten.mean.dim,
|
|
969
|
+
torch.ops.aten.permute.default,
|
|
970
|
+
torch.ops.aten.permute_copy.default,
|
|
971
|
+
torch.ops.aten.squeeze.dim,
|
|
972
|
+
torch.ops.aten.squeeze_copy.dim,
|
|
973
|
+
torch.ops.aten.adaptive_avg_pool2d.default,
|
|
974
|
+
torch.ops.aten.view_copy.default,
|
|
975
|
+
torch.ops.aten.view.default,
|
|
976
|
+
torch.ops.aten.slice_copy.Tensor,
|
|
977
|
+
torch.ops.aten.flatten.using_ints,
|
|
978
|
+
]
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
def propagate_annotation(model: torch.fx.GraphModule) -> None:
|
|
982
|
+
for n in model.graph.nodes:
|
|
983
|
+
if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
|
|
984
|
+
continue
|
|
985
|
+
|
|
986
|
+
prev_node = n.args[0]
|
|
987
|
+
if not isinstance(prev_node, Node):
|
|
988
|
+
continue
|
|
989
|
+
|
|
990
|
+
quantization_annotation = prev_node.meta.get("quantization_annotation", None)
|
|
991
|
+
if not quantization_annotation:
|
|
992
|
+
continue
|
|
993
|
+
|
|
994
|
+
output_qspec = quantization_annotation.output_qspec
|
|
995
|
+
if not output_qspec:
|
|
996
|
+
continue
|
|
997
|
+
|
|
998
|
+
# make sure current node is not annotated
|
|
999
|
+
if (
|
|
1000
|
+
"quantization_annotation" in n.meta
|
|
1001
|
+
and n.meta["quantization_annotation"]._annotated
|
|
1002
|
+
):
|
|
1003
|
+
continue
|
|
1004
|
+
|
|
1005
|
+
shared_qspec = SharedQuantizationSpec(prev_node)
|
|
1006
|
+
# propagate the previous output_qspec to the current node
|
|
1007
|
+
n.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
1008
|
+
input_qspec_map={
|
|
1009
|
+
prev_node: shared_qspec,
|
|
1010
|
+
},
|
|
1011
|
+
output_qspec=shared_qspec,
|
|
1012
|
+
_annotated=True,
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
# TODO: make the list of ops customizable
|
|
1017
|
+
def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
1018
|
+
for n in model.graph.nodes:
|
|
1019
|
+
if n.op != "call_function" or n.target not in [
|
|
1020
|
+
torch.ops.aten.add.Tensor,
|
|
1021
|
+
torch.ops.aten.mul.Tensor,
|
|
1022
|
+
]:
|
|
1023
|
+
continue
|
|
1024
|
+
args = list(n.args)
|
|
1025
|
+
new_args = []
|
|
1026
|
+
for i in range(len(args)):
|
|
1027
|
+
if isinstance(args[i], torch.fx.Node):
|
|
1028
|
+
new_args.append(args[i])
|
|
1029
|
+
continue
|
|
1030
|
+
prefix = "_tensor_constant_"
|
|
1031
|
+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
|
|
1032
|
+
tensor_constant_name = get_new_attr_name(model)
|
|
1033
|
+
model.register_buffer(tensor_constant_name, torch.tensor(args[i]))
|
|
1034
|
+
with model.graph.inserting_before(n):
|
|
1035
|
+
get_attr_node = model.graph.create_node(
|
|
1036
|
+
"get_attr", tensor_constant_name, (), {}
|
|
1037
|
+
)
|
|
1038
|
+
new_args.append(get_attr_node)
|
|
1039
|
+
n.args = tuple(new_args)
|
|
1040
|
+
model.recompile()
|
|
1041
|
+
return model
|