tico 0.1.0.dev250901__py3-none-any.whl → 0.1.0.dev250902__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/ptq/examples/compare_ppl.py +121 -0
- tico/experimental/quantization/ptq/utils/__init__.py +2 -0
- tico/experimental/quantization/ptq/utils/metrics.py +123 -0
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- {tico-0.1.0.dev250901.dist-info → tico-0.1.0.dev250902.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250901.dist-info → tico-0.1.0.dev250902.dist-info}/RECORD +11 -9
- {tico-0.1.0.dev250901.dist-info → tico-0.1.0.dev250902.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250901.dist-info → tico-0.1.0.dev250902.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250901.dist-info → tico-0.1.0.dev250902.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250901.dist-info → tico-0.1.0.dev250902.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -0,0 +1,121 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# =============================================================================
|
16
|
+
# QUICK PTQ WORKFLOW (OPTIONAL FP32 BASELINE)
|
17
|
+
# -----------------------------------------------------------------------------
|
18
|
+
# Toggle RUN_FP to choose between:
|
19
|
+
# • FP32 perplexity measurement only, OR
|
20
|
+
# • Full post-training UINT-8 flow (wrap → calibrate → eval).
|
21
|
+
# =============================================================================
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import tqdm
|
25
|
+
from datasets import load_dataset
|
26
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
27
|
+
|
28
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
29
|
+
from tico.experimental.quantization.ptq.utils import perplexity
|
30
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
31
|
+
|
32
|
+
# -------------------------------------------------------------------------
|
33
|
+
# 0. Global configuration
|
34
|
+
# -------------------------------------------------------------------------
|
35
|
+
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
36
|
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
37
|
+
STRIDE = 512 # sliding-window stride for perplexity
|
38
|
+
RUN_FP = True # set False → run UINT-8 path
|
39
|
+
|
40
|
+
# Token-budget presets for activation calibration
|
41
|
+
TOKENS: dict[str, int] = {
|
42
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
43
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
44
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
45
|
+
"baseline": 50_000,
|
46
|
+
# Production / 4-bit observer smoothing
|
47
|
+
"production": 200_000,
|
48
|
+
}
|
49
|
+
CALIB_TOKENS = TOKENS["baseline"]
|
50
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
51
|
+
|
52
|
+
# -------------------------------------------------------------------------
|
53
|
+
# 1. Load model
|
54
|
+
# -------------------------------------------------------------------------
|
55
|
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
56
|
+
|
57
|
+
if RUN_FP:
|
58
|
+
# -- FP32 baseline ------------------------------------------------------
|
59
|
+
print("Loading FP32 model …")
|
60
|
+
fp_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
61
|
+
fp_model.config.use_cache = False
|
62
|
+
else:
|
63
|
+
# -- UINT-8 pipeline -----------------------------------------------------
|
64
|
+
print("Creating UINT-8 clone …")
|
65
|
+
uint8_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
66
|
+
uint8_model.config.use_cache = False
|
67
|
+
|
68
|
+
# ---------------------------------------------------------------------
|
69
|
+
# 2. Wrap every Transformer layer with PTQWrapper
|
70
|
+
# ---------------------------------------------------------------------
|
71
|
+
qcfg = QuantConfig() # all-uint8 defaults
|
72
|
+
|
73
|
+
wrapped_layers = torch.nn.ModuleList()
|
74
|
+
for idx, layer in enumerate(uint8_model.model.layers):
|
75
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
76
|
+
wrapped_layers.append(PTQWrapper(layer, qcfg=layer_cfg))
|
77
|
+
uint8_model.model.layers = wrapped_layers
|
78
|
+
|
79
|
+
# ---------------------------------------------------------------------
|
80
|
+
# 3. Single-pass activation calibration
|
81
|
+
# ---------------------------------------------------------------------
|
82
|
+
print("Calibrating UINT-8 observers …")
|
83
|
+
calib_txt = " ".join(
|
84
|
+
load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
|
85
|
+
)[:CALIB_TOKENS]
|
86
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
|
87
|
+
|
88
|
+
# (a) switch every QuantModuleBase to CALIB mode
|
89
|
+
for l in uint8_model.model.layers:
|
90
|
+
l.enable_calibration()
|
91
|
+
|
92
|
+
# (b) run inference to collect ranges
|
93
|
+
with torch.no_grad():
|
94
|
+
for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Calibration"):
|
95
|
+
uint8_model(ids[:, i : i + STRIDE])
|
96
|
+
|
97
|
+
# (c) freeze (scale, zero-point)
|
98
|
+
for l in uint8_model.model.layers:
|
99
|
+
l.freeze_qparams()
|
100
|
+
|
101
|
+
# -------------------------------------------------------------------------
|
102
|
+
# 4. Evaluate perplexity on Wikitext-2
|
103
|
+
# -------------------------------------------------------------------------
|
104
|
+
print("\nCalculating perplexities …")
|
105
|
+
test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
106
|
+
enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
|
107
|
+
|
108
|
+
if RUN_FP:
|
109
|
+
ppl_fp = perplexity(fp_model, enc, DEVICE, stride=STRIDE)
|
110
|
+
else:
|
111
|
+
ppl_int8 = perplexity(uint8_model, enc, DEVICE, stride=STRIDE)
|
112
|
+
|
113
|
+
# -------------------------------------------------------------------------
|
114
|
+
# 5. Report
|
115
|
+
# -------------------------------------------------------------------------
|
116
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
117
|
+
if RUN_FP:
|
118
|
+
print(f"│ FP32 : {ppl_fp:8.2f}")
|
119
|
+
else:
|
120
|
+
print(f"│ UINT-8 : {ppl_int8:8.2f}")
|
121
|
+
print("└───────────────────────────────────────────")
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import tqdm
|
19
|
+
|
20
|
+
|
21
|
+
def perplexity(
|
22
|
+
model: torch.nn.Module,
|
23
|
+
encodings: torch.Tensor,
|
24
|
+
device: torch.device | str,
|
25
|
+
*,
|
26
|
+
max_length: Optional[int] = None,
|
27
|
+
stride: int = 512,
|
28
|
+
ignore_index: int | None = -100,
|
29
|
+
show_progress: bool = True,
|
30
|
+
) -> float:
|
31
|
+
"""
|
32
|
+
Compute perplexity (PPL) using a "strided sliding-window"
|
33
|
+
evaluation strategy.
|
34
|
+
|
35
|
+
The function:
|
36
|
+
1. Splits the token sequence into overlapping windows of length
|
37
|
+
`max_length` (model context size).
|
38
|
+
2. Masks tokens that were already scored in previous windows
|
39
|
+
(`labels == -100`), so each token's negative log-likelihood (NLL)
|
40
|
+
is counted EXACTLY once.
|
41
|
+
3. Aggregates token-wise NLL to return corpus-level PPL.
|
42
|
+
|
43
|
+
Parameters
|
44
|
+
----------
|
45
|
+
model : torch.nn.Module
|
46
|
+
Causal LM loaded in evaluation mode (`model.eval()`).
|
47
|
+
encodings : torch.Tensor | transformers.BatchEncoding
|
48
|
+
Tokenised corpus. If a `BatchEncoding` is passed, its
|
49
|
+
`.input_ids` field is used. Shape must be `(1, seq_len)`.
|
50
|
+
device : torch.device | str
|
51
|
+
CUDA or CPU device on which to run evaluation.
|
52
|
+
max_length : int, optional
|
53
|
+
Context window size. Defaults to `model.config.max_position_embeddings`.
|
54
|
+
stride : int, default = 512
|
55
|
+
Step size by which the sliding window advances. Must satisfy
|
56
|
+
`1 ≤ stride ≤ max_length`.
|
57
|
+
ignore_index : int, default = -100
|
58
|
+
Label value to ignore in loss computation. This should match
|
59
|
+
the `ignore_index` used by the model's internal
|
60
|
+
`CrossEntropyLoss`. For Hugging Face causal LMs, the
|
61
|
+
convention is `-100`.
|
62
|
+
show_progress : bool, default = True
|
63
|
+
If True, displays a tqdm progess bar while evaluating.
|
64
|
+
|
65
|
+
Returns
|
66
|
+
-------
|
67
|
+
float
|
68
|
+
Corpus-level perplexity.
|
69
|
+
"""
|
70
|
+
# -------- input preparation -------- #
|
71
|
+
try:
|
72
|
+
# transformers.BatchEncoding has `input_ids`
|
73
|
+
input_ids_full = encodings.input_ids # type: ignore[attr-defined]
|
74
|
+
except AttributeError: # already a tensor
|
75
|
+
input_ids_full = encodings
|
76
|
+
assert isinstance(input_ids_full, torch.Tensor)
|
77
|
+
input_ids_full = input_ids_full.to(device)
|
78
|
+
|
79
|
+
if max_length is None:
|
80
|
+
assert hasattr(model, "config")
|
81
|
+
assert hasattr(model.config, "max_position_embeddings")
|
82
|
+
assert isinstance(model.config.max_position_embeddings, int)
|
83
|
+
max_length = model.config.max_position_embeddings
|
84
|
+
assert max_length is not None
|
85
|
+
assert (
|
86
|
+
1 <= stride <= max_length
|
87
|
+
), f"stride ({stride}) must be in [1, max_length ({max_length})]"
|
88
|
+
|
89
|
+
seq_len = input_ids_full.size(1)
|
90
|
+
nll_sum = 0.0
|
91
|
+
n_tokens = 0
|
92
|
+
prev_end = 0
|
93
|
+
|
94
|
+
# -------- main loop -------- #
|
95
|
+
for begin in tqdm.trange(0, seq_len, stride, desc="PPL", disable=not show_progress):
|
96
|
+
end = min(begin + max_length, seq_len)
|
97
|
+
trg_len = end - prev_end # fresh tokens in this window
|
98
|
+
|
99
|
+
input_ids = input_ids_full[:, begin:end]
|
100
|
+
target_ids = input_ids.clone()
|
101
|
+
target_ids[:, :-trg_len] = ignore_index # mask previously-scored tokens
|
102
|
+
|
103
|
+
with torch.no_grad():
|
104
|
+
outputs = model(input_ids, labels=target_ids)
|
105
|
+
# loss is already averaged over non-masked labels
|
106
|
+
neg_log_likelihood = outputs.loss
|
107
|
+
|
108
|
+
# exact number of labels that contributed to loss
|
109
|
+
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item()
|
110
|
+
nll_sum += neg_log_likelihood * loss_tokens
|
111
|
+
n_tokens += int(loss_tokens)
|
112
|
+
|
113
|
+
prev_end = end
|
114
|
+
if end == seq_len:
|
115
|
+
break
|
116
|
+
|
117
|
+
avg_nll: float | torch.Tensor = nll_sum / n_tokens
|
118
|
+
if not isinstance(avg_nll, torch.Tensor):
|
119
|
+
avg_nll = torch.tensor(avg_nll)
|
120
|
+
assert isinstance(avg_nll, torch.Tensor)
|
121
|
+
ppl = torch.exp(avg_nll)
|
122
|
+
|
123
|
+
return ppl.item()
|
@@ -21,7 +21,9 @@ from tico.utils.utils import is_target_node
|
|
21
21
|
|
22
22
|
|
23
23
|
assert_node_targets = [
|
24
|
+
torch.ops.aten._assert_scalar.default,
|
24
25
|
torch.ops.aten._assert_tensor_metadata.default,
|
26
|
+
torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation
|
25
27
|
]
|
26
28
|
|
27
29
|
|
@@ -29,7 +31,7 @@ assert_node_targets = [
|
|
29
31
|
class RemoveRedundantAssertionNodes(PassBase):
|
30
32
|
"""
|
31
33
|
This removes redundant assertion nodes.
|
32
|
-
|
34
|
+
When assertion node is erased, related comparison nodes are also removed by graph.eliminate_dead_code().
|
33
35
|
"""
|
34
36
|
|
35
37
|
def __init__(self):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=PXZzhb0ZexNIwGhVJpg4Ln_RqskbSIMigqj0GdZgbeA,1883
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
@@ -62,6 +62,7 @@ tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAd
|
|
62
62
|
tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
|
63
63
|
tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
|
64
64
|
tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
65
|
+
tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=ODaRB234iy2dFfGIBd-OtKxdSzxnIbgKZkQ_o30tSts,5287
|
65
66
|
tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
|
66
67
|
tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
|
67
68
|
tico/experimental/quantization/ptq/examples/quantize_llama_decoder_layer.py,sha256=mBWrjkyEovYQsPC4Rrsri6Pm1rlFmDb3NiP0DQQhFyM,5751
|
@@ -73,7 +74,8 @@ tico/experimental/quantization/ptq/observers/ema.py,sha256=MAMdBmjVNMg_vsqXrcBzb
|
|
73
74
|
tico/experimental/quantization/ptq/observers/identity.py,sha256=vkec8Or-7VwM4zkFEvEKROQJk8XEHMVX8mBNDnxSyS8,2591
|
74
75
|
tico/experimental/quantization/ptq/observers/minmax.py,sha256=mLHkwIzWFzQXev7EU7w1333KckwRjukc3_cUPJOnUfs,1486
|
75
76
|
tico/experimental/quantization/ptq/observers/mx.py,sha256=aP4qmBgeiRIYZJksShN5gs6UyYOFi2-Sbk5k5xvPQ4w,1863
|
76
|
-
tico/experimental/quantization/ptq/utils/__init__.py,sha256=
|
77
|
+
tico/experimental/quantization/ptq/utils/__init__.py,sha256=MrQwMbbKS0dJrO8jsceCai4Z59iKQNpTPZND3GN6TrM,216
|
78
|
+
tico/experimental/quantization/ptq/utils/metrics.py,sha256=EW_FQmJrl9Y4esspZQ0GHfJ58RwuJUz0l8IfYq3NWY4,4461
|
77
79
|
tico/experimental/quantization/ptq/utils/reduce_utils.py,sha256=3kWawLB91EcvvHlCrNqqfZF7tpgr22htBSA049mKw_4,973
|
78
80
|
tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
79
81
|
tico/experimental/quantization/ptq/wrappers/ptq_wrapper.py,sha256=F9sK_DiRaXiGNHULcwIbs5EUtHz6ZJ7N4r5CWTTfhsM,2442
|
@@ -119,7 +121,7 @@ tico/passes/lower_to_slice.py,sha256=OzlFzK3lBYyYwC3WThsWd94Ob4JINIJF8UaLAtnumzU
|
|
119
121
|
tico/passes/merge_consecutive_cat.py,sha256=ayZNLDA1DFM7Fxxi2Dmk1CujkgUuaVCH1rhQgLrvvOQ,2701
|
120
122
|
tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
|
121
123
|
tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
|
122
|
-
tico/passes/remove_redundant_assert_nodes.py,sha256=
|
124
|
+
tico/passes/remove_redundant_assert_nodes.py,sha256=rYbTCyuNIXIC-2NreHKBVCuaSUkEQvB_iSRzb26P_EA,1821
|
123
125
|
tico/passes/remove_redundant_expand.py,sha256=auyqIoQT4HJhiJfuUe6BrEtUhvz221ohnIK5EuszWeg,2112
|
124
126
|
tico/passes/remove_redundant_permute.py,sha256=98UsaZzFZdQzEEAR1pIzRisAf6hgfXLa88aayjalt3E,4292
|
125
127
|
tico/passes/remove_redundant_reshape.py,sha256=aeep6LDvY58GEuOrWckkEXnJa6wkkbiJ9FrimT9F3-s,16384
|
@@ -242,9 +244,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
242
244
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
243
245
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
244
246
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
245
|
-
tico-0.1.0.
|
246
|
-
tico-0.1.0.
|
247
|
-
tico-0.1.0.
|
248
|
-
tico-0.1.0.
|
249
|
-
tico-0.1.0.
|
250
|
-
tico-0.1.0.
|
247
|
+
tico-0.1.0.dev250902.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
248
|
+
tico-0.1.0.dev250902.dist-info/METADATA,sha256=CePT5yw5-ln0-Ct8n61iGDnFfnoASlqAfPQmxRQ9QQ0,8450
|
249
|
+
tico-0.1.0.dev250902.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
250
|
+
tico-0.1.0.dev250902.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
251
|
+
tico-0.1.0.dev250902.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
252
|
+
tico-0.1.0.dev250902.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|