tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_expand.py +3 -1
  10. tico/quantization/__init__.py +6 -0
  11. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  12. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  14. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  29. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  31. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  32. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  33. tico/quantization/config/base.py +26 -0
  34. tico/quantization/config/fpi_gptq.py +29 -0
  35. tico/quantization/config/gptq.py +29 -0
  36. tico/quantization/config/pt2e.py +25 -0
  37. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  38. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
  39. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  40. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  41. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  42. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  47. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  48. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  52. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  59. tico/quantization/wrapq/quantizer.py +179 -0
  60. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  62. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  63. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  64. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  65. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  66. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  67. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  68. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  69. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  70. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
  71. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
  72. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  73. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  74. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  75. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  76. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  77. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  78. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  79. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  80. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
  81. tico/serialize/circle_serializer.py +11 -4
  82. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  83. tico/serialize/operators/op_le.py +54 -0
  84. tico/serialize/operators/op_mm.py +15 -132
  85. tico/utils/convert.py +20 -15
  86. tico/utils/register_custom_op.py +6 -4
  87. tico/utils/signature.py +7 -8
  88. tico/utils/validate_args_kwargs.py +12 -0
  89. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  90. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
  91. tico/experimental/quantization/__init__.py +0 -6
  92. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  93. tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
  94. tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
  95. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
  96. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  97. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  98. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  99. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  100. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  101. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  102. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  103. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  104. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  105. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  106. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  107. /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
  108. /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
  109. /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  111. /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
  112. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  113. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  114. /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
  115. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  116. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  117. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  118. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  119. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  120. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  121. /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
  122. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  123. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
  124. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  125. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
  126. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  127. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
  128. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  129. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
  130. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,224 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # ============================================================================
