tico 0.1.0.dev250831__py3-none-any.whl → 0.1.0.dev250902__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/ptq/examples/compare_ppl.py +121 -0
- tico/experimental/quantization/ptq/examples/quantize_llama_decoder_layer.py +124 -0
- tico/experimental/quantization/ptq/utils/__init__.py +2 -0
- tico/experimental/quantization/ptq/utils/metrics.py +123 -0
- tico/experimental/quantization/ptq/wrappers/llama/__init__.py +4 -1
- tico/experimental/quantization/ptq/wrappers/llama/quant_decoder_layer.py +168 -0
- tico/experimental/quantization/ptq/wrappers/registry.py +1 -0
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- {tico-0.1.0.dev250831.dist-info → tico-0.1.0.dev250902.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250831.dist-info → tico-0.1.0.dev250902.dist-info}/RECORD +15 -11
- {tico-0.1.0.dev250831.dist-info → tico-0.1.0.dev250902.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250831.dist-info → tico-0.1.0.dev250902.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250831.dist-info → tico-0.1.0.dev250902.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250831.dist-info → tico-0.1.0.dev250902.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -0,0 +1,121 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# =============================================================================
|
16
|
+
# QUICK PTQ WORKFLOW (OPTIONAL FP32 BASELINE)
|
17
|
+
# -----------------------------------------------------------------------------
|
18
|
+
# Toggle RUN_FP to choose between:
|
19
|
+
# • FP32 perplexity measurement only, OR
|
20
|
+
# • Full post-training UINT-8 flow (wrap → calibrate → eval).
|
21
|
+
# =============================================================================
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import tqdm
|
25
|
+
from datasets import load_dataset
|
26
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
27
|
+
|
28
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
29
|
+
from tico.experimental.quantization.ptq.utils import perplexity
|
30
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
31
|
+
|
32
|
+
# -------------------------------------------------------------------------
|
33
|
+
# 0. Global configuration
|
34
|
+
# -------------------------------------------------------------------------
|
35
|
+
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
36
|
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
37
|
+
STRIDE = 512 # sliding-window stride for perplexity
|
38
|
+
RUN_FP = True # set False → run UINT-8 path
|
39
|
+
|
40
|
+
# Token-budget presets for activation calibration
|
41
|
+
TOKENS: dict[str, int] = {
|
42
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
43
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
44
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
45
|
+
"baseline": 50_000,
|
46
|
+
# Production / 4-bit observer smoothing
|
47
|
+
"production": 200_000,
|
48
|
+
}
|
49
|
+
CALIB_TOKENS = TOKENS["baseline"]
|
50
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
51
|
+
|
52
|
+
# -------------------------------------------------------------------------
|
53
|
+
# 1. Load model
|
54
|
+
# -------------------------------------------------------------------------
|
55
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
56
|
+
|
57
|
+
if RUN_FP:
|
58
|
+
# -- FP32 baseline ------------------------------------------------------
|
59
|
+
print("Loading FP32 model …")
|
60
|
+
fp_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
61
|
+
fp_model.config.use_cache = False
|
62
|
+
else:
|
63
|
+
# -- UINT-8 pipeline -----------------------------------------------------
|
64
|
+
print("Creating UINT-8 clone …")
|
65
|
+
uint8_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
66
|
+
uint8_model.config.use_cache = False
|
67
|
+
|
68
|
+
# ---------------------------------------------------------------------
|
69
|
+
# 2. Wrap every Transformer layer with PTQWrapper
|
70
|
+
# ---------------------------------------------------------------------
|
71
|
+
qcfg = QuantConfig() # all-uint8 defaults
|
72
|
+
|
73
|
+
wrapped_layers = torch.nn.ModuleList()
|
74
|
+
for idx, layer in enumerate(uint8_model.model.layers):
|
75
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
76
|
+
wrapped_layers.append(PTQWrapper(layer, qcfg=layer_cfg))
|
77
|
+
uint8_model.model.layers = wrapped_layers
|
78
|
+
|
79
|
+
# ---------------------------------------------------------------------
|
80
|
+
# 3. Single-pass activation calibration
|
81
|
+
# ---------------------------------------------------------------------
|
82
|
+
print("Calibrating UINT-8 observers …")
|
83
|
+
calib_txt = " ".join(
|
84
|
+
load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
|
85
|
+
)[:CALIB_TOKENS]
|
86
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
87
|
+
|
88
|
+
# (a) switch every QuantModuleBase to CALIB mode
|
89
|
+
for l in uint8_model.model.layers:
|
90
|
+
l.enable_calibration()
|
91
|
+
|
92
|
+
# (b) run inference to collect ranges
|
93
|
+
with torch.no_grad():
|
94
|
+
for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Calibration"):
|
95
|
+
uint8_model(ids[:, i : i + STRIDE])
|
96
|
+
|
97
|
+
# (c) freeze (scale, zero-point)
|
98
|
+
for l in uint8_model.model.layers:
|
99
|
+
l.freeze_qparams()
|
100
|
+
|
101
|
+
# -------------------------------------------------------------------------
|
102
|
+
# 4. Evaluate perplexity on Wikitext-2
|
103
|
+
# -------------------------------------------------------------------------
|
104
|
+
print("\nCalculating perplexities …")
|
105
|
+
test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
106
|
+
enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
|
107
|
+
|
108
|
+
if RUN_FP:
|
109
|
+
ppl_fp = perplexity(fp_model, enc, DEVICE, stride=STRIDE)
|
110
|
+
else:
|
111
|
+
ppl_int8 = perplexity(uint8_model, enc, DEVICE, stride=STRIDE)
|
112
|
+
|
113
|
+
# -------------------------------------------------------------------------
|
114
|
+
# 5. Report
|
115
|
+
# -------------------------------------------------------------------------
|
116
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
117
|
+
if RUN_FP:
|
118
|
+
print(f"│ FP32 : {ppl_fp:8.2f}")
|
119
|
+
else:
|
120
|
+
print(f"│ UINT-8 : {ppl_int8:8.2f}")
|
121
|
+
print("└───────────────────────────────────────────")
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# =============================================================================
|
16
|
+
# POST-TRAINING QUANTIZATION EXAMPLE — Llama Decoder Layer (Self-Attn + MLP)
|
17
|
+
# -----------------------------------------------------------------------------
|
18
|
+
# This demo shows how to:
|
19
|
+
# 1. Replace a single FP32 `LlamaDecoderLayer` with `QuantLlamaDecoderLayer`.
|
20
|
+
# 2. Collect activation statistics in one calibration sweep.
|
21
|
+
# 3. Freeze scales / zero-points and switch to INT-simulation mode.
|
22
|
+
# 4. Compare INT-8 vs FP32 outputs with a quick mean-absolute-diff check.
|
23
|
+
# 5. Export the calibrated, quantized block to a Circle model.
|
24
|
+
# -----------------------------------------------------------------------------
|
25
|
+
# Style / layout is kept identical to the `quantize_llama_attn.py` and
|
26
|
+
# `quantize_llama_mlp.py` examples for easy side-by-side reading.
|
27
|
+
# =============================================================================
|
28
|
+
|
29
|
+
import pathlib
|
30
|
+
|
31
|
+
import torch
|
32
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
33
|
+
|
34
|
+
from tico.experimental.quantization.evaluation.metric import compute_peir
|
35
|
+
from tico.experimental.quantization.evaluation.utils import plot_two_outputs
|
36
|
+
from tico.experimental.quantization.ptq.mode import Mode
|
37
|
+
from tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer import (
|
38
|
+
QuantLlamaDecoderLayer,
|
39
|
+
)
|
40
|
+
from tico.utils.utils import SuppressWarning
|
41
|
+
|
42
|
+
MODEL_NAME = "Maykeye/TinyLLama-v0"
|
43
|
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
44
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
45
|
+
|
46
|
+
model.eval() # disable dropout, etc.
|
47
|
+
rotary = model.model.rotary_emb # RoPE helper
|
48
|
+
|
49
|
+
# -------------------------------------------------------------------------
|
50
|
+
# 1. Swap in the quant wrapper
|
51
|
+
# -------------------------------------------------------------------------
|
52
|
+
fp32_layer = model.model.layers[0] # keep a reference for diff check
|
53
|
+
model.model.layers[0] = QuantLlamaDecoderLayer(
|
54
|
+
fp32_layer
|
55
|
+
) # PTQWrapper(fp32_layer) is also fine
|
56
|
+
model.eval()
|
57
|
+
|
58
|
+
qlayer = model.model.layers[0] # alias for brevity
|
59
|
+
|
60
|
+
# -------------------------------------------------------------------------
|
61
|
+
# 2. Single-pass calibration (gather activation ranges)
|
62
|
+
# -------------------------------------------------------------------------
|
63
|
+
PROMPTS = [
|
64
|
+
"The quick brown fox jumps over the lazy dog.",
|
65
|
+
"In 2025, AI systems accelerated hardware-software co-design at scale.",
|
66
|
+
"양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.",
|
67
|
+
"今日はいい天気ですね。ところでRoPE角度は長さに依存します。",
|
68
|
+
"def quicksort(arr):\n if len(arr) <= 1: return arr\n ...",
|
69
|
+
"Prices rose 3.14% — see Figure 2; emails: foo@bar.com!",
|
70
|
+
]
|
71
|
+
|
72
|
+
with torch.no_grad():
|
73
|
+
qlayer.enable_calibration()
|
74
|
+
for prompt in PROMPTS:
|
75
|
+
ids = tokenizer(prompt, return_tensors="pt")
|
76
|
+
hidden = model.model.embed_tokens(ids["input_ids"])
|
77
|
+
pos = rotary(hidden, ids["input_ids"]) # (cos, sin) tuple
|
78
|
+
S = pos[0].shape[1]
|
79
|
+
attn_mask = torch.zeros(1, 1, S, S) # causal-mask placeholder
|
80
|
+
_ = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos)
|
81
|
+
qlayer.freeze_qparams()
|
82
|
+
|
83
|
+
assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now."
|
84
|
+
|
85
|
+
# -------------------------------------------------------------------------
|
86
|
+
# 3. Quick INT-sim vs FP32 sanity check
|
87
|
+
# -------------------------------------------------------------------------
|
88
|
+
ids = tokenizer("check", return_tensors="pt")
|
89
|
+
hidden = model.model.embed_tokens(ids["input_ids"])
|
90
|
+
pos = rotary(hidden, ids["input_ids"])
|
91
|
+
S = pos[0].shape[1]
|
92
|
+
attn_mask = torch.zeros(1, 1, S, S)
|
93
|
+
|
94
|
+
with torch.no_grad():
|
95
|
+
int8_out = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos)
|
96
|
+
int8 = int8_out[0] if isinstance(int8_out, tuple) else int8_out
|
97
|
+
fp32_out = fp32_layer(hidden, attention_mask=attn_mask, position_embeddings=pos)
|
98
|
+
fp32 = fp32_out[0] if isinstance(fp32_out, tuple) else fp32_out
|
99
|
+
|
100
|
+
print("┌───────────── Quantization Error Summary ─────────────")
|
101
|
+
print(f"│ Mean |diff|: {(int8 - fp32).abs().mean().item():.6f}")
|
102
|
+
print(f"│ PEIR : {compute_peir(fp32, int8) * 100:.6f} %")
|
103
|
+
print("└──────────────────────────────────────────────────────")
|
104
|
+
print(plot_two_outputs(fp32, int8))
|
105
|
+
|
106
|
+
# -------------------------------------------------------------------------
|
107
|
+
# 4. Export the calibrated layer to Circle
|
108
|
+
# -------------------------------------------------------------------------
|
109
|
+
import tico
|
110
|
+
|
111
|
+
save_path = pathlib.Path("decoder_layer.q.circle")
|
112
|
+
B, S, D = 1, 4, model.config.hidden_size
|
113
|
+
example_hidden = torch.randn(B, S, D)
|
114
|
+
example_pos = rotary(example_hidden, torch.arange(S)[None, :])
|
115
|
+
attn_mask = torch.zeros(1, 1, S, S)
|
116
|
+
|
117
|
+
with SuppressWarning(UserWarning, ".*"):
|
118
|
+
cm = tico.convert(
|
119
|
+
qlayer, (example_hidden, attn_mask), {"position_embeddings": example_pos}
|
120
|
+
)
|
121
|
+
# Note that the model is not fully quantized.
|
122
|
+
cm.save(save_path)
|
123
|
+
|
124
|
+
print(f"Quantized Circle model saved to {save_path.resolve()}")
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import tqdm
|
19
|
+
|
20
|
+
|
21
|
+
def perplexity(
|
22
|
+
model: torch.nn.Module,
|
23
|
+
encodings: torch.Tensor,
|
24
|
+
device: torch.device | str,
|
25
|
+
*,
|
26
|
+
max_length: Optional[int] = None,
|
27
|
+
stride: int = 512,
|
28
|
+
ignore_index: int | None = -100,
|
29
|
+
show_progress: bool = True,
|
30
|
+
) -> float:
|
31
|
+
"""
|
32
|
+
Compute perplexity (PPL) using a "strided sliding-window"
|
33
|
+
evaluation strategy.
|
34
|
+
|
35
|
+
The function:
|
36
|
+
1. Splits the token sequence into overlapping windows of length
|
37
|
+
`max_length` (model context size).
|
38
|
+
2. Masks tokens that were already scored in previous windows
|
39
|
+
(`labels == -100`), so each token's negative log-likelihood (NLL)
|
40
|
+
is counted EXACTLY once.
|
41
|
+
3. Aggregates token-wise NLL to return corpus-level PPL.
|
42
|
+
|
43
|
+
Parameters
|
44
|
+
----------
|
45
|
+
model : torch.nn.Module
|
46
|
+
Causal LM loaded in evaluation mode (`model.eval()`).
|
47
|
+
encodings : torch.Tensor | transformers.BatchEncoding
|
48
|
+
Tokenised corpus. If a `BatchEncoding` is passed, its
|
49
|
+
`.input_ids` field is used. Shape must be `(1, seq_len)`.
|
50
|
+
device : torch.device | str
|
51
|
+
CUDA or CPU device on which to run evaluation.
|
52
|
+
max_length : int, optional
|
53
|
+
Context window size. Defaults to `model.config.max_position_embeddings`.
|
54
|
+
stride : int, default = 512
|
55
|
+
Step size by which the sliding window advances. Must satisfy
|
56
|
+
`1 ≤ stride ≤ max_length`.
|
57
|
+
ignore_index : int, default = -100
|
58
|
+
Label value to ignore in loss computation. This should match
|
59
|
+
the `ignore_index` used by the model's internal
|
60
|
+
`CrossEntropyLoss`. For Hugging Face causal LMs, the
|
61
|
+
convention is `-100`.
|
62
|
+
show_progress : bool, default = True
|
63
|
+
If True, displays a tqdm progess bar while evaluating.
|
64
|
+
|
65
|
+
Returns
|
66
|
+
-------
|
67
|
+
float
|
68
|
+
Corpus-level perplexity.
|
69
|
+
"""
|
70
|
+
# -------- input preparation -------- #
|
71
|
+
try:
|
72
|
+
# transformers.BatchEncoding has `input_ids`
|
73
|
+
input_ids_full = encodings.input_ids # type: ignore[attr-defined]
|
74
|
+
except AttributeError: # already a tensor
|
75
|
+
input_ids_full = encodings
|
76
|
+
assert isinstance(input_ids_full, torch.Tensor)
|
77
|
+
input_ids_full = input_ids_full.to(device)
|
78
|
+
|
79
|
+
if max_length is None:
|
80
|
+
assert hasattr(model, "config")
|
81
|
+
assert hasattr(model.config, "max_position_embeddings")
|
82
|
+
assert isinstance(model.config.max_position_embeddings, int)
|
83
|
+
max_length = model.config.max_position_embeddings
|
84
|
+
assert max_length is not None
|
85
|
+
assert (
|
86
|
+
1 <= stride <= max_length
|
87
|
+
), f"stride ({stride}) must be in [1, max_length ({max_length})]"
|
88
|
+
|
89
|
+
seq_len = input_ids_full.size(1)
|
90
|
+
nll_sum = 0.0
|
91
|
+
n_tokens = 0
|
92
|
+
prev_end = 0
|
93
|
+
|
94
|
+
# -------- main loop -------- #
|
95
|
+
for begin in tqdm.trange(0, seq_len, stride, desc="PPL", disable=not show_progress):
|
96
|
+
end = min(begin + max_length, seq_len)
|
97
|
+
trg_len = end - prev_end # fresh tokens in this window
|
98
|
+
|
99
|
+
input_ids = input_ids_full[:, begin:end]
|
100
|
+
target_ids = input_ids.clone()
|
101
|
+
target_ids[:, :-trg_len] = ignore_index # mask previously-scored tokens
|
102
|
+
|
103
|
+
with torch.no_grad():
|
104
|
+
outputs = model(input_ids, labels=target_ids)
|
105
|
+
# loss is already averaged over non-masked labels
|
106
|
+
neg_log_likelihood = outputs.loss
|
107
|
+
|
108
|
+
# exact number of labels that contributed to loss
|
109
|
+
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item()
|
110
|
+
nll_sum += neg_log_likelihood * loss_tokens
|
111
|
+
n_tokens += int(loss_tokens)
|
112
|
+
|
113
|
+
prev_end = end
|
114
|
+
if end == seq_len:
|
115
|
+
break
|
116
|
+
|
117
|
+
avg_nll: float | torch.Tensor = nll_sum / n_tokens
|
118
|
+
if not isinstance(avg_nll, torch.Tensor):
|
119
|
+
avg_nll = torch.tensor(avg_nll)
|
120
|
+
assert isinstance(avg_nll, torch.Tensor)
|
121
|
+
ppl = torch.exp(avg_nll)
|
122
|
+
|
123
|
+
return ppl.item()
|
@@ -1,6 +1,9 @@
|
|
1
1
|
from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
|
2
2
|
QuantLlamaAttention,
|
3
3
|
)
|
4
|
+
from tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer import (
|
5
|
+
QuantLlamaDecoderLayer,
|
6
|
+
)
|
4
7
|
from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
5
8
|
|
6
|
-
__all__ = ["QuantLlamaAttention", "QuantLlamaMLP"]
|
9
|
+
__all__ = ["QuantLlamaAttention", "QuantLlamaDecoderLayer", "QuantLlamaMLP"]
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
21
|
+
from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
|
22
|
+
QuantLlamaAttention,
|
23
|
+
)
|
24
|
+
from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
25
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
26
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
27
|
+
QuantModuleBase,
|
28
|
+
)
|
29
|
+
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
30
|
+
|
31
|
+
|
32
|
+
@try_register("transformers.models.llama.modeling_llama.LlamaDecoderLayer")
|
33
|
+
class QuantLlamaDecoderLayer(QuantModuleBase):
|
34
|
+
"""
|
35
|
+
Quant-aware drop-in replacement for HF `LlamaDecoderLayer`.
|
36
|
+
Signature and return-value are identical to the original.
|
37
|
+
|
38
|
+
▸ Attention & MLP blocks are replaced by their quantized counterparts
|
39
|
+
▸ LayerNorms remain FP32 (no fake-quant)
|
40
|
+
▸ A "static" causal mask is pre-built in `__init__` to avoid
|
41
|
+
dynamic boolean-to-float casts inside `forward`.
|
42
|
+
|
43
|
+
Notes on the causal mask
|
44
|
+
------------------------
|
45
|
+
Building a boolean mask "inside" `forward` would introduce
|
46
|
+
non-deterministic dynamic ops that an integer-only accelerator cannot
|
47
|
+
fuse easily. Therefore we:
|
48
|
+
|
49
|
+
1. Pre-compute a full upper-triangular mask of size
|
50
|
+
`[1, 1, max_seq, max_seq]` in `__init__`.
|
51
|
+
2. In `forward`, if the caller passes `attention_mask=None`, we
|
52
|
+
slice the pre-computed template to the current sequence length.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
fp_layer: nn.Module,
|
58
|
+
*,
|
59
|
+
qcfg: Optional[QuantConfig] = None,
|
60
|
+
fp_name: Optional[str] = None,
|
61
|
+
return_type: Optional[str] = None,
|
62
|
+
):
|
63
|
+
"""
|
64
|
+
Q) Why do we need `return_type`?
|
65
|
+
A) Different versions of `transformers` wrap the decoder output in
|
66
|
+
different containers: a plain Tensor or a tuple.
|
67
|
+
"""
|
68
|
+
self.return_type = return_type
|
69
|
+
if self.return_type is None:
|
70
|
+
import transformers
|
71
|
+
|
72
|
+
v = tuple(map(int, transformers.__version__.split(".")[:2]))
|
73
|
+
self.return_type = "tensor" if v >= (4, 54) else "tuple"
|
74
|
+
assert self.return_type is not None
|
75
|
+
super().__init__(qcfg, fp_name=fp_name)
|
76
|
+
|
77
|
+
# Child QuantConfigs -------------------------------------------------
|
78
|
+
attn_cfg = qcfg.child("self_attn") if qcfg else None
|
79
|
+
mlp_cfg = qcfg.child("mlp") if qcfg else None
|
80
|
+
|
81
|
+
# Quantized sub-modules ---------------------------------------------
|
82
|
+
assert hasattr(fp_layer, "self_attn") and isinstance(
|
83
|
+
fp_layer.self_attn, torch.nn.Module
|
84
|
+
)
|
85
|
+
assert hasattr(fp_layer, "mlp") and isinstance(fp_layer.mlp, torch.nn.Module)
|
86
|
+
self.self_attn = PTQWrapper(
|
87
|
+
fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{fp_name}.self_attn"
|
88
|
+
)
|
89
|
+
self.mlp = PTQWrapper(fp_layer.mlp, qcfg=mlp_cfg, fp_name=f"{fp_name}.mlp")
|
90
|
+
|
91
|
+
# LayerNorms remain FP (copied from fp_layer to keep weights)
|
92
|
+
assert hasattr(fp_layer, "input_layernorm") and isinstance(
|
93
|
+
fp_layer.input_layernorm, torch.nn.Module
|
94
|
+
)
|
95
|
+
assert hasattr(fp_layer, "post_attention_layernorm") and isinstance(
|
96
|
+
fp_layer.post_attention_layernorm, torch.nn.Module
|
97
|
+
)
|
98
|
+
self.input_layernorm = fp_layer.input_layernorm
|
99
|
+
self.post_attention_layernorm = fp_layer.post_attention_layernorm
|
100
|
+
|
101
|
+
# Static causal mask template ---------------------------------------
|
102
|
+
assert hasattr(fp_layer.self_attn, "config") and hasattr(
|
103
|
+
fp_layer.self_attn.config, "max_position_embeddings"
|
104
|
+
)
|
105
|
+
assert isinstance(fp_layer.self_attn.config.max_position_embeddings, int)
|
106
|
+
max_seq = fp_layer.self_attn.config.max_position_embeddings
|
107
|
+
mask = torch.full((1, 1, max_seq, max_seq), float("-120"))
|
108
|
+
mask.triu_(1)
|
109
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
110
|
+
|
111
|
+
def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
112
|
+
"""Return `[1,1,L,L]` causal mask slice on *device*."""
|
113
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
114
|
+
return self.causal_mask_template[..., :seq_len, :seq_len].to(device)
|
115
|
+
|
116
|
+
def forward(
|
117
|
+
self,
|
118
|
+
hidden_states: torch.Tensor,
|
119
|
+
attention_mask: Optional[torch.Tensor] = None,
|
120
|
+
position_ids: Optional[torch.LongTensor] = None,
|
121
|
+
past_key_value: Optional["Cache"] = None, # type: ignore[name-defined]
|
122
|
+
output_attentions: Optional[bool] = False,
|
123
|
+
use_cache: Optional[bool] = False,
|
124
|
+
cache_position: Optional[torch.LongTensor] = None,
|
125
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
126
|
+
**kwargs,
|
127
|
+
) -> Tuple[torch.Tensor] | torch.Tensor:
|
128
|
+
if output_attentions:
|
129
|
+
raise NotImplementedError(
|
130
|
+
"QuantLlamaDecoderLayer does not support output attention yet."
|
131
|
+
)
|
132
|
+
residual = hidden_states
|
133
|
+
hidden_states = self.input_layernorm(hidden_states)
|
134
|
+
|
135
|
+
if attention_mask is None or attention_mask.dtype == torch.bool:
|
136
|
+
L = hidden_states.size(1)
|
137
|
+
attention_mask = self._slice_causal(L, hidden_states.device)
|
138
|
+
|
139
|
+
hidden_states, _ = self.self_attn(
|
140
|
+
hidden_states=hidden_states,
|
141
|
+
attention_mask=attention_mask,
|
142
|
+
position_ids=position_ids,
|
143
|
+
past_key_value=past_key_value,
|
144
|
+
output_attentions=output_attentions,
|
145
|
+
use_cache=use_cache,
|
146
|
+
cache_position=cache_position,
|
147
|
+
position_embeddings=position_embeddings,
|
148
|
+
**kwargs,
|
149
|
+
)
|
150
|
+
hidden_states = residual + hidden_states
|
151
|
+
|
152
|
+
# ─── MLP block ─────────────────────────────────────────────────
|
153
|
+
residual = hidden_states
|
154
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
155
|
+
hidden_states = self.mlp(hidden_states)
|
156
|
+
hidden_states = residual + hidden_states
|
157
|
+
|
158
|
+
if self.return_type == "tuple":
|
159
|
+
return (hidden_states,)
|
160
|
+
elif self.return_type == "tensor":
|
161
|
+
return hidden_states
|
162
|
+
else:
|
163
|
+
raise RuntimeError("Invalid return type.")
|
164
|
+
|
165
|
+
# No local observers; just recurse into children
|
166
|
+
def _all_observers(self):
|
167
|
+
yield from self.self_attn._all_observers()
|
168
|
+
yield from self.mlp._all_observers()
|
@@ -29,6 +29,7 @@ _CORE_MODULES = (
|
|
29
29
|
"tico.experimental.quantization.ptq.wrappers.nn.quant_linear",
|
30
30
|
"tico.experimental.quantization.ptq.wrappers.nn.quant_silu",
|
31
31
|
"tico.experimental.quantization.ptq.wrappers.llama.quant_attn",
|
32
|
+
"tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer",
|
32
33
|
"tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
|
33
34
|
# add future core wrappers here
|
34
35
|
)
|
@@ -21,7 +21,9 @@ from tico.utils.utils import is_target_node
|
|
21
21
|
|
22
22
|
|
23
23
|
assert_node_targets = [
|
24
|
+
torch.ops.aten._assert_scalar.default,
|
24
25
|
torch.ops.aten._assert_tensor_metadata.default,
|
26
|
+
torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation
|
25
27
|
]
|
26
28
|
|
27
29
|
|
@@ -29,7 +31,7 @@ assert_node_targets = [
|
|
29
31
|
class RemoveRedundantAssertionNodes(PassBase):
|
30
32
|
"""
|
31
33
|
This removes redundant assertion nodes.
|
32
|
-
|
34
|
+
When assertion node is erased, related comparison nodes are also removed by graph.eliminate_dead_code().
|
33
35
|
"""
|
34
36
|
|
35
37
|
def __init__(self):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=PXZzhb0ZexNIwGhVJpg4Ln_RqskbSIMigqj0GdZgbeA,1883
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
@@ -62,8 +62,10 @@ tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAd
|
|
62
62
|
tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
|
63
63
|
tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
|
64
64
|
tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
65
|
+
tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=ODaRB234iy2dFfGIBd-OtKxdSzxnIbgKZkQ_o30tSts,5287
|
65
66
|
tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
|
66
67
|
tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
|
68
|
+
tico/experimental/quantization/ptq/examples/quantize_llama_decoder_layer.py,sha256=mBWrjkyEovYQsPC4Rrsri6Pm1rlFmDb3NiP0DQQhFyM,5751
|
67
69
|
tico/experimental/quantization/ptq/examples/quantize_llama_mlp.py,sha256=N1qZQgt1S-xZrdv-PW7OfXEcv0gsO2q9faOF4aD-zKo,4147
|
68
70
|
tico/experimental/quantization/ptq/observers/__init__.py,sha256=WF2MvL9M_jl-B1FqcY9zic34NOCRp17HkRYv-TMxMr4,613
|
69
71
|
tico/experimental/quantization/ptq/observers/affine_base.py,sha256=e2Eba64nrxKQyE4F_WJ7WTSsk3xe6bkdGUKaoLFWGFw,4638
|
@@ -72,15 +74,17 @@ tico/experimental/quantization/ptq/observers/ema.py,sha256=MAMdBmjVNMg_vsqXrcBzb
|
|
72
74
|
tico/experimental/quantization/ptq/observers/identity.py,sha256=vkec8Or-7VwM4zkFEvEKROQJk8XEHMVX8mBNDnxSyS8,2591
|
73
75
|
tico/experimental/quantization/ptq/observers/minmax.py,sha256=mLHkwIzWFzQXev7EU7w1333KckwRjukc3_cUPJOnUfs,1486
|
74
76
|
tico/experimental/quantization/ptq/observers/mx.py,sha256=aP4qmBgeiRIYZJksShN5gs6UyYOFi2-Sbk5k5xvPQ4w,1863
|
75
|
-
tico/experimental/quantization/ptq/utils/__init__.py,sha256=
|
77
|
+
tico/experimental/quantization/ptq/utils/__init__.py,sha256=MrQwMbbKS0dJrO8jsceCai4Z59iKQNpTPZND3GN6TrM,216
|
78
|
+
tico/experimental/quantization/ptq/utils/metrics.py,sha256=EW_FQmJrl9Y4esspZQ0GHfJ58RwuJUz0l8IfYq3NWY4,4461
|
76
79
|
tico/experimental/quantization/ptq/utils/reduce_utils.py,sha256=3kWawLB91EcvvHlCrNqqfZF7tpgr22htBSA049mKw_4,973
|
77
80
|
tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
78
81
|
tico/experimental/quantization/ptq/wrappers/ptq_wrapper.py,sha256=F9sK_DiRaXiGNHULcwIbs5EUtHz6ZJ7N4r5CWTTfhsM,2442
|
79
82
|
tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfvto6zKrBOKL4gmxfFFc31jHzyQV_zfps-iQM,3604
|
80
83
|
tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
|
81
|
-
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=
|
82
|
-
tico/experimental/quantization/ptq/wrappers/llama/__init__.py,sha256=
|
84
|
+
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=M1D_foC0PR-Ii4G0lbOO3_pmhvHlMF28NolK_q2DZtw,4783
|
85
|
+
tico/experimental/quantization/ptq/wrappers/llama/__init__.py,sha256=4xuAYnJcohMTtBzrH4cxq8WKG2GQo8nbhektVg8w7F0,380
|
83
86
|
tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py,sha256=WIUI6EFMTvvruvqu8pBxWy6qJeDyjkaYbJk1R3pAmwE,8578
|
87
|
+
tico/experimental/quantization/ptq/wrappers/llama/quant_decoder_layer.py,sha256=2XsIf5rcabDXXkahqriSxfo2curFq0Y5bnRPcYkJPg8,7187
|
84
88
|
tico/experimental/quantization/ptq/wrappers/llama/quant_mlp.py,sha256=uZMnrX66oZwxhKhcNbLXXeri-WxxRBiZnr15aBXJMm0,3562
|
85
89
|
tico/experimental/quantization/ptq/wrappers/nn/__init__.py,sha256=I9uTt5HfcRoMEDYHpAeATMv2TbCQiX0ZbfUFMzSJ4Qw,336
|
86
90
|
tico/experimental/quantization/ptq/wrappers/nn/quant_layernorm.py,sha256=G5Sgt-tXnzh0Rxyk-2honmZIfEQOZlRfOsoDBdSGmA4,6887
|
@@ -117,7 +121,7 @@ tico/passes/lower_to_slice.py,sha256=OzlFzK3lBYyYwC3WThsWd94Ob4JINIJF8UaLAtnumzU
|
|
117
121
|
tico/passes/merge_consecutive_cat.py,sha256=ayZNLDA1DFM7Fxxi2Dmk1CujkgUuaVCH1rhQgLrvvOQ,2701
|
118
122
|
tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
|
119
123
|
tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
|
120
|
-
tico/passes/remove_redundant_assert_nodes.py,sha256=
|
124
|
+
tico/passes/remove_redundant_assert_nodes.py,sha256=rYbTCyuNIXIC-2NreHKBVCuaSUkEQvB_iSRzb26P_EA,1821
|
121
125
|
tico/passes/remove_redundant_expand.py,sha256=auyqIoQT4HJhiJfuUe6BrEtUhvz221ohnIK5EuszWeg,2112
|
122
126
|
tico/passes/remove_redundant_permute.py,sha256=98UsaZzFZdQzEEAR1pIzRisAf6hgfXLa88aayjalt3E,4292
|
123
127
|
tico/passes/remove_redundant_reshape.py,sha256=aeep6LDvY58GEuOrWckkEXnJa6wkkbiJ9FrimT9F3-s,16384
|
@@ -240,9 +244,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
240
244
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
241
245
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
242
246
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
243
|
-
tico-0.1.0.
|
244
|
-
tico-0.1.0.
|
245
|
-
tico-0.1.0.
|
246
|
-
tico-0.1.0.
|
247
|
-
tico-0.1.0.
|
248
|
-
tico-0.1.0.
|
247
|
+
tico-0.1.0.dev250902.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
248
|
+
tico-0.1.0.dev250902.dist-info/METADATA,sha256=CePT5yw5-ln0-Ct8n61iGDnFfnoASlqAfPQmxRQ9QQ0,8450
|
249
|
+
tico-0.1.0.dev250902.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
250
|
+
tico-0.1.0.dev250902.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
251
|
+
tico-0.1.0.dev250902.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
252
|
+
tico-0.1.0.dev250902.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|