ai-edge-torch-nightly 0.3.0.dev20250114__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +32 -0
- ai_edge_torch/_config.py +69 -0
- ai_edge_torch/_convert/__init__.py +14 -0
- ai_edge_torch/_convert/conversion.py +153 -0
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/_convert/converter.py +270 -0
- ai_edge_torch/_convert/fx_passes/__init__.py +23 -0
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +288 -0
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +131 -0
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +258 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +50 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +18 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +68 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +216 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +449 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +303 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py +64 -0
- ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +52 -0
- ai_edge_torch/_convert/signature.py +66 -0
- ai_edge_torch/_convert/test/__init__.py +14 -0
- ai_edge_torch/_convert/test/test_convert.py +558 -0
- ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
- ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/_convert/to_channel_last_io.py +92 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +496 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +140 -0
- ai_edge_torch/debug/test/test_search_model.py +51 -0
- ai_edge_torch/debug/utils.py +59 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/fx_pass_base.py +110 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/amd_llama_135m/__init__.py +14 -0
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +87 -0
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +70 -0
- ai_edge_torch/generative/examples/amd_llama_135m/verify.py +72 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +107 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +295 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +43 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +157 -0
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +91 -0
- ai_edge_torch/generative/examples/llama/llama.py +196 -0
- ai_edge_torch/generative/examples/llama/verify.py +88 -0
- ai_edge_torch/generative/examples/moonshine/__init__.py +14 -0
- ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py +50 -0
- ai_edge_torch/generative/examples/moonshine/moonshine.py +103 -0
- ai_edge_torch/generative/examples/openelm/__init__.py +14 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +127 -0
- ai_edge_torch/generative/examples/openelm/verify.py +71 -0
- ai_edge_torch/generative/examples/paligemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +95 -0
- ai_edge_torch/generative/examples/paligemma/decoder.py +151 -0
- ai_edge_torch/generative/examples/paligemma/decoder2.py +177 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +160 -0
- ai_edge_torch/generative/examples/paligemma/paligemma.py +179 -0
- ai_edge_torch/generative/examples/paligemma/verify.py +161 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder.py +75 -0
- ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +99 -0
- ai_edge_torch/generative/examples/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/phi/phi2.py +107 -0
- ai_edge_torch/generative/examples/phi/phi3.py +219 -0
- ai_edge_torch/generative/examples/phi/verify.py +64 -0
- ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
- ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +93 -0
- ai_edge_torch/generative/examples/qwen/qwen.py +134 -0
- ai_edge_torch/generative/examples/qwen/verify.py +88 -0
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +125 -0
- ai_edge_torch/generative/examples/smollm/verify.py +86 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +185 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +173 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +398 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +749 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +119 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +254 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +62 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +66 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +74 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +39 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +111 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +77 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +138 -0
- ai_edge_torch/generative/examples/t5/t5.py +655 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +246 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +156 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +138 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +88 -0
- ai_edge_torch/generative/examples/tiny_llama/verify.py +72 -0
- ai_edge_torch/generative/fx_passes/__init__.py +30 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +50 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +399 -0
- ai_edge_torch/generative/layers/attention_utils.py +210 -0
- ai_edge_torch/generative/layers/builder.py +160 -0
- ai_edge_torch/generative/layers/feed_forward.py +120 -0
- ai_edge_torch/generative/layers/kv_cache.py +204 -0
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +238 -0
- ai_edge_torch/generative/layers/normalization.py +222 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +94 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +144 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +806 -0
- ai_edge_torch/generative/layers/unet/builder.py +50 -0
- ai_edge_torch/generative/layers/unet/model_config.py +282 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +47 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +154 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +62 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +56 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_custom_dus.py +107 -0
- ai_edge_torch/generative/test/test_kv_cache.py +120 -0
- ai_edge_torch/generative/test/test_loader.py +83 -0
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion.py +191 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +362 -0
- ai_edge_torch/generative/test/test_quantize.py +183 -0
- ai_edge_torch/generative/test/utils.py +82 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/converter.py +215 -0
- ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
- ai_edge_torch/generative/utilities/loader.py +398 -0
- ai_edge_torch/generative/utilities/model_builder.py +180 -0
- ai_edge_torch/generative/utilities/moonshine_loader.py +154 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +1032 -0
- ai_edge_torch/generative/utilities/t5_loader.py +512 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +335 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +153 -0
- ai_edge_torch/hlfb/mark_pattern/fx_utils.py +69 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +288 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +185 -0
- ai_edge_torch/lowertools/__init__.py +18 -0
- ai_edge_torch/lowertools/_shim.py +86 -0
- ai_edge_torch/lowertools/common_utils.py +142 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +260 -0
- ai_edge_torch/lowertools/test_utils.py +62 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +301 -0
- ai_edge_torch/lowertools/translate_recipe.py +163 -0
- ai_edge_torch/model.py +177 -0
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +88 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +403 -0
- ai_edge_torch/odml_torch/export_utils.py +157 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +18 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +180 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +27 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +294 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +243 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +285 -0
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +87 -0
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +177 -0
- ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/decomp.py +69 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +65 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +201 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +156 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +466 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1061 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +145 -0
- ai_edge_torch/version.py +16 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/METADATA +44 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/RECORD +213 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.3.0.dev20250114.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1061 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from dataclasses import dataclass
|
17
|
+
import itertools
|
18
|
+
import operator
|
19
|
+
from typing import Callable, Dict, List, NamedTuple, Optional
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
23
|
+
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
24
|
+
from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
|
25
|
+
from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
|
26
|
+
from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
|
27
|
+
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
28
|
+
from torch.ao.quantization.quantizer import QuantizationSpec
|
29
|
+
from torch.ao.quantization.quantizer import QuantizationSpecBase
|
30
|
+
from torch.ao.quantization.quantizer import SharedQuantizationSpec
|
31
|
+
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map
|
32
|
+
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
|
33
|
+
from torch.fx import Node
|
34
|
+
from torch.fx.passes.utils.matcher_with_name_node_map_utils import SubgraphMatcherWithNameNodeMap
|
35
|
+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
|
36
|
+
import torch.nn.functional as F
|
37
|
+
|
38
|
+
__all__ = [
|
39
|
+
"OperatorConfig",
|
40
|
+
"OperatorPatternType",
|
41
|
+
"QuantizationConfig",
|
42
|
+
"get_input_act_qspec",
|
43
|
+
"get_output_act_qspec",
|
44
|
+
"get_weight_qspec",
|
45
|
+
"get_bias_qspec",
|
46
|
+
"OP_TO_ANNOTATOR",
|
47
|
+
"propagate_annotation",
|
48
|
+
]
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass(eq=True, frozen=True)
|
52
|
+
class QuantizationConfig:
|
53
|
+
input_activation: Optional[QuantizationSpec]
|
54
|
+
output_activation: Optional[QuantizationSpec]
|
55
|
+
weight: Optional[QuantizationSpec]
|
56
|
+
bias: Optional[QuantizationSpec]
|
57
|
+
fixed_qparams: Optional[QuantizationSpec]
|
58
|
+
# TODO: remove, since we can use observer_or_fake_quant_ctr to express this
|
59
|
+
is_qat: bool = False
|
60
|
+
is_dynamic: bool = False
|
61
|
+
|
62
|
+
|
63
|
+
OperatorPatternType = List[Callable]
|
64
|
+
OperatorPatternType.__module__ = "ai_edge_torch.quantize.pt2e_quantizer_utils"
|
65
|
+
|
66
|
+
AnnotatorType = Callable[
|
67
|
+
[
|
68
|
+
torch.fx.GraphModule,
|
69
|
+
Optional[QuantizationConfig],
|
70
|
+
Optional[Callable[[Node], bool]],
|
71
|
+
],
|
72
|
+
Optional[List[List[Node]]],
|
73
|
+
]
|
74
|
+
OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {}
|
75
|
+
|
76
|
+
|
77
|
+
def register_annotator(op: str):
|
78
|
+
def decorator(annotator: AnnotatorType):
|
79
|
+
OP_TO_ANNOTATOR[op] = annotator
|
80
|
+
|
81
|
+
return decorator
|
82
|
+
|
83
|
+
|
84
|
+
class OperatorConfig(NamedTuple):
|
85
|
+
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
|
86
|
+
# Basically we are mapping a quantization config to some list of patterns.
|
87
|
+
# a pattern is defined as a list of nn module, function or builtin function names
|
88
|
+
# e.g. [nn.Conv2d, torch.relu, torch.add]
|
89
|
+
# We have not resolved whether fusion can be considered internal details of the
|
90
|
+
# quantizer hence it does not need communication to user.
|
91
|
+
# Note this pattern is not really informative since it does not really
|
92
|
+
# tell us the graph structure resulting from the list of ops.
|
93
|
+
config: QuantizationConfig
|
94
|
+
operators: List[OperatorPatternType]
|
95
|
+
|
96
|
+
|
97
|
+
def _is_annotated(nodes: List[Node]):
|
98
|
+
"""Checks if a list of nodes is annotated.
|
99
|
+
|
100
|
+
Given a list of nodes (that represents an operator pattern), check if any of
|
101
|
+
the node is annotated, return True if any of the node
|
102
|
+
is annotated, otherwise return False
|
103
|
+
"""
|
104
|
+
annotated = False
|
105
|
+
for node in nodes:
|
106
|
+
annotated = annotated or (
|
107
|
+
"quantization_annotation" in node.meta
|
108
|
+
and node.meta["quantization_annotation"]._annotated
|
109
|
+
)
|
110
|
+
return annotated
|
111
|
+
|
112
|
+
|
113
|
+
def _mark_nodes_as_annotated(nodes: List[Node]):
|
114
|
+
for node in nodes:
|
115
|
+
if node is not None:
|
116
|
+
if "quantization_annotation" not in node.meta:
|
117
|
+
node.meta["quantization_annotation"] = QuantizationAnnotation()
|
118
|
+
node.meta["quantization_annotation"]._annotated = True
|
119
|
+
|
120
|
+
|
121
|
+
def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
|
122
|
+
if quantization_config is None:
|
123
|
+
return None
|
124
|
+
if quantization_config.input_activation is None:
|
125
|
+
return None
|
126
|
+
quantization_spec: QuantizationSpec = quantization_config.input_activation
|
127
|
+
assert quantization_spec.qscheme in [
|
128
|
+
torch.per_tensor_affine,
|
129
|
+
torch.per_tensor_symmetric,
|
130
|
+
]
|
131
|
+
return quantization_spec
|
132
|
+
|
133
|
+
|
134
|
+
def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
|
135
|
+
if quantization_config is None:
|
136
|
+
return None
|
137
|
+
if quantization_config.output_activation is None:
|
138
|
+
return None
|
139
|
+
quantization_spec: QuantizationSpec = quantization_config.output_activation
|
140
|
+
assert quantization_spec.qscheme in [
|
141
|
+
torch.per_tensor_affine,
|
142
|
+
torch.per_tensor_symmetric,
|
143
|
+
]
|
144
|
+
return quantization_spec
|
145
|
+
|
146
|
+
|
147
|
+
def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
|
148
|
+
if quantization_config is None:
|
149
|
+
return None
|
150
|
+
assert quantization_config is not None
|
151
|
+
if quantization_config.weight is None:
|
152
|
+
return None
|
153
|
+
quantization_spec: QuantizationSpec = quantization_config.weight
|
154
|
+
if quantization_spec.qscheme not in [
|
155
|
+
torch.per_tensor_symmetric,
|
156
|
+
torch.per_channel_symmetric,
|
157
|
+
]:
|
158
|
+
raise ValueError(
|
159
|
+
f"Unsupported quantization_spec {quantization_spec} for weight"
|
160
|
+
)
|
161
|
+
return quantization_spec
|
162
|
+
|
163
|
+
|
164
|
+
def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
|
165
|
+
if quantization_config is None:
|
166
|
+
return None
|
167
|
+
assert quantization_config is not None
|
168
|
+
if quantization_config.bias is None:
|
169
|
+
return None
|
170
|
+
quantization_spec: QuantizationSpec = quantization_config.bias
|
171
|
+
assert (
|
172
|
+
quantization_spec.dtype == torch.float
|
173
|
+
), "Only float dtype for bias is supported for bias right now"
|
174
|
+
return quantization_spec
|
175
|
+
|
176
|
+
|
177
|
+
def get_fixed_qparams_qspec(quantization_config: Optional[QuantizationConfig]):
|
178
|
+
if quantization_config is None:
|
179
|
+
return None
|
180
|
+
assert quantization_config is not None
|
181
|
+
if quantization_config.fixed_qparams is None:
|
182
|
+
return None
|
183
|
+
quantization_spec: QuantizationSpec = quantization_config.fixed_qparams
|
184
|
+
return quantization_spec
|
185
|
+
|
186
|
+
|
187
|
+
@register_annotator("linear")
|
188
|
+
def _annotate_linear(
|
189
|
+
gm: torch.fx.GraphModule,
|
190
|
+
quantization_config: Optional[QuantizationConfig],
|
191
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
192
|
+
) -> Optional[List[List[Node]]]:
|
193
|
+
annotated_partitions = []
|
194
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
195
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
196
|
+
weight_qspec = get_weight_qspec(quantization_config)
|
197
|
+
bias_qspec = get_bias_qspec(quantization_config)
|
198
|
+
for node in gm.graph.nodes:
|
199
|
+
if (
|
200
|
+
node.op != "call_function"
|
201
|
+
or node.target != torch.ops.aten.linear.default
|
202
|
+
):
|
203
|
+
continue
|
204
|
+
if filter_fn and not filter_fn(node):
|
205
|
+
continue
|
206
|
+
act_node = node.args[0]
|
207
|
+
weight_node = node.args[1]
|
208
|
+
bias_node = None
|
209
|
+
if len(node.args) > 2:
|
210
|
+
bias_node = node.args[2]
|
211
|
+
|
212
|
+
if _is_annotated([node]) is False: # type: ignore[list-item]
|
213
|
+
_annotate_input_qspec_map(
|
214
|
+
node,
|
215
|
+
act_node,
|
216
|
+
input_act_qspec,
|
217
|
+
)
|
218
|
+
_annotate_input_qspec_map(
|
219
|
+
node,
|
220
|
+
weight_node,
|
221
|
+
weight_qspec,
|
222
|
+
)
|
223
|
+
nodes_to_mark_annotated = [node, weight_node]
|
224
|
+
if bias_node:
|
225
|
+
_annotate_input_qspec_map(
|
226
|
+
node,
|
227
|
+
bias_node,
|
228
|
+
bias_qspec,
|
229
|
+
)
|
230
|
+
nodes_to_mark_annotated.append(bias_node)
|
231
|
+
_annotate_output_qspec(node, output_act_qspec)
|
232
|
+
_mark_nodes_as_annotated(nodes_to_mark_annotated)
|
233
|
+
annotated_partitions.append(nodes_to_mark_annotated)
|
234
|
+
|
235
|
+
return annotated_partitions
|
236
|
+
|
237
|
+
|
238
|
+
@register_annotator("addmm")
|
239
|
+
def _annotate_addmm(
|
240
|
+
gm: torch.fx.GraphModule,
|
241
|
+
quantization_config: Optional[QuantizationConfig],
|
242
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
243
|
+
) -> Optional[List[List[Node]]]:
|
244
|
+
annotated_partitions = []
|
245
|
+
for n in gm.graph.nodes:
|
246
|
+
if n.op != "call_function" or n.target not in [
|
247
|
+
torch.ops.aten.addmm.default,
|
248
|
+
]:
|
249
|
+
continue
|
250
|
+
addm_node = n
|
251
|
+
|
252
|
+
input_qspec_map = {}
|
253
|
+
input_act = addm_node.args[0]
|
254
|
+
assert isinstance(input_act, Node)
|
255
|
+
is_bias = (
|
256
|
+
len(list(input_act.meta["val"].size())) < 2
|
257
|
+
and input_act.op == "get_attr"
|
258
|
+
and "_param_constant" in input_act.target
|
259
|
+
)
|
260
|
+
input_qspec_map[input_act] = (
|
261
|
+
get_bias_qspec(quantization_config)
|
262
|
+
if is_bias
|
263
|
+
else get_input_act_qspec(quantization_config)
|
264
|
+
)
|
265
|
+
|
266
|
+
mat1_act = addm_node.args[1]
|
267
|
+
assert isinstance(mat1_act, Node)
|
268
|
+
input_qspec_map[mat1_act] = get_input_act_qspec(quantization_config)
|
269
|
+
|
270
|
+
mat2_act = addm_node.args[2]
|
271
|
+
assert isinstance(mat2_act, Node)
|
272
|
+
is_weight = False
|
273
|
+
if mat2_act.op == "get_attr" and "_param_constant" in mat2_act.target:
|
274
|
+
is_weight = True
|
275
|
+
elif mat2_act.target == torch.ops.aten.t.default:
|
276
|
+
t_in = mat2_act.args[0]
|
277
|
+
if t_in.op == "get_attr" and "_param_constant" in t_in.target:
|
278
|
+
is_weight = True
|
279
|
+
input_qspec_map[mat2_act] = (
|
280
|
+
get_weight_qspec(quantization_config)
|
281
|
+
if is_weight
|
282
|
+
else get_input_act_qspec(quantization_config)
|
283
|
+
)
|
284
|
+
|
285
|
+
partition = [addm_node, addm_node.args[1], addm_node.args[2]]
|
286
|
+
|
287
|
+
if _is_annotated(partition):
|
288
|
+
continue
|
289
|
+
|
290
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
291
|
+
continue
|
292
|
+
|
293
|
+
addm_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
294
|
+
input_qspec_map=input_qspec_map,
|
295
|
+
output_qspec=get_output_act_qspec(quantization_config),
|
296
|
+
_annotated=True,
|
297
|
+
)
|
298
|
+
_mark_nodes_as_annotated(partition)
|
299
|
+
annotated_partitions.append(partition)
|
300
|
+
return annotated_partitions
|
301
|
+
|
302
|
+
|
303
|
+
@register_annotator("conv")
|
304
|
+
def _annotate_conv(
|
305
|
+
gm: torch.fx.GraphModule,
|
306
|
+
quantization_config: Optional[QuantizationConfig],
|
307
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
308
|
+
) -> Optional[List[List[Node]]]:
|
309
|
+
annotated_partitions = []
|
310
|
+
for n in gm.graph.nodes:
|
311
|
+
if n.op != "call_function" or n.target not in [
|
312
|
+
torch.ops.aten.conv1d.default,
|
313
|
+
torch.ops.aten.conv2d.default,
|
314
|
+
torch.ops.aten.convolution.default,
|
315
|
+
]:
|
316
|
+
continue
|
317
|
+
conv_node = n
|
318
|
+
|
319
|
+
input_qspec_map = {}
|
320
|
+
input_act = conv_node.args[0]
|
321
|
+
assert isinstance(input_act, Node)
|
322
|
+
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
|
323
|
+
|
324
|
+
weight = conv_node.args[1]
|
325
|
+
assert isinstance(weight, Node)
|
326
|
+
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
327
|
+
|
328
|
+
# adding weight node to the partition as well
|
329
|
+
partition = [conv_node, conv_node.args[1]]
|
330
|
+
|
331
|
+
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
|
332
|
+
if isinstance(bias, Node):
|
333
|
+
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
334
|
+
partition.append(bias)
|
335
|
+
|
336
|
+
if _is_annotated(partition):
|
337
|
+
continue
|
338
|
+
|
339
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
340
|
+
continue
|
341
|
+
|
342
|
+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
343
|
+
input_qspec_map=input_qspec_map,
|
344
|
+
output_qspec=get_output_act_qspec(quantization_config),
|
345
|
+
_annotated=True,
|
346
|
+
)
|
347
|
+
_mark_nodes_as_annotated(partition)
|
348
|
+
annotated_partitions.append(partition)
|
349
|
+
return annotated_partitions
|
350
|
+
|
351
|
+
|
352
|
+
@register_annotator("conv_relu")
|
353
|
+
def _annotate_conv_relu(
|
354
|
+
gm: torch.fx.GraphModule,
|
355
|
+
quantization_config: Optional[QuantizationConfig],
|
356
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
357
|
+
) -> Optional[List[List[Node]]]:
|
358
|
+
annotated_partitions = []
|
359
|
+
for n in gm.graph.nodes:
|
360
|
+
if n.op != "call_function" or n.target not in [
|
361
|
+
torch.ops.aten.relu.default,
|
362
|
+
torch.ops.aten.relu_.default,
|
363
|
+
]:
|
364
|
+
continue
|
365
|
+
relu_node = n
|
366
|
+
maybe_conv_node = n.args[0]
|
367
|
+
if (
|
368
|
+
not isinstance(maybe_conv_node, Node)
|
369
|
+
or maybe_conv_node.op != "call_function"
|
370
|
+
or maybe_conv_node.target
|
371
|
+
not in [
|
372
|
+
torch.ops.aten.conv1d.default,
|
373
|
+
torch.ops.aten.conv2d.default,
|
374
|
+
torch.ops.aten.convolution.default,
|
375
|
+
]
|
376
|
+
):
|
377
|
+
continue
|
378
|
+
conv_node = maybe_conv_node
|
379
|
+
|
380
|
+
input_qspec_map = {}
|
381
|
+
input_act = conv_node.args[0]
|
382
|
+
assert isinstance(input_act, Node)
|
383
|
+
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
|
384
|
+
|
385
|
+
weight = conv_node.args[1]
|
386
|
+
assert isinstance(weight, Node)
|
387
|
+
input_qspec_map[weight] = get_weight_qspec(quantization_config)
|
388
|
+
|
389
|
+
# adding weight node to the partition as well
|
390
|
+
partition = [relu_node, conv_node, conv_node.args[1]]
|
391
|
+
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
|
392
|
+
if isinstance(bias, Node):
|
393
|
+
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
394
|
+
partition.append(bias)
|
395
|
+
|
396
|
+
if _is_annotated(partition):
|
397
|
+
continue
|
398
|
+
|
399
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
400
|
+
continue
|
401
|
+
|
402
|
+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
403
|
+
input_qspec_map=input_qspec_map, _annotated=True
|
404
|
+
)
|
405
|
+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
406
|
+
output_qspec=get_output_act_qspec(
|
407
|
+
quantization_config
|
408
|
+
), # type: ignore[arg-type]
|
409
|
+
_annotated=True,
|
410
|
+
)
|
411
|
+
_mark_nodes_as_annotated(partition)
|
412
|
+
annotated_partitions.append(partition)
|
413
|
+
return annotated_partitions
|
414
|
+
|
415
|
+
|
416
|
+
@register_annotator("conv_bn")
|
417
|
+
def _annotate_conv_bn(
|
418
|
+
gm: torch.fx.GraphModule,
|
419
|
+
quantization_config: Optional[QuantizationConfig],
|
420
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
421
|
+
) -> Optional[List[List[Node]]]:
|
422
|
+
"""Find conv + batchnorm parititions Note: This is only used for QAT.
|
423
|
+
|
424
|
+
In PTQ, batchnorm should already be fused into the conv.
|
425
|
+
"""
|
426
|
+
return _do_annotate_conv_bn(
|
427
|
+
gm, quantization_config, filter_fn, has_relu=False
|
428
|
+
)
|
429
|
+
|
430
|
+
|
431
|
+
@register_annotator("conv_bn_relu")
|
432
|
+
def _annotate_conv_bn_relu(
|
433
|
+
gm: torch.fx.GraphModule,
|
434
|
+
quantization_config: Optional[QuantizationConfig],
|
435
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
436
|
+
) -> Optional[List[List[Node]]]:
|
437
|
+
"""Find conv + batchnorm + relu parititions Note: This is only used for QAT.
|
438
|
+
|
439
|
+
In PTQ, batchnorm should already be fused into the conv.
|
440
|
+
"""
|
441
|
+
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
|
442
|
+
|
443
|
+
|
444
|
+
def _do_annotate_conv_bn(
|
445
|
+
gm: torch.fx.GraphModule,
|
446
|
+
quantization_config: Optional[QuantizationConfig],
|
447
|
+
filter_fn: Optional[Callable[[Node], bool]],
|
448
|
+
has_relu: bool,
|
449
|
+
) -> List[List[Node]]:
|
450
|
+
"""Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
|
451
|
+
|
452
|
+
return a list of annotated partitions.
|
453
|
+
|
454
|
+
The output of the pattern must include a dictionary from string name to node
|
455
|
+
for the following names: "input", "conv", "weight", "bias", and "output".
|
456
|
+
"""
|
457
|
+
|
458
|
+
def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
|
459
|
+
def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
|
460
|
+
conv = conv_fn(x, conv_weight, conv_bias)
|
461
|
+
bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
|
462
|
+
if has_relu:
|
463
|
+
output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
|
464
|
+
else:
|
465
|
+
output = bn
|
466
|
+
return output, {
|
467
|
+
"input": x,
|
468
|
+
"conv": conv,
|
469
|
+
"weight": conv_weight,
|
470
|
+
"bias": conv_bias,
|
471
|
+
"output": output,
|
472
|
+
}
|
473
|
+
|
474
|
+
return _conv_bn
|
475
|
+
|
476
|
+
# Needed for matching, otherwise the matches gets filtered out due to unused
|
477
|
+
# nodes returned by batch norm
|
478
|
+
gm.graph.eliminate_dead_code()
|
479
|
+
gm.recompile()
|
480
|
+
|
481
|
+
matches = []
|
482
|
+
combinations = [
|
483
|
+
(F.conv1d, _conv1d_bn_example_inputs),
|
484
|
+
(F.conv2d, _conv2d_bn_example_inputs),
|
485
|
+
]
|
486
|
+
|
487
|
+
# Add `is_cuda` and `relu_is_inplace` dimensions
|
488
|
+
combinations = itertools.product(
|
489
|
+
combinations,
|
490
|
+
[True, False] if torch.cuda.is_available() else [False], # is_cuda
|
491
|
+
[True, False] if has_relu else [False], # relu_is_inplace
|
492
|
+
)
|
493
|
+
|
494
|
+
# Match against all conv dimensions and cuda variants
|
495
|
+
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
|
496
|
+
pattern = get_pattern(conv_fn, relu_is_inplace)
|
497
|
+
pattern = _get_aten_graph_module_for_pattern(
|
498
|
+
pattern, example_inputs, is_cuda
|
499
|
+
)
|
500
|
+
pattern.graph.eliminate_dead_code()
|
501
|
+
pattern.recompile()
|
502
|
+
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
503
|
+
matches.extend(matcher.match(gm.graph))
|
504
|
+
|
505
|
+
# Annotate nodes returned in the matches
|
506
|
+
annotated_partitions = []
|
507
|
+
for match in matches:
|
508
|
+
name_node_map = match.name_node_map
|
509
|
+
input_node = name_node_map["input"]
|
510
|
+
conv_node = name_node_map["conv"]
|
511
|
+
weight_node = name_node_map["weight"]
|
512
|
+
bias_node = name_node_map["bias"]
|
513
|
+
output_node = name_node_map["output"]
|
514
|
+
|
515
|
+
# TODO: annotate the uses of input, weight, and bias separately instead
|
516
|
+
# of assuming they come from a single conv node. This is not possible today
|
517
|
+
# because input may have multiple users, and we can't rely on the conv node
|
518
|
+
# always being the first user. This was the case in models with skip
|
519
|
+
# connections like resnet18
|
520
|
+
|
521
|
+
# Validate conv args
|
522
|
+
if conv_node.args[0] is not input_node:
|
523
|
+
raise ValueError("Conv arg did not contain input node ", input_node)
|
524
|
+
if conv_node.args[1] is not weight_node:
|
525
|
+
raise ValueError("Conv arg did not contain weight node ", weight_node)
|
526
|
+
if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
|
527
|
+
raise ValueError("Conv arg did not contain bias node ", bias_node)
|
528
|
+
|
529
|
+
# Skip if the partition is already annotated or is filtered out by the user
|
530
|
+
partition = [conv_node, weight_node]
|
531
|
+
if bias_node is not None:
|
532
|
+
partition.append(bias_node)
|
533
|
+
if _is_annotated(partition):
|
534
|
+
continue
|
535
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
536
|
+
continue
|
537
|
+
|
538
|
+
# Annotate conv inputs and pattern output
|
539
|
+
input_qspec_map = {}
|
540
|
+
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
|
541
|
+
input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
|
542
|
+
if bias_node is not None:
|
543
|
+
input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
|
544
|
+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
545
|
+
input_qspec_map=input_qspec_map,
|
546
|
+
_annotated=True,
|
547
|
+
)
|
548
|
+
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
549
|
+
output_qspec=get_output_act_qspec(
|
550
|
+
quantization_config
|
551
|
+
), # type: ignore[arg-type]
|
552
|
+
_annotated=True,
|
553
|
+
)
|
554
|
+
_mark_nodes_as_annotated(partition)
|
555
|
+
annotated_partitions.append(partition)
|
556
|
+
return annotated_partitions
|
557
|
+
|
558
|
+
|
559
|
+
@register_annotator("gru_io_only")
|
560
|
+
def _annotate_gru_io_only(
|
561
|
+
gm: torch.fx.GraphModule,
|
562
|
+
quantization_config: Optional[QuantizationConfig],
|
563
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
564
|
+
) -> Optional[List[List[Node]]]:
|
565
|
+
gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
|
566
|
+
gru_partitions = list(itertools.chain(*gru_partitions.values()))
|
567
|
+
annotated_partitions = []
|
568
|
+
for gru_partition in gru_partitions:
|
569
|
+
annotated_partitions.append(gru_partition.nodes)
|
570
|
+
output_nodes = gru_partition.output_nodes
|
571
|
+
input_nodes = gru_partition.input_nodes
|
572
|
+
# skip annotation if it is already annotated
|
573
|
+
if _is_annotated(input_nodes + output_nodes):
|
574
|
+
continue
|
575
|
+
# inside each GRU partition, we should be able to annotate each linear
|
576
|
+
# subgraph
|
577
|
+
input_qspec_map: Dict[Node, QuantizationSpecBase] = {}
|
578
|
+
input_act = input_nodes[0]
|
579
|
+
input_act_user = next(iter(input_act.users.keys()))
|
580
|
+
assert isinstance(input_act, Node)
|
581
|
+
assert isinstance(input_act_user, Node)
|
582
|
+
input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
|
583
|
+
input_qspec_map={
|
584
|
+
input_act: get_input_act_qspec(quantization_config),
|
585
|
+
},
|
586
|
+
_annotated=True,
|
587
|
+
)
|
588
|
+
|
589
|
+
hidden_state = input_nodes[1]
|
590
|
+
hidden_state_user = next(iter(hidden_state.users.keys()))
|
591
|
+
assert isinstance(hidden_state, Node)
|
592
|
+
assert isinstance(hidden_state_user, Node)
|
593
|
+
hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
|
594
|
+
input_qspec_map={
|
595
|
+
hidden_state: get_input_act_qspec(quantization_config),
|
596
|
+
},
|
597
|
+
_annotated=True,
|
598
|
+
)
|
599
|
+
|
600
|
+
assert len(output_nodes) == 2, "expecting GRU to have two outputs"
|
601
|
+
for output in output_nodes:
|
602
|
+
output.meta["quantization_annotation"] = QuantizationAnnotation(
|
603
|
+
output_qspec=get_output_act_qspec(quantization_config),
|
604
|
+
_annotated=True,
|
605
|
+
)
|
606
|
+
nodes_to_mark_annotated = list(gru_partition.nodes)
|
607
|
+
_mark_nodes_as_annotated(nodes_to_mark_annotated)
|
608
|
+
return annotated_partitions
|
609
|
+
|
610
|
+
|
611
|
+
@register_annotator("max_pool2d")
|
612
|
+
def _annotate_max_pool2d(
|
613
|
+
gm: torch.fx.GraphModule,
|
614
|
+
quantization_config: Optional[QuantizationConfig],
|
615
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
616
|
+
) -> Optional[List[List[Node]]]:
|
617
|
+
module_partitions = get_source_partitions(
|
618
|
+
gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn
|
619
|
+
)
|
620
|
+
maxpool_partitions = list(itertools.chain(*module_partitions.values()))
|
621
|
+
annotated_partitions = []
|
622
|
+
for maxpool_partition in maxpool_partitions:
|
623
|
+
annotated_partitions.append(maxpool_partition.nodes)
|
624
|
+
output_node = maxpool_partition.output_nodes[0]
|
625
|
+
maxpool_node = None
|
626
|
+
for n in maxpool_partition.nodes:
|
627
|
+
if (
|
628
|
+
n.target == torch.ops.aten.max_pool2d.default
|
629
|
+
or n.target == torch.ops.aten.max_pool2d_with_indices.default
|
630
|
+
):
|
631
|
+
maxpool_node = n
|
632
|
+
|
633
|
+
assert (
|
634
|
+
maxpool_node is not None
|
635
|
+
), "PT2EQuantizer only works with torch.ops.aten.max_pool2d.default, "
|
636
|
+
"please make sure you are exporting the model correctly"
|
637
|
+
if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item]
|
638
|
+
continue
|
639
|
+
|
640
|
+
input_act = maxpool_node.args[0] # type: ignore[union-attr]
|
641
|
+
assert isinstance(input_act, Node)
|
642
|
+
|
643
|
+
# only annotate maxpool when the output of the input node is annotated
|
644
|
+
if (
|
645
|
+
"quantization_annotation" not in input_act.meta
|
646
|
+
or not input_act.meta["quantization_annotation"]._annotated
|
647
|
+
or input_act.meta["quantization_annotation"].output_qspec is None
|
648
|
+
):
|
649
|
+
continue
|
650
|
+
# input and output of maxpool will share quantization parameter with input of maxpool
|
651
|
+
act_qspec = SharedQuantizationSpec(input_act)
|
652
|
+
# act_qspec = get_act_qspec(quantization_config)
|
653
|
+
maxpool_node.meta[
|
654
|
+
# type: ignore[union-attr]
|
655
|
+
"quantization_annotation"
|
656
|
+
] = QuantizationAnnotation(
|
657
|
+
input_qspec_map={
|
658
|
+
input_act: act_qspec,
|
659
|
+
},
|
660
|
+
_annotated=True,
|
661
|
+
)
|
662
|
+
output_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
663
|
+
output_qspec=act_qspec,
|
664
|
+
_annotated=True,
|
665
|
+
)
|
666
|
+
return annotated_partitions
|
667
|
+
|
668
|
+
|
669
|
+
@register_annotator("adaptive_avg_pool2d")
|
670
|
+
def _annotate_adaptive_avg_pool2d(
|
671
|
+
gm: torch.fx.GraphModule,
|
672
|
+
quantization_config: Optional[QuantizationConfig],
|
673
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
674
|
+
) -> Optional[List[List[Node]]]:
|
675
|
+
"""Always annotate adaptive_avg_pool2d op"""
|
676
|
+
module_partitions = get_source_partitions(
|
677
|
+
gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
|
678
|
+
)
|
679
|
+
partitions = list(itertools.chain(*module_partitions.values()))
|
680
|
+
annotated_partitions = []
|
681
|
+
for partition in partitions:
|
682
|
+
pool_node = partition.output_nodes[0]
|
683
|
+
if pool_node.op != "call_function" or (
|
684
|
+
pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
|
685
|
+
and pool_node.target != torch.ops.aten._adaptive_avg_pool2d.default
|
686
|
+
and pool_node.target != torch.ops.aten.mean.dim
|
687
|
+
and pool_node.target != torch.ops.aten.as_strided_.default
|
688
|
+
):
|
689
|
+
raise ValueError(
|
690
|
+
f"{pool_node} is not an aten adaptive_avg_pool2d operator"
|
691
|
+
)
|
692
|
+
|
693
|
+
if _is_annotated([pool_node]):
|
694
|
+
continue
|
695
|
+
|
696
|
+
annotated_partitions.append(partition.nodes)
|
697
|
+
input_act = pool_node.args[0]
|
698
|
+
assert isinstance(input_act, Node)
|
699
|
+
|
700
|
+
# only annotate input output sharing operator
|
701
|
+
# when the output of the input node is annotated
|
702
|
+
if (
|
703
|
+
"quantization_annotation" not in input_act.meta
|
704
|
+
or not input_act.meta["quantization_annotation"]._annotated
|
705
|
+
or input_act.meta["quantization_annotation"].output_qspec is None
|
706
|
+
):
|
707
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
708
|
+
else:
|
709
|
+
input_act_qspec = SharedQuantizationSpec(input_act)
|
710
|
+
|
711
|
+
# output sharing with input
|
712
|
+
output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
|
713
|
+
pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
714
|
+
input_qspec_map={
|
715
|
+
input_act: input_act_qspec,
|
716
|
+
},
|
717
|
+
output_qspec=output_act_qspec,
|
718
|
+
_annotated=True,
|
719
|
+
)
|
720
|
+
return annotated_partitions
|
721
|
+
|
722
|
+
|
723
|
+
@register_annotator("fixed_qparams")
|
724
|
+
def _annotate_fixed_qparams(
|
725
|
+
gm: torch.fx.GraphModule,
|
726
|
+
quantization_config: Optional[QuantizationConfig],
|
727
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
728
|
+
) -> Optional[List[List[Node]]]:
|
729
|
+
annotated_partitions = []
|
730
|
+
for node in gm.graph.nodes:
|
731
|
+
if node.op != "call_function" or (
|
732
|
+
node.target != torch.ops.aten.sigmoid.default
|
733
|
+
and node.target != torch.ops.aten._softmax.default
|
734
|
+
):
|
735
|
+
continue
|
736
|
+
|
737
|
+
input_act = node.args[0] # type: ignore[union-attr]
|
738
|
+
assert isinstance(input_act, Node)
|
739
|
+
|
740
|
+
# only annotate when the output of the input node is annotated
|
741
|
+
if (
|
742
|
+
"quantization_annotation" not in input_act.meta
|
743
|
+
or not input_act.meta["quantization_annotation"]._annotated
|
744
|
+
or input_act.meta["quantization_annotation"].output_qspec is None
|
745
|
+
):
|
746
|
+
continue
|
747
|
+
partition = [node]
|
748
|
+
|
749
|
+
if _is_annotated(partition):
|
750
|
+
continue
|
751
|
+
|
752
|
+
if filter_fn and any(not filter_fn(n) for n in partition):
|
753
|
+
continue
|
754
|
+
|
755
|
+
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
756
|
+
output_qspec=get_fixed_qparams_qspec(quantization_config),
|
757
|
+
_annotated=True,
|
758
|
+
)
|
759
|
+
_mark_nodes_as_annotated(partition)
|
760
|
+
annotated_partitions.append(partition)
|
761
|
+
|
762
|
+
return annotated_partitions
|
763
|
+
|
764
|
+
|
765
|
+
@register_annotator("add_relu")
|
766
|
+
def _annotate_add_relu(
|
767
|
+
gm: torch.fx.GraphModule,
|
768
|
+
quantization_config: Optional[QuantizationConfig],
|
769
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
770
|
+
) -> Optional[List[List[Node]]]:
|
771
|
+
fused_partitions = find_sequential_partitions(
|
772
|
+
gm, [torch.add, torch.nn.ReLU], filter_fn
|
773
|
+
)
|
774
|
+
annotated_partitions = []
|
775
|
+
for fused_partition in fused_partitions:
|
776
|
+
add_partition, relu_partition = fused_partition
|
777
|
+
annotated_partitions.append(add_partition.nodes + relu_partition.nodes)
|
778
|
+
if len(relu_partition.output_nodes) > 1:
|
779
|
+
raise ValueError("Relu partition has more than one output node")
|
780
|
+
relu_node = relu_partition.output_nodes[0]
|
781
|
+
if len(add_partition.output_nodes) > 1:
|
782
|
+
raise ValueError("add partition has more than one output node")
|
783
|
+
add_node = add_partition.output_nodes[0]
|
784
|
+
|
785
|
+
if _is_annotated([relu_node, add_node]):
|
786
|
+
continue
|
787
|
+
|
788
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
789
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
790
|
+
|
791
|
+
input_qspec_map = {}
|
792
|
+
input_act0 = add_node.args[0]
|
793
|
+
if isinstance(input_act0, Node):
|
794
|
+
input_qspec_map[input_act0] = input_act_qspec
|
795
|
+
|
796
|
+
input_act1 = add_node.args[1]
|
797
|
+
if isinstance(input_act1, Node):
|
798
|
+
input_qspec_map[input_act1] = input_act_qspec
|
799
|
+
|
800
|
+
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
801
|
+
input_qspec_map=input_qspec_map,
|
802
|
+
_annotated=True,
|
803
|
+
)
|
804
|
+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
805
|
+
output_qspec=output_act_qspec,
|
806
|
+
_annotated=True,
|
807
|
+
)
|
808
|
+
return annotated_partitions
|
809
|
+
|
810
|
+
|
811
|
+
@register_annotator("add")
|
812
|
+
def _annotate_add(
|
813
|
+
gm: torch.fx.GraphModule,
|
814
|
+
quantization_config: Optional[QuantizationConfig],
|
815
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
816
|
+
) -> Optional[List[List[Node]]]:
|
817
|
+
add_partitions = get_source_partitions(
|
818
|
+
gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
|
819
|
+
)
|
820
|
+
add_partitions = list(itertools.chain(*add_partitions.values()))
|
821
|
+
annotated_partitions = []
|
822
|
+
for add_partition in add_partitions:
|
823
|
+
annotated_partitions.append(add_partition.nodes)
|
824
|
+
add_node = add_partition.output_nodes[0]
|
825
|
+
if _is_annotated([add_node]):
|
826
|
+
continue
|
827
|
+
|
828
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
829
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
830
|
+
|
831
|
+
input_qspec_map = {}
|
832
|
+
input_act0 = add_node.args[0]
|
833
|
+
if isinstance(input_act0, Node):
|
834
|
+
input_qspec_map[input_act0] = input_act_qspec
|
835
|
+
|
836
|
+
input_act1 = add_node.args[1]
|
837
|
+
if isinstance(input_act1, Node):
|
838
|
+
input_qspec_map[input_act1] = input_act_qspec
|
839
|
+
|
840
|
+
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
841
|
+
input_qspec_map=input_qspec_map,
|
842
|
+
output_qspec=output_act_qspec,
|
843
|
+
_annotated=True,
|
844
|
+
)
|
845
|
+
return annotated_partitions
|
846
|
+
|
847
|
+
|
848
|
+
@register_annotator("mul_relu")
|
849
|
+
def _annotate_mul_relu(
|
850
|
+
gm: torch.fx.GraphModule,
|
851
|
+
quantization_config: Optional[QuantizationConfig],
|
852
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
853
|
+
) -> Optional[List[List[Node]]]:
|
854
|
+
fused_partitions = find_sequential_partitions(
|
855
|
+
gm, [torch.mul, torch.nn.ReLU], filter_fn
|
856
|
+
)
|
857
|
+
annotated_partitions = []
|
858
|
+
for fused_partition in fused_partitions:
|
859
|
+
mul_partition, relu_partition = fused_partition
|
860
|
+
annotated_partitions.append(mul_partition.nodes + relu_partition.nodes)
|
861
|
+
if len(relu_partition.output_nodes) > 1:
|
862
|
+
raise ValueError("Relu partition has more than one output node")
|
863
|
+
relu_node = relu_partition.output_nodes[0]
|
864
|
+
if len(mul_partition.output_nodes) > 1:
|
865
|
+
raise ValueError("mul partition has more than one output node")
|
866
|
+
mul_node = mul_partition.output_nodes[0]
|
867
|
+
|
868
|
+
if _is_annotated([relu_node, mul_node]):
|
869
|
+
continue
|
870
|
+
|
871
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
872
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
873
|
+
|
874
|
+
input_qspec_map = {}
|
875
|
+
input_act0 = mul_node.args[0]
|
876
|
+
if isinstance(input_act0, Node):
|
877
|
+
input_qspec_map[input_act0] = input_act_qspec
|
878
|
+
|
879
|
+
input_act1 = mul_node.args[1]
|
880
|
+
if isinstance(input_act1, Node):
|
881
|
+
input_qspec_map[input_act1] = input_act_qspec
|
882
|
+
|
883
|
+
mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
884
|
+
input_qspec_map=input_qspec_map,
|
885
|
+
_annotated=True,
|
886
|
+
)
|
887
|
+
relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
888
|
+
output_qspec=output_act_qspec,
|
889
|
+
_annotated=True,
|
890
|
+
)
|
891
|
+
return annotated_partitions
|
892
|
+
|
893
|
+
|
894
|
+
@register_annotator("mul")
|
895
|
+
def _annotate_mul(
|
896
|
+
gm: torch.fx.GraphModule,
|
897
|
+
quantization_config: Optional[QuantizationConfig],
|
898
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
899
|
+
) -> Optional[List[List[Node]]]:
|
900
|
+
mul_partitions = get_source_partitions(
|
901
|
+
gm.graph,
|
902
|
+
["mul", "mul_", operator.mul, torch.mul, operator.imul],
|
903
|
+
filter_fn,
|
904
|
+
)
|
905
|
+
mul_partitions = list(itertools.chain(*mul_partitions.values()))
|
906
|
+
annotated_partitions = []
|
907
|
+
for mul_partition in mul_partitions:
|
908
|
+
annotated_partitions.append(mul_partition.nodes)
|
909
|
+
mul_node = mul_partition.output_nodes[0]
|
910
|
+
if _is_annotated([mul_node]):
|
911
|
+
continue
|
912
|
+
|
913
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
914
|
+
output_act_qspec = get_output_act_qspec(quantization_config)
|
915
|
+
|
916
|
+
input_qspec_map = {}
|
917
|
+
input_act0 = mul_node.args[0]
|
918
|
+
if isinstance(input_act0, Node):
|
919
|
+
input_qspec_map[input_act0] = input_act_qspec
|
920
|
+
|
921
|
+
input_act1 = mul_node.args[1]
|
922
|
+
if isinstance(input_act1, Node):
|
923
|
+
input_qspec_map[input_act1] = input_act_qspec
|
924
|
+
|
925
|
+
mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
926
|
+
input_qspec_map=input_qspec_map,
|
927
|
+
output_qspec=output_act_qspec,
|
928
|
+
_annotated=True,
|
929
|
+
)
|
930
|
+
return annotated_partitions
|
931
|
+
|
932
|
+
|
933
|
+
# TODO: remove Optional in return type, fix annotated_partitions logic
|
934
|
+
@register_annotator("cat")
|
935
|
+
def _annotate_cat(
|
936
|
+
gm: torch.fx.GraphModule,
|
937
|
+
quantization_config: Optional[QuantizationConfig],
|
938
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
939
|
+
) -> Optional[List[List[Node]]]:
|
940
|
+
cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
|
941
|
+
cat_partitions = list(itertools.chain(*cat_partitions.values()))
|
942
|
+
annotated_partitions = []
|
943
|
+
for cat_partition in cat_partitions:
|
944
|
+
cat_node = cat_partition.output_nodes[0]
|
945
|
+
if _is_annotated([cat_node]):
|
946
|
+
continue
|
947
|
+
|
948
|
+
if cat_node.target != torch.ops.aten.cat.default:
|
949
|
+
raise Exception(
|
950
|
+
"Expected cat node: torch.ops.aten.cat.default, but found"
|
951
|
+
f" {cat_node.target} please check if you are calling the correct"
|
952
|
+
" capture API"
|
953
|
+
)
|
954
|
+
|
955
|
+
annotated_partitions.append(cat_partition.nodes)
|
956
|
+
|
957
|
+
input_act_qspec = get_input_act_qspec(quantization_config)
|
958
|
+
inputs = cat_node.args[0]
|
959
|
+
|
960
|
+
input_qspec_map = {}
|
961
|
+
input_act0 = inputs[0]
|
962
|
+
if isinstance(input_act0, Node):
|
963
|
+
input_qspec_map[input_act0] = input_act_qspec
|
964
|
+
|
965
|
+
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
|
966
|
+
for input_act in inputs[1:]:
|
967
|
+
input_qspec_map[input_act] = shared_with_input0_qspec
|
968
|
+
|
969
|
+
output_act_qspec = shared_with_input0_qspec
|
970
|
+
|
971
|
+
cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
972
|
+
input_qspec_map=input_qspec_map,
|
973
|
+
output_qspec=output_act_qspec,
|
974
|
+
_annotated=True,
|
975
|
+
)
|
976
|
+
return annotated_partitions
|
977
|
+
|
978
|
+
|
979
|
+
def _is_share_obs_or_fq_op(op: Callable) -> bool:
|
980
|
+
return op in [
|
981
|
+
torch.ops.aten.hardtanh.default,
|
982
|
+
torch.ops.aten.hardtanh_.default,
|
983
|
+
torch.ops.aten.mean.default,
|
984
|
+
torch.ops.aten.mean.dim,
|
985
|
+
torch.ops.aten.permute.default,
|
986
|
+
torch.ops.aten.permute_copy.default,
|
987
|
+
torch.ops.aten.squeeze.dim,
|
988
|
+
torch.ops.aten.squeeze_copy.dim,
|
989
|
+
torch.ops.aten.adaptive_avg_pool2d.default,
|
990
|
+
torch.ops.aten.view_copy.default,
|
991
|
+
torch.ops.aten.view.default,
|
992
|
+
torch.ops.aten.slice_copy.Tensor,
|
993
|
+
torch.ops.aten.flatten.using_ints,
|
994
|
+
]
|
995
|
+
|
996
|
+
|
997
|
+
def propagate_annotation(model: torch.fx.GraphModule) -> None:
|
998
|
+
for n in model.graph.nodes:
|
999
|
+
if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
|
1000
|
+
continue
|
1001
|
+
|
1002
|
+
prev_node = n.args[0]
|
1003
|
+
if not isinstance(prev_node, Node):
|
1004
|
+
continue
|
1005
|
+
|
1006
|
+
quantization_annotation = prev_node.meta.get(
|
1007
|
+
"quantization_annotation", None
|
1008
|
+
)
|
1009
|
+
if not quantization_annotation:
|
1010
|
+
continue
|
1011
|
+
|
1012
|
+
output_qspec = quantization_annotation.output_qspec
|
1013
|
+
if not output_qspec:
|
1014
|
+
continue
|
1015
|
+
|
1016
|
+
# make sure current node is not annotated
|
1017
|
+
if (
|
1018
|
+
"quantization_annotation" in n.meta
|
1019
|
+
and n.meta["quantization_annotation"]._annotated
|
1020
|
+
):
|
1021
|
+
continue
|
1022
|
+
|
1023
|
+
shared_qspec = SharedQuantizationSpec(prev_node)
|
1024
|
+
# propagate the previous output_qspec to the current node
|
1025
|
+
n.meta["quantization_annotation"] = QuantizationAnnotation(
|
1026
|
+
input_qspec_map={
|
1027
|
+
prev_node: shared_qspec,
|
1028
|
+
},
|
1029
|
+
output_qspec=shared_qspec,
|
1030
|
+
_annotated=True,
|
1031
|
+
)
|
1032
|
+
|
1033
|
+
|
1034
|
+
# TODO: make the list of ops customizable
|
1035
|
+
def _convert_scalars_to_attrs(
|
1036
|
+
model: torch.fx.GraphModule,
|
1037
|
+
) -> torch.fx.GraphModule:
|
1038
|
+
for n in model.graph.nodes:
|
1039
|
+
if n.op != "call_function" or n.target not in [
|
1040
|
+
torch.ops.aten.add.Tensor,
|
1041
|
+
torch.ops.aten.mul.Tensor,
|
1042
|
+
]:
|
1043
|
+
continue
|
1044
|
+
args = list(n.args)
|
1045
|
+
new_args = []
|
1046
|
+
for i in range(len(args)):
|
1047
|
+
if isinstance(args[i], torch.fx.Node):
|
1048
|
+
new_args.append(args[i])
|
1049
|
+
continue
|
1050
|
+
prefix = "_tensor_constant_"
|
1051
|
+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
|
1052
|
+
tensor_constant_name = get_new_attr_name(model)
|
1053
|
+
model.register_buffer(tensor_constant_name, torch.tensor(args[i]))
|
1054
|
+
with model.graph.inserting_before(n):
|
1055
|
+
get_attr_node = model.graph.create_node(
|
1056
|
+
"get_attr", tensor_constant_name, (), {}
|
1057
|
+
)
|
1058
|
+
new_args.append(get_attr_node)
|
1059
|
+
n.args = tuple(new_args)
|
1060
|
+
model.recompile()
|
1061
|
+
return model
|