ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
ai_edge_torch/model.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Represents an ai_edge_torch model.
|
|
17
|
+
|
|
18
|
+
PyTorch models can be converted to this representation through `ai_edge_torch.convert`.
|
|
19
|
+
"""
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import abc
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import numpy.typing as npt
|
|
26
|
+
import tensorflow as tf
|
|
27
|
+
|
|
28
|
+
from ai_edge_torch.convert import conversion_utils as cutils
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Model(abc.ABC):
|
|
32
|
+
"""Represents and edge model."""
|
|
33
|
+
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def __call__(
|
|
36
|
+
self,
|
|
37
|
+
*args: npt.ArrayLike,
|
|
38
|
+
signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
|
|
39
|
+
**kwargs,
|
|
40
|
+
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
41
|
+
raise NotImplementedError()
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
def export(self, path: str):
|
|
45
|
+
raise NotImplementedError()
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def load(path: str) -> TfLiteModel:
|
|
49
|
+
tflite_model = TfLiteModel.load(path)
|
|
50
|
+
if tflite_model:
|
|
51
|
+
return tflite_model
|
|
52
|
+
|
|
53
|
+
raise ValueError(f'File format in {path} cannot be deserialized.')
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TfLiteModel(Model):
|
|
57
|
+
"""An edge model which uses tflite under-the-hood."""
|
|
58
|
+
|
|
59
|
+
def __init__(self, tflite_model):
|
|
60
|
+
"""Initializes the TfLiteModel instance using a TFLite serialized object.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
tflite_model: A TFlite serialized object.
|
|
64
|
+
"""
|
|
65
|
+
self._tflite_model = tflite_model
|
|
66
|
+
|
|
67
|
+
def __call__(
|
|
68
|
+
self,
|
|
69
|
+
*args: npt.ArrayLike,
|
|
70
|
+
signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
|
|
71
|
+
**kwargs,
|
|
72
|
+
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
73
|
+
"""Runs inference on the edge model using the provided arguments.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
*args: The arguments to be passed to the model for inference.
|
|
77
|
+
**kwargs: The arguments with specific names to be passed to the model for inference.
|
|
78
|
+
signature_name: The name of the signature to be used for inference.
|
|
79
|
+
The default signature is used if not provided.
|
|
80
|
+
"""
|
|
81
|
+
interpreter = tf.lite.Interpreter(model_content=self._tflite_model)
|
|
82
|
+
interpreter.allocate_tensors()
|
|
83
|
+
|
|
84
|
+
signature_list = interpreter.get_signature_list()
|
|
85
|
+
if signature_name not in signature_list:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Invalid signature name provided. Available signatures: {', '.join(signature_list.keys())}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
runner = interpreter.get_signature_runner(signature_name)
|
|
92
|
+
except ValueError as exception:
|
|
93
|
+
if 'Invalid signature_key provided.' in str(exception):
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f'Invalid signature key provided. Available signatures: {list(signature_list.keys())}'
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
raise exception
|
|
99
|
+
|
|
100
|
+
if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Gather the input dictionary based on the signature.
|
|
106
|
+
inputs = {f'args_{idx}': args[idx] for idx in range(len(args))}
|
|
107
|
+
inputs = {**inputs, **kwargs}
|
|
108
|
+
outputs = runner(**inputs)
|
|
109
|
+
|
|
110
|
+
return (
|
|
111
|
+
outputs['output_0']
|
|
112
|
+
if len(outputs) == 1
|
|
113
|
+
else [outputs[f'output_{idx}'] for idx in range(len(outputs))]
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def export(self, path: str) -> None:
|
|
117
|
+
"""Serializes the edge model to disk.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
path: The path to file to which the model is serialized.
|
|
121
|
+
"""
|
|
122
|
+
with open(path, 'wb') as file_handle:
|
|
123
|
+
file_handle.write(self._tflite_model)
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def load(path: str) -> TfLiteModel | None:
|
|
127
|
+
"""Returns an edge (tflite) model by reading it from the disk.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
str: The path to the model.
|
|
131
|
+
"""
|
|
132
|
+
with open(path, 'rb') as file_handle:
|
|
133
|
+
model_content = file_handle.read()
|
|
134
|
+
|
|
135
|
+
# Check if this is indeed a tflite model:
|
|
136
|
+
try:
|
|
137
|
+
interpreter = tf.lite.Interpreter(model_content=model_content)
|
|
138
|
+
interpreter.get_signature_list()
|
|
139
|
+
except:
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
return TfLiteModel(model_content)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from .pt2e_quantizer import PT2EQuantizer
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import copy
|
|
19
|
+
import functools
|
|
20
|
+
from typing import Any, Callable, Dict, List, Optional, Set
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
|
24
|
+
from torch.ao.quantization.observer import HistogramObserver
|
|
25
|
+
from torch.ao.quantization.observer import MinMaxObserver
|
|
26
|
+
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
|
|
27
|
+
from torch.ao.quantization.observer import MovingAveragePerChannelMinMaxObserver # NOQA
|
|
28
|
+
from torch.ao.quantization.observer import PerChannelMinMaxObserver
|
|
29
|
+
from torch.ao.quantization.observer import PlaceholderObserver
|
|
30
|
+
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
|
31
|
+
from torch.ao.quantization.quantizer import FixedQParamsQuantizationSpec
|
|
32
|
+
from torch.ao.quantization.quantizer import QuantizationSpec
|
|
33
|
+
from torch.ao.quantization.quantizer import Quantizer
|
|
34
|
+
from torch.fx import Node
|
|
35
|
+
import torch.nn.functional as F
|
|
36
|
+
|
|
37
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
|
|
38
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
|
|
39
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
|
|
40
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
|
|
41
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
|
|
42
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"PT2EQuantizer",
|
|
46
|
+
"get_symmetric_quantization_config",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
|
51
|
+
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
|
52
|
+
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
|
53
|
+
# those are clamp ops
|
|
54
|
+
"conv2d": [
|
|
55
|
+
[torch.nn.Conv2d, torch.nn.ReLU],
|
|
56
|
+
[torch.nn.Conv2d, F.relu],
|
|
57
|
+
[F.conv2d, torch.nn.ReLU],
|
|
58
|
+
[F.conv2d, F.relu],
|
|
59
|
+
],
|
|
60
|
+
"linear": [[torch.nn.Linear], [F.linear]],
|
|
61
|
+
"add": [[torch.add]],
|
|
62
|
+
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
|
|
63
|
+
"adaptive_avg_pool2d": [
|
|
64
|
+
[torch.nn.AdaptiveAvgPool2d],
|
|
65
|
+
[F.adaptive_avg_pool2d],
|
|
66
|
+
],
|
|
67
|
+
}
|
|
68
|
+
return copy.deepcopy(supported_operators)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
|
72
|
+
supported_config_and_operators: List[OperatorConfig] = []
|
|
73
|
+
for quantization_config in [
|
|
74
|
+
get_symmetric_quantization_config(),
|
|
75
|
+
get_symmetric_quantization_config(is_qat=True),
|
|
76
|
+
get_symmetric_quantization_config(is_per_channel=True),
|
|
77
|
+
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
|
78
|
+
]:
|
|
79
|
+
ops = _supported_symmetric_quantized_operators()
|
|
80
|
+
for pattern_list in ops.values():
|
|
81
|
+
supported_config_and_operators.append(
|
|
82
|
+
OperatorConfig(quantization_config, pattern_list)
|
|
83
|
+
)
|
|
84
|
+
return copy.deepcopy(supported_config_and_operators)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@functools.lru_cache
|
|
88
|
+
def get_symmetric_quantization_config(
|
|
89
|
+
is_per_channel: bool = False,
|
|
90
|
+
is_qat: bool = False,
|
|
91
|
+
is_dynamic: bool = False,
|
|
92
|
+
):
|
|
93
|
+
if is_qat:
|
|
94
|
+
if is_dynamic:
|
|
95
|
+
raise NotImplementedError("dynamic quantization for qat is not yet implemented.")
|
|
96
|
+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
97
|
+
else:
|
|
98
|
+
if is_dynamic:
|
|
99
|
+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
|
|
100
|
+
else:
|
|
101
|
+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
|
|
102
|
+
|
|
103
|
+
act_quantization_spec = QuantizationSpec(
|
|
104
|
+
dtype=torch.int8,
|
|
105
|
+
quant_min=-128,
|
|
106
|
+
quant_max=127,
|
|
107
|
+
qscheme=torch.per_tensor_affine,
|
|
108
|
+
is_dynamic=is_dynamic,
|
|
109
|
+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
|
|
110
|
+
)
|
|
111
|
+
qscheme = (
|
|
112
|
+
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
|
|
113
|
+
)
|
|
114
|
+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
|
|
115
|
+
if is_qat:
|
|
116
|
+
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
117
|
+
elif is_per_channel:
|
|
118
|
+
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
|
|
119
|
+
|
|
120
|
+
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
|
121
|
+
if is_qat:
|
|
122
|
+
if qscheme == torch.per_tensor_symmetric:
|
|
123
|
+
extra_args["observer"] = MovingAverageMinMaxObserver
|
|
124
|
+
else:
|
|
125
|
+
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
|
|
126
|
+
weight_quantization_spec = QuantizationSpec(
|
|
127
|
+
dtype=torch.int8,
|
|
128
|
+
quant_min=-127,
|
|
129
|
+
quant_max=127,
|
|
130
|
+
qscheme=qscheme,
|
|
131
|
+
ch_axis=0,
|
|
132
|
+
is_dynamic=False,
|
|
133
|
+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
|
|
134
|
+
**extra_args
|
|
135
|
+
),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
bias_quantization_spec = None
|
|
139
|
+
|
|
140
|
+
# Some TFLite ops (e.g. Logistic, Softmax) have fixed qparams requirements
|
|
141
|
+
fixed_qparams_spec = FixedQParamsQuantizationSpec(
|
|
142
|
+
dtype=torch.int8,
|
|
143
|
+
scale=1 / 256,
|
|
144
|
+
zero_point=-128,
|
|
145
|
+
quant_min=-128,
|
|
146
|
+
quant_max=127,
|
|
147
|
+
qscheme=torch.per_tensor_affine,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if is_dynamic:
|
|
151
|
+
# Only valid for TFLite downstream to have no input activation quantization
|
|
152
|
+
# because dynamic quantization should be legalized to TFLite DRQ kernels
|
|
153
|
+
# which calculate quantization parameters during runtime inside the kernels
|
|
154
|
+
quantization_config = QuantizationConfig(
|
|
155
|
+
None,
|
|
156
|
+
None,
|
|
157
|
+
weight_quantization_spec,
|
|
158
|
+
bias_quantization_spec,
|
|
159
|
+
None,
|
|
160
|
+
is_qat,
|
|
161
|
+
True,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
quantization_config = QuantizationConfig(
|
|
165
|
+
act_quantization_spec,
|
|
166
|
+
act_quantization_spec,
|
|
167
|
+
weight_quantization_spec,
|
|
168
|
+
bias_quantization_spec,
|
|
169
|
+
fixed_qparams_spec,
|
|
170
|
+
is_qat,
|
|
171
|
+
False,
|
|
172
|
+
)
|
|
173
|
+
return quantization_config
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
|
177
|
+
return _get_supported_symmetric_config_and_operators()
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _get_module_name_filter(module_name: str):
|
|
181
|
+
"""Get the module_name_filter function for a given module name, the filter accepts
|
|
182
|
+
a node and checks if the node comes from a module that has certain module name
|
|
183
|
+
|
|
184
|
+
For example:
|
|
185
|
+
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
>> module_name_filter = _get_module_name_filter("blocks.sub")
|
|
189
|
+
>> print(module_name_filter(node))
|
|
190
|
+
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def module_name_filter(n: Node) -> bool:
|
|
194
|
+
# example: {
|
|
195
|
+
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
|
196
|
+
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
|
197
|
+
# }
|
|
198
|
+
# get_attr nodes doesn't have nn_module_stack?
|
|
199
|
+
nn_module_stack = n.meta.get("nn_module_stack", {})
|
|
200
|
+
names = [n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()]
|
|
201
|
+
return module_name in names
|
|
202
|
+
|
|
203
|
+
return module_name_filter
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _get_module_type_filter(tp: Callable):
|
|
207
|
+
"""Get the module_type_filter function for a given module type, the filter accepts
|
|
208
|
+
a node and checks if the node comes from a module that has certain module type
|
|
209
|
+
|
|
210
|
+
For example:
|
|
211
|
+
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
|
|
215
|
+
>> print(module_type_filter(node))
|
|
216
|
+
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def module_type_filter(n: Node) -> bool:
|
|
220
|
+
# example: {
|
|
221
|
+
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
|
|
222
|
+
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
|
|
223
|
+
# }
|
|
224
|
+
nn_module_stack = n.meta.get("nn_module_stack", {})
|
|
225
|
+
types = [t for _, t in nn_module_stack.values()]
|
|
226
|
+
return tp in types
|
|
227
|
+
|
|
228
|
+
return module_type_filter
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _get_not_module_type_or_name_filter(
|
|
232
|
+
tp_list: List[Callable], module_name_list: List[str]
|
|
233
|
+
) -> Callable[[Node], bool]:
|
|
234
|
+
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
|
|
235
|
+
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
|
|
236
|
+
|
|
237
|
+
def not_module_type_or_name_filter(n: Node) -> bool:
|
|
238
|
+
return not any(f(n) for f in module_type_filters + module_name_list_filters)
|
|
239
|
+
|
|
240
|
+
return not_module_type_or_name_filter
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class PT2EQuantizer(Quantizer):
|
|
244
|
+
supported_config_and_operators = _get_supported_config_and_operators()
|
|
245
|
+
STATIC_QAT_ONLY_OPS = [
|
|
246
|
+
"conv_bn_relu",
|
|
247
|
+
"conv_bn",
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
# static quantization ops (both PTQ and QAT)
|
|
251
|
+
STATIC_OPS = [
|
|
252
|
+
"linear",
|
|
253
|
+
"addmm",
|
|
254
|
+
"conv_relu",
|
|
255
|
+
"conv",
|
|
256
|
+
"adaptive_avg_pool2d",
|
|
257
|
+
"gru_io_only",
|
|
258
|
+
"max_pool2d",
|
|
259
|
+
"add_relu",
|
|
260
|
+
"add",
|
|
261
|
+
"mul_relu",
|
|
262
|
+
"mul",
|
|
263
|
+
"cat",
|
|
264
|
+
"fixed_qparams",
|
|
265
|
+
]
|
|
266
|
+
|
|
267
|
+
DYNAMIC_OPS = [
|
|
268
|
+
"linear",
|
|
269
|
+
"addmm",
|
|
270
|
+
"conv",
|
|
271
|
+
"conv_relu",
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
def __init__(self):
|
|
275
|
+
super().__init__()
|
|
276
|
+
self.global_config: Optional[QuantizationConfig] = None
|
|
277
|
+
self.operator_type_config: Dict[
|
|
278
|
+
torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
|
|
279
|
+
] = {}
|
|
280
|
+
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
|
|
281
|
+
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
|
285
|
+
op_configs: Set[QuantizationConfig] = set({})
|
|
286
|
+
for spec, _ in cls.supported_config_and_operators:
|
|
287
|
+
op_configs.add(spec)
|
|
288
|
+
return list(op_configs)
|
|
289
|
+
|
|
290
|
+
@classmethod
|
|
291
|
+
def get_supported_operator_for_quantization_config(
|
|
292
|
+
cls, quantization_config: Optional[QuantizationConfig]
|
|
293
|
+
) -> List[OperatorPatternType]:
|
|
294
|
+
if quantization_config is None:
|
|
295
|
+
all_ops = []
|
|
296
|
+
for _, ops in cls.supported_config_and_operators:
|
|
297
|
+
all_ops.extend(ops)
|
|
298
|
+
return all_ops
|
|
299
|
+
|
|
300
|
+
for config, ops in cls.supported_config_and_operators:
|
|
301
|
+
# note: this assumes each entry in cls.supported_spec_and_operators
|
|
302
|
+
# corresponds to one spec, e.g. we don't have
|
|
303
|
+
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
|
|
304
|
+
# where the first and second entry have the same spec but did not
|
|
305
|
+
# merge the op list
|
|
306
|
+
if config == quantization_config:
|
|
307
|
+
return ops
|
|
308
|
+
return []
|
|
309
|
+
|
|
310
|
+
def set_global(self, quantization_config: QuantizationConfig) -> PT2EQuantizer:
|
|
311
|
+
self.global_config = quantization_config
|
|
312
|
+
return self
|
|
313
|
+
|
|
314
|
+
def set_operator_type(
|
|
315
|
+
self,
|
|
316
|
+
operator_type: torch._ops.OpOverloadPacket,
|
|
317
|
+
quantization_config: QuantizationConfig,
|
|
318
|
+
) -> PT2EQuantizer:
|
|
319
|
+
self.operator_type_config[operator_type] = quantization_config
|
|
320
|
+
return self
|
|
321
|
+
|
|
322
|
+
def set_module_type(
|
|
323
|
+
self, module_type: Callable, quantization_config: QuantizationConfig
|
|
324
|
+
):
|
|
325
|
+
"""Set quantization_config for a submodule with type: `module_type`, for example:
|
|
326
|
+
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
|
|
327
|
+
patterns in the submodule with this module type with the given `quantization_config`
|
|
328
|
+
"""
|
|
329
|
+
self.module_type_config[module_type] = quantization_config
|
|
330
|
+
return self
|
|
331
|
+
|
|
332
|
+
def set_module_name(
|
|
333
|
+
self, module_name: str, quantization_config: Optional[QuantizationConfig]
|
|
334
|
+
):
|
|
335
|
+
"""Set quantization_config for a submodule with name: `module_name`, for example:
|
|
336
|
+
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
|
|
337
|
+
patterns in the submodule with this module name with the given `quantization_config`
|
|
338
|
+
"""
|
|
339
|
+
assert (
|
|
340
|
+
quantization_config is not None
|
|
341
|
+
), " quantization_config == None is not supported yet"
|
|
342
|
+
self.module_name_config[module_name] = quantization_config
|
|
343
|
+
return self
|
|
344
|
+
|
|
345
|
+
def transform_for_annotation(
|
|
346
|
+
self, model: torch.fx.GraphModule
|
|
347
|
+
) -> torch.fx.GraphModule:
|
|
348
|
+
"""Transforms scalar values to tensor attributes"""
|
|
349
|
+
return _convert_scalars_to_attrs(model)
|
|
350
|
+
|
|
351
|
+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
352
|
+
"""just handling global spec for now"""
|
|
353
|
+
if self.global_config and not self.global_config.input_activation: # type: ignore[union-attr]
|
|
354
|
+
model = self._annotate_for_dynamic_quantization_config(model)
|
|
355
|
+
else:
|
|
356
|
+
model = self._annotate_for_static_quantization_config(model)
|
|
357
|
+
propagate_annotation(model)
|
|
358
|
+
return model
|
|
359
|
+
|
|
360
|
+
def _annotate_all_static_patterns(
|
|
361
|
+
self,
|
|
362
|
+
model: torch.fx.GraphModule,
|
|
363
|
+
quantization_config: Optional[QuantizationConfig],
|
|
364
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
365
|
+
) -> torch.fx.GraphModule:
|
|
366
|
+
if quantization_config is None:
|
|
367
|
+
return model
|
|
368
|
+
|
|
369
|
+
if quantization_config.is_qat:
|
|
370
|
+
for op in self.STATIC_QAT_ONLY_OPS:
|
|
371
|
+
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
|
|
372
|
+
for op in self.STATIC_OPS:
|
|
373
|
+
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
|
|
374
|
+
return model
|
|
375
|
+
|
|
376
|
+
def _annotate_all_dynamic_patterns(
|
|
377
|
+
self,
|
|
378
|
+
model: torch.fx.GraphModule,
|
|
379
|
+
quantization_config: Optional[QuantizationConfig],
|
|
380
|
+
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
381
|
+
) -> torch.fx.GraphModule:
|
|
382
|
+
if quantization_config is None:
|
|
383
|
+
return model
|
|
384
|
+
|
|
385
|
+
for op in self.DYNAMIC_OPS:
|
|
386
|
+
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
|
|
387
|
+
return model
|
|
388
|
+
|
|
389
|
+
def _annotate_for_static_quantization_config(
|
|
390
|
+
self, model: torch.fx.GraphModule
|
|
391
|
+
) -> torch.fx.GraphModule:
|
|
392
|
+
module_name_list = list(self.module_name_config.keys())
|
|
393
|
+
for module_name, config in self.module_name_config.items():
|
|
394
|
+
self._annotate_all_static_patterns(
|
|
395
|
+
model, config, _get_module_name_filter(module_name)
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
tp_list = list(self.module_type_config.keys())
|
|
399
|
+
for module_type, config in self.module_type_config.items():
|
|
400
|
+
self._annotate_all_static_patterns(
|
|
401
|
+
model, config, _get_module_type_filter(module_type)
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
self._annotate_all_static_patterns(
|
|
405
|
+
model,
|
|
406
|
+
self.global_config,
|
|
407
|
+
_get_not_module_type_or_name_filter(tp_list, module_name_list),
|
|
408
|
+
)
|
|
409
|
+
return model
|
|
410
|
+
|
|
411
|
+
def _annotate_for_dynamic_quantization_config(
|
|
412
|
+
self, model: torch.fx.GraphModule
|
|
413
|
+
) -> torch.fx.GraphModule:
|
|
414
|
+
module_name_list = list(self.module_name_config.keys())
|
|
415
|
+
for module_name, config in self.module_name_config.items():
|
|
416
|
+
self._annotate_all_dynamic_patterns(
|
|
417
|
+
model, config, _get_module_name_filter(module_name)
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
tp_list = list(self.module_type_config.keys())
|
|
421
|
+
for module_type, config in self.module_type_config.items():
|
|
422
|
+
self._annotate_all_dynamic_patterns(
|
|
423
|
+
model, config, _get_module_type_filter(module_type)
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
self._annotate_all_dynamic_patterns(
|
|
427
|
+
model,
|
|
428
|
+
self.global_config,
|
|
429
|
+
_get_not_module_type_or_name_filter(tp_list, module_name_list),
|
|
430
|
+
)
|
|
431
|
+
return model
|
|
432
|
+
|
|
433
|
+
def validate(self, model: torch.fx.GraphModule) -> None:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
@classmethod
|
|
437
|
+
def get_supported_operators(cls) -> List[OperatorConfig]:
|
|
438
|
+
return cls.supported_config_and_operators
|