tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_expand.py +3 -1
  10. tico/quantization/__init__.py +6 -0
  11. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  12. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  14. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  29. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  31. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  32. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  33. tico/quantization/config/base.py +26 -0
  34. tico/quantization/config/fpi_gptq.py +29 -0
  35. tico/quantization/config/gptq.py +29 -0
  36. tico/quantization/config/pt2e.py +25 -0
  37. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  38. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
  39. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  40. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  41. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  42. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  47. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  48. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  52. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  59. tico/quantization/wrapq/quantizer.py +179 -0
  60. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  62. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  63. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  64. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  65. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  66. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  67. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  68. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  69. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  70. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
  71. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
  72. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  73. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  74. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  75. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  76. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  77. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  78. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  79. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  80. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
  81. tico/serialize/circle_serializer.py +11 -4
  82. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  83. tico/serialize/operators/op_le.py +54 -0
  84. tico/serialize/operators/op_mm.py +15 -132
  85. tico/utils/convert.py +20 -15
  86. tico/utils/register_custom_op.py +6 -4
  87. tico/utils/signature.py +7 -8
  88. tico/utils/validate_args_kwargs.py +12 -0
  89. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  90. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
  91. tico/experimental/quantization/__init__.py +0 -6
  92. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  93. tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
  94. tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
  95. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
  96. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  97. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  98. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  99. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  100. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  101. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  102. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  103. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  104. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  105. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  106. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  107. /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
  108. /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
  109. /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  111. /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
  112. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  113. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  114. /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
  115. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  116. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  117. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  118. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  119. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  120. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  121. /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
  122. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  123. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
  124. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  125. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
  126. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  127. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
  128. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  129. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
  130. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
