tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/fpi_gptq.py +29 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +20 -15
- tico/utils/register_custom_op.py +6 -4
- tico/utils/signature.py +7 -8
- tico/utils/validate_args_kwargs.py +12 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
- /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
- /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# ============================================================================
|
|
16
|
+
# LAYER-WISE DIFF DEBUGGING PIPELINE
|
|
17
|
+
# ----------------------------------------------------------------------------
|
|
18
|
+
# A quantization debugging pipeline that identifies accuracy regressions
|
|
19
|
+
# by comparing UINT vs FP outputs at each layer.
|
|
20
|
+
#
|
|
21
|
+
# 1. Load a full-precision (FP) LLaMA-3-1B model.
|
|
22
|
+
# 2. Wrap each Transformer block with PTQWrapper (activations → fake-quant).
|
|
23
|
+
# 3. Capture reference FP layer outputs before quantization.
|
|
24
|
+
# 4. Calibrate UINT-8 activation observers in a single pass.
|
|
25
|
+
# 5. Freeze quantization parameters (scale, zero-point).
|
|
26
|
+
# 6. Re-run inference and compare UINT-8 vs FP outputs per layer.
|
|
27
|
+
# 7. Report where quantization hurts the most.
|
|
28
|
+
#
|
|
29
|
+
# Use this pipeline to trace precision loss layer by layer, and pinpoint
|
|
30
|
+
# problematic modules during post-training quantization.
|
|
31
|
+
# ============================================================================
|
|
32
|
+
|
|
33
|
+
import argparse
|
|
34
|
+
import sys
|
|
35
|
+
|
|
36
|
+
import torch
|
|
37
|
+
import tqdm
|
|
38
|
+
from datasets import load_dataset
|
|
39
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
40
|
+
|
|
41
|
+
from tico.quantization import convert, prepare
|
|
42
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
43
|
+
from tico.quantization.wrapq.utils.introspection import (
|
|
44
|
+
build_fqn_map,
|
|
45
|
+
compare_layer_outputs,
|
|
46
|
+
save_fp_outputs,
|
|
47
|
+
)
|
|
48
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
49
|
+
|
|
50
|
+
# Token-budget presets for activation calibration
|
|
51
|
+
TOKENS: dict[str, int] = {
|
|
52
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
|
53
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
|
54
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
|
55
|
+
"baseline": 50_000,
|
|
56
|
+
# Production / 4-bit observer smoothing
|
|
57
|
+
"production": 200_000,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
DTYPE_MAP = {
|
|
61
|
+
"float32": torch.float32,
|
|
62
|
+
"bfloat16": torch.bfloat16,
|
|
63
|
+
"float16": torch.float16,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Hardcoded dataset settings
|
|
67
|
+
DATASET_NAME = "wikitext"
|
|
68
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
69
|
+
TRAIN_SPLIT = "train"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def main():
|
|
73
|
+
parser = argparse.ArgumentParser(
|
|
74
|
+
description="Layer-wise diff debugging pipeline for PTQ"
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
78
|
+
)
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--device",
|
|
81
|
+
type=str,
|
|
82
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
83
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
84
|
+
)
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"--dtype",
|
|
87
|
+
choices=list(DTYPE_MAP.keys()),
|
|
88
|
+
default="float32",
|
|
89
|
+
help=f"Model dtype for load.",
|
|
90
|
+
)
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--stride",
|
|
93
|
+
type=int,
|
|
94
|
+
default=512,
|
|
95
|
+
help="Sliding-window stride used during calibration.",
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--calib-preset",
|
|
99
|
+
choices=list(TOKENS.keys()),
|
|
100
|
+
default="debug",
|
|
101
|
+
help="Calibration token budget preset.",
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--trust-remote-code",
|
|
106
|
+
action="store_true",
|
|
107
|
+
help="Enable only if you trust the model repo code.",
|
|
108
|
+
)
|
|
109
|
+
parser.add_argument(
|
|
110
|
+
"--hf-token",
|
|
111
|
+
type=str,
|
|
112
|
+
default=None,
|
|
113
|
+
help="Optional HF token for gated/private repos.",
|
|
114
|
+
)
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--use-cache",
|
|
117
|
+
dest="use_cache",
|
|
118
|
+
action="store_true",
|
|
119
|
+
default=False,
|
|
120
|
+
help="Use model KV cache if enabled (off by default).",
|
|
121
|
+
)
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
args = parser.parse_args()
|
|
127
|
+
|
|
128
|
+
# Basic setup
|
|
129
|
+
torch.manual_seed(args.seed)
|
|
130
|
+
device = torch.device(args.device)
|
|
131
|
+
dtype = DTYPE_MAP[args.dtype] # noqa: E999 (kept readable)
|
|
132
|
+
|
|
133
|
+
print("=== Config ===")
|
|
134
|
+
print(f"Model : {args.model}")
|
|
135
|
+
print(f"Device : {device.type}")
|
|
136
|
+
print(f"DType : {args.dtype}")
|
|
137
|
+
print(f"Stride : {args.stride}")
|
|
138
|
+
print(
|
|
139
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
140
|
+
)
|
|
141
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
142
|
+
print()
|
|
143
|
+
|
|
144
|
+
# -------------------------------------------------------------------------
|
|
145
|
+
# 1. Load the FP backbone and tokenizer
|
|
146
|
+
# -------------------------------------------------------------------------
|
|
147
|
+
print("Loading FP model …")
|
|
148
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
149
|
+
args.model,
|
|
150
|
+
trust_remote_code=args.trust_remote_code,
|
|
151
|
+
token=args.hf_token,
|
|
152
|
+
)
|
|
153
|
+
model = (
|
|
154
|
+
AutoModelForCausalLM.from_pretrained(
|
|
155
|
+
args.model,
|
|
156
|
+
torch_dtype=dtype,
|
|
157
|
+
trust_remote_code=args.trust_remote_code,
|
|
158
|
+
token=args.hf_token,
|
|
159
|
+
)
|
|
160
|
+
.to(device)
|
|
161
|
+
.eval()
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Disable KV cache to force full forward passes for introspection
|
|
165
|
+
model.config.use_cache = args.use_cache
|
|
166
|
+
|
|
167
|
+
# Build module -> FQN map before wrapping
|
|
168
|
+
m_to_fqn = build_fqn_map(model)
|
|
169
|
+
|
|
170
|
+
# Prepare calibration inputs (HF Wikitext-2 train split)
|
|
171
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
172
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
173
|
+
# Use Wikitext-2 train split for calibration.
|
|
174
|
+
dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
175
|
+
|
|
176
|
+
# -------------------------------------------------------------------------
|
|
177
|
+
# 2. Wrap every layer with PTQWrapper (UINT-8 activations)
|
|
178
|
+
# -------------------------------------------------------------------------
|
|
179
|
+
print("Wrapping layers with PTQWrapper …")
|
|
180
|
+
qcfg = PTQConfig() # default: per-tensor UINT8
|
|
181
|
+
prepare(model, qcfg)
|
|
182
|
+
|
|
183
|
+
# -------------------------------------------------------------------------
|
|
184
|
+
# 3. Activation calibration plus FP-vs-UINT8 diffing
|
|
185
|
+
# -------------------------------------------------------------------------
|
|
186
|
+
print("Calibrating UINT-8 observers …")
|
|
187
|
+
calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
|
|
188
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
189
|
+
|
|
190
|
+
# Save reference FP activations before observers clamp/quantize
|
|
191
|
+
save_handles, act_cache = save_fp_outputs(model)
|
|
192
|
+
|
|
193
|
+
iterator = range(0, ids.size(1) - 1, args.stride)
|
|
194
|
+
if not args.no_tqdm:
|
|
195
|
+
iterator = tqdm.tqdm(iterator, desc="Act-Calibration")
|
|
196
|
+
with torch.no_grad():
|
|
197
|
+
for i in iterator:
|
|
198
|
+
inputs = ids[:, i : i + args.stride]
|
|
199
|
+
model(inputs) # observers collect act. ranges
|
|
200
|
+
|
|
201
|
+
# Remove save hooks now that FP activations are cached
|
|
202
|
+
for h in save_handles:
|
|
203
|
+
h.remove()
|
|
204
|
+
|
|
205
|
+
# Freeze (scale, zero-point) after calibration
|
|
206
|
+
convert(model)
|
|
207
|
+
|
|
208
|
+
# Register diff hooks and measure per-layer deltas
|
|
209
|
+
cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
|
|
210
|
+
# Use same inputs for comparison.
|
|
211
|
+
with torch.no_grad():
|
|
212
|
+
model(inputs)
|
|
213
|
+
|
|
214
|
+
assert isinstance(cmp_handles, list)
|
|
215
|
+
for h in cmp_handles:
|
|
216
|
+
h.remove()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
if __name__ == "__main__":
|
|
220
|
+
try:
|
|
221
|
+
main()
|
|
222
|
+
except Exception as e:
|
|
223
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
224
|
+
sys.exit(1)
|
|
@@ -29,13 +29,15 @@ import pathlib
|
|
|
29
29
|
import torch
|
|
30
30
|
import torch.nn as nn
|
|
31
31
|
|
|
32
|
-
from tico.
|
|
33
|
-
from tico.
|
|
34
|
-
|
|
35
|
-
from tico.
|
|
36
|
-
from tico.
|
|
32
|
+
from tico.quantization import convert, prepare
|
|
33
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
34
|
+
from tico.quantization.evaluation.metric import compute_peir
|
|
35
|
+
from tico.quantization.evaluation.utils import plot_two_outputs
|
|
36
|
+
from tico.quantization.wrapq.mode import Mode
|
|
37
|
+
from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear
|
|
37
38
|
from tico.utils.utils import SuppressWarning
|
|
38
39
|
|
|
40
|
+
|
|
39
41
|
# -------------------------------------------------------------------------
|
|
40
42
|
# 0. Define a toy model (1 Linear layer only)
|
|
41
43
|
# -------------------------------------------------------------------------
|
|
@@ -60,20 +62,19 @@ fp32_layer = model.fc
|
|
|
60
62
|
# -------------------------------------------------------------------------
|
|
61
63
|
# 1. Replace the Linear with QuantLinear wrapper
|
|
62
64
|
# -------------------------------------------------------------------------
|
|
63
|
-
model.fc =
|
|
64
|
-
# model.fc = PTQWrapper(fp32_layer) (Wrapping helper class)
|
|
65
|
+
model.fc = prepare(fp32_layer, PTQConfig()) # type: ignore[assignment]
|
|
65
66
|
qlayer = model.fc # alias for brevity
|
|
66
67
|
|
|
67
68
|
# -------------------------------------------------------------------------
|
|
68
69
|
# 2. Single-pass calibration (collect activation ranges)
|
|
69
70
|
# -------------------------------------------------------------------------
|
|
70
|
-
assert isinstance(qlayer, QuantLinear)
|
|
71
|
+
assert isinstance(qlayer.wrapped, QuantLinear)
|
|
71
72
|
with torch.no_grad():
|
|
72
|
-
qlayer.enable_calibration()
|
|
73
73
|
for _ in range(16): # small toy batch
|
|
74
74
|
x = torch.randn(4, 16) # (batch=4, features=16)
|
|
75
75
|
_ = model(x)
|
|
76
|
-
|
|
76
|
+
|
|
77
|
+
convert(qlayer)
|
|
77
78
|
|
|
78
79
|
assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now."
|
|
79
80
|
|
|
@@ -17,13 +17,12 @@ import pathlib
|
|
|
17
17
|
import torch
|
|
18
18
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
|
|
23
|
-
from tico.
|
|
24
|
-
from tico.
|
|
25
|
-
|
|
26
|
-
)
|
|
20
|
+
from tico.quantization import convert, prepare
|
|
21
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
22
|
+
from tico.quantization.evaluation.metric import compute_peir
|
|
23
|
+
from tico.quantization.evaluation.utils import plot_two_outputs
|
|
24
|
+
from tico.quantization.wrapq.mode import Mode
|
|
25
|
+
from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
|
|
27
26
|
from tico.utils.utils import SuppressWarning
|
|
28
27
|
|
|
29
28
|
name = "Maykeye/TinyLLama-v0"
|
|
@@ -34,12 +33,11 @@ tokenizer = AutoTokenizer.from_pretrained(name)
|
|
|
34
33
|
# 1. Replace layer-0’s MLP with QuantLlamaMLP
|
|
35
34
|
# -------------------------------------------------------------------------
|
|
36
35
|
orig_attn = model.model.layers[0].self_attn
|
|
37
|
-
model.model.layers[0].self_attn =
|
|
38
|
-
orig_attn
|
|
39
|
-
) # PTQWrapper(orig_attn) is also fine
|
|
36
|
+
model.model.layers[0].self_attn = prepare(orig_attn, PTQConfig())
|
|
40
37
|
model.eval()
|
|
41
38
|
|
|
42
39
|
attn_q = model.model.layers[0].self_attn # quant wrapper
|
|
40
|
+
assert isinstance(attn_q.wrapped, QuantLlamaAttention)
|
|
43
41
|
rotary = model.model.rotary_emb
|
|
44
42
|
|
|
45
43
|
# -------------------------------------------------------------------------
|
|
@@ -55,7 +53,6 @@ PROMPTS = [
|
|
|
55
53
|
]
|
|
56
54
|
|
|
57
55
|
with torch.no_grad():
|
|
58
|
-
attn_q.enable_calibration()
|
|
59
56
|
for prompt in PROMPTS:
|
|
60
57
|
ids = tokenizer(prompt, return_tensors="pt")
|
|
61
58
|
embeds = model.model.embed_tokens(ids["input_ids"])
|
|
@@ -63,7 +60,8 @@ with torch.no_grad():
|
|
|
63
60
|
S = cos_sin[0].shape[1]
|
|
64
61
|
float_mask = torch.zeros(1, 1, S, S)
|
|
65
62
|
_ = attn_q(embeds, cos_sin) # observers collect
|
|
66
|
-
|
|
63
|
+
|
|
64
|
+
convert(attn_q)
|
|
67
65
|
|
|
68
66
|
assert attn_q._mode is Mode.QUANT, "Quantization mode should be active now."
|
|
69
67
|
|
tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py
RENAMED
|
@@ -31,10 +31,12 @@ import pathlib
|
|
|
31
31
|
import torch
|
|
32
32
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
33
33
|
|
|
34
|
-
from tico.
|
|
35
|
-
from tico.
|
|
36
|
-
from tico.
|
|
37
|
-
from tico.
|
|
34
|
+
from tico.quantization import convert, prepare
|
|
35
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
36
|
+
from tico.quantization.evaluation.metric import compute_peir
|
|
37
|
+
from tico.quantization.evaluation.utils import plot_two_outputs
|
|
38
|
+
from tico.quantization.wrapq.mode import Mode
|
|
39
|
+
from tico.quantization.wrapq.wrappers.llama.quant_decoder_layer import (
|
|
38
40
|
QuantLlamaDecoderLayer,
|
|
39
41
|
)
|
|
40
42
|
from tico.utils.utils import SuppressWarning
|
|
@@ -50,12 +52,11 @@ rotary = model.model.rotary_emb # RoPE helper
|
|
|
50
52
|
# 1. Swap in the quant wrapper
|
|
51
53
|
# -------------------------------------------------------------------------
|
|
52
54
|
fp32_layer = model.model.layers[0] # keep a reference for diff check
|
|
53
|
-
model.model.layers[0] =
|
|
54
|
-
fp32_layer
|
|
55
|
-
) # PTQWrapper(fp32_layer) is also fine
|
|
55
|
+
model.model.layers[0] = prepare(fp32_layer, PTQConfig())
|
|
56
56
|
model.eval()
|
|
57
57
|
|
|
58
58
|
qlayer = model.model.layers[0] # alias for brevity
|
|
59
|
+
assert isinstance(qlayer.wrapped, QuantLlamaDecoderLayer)
|
|
59
60
|
|
|
60
61
|
# -------------------------------------------------------------------------
|
|
61
62
|
# 2. Single-pass calibration (gather activation ranges)
|
|
@@ -70,7 +71,6 @@ PROMPTS = [
|
|
|
70
71
|
]
|
|
71
72
|
|
|
72
73
|
with torch.no_grad():
|
|
73
|
-
qlayer.enable_calibration()
|
|
74
74
|
for prompt in PROMPTS:
|
|
75
75
|
ids = tokenizer(prompt, return_tensors="pt")
|
|
76
76
|
hidden = model.model.embed_tokens(ids["input_ids"])
|
|
@@ -78,7 +78,8 @@ with torch.no_grad():
|
|
|
78
78
|
S = pos[0].shape[1]
|
|
79
79
|
attn_mask = torch.zeros(1, 1, S, S) # causal-mask placeholder
|
|
80
80
|
_ = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos)
|
|
81
|
-
|
|
81
|
+
|
|
82
|
+
convert(qlayer)
|
|
82
83
|
|
|
83
84
|
assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now."
|
|
84
85
|
|
|
@@ -18,13 +18,14 @@ import torch
|
|
|
18
18
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
19
19
|
|
|
20
20
|
import tico
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
23
|
-
from tico.
|
|
24
|
-
from tico.
|
|
25
|
-
from tico.
|
|
26
|
-
from tico.
|
|
27
|
-
from tico.
|
|
21
|
+
from tico.quantization import convert, prepare
|
|
22
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
23
|
+
from tico.quantization.evaluation.metric import compute_peir
|
|
24
|
+
from tico.quantization.evaluation.utils import plot_two_outputs
|
|
25
|
+
from tico.quantization.wrapq.dtypes import INT16
|
|
26
|
+
from tico.quantization.wrapq.mode import Mode
|
|
27
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
28
|
+
from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
|
28
29
|
from tico.utils.utils import SuppressWarning
|
|
29
30
|
|
|
30
31
|
name = "Maykeye/TinyLLama-v0"
|
|
@@ -36,13 +37,13 @@ model.eval()
|
|
|
36
37
|
# 1. Replace layer-0’s MLP with QuantLlamaMLP
|
|
37
38
|
# -------------------------------------------------------------------------
|
|
38
39
|
fp32_mlp = model.model.layers[0].mlp
|
|
39
|
-
model.model.layers[0].mlp =
|
|
40
|
-
fp32_mlp,
|
|
41
|
-
|
|
42
|
-
) # PTQWrapper(fp32_mlp) is also fine
|
|
40
|
+
model.model.layers[0].mlp = prepare(
|
|
41
|
+
fp32_mlp, PTQConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM)
|
|
42
|
+
)
|
|
43
43
|
model.eval()
|
|
44
44
|
|
|
45
45
|
mlp_q = model.model.layers[0].mlp
|
|
46
|
+
assert isinstance(mlp_q.wrapped, QuantLlamaMLP)
|
|
46
47
|
|
|
47
48
|
# -------------------------------------------------------------------------
|
|
48
49
|
# 2. Single-pass calibration
|
|
@@ -57,13 +58,12 @@ PROMPTS = [
|
|
|
57
58
|
]
|
|
58
59
|
|
|
59
60
|
with torch.no_grad():
|
|
60
|
-
mlp_q.enable_calibration()
|
|
61
61
|
for prompt in PROMPTS:
|
|
62
62
|
enc = tokenizer(prompt, return_tensors="pt")
|
|
63
63
|
emb = model.model.embed_tokens(enc["input_ids"])
|
|
64
64
|
_ = mlp_q(emb)
|
|
65
65
|
|
|
66
|
-
|
|
66
|
+
convert(mlp_q)
|
|
67
67
|
|
|
68
68
|
assert mlp_q._mode is Mode.QUANT, "Quantization mode should be active now."
|
|
69
69
|
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# PTQ + GPTQ HYBRID QUANTIZATION PIPELINE
|
|
17
|
+
# -----------------------------------------------------------------------------
|
|
18
|
+
# This script shows how to:
|
|
19
|
+
# 1. Load a pretrained FP Llama-3 model.
|
|
20
|
+
# 2. Run GPTQ to quantize weights only.
|
|
21
|
+
# 3. Wrap every Transformer layer with a PTQWrapper to quantize activations.
|
|
22
|
+
# 4. Calibrate UINT-8 observers in a single pass over a text corpus.
|
|
23
|
+
# 5. Inject GPTQ’s per-tensor weight scales / zero-points into the PTQ graph.
|
|
24
|
+
# 6. Freeze all Q-params and compute Wikitext-2 perplexity.
|
|
25
|
+
# =============================================================================
|
|
26
|
+
|
|
27
|
+
import argparse
|
|
28
|
+
import sys
|
|
29
|
+
from typing import Any
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
import tqdm
|
|
33
|
+
from datasets import load_dataset
|
|
34
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
35
|
+
|
|
36
|
+
from tico.quantization import convert, prepare
|
|
37
|
+
from tico.quantization.config.gptq import GPTQConfig
|
|
38
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
39
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
40
|
+
from tico.quantization.wrapq.utils.introspection import build_fqn_map
|
|
41
|
+
from tico.quantization.wrapq.utils.metrics import perplexity
|
|
42
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
43
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Token-budget presets for activation calibration
|
|
47
|
+
TOKENS: dict[str, int] = {
|
|
48
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
|
49
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
|
50
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
|
51
|
+
"baseline": 50_000,
|
|
52
|
+
# Production / 4-bit observer smoothing
|
|
53
|
+
"production": 200_000,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
DTYPE_MAP = {
|
|
57
|
+
"float32": torch.float32,
|
|
58
|
+
"bfloat16": torch.bfloat16,
|
|
59
|
+
"float16": torch.float16,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
# Hardcoded dataset settings
|
|
63
|
+
DATASET_NAME = "wikitext"
|
|
64
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
65
|
+
TRAIN_SPLIT = "train"
|
|
66
|
+
TEST_SPLIT = "test"
|
|
67
|
+
|
|
68
|
+
# -------------------------------------------------------------------------
|
|
69
|
+
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
|
|
70
|
+
# -------------------------------------------------------------------------
|
|
71
|
+
def inject_gptq_qparams(
|
|
72
|
+
root: torch.nn.Module,
|
|
73
|
+
gptq_quantizers: dict[str, Any], # {fp_name: quantizer}
|
|
74
|
+
weight_obs_name: str = "weight",
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
For every `QuantModuleBase` whose `fp_name` matches a GPTQ key,
|
|
78
|
+
locate the observer called `weight_obs_name` and overwrite its
|
|
79
|
+
(scale, zero-point), then lock them against further updates.
|
|
80
|
+
"""
|
|
81
|
+
for m in root.modules():
|
|
82
|
+
if not isinstance(m, QuantModuleBase):
|
|
83
|
+
continue
|
|
84
|
+
if m.fp_name is None:
|
|
85
|
+
continue
|
|
86
|
+
quantizer = gptq_quantizers.get(m.fp_name)
|
|
87
|
+
if quantizer is None:
|
|
88
|
+
continue
|
|
89
|
+
obs = m.get_observer(weight_obs_name)
|
|
90
|
+
if obs is None:
|
|
91
|
+
continue
|
|
92
|
+
assert isinstance(obs, AffineObserverBase)
|
|
93
|
+
# GPTQ quantizer attributes
|
|
94
|
+
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def main():
|
|
98
|
+
parser = argparse.ArgumentParser(
|
|
99
|
+
description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--device",
|
|
106
|
+
type=str,
|
|
107
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
108
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument(
|
|
111
|
+
"--dtype",
|
|
112
|
+
choices=list(DTYPE_MAP.keys()),
|
|
113
|
+
default="float32",
|
|
114
|
+
help="Model dtype for load.",
|
|
115
|
+
)
|
|
116
|
+
parser.add_argument(
|
|
117
|
+
"--stride",
|
|
118
|
+
type=int,
|
|
119
|
+
default=512,
|
|
120
|
+
help="Sliding-window stride used for calibration and eval.",
|
|
121
|
+
)
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--calib-preset",
|
|
124
|
+
choices=list(TOKENS.keys()),
|
|
125
|
+
default="debug",
|
|
126
|
+
help="Activation calibration token budget preset.",
|
|
127
|
+
)
|
|
128
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"--trust-remote-code",
|
|
131
|
+
action="store_true",
|
|
132
|
+
help="Enable only if you trust the model repo code.",
|
|
133
|
+
)
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--hf-token",
|
|
136
|
+
type=str,
|
|
137
|
+
default=None,
|
|
138
|
+
help="Optional HF token for gated/private repos.",
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--use-cache",
|
|
142
|
+
dest="use_cache",
|
|
143
|
+
action="store_true",
|
|
144
|
+
default=False,
|
|
145
|
+
help="Use model KV cache if enabled (off by default).",
|
|
146
|
+
)
|
|
147
|
+
parser.add_argument(
|
|
148
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
args = parser.parse_args()
|
|
152
|
+
|
|
153
|
+
# Basic setup
|
|
154
|
+
torch.manual_seed(args.seed)
|
|
155
|
+
device = torch.device(args.device)
|
|
156
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
157
|
+
|
|
158
|
+
print("=== Config ===")
|
|
159
|
+
print(f"Model : {args.model}")
|
|
160
|
+
print(f"Device : {device.type}")
|
|
161
|
+
print(f"DType : {args.dtype}")
|
|
162
|
+
print(f"Stride : {args.stride}")
|
|
163
|
+
print(
|
|
164
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
165
|
+
)
|
|
166
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
167
|
+
print()
|
|
168
|
+
|
|
169
|
+
# -------------------------------------------------------------------------
|
|
170
|
+
# 2. Load the FP backbone and tokenizer
|
|
171
|
+
# -------------------------------------------------------------------------
|
|
172
|
+
print("Loading FP model …")
|
|
173
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
174
|
+
args.model,
|
|
175
|
+
trust_remote_code=args.trust_remote_code,
|
|
176
|
+
token=args.hf_token,
|
|
177
|
+
)
|
|
178
|
+
model = (
|
|
179
|
+
AutoModelForCausalLM.from_pretrained(
|
|
180
|
+
args.model,
|
|
181
|
+
torch_dtype=dtype,
|
|
182
|
+
trust_remote_code=args.trust_remote_code,
|
|
183
|
+
token=args.hf_token,
|
|
184
|
+
)
|
|
185
|
+
.to(device)
|
|
186
|
+
.eval()
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
model.config.use_cache = args.use_cache
|
|
190
|
+
|
|
191
|
+
# Build module -> FQN map BEFORE wrapping
|
|
192
|
+
m_to_fqn = build_fqn_map(model)
|
|
193
|
+
|
|
194
|
+
# -------------------------------------------------------------------------
|
|
195
|
+
# 3. Run GPTQ (weight-only) pass
|
|
196
|
+
# -------------------------------------------------------------------------
|
|
197
|
+
print("Applying GPTQ …")
|
|
198
|
+
dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
199
|
+
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
200
|
+
|
|
201
|
+
it = (
|
|
202
|
+
dataset_test
|
|
203
|
+
if args.no_tqdm
|
|
204
|
+
else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
|
|
205
|
+
)
|
|
206
|
+
for d in it:
|
|
207
|
+
ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
|
|
208
|
+
q_m(ids) # observers gather weight stats
|
|
209
|
+
|
|
210
|
+
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
|
|
211
|
+
|
|
212
|
+
# -------------------------------------------------------------------------
|
|
213
|
+
# 4. Wrap every layer with PTQWrapper (activation UINT-8)
|
|
214
|
+
# -------------------------------------------------------------------------
|
|
215
|
+
print("Wrapping layers with PTQWrapper …")
|
|
216
|
+
qcfg = PTQConfig() # default: per-tensor UINT8
|
|
217
|
+
prepare(q_m, qcfg)
|
|
218
|
+
|
|
219
|
+
# -------------------------------------------------------------------------
|
|
220
|
+
# 5. Single-pass activation calibration
|
|
221
|
+
# -------------------------------------------------------------------------
|
|
222
|
+
print("Calibrating UINT-8 observers …")
|
|
223
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
224
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
225
|
+
dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
226
|
+
calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
|
|
227
|
+
train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
228
|
+
|
|
229
|
+
# Overwrite weight observers with GPTQ statistics
|
|
230
|
+
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
|
|
231
|
+
inject_gptq_qparams(q_m, q_m.quantizers)
|
|
232
|
+
else:
|
|
233
|
+
print(
|
|
234
|
+
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Forward passes to collect activation ranges
|
|
238
|
+
iterator = range(0, train_ids.size(1) - 1, args.stride)
|
|
239
|
+
if not args.no_tqdm:
|
|
240
|
+
iterator = tqdm.tqdm(iterator, desc="Act-calibration")
|
|
241
|
+
with torch.no_grad():
|
|
242
|
+
for i in iterator:
|
|
243
|
+
q_m(train_ids[:, i : i + args.stride])
|
|
244
|
+
|
|
245
|
+
# Freeze all Q-params (scale, zero-point)
|
|
246
|
+
convert(q_m)
|
|
247
|
+
|
|
248
|
+
# -------------------------------------------------------------------------
|
|
249
|
+
# 6. Evaluate perplexity on Wikitext-2
|
|
250
|
+
# -------------------------------------------------------------------------
|
|
251
|
+
print("\nCalculating perplexities …")
|
|
252
|
+
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
|
|
253
|
+
ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
|
|
254
|
+
|
|
255
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
256
|
+
print(f"│ UINT-8 : {ppl_uint8:8.2f}")
|
|
257
|
+
print("└───────────────────────────────────────────")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
if __name__ == "__main__":
|
|
261
|
+
try:
|
|
262
|
+
main()
|
|
263
|
+
except Exception as e:
|
|
264
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
265
|
+
sys.exit(1)
|