pare-quant 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pare/__init__.py +99 -0
- pare/__main__.py +191 -0
- pare/calibration/__init__.py +0 -0
- pare/calibration/hessian.py +73 -0
- pare/calibration/layerwise.py +241 -0
- pare/calibration/observer.py +72 -0
- pare/calibration/runner.py +96 -0
- pare/config.py +96 -0
- pare/core/__init__.py +16 -0
- pare/core/dtype.py +97 -0
- pare/core/functional.py +214 -0
- pare/core/pack.py +144 -0
- pare/core/scale.py +137 -0
- pare/eval/__init__.py +0 -0
- pare/eval/lambada.py +88 -0
- pare/eval/perplexity.py +100 -0
- pare/eval/throughput.py +157 -0
- pare/kernels/__init__.py +0 -0
- pare/kernels/matmul_int4.py +248 -0
- pare/layers/__init__.py +0 -0
- pare/layers/linear.py +281 -0
- pare/model/__init__.py +0 -0
- pare/model/io.py +216 -0
- pare/model/patcher.py +103 -0
- pare/schemes/__init__.py +0 -0
- pare/schemes/awq.py +384 -0
- pare/schemes/base.py +93 -0
- pare/schemes/gptq.py +289 -0
- pare/schemes/rtn.py +65 -0
- pare/schemes/smoothquant.py +278 -0
- pare/sensitivity.py +106 -0
- pare_quant-0.1.0.dist-info/METADATA +194 -0
- pare_quant-0.1.0.dist-info/RECORD +36 -0
- pare_quant-0.1.0.dist-info/WHEEL +4 -0
- pare_quant-0.1.0.dist-info/entry_points.txt +2 -0
- pare_quant-0.1.0.dist-info/licenses/LICENSE +176 -0
pare/__init__.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Pare — production-ready quantization for large language and multimodal models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from pare.config import QuantConfig
|
|
8
|
+
from pare.core.dtype import QuantDtype
|
|
9
|
+
from pare.model.io import load_quantized, save_quantized
|
|
10
|
+
|
|
11
|
+
__version__ = "0.1.0"
|
|
12
|
+
__all__ = [
|
|
13
|
+
"QuantConfig", "QuantDtype", "quantize",
|
|
14
|
+
"save_quantized", "load_quantized",
|
|
15
|
+
"__version__",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def quantize(
|
|
20
|
+
model: nn.Module,
|
|
21
|
+
config: QuantConfig,
|
|
22
|
+
calibration_data: "list | None" = None,
|
|
23
|
+
device: "str" = "cpu",
|
|
24
|
+
) -> nn.Module:
|
|
25
|
+
"""Quantize a model in-place using the scheme specified in ``config``.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: Any ``nn.Module`` (HuggingFace, custom, etc.).
|
|
29
|
+
config: A ``QuantConfig`` describing the target dtype, scheme,
|
|
30
|
+
granularity, and scheme-specific hyperparameters.
|
|
31
|
+
calibration_data: Required for GPTQ, AWQ, and SmoothQuant.
|
|
32
|
+
A list of input_ids tensors, each shaped [batch, seq_len].
|
|
33
|
+
Ignored for RTN.
|
|
34
|
+
device: Device to run calibration forward passes on (GPTQ only).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The same model object with ``nn.Linear`` layers replaced by
|
|
38
|
+
``QuantizedLinear`` instances.
|
|
39
|
+
|
|
40
|
+
Examples::
|
|
41
|
+
|
|
42
|
+
# RTN — no calibration needed
|
|
43
|
+
from pare import quantize, QuantConfig
|
|
44
|
+
config = QuantConfig(bits=4, scheme="rtn", group_size=128)
|
|
45
|
+
model = quantize(model, config)
|
|
46
|
+
|
|
47
|
+
# GPTQ — calibration data required
|
|
48
|
+
config = QuantConfig(bits=4, scheme="gptq", group_size=128)
|
|
49
|
+
model = quantize(model, config, calibration_data=calib_ids, device="cuda")
|
|
50
|
+
"""
|
|
51
|
+
# Mixed-precision sensitivity: score layers BEFORE quantization so we
|
|
52
|
+
# still have the original FP16 weights. RTN is used as a fast proxy.
|
|
53
|
+
layer_bits_override: dict[str, int] = {}
|
|
54
|
+
if config.sensitive_bits is not None and calibration_data is not None:
|
|
55
|
+
from pare.sensitivity import score_layers
|
|
56
|
+
scores = score_layers(
|
|
57
|
+
model,
|
|
58
|
+
calibration_data,
|
|
59
|
+
bits=config.bits,
|
|
60
|
+
granularity=config.granularity,
|
|
61
|
+
group_size=config.group_size,
|
|
62
|
+
device=device,
|
|
63
|
+
)
|
|
64
|
+
layer_bits_override = {
|
|
65
|
+
name: config.sensitive_bits
|
|
66
|
+
for name, err in scores.items()
|
|
67
|
+
if err > config.sensitivity_threshold
|
|
68
|
+
}
|
|
69
|
+
n_total = len(scores)
|
|
70
|
+
n_sensitive = len(layer_bits_override)
|
|
71
|
+
print(
|
|
72
|
+
f"[pare] Sensitivity: {n_total} layers scored, "
|
|
73
|
+
f"{n_sensitive} above {config.sensitivity_threshold:.0%} threshold "
|
|
74
|
+
f"→ {config.bits}-bit→{config.sensitive_bits}-bit"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
scheme = config.scheme
|
|
78
|
+
if scheme == "rtn":
|
|
79
|
+
from pare.schemes.rtn import RTNQuantizer
|
|
80
|
+
quantizer = RTNQuantizer(config, layer_bits_override=layer_bits_override)
|
|
81
|
+
return quantizer.quantize_model(model)
|
|
82
|
+
|
|
83
|
+
elif scheme == "gptq":
|
|
84
|
+
from pare.schemes.gptq import GPTQQuantizer
|
|
85
|
+
quantizer = GPTQQuantizer(config, layer_bits_override=layer_bits_override)
|
|
86
|
+
return quantizer.quantize_model(model, calibration_data=calibration_data, device=device)
|
|
87
|
+
|
|
88
|
+
elif scheme == "awq":
|
|
89
|
+
from pare.schemes.awq import AWQQuantizer
|
|
90
|
+
quantizer = AWQQuantizer(config, layer_bits_override=layer_bits_override)
|
|
91
|
+
return quantizer.quantize_model(model, calibration_data=calibration_data, device=device)
|
|
92
|
+
|
|
93
|
+
elif scheme == "smoothquant":
|
|
94
|
+
from pare.schemes.smoothquant import SmoothQuantQuantizer
|
|
95
|
+
quantizer = SmoothQuantQuantizer(config, layer_bits_override=layer_bits_override)
|
|
96
|
+
return quantizer.quantize_model(model, calibration_data=calibration_data, device=device)
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError(f"Unknown scheme: {scheme!r}")
|
pare/__main__.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""CLI entry point: python -m pare <command> [options]
|
|
2
|
+
|
|
3
|
+
Commands
|
|
4
|
+
--------
|
|
5
|
+
quantize Quantize a HuggingFace model and save to disk.
|
|
6
|
+
eval Evaluate a quantized (or FP16) model's perplexity.
|
|
7
|
+
|
|
8
|
+
Examples
|
|
9
|
+
--------
|
|
10
|
+
# GPTQ INT4, 128 calibration sequences, save to ./llama2-7b-int4/
|
|
11
|
+
python -m pare quantize \\
|
|
12
|
+
--model meta-llama/Llama-2-7b-hf \\
|
|
13
|
+
--bits 4 \\
|
|
14
|
+
--scheme gptq \\
|
|
15
|
+
--output ./llama2-7b-int4
|
|
16
|
+
|
|
17
|
+
# AWQ INT4 with custom alpha
|
|
18
|
+
python -m pare quantize \\
|
|
19
|
+
--model meta-llama/Llama-2-7b-hf \\
|
|
20
|
+
--bits 4 \\
|
|
21
|
+
--scheme awq \\
|
|
22
|
+
--output ./llama2-7b-awq
|
|
23
|
+
|
|
24
|
+
# SmoothQuant INT8
|
|
25
|
+
python -m pare quantize \\
|
|
26
|
+
--model meta-llama/Llama-2-7b-hf \\
|
|
27
|
+
--bits 8 \\
|
|
28
|
+
--scheme smoothquant \\
|
|
29
|
+
--output ./llama2-7b-sq-int8
|
|
30
|
+
|
|
31
|
+
# Evaluate a saved quantized model
|
|
32
|
+
python -m pare eval \\
|
|
33
|
+
--model meta-llama/Llama-2-7b-hf \\
|
|
34
|
+
--quantized ./llama2-7b-int4 \\
|
|
35
|
+
--dataset wikitext2
|
|
36
|
+
|
|
37
|
+
# Evaluate FP16 baseline (no --quantized flag)
|
|
38
|
+
python -m pare eval \\
|
|
39
|
+
--model meta-llama/Llama-2-7b-hf \\
|
|
40
|
+
--dataset wikitext2
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
from __future__ import annotations
|
|
44
|
+
|
|
45
|
+
import argparse
|
|
46
|
+
import sys
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
50
|
+
parser = argparse.ArgumentParser(
|
|
51
|
+
prog="python -m pare",
|
|
52
|
+
description="Pare — post-training quantization for LLMs",
|
|
53
|
+
)
|
|
54
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
55
|
+
|
|
56
|
+
# ── quantize ──────────────────────────────────────────────────────
|
|
57
|
+
p_q = sub.add_parser("quantize", help="Quantize a model and save to disk")
|
|
58
|
+
p_q.add_argument("--model", required=True, help="HuggingFace model ID or local path")
|
|
59
|
+
p_q.add_argument("--output", required=True, help="Directory to save the quantized model")
|
|
60
|
+
p_q.add_argument("--bits", type=int, default=4, choices=[2, 3, 4, 8],
|
|
61
|
+
help="Target bit-width (default: 4)")
|
|
62
|
+
p_q.add_argument("--scheme", default="awq",
|
|
63
|
+
choices=["rtn", "gptq", "awq", "smoothquant"],
|
|
64
|
+
help="Quantization scheme (default: awq)")
|
|
65
|
+
p_q.add_argument("--granularity", default="per_group",
|
|
66
|
+
choices=["per_tensor", "per_channel", "per_group"],
|
|
67
|
+
help="Scale granularity (default: per_group)")
|
|
68
|
+
p_q.add_argument("--group-size", type=int, default=128,
|
|
69
|
+
help="Group size for per_group granularity (default: 128)")
|
|
70
|
+
p_q.add_argument("--sym", action="store_true",
|
|
71
|
+
help="Symmetric quantization (no zero-point)")
|
|
72
|
+
p_q.add_argument("--smooth-alpha", type=float, default=0.5,
|
|
73
|
+
help="SmoothQuant migration strength α (default: 0.5)")
|
|
74
|
+
p_q.add_argument("--n-calib", type=int, default=128,
|
|
75
|
+
help="Number of calibration sequences (default: 128)")
|
|
76
|
+
p_q.add_argument("--calib-seqlen", type=int, default=2048,
|
|
77
|
+
help="Length of each calibration sequence (default: 2048)")
|
|
78
|
+
p_q.add_argument("--device", default="cuda",
|
|
79
|
+
help="Device for calibration forward passes (default: cuda)")
|
|
80
|
+
p_q.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"],
|
|
81
|
+
help="Model load dtype (default: float16)")
|
|
82
|
+
p_q.add_argument("--token", default=None,
|
|
83
|
+
help="HuggingFace access token for gated models")
|
|
84
|
+
|
|
85
|
+
# ── eval ──────────────────────────────────────────────────────────
|
|
86
|
+
p_e = sub.add_parser("eval", help="Evaluate perplexity of a model")
|
|
87
|
+
p_e.add_argument("--model", required=True, help="HuggingFace model ID or local path")
|
|
88
|
+
p_e.add_argument("--quantized", default=None,
|
|
89
|
+
help="Path to a Pare-saved quantized model (omit for FP16 baseline)")
|
|
90
|
+
p_e.add_argument("--dataset", default="wikitext2",
|
|
91
|
+
choices=["wikitext2", "c4"],
|
|
92
|
+
help="Evaluation dataset (default: wikitext2)")
|
|
93
|
+
p_e.add_argument("--seq-len", type=int, default=2048,
|
|
94
|
+
help="Sequence length (default: 2048)")
|
|
95
|
+
p_e.add_argument("--n-samples", type=int, default=None,
|
|
96
|
+
help="Number of sequences to evaluate (default: full dataset)")
|
|
97
|
+
p_e.add_argument("--device", default="cuda")
|
|
98
|
+
p_e.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"])
|
|
99
|
+
p_e.add_argument("--token", default=None)
|
|
100
|
+
|
|
101
|
+
return parser
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def cmd_quantize(args: argparse.Namespace) -> None:
|
|
105
|
+
import torch
|
|
106
|
+
from datasets import load_dataset
|
|
107
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
108
|
+
|
|
109
|
+
from pare import QuantConfig, quantize, save_quantized
|
|
110
|
+
|
|
111
|
+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
|
112
|
+
torch_dtype = dtype_map[args.dtype]
|
|
113
|
+
|
|
114
|
+
print(f"[pare] Loading {args.model} ...", flush=True)
|
|
115
|
+
tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.token)
|
|
116
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
117
|
+
args.model, torch_dtype=torch_dtype, token=args.token, device_map="auto",
|
|
118
|
+
)
|
|
119
|
+
model.eval()
|
|
120
|
+
|
|
121
|
+
config = QuantConfig(
|
|
122
|
+
bits=args.bits,
|
|
123
|
+
scheme=args.scheme,
|
|
124
|
+
granularity=args.granularity,
|
|
125
|
+
group_size=args.group_size,
|
|
126
|
+
sym=args.sym,
|
|
127
|
+
smooth_alpha=args.smooth_alpha,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
calib_data = None
|
|
131
|
+
if args.scheme != "rtn":
|
|
132
|
+
print(f"[pare] Preparing {args.n_calib} calibration sequences ...", flush=True)
|
|
133
|
+
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
|
|
134
|
+
text = "\n\n".join(ds["text"])
|
|
135
|
+
tokens = tokenizer(text, return_tensors="pt").input_ids
|
|
136
|
+
seq_len = args.calib_seqlen
|
|
137
|
+
calib_data = [tokens[:, i * seq_len : (i + 1) * seq_len] for i in range(args.n_calib)]
|
|
138
|
+
|
|
139
|
+
print(f"[pare] Quantizing with {args.scheme} INT{args.bits} ...", flush=True)
|
|
140
|
+
quantize(model, config, calibration_data=calib_data, device=args.device)
|
|
141
|
+
|
|
142
|
+
save_quantized(model, args.output)
|
|
143
|
+
print(f"[pare] Done. Model saved to {args.output}", flush=True)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def cmd_eval(args: argparse.Namespace) -> None:
|
|
147
|
+
import torch
|
|
148
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
149
|
+
|
|
150
|
+
from pare import load_quantized
|
|
151
|
+
from pare.eval.perplexity import evaluate_perplexity
|
|
152
|
+
|
|
153
|
+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
|
154
|
+
torch_dtype = dtype_map[args.dtype]
|
|
155
|
+
|
|
156
|
+
print(f"[pare] Loading {args.model} ...", flush=True)
|
|
157
|
+
tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.token)
|
|
158
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
159
|
+
args.model, torch_dtype=torch_dtype, token=args.token, device_map="auto",
|
|
160
|
+
)
|
|
161
|
+
model.eval()
|
|
162
|
+
|
|
163
|
+
if args.quantized:
|
|
164
|
+
print(f"[pare] Loading quantized weights from {args.quantized} ...", flush=True)
|
|
165
|
+
load_quantized(model, args.quantized)
|
|
166
|
+
|
|
167
|
+
print(f"[pare] Evaluating PPL on {args.dataset} ...", flush=True)
|
|
168
|
+
ppl = evaluate_perplexity(
|
|
169
|
+
model, tokenizer, dataset=args.dataset,
|
|
170
|
+
seq_len=args.seq_len, n_samples=args.n_samples, device=args.device,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
tag = f"{args.quantized or args.dtype}"
|
|
174
|
+
print(f"\n {args.model} [{tag}] {args.dataset} PPL: {ppl:.2f}")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def main() -> None:
|
|
178
|
+
parser = _build_parser()
|
|
179
|
+
args = parser.parse_args()
|
|
180
|
+
|
|
181
|
+
if args.command == "quantize":
|
|
182
|
+
cmd_quantize(args)
|
|
183
|
+
elif args.command == "eval":
|
|
184
|
+
cmd_eval(args)
|
|
185
|
+
else:
|
|
186
|
+
parser.print_help()
|
|
187
|
+
sys.exit(1)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
if __name__ == "__main__":
|
|
191
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Per-layer Hessian accumulation for GPTQ.
|
|
2
|
+
|
|
3
|
+
The GPTQ objective for one weight row w is:
|
|
4
|
+
|
|
5
|
+
(w - q)^T H (w - q)
|
|
6
|
+
|
|
7
|
+
where H = 2/n * X X^T, X shaped [in_features, n_tokens].
|
|
8
|
+
|
|
9
|
+
We accumulate H incrementally over batches so we never store all
|
|
10
|
+
activations in memory — only the running [in_features, in_features] sum.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import Tensor
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HessianAccumulator:
|
|
20
|
+
"""Accumulates the per-layer Hessian H = 2/n * X X^T online.
|
|
21
|
+
|
|
22
|
+
Usage::
|
|
23
|
+
|
|
24
|
+
acc = HessianAccumulator()
|
|
25
|
+
for batch_activations in ...:
|
|
26
|
+
acc.accumulate(batch_activations)
|
|
27
|
+
H = acc.finalize() # [in_features, in_features]
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self) -> None:
|
|
31
|
+
self.H: Tensor | None = None
|
|
32
|
+
self.n_tokens: int = 0
|
|
33
|
+
|
|
34
|
+
def accumulate(self, x: Tensor) -> None:
|
|
35
|
+
"""Add one batch of activations to the running sum.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
x: Input activations, shape [batch, seq_len, in_features]
|
|
39
|
+
or [batch, in_features]. Will be cast to float32.
|
|
40
|
+
"""
|
|
41
|
+
x = x.detach().float()
|
|
42
|
+
|
|
43
|
+
# Flatten batch and sequence dims → [n_tokens, in_features]
|
|
44
|
+
if x.dim() == 3:
|
|
45
|
+
x = x.reshape(-1, x.shape[-1])
|
|
46
|
+
elif x.dim() != 2:
|
|
47
|
+
raise ValueError(f"Expected 2-D or 3-D activation, got shape {tuple(x.shape)}")
|
|
48
|
+
|
|
49
|
+
n = x.shape[0] # number of tokens in this batch
|
|
50
|
+
|
|
51
|
+
if self.H is None:
|
|
52
|
+
in_features = x.shape[1]
|
|
53
|
+
self.H = torch.zeros(
|
|
54
|
+
in_features, in_features,
|
|
55
|
+
device=x.device, dtype=torch.float32,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# H_unnorm += X_row^T @ X_row (= X_col @ X_col^T where X_col = x.T)
|
|
59
|
+
self.H.addmm_(x.T, x) # in-place fused multiply-add, no extra alloc
|
|
60
|
+
self.n_tokens += n
|
|
61
|
+
|
|
62
|
+
def finalize(self) -> Tensor:
|
|
63
|
+
"""Return the normalised Hessian H = 2/n * ΣX^TX.
|
|
64
|
+
|
|
65
|
+
Raises RuntimeError if no samples have been accumulated.
|
|
66
|
+
"""
|
|
67
|
+
if self.H is None or self.n_tokens == 0:
|
|
68
|
+
raise RuntimeError("HessianAccumulator has no samples — did you forget to call accumulate()?")
|
|
69
|
+
return self.H * (2.0 / self.n_tokens)
|
|
70
|
+
|
|
71
|
+
def reset(self) -> None:
|
|
72
|
+
self.H = None
|
|
73
|
+
self.n_tokens = 0
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""Layerwise GPTQ for transformer block models (Llama, Mistral, Qwen, etc.).
|
|
2
|
+
|
|
3
|
+
Problem: For 7B+ models, accumulating Hessians for all layers simultaneously
|
|
4
|
+
requires ~39 GB on top of the 14 GB model — OOM on 48 GB GPUs.
|
|
5
|
+
|
|
6
|
+
Solution: process one transformer block at a time:
|
|
7
|
+
1. Intercept the first block with a Catcher to capture embedding outputs.
|
|
8
|
+
2. Precompute position_embeddings once from the shared rotary_emb.
|
|
9
|
+
3. For each block i:
|
|
10
|
+
a. Load block to GPU.
|
|
11
|
+
b. Register hooks on its linear sublayers; run all calibration inputs.
|
|
12
|
+
c. GPTQ-quantize each sublayer immediately (peak VRAM: ~2 GB).
|
|
13
|
+
d. Collect outputs (= inputs to block i+1).
|
|
14
|
+
e. Offload block to CPU.
|
|
15
|
+
|
|
16
|
+
Supported architectures: any HF model with model.model.layers (Llama, Mistral,
|
|
17
|
+
Qwen2, Phi-3, Gemma, ...).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import torch.nn as nn
|
|
24
|
+
from torch import Tensor
|
|
25
|
+
|
|
26
|
+
from pare.calibration.hessian import HessianAccumulator
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_supported(model: nn.Module) -> bool:
|
|
30
|
+
"""Return True if this model has the block structure we support."""
|
|
31
|
+
return (
|
|
32
|
+
hasattr(model, "model")
|
|
33
|
+
and hasattr(model.model, "layers")
|
|
34
|
+
and len(model.model.layers) > 0
|
|
35
|
+
and isinstance(model.model.layers[0], nn.Module)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LayerwiseGPTQ:
|
|
40
|
+
"""Block-by-block GPTQ. Keeps peak GPU usage to O(one_block + calib_acts)."""
|
|
41
|
+
|
|
42
|
+
def run(
|
|
43
|
+
self,
|
|
44
|
+
model: nn.Module,
|
|
45
|
+
calibration_data: list[Tensor],
|
|
46
|
+
quantizer,
|
|
47
|
+
device: str | torch.device,
|
|
48
|
+
) -> nn.Module:
|
|
49
|
+
"""Quantize model in-place using layerwise Hessian collection.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
model: Model to quantize (Llama/Mistral-family HF model).
|
|
53
|
+
calibration_data: List of input_ids tensors [1, seq_len].
|
|
54
|
+
quantizer: GPTQQuantizer instance; quantize_layer is called
|
|
55
|
+
for each sublayer.
|
|
56
|
+
device: GPU device string, e.g. "cuda".
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The same model, quantized in-place.
|
|
60
|
+
"""
|
|
61
|
+
device = torch.device(device)
|
|
62
|
+
layers = model.model.layers
|
|
63
|
+
seq_len = calibration_data[0].shape[1]
|
|
64
|
+
|
|
65
|
+
# ── 1. Capture embedding outputs (inputs to layer 0) ───────────────
|
|
66
|
+
inps = _capture_embeddings(model, layers, calibration_data, device)
|
|
67
|
+
|
|
68
|
+
# ── 2. Precompute shared position_embeddings ───────────────────────
|
|
69
|
+
pe = _precompute_pe(model, layers, inps, seq_len, device)
|
|
70
|
+
model.cpu()
|
|
71
|
+
torch.cuda.empty_cache()
|
|
72
|
+
|
|
73
|
+
# ── 3. Resolve the dotted path to model.model.layers ──────────────
|
|
74
|
+
layers_path = _find_layers_path(model)
|
|
75
|
+
|
|
76
|
+
# ── 4. Process one block at a time ─────────────────────────────────
|
|
77
|
+
n_layers = len(layers)
|
|
78
|
+
for li, layer in enumerate(layers):
|
|
79
|
+
layer_prefix = f"{layers_path}.{li}"
|
|
80
|
+
layer.to(device)
|
|
81
|
+
|
|
82
|
+
# Register hooks on every quantizable linear sublayer in this block.
|
|
83
|
+
subs: dict[str, nn.Linear] = {}
|
|
84
|
+
accs: dict[str, HessianAccumulator] = {}
|
|
85
|
+
hooks = []
|
|
86
|
+
|
|
87
|
+
for rel_name, mod in layer.named_modules():
|
|
88
|
+
if not isinstance(mod, nn.Linear):
|
|
89
|
+
continue
|
|
90
|
+
full_name = f"{layer_prefix}.{rel_name}"
|
|
91
|
+
if not quantizer._should_quantize(full_name, mod):
|
|
92
|
+
continue
|
|
93
|
+
acc = HessianAccumulator()
|
|
94
|
+
subs[full_name] = mod
|
|
95
|
+
accs[full_name] = acc
|
|
96
|
+
hooks.append(mod.register_forward_hook(_make_hook(acc)))
|
|
97
|
+
|
|
98
|
+
with torch.no_grad():
|
|
99
|
+
for x in inps:
|
|
100
|
+
_call_layer(layer, x.to(device), pe, device)
|
|
101
|
+
|
|
102
|
+
for h in hooks:
|
|
103
|
+
h.remove()
|
|
104
|
+
|
|
105
|
+
# Quantize each sublayer immediately — no simultaneous Hessian storage.
|
|
106
|
+
for full_name, linear in subs.items():
|
|
107
|
+
H = accs[full_name].finalize()
|
|
108
|
+
# Expose the Hessian so quantize_layer can find it.
|
|
109
|
+
quantizer._hessians[full_name] = H.cpu()
|
|
110
|
+
q_layer = quantizer.quantize_layer(linear, full_name)
|
|
111
|
+
_set_submodule(layer, full_name[len(layer_prefix) + 1:], q_layer)
|
|
112
|
+
# Clean up immediately to keep memory flat.
|
|
113
|
+
del quantizer._hessians[full_name], H
|
|
114
|
+
torch.cuda.empty_cache()
|
|
115
|
+
|
|
116
|
+
# Collect this block's outputs → inputs for the next block.
|
|
117
|
+
outs: list[Tensor] = []
|
|
118
|
+
with torch.no_grad():
|
|
119
|
+
for x in inps:
|
|
120
|
+
out = _call_layer(layer, x.to(device), pe, device)
|
|
121
|
+
outs.append(out.cpu())
|
|
122
|
+
inps = outs
|
|
123
|
+
|
|
124
|
+
layer.cpu()
|
|
125
|
+
torch.cuda.empty_cache()
|
|
126
|
+
print(f"[pare] layer {li + 1}/{n_layers} quantized", flush=True)
|
|
127
|
+
|
|
128
|
+
return model
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ---------------------------------------------------------------------------
|
|
132
|
+
# Internal helpers
|
|
133
|
+
# ---------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
class _Catcher(nn.Module):
|
|
136
|
+
"""Replaces layer 0 temporarily to intercept embedding outputs."""
|
|
137
|
+
def __init__(self, layer: nn.Module, store: list[Tensor]) -> None:
|
|
138
|
+
super().__init__()
|
|
139
|
+
self._layer = layer
|
|
140
|
+
self._store = store
|
|
141
|
+
|
|
142
|
+
def forward(self, x: Tensor, **_kw) -> Tensor:
|
|
143
|
+
self._store.append(x.cpu())
|
|
144
|
+
raise StopIteration
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _capture_embeddings(
|
|
148
|
+
model: nn.Module,
|
|
149
|
+
layers: nn.ModuleList,
|
|
150
|
+
calibration_data: list[Tensor],
|
|
151
|
+
device: torch.device,
|
|
152
|
+
) -> list[Tensor]:
|
|
153
|
+
"""Run the embedding + pre-block path, capture inputs to layer 0."""
|
|
154
|
+
inps: list[Tensor] = []
|
|
155
|
+
original = layers[0]
|
|
156
|
+
layers[0] = _Catcher(original, inps)
|
|
157
|
+
model.to(device)
|
|
158
|
+
try:
|
|
159
|
+
with torch.no_grad():
|
|
160
|
+
for input_ids in calibration_data:
|
|
161
|
+
try:
|
|
162
|
+
model(input_ids.to(device))
|
|
163
|
+
except StopIteration:
|
|
164
|
+
pass
|
|
165
|
+
finally:
|
|
166
|
+
layers[0] = original
|
|
167
|
+
return inps
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _precompute_pe(
|
|
171
|
+
model: nn.Module,
|
|
172
|
+
layers: nn.ModuleList,
|
|
173
|
+
inps: list[Tensor],
|
|
174
|
+
seq_len: int,
|
|
175
|
+
device: torch.device,
|
|
176
|
+
) -> tuple[Tensor, Tensor] | None:
|
|
177
|
+
"""Precompute (cos, sin) position embeddings for seq_len tokens.
|
|
178
|
+
|
|
179
|
+
In transformers >= 4.43, rotary_emb is a shared module at model.model
|
|
180
|
+
level. Older versions attach it per attention layer. Returns None if
|
|
181
|
+
the model doesn't use RoPE.
|
|
182
|
+
"""
|
|
183
|
+
rotary = getattr(model.model, "rotary_emb", None)
|
|
184
|
+
if rotary is None:
|
|
185
|
+
# Try old-style per-layer rotary (transformers < 4.43).
|
|
186
|
+
rotary = getattr(getattr(layers[0], "self_attn", None), "rotary_emb", None)
|
|
187
|
+
if rotary is None:
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
with torch.no_grad():
|
|
191
|
+
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
|
|
192
|
+
cos, sin = rotary(inps[0].to(device), position_ids)
|
|
193
|
+
return cos.cpu(), sin.cpu()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _call_layer(
|
|
197
|
+
layer: nn.Module,
|
|
198
|
+
x: Tensor,
|
|
199
|
+
pe: tuple[Tensor, Tensor] | None,
|
|
200
|
+
device: torch.device,
|
|
201
|
+
) -> Tensor:
|
|
202
|
+
"""Call a decoder layer, handling both new and old transformers APIs.
|
|
203
|
+
|
|
204
|
+
- New (>= 4.46): requires position_embeddings=(cos, sin) kwarg; returns Tensor.
|
|
205
|
+
- Old (< 4.46): uses position_ids internally; returns (hidden_states, ...) tuple.
|
|
206
|
+
"""
|
|
207
|
+
kwargs: dict = {"use_cache": False}
|
|
208
|
+
if pe is not None:
|
|
209
|
+
kwargs["position_embeddings"] = (pe[0].to(device), pe[1].to(device))
|
|
210
|
+
|
|
211
|
+
out = layer(x, **kwargs)
|
|
212
|
+
|
|
213
|
+
# Normalise output: newer transformers returns a plain Tensor,
|
|
214
|
+
# older versions return a tuple (hidden_states, ...).
|
|
215
|
+
if isinstance(out, tuple):
|
|
216
|
+
out = out[0]
|
|
217
|
+
return out
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _find_layers_path(model: nn.Module) -> str:
|
|
221
|
+
"""Return the dotted name of model.model.layers in named_modules()."""
|
|
222
|
+
target = model.model.layers
|
|
223
|
+
for name, mod in model.named_modules():
|
|
224
|
+
if mod is target:
|
|
225
|
+
return name
|
|
226
|
+
raise RuntimeError("Could not find model.model.layers in named_modules()")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _set_submodule(parent: nn.Module, rel_path: str, new_mod: nn.Module) -> None:
|
|
230
|
+
"""Set parent.a.b.c = new_mod given rel_path='a.b.c'."""
|
|
231
|
+
parts = rel_path.split(".")
|
|
232
|
+
m = parent
|
|
233
|
+
for part in parts[:-1]:
|
|
234
|
+
m = getattr(m, part)
|
|
235
|
+
setattr(m, parts[-1], new_mod)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _make_hook(acc: HessianAccumulator):
|
|
239
|
+
def hook(module: nn.Module, inputs: tuple, output: Tensor) -> None:
|
|
240
|
+
acc.accumulate(inputs[0].detach())
|
|
241
|
+
return hook
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Per-channel activation magnitude observer for AWQ.
|
|
2
|
+
|
|
3
|
+
AWQ needs to know which input channels carry large activations so it can
|
|
4
|
+
protect those weight columns during quantization. This module accumulates
|
|
5
|
+
the per-channel mean |x| online — just like HessianAccumulator but O(in)
|
|
6
|
+
instead of O(in²).
|
|
7
|
+
|
|
8
|
+
Usage::
|
|
9
|
+
|
|
10
|
+
obs = ActivationObserver()
|
|
11
|
+
hook = layer.register_forward_hook(
|
|
12
|
+
lambda m, inp, out: obs.accumulate(inp[0])
|
|
13
|
+
)
|
|
14
|
+
# ... run calibration forward passes ...
|
|
15
|
+
hook.remove()
|
|
16
|
+
x_max = obs.finalize() # [in_features] mean |x| per input channel
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from torch import Tensor
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ActivationObserver:
|
|
26
|
+
"""Online accumulator for per-channel activation statistics.
|
|
27
|
+
|
|
28
|
+
Tracks both mean |x| (used by AWQ) and max |x| (used by SmoothQuant)
|
|
29
|
+
in a single pass over calibration data.
|
|
30
|
+
|
|
31
|
+
Accumulates over any number of batches of shape
|
|
32
|
+
``[batch, seq_len, in_features]`` or ``[batch, in_features]``.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
self._sum: Tensor | None = None
|
|
37
|
+
self._max: Tensor | None = None
|
|
38
|
+
self._n_tokens: int = 0
|
|
39
|
+
|
|
40
|
+
def accumulate(self, x: Tensor) -> None:
|
|
41
|
+
x = x.detach().float()
|
|
42
|
+
if x.dim() == 3:
|
|
43
|
+
x = x.reshape(-1, x.shape[-1]) # [n_tokens, in_features]
|
|
44
|
+
elif x.dim() != 2:
|
|
45
|
+
raise ValueError(f"Expected 2-D or 3-D activation, got {tuple(x.shape)}")
|
|
46
|
+
|
|
47
|
+
x_abs = x.abs()
|
|
48
|
+
|
|
49
|
+
if self._sum is None:
|
|
50
|
+
self._sum = torch.zeros(x.shape[1], device=x.device, dtype=torch.float32)
|
|
51
|
+
self._max = torch.zeros(x.shape[1], device=x.device, dtype=torch.float32)
|
|
52
|
+
|
|
53
|
+
self._sum.add_(x_abs.sum(dim=0))
|
|
54
|
+
torch.maximum(self._max, x_abs.amax(dim=0), out=self._max)
|
|
55
|
+
self._n_tokens += x.shape[0]
|
|
56
|
+
|
|
57
|
+
def finalize(self) -> Tensor:
|
|
58
|
+
"""Return mean |x| per input channel, shape [in_features]. Used by AWQ."""
|
|
59
|
+
if self._sum is None or self._n_tokens == 0:
|
|
60
|
+
raise RuntimeError("ActivationObserver has no samples")
|
|
61
|
+
return self._sum / self._n_tokens
|
|
62
|
+
|
|
63
|
+
def max_abs(self) -> Tensor:
|
|
64
|
+
"""Return max |x| per input channel, shape [in_features]. Used by SmoothQuant."""
|
|
65
|
+
if self._max is None or self._n_tokens == 0:
|
|
66
|
+
raise RuntimeError("ActivationObserver has no samples")
|
|
67
|
+
return self._max.clone()
|
|
68
|
+
|
|
69
|
+
def reset(self) -> None:
|
|
70
|
+
self._sum = None
|
|
71
|
+
self._max = None
|
|
72
|
+
self._n_tokens = 0
|