16
+ # LAYER-WISE DIFF DEBUGGING PIPELINE
17
+ # ----------------------------------------------------------------------------
18
+ # A quantization debugging pipeline that identifies accuracy regressions
19
+ # by comparing UINT vs FP outputs at each layer.
20
+ #
21
+ # 1. Load a full-precision (FP) LLaMA-3-1B model.
22
+ # 2. Wrap each Transformer block with PTQWrapper (activations → fake-quant).
23
+ # 3. Capture reference FP layer outputs before quantization.
24
+ # 4. Calibrate UINT-8 activation observers in a single pass.
25
+ # 5. Freeze quantization parameters (scale, zero-point).
26
+ # 6. Re-run inference and compare UINT-8 vs FP outputs per layer.
27
+ # 7. Report where quantization hurts the most.
28
+ #
29
+ # Use this pipeline to trace precision loss layer by layer, and pinpoint
30
+ # problematic modules during post-training quantization.
31
+ # ============================================================================
32
+
33
+ import argparse
34
+ import sys
35
+
36
+ import torch
37
+ import tqdm
38
+ from datasets import load_dataset
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer
40
+
41
+ from tico.quantization import convert, prepare
42
+ from tico.quantization.config.ptq import PTQConfig
43
+ from tico.quantization.wrapq.utils.introspection import (
44
+ build_fqn_map,
45
+ compare_layer_outputs,
46
+ save_fp_outputs,
47
+ )
48
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
49
+
50
+ # Token-budget presets for activation calibration
51
+ TOKENS: dict[str, int] = {
52
+ # Smoke test (<1 min turnaround on CPU/GPU)
53
+ "debug": 2_000, # ≈16 × 128-seq batches
54
+ # Good default for 1-7B models (≲3 % ppl delta)
55
+ "baseline": 50_000,
56
+ # Production / 4-bit observer smoothing
57
+ "production": 200_000,
58
+ }
59
+
60
+ DTYPE_MAP = {
61
+ "float32": torch.float32,
62
+ "bfloat16": torch.bfloat16,
63
+ "float16": torch.float16,
64
+ }
65
+
66
+ # Hardcoded dataset settings
67
+ DATASET_NAME = "wikitext"
68
+ DATASET_CONFIG = "wikitext-2-raw-v1"
69
+ TRAIN_SPLIT = "train"
70
+
71
+
72
+ def main():
73
+ parser = argparse.ArgumentParser(
74
+ description="Layer-wise diff debugging pipeline for PTQ"
75
+ )
76
+ parser.add_argument(
77
+ "--model", type=str, required=True, help="HF repo name or local path."
78
+ )
79
+ parser.add_argument(
80
+ "--device",
81
+ type=str,
82
+ default="cuda" if torch.cuda.is_available() else "cpu",
83
+ help="Device to run on (cuda|cpu|mps).",
84
+ )
85
+ parser.add_argument(
86
+ "--dtype",
87
+ choices=list(DTYPE_MAP.keys()),
88
+ default="float32",
89
+ help=f"Model dtype for load.",
90
+ )
91
+ parser.add_argument(
92
+ "--stride",
93
+ type=int,
94
+ default=512,
95
+ help="Sliding-window stride used during calibration.",
96
+ )
97
+ parser.add_argument(
98
+ "--calib-preset",
99
+ choices=list(TOKENS.keys()),
100
+ default="debug",
101
+ help="Calibration token budget preset.",
102
+ )
103
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
104
+ parser.add_argument(
105
+ "--trust-remote-code",
106
+ action="store_true",
107
+ help="Enable only if you trust the model repo code.",
108
+ )
109
+ parser.add_argument(
110
+ "--hf-token",
111
+ type=str,
112
+ default=None,
113
+ help="Optional HF token for gated/private repos.",
114
+ )
115
+ parser.add_argument(
116
+ "--use-cache",
117
+ dest="use_cache",
118
+ action="store_true",
119
+ default=False,
120
+ help="Use model KV cache if enabled (off by default).",
121
+ )
122
+ parser.add_argument(
123
+ "--no-tqdm", action="store_true", help="Disable tqdm progress bars."
124
+ )
125
+
126
+ args = parser.parse_args()
127
+
128
+ # Basic setup
129
+ torch.manual_seed(args.seed)
130
+ device = torch.device(args.device)
131
+ dtype = DTYPE_MAP[args.dtype] # noqa: E999 (kept readable)
132
+
133
+ print("=== Config ===")
134
+ print(f"Model : {args.model}")
135
+ print(f"Device : {device.type}")
136
+ print(f"DType : {args.dtype}")
137
+ print(f"Stride : {args.stride}")
138
+ print(
139
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
140
+ )
141
+ print(f"Use HF cache? : {args.use_cache}")
142
+ print()
143
+
144
+ # -------------------------------------------------------------------------
145
+ # 1. Load the FP backbone and tokenizer
146
+ # -------------------------------------------------------------------------
147
+ print("Loading FP model …")
148
+ tokenizer = AutoTokenizer.from_pretrained(
149
+ args.model,
150
+ trust_remote_code=args.trust_remote_code,
151
+ token=args.hf_token,
152
+ )
153
+ model = (
154
+ AutoModelForCausalLM.from_pretrained(
155
+ args.model,
156
+ torch_dtype=dtype,
157
+ trust_remote_code=args.trust_remote_code,
158
+ token=args.hf_token,
159
+ )
160
+ .to(device)
161
+ .eval()
162
+ )
163
+
164
+ # Disable KV cache to force full forward passes for introspection
165
+ model.config.use_cache = args.use_cache
166
+
167
+ # Build module -> FQN map before wrapping
168
+ m_to_fqn = build_fqn_map(model)
169
+
170
+ # Prepare calibration inputs (HF Wikitext-2 train split)
171
+ CALIB_TOKENS = TOKENS[args.calib_preset]
172
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
173
+ # Use Wikitext-2 train split for calibration.
174
+ dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
175
+
176
+ # -------------------------------------------------------------------------
177
+ # 2. Wrap every layer with PTQWrapper (UINT-8 activations)
178
+ # -------------------------------------------------------------------------
179
+ print("Wrapping layers with PTQWrapper …")
180
+ qcfg = PTQConfig() # default: per-tensor UINT8
181
+ prepare(model, qcfg)
182
+
183
+ # -------------------------------------------------------------------------
184
+ # 3. Activation calibration plus FP-vs-UINT8 diffing
185
+ # -------------------------------------------------------------------------
186
+ print("Calibrating UINT-8 observers …")
187
+ calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
188
+ ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
189
+
190
+ # Save reference FP activations before observers clamp/quantize
191
+ save_handles, act_cache = save_fp_outputs(model)
192
+
193
+ iterator = range(0, ids.size(1) - 1, args.stride)
194
+ if not args.no_tqdm:
195
+ iterator = tqdm.tqdm(iterator, desc="Act-Calibration")
196
+ with torch.no_grad():
197
+ for i in iterator:
198
+ inputs = ids[:, i : i + args.stride]
199
+ model(inputs) # observers collect act. ranges
200
+
201
+ # Remove save hooks now that FP activations are cached
202
+ for h in save_handles:
203
+ h.remove()
204
+
205
+ # Freeze (scale, zero-point) after calibration
206
+ convert(model)
207
+
208
+ # Register diff hooks and measure per-layer deltas
209
+ cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
210
+ # Use same inputs for comparison.
211
+ with torch.no_grad():
212
+ model(inputs)
213
+
214
+ assert isinstance(cmp_handles, list)
215
+ for h in cmp_handles:
216
+ h.remove()
217
+
218
+
219
+ if __name__ == "__main__":
220
+ try:
221
+ main()
222
+ except Exception as e:
223
+ print(f"\n[Error] {e}", file=sys.stderr)
224
+ sys.exit(1)
@@ -29,13 +29,15 @@ import pathlib
29
29
  import torch
