tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def compute_max_abs_diff(base: torch.Tensor, target: torch.Tensor) -> float:
|
|
22
|
+
"""
|
|
23
|
+
Return the *maximum* absolute element-wise difference between two tensors.
|
|
24
|
+
"""
|
|
25
|
+
assert base.shape == target.shape, "shape mismatch"
|
|
26
|
+
return (base.detach() - target.detach()).abs().max().item()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compute_peir(base: torch.Tensor, target: torch.Tensor) -> float:
|
|
30
|
+
"""
|
|
31
|
+
Peak-Error-to-Interval Ratio (PEIR).
|
|
32
|
+
|
|
33
|
+
PEIR = max(|base - target|) / (max(base) - min(base))
|
|
34
|
+
|
|
35
|
+
The interval denominator uses the reference (*base*) tensor only — this
|
|
36
|
+
makes PEIR independent of quantisation error in `target`.
|
|
37
|
+
"""
|
|
38
|
+
assert base.shape == target.shape, "shape mismatch"
|
|
39
|
+
peak_error = (base.detach() - target.detach()).abs().max().item()
|
|
40
|
+
interval = (base.detach().max() - base.detach().min()).item()
|
|
41
|
+
interval = 1.0 if interval == 0.0 else interval # avoid divide-by-zero
|
|
42
|
+
return peak_error / interval
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def mse(base: torch.Tensor, target: torch.Tensor) -> float:
|
|
46
|
+
"""
|
|
47
|
+
Mean Squared Error (MSE).
|
|
48
|
+
Penalizes **larger** deviations more heavily than MAE by squaring each
|
|
49
|
+
difference — helpful to expose occasional large spikes.
|
|
50
|
+
Formula
|
|
51
|
+
-------
|
|
52
|
+
MSE = mean((base - target)²)
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
float
|
|
56
|
+
Mean squared error. *Lower is better*.
|
|
57
|
+
"""
|
|
58
|
+
return torch.mean((base.detach() - target.detach()) ** 2).item()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MetricCalculator:
|
|
62
|
+
"""
|
|
63
|
+
Lightweight registry-and-dispatcher for **pair-wise tensor comparison metrics**.
|
|
64
|
+
|
|
65
|
+
Purpose
|
|
66
|
+
-------
|
|
67
|
+
Consolidate all metrics used to assess the discrepancy between a reference
|
|
68
|
+
(usually FP32) tensor and its quantized counterpart, while letting the caller
|
|
69
|
+
choose *at runtime* which subset to evaluate.
|
|
70
|
+
|
|
71
|
+
Built-in metrics
|
|
72
|
+
----------------
|
|
73
|
+
Key Description
|
|
74
|
+
-------------------- -------------------------------------------------
|
|
75
|
+
"diff" / "max_abs_diff" Maximum absolute element-wise difference
|
|
76
|
+
"peir" Peak-Error-to-Interval Ratio
|
|
77
|
+
|
|
78
|
+
Usage pattern
|
|
79
|
+
-------------
|
|
80
|
+
>>> calc = MetricCalculator(custom_metrics={'mse': mse_fn})
|
|
81
|
+
>>> stats = calc.compute(fp_outs, q_outs, metrics=['diff', 'mse'])
|
|
82
|
+
|
|
83
|
+
• **Instantiation** registers any extra user metrics
|
|
84
|
+
(signature: ``fn(base: Tensor, target: Tensor) -> float``).
|
|
85
|
+
• **compute(...)** takes two *equal-length* lists of tensors and an optional
|
|
86
|
+
list of metric names.
|
|
87
|
+
— If *metrics* is *None*, every registered metric is evaluated.
|
|
88
|
+
— Returns a dict: ``{metric_name -> [value for each tensor pair]}``.
|
|
89
|
+
|
|
90
|
+
Implementation notes
|
|
91
|
+
--------------------
|
|
92
|
+
* All tensors are detached before calculation to avoid autograd overhead.
|
|
93
|
+
* Registrations are stored in `self.registry` (str → callable).
|
|
94
|
+
* Duplicate metric names between built-ins and custom metrics raise an error
|
|
95
|
+
at construction time to prevent silent shadowing.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
builtin_metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = {
|
|
99
|
+
"diff": compute_max_abs_diff,
|
|
100
|
+
"max_abs_diff": compute_max_abs_diff,
|
|
101
|
+
"peir": compute_peir,
|
|
102
|
+
"mse": mse,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
custom_metrics: Optional[
|
|
108
|
+
Dict[str, Callable[[torch.Tensor, torch.Tensor], float]]
|
|
109
|
+
] = None,
|
|
110
|
+
):
|
|
111
|
+
self.registry: Dict[str, Callable] = self.builtin_metrics.copy()
|
|
112
|
+
if custom_metrics:
|
|
113
|
+
dup = self.registry.keys() & custom_metrics.keys()
|
|
114
|
+
if dup:
|
|
115
|
+
raise RuntimeError(f"Duplicate metric names: {dup}")
|
|
116
|
+
assert custom_metrics is not None
|
|
117
|
+
self.registry.update(custom_metrics) # type: ignore[arg-type]
|
|
118
|
+
|
|
119
|
+
# ----------------------------------------------------------------- #
|
|
120
|
+
# Public API #
|
|
121
|
+
# ----------------------------------------------------------------- #
|
|
122
|
+
def compute(
|
|
123
|
+
self,
|
|
124
|
+
base_outputs: List[torch.Tensor],
|
|
125
|
+
target_outputs: List[torch.Tensor],
|
|
126
|
+
metrics: Optional[List[str]] = None,
|
|
127
|
+
) -> Dict[str, List[Any]]:
|
|
128
|
+
"""
|
|
129
|
+
Compute selected metrics for every (base, target) pair.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
metrics
|
|
134
|
+
List of metric names to evaluate **this call**.
|
|
135
|
+
• None → evaluate *all* registered metrics.
|
|
136
|
+
"""
|
|
137
|
+
sel = metrics or list(self.registry)
|
|
138
|
+
unknown = set(sel) - self.registry.keys()
|
|
139
|
+
if unknown:
|
|
140
|
+
raise RuntimeError(f"Unknown metric(s): {unknown}")
|
|
141
|
+
|
|
142
|
+
results: Dict[str, List[Any]] = {m: [] for m in sel}
|
|
143
|
+
for base, tgt in zip(base_outputs, target_outputs):
|
|
144
|
+
for m in sel:
|
|
145
|
+
results[m].append(self.registry[m](base, tgt))
|
|
146
|
+
return results
|
|
@@ -44,7 +44,7 @@ def quantize(
|
|
|
44
44
|
data = np.array(data)
|
|
45
45
|
# Perfrom quantization
|
|
46
46
|
if not scale:
|
|
47
|
-
logger.
|
|
47
|
+
logger.warning("WARNING: scale value is 0. 1e-7 will be used instead.")
|
|
48
48
|
scale = 1e-7
|
|
49
49
|
rescaled = np.round(data / scale) + zero_point
|
|
50
50
|
# Clamp the values
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -13,25 +13,17 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import copy
|
|
16
|
-
from typing import Any, Dict, Optional
|
|
16
|
+
from typing import Any, Dict, Optional
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from tico.experimental.quantization.config import BaseConfig
|
|
26
|
-
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
20
|
+
from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
|
|
21
|
+
from tico.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
|
|
22
|
+
from tico.quantization.config.base import BaseConfig
|
|
23
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
24
|
+
from tico.quantization.quantizer_registry import get_quantizer
|
|
27
25
|
|
|
28
26
|
|
|
29
|
-
config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
|
|
30
|
-
"pt2e": PT2EQuantizer,
|
|
31
|
-
"gptq": GPTQQuantizer,
|
|
32
|
-
"smooth_quant": SmoothQuantQuantizer,
|
|
33
|
-
}
|
|
34
|
-
|
|
35
27
|
QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
|
|
36
28
|
|
|
37
29
|
|
|
@@ -40,7 +32,7 @@ def prepare(
|
|
|
40
32
|
quant_config: BaseConfig,
|
|
41
33
|
args: Optional[Any] = None,
|
|
42
34
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
43
|
-
inplace: Optional[bool] =
|
|
35
|
+
inplace: Optional[bool] = True,
|
|
44
36
|
):
|
|
45
37
|
"""
|
|
46
38
|
Prepare the model for quantization using the provided configuration.
|
|
@@ -61,21 +53,22 @@ def prepare(
|
|
|
61
53
|
"""
|
|
62
54
|
if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
|
|
63
55
|
raise RuntimeError("prepare() already has been called.")
|
|
64
|
-
|
|
56
|
+
quantizer = get_quantizer(quant_config)
|
|
57
|
+
|
|
58
|
+
if isinstance(quantizer, PT2EQuantizer) and inplace:
|
|
65
59
|
raise RuntimeError(
|
|
66
60
|
"In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
|
|
67
61
|
)
|
|
68
62
|
|
|
69
63
|
model = model if inplace else copy.deepcopy(model)
|
|
70
64
|
|
|
71
|
-
quantizer = config_to_quantizer[quant_config.name](quant_config)
|
|
72
65
|
model = quantizer.prepare(model, args, kwargs)
|
|
73
66
|
setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
|
|
74
67
|
|
|
75
68
|
return model
|
|
76
69
|
|
|
77
70
|
|
|
78
|
-
def convert(model, inplace: Optional[bool] =
|
|
71
|
+
def convert(model, inplace: Optional[bool] = True):
|
|
79
72
|
"""
|
|
80
73
|
Convert the prepared model to a quantized model using the provided configuration.
|
|
81
74
|
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import importlib
|
|
16
|
+
from typing import Dict, Optional, Type, TypeVar
|
|
17
|
+
|
|
18
|
+
from tico.quantization.config.base import BaseConfig
|
|
19
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
20
|
+
|
|
21
|
+
TQ = TypeVar("TQ", bound=BaseQuantizer)
|
|
22
|
+
|
|
23
|
+
# Mapping: Config type -> Quantizer type
|
|
24
|
+
_REGISTRY: Dict[Type[BaseConfig], Type[BaseQuantizer]] = {}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def register_quantizer(config_cls: Type[BaseConfig]):
|
|
28
|
+
"""
|
|
29
|
+
Decorator to register a quantizer for a given config class.
|
|
30
|
+
Usage:
|
|
31
|
+
@register_quantizer(GPTQConfig)
|
|
32
|
+
class GPTQQuantizer(BaseQuantizer): ...
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def wrapper(quantizer_cls: Type[TQ]) -> Type[TQ]:
|
|
36
|
+
_REGISTRY[config_cls] = quantizer_cls
|
|
37
|
+
return quantizer_cls
|
|
38
|
+
|
|
39
|
+
return wrapper
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _lookup(cfg: BaseConfig) -> Optional[Type[BaseQuantizer]]:
|
|
43
|
+
"""Return a quantizer class only if the exact config type is registered."""
|
|
44
|
+
return _REGISTRY.get(type(cfg))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_quantizer(cfg: BaseConfig) -> BaseQuantizer:
|
|
48
|
+
"""Factory to return a quantizer instance for the given config."""
|
|
49
|
+
qcls = _lookup(cfg)
|
|
50
|
+
if qcls is not None:
|
|
51
|
+
return qcls(cfg)
|
|
52
|
+
|
|
53
|
+
# Lazy import by naming convention
|
|
54
|
+
name = getattr(cfg, "name", None)
|
|
55
|
+
if name:
|
|
56
|
+
if name == "ptq":
|
|
57
|
+
importlib.import_module(f"tico.quantization.wrapq.quantizer")
|
|
58
|
+
else:
|
|
59
|
+
try:
|
|
60
|
+
importlib.import_module(f"tico.quantization.algorithm.{name}.quantizer")
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise RuntimeError(
|
|
63
|
+
f"Failed to import quantizer module for config name='{name}': {e}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
qcls = _lookup(cfg)
|
|
67
|
+
if qcls is not None:
|
|
68
|
+
return qcls(cfg)
|
|
69
|
+
|
|
70
|
+
raise RuntimeError(
|
|
71
|
+
f"No quantizer registered for config type {type(cfg).__name__} "
|
|
72
|
+
f"(name='{getattr(cfg,'name',None)}')."
|
|
73
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class DType:
|
|
20
|
+
"""
|
|
21
|
+
Self-contained integer dtypes for quantization.
|
|
22
|
+
|
|
23
|
+
A DType is just an immutable value-object with two fields:
|
|
24
|
+
- bits
|
|
25
|
+
- signed
|
|
26
|
+
|
|
27
|
+
Common presets (INT8, UINT4, ..) are provided as constants for convenience.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
bits: int # pylint: disable=used-before-assignment
|
|
31
|
+
signed: bool = False # False -> unsigned
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def qmin(self) -> int:
|
|
35
|
+
assert self.bits is not None
|
|
36
|
+
if self.signed:
|
|
37
|
+
return -(1 << (self.bits - 1))
|
|
38
|
+
return 0
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def qmax(self) -> int:
|
|
42
|
+
assert self.bits is not None
|
|
43
|
+
if self.signed:
|
|
44
|
+
return (1 << (self.bits - 1)) - 1
|
|
45
|
+
return (1 << self.bits) - 1
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
prefix = "int" if self.signed else "uint"
|
|
49
|
+
return f"{prefix}{self.bits}"
|
|
50
|
+
|
|
51
|
+
# ────────────────────────────────
|
|
52
|
+
# Factory helpers
|
|
53
|
+
# ────────────────────────────────
|
|
54
|
+
@staticmethod
|
|
55
|
+
def int(bits: int): # type: ignore[valid-type]
|
|
56
|
+
return DType(bits, signed=True)
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def uint(bits: int): # type: ignore[valid-type]
|
|
60
|
+
return DType(bits, signed=False)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ---------------------------------------------------------------------
|
|
64
|
+
# Convenient canned versions
|
|
65
|
+
# ---------------------------------------------------------------------
|
|
66
|
+
UINT4 = DType.uint(4)
|
|
67
|
+
INT4 = DType.int(4)
|
|
68
|
+
INT8 = DType.int(8)
|
|
69
|
+
UINT8 = DType.uint(8)
|
|
70
|
+
INT16 = DType.int(16)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# QUICK PTQ WORKFLOW (OPTIONAL FP32 BASELINE)
|
|
17
|
+
# -----------------------------------------------------------------------------
|
|
18
|
+
# Toggle RUN_FP to choose between:
|
|
19
|
+
# • FP32 perplexity measurement only, OR
|
|
20
|
+
# • Full post-training UINT-8 flow (wrap → calibrate → eval).
|
|
21
|
+
# =============================================================================
|
|
22
|
+
|
|
23
|
+
import argparse
|
|
24
|
+
import sys
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import tqdm
|
|
28
|
+
from datasets import load_dataset
|
|
29
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
30
|
+
|
|
31
|
+
from tico.quantization import convert, prepare
|
|
32
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
33
|
+
from tico.quantization.wrapq.utils.metrics import perplexity
|
|
34
|
+
|
|
35
|
+
# Token-budget presets for activation calibration
|
|
36
|
+
TOKENS: dict[str, int] = {
|
|
37
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
|
38
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
|
39
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
|
40
|
+
"baseline": 50_000,
|
|
41
|
+
# Production / 4-bit observer smoothing
|
|
42
|
+
"production": 200_000,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
DTYPE_MAP = {
|
|
46
|
+
"float32": torch.float32,
|
|
47
|
+
"bfloat16": torch.bfloat16,
|
|
48
|
+
"float16": torch.float16,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
# Hardcoded dataset settings
|
|
52
|
+
DATASET_NAME = "wikitext"
|
|
53
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
54
|
+
TRAIN_SPLIT = "train"
|
|
55
|
+
TEST_SPLIT = "test"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def main():
|
|
59
|
+
parser = argparse.ArgumentParser(description="Quick PTQ example (FP or UINT8)")
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--mode",
|
|
62
|
+
choices=["fp", "uint8"],
|
|
63
|
+
default="fp",
|
|
64
|
+
help="Choose FP baseline only or full UINT8 PTQ path.",
|
|
65
|
+
)
|
|
66
|
+
parser.add_argument(
|
|
67
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--device",
|
|
71
|
+
type=str,
|
|
72
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
73
|
+
help="Device to run on (cuda|cpu).",
|
|
74
|
+
)
|
|
75
|
+
parser.add_argument(
|
|
76
|
+
"--dtype",
|
|
77
|
+
choices=list(DTYPE_MAP.keys()),
|
|
78
|
+
default="float32",
|
|
79
|
+
help=f"Model dtype for load.",
|
|
80
|
+
)
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--stride", type=int, default=512, help="Sliding-window stride for perplexity."
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"--trust-remote-code",
|
|
87
|
+
action="store_true",
|
|
88
|
+
help="Enable only if you trust the model repo code.",
|
|
89
|
+
)
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--hf-token",
|
|
92
|
+
type=str,
|
|
93
|
+
default=None,
|
|
94
|
+
help="Optional HF token for gated/private models.",
|
|
95
|
+
)
|
|
96
|
+
parser.add_argument(
|
|
97
|
+
"--use-cache",
|
|
98
|
+
dest="use_cache",
|
|
99
|
+
action="store_true",
|
|
100
|
+
default=False,
|
|
101
|
+
help="Use model KV cache if enabled (off by default).",
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
105
|
+
)
|
|
106
|
+
# 2) calib-preset default = debug
|
|
107
|
+
parser.add_argument(
|
|
108
|
+
"--calib-preset",
|
|
109
|
+
choices=list(TOKENS.keys()),
|
|
110
|
+
default="debug",
|
|
111
|
+
help="Calibration token budget preset.",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
args = parser.parse_args()
|
|
115
|
+
|
|
116
|
+
# Basic setup
|
|
117
|
+
torch.manual_seed(args.seed)
|
|
118
|
+
device = torch.device(args.device)
|
|
119
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
120
|
+
|
|
121
|
+
print("=== Config ===")
|
|
122
|
+
print(f"Mode : {args.mode}")
|
|
123
|
+
print(f"Model : {args.model}")
|
|
124
|
+
print(f"Device : {device.type}")
|
|
125
|
+
print(f"DType : {args.dtype}")
|
|
126
|
+
print(f"Stride : {args.stride}")
|
|
127
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
128
|
+
print(
|
|
129
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
130
|
+
)
|
|
131
|
+
print()
|
|
132
|
+
|
|
133
|
+
# -------------------------------------------------------------------------
|
|
134
|
+
# 1. Load model and tokenizer
|
|
135
|
+
# -------------------------------------------------------------------------
|
|
136
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
137
|
+
args.model,
|
|
138
|
+
trust_remote_code=args.trust_remote_code,
|
|
139
|
+
token=args.hf_token,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
model = (
|
|
143
|
+
AutoModelForCausalLM.from_pretrained(
|
|
144
|
+
args.model,
|
|
145
|
+
torch_dtype=dtype,
|
|
146
|
+
trust_remote_code=args.trust_remote_code,
|
|
147
|
+
token=args.hf_token,
|
|
148
|
+
)
|
|
149
|
+
.to(device)
|
|
150
|
+
.eval()
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
model.config.use_cache = args.use_cache
|
|
154
|
+
|
|
155
|
+
if args.mode == "fp":
|
|
156
|
+
fp_model = model
|
|
157
|
+
else:
|
|
158
|
+
# INT8 PTQ path
|
|
159
|
+
uint8_model = model
|
|
160
|
+
|
|
161
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
162
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
163
|
+
|
|
164
|
+
# ---------------------------------------------------------------------
|
|
165
|
+
# 2. Wrap every Transformer layer with PTQWrapper
|
|
166
|
+
# ---------------------------------------------------------------------
|
|
167
|
+
qcfg = PTQConfig() # all-uint8 defaults
|
|
168
|
+
prepare(uint8_model, qcfg)
|
|
169
|
+
|
|
170
|
+
# ---------------------------------------------------------------------
|
|
171
|
+
# 3. Single-pass activation calibration
|
|
172
|
+
# ---------------------------------------------------------------------
|
|
173
|
+
print("Calibrating UINT-8 observers …")
|
|
174
|
+
calib_txt = " ".join(
|
|
175
|
+
load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)["text"]
|
|
176
|
+
)[:CALIB_TOKENS]
|
|
177
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
178
|
+
|
|
179
|
+
# Run inference to collect ranges
|
|
180
|
+
iterator = range(0, ids.size(1) - 1, args.stride)
|
|
181
|
+
if not args.no_tqdm:
|
|
182
|
+
iterator = tqdm.tqdm(iterator, desc="Calibration")
|
|
183
|
+
with torch.no_grad():
|
|
184
|
+
for i in iterator:
|
|
185
|
+
uint8_model(ids[:, i : i + args.stride])
|
|
186
|
+
|
|
187
|
+
# Freeze (scale, zero-point)
|
|
188
|
+
convert(uint8_model)
|
|
189
|
+
|
|
190
|
+
# -------------------------------------------------------------------------
|
|
191
|
+
# 4. Evaluate perplexity
|
|
192
|
+
# -------------------------------------------------------------------------
|
|
193
|
+
print("\nCalculating perplexities …")
|
|
194
|
+
test_ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
195
|
+
enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
|
|
196
|
+
|
|
197
|
+
if args.mode == "fp":
|
|
198
|
+
ppl_fp = perplexity(
|
|
199
|
+
fp_model,
|
|
200
|
+
enc,
|
|
201
|
+
args.device,
|
|
202
|
+
stride=args.stride,
|
|
203
|
+
show_progress=not args.no_tqdm,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
ppl_int8 = perplexity(
|
|
207
|
+
uint8_model,
|
|
208
|
+
enc,
|
|
209
|
+
args.device,
|
|
210
|
+
stride=args.stride,
|
|
211
|
+
show_progress=not args.no_tqdm,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# -------------------------------------------------------------------------
|
|
215
|
+
# 5. Report
|
|
216
|
+
# -------------------------------------------------------------------------
|
|
217
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
218
|
+
if args.mode == "fp":
|
|
219
|
+
print(f"│ FP : {ppl_fp:8.2f}")
|
|
220
|
+
else:
|
|
221
|
+
print(f"│ UINT-8 : {ppl_int8:8.2f}")
|
|
222
|
+
print("└───────────────────────────────────────────")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
if __name__ == "__main__":
|
|
226
|
+
try:
|
|
227
|
+
main()
|
|
228
|
+
except Exception as e:
|
|
229
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
230
|
+
sys.exit(1)
|