tico 0.1.0.dev250902__py3-none-any.whl → 0.1.0.dev250904__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/ptq/__init__.py +1 -13
- tico/experimental/quantization/ptq/examples/compare_ppl.py +1 -1
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +129 -0
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +165 -0
- tico/experimental/quantization/ptq/observers/__init__.py +1 -15
- tico/experimental/quantization/ptq/observers/ema.py +1 -1
- tico/experimental/quantization/ptq/observers/minmax.py +1 -1
- tico/experimental/quantization/ptq/utils/__init__.py +1 -7
- tico/experimental/quantization/ptq/utils/introspection.py +169 -0
- tico/experimental/quantization/ptq/wrappers/__init__.py +1 -0
- tico/experimental/quantization/ptq/wrappers/llama/__init__.py +1 -9
- tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py +4 -1
- tico/experimental/quantization/ptq/wrappers/nn/__init__.py +1 -11
- tico/experimental/quantization/ptq/wrappers/registry.py +12 -9
- {tico-0.1.0.dev250902.dist-info → tico-0.1.0.dev250904.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250902.dist-info → tico-0.1.0.dev250904.dist-info}/RECORD +21 -18
- {tico-0.1.0.dev250902.dist-info → tico-0.1.0.dev250904.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250902.dist-info → tico-0.1.0.dev250904.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250902.dist-info → tico-0.1.0.dev250904.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250902.dist-info → tico-0.1.0.dev250904.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -1,13 +1 @@
|
|
1
|
-
|
2
|
-
Public PTQ API — re-export the most common symbols.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from tico.experimental.quantization.ptq.dtypes import DType
|
6
|
-
from tico.experimental.quantization.ptq.mode import Mode
|
7
|
-
from tico.experimental.quantization.ptq.qscheme import QScheme
|
8
|
-
|
9
|
-
__all__ = [
|
10
|
-
"DType",
|
11
|
-
"Mode",
|
12
|
-
"QScheme",
|
13
|
-
]
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -26,7 +26,7 @@ from datasets import load_dataset
|
|
26
26
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
27
27
|
|
28
28
|
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
29
|
-
from tico.experimental.quantization.ptq.utils import perplexity
|
29
|
+
from tico.experimental.quantization.ptq.utils.metrics import perplexity
|
30
30
|
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
31
31
|
|
32
32
|
# -------------------------------------------------------------------------
|
@@ -0,0 +1,129 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
import tqdm
|
17
|
+
from datasets import load_dataset
|
18
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
19
|
+
|
20
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
21
|
+
from tico.experimental.quantization.ptq.utils.introspection import (
|
22
|
+
build_fqn_map,
|
23
|
+
compare_layer_outputs,
|
24
|
+
save_fp_outputs,
|
25
|
+
)
|
26
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
27
|
+
|
28
|
+
# ============================================================================
|
29
|
+
# LAYER-WISE DIFF DEBUGGING PIPELINE
|
30
|
+
# ----------------------------------------------------------------------------
|
31
|
+
# A quantization debugging pipeline that identifies accuracy regressions
|
32
|
+
# by comparing UINT vs FP outputs at each layer.
|
33
|
+
#
|
34
|
+
# 1. Load a full-precision (FP) LLaMA-3-1B model.
|
35
|
+
# 2. Wrap each Transformer block with PTQWrapper (activations → fake-quant).
|
36
|
+
# 3. Capture reference FP layer outputs before quantization.
|
37
|
+
# 4. Calibrate UINT-8 activation observers in a single pass.
|
38
|
+
# 5. Freeze quantization parameters (scale, zero-point).
|
39
|
+
# 6. Re-run inference and compare UINT-8 vs FP outputs per layer.
|
40
|
+
# 7. Report where quantization hurts the most.
|
41
|
+
#
|
42
|
+
# Use this pipeline to trace precision loss layer by layer, and pinpoint
|
43
|
+
# problematic modules during post-training quantization.
|
44
|
+
# ============================================================================
|
45
|
+
|
46
|
+
# -------------------------------------------------------------------------
|
47
|
+
# 0. Global configuration
|
48
|
+
# -------------------------------------------------------------------------
|
49
|
+
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
50
|
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
51
|
+
STRIDE = 512
|
52
|
+
|
53
|
+
# Token-budget presets for activation calibration
|
54
|
+
TOKENS: dict[str, int] = {
|
55
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
56
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
57
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
58
|
+
"baseline": 50_000,
|
59
|
+
# Production / 4-bit observer smoothing
|
60
|
+
"production": 200_000,
|
61
|
+
}
|
62
|
+
CALIB_TOKENS = TOKENS["baseline"]
|
63
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
64
|
+
|
65
|
+
# -------------------------------------------------------------------------
|
66
|
+
# 1. Load the FP backbone
|
67
|
+
# -------------------------------------------------------------------------
|
68
|
+
print("Loading FP model …")
|
69
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
70
|
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
71
|
+
model.config.use_cache = False # disable KV-cache → full forward
|
72
|
+
m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
|
73
|
+
|
74
|
+
# Use Wikitext-2 train split for calibration.
|
75
|
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
76
|
+
|
77
|
+
# -------------------------------------------------------------------------
|
78
|
+
# 2. Wrap every layer with PTQWrapper (UINT-8 activations)
|
79
|
+
# -------------------------------------------------------------------------
|
80
|
+
print("Wrapping layers with PTQWrapper …")
|
81
|
+
qcfg = QuantConfig() # default: per-tensor UINT8
|
82
|
+
|
83
|
+
new_layers = torch.nn.ModuleList()
|
84
|
+
for idx, fp_layer in enumerate(model.model.layers):
|
85
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
86
|
+
q_layer = PTQWrapper(
|
87
|
+
fp_layer,
|
88
|
+
qcfg=layer_cfg,
|
89
|
+
fp_name=m_to_fqn.get(fp_layer),
|
90
|
+
)
|
91
|
+
new_layers.append(q_layer)
|
92
|
+
|
93
|
+
model.model.layers = new_layers # swap in quant wrappers
|
94
|
+
|
95
|
+
# -------------------------------------------------------------------------
|
96
|
+
# 3. Activation calibration plus FP-vs-UINT8 diffing
|
97
|
+
# -------------------------------------------------------------------------
|
98
|
+
print("Calibrating UINT-8 observers …")
|
99
|
+
calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
|
100
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
101
|
+
|
102
|
+
# (a) Enable CALIB mode on every QuantModuleBase
|
103
|
+
for l in model.model.layers:
|
104
|
+
l.enable_calibration()
|
105
|
+
|
106
|
+
# Save reference FP activations before observers clamp/quantize
|
107
|
+
save_handles, act_cache = save_fp_outputs(model)
|
108
|
+
|
109
|
+
with torch.no_grad():
|
110
|
+
for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
|
111
|
+
inputs = ids[:, i : i + STRIDE]
|
112
|
+
model(inputs) # observers collect act. ranges
|
113
|
+
|
114
|
+
# Remove save hooks now that FP activations are cached
|
115
|
+
for h in save_handles:
|
116
|
+
h.remove()
|
117
|
+
|
118
|
+
# (b) Freeze (scale, zero-point) after calibration
|
119
|
+
for l in model.model.layers:
|
120
|
+
l.freeze_qparams()
|
121
|
+
|
122
|
+
# (c) Register diff hooks and measure per-layer deltas
|
123
|
+
cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
|
124
|
+
# Use same inputs for comparison.
|
125
|
+
model(inputs)
|
126
|
+
|
127
|
+
assert isinstance(cmp_handles, list)
|
128
|
+
for h in cmp_handles:
|
129
|
+
h.remove()
|
@@ -0,0 +1,165 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# =============================================================================
|
16
|
+
# PTQ + GPTQ HYBRID QUANTIZATION PIPELINE
|
17
|
+
# -----------------------------------------------------------------------------
|
18
|
+
# This script shows how to:
|
19
|
+
# 1. Load a pretrained FP Llama-3 model.
|
20
|
+
# 2. Run GPTQ to quantize weights only.
|
21
|
+
# 3. Wrap every Transformer layer with a PTQWrapper to quantize activations.
|
22
|
+
# 4. Calibrate UINT-8 observers in a single pass over a text corpus.
|
23
|
+
# 5. Inject GPTQ’s per-tensor weight scales / zero-points into the PTQ graph.
|
24
|
+
# 6. Freeze all Q-params and compute Wikitext-2 perplexity.
|
25
|
+
# =============================================================================
|
26
|
+
|
27
|
+
from typing import Any
|
28
|
+
|
29
|
+
import torch
|
30
|
+
import tqdm
|
31
|
+
from datasets import load_dataset
|
32
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
33
|
+
|
34
|
+
from tico.experimental.quantization import convert, prepare
|
35
|
+
from tico.experimental.quantization.config import GPTQConfig
|
36
|
+
from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
|
37
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
38
|
+
from tico.experimental.quantization.ptq.utils.introspection import build_fqn_map
|
39
|
+
from tico.experimental.quantization.ptq.utils.metrics import perplexity
|
40
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
41
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
42
|
+
QuantModuleBase,
|
43
|
+
)
|
44
|
+
|
45
|
+
# -------------------------------------------------------------------------
|
46
|
+
# 0. Global configuration
|
47
|
+
# -------------------------------------------------------------------------
|
48
|
+
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
49
|
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
50
|
+
STRIDE = 512
|
51
|
+
|
52
|
+
# Token-budget presets for activation calibration
|
53
|
+
TOKENS: dict[str, int] = {
|
54
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
55
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
56
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
57
|
+
"baseline": 50_000,
|
58
|
+
# Production / 4-bit observer smoothing
|
59
|
+
"production": 200_000,
|
60
|
+
}
|
61
|
+
CALIB_TOKENS = TOKENS["baseline"]
|
62
|
+
|
63
|
+
# -------------------------------------------------------------------------
|
64
|
+
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
|
65
|
+
# -------------------------------------------------------------------------
|
66
|
+
def inject_gptq_qparams(
|
67
|
+
root: torch.nn.Module,
|
68
|
+
gptq_quantizers: dict[str, Any], # {fp_name: quantizer}
|
69
|
+
weight_obs_name: str = "weight",
|
70
|
+
):
|
71
|
+
"""
|
72
|
+
For every `QuantModuleBase` whose `fp_name` matches a GPTQ key,
|
73
|
+
locate the observer called `weight_obs_name` and overwrite its
|
74
|
+
(scale, zero-point), then lock them against further updates.
|
75
|
+
"""
|
76
|
+
for m in root.modules():
|
77
|
+
if not isinstance(m, QuantModuleBase):
|
78
|
+
continue
|
79
|
+
if m.fp_name is None:
|
80
|
+
continue
|
81
|
+
quantizer = gptq_quantizers.get(m.fp_name)
|
82
|
+
if quantizer is None:
|
83
|
+
continue
|
84
|
+
obs = m.get_observer(weight_obs_name)
|
85
|
+
if obs is None:
|
86
|
+
continue
|
87
|
+
assert isinstance(obs, AffineObserverBase)
|
88
|
+
# GPTQ quantizer attributes
|
89
|
+
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
|
90
|
+
|
91
|
+
|
92
|
+
# -------------------------------------------------------------------------
|
93
|
+
# 2. Load the FP backbone
|
94
|
+
# -------------------------------------------------------------------------
|
95
|
+
print("Loading FP model …")
|
96
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
97
|
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
98
|
+
model.config.use_cache = False # disable KV-cache → full forward
|
99
|
+
m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
|
100
|
+
|
101
|
+
# -------------------------------------------------------------------------
|
102
|
+
# 3. Run GPTQ (weight-only) pass
|
103
|
+
# -------------------------------------------------------------------------
|
104
|
+
print("Applying GPTQ …")
|
105
|
+
dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="test")
|
106
|
+
q_m = prepare(model, GPTQConfig(), inplace=True)
|
107
|
+
|
108
|
+
for d in tqdm.tqdm(dataset, desc="GPTQ calibration"):
|
109
|
+
ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(DEVICE)
|
110
|
+
q_m(ids) # observers gather weight stats
|
111
|
+
|
112
|
+
q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
|
113
|
+
|
114
|
+
# -------------------------------------------------------------------------
|
115
|
+
# 4. Wrap every layer with PTQWrapper (activation UINT-8)
|
116
|
+
# -------------------------------------------------------------------------
|
117
|
+
qcfg = QuantConfig() # default: per-tensor UINT8
|
118
|
+
new_layers = torch.nn.ModuleList()
|
119
|
+
|
120
|
+
for idx, fp_layer in enumerate(q_m.model.layers):
|
121
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
122
|
+
q_layer = PTQWrapper(
|
123
|
+
fp_layer,
|
124
|
+
qcfg=layer_cfg,
|
125
|
+
fp_name=m_to_fqn.get(fp_layer),
|
126
|
+
)
|
127
|
+
new_layers.append(q_layer)
|
128
|
+
|
129
|
+
q_m.model.layers = new_layers
|
130
|
+
|
131
|
+
# -------------------------------------------------------------------------
|
132
|
+
# 5. Single-pass activation calibration
|
133
|
+
# -------------------------------------------------------------------------
|
134
|
+
print("Calibrating UINT-8 observers …")
|
135
|
+
calib_txt = " ".join(
|
136
|
+
load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
|
137
|
+
)[:CALIB_TOKENS]
|
138
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
139
|
+
|
140
|
+
# (a) Enable CALIB mode on every QuantModuleBase
|
141
|
+
for l in q_m.model.layers:
|
142
|
+
l.enable_calibration()
|
143
|
+
|
144
|
+
# (b) Overwrite weight observers with GPTQ statistics
|
145
|
+
inject_gptq_qparams(q_m, q_m.quantizers)
|
146
|
+
|
147
|
+
with torch.no_grad():
|
148
|
+
for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
|
149
|
+
q_m(ids[:, i : i + STRIDE]) # observers collect act. ranges
|
150
|
+
|
151
|
+
# (c) Freeze all Q-params (scale, zp)
|
152
|
+
for l in q_m.model.layers:
|
153
|
+
l.freeze_qparams()
|
154
|
+
|
155
|
+
# -------------------------------------------------------------------------
|
156
|
+
# 6. Evaluate perplexity on Wikitext-2
|
157
|
+
# -------------------------------------------------------------------------
|
158
|
+
print("\nCalculating perplexities …")
|
159
|
+
test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
160
|
+
enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
|
161
|
+
ppl_uint8 = perplexity(q_m, enc, DEVICE, stride=STRIDE)
|
162
|
+
|
163
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
164
|
+
print(f"│ UINT-8 : {ppl_uint8:8.2f}")
|
165
|
+
print("└───────────────────────────────────────────")
|
@@ -1,15 +1 @@
|
|
1
|
-
|
2
|
-
from tico.experimental.quantization.ptq.observers.base import ObserverBase
|
3
|
-
from tico.experimental.quantization.ptq.observers.ema import EMAObserver
|
4
|
-
from tico.experimental.quantization.ptq.observers.identity import IdentityObserver
|
5
|
-
from tico.experimental.quantization.ptq.observers.minmax import MinMaxObserver
|
6
|
-
from tico.experimental.quantization.ptq.observers.mx import MXObserver
|
7
|
-
|
8
|
-
__all__ = [
|
9
|
-
"AffineObserverBase",
|
10
|
-
"ObserverBase",
|
11
|
-
"EMAObserver",
|
12
|
-
"IdentityObserver",
|
13
|
-
"MinMaxObserver",
|
14
|
-
"MXObserver",
|
15
|
-
]
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import torch
|
16
16
|
|
17
17
|
from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
|
18
|
-
from tico.experimental.quantization.ptq.utils import channelwise_minmax
|
18
|
+
from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
|
19
19
|
|
20
20
|
|
21
21
|
class EMAObserver(AffineObserverBase):
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import torch
|
16
16
|
|
17
17
|
from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
|
18
|
-
from tico.experimental.quantization.ptq.utils import channelwise_minmax
|
18
|
+
from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
|
19
19
|
|
20
20
|
|
21
21
|
class MinMaxObserver(AffineObserverBase):
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from tico.experimental.quantization.evaluation.metric import MetricCalculator
|
20
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
21
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
22
|
+
QuantModuleBase,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
|
27
|
+
"""
|
28
|
+
Return {module_object: full_qualified_name} without touching the modules.
|
29
|
+
"""
|
30
|
+
return {m: n for n, m in root.named_modules()}
|
31
|
+
|
32
|
+
|
33
|
+
def save_fp_outputs(
|
34
|
+
model: torch.nn.Module,
|
35
|
+
) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[str, torch.Tensor]]:
|
36
|
+
"""
|
37
|
+
Register forward-hooks on every `QuantModuleBase` wrapper itself (not the
|
38
|
+
wrapped `module`) and cache its output while the wrapper runs in CALIB mode.
|
39
|
+
|
40
|
+
Parameters
|
41
|
+
----------
|
42
|
+
model : torch.nn.Module
|
43
|
+
The model whose wrappers are already switched to CALIB mode
|
44
|
+
(`enable_calibration()` has been called).
|
45
|
+
|
46
|
+
Returns
|
47
|
+
-------
|
48
|
+
handles : list[RemovableHandle]
|
49
|
+
Hook handles; call `.remove()` on each one to detach the hooks.
|
50
|
+
cache : dict[str, torch.Tensor]
|
51
|
+
Mapping "wrapper-name → cached FP32 activation" captured from the first
|
52
|
+
forward pass. Keys default to `wrapper.fp_name`; if that attribute is
|
53
|
+
`None`, the `id(wrapper)` string is used instead.
|
54
|
+
"""
|
55
|
+
cache: Dict[str, torch.Tensor] = {}
|
56
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
57
|
+
|
58
|
+
def _save(name: str):
|
59
|
+
def hook(_, __, out: torch.Tensor | Tuple):
|
60
|
+
if isinstance(out, tuple):
|
61
|
+
out = out[0]
|
62
|
+
assert isinstance(out, torch.Tensor)
|
63
|
+
cache[name] = out.detach()
|
64
|
+
|
65
|
+
return hook
|
66
|
+
|
67
|
+
for m in model.modules():
|
68
|
+
if isinstance(m, QuantModuleBase):
|
69
|
+
name = m.fp_name or str(id(m))
|
70
|
+
handles.append(m.register_forward_hook(_save(name)))
|
71
|
+
|
72
|
+
return handles, cache
|
73
|
+
|
74
|
+
|
75
|
+
def compare_layer_outputs(
|
76
|
+
model: torch.nn.Module,
|
77
|
+
cache: Dict[str, torch.Tensor],
|
78
|
+
*,
|
79
|
+
metrics: Optional[List[str]] = None,
|
80
|
+
custom_metrics: Optional[Dict[str, Callable]] = None,
|
81
|
+
rtol: float = 1e-3,
|
82
|
+
atol: float = 1e-3,
|
83
|
+
collect: bool = False,
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
Register forward-hooks on every `QuantModuleBase` wrapper to compare its
|
87
|
+
QUANT-mode output to the FP32 reference saved by `save_fp_outputs()`.
|
88
|
+
|
89
|
+
Each hook prints a per-layer diff report:
|
90
|
+
|
91
|
+
✓ layer_name max=1.23e-02 mean=8.45e-04 (within tolerance)
|
92
|
+
⚠️ layer_name max=3.07e+00 mean=5.12e-01 (exceeds tolerance)
|
93
|
+
|
94
|
+
Parameters
|
95
|
+
----------
|
96
|
+
model : torch.nn.Module
|
97
|
+
The model whose wrappers are now in QUANT mode
|
98
|
+
(`freeze_qparams()` has been called).
|
99
|
+
cache : dict[str, torch.Tensor]
|
100
|
+
The reference activations captured during CALIB mode.
|
101
|
+
metrics
|
102
|
+
Metrics to compute. Defaults to `["diff"]`. Add `peir` to print PEIR.
|
103
|
+
custom_metrics
|
104
|
+
Optional user metric functions. Same signature as built-ins.
|
105
|
+
rtol, atol : float, optional
|
106
|
+
Relative / absolute tolerances used to flag large deviations
|
107
|
+
(similar to `torch.allclose` semantics).
|
108
|
+
collect : bool, optional
|
109
|
+
• False (default) → print one-line report per layer, return `None`
|
110
|
+
• True → suppress printing, return a nested dict
|
111
|
+
{layer_name -> {metric -> value}}
|
112
|
+
|
113
|
+
Returns
|
114
|
+
-------
|
115
|
+
handles
|
116
|
+
Hook handles; call `.remove()` once diffing is complete.
|
117
|
+
results
|
118
|
+
Only if *collect* is True.
|
119
|
+
"""
|
120
|
+
metrics = metrics or ["diff"]
|
121
|
+
calc = MetricCalculator(custom_metrics)
|
122
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
123
|
+
results: Dict[
|
124
|
+
str, Dict[str, float]
|
125
|
+
] = {} # Dict[layer_name, Dict[metric_name, value]]
|
126
|
+
|
127
|
+
def _cmp(name: str):
|
128
|
+
ref = cache.get(name)
|
129
|
+
|
130
|
+
def hook(_, __, out):
|
131
|
+
if ref is None:
|
132
|
+
if not collect:
|
133
|
+
print(f"[{name}] no cached reference")
|
134
|
+
return
|
135
|
+
if isinstance(out, tuple):
|
136
|
+
out = out[0]
|
137
|
+
assert isinstance(out, torch.Tensor)
|
138
|
+
|
139
|
+
# Compute all requested metrics
|
140
|
+
res = calc.compute([ref], [out], metrics) # lists with length-1 tensors
|
141
|
+
res = {k: v[0] for k, v in res.items()} # flatten
|
142
|
+
|
143
|
+
if collect:
|
144
|
+
results[name] = res # type: ignore[assignment]
|
145
|
+
return
|
146
|
+
|
147
|
+
# Pretty print ------------------------------------------------ #
|
148
|
+
diff_val = res.get("diff") or res.get("max_abs_diff")
|
149
|
+
thresh = atol + rtol * ref.abs().max().item()
|
150
|
+
flag = "⚠️" if (diff_val is not None and diff_val > thresh) else "✓" # type: ignore[operator]
|
151
|
+
|
152
|
+
pieces = [f"{flag} {name:45s}"]
|
153
|
+
for key, val in res.items():
|
154
|
+
pieces.append(f"{key}={val:<7.4}")
|
155
|
+
print(" ".join(pieces))
|
156
|
+
|
157
|
+
return hook
|
158
|
+
|
159
|
+
for m in model.modules():
|
160
|
+
if isinstance(m, PTQWrapper):
|
161
|
+
# skip the internal fp module inside the wrapper
|
162
|
+
continue
|
163
|
+
if isinstance(m, QuantModuleBase):
|
164
|
+
lname = m.fp_name or str(id(m))
|
165
|
+
handles.append(m.register_forward_hook(_cmp(lname)))
|
166
|
+
|
167
|
+
if collect:
|
168
|
+
return handles, results
|
169
|
+
return handles
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -1,9 +1 @@
|
|
1
|
-
|
2
|
-
QuantLlamaAttention,
|
3
|
-
)
|
4
|
-
from tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer import (
|
5
|
-
QuantLlamaDecoderLayer,
|
6
|
-
)
|
7
|
-
from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
8
|
-
|
9
|
-
__all__ = ["QuantLlamaAttention", "QuantLlamaDecoderLayer", "QuantLlamaMLP"]
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -25,7 +25,10 @@ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
25
25
|
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
26
26
|
|
27
27
|
|
28
|
-
@try_register(
|
28
|
+
@try_register(
|
29
|
+
"transformers.models.llama.modeling_llama.LlamaAttention",
|
30
|
+
"transformers.models.llama.modeling_llama.LlamaSdpaAttention",
|
31
|
+
)
|
29
32
|
class QuantLlamaAttention(QuantModuleBase):
|
30
33
|
def __init__(
|
31
34
|
self,
|
@@ -1,11 +1 @@
|
|
1
|
-
|
2
|
-
QuantLayerNorm,
|
3
|
-
)
|
4
|
-
from tico.experimental.quantization.ptq.wrappers.nn.quant_linear import QuantLinear
|
5
|
-
from tico.experimental.quantization.ptq.wrappers.nn.quant_silu import QuantSiLU
|
6
|
-
|
7
|
-
__all__ = [
|
8
|
-
"QuantLayerNorm",
|
9
|
-
"QuantLinear",
|
10
|
-
"QuantSiLU",
|
11
|
-
]
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -90,7 +90,9 @@ def register(
|
|
90
90
|
|
91
91
|
|
92
92
|
# ───────────────────────────── conditional decorator
|
93
|
-
def try_register(
|
93
|
+
def try_register(
|
94
|
+
*paths: str,
|
95
|
+
) -> Callable[[Type[QuantModuleBase]], Type[QuantModuleBase]]:
|
94
96
|
"""
|
95
97
|
@try_register("transformers.models.llama.modeling_llama.LlamaMLP")
|
96
98
|
|
@@ -99,14 +101,15 @@ def try_register(path: str) -> Callable[[Type[QuantModuleBase]], Type[QuantModul
|
|
99
101
|
"""
|
100
102
|
|
101
103
|
def _decorator(quant_cls: Type[QuantModuleBase]):
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
104
|
+
for path in paths:
|
105
|
+
module_name, _, cls_name = path.rpartition(".")
|
106
|
+
try:
|
107
|
+
mod = importlib.import_module(module_name)
|
108
|
+
fp_cls = getattr(mod, cls_name)
|
109
|
+
_WRAPPERS[fp_cls] = quant_cls
|
110
|
+
except (ModuleNotFoundError, AttributeError):
|
111
|
+
# optional dep missing or class renamed – skip silently
|
112
|
+
pass
|
110
113
|
return quant_cls
|
111
114
|
|
112
115
|
return _decorator
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=Q7A_sZgX9kbJSpo2ndJB19BlbFSGnajl9HAFGt8D2Q0,1883
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
@@ -56,37 +56,40 @@ tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0
|
|
56
56
|
tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
|
57
57
|
tico/experimental/quantization/passes/quantize_bias.py,sha256=T7YxJ70N0tSK0FF9VJZA5iP0sHdnnsX9GX4AT4JDFSk,4325
|
58
58
|
tico/experimental/quantization/passes/remove_weight_dequant_op.py,sha256=gI1MtrHazWpdNfys7f1ngTTWplzluF7SA-uX0HMR5Mc,6592
|
59
|
-
tico/experimental/quantization/ptq/__init__.py,sha256=
|
59
|
+
tico/experimental/quantization/ptq/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
60
60
|
tico/experimental/quantization/ptq/dtypes.py,sha256=xfCBtq6mQmUYRwsoFgII6gvRl1raQi0Inj9pznDuKwQ,2236
|
61
61
|
tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAdh_3PWB-ptPKaI,1052
|
62
62
|
tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
|
63
63
|
tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
|
64
64
|
tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
65
|
-
tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=
|
65
|
+
tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=SmSmaCBVWTcGRPRk2zopDqESD_gF8D7J4kUNNZ-0cMk,5295
|
66
|
+
tico/experimental/quantization/ptq/examples/debug_quant_outputs.py,sha256=astXzx-maq1W4gKvX2QaGmD2Tpmjunv4JqDYVk9eZRQ,5177
|
66
67
|
tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
|
67
68
|
tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
|
68
69
|
tico/experimental/quantization/ptq/examples/quantize_llama_decoder_layer.py,sha256=mBWrjkyEovYQsPC4Rrsri6Pm1rlFmDb3NiP0DQQhFyM,5751
|
69
70
|
tico/experimental/quantization/ptq/examples/quantize_llama_mlp.py,sha256=N1qZQgt1S-xZrdv-PW7OfXEcv0gsO2q9faOF4aD-zKo,4147
|
70
|
-
tico/experimental/quantization/ptq/
|
71
|
+
tico/experimental/quantization/ptq/examples/quantize_with_gptq.py,sha256=w21Qao5_6SnWMuxmnZbZOoqaLQOuSnK52mHin4aedtA,6979
|
72
|
+
tico/experimental/quantization/ptq/observers/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
71
73
|
tico/experimental/quantization/ptq/observers/affine_base.py,sha256=e2Eba64nrxKQyE4F_WJ7WTSsk3xe6bkdGUKaoLFWGFw,4638
|
72
74
|
tico/experimental/quantization/ptq/observers/base.py,sha256=Wons1MzpqK1mfcy-ppl-B2Dum0edXg2dWW2Lw3V18tw,3280
|
73
|
-
tico/experimental/quantization/ptq/observers/ema.py,sha256=
|
75
|
+
tico/experimental/quantization/ptq/observers/ema.py,sha256=oISP1XaD3lapVaHQKscD3rjLcKbhOy4Nvi6dqRFZwF8,2070
|
74
76
|
tico/experimental/quantization/ptq/observers/identity.py,sha256=vkec8Or-7VwM4zkFEvEKROQJk8XEHMVX8mBNDnxSyS8,2591
|
75
|
-
tico/experimental/quantization/ptq/observers/minmax.py,sha256=
|
77
|
+
tico/experimental/quantization/ptq/observers/minmax.py,sha256=WWcAyEIrd5j3k9qsoBJi3nUnWtrwPaKlR9CPezbDSqQ,1499
|
76
78
|
tico/experimental/quantization/ptq/observers/mx.py,sha256=aP4qmBgeiRIYZJksShN5gs6UyYOFi2-Sbk5k5xvPQ4w,1863
|
77
|
-
tico/experimental/quantization/ptq/utils/__init__.py,sha256=
|
79
|
+
tico/experimental/quantization/ptq/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
80
|
+
tico/experimental/quantization/ptq/utils/introspection.py,sha256=y2oGf7RoApMHJeXLmIz3VVWB9vazGEgyLbxLiVTTQdw,6000
|
78
81
|
tico/experimental/quantization/ptq/utils/metrics.py,sha256=EW_FQmJrl9Y4esspZQ0GHfJ58RwuJUz0l8IfYq3NWY4,4461
|
79
82
|
tico/experimental/quantization/ptq/utils/reduce_utils.py,sha256=3kWawLB91EcvvHlCrNqqfZF7tpgr22htBSA049mKw_4,973
|
80
|
-
tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=
|
83
|
+
tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
81
84
|
tico/experimental/quantization/ptq/wrappers/ptq_wrapper.py,sha256=F9sK_DiRaXiGNHULcwIbs5EUtHz6ZJ7N4r5CWTTfhsM,2442
|
82
85
|
tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfvto6zKrBOKL4gmxfFFc31jHzyQV_zfps-iQM,3604
|
83
86
|
tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
|
84
|
-
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=
|
85
|
-
tico/experimental/quantization/ptq/wrappers/llama/__init__.py,sha256=
|
86
|
-
tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py,sha256
|
87
|
+
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=wauoZdZBR15bGj1Upt9owEfFDT-Tj6HzciG9HDM1BHo,4845
|
88
|
+
tico/experimental/quantization/ptq/wrappers/llama/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
89
|
+
tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py,sha256=-K1COLHIHfJZhQu-RE6KfJIkaL7S6yR4iUj48QkjMTw,8652
|
87
90
|
tico/experimental/quantization/ptq/wrappers/llama/quant_decoder_layer.py,sha256=2XsIf5rcabDXXkahqriSxfo2curFq0Y5bnRPcYkJPg8,7187
|
88
91
|
tico/experimental/quantization/ptq/wrappers/llama/quant_mlp.py,sha256=uZMnrX66oZwxhKhcNbLXXeri-WxxRBiZnr15aBXJMm0,3562
|
89
|
-
tico/experimental/quantization/ptq/wrappers/nn/__init__.py,sha256=
|
92
|
+
tico/experimental/quantization/ptq/wrappers/nn/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
90
93
|
tico/experimental/quantization/ptq/wrappers/nn/quant_layernorm.py,sha256=G5Sgt-tXnzh0Rxyk-2honmZIfEQOZlRfOsoDBdSGmA4,6887
|
91
94
|
tico/experimental/quantization/ptq/wrappers/nn/quant_linear.py,sha256=xW-VEPB7RJoslS3xLVCdhIuMjppknvpkZleRGK4JFVQ,2240
|
92
95
|
tico/experimental/quantization/ptq/wrappers/nn/quant_silu.py,sha256=XnJDggkWUTfXC1-BLeAbcCUtp687XLIkIIbuQlqycDw,1864
|
@@ -244,9 +247,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
244
247
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
245
248
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
246
249
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
247
|
-
tico-0.1.0.
|
248
|
-
tico-0.1.0.
|
249
|
-
tico-0.1.0.
|
250
|
-
tico-0.1.0.
|
251
|
-
tico-0.1.0.
|
252
|
-
tico-0.1.0.
|
250
|
+
tico-0.1.0.dev250904.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
251
|
+
tico-0.1.0.dev250904.dist-info/METADATA,sha256=jCLzVEpVnwflcRRdt3AZElB_Lw1_lWSzusXsiA5RAig,8450
|
252
|
+
tico-0.1.0.dev250904.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
253
|
+
tico-0.1.0.dev250904.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
254
|
+
tico-0.1.0.dev250904.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
255
|
+
tico-0.1.0.dev250904.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|