30
30
  import torch.nn as nn
31
31
 
32
- from tico.experimental.quantization.evaluation.metric import compute_peir
33
- from tico.experimental.quantization.evaluation.utils import plot_two_outputs
34
-
35
- from tico.experimental.quantization.ptq.mode import Mode
36
- from tico.experimental.quantization.ptq.wrappers.nn.quant_linear import QuantLinear
32
+ from tico.quantization import convert, prepare
33
+ from tico.quantization.config.ptq import PTQConfig
34
+ from tico.quantization.evaluation.metric import compute_peir
35
+ from tico.quantization.evaluation.utils import plot_two_outputs
36
+ from tico.quantization.wrapq.mode import Mode
37
+ from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear
37
38
  from tico.utils.utils import SuppressWarning
38
39
 
40
+
39
41
  # -------------------------------------------------------------------------
40
42
  # 0. Define a toy model (1 Linear layer only)
41
43
  # -------------------------------------------------------------------------
@@ -60,20 +62,19 @@ fp32_layer = model.fc
60
62
  # -------------------------------------------------------------------------
61
63
  # 1. Replace the Linear with QuantLinear wrapper
62
64
  # -------------------------------------------------------------------------
63
- model.fc = QuantLinear(fp32_layer) # type: ignore[assignment]
64
- # model.fc = PTQWrapper(fp32_layer) (Wrapping helper class)
65
+ model.fc = prepare(fp32_layer, PTQConfig()) # type: ignore[assignment]
65
66
  qlayer = model.fc # alias for brevity
66
67
 
67
68
  # -------------------------------------------------------------------------
68
69
  # 2. Single-pass calibration (collect activation ranges)
69
70
  # -------------------------------------------------------------------------
70
- assert isinstance(qlayer, QuantLinear)
71
+ assert isinstance(qlayer.wrapped, QuantLinear)
71
72
  with torch.no_grad():
72
- qlayer.enable_calibration()
73
73
  for _ in range(16): # small toy batch
74
74
  x = torch.randn(4, 16) # (batch=4, features=16)
75
75
  _ = model(x)
76
- qlayer.freeze_qparams() # lock scales & zero-points
76
+
77
+ convert(qlayer)
77
78
 
78
79
  assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now."
79
80
 
@@ -17,13 +17,12 @@ import pathlib
17
17
  import torch
18
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
19
 
20
- from tico.experimental.quantization.evaluation.metric import compute_peir
21
- from tico.experimental.quantization.evaluation.utils import plot_two_outputs
22
-
23
- from tico.experimental.quantization.ptq.mode import Mode
24
- from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
25
- QuantLlamaAttention,
26
- )
20
+ from tico.quantization import convert, prepare
21
+ from tico.quantization.config.ptq import PTQConfig
22
+ from tico.quantization.evaluation.metric import compute_peir
23
+ from tico.quantization.evaluation.utils import plot_two_outputs
24
+ from tico.quantization.wrapq.mode import Mode
25
+ from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
27
26
  from tico.utils.utils import SuppressWarning
