tico 0.1.0.dev250917__py3-none-any.whl → 0.1.0.dev250921__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +3 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +2 -2
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +1 -1
- tico/experimental/quantization/config/__init__.py +1 -0
- tico/experimental/quantization/config/base.py +26 -0
- tico/experimental/quantization/config/gptq.py +29 -0
- tico/experimental/quantization/config/pt2e.py +25 -0
- tico/experimental/quantization/{config.py → config/smoothquant.py} +1 -35
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +191 -70
- tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py +494 -0
- tico/experimental/quantization/ptq/wrappers/registry.py +1 -0
- tico/experimental/quantization/public_interface.py +1 -1
- tico/experimental/quantization/quantizer.py +1 -1
- tico/passes/convert_matmul_to_linear.py +200 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +6 -1
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/RECORD +25 -19
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -20,6 +20,9 @@ from tico.config.base import CompileConfigBase
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class CompileConfigV1(CompileConfigBase):
|
|
22
22
|
legalize_causal_mask_value: bool = False
|
|
23
|
+
remove_constant_input: bool = False
|
|
24
|
+
convert_lhs_const_mm_to_fc: bool = False
|
|
25
|
+
convert_rhs_const_mm_to_fc: bool = True
|
|
23
26
|
|
|
24
27
|
def get(self, name: str):
|
|
25
28
|
return super().get(name)
|
|
@@ -25,7 +25,7 @@ from tico.experimental.quantization.algorithm.gptq.utils import (
|
|
|
25
25
|
gather_single_batch_from_dict,
|
|
26
26
|
gather_single_batch_from_list,
|
|
27
27
|
)
|
|
28
|
-
from tico.experimental.quantization.config import
|
|
28
|
+
from tico.experimental.quantization.config.gptq import GPTQConfig
|
|
29
29
|
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
30
30
|
|
|
31
31
|
|
|
@@ -44,7 +44,7 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
44
44
|
3) convert(model) to consume the collected data and apply GPTQ.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
def __init__(self, config:
|
|
47
|
+
def __init__(self, config: GPTQConfig):
|
|
48
48
|
super().__init__(config)
|
|
49
49
|
|
|
50
50
|
# cache_args[i] -> list of the i-th positional argument for each batch
|
|
@@ -23,7 +23,7 @@ from tico.experimental.quantization.algorithm.smoothquant.observer import (
|
|
|
23
23
|
from tico.experimental.quantization.algorithm.smoothquant.smooth_quant import (
|
|
24
24
|
apply_smoothing,
|
|
25
25
|
)
|
|
26
|
-
from tico.experimental.quantization.config import SmoothQuantConfig
|
|
26
|
+
from tico.experimental.quantization.config.smoothquant import SmoothQuantConfig
|
|
27
27
|
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
28
28
|
|
|
29
29
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseConfig(ABC):
|
|
19
|
+
"""
|
|
20
|
+
Base configuration class for quantization.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def name(self) -> str:
|
|
26
|
+
pass
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from tico.experimental.quantization.config.base import BaseConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GPTQConfig(BaseConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration for GPTQ.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, verbose: bool = False, show_progress: bool = True):
|
|
24
|
+
self.verbose = verbose
|
|
25
|
+
self.show_progress = show_progress
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return "gptq"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from tico.experimental.quantization.config.base import BaseConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PT2EConfig(BaseConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration for pytorch 2.0 export quantization.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def name(self) -> str:
|
|
25
|
+
return "pt2e"
|
|
@@ -12,43 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from abc import ABC, abstractmethod
|
|
16
15
|
from typing import Dict, Literal, Optional
|
|
17
16
|
|
|
18
|
-
|
|
19
|
-
class BaseConfig(ABC):
|
|
20
|
-
"""
|
|
21
|
-
Base configuration class for quantization.
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
@property
|
|
25
|
-
@abstractmethod
|
|
26
|
-
def name(self) -> str:
|
|
27
|
-
pass
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class PT2EConfig(BaseConfig):
|
|
31
|
-
"""
|
|
32
|
-
Configuration for pytorch 2.0 export quantization.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def name(self) -> str:
|
|
37
|
-
return "pt2e"
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class GPTQConfig(BaseConfig):
|
|
41
|
-
"""
|
|
42
|
-
Configuration for GPTQ.
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
def __init__(self, verbose: bool = False, show_progress: bool = True):
|
|
46
|
-
self.verbose = verbose
|
|
47
|
-
self.show_progress = show_progress
|
|
48
|
-
|
|
49
|
-
@property
|
|
50
|
-
def name(self) -> str:
|
|
51
|
-
return "gptq"
|
|
17
|
+
from tico.experimental.quantization.config.base import BaseConfig
|
|
52
18
|
|
|
53
19
|
|
|
54
20
|
class SmoothQuantConfig(BaseConfig):
|
|
@@ -24,6 +24,8 @@
|
|
|
24
24
|
# 6. Freeze all Q-params and compute Wikitext-2 perplexity.
|
|
25
25
|
# =============================================================================
|
|
26
26
|
|
|
27
|
+
import argparse
|
|
28
|
+
import sys
|
|
27
29
|
from typing import Any
|
|
28
30
|
|
|
29
31
|
import torch
|
|
@@ -32,7 +34,7 @@ from datasets import load_dataset
|
|
|
32
34
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
33
35
|
|
|
34
36
|
from tico.experimental.quantization import convert, prepare
|
|
35
|
-
from tico.experimental.quantization.config import GPTQConfig
|
|
37
|
+
from tico.experimental.quantization.config.gptq import GPTQConfig
|
|
36
38
|
from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
|
|
37
39
|
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
38
40
|
from tico.experimental.quantization.ptq.utils.introspection import build_fqn_map
|
|
@@ -42,12 +44,6 @@ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
|
42
44
|
QuantModuleBase,
|
|
43
45
|
)
|
|
44
46
|
|
|
45
|
-
# -------------------------------------------------------------------------
|
|
46
|
-
# 0. Global configuration
|
|
47
|
-
# -------------------------------------------------------------------------
|
|
48
|
-
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
|
49
|
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
50
|
-
STRIDE = 512
|
|
51
47
|
|
|
52
48
|
# Token-budget presets for activation calibration
|
|
53
49
|
TOKENS: dict[str, int] = {
|
|
@@ -58,7 +54,18 @@ TOKENS: dict[str, int] = {
|
|
|
58
54
|
# Production / 4-bit observer smoothing
|
|
59
55
|
"production": 200_000,
|
|
60
56
|
}
|
|
61
|
-
|
|
57
|
+
|
|
58
|
+
DTYPE_MAP = {
|
|
59
|
+
"float32": torch.float32,
|
|
60
|
+
"bfloat16": torch.bfloat16,
|
|
61
|
+
"float16": torch.float16,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
# Hardcoded dataset settings
|
|
65
|
+
DATASET_NAME = "wikitext"
|
|
66
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
67
|
+
TRAIN_SPLIT = "train"
|
|
68
|
+
TEST_SPLIT = "test"
|
|
62
69
|
|
|
63
70
|
# -------------------------------------------------------------------------
|
|
64
71
|
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
|
|
@@ -89,77 +96,191 @@ def inject_gptq_qparams(
|
|
|
89
96
|
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
|
|
90
97
|
|
|
91
98
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
model =
|
|
98
|
-
|
|
99
|
-
|
|
99
|
+
def main():
|
|
100
|
+
parser = argparse.ArgumentParser(
|
|
101
|
+
description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--device",
|
|
108
|
+
type=str,
|
|
109
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
110
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
111
|
+
)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--dtype",
|
|
114
|
+
choices=list(DTYPE_MAP.keys()),
|
|
115
|
+
default="float32",
|
|
116
|
+
help="Model dtype for load.",
|
|
117
|
+
)
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--stride",
|
|
120
|
+
type=int,
|
|
121
|
+
default=512,
|
|
122
|
+
help="Sliding-window stride used for calibration and eval.",
|
|
123
|
+
)
|
|
124
|
+
parser.add_argument(
|
|
125
|
+
"--calib-preset",
|
|
126
|
+
choices=list(TOKENS.keys()),
|
|
127
|
+
default="debug",
|
|
128
|
+
help="Activation calibration token budget preset.",
|
|
129
|
+
)
|
|
130
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--trust-remote-code",
|
|
133
|
+
action="store_true",
|
|
134
|
+
help="Enable only if you trust the model repo code.",
|
|
135
|
+
)
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"--hf-token",
|
|
138
|
+
type=str,
|
|
139
|
+
default=None,
|
|
140
|
+
help="Optional HF token for gated/private repos.",
|
|
141
|
+
)
|
|
142
|
+
parser.add_argument(
|
|
143
|
+
"--use-cache",
|
|
144
|
+
dest="use_cache",
|
|
145
|
+
action="store_true",
|
|
146
|
+
default=False,
|
|
147
|
+
help="Use model KV cache if enabled (off by default).",
|
|
148
|
+
)
|
|
149
|
+
parser.add_argument(
|
|
150
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
151
|
+
)
|
|
100
152
|
|
|
101
|
-
|
|
102
|
-
# 3. Run GPTQ (weight-only) pass
|
|
103
|
-
# -------------------------------------------------------------------------
|
|
104
|
-
print("Applying GPTQ …")
|
|
105
|
-
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="test")
|
|
106
|
-
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
153
|
+
args = parser.parse_args()
|
|
107
154
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
155
|
+
# Basic setup
|
|
156
|
+
torch.manual_seed(args.seed)
|
|
157
|
+
device = torch.device(args.device)
|
|
158
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
111
159
|
|
|
112
|
-
|
|
160
|
+
print("=== Config ===")
|
|
161
|
+
print(f"Model : {args.model}")
|
|
162
|
+
print(f"Device : {device.type}")
|
|
163
|
+
print(f"DType : {args.dtype}")
|
|
164
|
+
print(f"Stride : {args.stride}")
|
|
165
|
+
print(
|
|
166
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
167
|
+
)
|
|
168
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
169
|
+
print()
|
|
113
170
|
|
|
114
|
-
# -------------------------------------------------------------------------
|
|
115
|
-
#
|
|
116
|
-
# -------------------------------------------------------------------------
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
171
|
+
# -------------------------------------------------------------------------
|
|
172
|
+
# 2. Load the FP backbone and tokenizer
|
|
173
|
+
# -------------------------------------------------------------------------
|
|
174
|
+
print("Loading FP model …")
|
|
175
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
176
|
+
args.model,
|
|
177
|
+
trust_remote_code=args.trust_remote_code,
|
|
178
|
+
token=args.hf_token,
|
|
179
|
+
)
|
|
180
|
+
model = (
|
|
181
|
+
AutoModelForCausalLM.from_pretrained(
|
|
182
|
+
args.model,
|
|
183
|
+
torch_dtype=dtype,
|
|
184
|
+
trust_remote_code=args.trust_remote_code,
|
|
185
|
+
token=args.hf_token,
|
|
186
|
+
)
|
|
187
|
+
.to(device)
|
|
188
|
+
.eval()
|
|
126
189
|
)
|
|
127
|
-
new_layers.append(q_layer)
|
|
128
190
|
|
|
129
|
-
|
|
191
|
+
model.config.use_cache = args.use_cache
|
|
130
192
|
|
|
131
|
-
#
|
|
132
|
-
|
|
133
|
-
# -------------------------------------------------------------------------
|
|
134
|
-
print("Calibrating UINT-8 observers …")
|
|
135
|
-
calib_txt = " ".join(
|
|
136
|
-
load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
|
|
137
|
-
)[:CALIB_TOKENS]
|
|
138
|
-
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
|
193
|
+
# Build module -> FQN map BEFORE wrapping
|
|
194
|
+
m_to_fqn = build_fqn_map(model)
|
|
139
195
|
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
|
|
196
|
+
# -------------------------------------------------------------------------
|
|
197
|
+
# 3. Run GPTQ (weight-only) pass
|
|
198
|
+
# -------------------------------------------------------------------------
|
|
199
|
+
print("Applying GPTQ …")
|
|
200
|
+
dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
201
|
+
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
143
202
|
|
|
144
|
-
|
|
145
|
-
|
|
203
|
+
it = (
|
|
204
|
+
dataset_test
|
|
205
|
+
if args.no_tqdm
|
|
206
|
+
else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
|
|
207
|
+
)
|
|
208
|
+
for d in it:
|
|
209
|
+
ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
|
|
210
|
+
q_m(ids) # observers gather weight stats
|
|
146
211
|
|
|
147
|
-
|
|
148
|
-
for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
|
|
149
|
-
q_m(ids[:, i : i + STRIDE]) # observers collect act. ranges
|
|
212
|
+
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
|
|
150
213
|
|
|
151
|
-
#
|
|
152
|
-
|
|
153
|
-
|
|
214
|
+
# -------------------------------------------------------------------------
|
|
215
|
+
# 4. Wrap every layer with PTQWrapper (activation UINT-8)
|
|
216
|
+
# -------------------------------------------------------------------------
|
|
217
|
+
print("Wrapping layers with PTQWrapper …")
|
|
218
|
+
layers = q_m.model.layers
|
|
219
|
+
if not isinstance(layers, (list, torch.nn.ModuleList)):
|
|
220
|
+
raise TypeError(f"'model.layers' must be list/ModuleList, got {type(layers)}")
|
|
154
221
|
|
|
155
|
-
#
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
222
|
+
qcfg = QuantConfig() # default: per-tensor UINT8
|
|
223
|
+
wrapped = torch.nn.ModuleList()
|
|
224
|
+
for idx, fp_layer in enumerate(layers):
|
|
225
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
|
226
|
+
wrapped.append(
|
|
227
|
+
PTQWrapper(
|
|
228
|
+
fp_layer,
|
|
229
|
+
qcfg=layer_cfg,
|
|
230
|
+
fp_name=m_to_fqn.get(fp_layer),
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
q_m.model.layers = wrapped
|
|
234
|
+
|
|
235
|
+
# -------------------------------------------------------------------------
|
|
236
|
+
# 5. Single-pass activation calibration
|
|
237
|
+
# -------------------------------------------------------------------------
|
|
238
|
+
print("Calibrating UINT-8 observers …")
|
|
239
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
240
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
241
|
+
dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
242
|
+
calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
|
|
243
|
+
train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
244
|
+
|
|
245
|
+
# (a) Enable CALIB mode on every QuantModuleBase
|
|
246
|
+
for l in q_m.model.layers:
|
|
247
|
+
l.enable_calibration()
|
|
248
|
+
|
|
249
|
+
# (b) Overwrite weight observers with GPTQ statistics
|
|
250
|
+
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
|
|
251
|
+
inject_gptq_qparams(q_m, q_m.quantizers)
|
|
252
|
+
else:
|
|
253
|
+
print(
|
|
254
|
+
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# (c) Forward passes to collect activation ranges
|
|
258
|
+
iterator = range(0, train_ids.size(1) - 1, args.stride)
|
|
259
|
+
if not args.no_tqdm:
|
|
260
|
+
iterator = tqdm.tqdm(iterator, desc="Act-calibration")
|
|
261
|
+
with torch.no_grad():
|
|
262
|
+
for i in iterator:
|
|
263
|
+
q_m(train_ids[:, i : i + args.stride])
|
|
264
|
+
|
|
265
|
+
# (d) Freeze all Q-params (scale, zero-point)
|
|
266
|
+
for l in q_m.model.layers:
|
|
267
|
+
l.freeze_qparams()
|
|
268
|
+
|
|
269
|
+
# -------------------------------------------------------------------------
|
|
270
|
+
# 6. Evaluate perplexity on Wikitext-2
|
|
271
|
+
# -------------------------------------------------------------------------
|
|
272
|
+
print("\nCalculating perplexities …")
|
|
273
|
+
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
|
|
274
|
+
ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
|
|
275
|
+
|
|
276
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
277
|
+
print(f"│ UINT-8 : {ppl_uint8:8.2f}")
|
|
278
|
+
print("└───────────────────────────────────────────")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
if __name__ == "__main__":
|
|
282
|
+
try:
|
|
283
|
+
main()
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
286
|
+
sys.exit(1)
|