tico 0.1.0.dev250901__py3-none-any.whl → 0.1.0.dev250902__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250901"
32
+ __version__ = "0.1.0.dev250902"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,121 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # =============================================================================
16
+ # QUICK PTQ WORKFLOW (OPTIONAL FP32 BASELINE)
17
+ # -----------------------------------------------------------------------------
18
+ # Toggle RUN_FP to choose between:
19
+ # • FP32 perplexity measurement only, OR
20
+ # • Full post-training UINT-8 flow (wrap → calibrate → eval).
21
+ # =============================================================================
22
+
23
+ import torch
24
+ import tqdm
25
+ from datasets import load_dataset
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer
27
+
28
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
29
+ from tico.experimental.quantization.ptq.utils import perplexity
30
+ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
31
+
32
+ # -------------------------------------------------------------------------
33
+ # 0. Global configuration
34
+ # -------------------------------------------------------------------------
35
+ MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
36
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ STRIDE = 512 # sliding-window stride for perplexity
38
+ RUN_FP = True # set False → run UINT-8 path
39
+
40
+ # Token-budget presets for activation calibration
41
+ TOKENS: dict[str, int] = {
42
+ # Smoke test (<1 min turnaround on CPU/GPU)
43
+ "debug": 2_000, # ≈16 × 128-seq batches
44
+ # Good default for 1-7B models (≲3 % ppl delta)
45
+ "baseline": 50_000,
46
+ # Production / 4-bit observer smoothing
47
+ "production": 200_000,
48
+ }
49
+ CALIB_TOKENS = TOKENS["baseline"]
50
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
51
+
52
+ # -------------------------------------------------------------------------
53
+ # 1. Load model
54
+ # -------------------------------------------------------------------------
55
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
56
+
57
+ if RUN_FP:
58
+ # -- FP32 baseline ------------------------------------------------------
59
+ print("Loading FP32 model …")
60
+ fp_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
61
+ fp_model.config.use_cache = False
62
+ else:
63
+ # -- UINT-8 pipeline -----------------------------------------------------
64
+ print("Creating UINT-8 clone …")
65
+ uint8_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
66
+ uint8_model.config.use_cache = False
67
+
68
+ # ---------------------------------------------------------------------
69
+ # 2. Wrap every Transformer layer with PTQWrapper
70
+ # ---------------------------------------------------------------------
71
+ qcfg = QuantConfig() # all-uint8 defaults
72
+
73
+ wrapped_layers = torch.nn.ModuleList()
74
+ for idx, layer in enumerate(uint8_model.model.layers):
75
+ layer_cfg = qcfg.child(f"layer{idx}")
76
+ wrapped_layers.append(PTQWrapper(layer, qcfg=layer_cfg))
77
+ uint8_model.model.layers = wrapped_layers
78
+
79
+ # ---------------------------------------------------------------------
80
+ # 3. Single-pass activation calibration
81
+ # ---------------------------------------------------------------------
82
+ print("Calibrating UINT-8 observers …")
83
+ calib_txt = " ".join(
84
+ load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
85
+ )[:CALIB_TOKENS]
86
+ ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
87
+
88
+ # (a) switch every QuantModuleBase to CALIB mode
89
+ for l in uint8_model.model.layers:
90
+ l.enable_calibration()
91
+
92
+ # (b) run inference to collect ranges
93
+ with torch.no_grad():
94
+ for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Calibration"):
95
+ uint8_model(ids[:, i : i + STRIDE])
96
+
97
+ # (c) freeze (scale, zero-point)
98
+ for l in uint8_model.model.layers:
99
+ l.freeze_qparams()
100
+
101
+ # -------------------------------------------------------------------------
102
+ # 4. Evaluate perplexity on Wikitext-2
103
+ # -------------------------------------------------------------------------
104
+ print("\nCalculating perplexities …")
105
+ test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
106
+ enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
107
+
108
+ if RUN_FP:
109
+ ppl_fp = perplexity(fp_model, enc, DEVICE, stride=STRIDE)
110
+ else:
111
+ ppl_int8 = perplexity(uint8_model, enc, DEVICE, stride=STRIDE)
112
+
113
+ # -------------------------------------------------------------------------
114
+ # 5. Report
115
+ # -------------------------------------------------------------------------
116
+ print("\n┌── Wikitext-2 test perplexity ─────────────")
117
+ if RUN_FP:
118
+ print(f"│ FP32 : {ppl_fp:8.2f}")
119
+ else:
120
+ print(f"│ UINT-8 : {ppl_int8:8.2f}")
121
+ print("└───────────────────────────────────────────")
@@ -1,5 +1,7 @@
1
+ from tico.experimental.quantization.ptq.utils.metrics import perplexity
1
2
  from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
2
3
 
3
4
  __all__ = [
4
5
  "channelwise_minmax",
6
+ "perplexity",
5
7
  ]