28
27
 
29
28
  name = "Maykeye/TinyLLama-v0"
@@ -34,12 +33,11 @@ tokenizer = AutoTokenizer.from_pretrained(name)
34
33
  # 1. Replace layer-0’s MLP with QuantLlamaMLP
35
34
  # -------------------------------------------------------------------------
36
35
  orig_attn = model.model.layers[0].self_attn
37
- model.model.layers[0].self_attn = QuantLlamaAttention(
38
- orig_attn
39
- ) # PTQWrapper(orig_attn) is also fine
36
+ model.model.layers[0].self_attn = prepare(orig_attn, PTQConfig())
40
37
  model.eval()
41
38
 
42
39
  attn_q = model.model.layers[0].self_attn # quant wrapper
40
+ assert isinstance(attn_q.wrapped, QuantLlamaAttention)
43
41
  rotary = model.model.rotary_emb
44
42
 
45
43
  # -------------------------------------------------------------------------
@@ -55,7 +53,6 @@ PROMPTS = [
55
53
  ]
56
54
 
57
55
  with torch.no_grad():
58
- attn_q.enable_calibration()
59
56
  for prompt in PROMPTS:
60
57
  ids = tokenizer(prompt, return_tensors="pt")
61
58
  embeds = model.model.embed_tokens(ids["input_ids"])
@@ -63,7 +60,8 @@ with torch.no_grad():
63
60
  S = cos_sin[0].shape[1]
64
61
  float_mask = torch.zeros(1, 1, S, S)
65
62
  _ = attn_q(embeds, cos_sin) # observers collect
66
- attn_q.freeze_qparams()
63
+
64
+ convert(attn_q)
67
65
 
68
66
  assert attn_q._mode is Mode.QUANT, "Quantization mode should be active now."
69
67
 
@@ -31,10 +31,12 @@ import pathlib
31
31
  import torch
32
32
  from transformers import AutoModelForCausalLM, AutoTokenizer
33
33
 
34
- from tico.experimental.quantization.evaluation.metric import compute_peir
35
- from tico.experimental.quantization.evaluation.utils import plot_two_outputs
36
- from tico.experimental.quantization.ptq.mode import Mode
37
- from tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer import (
34
+ from tico.quantization import convert, prepare
35
+ from tico.quantization.config.ptq import PTQConfig
36
+ from tico.quantization.evaluation.metric import compute_peir
37
+ from tico.quantization.evaluation.utils import plot_two_outputs
38
+ from tico.quantization.wrapq.mode import Mode
39
+ from tico.quantization.wrapq.wrappers.llama.quant_decoder_layer import (
38
40
  QuantLlamaDecoderLayer,
39
41
  )
40
42
  from tico.utils.utils import SuppressWarning
@@ -50,12 +52,11 @@ rotary = model.model.rotary_emb # RoPE helper
50
52
  # 1. Swap in the quant wrapper
51
53
  # -------------------------------------------------------------------------
52
54
  fp32_layer = model.model.layers[0] # keep a reference for diff check
53
- model.model.layers[0] = QuantLlamaDecoderLayer(
54
- fp32_layer
55
- ) # PTQWrapper(fp32_layer) is also fine
55
+ model.model.layers[0] = prepare(fp32_layer, PTQConfig())
56
56
  model.eval()
57
57
 
58
58
  qlayer = model.model.layers[0] # alias for brevity
59
+ assert isinstance(qlayer.wrapped, QuantLlamaDecoderLayer)
59
60
 
60
61
  # -------------------------------------------------------------------------
61
62
  # 2. Single-pass calibration (gather activation ranges)
@@ -70,7 +71,6 @@ PROMPTS = [
70
71
  ]
71
72
 
72
73
  with torch.no_grad():
73
- qlayer.enable_calibration()
74
74
  for prompt in PROMPTS:
75
75
  ids = tokenizer(prompt, return_tensors="pt")
76
76
  hidden = model.model.embed_tokens(ids["input_ids"])
@@ -78,7 +78,8 @@ with torch.no_grad():
78
78
  S = pos[0].shape[1]
79
79
  attn_mask = torch.zeros(1, 1, S, S) # causal-mask placeholder
80
80
  _ = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos)
