pare-quant 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pare/__init__.py ADDED
@@ -0,0 +1,99 @@
1
+ """Pare — production-ready quantization for large language and multimodal models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch.nn as nn
6
+
7
+ from pare.config import QuantConfig
8
+ from pare.core.dtype import QuantDtype
9
+ from pare.model.io import load_quantized, save_quantized
10
+
11
+ __version__ = "0.1.0"
12
+ __all__ = [
13
+ "QuantConfig", "QuantDtype", "quantize",
14
+ "save_quantized", "load_quantized",
15
+ "__version__",
16
+ ]
17
+
18
+
19
+ def quantize(
20
+ model: nn.Module,
21
+ config: QuantConfig,
22
+ calibration_data: "list | None" = None,
23
+ device: "str" = "cpu",
24
+ ) -> nn.Module:
25
+ """Quantize a model in-place using the scheme specified in ``config``.
26
+
27
+ Args:
28
+ model: Any ``nn.Module`` (HuggingFace, custom, etc.).
29
+ config: A ``QuantConfig`` describing the target dtype, scheme,
30
+ granularity, and scheme-specific hyperparameters.
31
+ calibration_data: Required for GPTQ, AWQ, and SmoothQuant.
32
+ A list of input_ids tensors, each shaped [batch, seq_len].
33
+ Ignored for RTN.
34
+ device: Device to run calibration forward passes on (GPTQ only).
35
+
36
+ Returns:
37
+ The same model object with ``nn.Linear`` layers replaced by
38
+ ``QuantizedLinear`` instances.
39
+
40
+ Examples::
41
+
42
+ # RTN — no calibration needed
43
+ from pare import quantize, QuantConfig
44
+ config = QuantConfig(bits=4, scheme="rtn", group_size=128)
45
+ model = quantize(model, config)
46
+
47
+ # GPTQ — calibration data required
48
+ config = QuantConfig(bits=4, scheme="gptq", group_size=128)
49
+ model = quantize(model, config, calibration_data=calib_ids, device="cuda")
50
+ """
51
+ # Mixed-precision sensitivity: score layers BEFORE quantization so we
52
+ # still have the original FP16 weights. RTN is used as a fast proxy.
53
+ layer_bits_override: dict[str, int] = {}
54
+ if config.sensitive_bits is not None and calibration_data is not None:
55
+ from pare.sensitivity import score_layers
56
+ scores = score_layers(
57
+ model,
58
+ calibration_data,
59
+ bits=config.bits,
60
+ granularity=config.granularity,
61
+ group_size=config.group_size,
62
+ device=device,
63
+ )
64
+ layer_bits_override = {
65
+ name: config.sensitive_bits
66
+ for name, err in scores.items()
67
+ if err > config.sensitivity_threshold
68
+ }
69
+ n_total = len(scores)
70
+ n_sensitive = len(layer_bits_override)
71
+ print(
72
+ f"[pare] Sensitivity: {n_total} layers scored, "
73
+ f"{n_sensitive} above {config.sensitivity_threshold:.0%} threshold "
74
+ f"→ {config.bits}-bit→{config.sensitive_bits}-bit"
75
+ )
76
+
77
+ scheme = config.scheme
78
+ if scheme == "rtn":
79
+ from pare.schemes.rtn import RTNQuantizer
80
+ quantizer = RTNQuantizer(config, layer_bits_override=layer_bits_override)
81
+ return quantizer.quantize_model(model)
82
+
83
+ elif scheme == "gptq":
84
+ from pare.schemes.gptq import GPTQQuantizer
85
+ quantizer = GPTQQuantizer(config, layer_bits_override=layer_bits_override)
86
+ return quantizer.quantize_model(model, calibration_data=calibration_data, device=device)
87
+
88
+ elif scheme == "awq":
89
+ from pare.schemes.awq import AWQQuantizer
90
+ quantizer = AWQQuantizer(config, layer_bits_override=layer_bits_override)
91
+ return quantizer.quantize_model(model, calibration_data=calibration_data, device=device)
92
+
93
+ elif scheme == "smoothquant":
94
+ from pare.schemes.smoothquant import SmoothQuantQuantizer
95
+ quantizer = SmoothQuantQuantizer(config, layer_bits_override=layer_bits_override)
96
+ return quantizer.quantize_model(model, calibration_data=calibration_data, device=device)
97
+
98
+ else:
99
+ raise ValueError(f"Unknown scheme: {scheme!r}")
pare/__main__.py ADDED
@@ -0,0 +1,191 @@
1
+ """CLI entry point: python -m pare <command> [options]
2
+
3
+ Commands
4
+ --------
5
+ quantize Quantize a HuggingFace model and save to disk.
6
+ eval Evaluate a quantized (or FP16) model's perplexity.
7
+
8
+ Examples
9
+ --------
10
+ # GPTQ INT4, 128 calibration sequences, save to ./llama2-7b-int4/
11
+ python -m pare quantize \\
12
+ --model meta-llama/Llama-2-7b-hf \\
13
+ --bits 4 \\
14
+ --scheme gptq \\
15
+ --output ./llama2-7b-int4
16
+
17
+ # AWQ INT4 with custom alpha
18
+ python -m pare quantize \\
19
+ --model meta-llama/Llama-2-7b-hf \\
20
+ --bits 4 \\
21
+ --scheme awq \\
22
+ --output ./llama2-7b-awq
23
+
24
+ # SmoothQuant INT8
25
+ python -m pare quantize \\
26
+ --model meta-llama/Llama-2-7b-hf \\
27
+ --bits 8 \\
28
+ --scheme smoothquant \\
29
+ --output ./llama2-7b-sq-int8
30
+
31
+ # Evaluate a saved quantized model
32
+ python -m pare eval \\
33
+ --model meta-llama/Llama-2-7b-hf \\
34
+ --quantized ./llama2-7b-int4 \\
35
+ --dataset wikitext2
36
+
37
+ # Evaluate FP16 baseline (no --quantized flag)
38
+ python -m pare eval \\
39
+ --model meta-llama/Llama-2-7b-hf \\
40
+ --dataset wikitext2
41
+ """
42
+
43
+ from __future__ import annotations
44
+
45
+ import argparse
46
+ import sys
47
+
48
+
49
+ def _build_parser() -> argparse.ArgumentParser:
50
+ parser = argparse.ArgumentParser(
51
+ prog="python -m pare",
52
+ description="Pare — post-training quantization for LLMs",
53
+ )
54
+ sub = parser.add_subparsers(dest="command", required=True)
55
+
56
+ # ── quantize ──────────────────────────────────────────────────────
57
+ p_q = sub.add_parser("quantize", help="Quantize a model and save to disk")
58
+ p_q.add_argument("--model", required=True, help="HuggingFace model ID or local path")
59
+ p_q.add_argument("--output", required=True, help="Directory to save the quantized model")
60
+ p_q.add_argument("--bits", type=int, default=4, choices=[2, 3, 4, 8],
61
+ help="Target bit-width (default: 4)")
62
+ p_q.add_argument("--scheme", default="awq",
63
+ choices=["rtn", "gptq", "awq", "smoothquant"],
64
+ help="Quantization scheme (default: awq)")
65
+ p_q.add_argument("--granularity", default="per_group",
66
+ choices=["per_tensor", "per_channel", "per_group"],
67
+ help="Scale granularity (default: per_group)")
68
+ p_q.add_argument("--group-size", type=int, default=128,
69
+ help="Group size for per_group granularity (default: 128)")
70
+ p_q.add_argument("--sym", action="store_true",
71
+ help="Symmetric quantization (no zero-point)")
72
+ p_q.add_argument("--smooth-alpha", type=float, default=0.5,
73
+ help="SmoothQuant migration strength α (default: 0.5)")
74
+ p_q.add_argument("--n-calib", type=int, default=128,
75
+ help="Number of calibration sequences (default: 128)")
76
+ p_q.add_argument("--calib-seqlen", type=int, default=2048,
77
+ help="Length of each calibration sequence (default: 2048)")
78
+ p_q.add_argument("--device", default="cuda",
79
+ help="Device for calibration forward passes (default: cuda)")
80
+ p_q.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"],
81
+ help="Model load dtype (default: float16)")
82
+ p_q.add_argument("--token", default=None,
83
+ help="HuggingFace access token for gated models")
84
+
85
+ # ── eval ──────────────────────────────────────────────────────────
86
+ p_e = sub.add_parser("eval", help="Evaluate perplexity of a model")
87
+ p_e.add_argument("--model", required=True, help="HuggingFace model ID or local path")
88
+ p_e.add_argument("--quantized", default=None,
89
+ help="Path to a Pare-saved quantized model (omit for FP16 baseline)")
90
+ p_e.add_argument("--dataset", default="wikitext2",
91
+ choices=["wikitext2", "c4"],
92
+ help="Evaluation dataset (default: wikitext2)")
93
+ p_e.add_argument("--seq-len", type=int, default=2048,
94
+ help="Sequence length (default: 2048)")
95
+ p_e.add_argument("--n-samples", type=int, default=None,
96
+ help="Number of sequences to evaluate (default: full dataset)")
97
+ p_e.add_argument("--device", default="cuda")
98
+ p_e.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"])
99
+ p_e.add_argument("--token", default=None)
100
+
101
+ return parser
102
+
103
+
104
+ def cmd_quantize(args: argparse.Namespace) -> None:
105
+ import torch
106
+ from datasets import load_dataset
107
+ from transformers import AutoModelForCausalLM, AutoTokenizer
108
+
109
+ from pare import QuantConfig, quantize, save_quantized
110
+
111
+ dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
112
+ torch_dtype = dtype_map[args.dtype]
113
+
114
+ print(f"[pare] Loading {args.model} ...", flush=True)
115
+ tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.token)
116
+ model = AutoModelForCausalLM.from_pretrained(
117
+ args.model, torch_dtype=torch_dtype, token=args.token, device_map="auto",
118
+ )
119
+ model.eval()
120
+
121
+ config = QuantConfig(
122
+ bits=args.bits,
123
+ scheme=args.scheme,
124
+ granularity=args.granularity,
125
+ group_size=args.group_size,
126
+ sym=args.sym,
127
+ smooth_alpha=args.smooth_alpha,
128
+ )
129
+
130
+ calib_data = None
131
+ if args.scheme != "rtn":
132
+ print(f"[pare] Preparing {args.n_calib} calibration sequences ...", flush=True)
133
+ ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
134
+ text = "\n\n".join(ds["text"])
135
+ tokens = tokenizer(text, return_tensors="pt").input_ids
136
+ seq_len = args.calib_seqlen
137
+ calib_data = [tokens[:, i * seq_len : (i + 1) * seq_len] for i in range(args.n_calib)]
138
+
139
+ print(f"[pare] Quantizing with {args.scheme} INT{args.bits} ...", flush=True)
140
+ quantize(model, config, calibration_data=calib_data, device=args.device)
141
+
142
+ save_quantized(model, args.output)
143
+ print(f"[pare] Done. Model saved to {args.output}", flush=True)
144
+
145
+
146
+ def cmd_eval(args: argparse.Namespace) -> None:
147
+ import torch
148
+ from transformers import AutoModelForCausalLM, AutoTokenizer
149
+
150
+ from pare import load_quantized
151
+ from pare.eval.perplexity import evaluate_perplexity
152
+
153
+ dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
154
+ torch_dtype = dtype_map[args.dtype]
155
+
156
+ print(f"[pare] Loading {args.model} ...", flush=True)
157
+ tokenizer = AutoTokenizer.from_pretrained(args.model, token=args.token)
158
+ model = AutoModelForCausalLM.from_pretrained(
159
+ args.model, torch_dtype=torch_dtype, token=args.token, device_map="auto",
160
+ )
161
+ model.eval()
162
+
163
+ if args.quantized:
164
+ print(f"[pare] Loading quantized weights from {args.quantized} ...", flush=True)
165
+ load_quantized(model, args.quantized)
166
+
167
+ print(f"[pare] Evaluating PPL on {args.dataset} ...", flush=True)
168
+ ppl = evaluate_perplexity(
169
+ model, tokenizer, dataset=args.dataset,
170
+ seq_len=args.seq_len, n_samples=args.n_samples, device=args.device,
171
+ )
172
+
173
+ tag = f"{args.quantized or args.dtype}"
174
+ print(f"\n {args.model} [{tag}] {args.dataset} PPL: {ppl:.2f}")
175
+
176
+
177
+ def main() -> None:
178
+ parser = _build_parser()
179
+ args = parser.parse_args()
180
+
181
+ if args.command == "quantize":
182
+ cmd_quantize(args)
183
+ elif args.command == "eval":
184
+ cmd_eval(args)
185
+ else:
186
+ parser.print_help()
187
+ sys.exit(1)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main()
File without changes
@@ -0,0 +1,73 @@
1
+ """Per-layer Hessian accumulation for GPTQ.
2
+
3
+ The GPTQ objective for one weight row w is:
4
+
5
+ (w - q)^T H (w - q)
6
+
7
+ where H = 2/n * X X^T, X shaped [in_features, n_tokens].
8
+
9
+ We accumulate H incrementally over batches so we never store all
10
+ activations in memory — only the running [in_features, in_features] sum.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import torch
16
+ from torch import Tensor
17
+
18
+
19
+ class HessianAccumulator:
20
+ """Accumulates the per-layer Hessian H = 2/n * X X^T online.
21
+
22
+ Usage::
23
+
24
+ acc = HessianAccumulator()
25
+ for batch_activations in ...:
26
+ acc.accumulate(batch_activations)
27
+ H = acc.finalize() # [in_features, in_features]
28
+ """
29
+
30
+ def __init__(self) -> None:
31
+ self.H: Tensor | None = None
32
+ self.n_tokens: int = 0
33
+
34
+ def accumulate(self, x: Tensor) -> None:
35
+ """Add one batch of activations to the running sum.
36
+
37
+ Args:
38
+ x: Input activations, shape [batch, seq_len, in_features]
39
+ or [batch, in_features]. Will be cast to float32.
40
+ """
41
+ x = x.detach().float()
42
+
43
+ # Flatten batch and sequence dims → [n_tokens, in_features]
44
+ if x.dim() == 3:
45
+ x = x.reshape(-1, x.shape[-1])
46
+ elif x.dim() != 2:
47
+ raise ValueError(f"Expected 2-D or 3-D activation, got shape {tuple(x.shape)}")
48
+
49
+ n = x.shape[0] # number of tokens in this batch
50
+
51
+ if self.H is None:
52
+ in_features = x.shape[1]
53
+ self.H = torch.zeros(
54
+ in_features, in_features,
55
+ device=x.device, dtype=torch.float32,
56
+ )
57
+
58
+ # H_unnorm += X_row^T @ X_row (= X_col @ X_col^T where X_col = x.T)
59
+ self.H.addmm_(x.T, x) # in-place fused multiply-add, no extra alloc
60
+ self.n_tokens += n
61
+
62
+ def finalize(self) -> Tensor:
63
+ """Return the normalised Hessian H = 2/n * ΣX^TX.
64
+
65
+ Raises RuntimeError if no samples have been accumulated.
66
+ """
67
+ if self.H is None or self.n_tokens == 0:
68
+ raise RuntimeError("HessianAccumulator has no samples — did you forget to call accumulate()?")
69
+ return self.H * (2.0 / self.n_tokens)
70
+
71
+ def reset(self) -> None:
72
+ self.H = None
73
+ self.n_tokens = 0
@@ -0,0 +1,241 @@
1
+ """Layerwise GPTQ for transformer block models (Llama, Mistral, Qwen, etc.).
2
+
3
+ Problem: For 7B+ models, accumulating Hessians for all layers simultaneously
4
+ requires ~39 GB on top of the 14 GB model — OOM on 48 GB GPUs.
5
+
6
+ Solution: process one transformer block at a time:
7
+ 1. Intercept the first block with a Catcher to capture embedding outputs.
8
+ 2. Precompute position_embeddings once from the shared rotary_emb.
9
+ 3. For each block i:
10
+ a. Load block to GPU.
11
+ b. Register hooks on its linear sublayers; run all calibration inputs.
12
+ c. GPTQ-quantize each sublayer immediately (peak VRAM: ~2 GB).
13
+ d. Collect outputs (= inputs to block i+1).
14
+ e. Offload block to CPU.
15
+
16
+ Supported architectures: any HF model with model.model.layers (Llama, Mistral,
17
+ Qwen2, Phi-3, Gemma, ...).
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch import Tensor
25
+
26
+ from pare.calibration.hessian import HessianAccumulator
27
+
28
+
29
+ def is_supported(model: nn.Module) -> bool:
30
+ """Return True if this model has the block structure we support."""
31
+ return (
32
+ hasattr(model, "model")
33
+ and hasattr(model.model, "layers")
34
+ and len(model.model.layers) > 0
35
+ and isinstance(model.model.layers[0], nn.Module)
36
+ )
37
+
38
+
39
+ class LayerwiseGPTQ:
40
+ """Block-by-block GPTQ. Keeps peak GPU usage to O(one_block + calib_acts)."""
41
+
42
+ def run(
43
+ self,
44
+ model: nn.Module,
45
+ calibration_data: list[Tensor],
46
+ quantizer,
47
+ device: str | torch.device,
48
+ ) -> nn.Module:
49
+ """Quantize model in-place using layerwise Hessian collection.
50
+
51
+ Args:
52
+ model: Model to quantize (Llama/Mistral-family HF model).
53
+ calibration_data: List of input_ids tensors [1, seq_len].
54
+ quantizer: GPTQQuantizer instance; quantize_layer is called
55
+ for each sublayer.
56
+ device: GPU device string, e.g. "cuda".
57
+
58
+ Returns:
59
+ The same model, quantized in-place.
60
+ """
61
+ device = torch.device(device)
62
+ layers = model.model.layers
63
+ seq_len = calibration_data[0].shape[1]
64
+
65
+ # ── 1. Capture embedding outputs (inputs to layer 0) ───────────────
66
+ inps = _capture_embeddings(model, layers, calibration_data, device)
67
+
68
+ # ── 2. Precompute shared position_embeddings ───────────────────────
69
+ pe = _precompute_pe(model, layers, inps, seq_len, device)
70
+ model.cpu()
71
+ torch.cuda.empty_cache()
72
+
73
+ # ── 3. Resolve the dotted path to model.model.layers ──────────────
74
+ layers_path = _find_layers_path(model)
75
+
76
+ # ── 4. Process one block at a time ─────────────────────────────────
77
+ n_layers = len(layers)
78
+ for li, layer in enumerate(layers):
79
+ layer_prefix = f"{layers_path}.{li}"
80
+ layer.to(device)
81
+
82
+ # Register hooks on every quantizable linear sublayer in this block.
83
+ subs: dict[str, nn.Linear] = {}
84
+ accs: dict[str, HessianAccumulator] = {}
85
+ hooks = []
86
+
87
+ for rel_name, mod in layer.named_modules():
88
+ if not isinstance(mod, nn.Linear):
89
+ continue
90
+ full_name = f"{layer_prefix}.{rel_name}"
91
+ if not quantizer._should_quantize(full_name, mod):
92
+ continue
93
+ acc = HessianAccumulator()
94
+ subs[full_name] = mod
95
+ accs[full_name] = acc
96
+ hooks.append(mod.register_forward_hook(_make_hook(acc)))
97
+
98
+ with torch.no_grad():
99
+ for x in inps:
100
+ _call_layer(layer, x.to(device), pe, device)
101
+
102
+ for h in hooks:
103
+ h.remove()
104
+
105
+ # Quantize each sublayer immediately — no simultaneous Hessian storage.
106
+ for full_name, linear in subs.items():
107
+ H = accs[full_name].finalize()
108
+ # Expose the Hessian so quantize_layer can find it.
109
+ quantizer._hessians[full_name] = H.cpu()
110
+ q_layer = quantizer.quantize_layer(linear, full_name)
111
+ _set_submodule(layer, full_name[len(layer_prefix) + 1:], q_layer)
112
+ # Clean up immediately to keep memory flat.
113
+ del quantizer._hessians[full_name], H
114
+ torch.cuda.empty_cache()
115
+
116
+ # Collect this block's outputs → inputs for the next block.
117
+ outs: list[Tensor] = []
118
+ with torch.no_grad():
119
+ for x in inps:
120
+ out = _call_layer(layer, x.to(device), pe, device)
121
+ outs.append(out.cpu())
122
+ inps = outs
123
+
124
+ layer.cpu()
125
+ torch.cuda.empty_cache()
126
+ print(f"[pare] layer {li + 1}/{n_layers} quantized", flush=True)
127
+
128
+ return model
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # Internal helpers
133
+ # ---------------------------------------------------------------------------
134
+
135
+ class _Catcher(nn.Module):
136
+ """Replaces layer 0 temporarily to intercept embedding outputs."""
137
+ def __init__(self, layer: nn.Module, store: list[Tensor]) -> None:
138
+ super().__init__()
139
+ self._layer = layer
140
+ self._store = store
141
+
142
+ def forward(self, x: Tensor, **_kw) -> Tensor:
143
+ self._store.append(x.cpu())
144
+ raise StopIteration
145
+
146
+
147
+ def _capture_embeddings(
148
+ model: nn.Module,
149
+ layers: nn.ModuleList,
150
+ calibration_data: list[Tensor],
151
+ device: torch.device,
152
+ ) -> list[Tensor]:
153
+ """Run the embedding + pre-block path, capture inputs to layer 0."""
154
+ inps: list[Tensor] = []
155
+ original = layers[0]
156
+ layers[0] = _Catcher(original, inps)
157
+ model.to(device)
158
+ try:
159
+ with torch.no_grad():
160
+ for input_ids in calibration_data:
161
+ try:
162
+ model(input_ids.to(device))
163
+ except StopIteration:
164
+ pass
165
+ finally:
166
+ layers[0] = original
167
+ return inps
168
+
169
+
170
+ def _precompute_pe(
171
+ model: nn.Module,
172
+ layers: nn.ModuleList,
173
+ inps: list[Tensor],
174
+ seq_len: int,
175
+ device: torch.device,
176
+ ) -> tuple[Tensor, Tensor] | None:
177
+ """Precompute (cos, sin) position embeddings for seq_len tokens.
178
+
179
+ In transformers >= 4.43, rotary_emb is a shared module at model.model
180
+ level. Older versions attach it per attention layer. Returns None if
181
+ the model doesn't use RoPE.
182
+ """
183
+ rotary = getattr(model.model, "rotary_emb", None)
184
+ if rotary is None:
185
+ # Try old-style per-layer rotary (transformers < 4.43).
186
+ rotary = getattr(getattr(layers[0], "self_attn", None), "rotary_emb", None)
187
+ if rotary is None:
188
+ return None
189
+
190
+ with torch.no_grad():
191
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
192
+ cos, sin = rotary(inps[0].to(device), position_ids)
193
+ return cos.cpu(), sin.cpu()
194
+
195
+
196
+ def _call_layer(
197
+ layer: nn.Module,
198
+ x: Tensor,
199
+ pe: tuple[Tensor, Tensor] | None,
200
+ device: torch.device,
201
+ ) -> Tensor:
202
+ """Call a decoder layer, handling both new and old transformers APIs.
203
+
204
+ - New (>= 4.46): requires position_embeddings=(cos, sin) kwarg; returns Tensor.
205
+ - Old (< 4.46): uses position_ids internally; returns (hidden_states, ...) tuple.
206
+ """
207
+ kwargs: dict = {"use_cache": False}
208
+ if pe is not None:
209
+ kwargs["position_embeddings"] = (pe[0].to(device), pe[1].to(device))
210
+
211
+ out = layer(x, **kwargs)
212
+
213
+ # Normalise output: newer transformers returns a plain Tensor,
214
+ # older versions return a tuple (hidden_states, ...).
215
+ if isinstance(out, tuple):
216
+ out = out[0]
217
+ return out
218
+
219
+
220
+ def _find_layers_path(model: nn.Module) -> str:
221
+ """Return the dotted name of model.model.layers in named_modules()."""
222
+ target = model.model.layers
223
+ for name, mod in model.named_modules():
224
+ if mod is target:
225
+ return name
226
+ raise RuntimeError("Could not find model.model.layers in named_modules()")
227
+
228
+
229
+ def _set_submodule(parent: nn.Module, rel_path: str, new_mod: nn.Module) -> None:
230
+ """Set parent.a.b.c = new_mod given rel_path='a.b.c'."""
231
+ parts = rel_path.split(".")
232
+ m = parent
233
+ for part in parts[:-1]:
234
+ m = getattr(m, part)
235
+ setattr(m, parts[-1], new_mod)
236
+
237
+
238
+ def _make_hook(acc: HessianAccumulator):
239
+ def hook(module: nn.Module, inputs: tuple, output: Tensor) -> None:
240
+ acc.accumulate(inputs[0].detach())
241
+ return hook
@@ -0,0 +1,72 @@
1
+ """Per-channel activation magnitude observer for AWQ.
2
+
3
+ AWQ needs to know which input channels carry large activations so it can
4
+ protect those weight columns during quantization. This module accumulates
5
+ the per-channel mean |x| online — just like HessianAccumulator but O(in)
6
+ instead of O(in²).
7
+
8
+ Usage::
9
+
10
+ obs = ActivationObserver()
11
+ hook = layer.register_forward_hook(
12
+ lambda m, inp, out: obs.accumulate(inp[0])
13
+ )
14
+ # ... run calibration forward passes ...
15
+ hook.remove()
16
+ x_max = obs.finalize() # [in_features] mean |x| per input channel
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import torch
22
+ from torch import Tensor
23
+
24
+
25
+ class ActivationObserver:
26
+ """Online accumulator for per-channel activation statistics.
27
+
28
+ Tracks both mean |x| (used by AWQ) and max |x| (used by SmoothQuant)
29
+ in a single pass over calibration data.
30
+
31
+ Accumulates over any number of batches of shape
32
+ ``[batch, seq_len, in_features]`` or ``[batch, in_features]``.
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ self._sum: Tensor | None = None
37
+ self._max: Tensor | None = None
38
+ self._n_tokens: int = 0
39
+
40
+ def accumulate(self, x: Tensor) -> None:
41
+ x = x.detach().float()
42
+ if x.dim() == 3:
43
+ x = x.reshape(-1, x.shape[-1]) # [n_tokens, in_features]
44
+ elif x.dim() != 2:
45
+ raise ValueError(f"Expected 2-D or 3-D activation, got {tuple(x.shape)}")
46
+
47
+ x_abs = x.abs()
48
+
49
+ if self._sum is None:
50
+ self._sum = torch.zeros(x.shape[1], device=x.device, dtype=torch.float32)
51
+ self._max = torch.zeros(x.shape[1], device=x.device, dtype=torch.float32)
52
+
53
+ self._sum.add_(x_abs.sum(dim=0))
54
+ torch.maximum(self._max, x_abs.amax(dim=0), out=self._max)
55
+ self._n_tokens += x.shape[0]
56
+
57
+ def finalize(self) -> Tensor:
58
+ """Return mean |x| per input channel, shape [in_features]. Used by AWQ."""
59
+ if self._sum is None or self._n_tokens == 0:
60
+ raise RuntimeError("ActivationObserver has no samples")
61
+ return self._sum / self._n_tokens
62
+
63
+ def max_abs(self) -> Tensor:
64
+ """Return max |x| per input channel, shape [in_features]. Used by SmoothQuant."""
65
+ if self._max is None or self._n_tokens == 0:
66
+ raise RuntimeError("ActivationObserver has no samples")
67
+ return self._max.clone()
68
+
69
+ def reset(self) -> None:
70
+ self._sum = None
71
+ self._max = None
72
+ self._n_tokens = 0