@@ -0,0 +1,123 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import tqdm
19
+
20
+
21
+ def perplexity(
22
+ model: torch.nn.Module,
23
+ encodings: torch.Tensor,
24
+ device: torch.device | str,
25
+ *,
26
+ max_length: Optional[int] = None,
27
+ stride: int = 512,
28
+ ignore_index: int | None = -100,
29
+ show_progress: bool = True,
30
+ ) -> float:
31
+ """
32
+ Compute perplexity (PPL) using a "strided sliding-window"
33
+ evaluation strategy.
34
+
35
+ The function:
36
+ 1. Splits the token sequence into overlapping windows of length
37
+ `max_length` (model context size).
38
+ 2. Masks tokens that were already scored in previous windows
39
+ (`labels == -100`), so each token's negative log-likelihood (NLL)
40
+ is counted EXACTLY once.
41
+ 3. Aggregates token-wise NLL to return corpus-level PPL.
42
+
43
+ Parameters
44
+ ----------
45
+ model : torch.nn.Module
46
+ Causal LM loaded in evaluation mode (`model.eval()`).
47
+ encodings : torch.Tensor | transformers.BatchEncoding
48
+ Tokenised corpus. If a `BatchEncoding` is passed, its
49
+ `.input_ids` field is used. Shape must be `(1, seq_len)`.
50
+ device : torch.device | str
51
+ CUDA or CPU device on which to run evaluation.
52
+ max_length : int, optional
53
+ Context window size. Defaults to `model.config.max_position_embeddings`.
54
+ stride : int, default = 512
55
+ Step size by which the sliding window advances. Must satisfy
56
+ `1 ≤ stride ≤ max_length`.
57
+ ignore_index : int, default = -100
58
+ Label value to ignore in loss computation. This should match
59
+ the `ignore_index` used by the model's internal
60
+ `CrossEntropyLoss`. For Hugging Face causal LMs, the
61
+ convention is `-100`.
62
+ show_progress : bool, default = True
63
+ If True, displays a tqdm progess bar while evaluating.
64
+
65
+ Returns
66
+ -------
67
+ float
68
+ Corpus-level perplexity.
69
+ """
70
+ # -------- input preparation -------- #
71
+ try:
72
+ # transformers.BatchEncoding has `input_ids`
73
+ input_ids_full = encodings.input_ids # type: ignore[attr-defined]
74
+ except AttributeError: # already a tensor
75
+ input_ids_full = encodings
76
+ assert isinstance(input_ids_full, torch.Tensor)
77
+ input_ids_full = input_ids_full.to(device)
78
+
79
+ if max_length is None:
80
+ assert hasattr(model, "config")
81
+ assert hasattr(model.config, "max_position_embeddings")
82
+ assert isinstance(model.config.max_position_embeddings, int)
83
+ max_length = model.config.max_position_embeddings
84
+ assert max_length is not None
85
+ assert (
86
+ 1 <= stride <= max_length
87
+ ), f"stride ({stride}) must be in [1, max_length ({max_length})]"
88
+
89
+ seq_len = input_ids_full.size(1)
90
+ nll_sum = 0.0
91
+ n_tokens = 0
92
+ prev_end = 0
93
+
94
+ # -------- main loop -------- #
95
+ for begin in tqdm.trange(0, seq_len, stride, desc="PPL", disable=not show_progress):
96
+ end = min(begin + max_length, seq_len)
97
+ trg_len = end - prev_end # fresh tokens in this window
98
+
99
+ input_ids = input_ids_full[:, begin:end]
100
+ target_ids = input_ids.clone()
101
+ target_ids[:, :-trg_len] = ignore_index # mask previously-scored tokens
102
+
103
+ with torch.no_grad():
104
+ outputs = model(input_ids, labels=target_ids)
105
+ # loss is already averaged over non-masked labels
106
+ neg_log_likelihood = outputs.loss
107
+
108
+ # exact number of labels that contributed to loss
109
+ loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item()
110
+ nll_sum += neg_log_likelihood * loss_tokens
111
+ n_tokens += int(loss_tokens)
112
+
113
+ prev_end = end
114
+ if end == seq_len:
115
+ break
116
+
117
+ avg_nll: float | torch.Tensor = nll_sum / n_tokens
118
+ if not isinstance(avg_nll, torch.Tensor):
119
+ avg_nll = torch.tensor(avg_nll)
120
+ assert isinstance(avg_nll, torch.Tensor)
121
+ ppl = torch.exp(avg_nll)
122
+
123
+ return ppl.item()
@@ -21,7 +21,9 @@ from tico.utils.utils import is_target_node
21
21
 
22
22
 
23
23
  assert_node_targets = [
24
+ torch.ops.aten._assert_scalar.default,
24
25
  torch.ops.aten._assert_tensor_metadata.default,
26
+ torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation
25
27
  ]
26
28
 
27
29
 