81
- qlayer.freeze_qparams()
81
+
82
+ convert(qlayer)
82
83
 
83
84
  assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now."
84
85
 
@@ -18,13 +18,14 @@ import torch
18
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
19
19
 
20
20
  import tico
21
- from tico.experimental.quantization.evaluation.metric import compute_peir
22
- from tico.experimental.quantization.evaluation.utils import plot_two_outputs
23
- from tico.experimental.quantization.ptq.dtypes import INT16
24
- from tico.experimental.quantization.ptq.mode import Mode
25
- from tico.experimental.quantization.ptq.qscheme import QScheme
26
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
27
- from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
21
+ from tico.quantization import convert, prepare
22
+ from tico.quantization.config.ptq import PTQConfig
23
+ from tico.quantization.evaluation.metric import compute_peir
24
+ from tico.quantization.evaluation.utils import plot_two_outputs
25
+ from tico.quantization.wrapq.dtypes import INT16
26
+ from tico.quantization.wrapq.mode import Mode
27
+ from tico.quantization.wrapq.qscheme import QScheme
28
+ from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
28
29
  from tico.utils.utils import SuppressWarning
29
30
 
30
31
  name = "Maykeye/TinyLLama-v0"
@@ -36,13 +37,13 @@ model.eval()
36
37
  # 1. Replace layer-0’s MLP with QuantLlamaMLP
37
38
  # -------------------------------------------------------------------------
38
39
  fp32_mlp = model.model.layers[0].mlp
39
- model.model.layers[0].mlp = QuantLlamaMLP(
40
- fp32_mlp,
41
- qcfg=QuantConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM),
42
- ) # PTQWrapper(fp32_mlp) is also fine
40
+ model.model.layers[0].mlp = prepare(
41
+ fp32_mlp, PTQConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM)
42
+ )
43
43
  model.eval()
44
44
 
45
45
  mlp_q = model.model.layers[0].mlp
46
+ assert isinstance(mlp_q.wrapped, QuantLlamaMLP)
46
47
 
47
48
  # -------------------------------------------------------------------------
48
49
  # 2. Single-pass calibration
@@ -57,13 +58,12 @@ PROMPTS = [
57
58
  ]
58
59
 
59
60
  with torch.no_grad():
60
- mlp_q.enable_calibration()
61
61
  for prompt in PROMPTS:
62
62
  enc = tokenizer(prompt, return_tensors="pt")
63
63
  emb = model.model.embed_tokens(enc["input_ids"])
64
64
  _ = mlp_q(emb)
65
65
 
66
- mlp_q.freeze_qparams()
66
+ convert(mlp_q)
67
67
 
68
68
  assert mlp_q._mode is Mode.QUANT, "Quantization mode should be active now."
69
69
 
