tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# PTQ + GPTQ HYBRID QUANTIZATION PIPELINE
|
|
17
|
+
# -----------------------------------------------------------------------------
|
|
18
|
+
# This script shows how to:
|
|
19
|
+
# 1. Load a pretrained FP Llama-3 model.
|
|
20
|
+
# 2. Run GPTQ to quantize weights only.
|
|
21
|
+
# 3. Wrap every Transformer layer with a PTQWrapper to quantize activations.
|
|
22
|
+
# 4. Calibrate UINT-8 observers in a single pass over a text corpus.
|
|
23
|
+
# 5. Inject GPTQ’s per-tensor weight scales / zero-points into the PTQ graph.
|
|
24
|
+
# 6. Freeze all Q-params and compute Wikitext-2 perplexity.
|
|
25
|
+
# =============================================================================
|
|
26
|
+
|
|
27
|
+
import argparse
|
|
28
|
+
import sys
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
import tqdm
|
|
33
|
+
from datasets import load_dataset
|
|
34
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
35
|
+
|
|
36
|
+
from tico.quantization import convert, prepare
|
|
37
|
+
from tico.quantization.config.gptq import GPTQConfig
|
|
38
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
39
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
40
|
+
from tico.quantization.wrapq.utils.introspection import build_fqn_map
|
|
41
|
+
from tico.quantization.wrapq.utils.metrics import perplexity
|
|
42
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
43
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Token-budget presets for activation calibration
|
|
47
|
+
TOKENS: dict[str, int] = {
|
|
48
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
|
49
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
|
50
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
|
51
|
+
"baseline": 50_000,
|
|
52
|
+
# Production / 4-bit observer smoothing
|
|
53
|
+
"production": 200_000,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
DTYPE_MAP = {
|
|
57
|
+
"float32": torch.float32,
|
|
58
|
+
"bfloat16": torch.bfloat16,
|
|
59
|
+
"float16": torch.float16,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
# Hardcoded dataset settings
|
|
63
|
+
DATASET_NAME = "wikitext"
|
|
64
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
65
|
+
TRAIN_SPLIT = "train"
|
|
66
|
+
TEST_SPLIT = "test"
|
|
67
|
+
|
|
68
|
+
# -------------------------------------------------------------------------
|
|
69
|
+
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
|
|
70
|
+
# -------------------------------------------------------------------------
|
|
71
|
+
def inject_gptq_qparams(
|
|
72
|
+
root: torch.nn.Module,
|
|
73
|
+
gptq_quantizers: dict[str, Any], # {fp_name: quantizer}
|
|
74
|
+
weight_obs_name: str = "weight",
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
For every `QuantModuleBase` whose `fp_name` matches a GPTQ key,
|
|
78
|
+
locate the observer called `weight_obs_name` and overwrite its
|
|
79
|
+
(scale, zero-point), then lock them against further updates.
|
|
80
|
+
"""
|
|
81
|
+
for m in root.modules():
|
|
82
|
+
if not isinstance(m, QuantModuleBase):
|
|
83
|
+
continue
|
|
84
|
+
if m.fp_name is None:
|
|
85
|
+
continue
|
|
86
|
+
quantizer = gptq_quantizers.get(m.fp_name)
|
|
87
|
+
if quantizer is None:
|
|
88
|
+
continue
|
|
89
|
+
obs = m.get_observer(weight_obs_name)
|
|
90
|
+
if obs is None:
|
|
91
|
+
continue
|
|
92
|
+
assert isinstance(obs, AffineObserverBase)
|
|
93
|
+
# GPTQ quantizer attributes
|
|
94
|
+
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def main():
|
|
98
|
+
parser = argparse.ArgumentParser(
|
|
99
|
+
description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--device",
|
|
106
|
+
type=str,
|
|
107
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
108
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument(
|
|
111
|
+
"--dtype",
|
|
112
|
+
choices=list(DTYPE_MAP.keys()),
|
|
113
|
+
default="float32",
|
|
114
|
+
help="Model dtype for load.",
|
|
115
|
+
)
|
|
116
|
+
parser.add_argument(
|
|
117
|
+
"--stride",
|
|
118
|
+
type=int,
|
|
119
|
+
default=512,
|
|
120
|
+
help="Sliding-window stride used for calibration and eval.",
|
|
121
|
+
)
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--calib-preset",
|
|
124
|
+
choices=list(TOKENS.keys()),
|
|
125
|
+
default="debug",
|
|
126
|
+
help="Activation calibration token budget preset.",
|
|
127
|
+
)
|
|
128
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"--trust-remote-code",
|
|
131
|
+
action="store_true",
|
|
132
|
+
help="Enable only if you trust the model repo code.",
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--hf-token",
|
|
136
|
+
type=str,
|
|
137
|
+
default=None,
|
|
138
|
+
help="Optional HF token for gated/private repos.",
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--use-cache",
|
|
142
|
+
dest="use_cache",
|
|
143
|
+
action="store_true",
|
|
144
|
+
default=False,
|
|
145
|
+
help="Use model KV cache if enabled (off by default).",
|
|
146
|
+
)
|
|
147
|
+
parser.add_argument(
|
|
148
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
args = parser.parse_args()
|
|
152
|
+
|
|
153
|
+
# Basic setup
|
|
154
|
+
torch.manual_seed(args.seed)
|
|
155
|
+
device = torch.device(args.device)
|
|
156
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
157
|
+
|
|
158
|
+
print("=== Config ===")
|
|
159
|
+
print(f"Model : {args.model}")
|
|
160
|
+
print(f"Device : {device.type}")
|
|
161
|
+
print(f"DType : {args.dtype}")
|
|
162
|
+
print(f"Stride : {args.stride}")
|
|
163
|
+
print(
|
|
164
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
165
|
+
)
|
|
166
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
167
|
+
print()
|
|
168
|
+
|
|
169
|
+
# -------------------------------------------------------------------------
|
|
170
|
+
# 2. Load the FP backbone and tokenizer
|
|
171
|
+
# -------------------------------------------------------------------------
|
|
172
|
+
print("Loading FP model …")
|
|
173
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
174
|
+
args.model,
|
|
175
|
+
trust_remote_code=args.trust_remote_code,
|
|
176
|
+
token=args.hf_token,
|
|
177
|
+
)
|
|
178
|
+
model = (
|
|
179
|
+
AutoModelForCausalLM.from_pretrained(
|
|
180
|
+
args.model,
|
|
181
|
+
torch_dtype=dtype,
|
|
182
|
+
trust_remote_code=args.trust_remote_code,
|
|
183
|
+
token=args.hf_token,
|
|
184
|
+
)
|
|
185
|
+
.to(device)
|
|
186
|
+
.eval()
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
model.config.use_cache = args.use_cache
|
|
190
|
+
|
|
191
|
+
# Build module -> FQN map BEFORE wrapping
|
|
192
|
+
m_to_fqn = build_fqn_map(model)
|
|
193
|
+
|
|
194
|
+
# -------------------------------------------------------------------------
|
|
195
|
+
# 3. Run GPTQ (weight-only) pass
|
|
196
|
+
# -------------------------------------------------------------------------
|
|
197
|
+
print("Applying GPTQ …")
|
|
198
|
+
dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
199
|
+
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
200
|
+
|
|
201
|
+
it = (
|
|
202
|
+
dataset_test
|
|
203
|
+
if args.no_tqdm
|
|
204
|
+
else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
|
|
205
|
+
)
|
|
206
|
+
for d in it:
|
|
207
|
+
ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
|
|
208
|
+
q_m(ids) # observers gather weight stats
|
|
209
|
+
|
|
210
|
+
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
|
|
211
|
+
|
|
212
|
+
# -------------------------------------------------------------------------
|
|
213
|
+
# 4. Wrap every layer with PTQWrapper (activation UINT-8)
|
|
214
|
+
# -------------------------------------------------------------------------
|
|
215
|
+
print("Wrapping layers with PTQWrapper …")
|
|
216
|
+
qcfg = PTQConfig() # default: per-tensor UINT8
|
|
217
|
+
prepare(q_m, qcfg)
|
|
218
|
+
|
|
219
|
+
# -------------------------------------------------------------------------
|
|
220
|
+
# 5. Single-pass activation calibration
|
|
221
|
+
# -------------------------------------------------------------------------
|
|
222
|
+
print("Calibrating UINT-8 observers …")
|
|
223
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
224
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
225
|
+
dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
226
|
+
calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
|
|
227
|
+
train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
228
|
+
|
|
229
|
+
# Overwrite weight observers with GPTQ statistics
|
|
230
|
+
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
|
|
231
|
+
inject_gptq_qparams(q_m, q_m.quantizers)
|
|
232
|
+
else:
|
|
233
|
+
print(
|
|
234
|
+
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Forward passes to collect activation ranges
|
|
238
|
+
iterator = range(0, train_ids.size(1) - 1, args.stride)
|
|
239
|
+
if not args.no_tqdm:
|
|
240
|
+
iterator = tqdm.tqdm(iterator, desc="Act-calibration")
|
|
241
|
+
with torch.no_grad():
|
|
242
|
+
for i in iterator:
|
|
243
|
+
q_m(train_ids[:, i : i + args.stride])
|
|
244
|
+
|
|
245
|
+
# Freeze all Q-params (scale, zero-point)
|
|
246
|
+
convert(q_m)
|
|
247
|
+
|
|
248
|
+
# -------------------------------------------------------------------------
|
|
249
|
+
# 6. Evaluate perplexity on Wikitext-2
|
|
250
|
+
# -------------------------------------------------------------------------
|
|
251
|
+
print("\nCalculating perplexities …")
|
|
252
|
+
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
|
|
253
|
+
ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
|
|
254
|
+
|
|
255
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
256
|
+
print(f"│ UINT-8 : {ppl_uint8:8.2f}")
|
|
257
|
+
print("└───────────────────────────────────────────")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
if __name__ == "__main__":
|
|
261
|
+
try:
|
|
262
|
+
main()
|
|
263
|
+
except Exception as e:
|
|
264
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
265
|
+
sys.exit(1)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from enum import auto, Enum
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Mode(Enum):
|
|
19
|
+
"""
|
|
20
|
+
Mode — global FSM for PTQWrapper & Handlers.
|
|
21
|
+
|
|
22
|
+
• NO_QUANT : pure pass-through (no stats, no fake-quant)
|
|
23
|
+
• CALIB : collect observer statistics only
|
|
24
|
+
• QUANT : use cached (scale, zero-point) → fake-quant enabled
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
NO_QUANT = auto()
|
|
28
|
+
CALIB = auto()
|
|
29
|
+
QUANT = auto()
|
|
30
|
+
|
|
31
|
+
def __str__(self) -> str:
|
|
32
|
+
return self.name.lower()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
from typing import Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from tico.quantization.wrapq.dtypes import DType, UINT8
|
|
21
|
+
from tico.quantization.wrapq.observers.base import ObserverBase
|
|
22
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AffineObserverBase(ObserverBase):
|
|
26
|
+
"""Base for affine observers (min/max → scale/zp)."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
*,
|
|
31
|
+
name: str,
|
|
32
|
+
dtype: DType = UINT8,
|
|
33
|
+
qscheme: QScheme = QScheme.PER_TENSOR_ASYMM,
|
|
34
|
+
channel_axis: Optional[int] = None,
|
|
35
|
+
):
|
|
36
|
+
super().__init__(
|
|
37
|
+
name=name, dtype=dtype, qscheme=qscheme, channel_axis=channel_axis
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def reset(self) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Reset running min/max and drop cached qparams.
|
|
43
|
+
"""
|
|
44
|
+
self.min_val: torch.Tensor = torch.tensor(math.inf)
|
|
45
|
+
self.max_val: torch.Tensor = torch.tensor(-math.inf)
|
|
46
|
+
if hasattr(self, "_cached_scale"):
|
|
47
|
+
del self._cached_scale
|
|
48
|
+
if hasattr(self, "_cached_zp"):
|
|
49
|
+
del self._cached_zp
|
|
50
|
+
|
|
51
|
+
def load_qparams(self, scale: torch.Tensor, zp: torch.Tensor, *, lock: bool = True):
|
|
52
|
+
"""
|
|
53
|
+
Inject externally computed qparams and optionally lock the observer.
|
|
54
|
+
|
|
55
|
+
When locked, subsequent `collect()` calls are ignored.
|
|
56
|
+
"""
|
|
57
|
+
self._cached_scale = scale.detach()
|
|
58
|
+
self._cached_zp = zp.to(torch.int)
|
|
59
|
+
if lock:
|
|
60
|
+
self.enabled = False
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def has_qparams(self) -> bool:
|
|
64
|
+
return hasattr(self, "_cached_scale")
|
|
65
|
+
|
|
66
|
+
def compute_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
67
|
+
qmin, qmax = self.dtype.qmin, self.dtype.qmax
|
|
68
|
+
rng = self.max_val - self.min_val
|
|
69
|
+
eps = 1e-12
|
|
70
|
+
|
|
71
|
+
if self.qscheme.is_symmetric():
|
|
72
|
+
max_abs = torch.maximum(self.max_val.abs(), self.min_val.abs())
|
|
73
|
+
scale = torch.clamp(max_abs, min=eps) / qmax
|
|
74
|
+
zp = torch.zeros_like(scale, dtype=torch.int)
|
|
75
|
+
self._cached_scale, self._cached_zp = scale, zp
|
|
76
|
+
return scale, zp
|
|
77
|
+
|
|
78
|
+
if self.channel_axis is None:
|
|
79
|
+
if torch.all(rng.abs() < 1e-8):
|
|
80
|
+
C = self.min_val
|
|
81
|
+
if torch.allclose(C, torch.zeros_like(C)):
|
|
82
|
+
scale = torch.ones_like(C)
|
|
83
|
+
zp = torch.zeros_like(C, dtype=torch.int)
|
|
84
|
+
elif (C > 0).all():
|
|
85
|
+
scale = torch.clamp(C, min=eps)
|
|
86
|
+
zp = torch.zeros_like(C, dtype=torch.int)
|
|
87
|
+
else:
|
|
88
|
+
scale = torch.clamp(C.abs(), min=eps)
|
|
89
|
+
zp = torch.full_like(C, qmax, dtype=torch.int)
|
|
90
|
+
else:
|
|
91
|
+
scale = torch.clamp(rng, min=eps) / (qmax - qmin)
|
|
92
|
+
zp = (
|
|
93
|
+
torch.round(qmin - self.min_val / scale)
|
|
94
|
+
.clamp(qmin, qmax)
|
|
95
|
+
.to(torch.int)
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
scale = torch.clamp(rng, min=eps) / (qmax - qmin)
|
|
99
|
+
zp = (
|
|
100
|
+
torch.round(qmin - self.min_val / scale).clamp(qmin, qmax).to(torch.int)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self._cached_scale, self._cached_zp = scale, zp
|
|
104
|
+
return scale, zp
|
|
105
|
+
|
|
106
|
+
def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
|
|
107
|
+
if not self.has_qparams:
|
|
108
|
+
raise RuntimeError(
|
|
109
|
+
"Call compute_qparams()/freeze_qparams() or load_qparams() first."
|
|
110
|
+
)
|
|
111
|
+
scale, zp = self._cached_scale, self._cached_zp
|
|
112
|
+
if self.channel_axis is None:
|
|
113
|
+
return torch.fake_quantize_per_tensor_affine(
|
|
114
|
+
x,
|
|
115
|
+
scale=scale,
|
|
116
|
+
zero_point=zp,
|
|
117
|
+
quant_min=self.dtype.qmin,
|
|
118
|
+
quant_max=self.dtype.qmax,
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
return torch.fake_quantize_per_channel_affine(
|
|
122
|
+
x,
|
|
123
|
+
scale=scale,
|
|
124
|
+
zero_point=zp,
|
|
125
|
+
axis=self.channel_axis,
|
|
126
|
+
quant_min=self.dtype.qmin,
|
|
127
|
+
quant_max=self.dtype.qmax,
|
|
128
|
+
)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from typing import Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from tico.quantization.wrapq.dtypes import DType, UINT8
|
|
21
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ObserverBase(ABC):
|
|
25
|
+
"""
|
|
26
|
+
Minimal abstract base for all observers/quantizers.
|
|
27
|
+
|
|
28
|
+
Subclasses must implement:
|
|
29
|
+
- reset()
|
|
30
|
+
- collect(x)
|
|
31
|
+
- fake_quant(x)
|
|
32
|
+
- compute_qparams(): optional in practice for some observers (e.g., MX),
|
|
33
|
+
but still part of the interface; those can return None.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
name: str,
|
|
40
|
+
dtype: DType = UINT8,
|
|
41
|
+
qscheme: QScheme = QScheme.PER_TENSOR_ASYMM,
|
|
42
|
+
channel_axis: Optional[int] = None, # None → per-tensor
|
|
43
|
+
):
|
|
44
|
+
self.name = name
|
|
45
|
+
self.dtype = dtype
|
|
46
|
+
self.qscheme = qscheme
|
|
47
|
+
self.channel_axis = channel_axis if qscheme.is_per_channel() else None
|
|
48
|
+
self.enabled = True
|
|
49
|
+
self.reset()
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def reset(self) -> None:
|
|
53
|
+
"""Clear any running statistics or cached params."""
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
def collect(self, x: torch.Tensor) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Update running statistics with a new batch of data.
|
|
59
|
+
|
|
60
|
+
This base implementation guards on `enabled` and then calls `_update_stats(x)`.
|
|
61
|
+
Subclasses should implement `_update_stats(x)` instead of overriding `collect`.
|
|
62
|
+
"""
|
|
63
|
+
if not self.enabled:
|
|
64
|
+
return
|
|
65
|
+
self._update_stats(x)
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def _update_stats(self, x: torch.Tensor) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Update running statistics (min/max, hist, mse buffers, ...).
|
|
71
|
+
|
|
72
|
+
Must be implemented by subclasses (e.g., MinMax, EMA, Histogram, MSE).
|
|
73
|
+
"""
|
|
74
|
+
raise NotImplementedError
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
"""
|
|
79
|
+
Apply the observer's quantization.
|
|
80
|
+
Implementations may or may not rely on qparams.
|
|
81
|
+
"""
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def compute_qparams(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
|
86
|
+
"""
|
|
87
|
+
Compute and (if applicable) cache quantization params.
|
|
88
|
+
Affine observers typically return (scale, zero_point).
|
|
89
|
+
Observers that do not use qparams (e.g., MX) may return None.
|
|
90
|
+
"""
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
# String repr helps debugging
|
|
94
|
+
def __repr__(self) -> str:
|
|
95
|
+
return (
|
|
96
|
+
f"{self.__class__.__name__}(name={self.name}, dtype={str(self.dtype)}, "
|
|
97
|
+
f"qscheme={str(self.qscheme)}, channel_axis={self.channel_axis}, enabled={self.enabled})"
|
|
98
|
+
)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
18
|
+
from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EMAObserver(AffineObserverBase):
|
|
22
|
+
"""
|
|
23
|
+
Exponential-Moving-Average min/max tracker.
|
|
24
|
+
|
|
25
|
+
Why?
|
|
26
|
+
-----
|
|
27
|
+
• Smoother than raw MinMax (reduces outlier shock).
|
|
28
|
+
• Much cheaper than histogram/MSE observers.
|
|
29
|
+
|
|
30
|
+
The update rule follows the common "momentum" form:
|
|
31
|
+
|
|
32
|
+
ema = momentum * ema + (1 - momentum) * new_value
|
|
33
|
+
|
|
34
|
+
With momentum → 0: FAST adaptation, momentum → 1: SLOW adaptation.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
momentum: float = 0.9,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
assert 0.0 < momentum < 1.0, "momentum must be in (0, 1)"
|
|
45
|
+
self.momentum = momentum
|
|
46
|
+
|
|
47
|
+
@torch.no_grad()
|
|
48
|
+
def _update_stats(self, x: torch.Tensor):
|
|
49
|
+
if self.channel_axis is None:
|
|
50
|
+
curr_min, curr_max = x.min(), x.max()
|
|
51
|
+
else:
|
|
52
|
+
curr_min, curr_max = channelwise_minmax(x, self.channel_axis)
|
|
53
|
+
|
|
54
|
+
if (
|
|
55
|
+
torch.isinf(self.min_val).any() and torch.isinf(self.max_val).any()
|
|
56
|
+
): # first batch → hard init
|
|
57
|
+
self.min_val, self.max_val = curr_min, curr_max
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
m = self.momentum
|
|
61
|
+
self.min_val = m * self.min_val + (1 - m) * curr_min
|
|
62
|
+
self.max_val = m * self.max_val + (1 - m) * curr_max
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
IdentityObserver: a "no-op" observer for FP-only modules.
|
|
17
|
+
|
|
18
|
+
Motivation
|
|
19
|
+
----------
|
|
20
|
+
Some layers should stay in full precision even when the rest of the model
|
|
21
|
+
is quantized. Attaching an `IdentityObserver` satisfies the wrapper API
|
|
22
|
+
(`_update_stats()`, `compute_qparams()`, `fake_quant()`) without actually
|
|
23
|
+
performing any statistics gathering or fake-quantization.
|
|
24
|
+
"""
|
|
25
|
+
import torch
|
|
26
|
+
|
|
27
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class IdentityObserver(AffineObserverBase):
|
|
31
|
+
"""
|
|
32
|
+
Passthrough observer that NEVER alters the tensor.
|
|
33
|
+
|
|
34
|
+
• `_update_stats()` → does nothing
|
|
35
|
+
• `compute_qparams()` → returns (1.0, 0) "dummy" q-params
|
|
36
|
+
• `fake_quant()` → returns `x` unchanged
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, **kwargs):
|
|
40
|
+
# Call parent so the usual fields (`dtype`, `qscheme`, …) exist,
|
|
41
|
+
# but immediately disable any stateful behaviour.
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
|
|
44
|
+
# Deactivate statistics collection permanently.
|
|
45
|
+
self.enabled = False
|
|
46
|
+
|
|
47
|
+
# Pre-cache sentinel q-params so wrapper code that blindly
|
|
48
|
+
# accesses them won't crash.
|
|
49
|
+
self._cached_scale = torch.tensor(1.0)
|
|
50
|
+
self._cached_zp = torch.tensor(0, dtype=torch.int)
|
|
51
|
+
|
|
52
|
+
def reset(self) -> None: # (simple override – nothing to do)
|
|
53
|
+
"""No internal state to reset."""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
def _update_stats(self, x: torch.Tensor) -> None:
|
|
57
|
+
"""Skip statistic collection entirely."""
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
def compute_qparams(self):
|
|
61
|
+
"""
|
|
62
|
+
Return the pre-cached (scale, zero_point) tuple.
|
|
63
|
+
|
|
64
|
+
Keeping the signature identical to other observers allows uniform
|
|
65
|
+
lifecycle management in wrapper code.
|
|
66
|
+
"""
|
|
67
|
+
return self._cached_scale, self._cached_zp
|
|
68
|
+
|
|
69
|
+
def fake_quant(self, x: torch.Tensor):
|
|
70
|
+
"""Identity mapping — leaves `x` in FP."""
|
|
71
|
+
return x
|
|
72
|
+
|
|
73
|
+
def __repr__(self) -> str:
|
|
74
|
+
return f"{self.__class__.__name__}()"
|