tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,225 @@
|
|
1
|
+
# Copyright (c) 2024 Intel Corporation
|
2
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import types
|
17
|
+
from typing import Any, Dict, List, Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from tico.experimental.quantization.algorithm.gptq.gptq import GPTQ
|
22
|
+
from tico.experimental.quantization.algorithm.gptq.utils import (
|
23
|
+
find_layers,
|
24
|
+
gather_single_batch_from_dict,
|
25
|
+
gather_single_batch_from_list,
|
26
|
+
)
|
27
|
+
from tico.experimental.quantization.config import BaseConfig, GPTQConfig
|
28
|
+
from tico.experimental.quantization.quantizer import BaseQuantizer
|
29
|
+
|
30
|
+
|
31
|
+
class GPTQQuantizer(BaseQuantizer):
|
32
|
+
"""
|
33
|
+
Quantizer for applying the GPTQ algorithm (typically for weight quantization)
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, config: BaseConfig):
|
37
|
+
super().__init__(config)
|
38
|
+
|
39
|
+
self.cache_args: List[Any] = []
|
40
|
+
self.cache_kwargs: Dict[str, Any] = {"batch_num": 0}
|
41
|
+
|
42
|
+
@torch.no_grad()
|
43
|
+
def prepare(
|
44
|
+
self,
|
45
|
+
model: torch.nn.Module,
|
46
|
+
args: Optional[Any] = None,
|
47
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
48
|
+
):
|
49
|
+
"""
|
50
|
+
Overrides the forward method of the first LLaMA layer (layer 0) to capture the
|
51
|
+
input required for calibration.
|
52
|
+
|
53
|
+
This method modifies the original forward pass of LLaMA layer 0 so that the
|
54
|
+
inputs used during inference are intercepted and recorded. These captured inputs
|
55
|
+
are then utilized to calibrate the quantization parameters for the GPTQ.
|
56
|
+
|
57
|
+
Parameters:
|
58
|
+
model: The target PyTorch model.
|
59
|
+
args: Positional example inputs required for capturing graph.
|
60
|
+
kwargs: Keyword example inputs required for capturing graph.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
The model prepared for GPTQ quantization.
|
64
|
+
"""
|
65
|
+
if args is None and kwargs is None:
|
66
|
+
raise RuntimeError(
|
67
|
+
"Either args or kwargs must be provided for captruing graph."
|
68
|
+
)
|
69
|
+
# Define a function to capture input activations and associated parameters.
|
70
|
+
def forward(layer, *args, **kwargs):
|
71
|
+
self.cache_kwargs["batch_num"] += 1
|
72
|
+
for idx, item in enumerate(args):
|
73
|
+
if (idx + 1) > len(self.cache_args):
|
74
|
+
self.cache_args.append([])
|
75
|
+
self.cache_args[idx].append(item)
|
76
|
+
for arg in kwargs:
|
77
|
+
if self.cache_kwargs.get(arg, None) is None:
|
78
|
+
self.cache_kwargs[arg] = []
|
79
|
+
self.cache_kwargs[arg].append(kwargs[arg])
|
80
|
+
# Raise an error to interrupt the forward pass after capturing data.
|
81
|
+
raise ValueError
|
82
|
+
|
83
|
+
# Replace the first layer with defined function to capture calibration data.
|
84
|
+
if hasattr(model, "model"):
|
85
|
+
assert hasattr(model.model, "layers")
|
86
|
+
assert isinstance(model.model.layers, torch.nn.ModuleList)
|
87
|
+
layer_forward_cache = model.model.layers[0].forward
|
88
|
+
model.model.layers[0].forward = types.MethodType(
|
89
|
+
forward, model.model.layers[0]
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
assert hasattr(model, "forward")
|
93
|
+
layer_forward_cache = model.forward
|
94
|
+
model.forward = types.MethodType(forward, model.forward)
|
95
|
+
|
96
|
+
model_forward_cache = model.forward
|
97
|
+
# Replace model's forward to avoid ValueError
|
98
|
+
def model_forward(model, *args, **kwargs):
|
99
|
+
nonlocal model_forward_cache
|
100
|
+
try:
|
101
|
+
model_forward_cache(*args, **kwargs)
|
102
|
+
except ValueError:
|
103
|
+
pass
|
104
|
+
|
105
|
+
model.forward = types.MethodType(model_forward, model)
|
106
|
+
kwargs = kwargs or {}
|
107
|
+
model(*args, **kwargs) # type: ignore[misc]
|
108
|
+
|
109
|
+
# Recover original forward
|
110
|
+
model.forward = model_forward_cache
|
111
|
+
if hasattr(model, "model"):
|
112
|
+
assert hasattr(model.model, "layers")
|
113
|
+
assert isinstance(model.model.layers, torch.nn.ModuleList)
|
114
|
+
model.model.layers[0].forward = layer_forward_cache
|
115
|
+
else:
|
116
|
+
model.forward = layer_forward_cache
|
117
|
+
|
118
|
+
return model
|
119
|
+
|
120
|
+
@torch.no_grad()
|
121
|
+
def convert(self, model):
|
122
|
+
"""
|
123
|
+
Convert the prepared model to its GPTQ quantized version.
|
124
|
+
|
125
|
+
Applies the GPTQ quantization on weights based on the collected statistics.
|
126
|
+
|
127
|
+
Parameters:
|
128
|
+
model: The prepared PyTorch model.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
The quantized model.
|
132
|
+
"""
|
133
|
+
gptq_conf = self.config
|
134
|
+
assert isinstance(gptq_conf, GPTQConfig)
|
135
|
+
|
136
|
+
# Save the original cache setting and disable caching during calibration/inference.
|
137
|
+
if hasattr(model, "config"):
|
138
|
+
use_cache = model.config.use_cache
|
139
|
+
model.config.use_cache = False
|
140
|
+
|
141
|
+
quantizers = {}
|
142
|
+
if hasattr(model, "model"):
|
143
|
+
target_layers = model.model.layers
|
144
|
+
else:
|
145
|
+
target_layers = [model]
|
146
|
+
for l_idx, layer in enumerate(target_layers):
|
147
|
+
# Identify quantizable submodules within the layer.
|
148
|
+
full = find_layers(layer)
|
149
|
+
|
150
|
+
sequential = [list(full.keys())]
|
151
|
+
for names in sequential:
|
152
|
+
subset = {n: full[n] for n in names}
|
153
|
+
|
154
|
+
gptq: Dict[str, GPTQ] = {}
|
155
|
+
for name in subset:
|
156
|
+
gptq[name] = GPTQ(subset[name])
|
157
|
+
gptq[name].quantizer.configure(
|
158
|
+
8, perchannel=True, sym=False, mse=False
|
159
|
+
)
|
160
|
+
# Define a hook to collect input/output batches for quantizer calibration.
|
161
|
+
def add_batch(name):
|
162
|
+
def tmp(_, inp, out):
|
163
|
+
gptq[name].add_batch(inp[0].data, out.data)
|
164
|
+
|
165
|
+
return tmp
|
166
|
+
|
167
|
+
handles = []
|
168
|
+
for name in subset:
|
169
|
+
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
170
|
+
# Run the current layer on the stored calibration inputs to capture activation stats.
|
171
|
+
batch_num = self.cache_kwargs.pop("batch_num")
|
172
|
+
for batch_idx in range(batch_num):
|
173
|
+
cache_args_batch = gather_single_batch_from_list(
|
174
|
+
self.cache_args, batch_idx
|
175
|
+
)
|
176
|
+
cache_kwargs_batch = gather_single_batch_from_dict(
|
177
|
+
self.cache_kwargs, batch_idx
|
178
|
+
)
|
179
|
+
layer(*cache_args_batch, **cache_kwargs_batch)[0]
|
180
|
+
self.cache_kwargs["batch_num"] = batch_num
|
181
|
+
for h in handles:
|
182
|
+
h.remove()
|
183
|
+
# Quantize each submodule using the collected calibration data.
|
184
|
+
for name in subset:
|
185
|
+
if gptq_conf.verbose:
|
186
|
+
print(l_idx, name)
|
187
|
+
print("Quantizing ...")
|
188
|
+
gptq[name].fasterquant(
|
189
|
+
percdamp=0.01,
|
190
|
+
groupsize=-1,
|
191
|
+
actorder=True,
|
192
|
+
static_groups=False,
|
193
|
+
verbose=gptq_conf.verbose,
|
194
|
+
)
|
195
|
+
quantizers["model.layers.%d.%s" % (l_idx, name)] = gptq[
|
196
|
+
name
|
197
|
+
].quantizer
|
198
|
+
gptq[name].free()
|
199
|
+
"""
|
200
|
+
Execute the quantized layer with the calibration inputs to obtain ouptuts
|
201
|
+
that will serve as inputs for the next layer.
|
202
|
+
|
203
|
+
This ensures that the quantization effects are correctly propagated to subsequent
|
204
|
+
layers.
|
205
|
+
"""
|
206
|
+
batch_num = self.cache_kwargs.pop("batch_num")
|
207
|
+
for batch_idx in range(batch_num):
|
208
|
+
cache_args_batch = gather_single_batch_from_list(
|
209
|
+
self.cache_args, batch_idx
|
210
|
+
)
|
211
|
+
cache_kwargs_batch = gather_single_batch_from_dict(
|
212
|
+
self.cache_kwargs, batch_idx
|
213
|
+
)
|
214
|
+
outs = layer(*cache_args_batch, **cache_kwargs_batch)[0]
|
215
|
+
# Update inputs for next iteration.
|
216
|
+
self.cache_args[0][batch_idx] = outs
|
217
|
+
self.cache_kwargs["batch_num"] = batch_num
|
218
|
+
|
219
|
+
if torch.cuda.is_available():
|
220
|
+
torch.cuda.empty_cache()
|
221
|
+
# Restore the original cache configuration.
|
222
|
+
if hasattr(model, "config"):
|
223
|
+
model.config.use_cache = use_cache
|
224
|
+
|
225
|
+
return model
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) 2024 Intel Corporation
|
2
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
|
19
|
+
def find_layers(module, layers=[torch.nn.Linear], name=""):
|
20
|
+
if type(module) in layers:
|
21
|
+
return {name: module}
|
22
|
+
res = {}
|
23
|
+
for name1, child in module.named_children():
|
24
|
+
res.update(
|
25
|
+
find_layers(
|
26
|
+
child, layers=layers, name=name + "." + name1 if name != "" else name1
|
27
|
+
)
|
28
|
+
)
|
29
|
+
return res
|
30
|
+
|
31
|
+
|
32
|
+
def gather_single_batch_from_dict(data_dict, idx):
|
33
|
+
"""
|
34
|
+
Gather single batch from a dict.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
data_dict (dict): data dict.
|
38
|
+
idx (int): index
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
dict: single batch.
|
42
|
+
"""
|
43
|
+
# obtain a set of keyword input from cache
|
44
|
+
single_batch = {}
|
45
|
+
for k, v in data_dict.items():
|
46
|
+
single_batch[k] = data_dict[k][idx]
|
47
|
+
return single_batch
|
48
|
+
|
49
|
+
|
50
|
+
def gather_single_batch_from_list(data_list, idx):
|
51
|
+
"""
|
52
|
+
Gather single batch from a list.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
data_dict (dict): data list.
|
56
|
+
idx (int): index
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
list: single batch.
|
60
|
+
"""
|
61
|
+
# obtain a set of keyword input from cache
|
62
|
+
single_batch = []
|
63
|
+
for data_item in data_list:
|
64
|
+
single_batch.append(data_item[idx])
|
65
|
+
return single_batch
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,215 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
import functools
|
18
|
+
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
import torch.fx
|
22
|
+
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
23
|
+
import torch
|
24
|
+
from torch.ao.quantization.observer import (
|
25
|
+
MinMaxObserver,
|
26
|
+
MovingAverageMinMaxObserver,
|
27
|
+
MovingAveragePerChannelMinMaxObserver,
|
28
|
+
PerChannelMinMaxObserver,
|
29
|
+
)
|
30
|
+
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
31
|
+
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
32
|
+
|
33
|
+
from tico.experimental.quantization.algorithm.pt2e.annotation.op import *
|
34
|
+
import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
35
|
+
import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
36
|
+
import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
|
37
|
+
from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
|
38
|
+
QuantizationConfig,
|
39
|
+
)
|
40
|
+
from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
|
41
|
+
convert_scalars_to_attrs,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
class PT2EAnnotator(Quantizer):
|
46
|
+
"""
|
47
|
+
The class annotates quantization configurations on each nodes.
|
48
|
+
|
49
|
+
Observers would be attached according to those configurations in
|
50
|
+
'torch.prepare_pt2e'.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self):
|
54
|
+
super().__init__()
|
55
|
+
self.global_config: Optional[QuantizationConfig] = None
|
56
|
+
self.operator_type_config: Dict[
|
57
|
+
torch._ops.OpOverloadPacket, QuantizationConfig
|
58
|
+
] = {}
|
59
|
+
self.module_type_config: Dict[Callable, QuantizationConfig] = {}
|
60
|
+
self.module_name_config: Dict[str, QuantizationConfig] = {}
|
61
|
+
|
62
|
+
def set_global(self, quantization_config: QuantizationConfig) -> PT2EAnnotator:
|
63
|
+
"""
|
64
|
+
Set quantization config globally.
|
65
|
+
"""
|
66
|
+
assert quantization_config is not None
|
67
|
+
self.global_config = quantization_config
|
68
|
+
return self
|
69
|
+
|
70
|
+
def set_operator_type(
|
71
|
+
self,
|
72
|
+
operator_type: torch._ops.OpOverloadPacket,
|
73
|
+
quantization_config: QuantizationConfig,
|
74
|
+
) -> PT2EAnnotator:
|
75
|
+
"""
|
76
|
+
Set quantization config for given operator type.
|
77
|
+
"""
|
78
|
+
assert quantization_config is not None
|
79
|
+
self.operator_type_config[operator_type] = quantization_config
|
80
|
+
return self
|
81
|
+
|
82
|
+
def set_module_type(
|
83
|
+
self, module_type: Callable, quantization_config: QuantizationConfig
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
Set quantization config for given module type.
|
87
|
+
|
88
|
+
For example, let's say quantizer.set_module_type(nn.Linear).
|
89
|
+
It will quantize all 'nn.Linear' modules with the `quantization_config`.
|
90
|
+
"""
|
91
|
+
assert quantization_config is not None
|
92
|
+
self.module_type_config[module_type] = quantization_config
|
93
|
+
return self
|
94
|
+
|
95
|
+
def set_module_name(
|
96
|
+
self, module_name: str, quantization_config: QuantizationConfig
|
97
|
+
):
|
98
|
+
"""
|
99
|
+
Set quantization config for given module name.
|
100
|
+
|
101
|
+
For example, let's say quantizer.set_module_name("blocks.sub").
|
102
|
+
It will quantize all nodes that come from a module whose name is "blocks.sub"
|
103
|
+
with the `quantization_config`.
|
104
|
+
"""
|
105
|
+
assert quantization_config is not None
|
106
|
+
self.module_name_config[module_name] = quantization_config
|
107
|
+
return self
|
108
|
+
|
109
|
+
def transform_for_annotation(
|
110
|
+
self, model: torch.fx.GraphModule
|
111
|
+
) -> torch.fx.GraphModule:
|
112
|
+
"""Allows for user defined transforms to run before annotating the graph."""
|
113
|
+
model = convert_scalars_to_attrs(model)
|
114
|
+
return model
|
115
|
+
|
116
|
+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
117
|
+
model = self._annotate_for_quantization(model)
|
118
|
+
annot_utils.propagate_annotation_forward(model)
|
119
|
+
return model
|
120
|
+
|
121
|
+
def validate(self, model: torch.fx.GraphModule) -> None:
|
122
|
+
# TODO Consider this method.
|
123
|
+
pass
|
124
|
+
|
125
|
+
def _annotate_by_config_and_filter(
|
126
|
+
self,
|
127
|
+
model: torch.fx.GraphModule,
|
128
|
+
quantization_config: Optional[QuantizationConfig],
|
129
|
+
filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
130
|
+
) -> torch.fx.GraphModule:
|
131
|
+
assert quantization_config is not None
|
132
|
+
|
133
|
+
for node in model.graph.nodes:
|
134
|
+
if node.target not in annot_spec.OP_TO_ANNOTATOR:
|
135
|
+
continue
|
136
|
+
annot_spec.OP_TO_ANNOTATOR[node.target](
|
137
|
+
model, node, quantization_config, filter_fn
|
138
|
+
)
|
139
|
+
return model
|
140
|
+
|
141
|
+
def _annotate_for_quantization(
|
142
|
+
self, model: torch.fx.GraphModule
|
143
|
+
) -> torch.fx.GraphModule:
|
144
|
+
# Annotate according to the given module names.
|
145
|
+
module_name_list = list(self.module_name_config.keys())
|
146
|
+
for module_name, config in self.module_name_config.items():
|
147
|
+
self._annotate_by_config_and_filter(
|
148
|
+
model, config, _get_module_name_filter(module_name)
|
149
|
+
)
|
150
|
+
|
151
|
+
# Annotate according to the given module types.
|
152
|
+
tp_list = list(self.module_type_config.keys())
|
153
|
+
for module_type, config in self.module_type_config.items():
|
154
|
+
self._annotate_by_config_and_filter(
|
155
|
+
model, config, quant_utils.get_module_type_filter(module_type)
|
156
|
+
)
|
157
|
+
|
158
|
+
# TODO Annotate according to the given operator types.
|
159
|
+
|
160
|
+
self._annotate_by_config_and_filter(
|
161
|
+
model,
|
162
|
+
self.global_config,
|
163
|
+
quant_utils.get_not_module_type_or_name_filter(tp_list, module_name_list),
|
164
|
+
)
|
165
|
+
return model
|
166
|
+
|
167
|
+
|
168
|
+
@functools.lru_cache
|
169
|
+
def get_asymmetric_quantization_config(
|
170
|
+
weight_is_per_channel: bool = True,
|
171
|
+
act_qmin: int = 0,
|
172
|
+
act_qmax: int = 255,
|
173
|
+
weight_qmin: int = 0,
|
174
|
+
weight_qmax: int = 255,
|
175
|
+
) -> QuantizationConfig:
|
176
|
+
# activation
|
177
|
+
act_extra_args: Dict[str, Any] = {"eps": 2**-12}
|
178
|
+
act_observer = MinMaxObserver
|
179
|
+
act_qspec = QuantizationSpec(
|
180
|
+
dtype=torch.uint8,
|
181
|
+
quant_min=act_qmin,
|
182
|
+
quant_max=act_qmax,
|
183
|
+
qscheme=torch.per_tensor_affine,
|
184
|
+
is_dynamic=False,
|
185
|
+
observer_or_fake_quant_ctr=act_observer.with_args(
|
186
|
+
**act_extra_args,
|
187
|
+
),
|
188
|
+
)
|
189
|
+
# weight
|
190
|
+
weight_extra_args: Dict[str, Any] = {"eps": 2**-12}
|
191
|
+
weight_qscheme = (
|
192
|
+
torch.per_channel_affine if weight_is_per_channel else torch.per_tensor_affine
|
193
|
+
)
|
194
|
+
weight_observer: _ObserverOrFakeQuantizeConstructor = (
|
195
|
+
PerChannelMinMaxObserver if weight_is_per_channel else MinMaxObserver
|
196
|
+
)
|
197
|
+
weight_qspec = QuantizationSpec(
|
198
|
+
dtype=torch.uint8,
|
199
|
+
quant_min=weight_qmin,
|
200
|
+
quant_max=weight_qmax,
|
201
|
+
qscheme=weight_qscheme,
|
202
|
+
ch_axis=0,
|
203
|
+
is_dynamic=False,
|
204
|
+
observer_or_fake_quant_ctr=weight_observer.with_args(**weight_extra_args),
|
205
|
+
)
|
206
|
+
|
207
|
+
# Set bias qspec in each annotation functions.
|
208
|
+
bias_qspec = None
|
209
|
+
quantization_config = QuantizationConfig(
|
210
|
+
act_qspec,
|
211
|
+
act_qspec,
|
212
|
+
weight_qspec,
|
213
|
+
bias_qspec,
|
214
|
+
)
|
215
|
+
return quantization_config
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from torch.ao.quantization.quantizer import QuantizationSpec
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass(eq=True, frozen=True)
|
22
|
+
class QuantizationConfig:
|
23
|
+
input_activation: Optional[QuantizationSpec]
|
24
|
+
output_activation: Optional[QuantizationSpec]
|
25
|
+
weight: Optional[QuantizationSpec]
|
26
|
+
bias: Optional[QuantizationSpec]
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import glob
|
16
|
+
from os.path import basename, dirname, isfile, join
|
17
|
+
|
18
|
+
modules = glob.glob(join(dirname(__file__), "*.py"))
|
19
|
+
__all__ = [
|
20
|
+
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
|
21
|
+
]
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Callable, Optional, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.ao.quantization.quantizer import SharedQuantizationSpec
|
21
|
+
|
22
|
+
import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
23
|
+
import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
24
|
+
import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
|
25
|
+
from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
|
26
|
+
QuantizationConfig,
|
27
|
+
)
|
28
|
+
from tico.utils.validate_args_kwargs import AdaptiveAvgPool2dArgs
|
29
|
+
|
30
|
+
|
31
|
+
@annot_spec.register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
|
32
|
+
def _annotate_adaptive_avg_pool2d(
|
33
|
+
gm: torch.fx.GraphModule,
|
34
|
+
node: torch.fx.Node,
|
35
|
+
quantization_config: Optional[QuantizationConfig],
|
36
|
+
filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
37
|
+
):
|
38
|
+
if (
|
39
|
+
node.op != "call_function"
|
40
|
+
or node.target != torch.ops.aten.adaptive_avg_pool2d.default
|
41
|
+
):
|
42
|
+
return
|
43
|
+
if filter_fn and not filter_fn(node):
|
44
|
+
return
|
45
|
+
if quant_utils.is_annotated(node):
|
46
|
+
return
|
47
|
+
|
48
|
+
args = AdaptiveAvgPool2dArgs(*node.args) # type: ignore[arg-type]
|
49
|
+
input = args.input
|
50
|
+
|
51
|
+
assert isinstance(input, torch.fx.Node)
|
52
|
+
if (
|
53
|
+
"quantization_annotation" not in input.meta
|
54
|
+
or not input.meta["quantization_annotation"]._annotated
|
55
|
+
or input.meta["quantization_annotation"].output_qspec is None
|
56
|
+
):
|
57
|
+
input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
|
58
|
+
else:
|
59
|
+
input_act_qspec = SharedQuantizationSpec(input)
|
60
|
+
annot_utils.annotate_input_qspec_map(node, input, input_act_qspec)
|
61
|
+
|
62
|
+
output_act_qspec = SharedQuantizationSpec((input, node))
|
63
|
+
annot_utils.annotate_output_qspec(node, output_act_qspec)
|
64
|
+
|
65
|
+
annot_utils.mark_nodes_as_annotated(node)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Callable, Optional, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
|
21
|
+
import tico.experimental.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
22
|
+
import tico.experimental.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
23
|
+
import tico.experimental.quantization.algorithm.pt2e.utils as quant_utils
|
24
|
+
from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
|
25
|
+
QuantizationConfig,
|
26
|
+
)
|
27
|
+
from tico.utils.validate_args_kwargs import AddTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@annot_spec.register_annotator([torch.ops.aten.add.Tensor])
|
31
|
+
def _annotate_add(
|
32
|
+
gm: torch.fx.GraphModule,
|
33
|
+
node: torch.fx.Node,
|
34
|
+
quantization_config: Optional[QuantizationConfig],
|
35
|
+
filter_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
36
|
+
):
|
37
|
+
if node.op != "call_function" or node.target != torch.ops.aten.add.Tensor:
|
38
|
+
return
|
39
|
+
if filter_fn and not filter_fn(node):
|
40
|
+
return
|
41
|
+
if quant_utils.is_annotated(node):
|
42
|
+
return
|
43
|
+
|
44
|
+
args = AddTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input_ = args.input
|
46
|
+
other = args.other
|
47
|
+
|
48
|
+
input_act_qspec = quant_utils.get_input_act_qspec(quantization_config)
|
49
|
+
if isinstance(input_, torch.fx.Node):
|
50
|
+
annot_utils.annotate_input_qspec_map(node, input_, input_act_qspec)
|
51
|
+
if isinstance(other, torch.fx.Node):
|
52
|
+
annot_utils.annotate_input_qspec_map(node, other, input_act_qspec)
|
53
|
+
|
54
|
+
output_act_qspec = quant_utils.get_output_act_qspec(quantization_config)
|
55
|
+
annot_utils.annotate_output_qspec(node, output_act_qspec)
|
56
|
+
|
57
|
+
annot_utils.mark_nodes_as_annotated(node)
|