tico 0.1.0.dev250917__py3-none-any.whl → 0.1.0.dev250918__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +3 -0
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +190 -69
- tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py +494 -0
- tico/experimental/quantization/ptq/wrappers/registry.py +1 -0
- tico/passes/convert_matmul_to_linear.py +200 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +6 -1
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/RECORD +16 -14
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -20,6 +20,9 @@ from tico.config.base import CompileConfigBase
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class CompileConfigV1(CompileConfigBase):
|
|
22
22
|
legalize_causal_mask_value: bool = False
|
|
23
|
+
remove_constant_input: bool = False
|
|
24
|
+
convert_lhs_const_mm_to_fc: bool = False
|
|
25
|
+
convert_rhs_const_mm_to_fc: bool = True
|
|
23
26
|
|
|
24
27
|
def get(self, name: str):
|
|
25
28
|
return super().get(name)
|
|
@@ -24,6 +24,8 @@
|
|
|
24
24
|
# 6. Freeze all Q-params and compute Wikitext-2 perplexity.
|
|
25
25
|
# =============================================================================
|
|
26
26
|
|
|
27
|
+
import argparse
|
|
28
|
+
import sys
|
|
27
29
|
from typing import Any
|
|
28
30
|
|
|
29
31
|
import torch
|
|
@@ -42,12 +44,6 @@ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
|
42
44
|
QuantModuleBase,
|
|
43
45
|
)
|
|
44
46
|
|
|
45
|
-
# -------------------------------------------------------------------------
|
|
46
|
-
# 0. Global configuration
|
|
47
|
-
# -------------------------------------------------------------------------
|
|
48
|
-
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
|
49
|
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
50
|
-
STRIDE = 512
|
|
51
47
|
|
|
52
48
|
# Token-budget presets for activation calibration
|
|
53
49
|
TOKENS: dict[str, int] = {
|
|
@@ -58,7 +54,18 @@ TOKENS: dict[str, int] = {
|
|
|
58
54
|
# Production / 4-bit observer smoothing
|
|
59
55
|
"production": 200_000,
|
|
60
56
|
}
|
|
61
|
-
|
|
57
|
+
|
|
58
|
+
DTYPE_MAP = {
|
|
59
|
+
"float32": torch.float32,
|
|
60
|
+
"bfloat16": torch.bfloat16,
|
|
61
|
+
"float16": torch.float16,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
# Hardcoded dataset settings
|
|
65
|
+
DATASET_NAME = "wikitext"
|
|
66
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
67
|
+
TRAIN_SPLIT = "train"
|
|
68
|
+
TEST_SPLIT = "test"
|
|
62
69
|
|
|
63
70
|
# -------------------------------------------------------------------------
|
|
64
71
|
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
|
|
@@ -89,77 +96,191 @@ def inject_gptq_qparams(
|
|
|
89
96
|
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
|
|
90
97
|
|
|
91
98
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
model =
|
|
98
|
-
|
|
99
|
-
|
|
99
|
+
def main():
|
|
100
|
+
parser = argparse.ArgumentParser(
|
|
101
|
+
description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--device",
|
|
108
|
+
type=str,
|
|
109
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
110
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
111
|
+
)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--dtype",
|
|
114
|
+
choices=list(DTYPE_MAP.keys()),
|
|
115
|
+
default="float32",
|
|
116
|
+
help="Model dtype for load.",
|
|
117
|
+
)
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--stride",
|
|
120
|
+
type=int,
|
|
121
|
+
default=512,
|
|
122
|
+
help="Sliding-window stride used for calibration and eval.",
|
|
123
|
+
)
|
|
124
|
+
parser.add_argument(
|
|
125
|
+
"--calib-preset",
|
|
126
|
+
choices=list(TOKENS.keys()),
|
|
127
|
+
default="debug",
|
|
128
|
+
help="Activation calibration token budget preset.",
|
|
129
|
+
)
|
|
130
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--trust-remote-code",
|
|
133
|
+
action="store_true",
|
|
134
|
+
help="Enable only if you trust the model repo code.",
|
|
135
|
+
)
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"--hf-token",
|
|
138
|
+
type=str,
|
|
139
|
+
default=None,
|
|
140
|
+
help="Optional HF token for gated/private repos.",
|
|
141
|
+
)
|
|
142
|
+
parser.add_argument(
|
|
143
|
+
"--use-cache",
|
|
144
|
+
dest="use_cache",
|
|
145
|
+
action="store_true",
|
|
146
|
+
default=False,
|
|
147
|
+
help="Use model KV cache if enabled (off by default).",
|
|
148
|
+
)
|
|
149
|
+
parser.add_argument(
|
|
150
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
151
|
+
)
|
|
100
152
|
|
|
101
|
-
|
|
102
|
-
# 3. Run GPTQ (weight-only) pass
|
|
103
|
-
# -------------------------------------------------------------------------
|
|
104
|
-
print("Applying GPTQ …")
|
|
105
|
-
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="test")
|
|
106
|
-
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
153
|
+
args = parser.parse_args()
|
|
107
154
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
155
|
+
# Basic setup
|
|
156
|
+
torch.manual_seed(args.seed)
|
|
157
|
+
device = torch.device(args.device)
|
|
158
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
111
159
|
|
|
112
|
-
|
|
160
|
+
print("=== Config ===")
|
|
161
|
+
print(f"Model : {args.model}")
|
|
162
|
+
print(f"Device : {device.type}")
|
|
163
|
+
print(f"DType : {args.dtype}")
|
|
164
|
+
print(f"Stride : {args.stride}")
|
|
165
|
+
print(
|
|
166
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
167
|
+
)
|
|
168
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
169
|
+
print()
|
|
113
170
|
|
|
114
|
-
# -------------------------------------------------------------------------
|
|
115
|
-
#
|
|
116
|
-
# -------------------------------------------------------------------------
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
171
|
+
# -------------------------------------------------------------------------
|
|
172
|
+
# 2. Load the FP backbone and tokenizer
|
|
173
|
+
# -------------------------------------------------------------------------
|
|
174
|
+
print("Loading FP model …")
|
|
175
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
176
|
+
args.model,
|
|
177
|
+
trust_remote_code=args.trust_remote_code,
|
|
178
|
+
token=args.hf_token,
|
|
179
|
+
)
|
|
180
|
+
model = (
|
|
181
|
+
AutoModelForCausalLM.from_pretrained(
|
|
182
|
+
args.model,
|
|
183
|
+
torch_dtype=dtype,
|
|
184
|
+
trust_remote_code=args.trust_remote_code,
|
|
185
|
+
token=args.hf_token,
|
|
186
|
+
)
|
|
187
|
+
.to(device)
|
|
188
|
+
.eval()
|
|
126
189
|
)
|
|
127
|
-
new_layers.append(q_layer)
|
|
128
190
|
|
|
129
|
-
|
|
191
|
+
model.config.use_cache = args.use_cache
|
|
130
192
|
|
|
131
|
-
#
|
|
132
|
-
|
|
133
|
-
# -------------------------------------------------------------------------
|
|
134
|
-
print("Calibrating UINT-8 observers …")
|
|
135
|
-
calib_txt = " ".join(
|
|
136
|
-
load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
|
|
137
|
-
)[:CALIB_TOKENS]
|
|
138
|
-
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
|
193
|
+
# Build module -> FQN map BEFORE wrapping
|
|
194
|
+
m_to_fqn = build_fqn_map(model)
|
|
139
195
|
|
|
140
|
-
#
|
|
141
|
-
|
|
142
|
-
|
|
196
|
+
# -------------------------------------------------------------------------
|
|
197
|
+
# 3. Run GPTQ (weight-only) pass
|
|
198
|
+
# -------------------------------------------------------------------------
|
|
199
|
+
print("Applying GPTQ …")
|
|
200
|
+
dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
201
|
+
q_m = prepare(model, GPTQConfig(), inplace=True)
|
|
143
202
|
|
|
144
|
-
|
|
145
|
-
|
|
203
|
+
it = (
|
|
204
|
+
dataset_test
|
|
205
|
+
if args.no_tqdm
|
|
206
|
+
else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
|
|
207
|
+
)
|
|
208
|
+
for d in it:
|
|
209
|
+
ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
|
|
210
|
+
q_m(ids) # observers gather weight stats
|
|
146
211
|
|
|
147
|
-
|
|
148
|
-
for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
|
|
149
|
-
q_m(ids[:, i : i + STRIDE]) # observers collect act. ranges
|
|
212
|
+
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
|
|
150
213
|
|
|
151
|
-
#
|
|
152
|
-
|
|
153
|
-
|
|
214
|
+
# -------------------------------------------------------------------------
|
|
215
|
+
# 4. Wrap every layer with PTQWrapper (activation UINT-8)
|
|
216
|
+
# -------------------------------------------------------------------------
|
|
217
|
+
print("Wrapping layers with PTQWrapper …")
|
|
218
|
+
layers = q_m.model.layers
|
|
219
|
+
if not isinstance(layers, (list, torch.nn.ModuleList)):
|
|
220
|
+
raise TypeError(f"'model.layers' must be list/ModuleList, got {type(layers)}")
|
|
154
221
|
|
|
155
|
-
#
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
222
|
+
qcfg = QuantConfig() # default: per-tensor UINT8
|
|
223
|
+
wrapped = torch.nn.ModuleList()
|
|
224
|
+
for idx, fp_layer in enumerate(layers):
|
|
225
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
|
226
|
+
wrapped.append(
|
|
227
|
+
PTQWrapper(
|
|
228
|
+
fp_layer,
|
|
229
|
+
qcfg=layer_cfg,
|
|
230
|
+
fp_name=m_to_fqn.get(fp_layer),
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
q_m.model.layers = wrapped
|
|
234
|
+
|
|
235
|
+
# -------------------------------------------------------------------------
|
|
236
|
+
# 5. Single-pass activation calibration
|
|
237
|
+
# -------------------------------------------------------------------------
|
|
238
|
+
print("Calibrating UINT-8 observers …")
|
|
239
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
240
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
241
|
+
dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
242
|
+
calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
|
|
243
|
+
train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
244
|
+
|
|
245
|
+
# (a) Enable CALIB mode on every QuantModuleBase
|
|
246
|
+
for l in q_m.model.layers:
|
|
247
|
+
l.enable_calibration()
|
|
248
|
+
|
|
249
|
+
# (b) Overwrite weight observers with GPTQ statistics
|
|
250
|
+
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
|
|
251
|
+
inject_gptq_qparams(q_m, q_m.quantizers)
|
|
252
|
+
else:
|
|
253
|
+
print(
|
|
254
|
+
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# (c) Forward passes to collect activation ranges
|
|
258
|
+
iterator = range(0, train_ids.size(1) - 1, args.stride)
|
|
259
|
+
if not args.no_tqdm:
|
|
260
|
+
iterator = tqdm.tqdm(iterator, desc="Act-calibration")
|
|
261
|
+
with torch.no_grad():
|
|
262
|
+
for i in iterator:
|
|
263
|
+
q_m(train_ids[:, i : i + args.stride])
|
|
264
|
+
|
|
265
|
+
# (d) Freeze all Q-params (scale, zero-point)
|
|
266
|
+
for l in q_m.model.layers:
|
|
267
|
+
l.freeze_qparams()
|
|
268
|
+
|
|
269
|
+
# -------------------------------------------------------------------------
|
|
270
|
+
# 6. Evaluate perplexity on Wikitext-2
|
|
271
|
+
# -------------------------------------------------------------------------
|
|
272
|
+
print("\nCalculating perplexities …")
|
|
273
|
+
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
|
|
274
|
+
ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
|
|
275
|
+
|
|
276
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
277
|
+
print(f"│ UINT-8 : {ppl_uint8:8.2f}")
|
|
278
|
+
print("└───────────────────────────────────────────")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
if __name__ == "__main__":
|
|
282
|
+
try:
|
|
283
|
+
main()
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
286
|
+
sys.exit(1)
|