@@ -29,7 +31,7 @@ assert_node_targets = [
29
31
  class RemoveRedundantAssertionNodes(PassBase):
30
32
  """
31
33
  This removes redundant assertion nodes.
32
- - `aten.assert_tensor_meta.default`
34
+ When assertion node is erased, related comparison nodes are also removed by graph.eliminate_dead_code().
33
35
  """
34
36
 
35
37
  def __init__(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250901
3
+ Version: 0.1.0.dev250902
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=MgvCVXWMpNL2dxPn54C8fdQaTJPdtHivhuNHH4qN5R8,1883
1
+ tico/__init__.py,sha256=PXZzhb0ZexNIwGhVJpg4Ln_RqskbSIMigqj0GdZgbeA,1883
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
@@ -62,6 +62,7 @@ tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAd
62
62
  tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
63
63
  tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
64
64
  tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
65
+ tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=ODaRB234iy2dFfGIBd-OtKxdSzxnIbgKZkQ_o30tSts,5287
65
66
  tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
66
67
  tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
67
68
  tico/experimental/quantization/ptq/examples/quantize_llama_decoder_layer.py,sha256=mBWrjkyEovYQsPC4Rrsri6Pm1rlFmDb3NiP0DQQhFyM,5751
@@ -73,7 +74,8 @@ tico/experimental/quantization/ptq/observers/ema.py,sha256=MAMdBmjVNMg_vsqXrcBzb
73
74
  tico/experimental/quantization/ptq/observers/identity.py,sha256=vkec8Or-7VwM4zkFEvEKROQJk8XEHMVX8mBNDnxSyS8,2591
74
75
  tico/experimental/quantization/ptq/observers/minmax.py,sha256=mLHkwIzWFzQXev7EU7w1333KckwRjukc3_cUPJOnUfs,1486
75
76
  tico/experimental/quantization/ptq/observers/mx.py,sha256=aP4qmBgeiRIYZJksShN5gs6UyYOFi2-Sbk5k5xvPQ4w,1863
76
- tico/experimental/quantization/ptq/utils/__init__.py,sha256=PL9IZgiWoMtsXVljeOy7KymmLVP238SXEFRLXYK72WQ,126
77
+ tico/experimental/quantization/ptq/utils/__init__.py,sha256=MrQwMbbKS0dJrO8jsceCai4Z59iKQNpTPZND3GN6TrM,216
78
+ tico/experimental/quantization/ptq/utils/metrics.py,sha256=EW_FQmJrl9Y4esspZQ0GHfJ58RwuJUz0l8IfYq3NWY4,4461
77
79
  tico/experimental/quantization/ptq/utils/reduce_utils.py,sha256=3kWawLB91EcvvHlCrNqqfZF7tpgr22htBSA049mKw_4,973
78
80
  tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
79
81
  tico/experimental/quantization/ptq/wrappers/ptq_wrapper.py,sha256=F9sK_DiRaXiGNHULcwIbs5EUtHz6ZJ7N4r5CWTTfhsM,2442
@@ -119,7 +121,7 @@ tico/passes/lower_to_slice.py,sha256=OzlFzK3lBYyYwC3WThsWd94Ob4JINIJF8UaLAtnumzU
119
121
  tico/passes/merge_consecutive_cat.py,sha256=ayZNLDA1DFM7Fxxi2Dmk1CujkgUuaVCH1rhQgLrvvOQ,2701
120
122
  tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
121
123
  tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
122
- tico/passes/remove_redundant_assert_nodes.py,sha256=IONd3xBy6I8tH6_Y1eN3_eCHH7WTC8soBgjXzOju9cQ,1612
124
+ tico/passes/remove_redundant_assert_nodes.py,sha256=rYbTCyuNIXIC-2NreHKBVCuaSUkEQvB_iSRzb26P_EA,1821
123
125
  tico/passes/remove_redundant_expand.py,sha256=auyqIoQT4HJhiJfuUe6BrEtUhvz221ohnIK5EuszWeg,2112
124
126
  tico/passes/remove_redundant_permute.py,sha256=98UsaZzFZdQzEEAR1pIzRisAf6hgfXLa88aayjalt3E,4292
125
127
  tico/passes/remove_redundant_reshape.py,sha256=aeep6LDvY58GEuOrWckkEXnJa6wkkbiJ9FrimT9F3-s,16384
@@ -242,9 +244,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
242
244
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
243
245
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
244
246
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
245
- tico-0.1.0.dev250901.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
246
- tico-0.1.0.dev250901.dist-info/METADATA,sha256=LMgoYoHYFT8cJU9VNYiiX89tMSxEX30x17x_6eWAr4o,8450
247
- tico-0.1.0.dev250901.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
248
- tico-0.1.0.dev250901.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
249
- tico-0.1.0.dev250901.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
250
- tico-0.1.0.dev250901.dist-info/RECORD,,
247
+ tico-0.1.0.dev250902.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
248
+ tico-0.1.0.dev250902.dist-info/METADATA,sha256=CePT5yw5-ln0-Ct8n61iGDnFfnoASlqAfPQmxRQ9QQ0,8450
249
+ tico-0.1.0.dev250902.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
250
+ tico-0.1.0.dev250902.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
251
+ tico-0.1.0.dev250902.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
252
+ tico-0.1.0.dev250902.dist-info/RECORD,,