@@ -1,164 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Dict, List, Optional
16
-
17
- import torch
18
-
19
-
20
- @torch.no_grad()
21
- def smooth_weights(
22
- front_module: torch.nn.Module,
23
- back_modules: torch.nn.Module | List[torch.nn.Module],
24
- activation_max: torch.Tensor,
25
- alpha: float,
26
- ):
27
- """
28
- Applies SmoothQuant-style smoothing to the weights and biases of two
29
- connected modules using activation maximum values.
30
-
31
- NOTE All modules **MUST** have `weight` and optionally `bias` attributes.
32
-
33
- Parameters
34
- -----------
35
- front_module
36
- The front module whose weights and biases will be adjusted.
37
- back_modules
38
- A list of back modules whose weights and biases will be adjusted.
39
- activation_max
40
- A tensor of channel-wise maximum activation values for the front module.
41
- alpha
42
- The smoothing factor that determines the scaling for weight adjustments.
43
-
44
- Raises
45
- -------
46
- AttributeError
47
- If `front_module` or any module in `back_modules` does not have `weight` attributes.
48
- ValueError
49
- If the shape of tensors in `activation_max` does not match the number of channels
50
- in `front_module`'s weight.
51
- NoteImplementedError
52
- If `front_module` or any module in `back_modules` is of an unsupported type.
53
- """
54
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
55
-
56
- if not isinstance(back_modules, list):
57
- back_modules = [back_modules]
58
-
59
- # Check attributes
60
- if not hasattr(front_module, "weight"):
61
- raise AttributeError(
62
- f"The front module '{type(front_module).__name__}' does not have a 'weight' attribute."
63
- )
64
- for back_m in back_modules:
65
- if not hasattr(back_m, "weight"):
66
- raise AttributeError(
67
- f"The front module '{type(back_m).__name__}' does not have a 'weight' attribute."
68
- )
69
- # Check shapes
70
- if isinstance(front_module, LlamaRMSNorm):
71
- front_numel = front_module.weight.numel()
72
- else:
73
- raise NotImplementedError(
74
- f"Unsupported module type: {type(front_module).__name__}"
75
- )
76
- for back_m in back_modules:
77
- if isinstance(back_m, torch.nn.Linear):
78
- back_numel = back_m.in_features
79
- else:
80
- raise NotImplementedError(
81
- f"Unsupported module type: {type(front_module).__name__}"
82
- )
83
-
84
- if front_numel != back_numel or back_numel != activation_max.numel():
85
- raise ValueError(
86
- f"Shape mismatch: front_numel({front_numel}), back_numel({back_numel}), activation_max_numel({activation_max.numel()})"
87
- )
88
-
89
- # Compute scales
90
- device, dtype = back_modules[0].weight.device, back_modules[0].weight.dtype
91
- activation_max = activation_max.to(device=device, dtype=dtype) # type: ignore[arg-type]
92
- weight_scales = torch.cat(
93
- [back_m.weight.abs().max(dim=0, keepdim=True)[0] for back_m in back_modules], # type: ignore[operator]
94
- dim=0,
95
- )
96
- weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
97
- scales = (
98
- (activation_max.pow(alpha) / weight_scales.pow(1 - alpha))
99
- .clamp(min=1e-5)
100
- .to(device) # type: ignore[arg-type]
101
- .to(dtype) # type: ignore[arg-type]
102
- )
103
-
104
- # Smooth
105
- front_module.weight.div_(scales)
106
- if hasattr(front_module, "bias"):
107
- front_module.bias.div_(scales)
108
-
109
- for back_m in back_modules:
110
- back_m.weight.mul_(scales.view(1, -1)) # type: ignore[operator]
111
-
112
-
113
- @torch.no_grad()
114
- def apply_smoothing(
115
- model: torch.nn.Module,
116
- activation_max: Dict[str, torch.Tensor],
117
- alpha: float = 0.5,
118
- custom_alpha_map: Optional[Dict[str, float]] = None,
119
- ):
120
- """
121
- Applies SmoothQuant-style smoothing to the model's weights using activation maximum values.
122
-
123
- Parameters
124
- -----------
125
- model
126
- A torch module whose weights will be smoothed.
127
- activation_max
128
- The channel-wise maximum activation values for the model.
129
- alpha
130
- The default smoothing factor to apply across all modules.
131
- custom_alpha_map
132
- A dictionary mapping layer/module names to custom alpha values.
133
- Layers specified in this dictionary will use the corresponding alpha
134
- value instead of the default.
135
- """
136
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
137
-
138
- for name, module in model.named_modules():
139
- alpha_to_apply = alpha
140
- if custom_alpha_map and name in custom_alpha_map:
141
- alpha_to_apply = custom_alpha_map[name]
142
- if alpha_to_apply > 1.0:
143
- raise RuntimeError(
144
- f"Alpha value cannot exceed 1.0. Given alpha: {alpha_to_apply}"
145
- )
146
- # SmoothQuant is applied before capturing the graph. Therefore, it needs to know
147
- # specific module information.
148
- # TODO Suport more modules.
149
- if isinstance(module, LlamaDecoderLayer):
150
- attn_ln = module.input_layernorm
151
- qkv = [
152
- module.self_attn.q_proj,
153
- module.self_attn.k_proj,
154
- module.self_attn.v_proj,
155
- ]
156
-
157
- qkv_input_scales = activation_max[name + ".self_attn.q_proj"]
158
- smooth_weights(attn_ln, qkv, qkv_input_scales, alpha_to_apply)
159
-
160
- ffn_ln = module.post_attention_layernorm
161
- fcs = [module.mlp.gate_proj, module.mlp.up_proj]
162
- fcs_input_scales = activation_max[name + ".mlp.gate_proj"]
163
-
164
- smooth_weights(ffn_ln, fcs, fcs_input_scales, alpha_to_apply)
@@ -1,121 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # =============================================================================
16
- # QUICK PTQ WORKFLOW (OPTIONAL FP32 BASELINE)
17
- # -----------------------------------------------------------------------------
18
- # Toggle RUN_FP to choose between:
19
- # • FP32 perplexity measurement only, OR
20
- # • Full post-training UINT-8 flow (wrap → calibrate → eval).
21
- # =============================================================================
22
-
23
- import torch
24
- import tqdm
25
- from datasets import load_dataset
26
- from transformers import AutoModelForCausalLM, AutoTokenizer
27
-
28
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
29
- from tico.experimental.quantization.ptq.utils.metrics import perplexity
30
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
31
-
32
- # -------------------------------------------------------------------------
33
- # 0. Global configuration
34
- # -------------------------------------------------------------------------
35
- MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
36
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
- STRIDE = 512 # sliding-window stride for perplexity
38
- RUN_FP = True # set False → run UINT-8 path
39
-
40
- # Token-budget presets for activation calibration
41
- TOKENS: dict[str, int] = {
42
- # Smoke test (<1 min turnaround on CPU/GPU)
43
- "debug": 2_000, # ≈16 × 128-seq batches
44
- # Good default for 1-7B models (≲3 % ppl delta)
45
- "baseline": 50_000,
46
- # Production / 4-bit observer smoothing
47
- "production": 200_000,
48
- }
49
- CALIB_TOKENS = TOKENS["baseline"]
50
- print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
51
-
52
- # -------------------------------------------------------------------------
53
- # 1. Load model
54
- # -------------------------------------------------------------------------
55
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
56
-
57
- if RUN_FP:
58
- # -- FP32 baseline ------------------------------------------------------
59
- print("Loading FP32 model …")
60
- fp_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
61
- fp_model.config.use_cache = False
62
- else:
63
- # -- UINT-8 pipeline -----------------------------------------------------
64
- print("Creating UINT-8 clone …")
65
- uint8_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
66
- uint8_model.config.use_cache = False
67
-
68
- # ---------------------------------------------------------------------
69
- # 2. Wrap every Transformer layer with PTQWrapper
70
- # ---------------------------------------------------------------------
71
- qcfg = QuantConfig() # all-uint8 defaults
72
-
73
- wrapped_layers = torch.nn.ModuleList()
74
- for idx, layer in enumerate(uint8_model.model.layers):
75
- layer_cfg = qcfg.child(f"layer{idx}")
76
- wrapped_layers.append(PTQWrapper(layer, qcfg=layer_cfg))
77
- uint8_model.model.layers = wrapped_layers
78
-
79
- # ---------------------------------------------------------------------
80
- # 3. Single-pass activation calibration
81
- # ---------------------------------------------------------------------
82
- print("Calibrating UINT-8 observers …")
83
- calib_txt = " ".join(
84
- load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
85
- )[:CALIB_TOKENS]
86
- ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
87
-
88
- # (a) switch every QuantModuleBase to CALIB mode
89
- for l in uint8_model.model.layers:
90
- l.enable_calibration()
91
-
92
- # (b) run inference to collect ranges
93
- with torch.no_grad():
94
- for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Calibration"):
95
- uint8_model(ids[:, i : i + STRIDE])
96
-
97
- # (c) freeze (scale, zero-point)
98
- for l in uint8_model.model.layers:
99
- l.freeze_qparams()
100
-
101
- # -------------------------------------------------------------------------
102
- # 4. Evaluate perplexity on Wikitext-2
103
- # -------------------------------------------------------------------------
104
- print("\nCalculating perplexities …")
105
- test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
106
- enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
107
-
108
- if RUN_FP:
109
- ppl_fp = perplexity(fp_model, enc, DEVICE, stride=STRIDE)
110
- else:
111
- ppl_int8 = perplexity(uint8_model, enc, DEVICE, stride=STRIDE)
112
-
113
- # -------------------------------------------------------------------------
114
- # 5. Report
115
- # -------------------------------------------------------------------------
116
- print("\n┌── Wikitext-2 test perplexity ─────────────")
117
- if RUN_FP:
118
- print(f"│ FP32 : {ppl_fp:8.2f}")
119
- else:
120
- print(f"│ UINT-8 : {ppl_int8:8.2f}")
121
- print("└───────────────────────────────────────────")
@@ -1,129 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- import tqdm
17
- from datasets import load_dataset
18
- from transformers import AutoModelForCausalLM, AutoTokenizer
19
-
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.utils.introspection import (
22
- build_fqn_map,
23
- compare_layer_outputs,
24
- save_fp_outputs,
25
- )
26
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
27
-
28
- # ============================================================================
29
- # LAYER-WISE DIFF DEBUGGING PIPELINE
30
- # ----------------------------------------------------------------------------
31
- # A quantization debugging pipeline that identifies accuracy regressions
32
- # by comparing UINT vs FP outputs at each layer.
33
- #
34
- # 1. Load a full-precision (FP) LLaMA-3-1B model.
35
- # 2. Wrap each Transformer block with PTQWrapper (activations → fake-quant).
36
- # 3. Capture reference FP layer outputs before quantization.
37
- # 4. Calibrate UINT-8 activation observers in a single pass.
38
- # 5. Freeze quantization parameters (scale, zero-point).
39
- # 6. Re-run inference and compare UINT-8 vs FP outputs per layer.
40
- # 7. Report where quantization hurts the most.
41
- #
42
- # Use this pipeline to trace precision loss layer by layer, and pinpoint
43
- # problematic modules during post-training quantization.
44
- # ============================================================================
45
-
46
- # -------------------------------------------------------------------------
47
- # 0. Global configuration
48
- # -------------------------------------------------------------------------
49
- MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
50
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
51
- STRIDE = 512
52
-
53
- # Token-budget presets for activation calibration
54
- TOKENS: dict[str, int] = {
55
- # Smoke test (<1 min turnaround on CPU/GPU)
56
- "debug": 2_000, # ≈16 × 128-seq batches
57
- # Good default for 1-7B models (≲3 % ppl delta)
58
- "baseline": 50_000,
59
- # Production / 4-bit observer smoothing
60
- "production": 200_000,
61
- }
62
- CALIB_TOKENS = TOKENS["baseline"]
63
- print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
64
-
65
- # -------------------------------------------------------------------------
66
- # 1. Load the FP backbone
67
- # -------------------------------------------------------------------------
68
- print("Loading FP model …")
69
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
70
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
71
- model.config.use_cache = False # disable KV-cache → full forward
72
- m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
73
-
74
- # Use Wikitext-2 train split for calibration.
75
- dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
76
-
77
- # -------------------------------------------------------------------------
78
- # 2. Wrap every layer with PTQWrapper (UINT-8 activations)
79
- # -------------------------------------------------------------------------
80
- print("Wrapping layers with PTQWrapper …")
81
- qcfg = QuantConfig() # default: per-tensor UINT8
82
-
83
- new_layers = torch.nn.ModuleList()
84
- for idx, fp_layer in enumerate(model.model.layers):
85
- layer_cfg = qcfg.child(f"layer{idx}")
86
- q_layer = PTQWrapper(
87
- fp_layer,
88
- qcfg=layer_cfg,
89
- fp_name=m_to_fqn.get(fp_layer),
90
- )
91
- new_layers.append(q_layer)
92
-
93
- model.model.layers = new_layers # swap in quant wrappers
94
-
95
- # -------------------------------------------------------------------------
96
- # 3. Activation calibration plus FP-vs-UINT8 diffing
97
- # -------------------------------------------------------------------------
98
- print("Calibrating UINT-8 observers …")
99
- calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
100
- ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
101
-
102
- # (a) Enable CALIB mode on every QuantModuleBase
103
- for l in model.model.layers:
104
- l.enable_calibration()
105
-
106
- # Save reference FP activations before observers clamp/quantize
107
- save_handles, act_cache = save_fp_outputs(model)
108
-
109
- with torch.no_grad():
110
- for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
111
- inputs = ids[:, i : i + STRIDE]
112
- model(inputs) # observers collect act. ranges
113
-
114
- # Remove save hooks now that FP activations are cached
115
- for h in save_handles:
116
- h.remove()
117
-
118
- # (b) Freeze (scale, zero-point) after calibration
119
- for l in model.model.layers:
120
- l.freeze_qparams()
121
-
122
- # (c) Register diff hooks and measure per-layer deltas
123
- cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
124
- # Use same inputs for comparison.
125
- model(inputs)
126
-
127
- assert isinstance(cmp_handles, list)
128
- for h in cmp_handles:
129
- h.remove()
@@ -1,165 +0,0 @@
1
- # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # =============================================================================
16
- # PTQ + GPTQ HYBRID QUANTIZATION PIPELINE
17
- # -----------------------------------------------------------------------------
18
- # This script shows how to:
19
- # 1. Load a pretrained FP Llama-3 model.
20
- # 2. Run GPTQ to quantize weights only.
21
- # 3. Wrap every Transformer layer with a PTQWrapper to quantize activations.
22
- # 4. Calibrate UINT-8 observers in a single pass over a text corpus.
23
- # 5. Inject GPTQ’s per-tensor weight scales / zero-points into the PTQ graph.
24
- # 6. Freeze all Q-params and compute Wikitext-2 perplexity.
25
- # =============================================================================
26
-
27
- from typing import Any
28
-
29
- import torch
30
- import tqdm
31
- from datasets import load_dataset
32
- from transformers import AutoModelForCausalLM, AutoTokenizer
33
-
34
- from tico.experimental.quantization import convert, prepare
35
- from tico.experimental.quantization.config import GPTQConfig
36
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
37
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
38
- from tico.experimental.quantization.ptq.utils.introspection import build_fqn_map
39
- from tico.experimental.quantization.ptq.utils.metrics import perplexity
40
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
41
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
42
- QuantModuleBase,
43
- )
44
-
45
- # -------------------------------------------------------------------------
46
- # 0. Global configuration
47
- # -------------------------------------------------------------------------
48
- MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
49
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
50
- STRIDE = 512
51
-
52
- # Token-budget presets for activation calibration
53
- TOKENS: dict[str, int] = {
54
- # Smoke test (<1 min turnaround on CPU/GPU)
55
- "debug": 2_000, # ≈16 × 128-seq batches
56
- # Good default for 1-7B models (≲3 % ppl delta)
57
- "baseline": 50_000,
58
- # Production / 4-bit observer smoothing
59
- "production": 200_000,
60
- }
61
- CALIB_TOKENS = TOKENS["baseline"]
62
-
63
- # -------------------------------------------------------------------------
64
- # 1. Helper — copy GPTQ (scale, zp) into PTQ observers
65
- # -------------------------------------------------------------------------
66
- def inject_gptq_qparams(
67
- root: torch.nn.Module,
68
- gptq_quantizers: dict[str, Any], # {fp_name: quantizer}
69
- weight_obs_name: str = "weight",
70
- ):
71
- """
72
- For every `QuantModuleBase` whose `fp_name` matches a GPTQ key,
73
- locate the observer called `weight_obs_name` and overwrite its
74
- (scale, zero-point), then lock them against further updates.
75
- """
76
- for m in root.modules():
77
- if not isinstance(m, QuantModuleBase):
78
- continue
79
- if m.fp_name is None:
80
- continue
81
- quantizer = gptq_quantizers.get(m.fp_name)
82
- if quantizer is None:
83
- continue
84
- obs = m.get_observer(weight_obs_name)
85
- if obs is None:
86
- continue
87
- assert isinstance(obs, AffineObserverBase)
88
- # GPTQ quantizer attributes
89
- obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
90
-
91
-
92
- # -------------------------------------------------------------------------
93
- # 2. Load the FP backbone
94
- # -------------------------------------------------------------------------
95
- print("Loading FP model …")
96
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
97
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
98
- model.config.use_cache = False # disable KV-cache → full forward
99
- m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
100
-
101
- # -------------------------------------------------------------------------
102
- # 3. Run GPTQ (weight-only) pass
103
- # -------------------------------------------------------------------------
104
- print("Applying GPTQ …")
105
- dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="test")
106
- q_m = prepare(model, GPTQConfig(), inplace=True)
107
-
108
- for d in tqdm.tqdm(dataset, desc="GPTQ calibration"):
109
- ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(DEVICE)
110
- q_m(ids) # observers gather weight stats
111
-
112
- q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
113
-
114
- # -------------------------------------------------------------------------
115
- # 4. Wrap every layer with PTQWrapper (activation UINT-8)
116
- # -------------------------------------------------------------------------
117
- qcfg = QuantConfig() # default: per-tensor UINT8
118
- new_layers = torch.nn.ModuleList()
119
-
120
- for idx, fp_layer in enumerate(q_m.model.layers):
121
- layer_cfg = qcfg.child(f"layer{idx}")
122
- q_layer = PTQWrapper(
123
- fp_layer,
124
- qcfg=layer_cfg,
125
- fp_name=m_to_fqn.get(fp_layer),
126
- )
127
- new_layers.append(q_layer)
128
-
129
- q_m.model.layers = new_layers
130
-
131
- # -------------------------------------------------------------------------
132
- # 5. Single-pass activation calibration
133
- # -------------------------------------------------------------------------
134
- print("Calibrating UINT-8 observers …")
135
- calib_txt = " ".join(
136
- load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
137
- )[:CALIB_TOKENS]
138
- ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
139
-
140
- # (a) Enable CALIB mode on every QuantModuleBase
141
- for l in q_m.model.layers:
142
- l.enable_calibration()
143
-
144
- # (b) Overwrite weight observers with GPTQ statistics
145
- inject_gptq_qparams(q_m, q_m.quantizers)
146
-
147
- with torch.no_grad():
148
- for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
149
- q_m(ids[:, i : i + STRIDE]) # observers collect act. ranges
150
-
151
- # (c) Freeze all Q-params (scale, zp)
152
- for l in q_m.model.layers:
153
- l.freeze_qparams()
154
-
155
- # -------------------------------------------------------------------------
156
- # 6. Evaluate perplexity on Wikitext-2
157
- # -------------------------------------------------------------------------
158
- print("\nCalculating perplexities …")
159
- test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
160
- enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
161
- ppl_uint8 = perplexity(q_m, enc, DEVICE, stride=STRIDE)
162
-
163
- print("\n┌── Wikitext-2 test perplexity ─────────────")
164
- print(f"│ UINT-8 : {ppl_uint8:8.2f}")
165
- print("└───────────────────────────────────────────")