@@ -0,0 +1,265 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # =============================================================================
16
+ # PTQ + GPTQ HYBRID QUANTIZATION PIPELINE
17
+ # -----------------------------------------------------------------------------
18
+ # This script shows how to:
19
+ # 1. Load a pretrained FP Llama-3 model.
20
+ # 2. Run GPTQ to quantize weights only.
21
+ # 3. Wrap every Transformer layer with a PTQWrapper to quantize activations.
22
+ # 4. Calibrate UINT-8 observers in a single pass over a text corpus.
23
+ # 5. Inject GPTQ’s per-tensor weight scales / zero-points into the PTQ graph.
24
+ # 6. Freeze all Q-params and compute Wikitext-2 perplexity.
25
+ # =============================================================================
26
+
27
+ import argparse
28
+ import sys
29
+ from typing import Any
30
+
31
+ import torch
32
+ import tqdm
33
+ from datasets import load_dataset
34
+ from transformers import AutoModelForCausalLM, AutoTokenizer
35
+
36
+ from tico.quantization import convert, prepare
37
+ from tico.quantization.config.gptq import GPTQConfig
38
+ from tico.quantization.config.ptq import PTQConfig
39
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
40
+ from tico.quantization.wrapq.utils.introspection import build_fqn_map
41
+ from tico.quantization.wrapq.utils.metrics import perplexity
42
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
43
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
44
+
45
+
46
+ # Token-budget presets for activation calibration
47
+ TOKENS: dict[str, int] = {
48
+ # Smoke test (<1 min turnaround on CPU/GPU)
49
+ "debug": 2_000, # ≈16 × 128-seq batches
50
+ # Good default for 1-7B models (≲3 % ppl delta)
51
+ "baseline": 50_000,
52
+ # Production / 4-bit observer smoothing
53
+ "production": 200_000,
54
+ }
55
+
56
+ DTYPE_MAP = {
57
+ "float32": torch.float32,
58
+ "bfloat16": torch.bfloat16,
59
+ "float16": torch.float16,
60
+ }
61
+
62
+ # Hardcoded dataset settings
63
+ DATASET_NAME = "wikitext"
64
+ DATASET_CONFIG = "wikitext-2-raw-v1"
65
+ TRAIN_SPLIT = "train"
66
+ TEST_SPLIT = "test"
67
+
68
+ # -------------------------------------------------------------------------
69
+ # 1. Helper — copy GPTQ (scale, zp) into PTQ observers
70
+ # -------------------------------------------------------------------------
71
+ def inject_gptq_qparams(
72
+ root: torch.nn.Module,
73
+ gptq_quantizers: dict[str, Any], # {fp_name: quantizer}
74
+ weight_obs_name: str = "weight",
75
+ ):
76
+ """
77
+ For every `QuantModuleBase` whose `fp_name` matches a GPTQ key,
78
+ locate the observer called `weight_obs_name` and overwrite its
79
+ (scale, zero-point), then lock them against further updates.
80
+ """
81
+ for m in root.modules():
82
+ if not isinstance(m, QuantModuleBase):
83
+ continue
84
+ if m.fp_name is None:
85
+ continue
86
+ quantizer = gptq_quantizers.get(m.fp_name)
87
+ if quantizer is None:
88
+ continue
89
+ obs = m.get_observer(weight_obs_name)
90
+ if obs is None:
91
+ continue
92
+ assert isinstance(obs, AffineObserverBase)
93
+ # GPTQ quantizer attributes
94
+ obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
95
+
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser(
99
+ description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
100
+ )
101
+ parser.add_argument(
102
+ "--model", type=str, required=True, help="HF repo name or local path."
103
+ )
104
+ parser.add_argument(
105
+ "--device",
106
+ type=str,
107
+ default="cuda" if torch.cuda.is_available() else "cpu",
108
+ help="Device to run on (cuda|cpu|mps).",
109
+ )
110
+ parser.add_argument(
111
+ "--dtype",
112
+ choices=list(DTYPE_MAP.keys()),
113
+ default="float32",
114
+ help="Model dtype for load.",
115
+ )
116
+ parser.add_argument(
117
+ "--stride",
118
+ type=int,
119
+ default=512,
120
+ help="Sliding-window stride used for calibration and eval.",
121
+ )
122
+ parser.add_argument(
123
+ "--calib-preset",
124
+ choices=list(TOKENS.keys()),
125
+ default="debug",
126
+ help="Activation calibration token budget preset.",
127
+ )
128
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
129
+ parser.add_argument(
130
+ "--trust-remote-code",
131
+ action="store_true",
132
+ help="Enable only if you trust the model repo code.",
133
+ )
134
+ parser.add_argument(
135
+ "--hf-token",
136
+ type=str,
137
+ default=None,
138
+ help="Optional HF token for gated/private repos.",
139
+ )
140
+ parser.add_argument(
141
+ "--use-cache",
142
+ dest="use_cache",
143
+ action="store_true",
144
+ default=False,
145
+ help="Use model KV cache if enabled (off by default).",
146
+ )
147
+ parser.add_argument(
148
+ "--no-tqdm", action="store_true", help="Disable tqdm progress bars."
149
+ )
150
+
151
+ args = parser.parse_args()
152
+
153
+ # Basic setup
154
+ torch.manual_seed(args.seed)
155
+ device = torch.device(args.device)
156
+ dtype = DTYPE_MAP[args.dtype]
157
+
158
+ print("=== Config ===")
159
+ print(f"Model : {args.model}")
160
+ print(f"Device : {device.type}")
161
+ print(f"DType : {args.dtype}")
162
+ print(f"Stride : {args.stride}")
163
+ print(
164
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
165
+ )
166
+ print(f"Use HF cache? : {args.use_cache}")
167
+ print()
168
+
169
+ # -------------------------------------------------------------------------
170
+ # 2. Load the FP backbone and tokenizer
171
+ # -------------------------------------------------------------------------
172
+ print("Loading FP model …")
173
+ tokenizer = AutoTokenizer.from_pretrained(
174
+ args.model,
175
+ trust_remote_code=args.trust_remote_code,
176
+ token=args.hf_token,
177
+ )
178
+ model = (
179
+ AutoModelForCausalLM.from_pretrained(
180
+ args.model,
181
+ torch_dtype=dtype,
182
+ trust_remote_code=args.trust_remote_code,
183
+ token=args.hf_token,
184
+ )
185
+ .to(device)
186
+ .eval()
187
+ )
188
+
189
+ model.config.use_cache = args.use_cache
190
+
191
+ # Build module -> FQN map BEFORE wrapping
192
+ m_to_fqn = build_fqn_map(model)
193
+
194
+ # -------------------------------------------------------------------------
195
+ # 3. Run GPTQ (weight-only) pass
196
+ # -------------------------------------------------------------------------
197
+ print("Applying GPTQ …")
198
+ dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
199
+ q_m = prepare(model, GPTQConfig(), inplace=True)
200
+
201
+ it = (
202
+ dataset_test
203
+ if args.no_tqdm
204
+ else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
205
+ )
206
+ for d in it:
207
+ ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
208
+ q_m(ids) # observers gather weight stats
209
+
210
+ q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
211
+
212
+ # -------------------------------------------------------------------------
213
+ # 4. Wrap every layer with PTQWrapper (activation UINT-8)
214
+ # -------------------------------------------------------------------------
215
+ print("Wrapping layers with PTQWrapper …")
216
+ qcfg = PTQConfig() # default: per-tensor UINT8
217
+ prepare(q_m, qcfg)
218
+
219
+ # -------------------------------------------------------------------------
220
+ # 5. Single-pass activation calibration
221
+ # -------------------------------------------------------------------------
222
+ print("Calibrating UINT-8 observers …")
223
+ CALIB_TOKENS = TOKENS[args.calib_preset]
224
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
225
+ dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
226
+ calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
227
+ train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
228
+
229
+ # Overwrite weight observers with GPTQ statistics
230
+ if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
231
+ inject_gptq_qparams(q_m, q_m.quantizers)
232
+ else:
233
+ print(
234
+ "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
235
+ )
236
+
237
+ # Forward passes to collect activation ranges
238
+ iterator = range(0, train_ids.size(1) - 1, args.stride)
239
+ if not args.no_tqdm:
240
+ iterator = tqdm.tqdm(iterator, desc="Act-calibration")
241
+ with torch.no_grad():
242
+ for i in iterator:
243
+ q_m(train_ids[:, i : i + args.stride])
244
+
245
+ # Freeze all Q-params (scale, zero-point)
246
+ convert(q_m)
247
+
248
+ # -------------------------------------------------------------------------
249
+ # 6. Evaluate perplexity on Wikitext-2
250
+ # -------------------------------------------------------------------------
251
+ print("\nCalculating perplexities …")
252
+ enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
253
+ ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
254
+
255
+ print("\n┌── Wikitext-2 test perplexity ─────────────")
256
+ print(f"│ UINT-8 : {ppl_uint8:8.2f}")
257
+ print("└───────────────────────────────────────────")
258
+
259
+
260
+ if __name__ == "__main__":
261
+ try:
262
+ main()
263
+ except Exception as e:
264
+ print(f"\n[Error] {e}", file=sys.stderr)
265
+ sys.exit(1)