tico 0.1.0.dev250916__py3-none-any.whl → 0.1.0.dev250918__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +3 -0
- tico/experimental/quantization/ptq/examples/compare_ppl.py +4 -2
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +189 -79
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +190 -69
- tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py +494 -0
- tico/experimental/quantization/ptq/wrappers/registry.py +1 -0
- tico/passes/convert_matmul_to_linear.py +200 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +6 -1
- tico/utils/signature.py +7 -8
- {tico-0.1.0.dev250916.dist-info → tico-0.1.0.dev250918.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250916.dist-info → tico-0.1.0.dev250918.dist-info}/RECORD +19 -17
- {tico-0.1.0.dev250916.dist-info → tico-0.1.0.dev250918.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250916.dist-info → tico-0.1.0.dev250918.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250916.dist-info → tico-0.1.0.dev250918.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250916.dist-info → tico-0.1.0.dev250918.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -20,6 +20,9 @@ from tico.config.base import CompileConfigBase
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class CompileConfigV1(CompileConfigBase):
|
|
22
22
|
legalize_causal_mask_value: bool = False
|
|
23
|
+
remove_constant_input: bool = False
|
|
24
|
+
convert_lhs_const_mm_to_fc: bool = False
|
|
25
|
+
convert_rhs_const_mm_to_fc: bool = True
|
|
23
26
|
|
|
24
27
|
def get(self, name: str):
|
|
25
28
|
return super().get(name)
|
|
@@ -77,7 +77,7 @@ def main():
|
|
|
77
77
|
"--dtype",
|
|
78
78
|
choices=list(DTYPE_MAP.keys()),
|
|
79
79
|
default="float32",
|
|
80
|
-
help="Model dtype for load
|
|
80
|
+
help=f"Model dtype for load.",
|
|
81
81
|
)
|
|
82
82
|
parser.add_argument(
|
|
83
83
|
"--stride", type=int, default=512, help="Sliding-window stride for perplexity."
|
|
@@ -126,7 +126,9 @@ def main():
|
|
|
126
126
|
print(f"DType : {args.dtype}")
|
|
127
127
|
print(f"Stride : {args.stride}")
|
|
128
128
|
print(f"Use HF cache? : {args.use_cache}")
|
|
129
|
-
print(
|
|
129
|
+
print(
|
|
130
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
131
|
+
)
|
|
130
132
|
print()
|
|
131
133
|
|
|
132
134
|
# -------------------------------------------------------------------------
|
|
@@ -12,19 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import torch
|
|
16
|
-
import tqdm
|
|
17
|
-
from datasets import load_dataset
|
|
18
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
19
|
-
|
|
20
|
-
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
21
|
-
from tico.experimental.quantization.ptq.utils.introspection import (
|
|
22
|
-
build_fqn_map,
|
|
23
|
-
compare_layer_outputs,
|
|
24
|
-
save_fp_outputs,
|
|
25
|
-
)
|
|
26
|
-
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
|
27
|
-
|
|
28
15
|
# ============================================================================
|
|
29
16
|
# LAYER-WISE DIFF DEBUGGING PIPELINE
|
|
30
17
|
# ----------------------------------------------------------------------------
|
|
@@ -43,12 +30,21 @@ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
|
|
43
30
|
# problematic modules during post-training quantization.
|
|
44
31
|
# ============================================================================
|
|
45
32
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
33
|
+
import argparse
|
|
34
|
+
import sys
|
|
35
|
+
|
|
36
|
+
import torch
|
|
37
|
+
import tqdm
|
|
38
|
+
from datasets import load_dataset
|
|
39
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
40
|
+
|
|
41
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
42
|
+
from tico.experimental.quantization.ptq.utils.introspection import (
|
|
43
|
+
build_fqn_map,
|
|
44
|
+
compare_layer_outputs,
|
|
45
|
+
save_fp_outputs,
|
|
46
|
+
)
|
|
47
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
|
52
48
|
|
|
53
49
|
# Token-budget presets for activation calibration
|
|
54
50
|
TOKENS: dict[str, int] = {
|
|
@@ -59,71 +55,185 @@ TOKENS: dict[str, int] = {
|
|
|
59
55
|
# Production / 4-bit observer smoothing
|
|
60
56
|
"production": 200_000,
|
|
61
57
|
}
|
|
62
|
-
CALIB_TOKENS = TOKENS["baseline"]
|
|
63
|
-
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
64
|
-
|
|
65
|
-
# -------------------------------------------------------------------------
|
|
66
|
-
# 1. Load the FP backbone
|
|
67
|
-
# -------------------------------------------------------------------------
|
|
68
|
-
print("Loading FP model …")
|
|
69
|
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
70
|
-
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
|
71
|
-
model.config.use_cache = False # disable KV-cache → full forward
|
|
72
|
-
m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
|
|
73
|
-
|
|
74
|
-
# Use Wikitext-2 train split for calibration.
|
|
75
|
-
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
|
76
|
-
|
|
77
|
-
# -------------------------------------------------------------------------
|
|
78
|
-
# 2. Wrap every layer with PTQWrapper (UINT-8 activations)
|
|
79
|
-
# -------------------------------------------------------------------------
|
|
80
|
-
print("Wrapping layers with PTQWrapper …")
|
|
81
|
-
qcfg = QuantConfig() # default: per-tensor UINT8
|
|
82
|
-
|
|
83
|
-
new_layers = torch.nn.ModuleList()
|
|
84
|
-
for idx, fp_layer in enumerate(model.model.layers):
|
|
85
|
-
layer_cfg = qcfg.child(f"layer{idx}")
|
|
86
|
-
q_layer = PTQWrapper(
|
|
87
|
-
fp_layer,
|
|
88
|
-
qcfg=layer_cfg,
|
|
89
|
-
fp_name=m_to_fqn.get(fp_layer),
|
|
90
|
-
)
|
|
91
|
-
new_layers.append(q_layer)
|
|
92
58
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
print("Calibrating UINT-8 observers …")
|
|
99
|
-
calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
|
|
100
|
-
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
|
59
|
+
DTYPE_MAP = {
|
|
60
|
+
"float32": torch.float32,
|
|
61
|
+
"bfloat16": torch.bfloat16,
|
|
62
|
+
"float16": torch.float16,
|
|
63
|
+
}
|
|
101
64
|
|
|
102
|
-
#
|
|
103
|
-
|
|
104
|
-
|
|
65
|
+
# Hardcoded dataset settings
|
|
66
|
+
DATASET_NAME = "wikitext"
|
|
67
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
68
|
+
TRAIN_SPLIT = "train"
|
|
105
69
|
|
|
106
|
-
# Save reference FP activations before observers clamp/quantize
|
|
107
|
-
save_handles, act_cache = save_fp_outputs(model)
|
|
108
70
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
71
|
+
def main():
|
|
72
|
+
parser = argparse.ArgumentParser(
|
|
73
|
+
description="Layer-wise diff debugging pipeline for PTQ"
|
|
74
|
+
)
|
|
75
|
+
parser.add_argument(
|
|
76
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
77
|
+
)
|
|
78
|
+
parser.add_argument(
|
|
79
|
+
"--device",
|
|
80
|
+
type=str,
|
|
81
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
82
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"--dtype",
|
|
86
|
+
choices=list(DTYPE_MAP.keys()),
|
|
87
|
+
default="float32",
|
|
88
|
+
help=f"Model dtype for load.",
|
|
89
|
+
)
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--stride",
|
|
92
|
+
type=int,
|
|
93
|
+
default=512,
|
|
94
|
+
help="Sliding-window stride used during calibration.",
|
|
95
|
+
)
|
|
96
|
+
parser.add_argument(
|
|
97
|
+
"--calib-preset",
|
|
98
|
+
choices=list(TOKENS.keys()),
|
|
99
|
+
default="debug",
|
|
100
|
+
help="Calibration token budget preset.",
|
|
101
|
+
)
|
|
102
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--trust-remote-code",
|
|
105
|
+
action="store_true",
|
|
106
|
+
help="Enable only if you trust the model repo code.",
|
|
107
|
+
)
|
|
108
|
+
parser.add_argument(
|
|
109
|
+
"--hf-token",
|
|
110
|
+
type=str,
|
|
111
|
+
default=None,
|
|
112
|
+
help="Optional HF token for gated/private repos.",
|
|
113
|
+
)
|
|
114
|
+
parser.add_argument(
|
|
115
|
+
"--use-cache",
|
|
116
|
+
dest="use_cache",
|
|
117
|
+
action="store_true",
|
|
118
|
+
default=False,
|
|
119
|
+
help="Use model KV cache if enabled (off by default).",
|
|
120
|
+
)
|
|
121
|
+
parser.add_argument(
|
|
122
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
123
|
+
)
|
|
113
124
|
|
|
114
|
-
|
|
115
|
-
for h in save_handles:
|
|
116
|
-
h.remove()
|
|
125
|
+
args = parser.parse_args()
|
|
117
126
|
|
|
118
|
-
#
|
|
119
|
-
|
|
120
|
-
|
|
127
|
+
# Basic setup
|
|
128
|
+
torch.manual_seed(args.seed)
|
|
129
|
+
device = torch.device(args.device)
|
|
130
|
+
dtype = DTYPE_MAP[args.dtype] # noqa: E999 (kept readable)
|
|
121
131
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
132
|
+
print("=== Config ===")
|
|
133
|
+
print(f"Model : {args.model}")
|
|
134
|
+
print(f"Device : {device.type}")
|
|
135
|
+
print(f"DType : {args.dtype}")
|
|
136
|
+
print(f"Stride : {args.stride}")
|
|
137
|
+
print(
|
|
138
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
139
|
+
)
|
|
140
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
141
|
+
print()
|
|
142
|
+
|
|
143
|
+
# -------------------------------------------------------------------------
|
|
144
|
+
# 1. Load the FP backbone and tokenizer
|
|
145
|
+
# -------------------------------------------------------------------------
|
|
146
|
+
print("Loading FP model …")
|
|
147
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
148
|
+
args.model,
|
|
149
|
+
trust_remote_code=args.trust_remote_code,
|
|
150
|
+
token=args.hf_token,
|
|
151
|
+
)
|
|
152
|
+
model = (
|
|
153
|
+
AutoModelForCausalLM.from_pretrained(
|
|
154
|
+
args.model,
|
|
155
|
+
torch_dtype=dtype,
|
|
156
|
+
trust_remote_code=args.trust_remote_code,
|
|
157
|
+
token=args.hf_token,
|
|
158
|
+
)
|
|
159
|
+
.to(device)
|
|
160
|
+
.eval()
|
|
161
|
+
)
|
|
126
162
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
163
|
+
# Disable KV cache to force full forward passes for introspection
|
|
164
|
+
model.config.use_cache = args.use_cache
|
|
165
|
+
|
|
166
|
+
# Build module -> FQN map before wrapping
|
|
167
|
+
m_to_fqn = build_fqn_map(model)
|
|
168
|
+
|
|
169
|
+
# Prepare calibration inputs (HF Wikitext-2 train split)
|
|
170
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
171
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
172
|
+
# Use Wikitext-2 train split for calibration.
|
|
173
|
+
dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
174
|
+
|
|
175
|
+
# -------------------------------------------------------------------------
|
|
176
|
+
# 2. Wrap every layer with PTQWrapper (UINT-8 activations)
|
|
177
|
+
# -------------------------------------------------------------------------
|
|
178
|
+
print("Wrapping layers with PTQWrapper …")
|
|
179
|
+
qcfg = QuantConfig() # default: per-tensor UINT8
|
|
180
|
+
|
|
181
|
+
new_layers = torch.nn.ModuleList()
|
|
182
|
+
for idx, fp_layer in enumerate(model.model.layers):
|
|
183
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
|
184
|
+
q_layer = PTQWrapper(
|
|
185
|
+
fp_layer,
|
|
186
|
+
qcfg=layer_cfg,
|
|
187
|
+
fp_name=m_to_fqn.get(fp_layer),
|
|
188
|
+
)
|
|
189
|
+
new_layers.append(q_layer)
|
|
190
|
+
|
|
191
|
+
model.model.layers = new_layers # swap in quant wrappers
|
|
192
|
+
|
|
193
|
+
# -------------------------------------------------------------------------
|
|
194
|
+
# 3. Activation calibration plus FP-vs-UINT8 diffing
|
|
195
|
+
# -------------------------------------------------------------------------
|
|
196
|
+
print("Calibrating UINT-8 observers …")
|
|
197
|
+
calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
|
|
198
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
199
|
+
|
|
200
|
+
# (a) Enable CALIB mode on every QuantModuleBase
|
|
201
|
+
for l in model.model.layers:
|
|
202
|
+
l.enable_calibration()
|
|
203
|
+
|
|
204
|
+
# Save reference FP activations before observers clamp/quantize
|
|
205
|
+
save_handles, act_cache = save_fp_outputs(model)
|
|
206
|
+
|
|
207
|
+
iterator = range(0, ids.size(1) - 1, args.stride)
|
|
208
|
+
if not args.no_tqdm:
|
|
209
|
+
iterator = tqdm.tqdm(iterator, desc="Act-Calibration")
|
|
210
|
+
with torch.no_grad():
|
|
211
|
+
for i in iterator:
|
|
212
|
+
inputs = ids[:, i : i + args.stride]
|
|
213
|
+
model(inputs) # observers collect act. ranges
|
|
214
|
+
|
|
215
|
+
# Remove save hooks now that FP activations are cached
|
|
216
|
+
for h in save_handles:
|
|
217
|
+
h.remove()
|
|
218
|
+
|
|
219
|
+
# (b) Freeze (scale, zero-point) after calibration
|
|
220
|
+
for l in model.model.layers:
|
|
221
|
+
l.freeze_qparams()
|
|
222
|
+
|
|
223
|
+
# (c) Register diff hooks and measure per-layer deltas
|
|
224
|
+
cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
|
|
225
|
+
# Use same inputs for comparison.
|
|
226
|
+
with torch.no_grad():
|
|
227
|
+
model(inputs)
|
|
228
|
+
|
|
229
|
+
assert isinstance(cmp_handles, list)
|
|
230
|
+
for h in cmp_handles:
|
|
231
|
+
h.remove()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
if __name__ == "__main__":
|
|
235
|
+
try:
|
|
236
|
+
main()
|
|
237
|
+
except Exception as e:
|
|
238
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
239
|
+
sys.exit(1)
|
|
@@ -24,6 +24,8 @@
|
|
|
24
24
|
# 6. Freeze all Q-params and compute Wikitext-2 perplexity.
|
|
25
25
|
# =============================================================================
|
|
26
26
|
|
|
27
|
+
import argparse
|
|
28
|
+
import sys
|
|
27
29
|
from typing import Any
|
|
28
30
|
|
|
29
31
|
import torch
|
|
@@ -42,12 +44,6 @@ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
|
42
44
|
QuantModuleBase,
|
|
43
45
|
)
|
|
44
46
|
|
|
45
|
-
# -------------------------------------------------------------------------
|
|
46
|
-
# 0. Global configuration
|
|
47
|
-
# -------------------------------------------------------------------------
|
|
48
|
-
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
|
49
|
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
50
|
-
STRIDE = 512
|
|
51
47
|
|
|
52
48
|
# Token-budget presets for activation calibration
|
|
53
49
|
TOKENS: dict[str, int] = {
|
|
@@ -58,7 +54,18 @@ TOKENS: dict[str, int] = {
|
|
|
58
54
|
# Production / 4-bit observer smoothing
|
|
59
55
|
"production": 200_000,
|
|
60
56
|
}
|
|
61
|
-
|
|
57
|
+
|
|
58
|
+
DTYPE_MAP = {
|
|
59
|
+
"float32": torch.float32,
|
|
60
|
+
"bfloat16": torch.bfloat16,
|
|
61
|
+
"float16": torch.float16,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
# Hardcoded dataset settings
|
|
65
|
+
DATASET_NAME = "wikitext"
|
|
66
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
67
|
+
TRAIN_SPLIT = "train"
|
|
68
|
+
TEST_SPLIT = "test"
|
|
62
69
|
|
|
63
70
|
# -------------------------------------------------------------------------
|
|
64
71
|
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
|
|
@@ -89,77 +96,191 @@ def inject_gptq_qparams(
|
|
|
89
96
|
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
|
|
90
97
|
|
|
91
98
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
model =
|
|
98
|
-
|
|
99
|
-
|
|
99
|
+
def main():
|
|
100
|
+
parser = argparse.ArgumentParser(
|
|
101
|
+
description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--device",
|
|
108
|
+
type=str,
|
|
109
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
110
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
111
|
+
)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--dtype",
|
|
114
|
+
choices=list(DTYPE_MAP.keys()),
|
|
115
|
+
default="float32",
|
|
116
|
+
help="Model dtype for load.",
|
|
117
|
+
)
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--stride",
|
|
120
|
+
type=int,
|
|
121
|
+
default=512,
|
|
122
|
+
help="Sliding-window stride used for calibration and eval.",
|
|
123
|
+
)
|
|
124
|
+
parser.add_argument(
|
|
125
|
+
"--calib-preset",
|
|
126
|
+
choices=list(TOKENS.keys()),
|
|
127
|
+
default="debug",
|
|
128
|
+
help="Activation calibration token budget preset.",
|
|
129
|
+
)
|
|
130
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--trust-remote-code",
|
|
133
|
+
action="store_true",
|
|
134
|
+
help="Enable only if you trust the model repo code.",
|
|
135
|
+
)
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"--hf-token",
|
|
138
|
+
type=str,
|
|
139
|
+
default=None,
|
|
140
|
+
help="Optional HF token for gated/private repos.",
|
|
141
|
+
)
|
|
142
|
+
parser.add_argument(
|
|
143
|
+
"--use-cache",
|
|
144
|
+
dest="use_cache",
|
|
145
|
+
action="store_true",
|
|
146
|
+
default=False,
|
|
147
|
+
help="Use model KV cache if enabled (off by default).",
|
|
148
|
+
)
|
|
149
|
+
parser.add_argument(
|
|
150
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
151
|
+
)
|
|
100
152
|
|
|
101
|
-
|
|
102
|
-
# 3. Run GPTQ (weight-only) pass
|
|
103
|
-
# -------------------------------------------------------------------------
|
|
104
|
-
print("Applying GPTQ …")
|
|
105
|
-
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="test")
|
|
106
|
-
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
153
|
+
args = parser.parse_args()
|
|
107
154
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
155
|
+
# Basic setup
|
|
156
|
+
torch.manual_seed(args.seed)
|
|
157
|
+
device = torch.device(args.device)
|
|
158
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
111
159
|
|
|
112
|
-
|
|
160
|
+
print("=== Config ===")
|
|
161
|
+
print(f"Model : {args.model}")
|
|
162
|
+
print(f"Device : {device.type}")
|
|
163
|
+
print(f"DType : {args.dtype}")
|
|
164
|
+
print(f"Stride : {args.stride}")
|
|
165
|
+
print(
|
|
166
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
167
|
+
)
|
|
168
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
169
|
+
print()
|
|
113
170
|
|
|
114
|
-
# -------------------------------------------------------------------------
|
|
115
|
-
#
|
|
116
|
-
# -------------------------------------------------------------------------
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
171
|
+
# -------------------------------------------------------------------------
|
|
172
|
+
# 2. Load the FP backbone and tokenizer
|
|
173
|
+
# -------------------------------------------------------------------------
|
|
174
|
+
print("Loading FP model …")
|
|
175
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
176
|
+
args.model,
|
|
177
|
+
trust_remote_code=args.trust_remote_code,
|
|
178
|
+
token=args.hf_token,
|
|
179
|
+
)
|
|
180
|
+
model = (
|
|
181
|
+
AutoModelForCausalLM.from_pretrained(
|
|
182
|
+
args.model,
|
|
183
|
+
torch_dtype=dtype,
|
|
184
|
+
trust_remote_code=args.trust_remote_code,
|
|
185
|
+
token=args.hf_token,
|
|
186
|
+
)
|
|
187
|
+
.to(device)
|
|
188
|
+
.eval()
|
|
126
189
|
)
|
|
127
|
-
new_layers.append(q_layer)
|
|
128
190
|
|
|
129
|
-
|
|
191
|
+
model.config.use_cache = args.use_cache
|
|
130
192
|
|
|
131
|
-
#
|
|
132
|
-
|
|
133
|
-
# -------------------------------------------------------------------------
|
|
134
|
-
print("Calibrating UINT-8 observers …")
|
|
135
|
-
calib_txt = " ".join(
|
|
136
|
-
load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
|
|
137
|
-
)[:CALIB_TOKENS]
|
|
138
|
-
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
|
193
|
+
# Build module -> FQN map BEFORE wrapping
|
|
194
|
+
m_to_fqn = build_fqn_map(model)
|
|
139
195
|
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
|
|
196
|
+
# -------------------------------------------------------------------------
|
|
197
|
+
# 3. Run GPTQ (weight-only) pass
|
|
198
|
+
# -------------------------------------------------------------------------
|
|
199
|
+
print("Applying GPTQ …")
|
|
200
|
+
dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
201
|
+
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
143
202
|
|
|
144
|
-
|
|
145
|
-
|
|
203
|
+
it = (
|
|
204
|
+
dataset_test
|
|
205
|
+
if args.no_tqdm
|
|
206
|
+
else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
|
|
207
|
+
)
|
|
208
|
+
for d in it:
|
|
209
|
+
ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
|
|
210
|
+
q_m(ids) # observers gather weight stats
|
|
146
211
|
|
|
147
|
-
|
|
148
|
-
for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
|
|
149
|
-
q_m(ids[:, i : i + STRIDE]) # observers collect act. ranges
|
|
212
|
+
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
|
|
150
213
|
|
|
151
|
-
#
|
|
152
|
-
|
|
153
|
-
|
|
214
|
+
# -------------------------------------------------------------------------
|
|
215
|
+
# 4. Wrap every layer with PTQWrapper (activation UINT-8)
|
|
216
|
+
# -------------------------------------------------------------------------
|
|
217
|
+
print("Wrapping layers with PTQWrapper …")
|
|
218
|
+
layers = q_m.model.layers
|
|
219
|
+
if not isinstance(layers, (list, torch.nn.ModuleList)):
|
|
220
|
+
raise TypeError(f"'model.layers' must be list/ModuleList, got {type(layers)}")
|
|
154
221
|
|
|
155
|
-
#
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
222
|
+
qcfg = QuantConfig() # default: per-tensor UINT8
|
|
223
|
+
wrapped = torch.nn.ModuleList()
|
|
224
|
+
for idx, fp_layer in enumerate(layers):
|
|
225
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
|
226
|
+
wrapped.append(
|
|
227
|
+
PTQWrapper(
|
|
228
|
+
fp_layer,
|
|
229
|
+
qcfg=layer_cfg,
|
|
230
|
+
fp_name=m_to_fqn.get(fp_layer),
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
q_m.model.layers = wrapped
|
|
234
|
+
|
|
235
|
+
# -------------------------------------------------------------------------
|
|
236
|
+
# 5. Single-pass activation calibration
|
|
237
|
+
# -------------------------------------------------------------------------
|
|
238
|
+
print("Calibrating UINT-8 observers …")
|
|
239
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
240
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
241
|
+
dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
242
|
+
calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
|
|
243
|
+
train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
244
|
+
|
|
245
|
+
# (a) Enable CALIB mode on every QuantModuleBase
|
|
246
|
+
for l in q_m.model.layers:
|
|
247
|
+
l.enable_calibration()
|
|
248
|
+
|
|
249
|
+
# (b) Overwrite weight observers with GPTQ statistics
|
|
250
|
+
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
|
|
251
|
+
inject_gptq_qparams(q_m, q_m.quantizers)
|
|
252
|
+
else:
|
|
253
|
+
print(
|
|
254
|
+
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# (c) Forward passes to collect activation ranges
|
|
258
|
+
iterator = range(0, train_ids.size(1) - 1, args.stride)
|
|
259
|
+
if not args.no_tqdm:
|
|
260
|
+
iterator = tqdm.tqdm(iterator, desc="Act-calibration")
|
|
261
|
+
with torch.no_grad():
|
|
262
|
+
for i in iterator:
|
|
263
|
+
q_m(train_ids[:, i : i + args.stride])
|
|
264
|
+
|
|
265
|
+
# (d) Freeze all Q-params (scale, zero-point)
|
|
266
|
+
for l in q_m.model.layers:
|
|
267
|
+
l.freeze_qparams()
|
|
268
|
+
|
|
269
|
+
# -------------------------------------------------------------------------
|
|
270
|
+
# 6. Evaluate perplexity on Wikitext-2
|
|
271
|
+
# -------------------------------------------------------------------------
|
|
272
|
+
print("\nCalculating perplexities …")
|
|
273
|
+
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
|
|
274
|
+
ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
|
|
275
|
+
|
|
276
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
277
|
+
print(f"│ UINT-8 : {ppl_uint8:8.2f}")
|
|
278
|
+
print("└───────────────────────────────────────────")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
if __name__ == "__main__":
|
|
282
|
+
try:
|
|
283
|
+
main()
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
286
|
+
sys.exit(1)
|