tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class DType:
|
|
20
|
+
"""
|
|
21
|
+
Self-contained integer dtypes for quantization.
|
|
22
|
+
|
|
23
|
+
A DType is just an immutable value-object with two fields:
|
|
24
|
+
- bits
|
|
25
|
+
- signed
|
|
26
|
+
|
|
27
|
+
Common presets (INT8, UINT4, ..) are provided as constants for convenience.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
bits: int # pylint: disable=used-before-assignment
|
|
31
|
+
signed: bool = False # False -> unsigned
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def qmin(self) -> int:
|
|
35
|
+
assert self.bits is not None
|
|
36
|
+
if self.signed:
|
|
37
|
+
return -(1 << (self.bits - 1))
|
|
38
|
+
return 0
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def qmax(self) -> int:
|
|
42
|
+
assert self.bits is not None
|
|
43
|
+
if self.signed:
|
|
44
|
+
return (1 << (self.bits - 1)) - 1
|
|
45
|
+
return (1 << self.bits) - 1
|
|
46
|
+
|
|
47
|
+
def __str__(self) -> str:
|
|
48
|
+
prefix = "int" if self.signed else "uint"
|
|
49
|
+
return f"{prefix}{self.bits}"
|
|
50
|
+
|
|
51
|
+
# ────────────────────────────────
|
|
52
|
+
# Factory helpers
|
|
53
|
+
# ────────────────────────────────
|
|
54
|
+
@staticmethod
|
|
55
|
+
def int(bits: int): # type: ignore[valid-type]
|
|
56
|
+
return DType(bits, signed=True)
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def uint(bits: int): # type: ignore[valid-type]
|
|
60
|
+
return DType(bits, signed=False)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ---------------------------------------------------------------------
|
|
64
|
+
# Convenient canned versions
|
|
65
|
+
# ---------------------------------------------------------------------
|
|
66
|
+
UINT4 = DType.uint(4)
|
|
67
|
+
INT4 = DType.int(4)
|
|
68
|
+
INT8 = DType.int(8)
|
|
69
|
+
UINT8 = DType.uint(8)
|
|
70
|
+
INT16 = DType.int(16)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# QUICK PTQ WORKFLOW (OPTIONAL FP32 BASELINE)
|
|
17
|
+
# -----------------------------------------------------------------------------
|
|
18
|
+
# Toggle RUN_FP to choose between:
|
|
19
|
+
# • FP32 perplexity measurement only, OR
|
|
20
|
+
# • Full post-training UINT-8 flow (wrap → calibrate → eval).
|
|
21
|
+
# =============================================================================
|
|
22
|
+
|
|
23
|
+
import argparse
|
|
24
|
+
import sys
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import tqdm
|
|
28
|
+
from datasets import load_dataset
|
|
29
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
30
|
+
|
|
31
|
+
from tico.quantization import convert, prepare
|
|
32
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
33
|
+
from tico.quantization.wrapq.utils.metrics import perplexity
|
|
34
|
+
|
|
35
|
+
# Token-budget presets for activation calibration
|
|
36
|
+
TOKENS: dict[str, int] = {
|
|
37
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
|
38
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
|
39
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
|
40
|
+
"baseline": 50_000,
|
|
41
|
+
# Production / 4-bit observer smoothing
|
|
42
|
+
"production": 200_000,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
DTYPE_MAP = {
|
|
46
|
+
"float32": torch.float32,
|
|
47
|
+
"bfloat16": torch.bfloat16,
|
|
48
|
+
"float16": torch.float16,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
# Hardcoded dataset settings
|
|
52
|
+
DATASET_NAME = "wikitext"
|
|
53
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
54
|
+
TRAIN_SPLIT = "train"
|
|
55
|
+
TEST_SPLIT = "test"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def main():
|
|
59
|
+
parser = argparse.ArgumentParser(description="Quick PTQ example (FP or UINT8)")
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--mode",
|
|
62
|
+
choices=["fp", "uint8"],
|
|
63
|
+
default="fp",
|
|
64
|
+
help="Choose FP baseline only or full UINT8 PTQ path.",
|
|
65
|
+
)
|
|
66
|
+
parser.add_argument(
|
|
67
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--device",
|
|
71
|
+
type=str,
|
|
72
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
73
|
+
help="Device to run on (cuda|cpu).",
|
|
74
|
+
)
|
|
75
|
+
parser.add_argument(
|
|
76
|
+
"--dtype",
|
|
77
|
+
choices=list(DTYPE_MAP.keys()),
|
|
78
|
+
default="float32",
|
|
79
|
+
help=f"Model dtype for load.",
|
|
80
|
+
)
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--stride", type=int, default=512, help="Sliding-window stride for perplexity."
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"--trust-remote-code",
|
|
87
|
+
action="store_true",
|
|
88
|
+
help="Enable only if you trust the model repo code.",
|
|
89
|
+
)
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--hf-token",
|
|
92
|
+
type=str,
|
|
93
|
+
default=None,
|
|
94
|
+
help="Optional HF token for gated/private models.",
|
|
95
|
+
)
|
|
96
|
+
parser.add_argument(
|
|
97
|
+
"--use-cache",
|
|
98
|
+
dest="use_cache",
|
|
99
|
+
action="store_true",
|
|
100
|
+
default=False,
|
|
101
|
+
help="Use model KV cache if enabled (off by default).",
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
105
|
+
)
|
|
106
|
+
# 2) calib-preset default = debug
|
|
107
|
+
parser.add_argument(
|
|
108
|
+
"--calib-preset",
|
|
109
|
+
choices=list(TOKENS.keys()),
|
|
110
|
+
default="debug",
|
|
111
|
+
help="Calibration token budget preset.",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
args = parser.parse_args()
|
|
115
|
+
|
|
116
|
+
# Basic setup
|
|
117
|
+
torch.manual_seed(args.seed)
|
|
118
|
+
device = torch.device(args.device)
|
|
119
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
120
|
+
|
|
121
|
+
print("=== Config ===")
|
|
122
|
+
print(f"Mode : {args.mode}")
|
|
123
|
+
print(f"Model : {args.model}")
|
|
124
|
+
print(f"Device : {device.type}")
|
|
125
|
+
print(f"DType : {args.dtype}")
|
|
126
|
+
print(f"Stride : {args.stride}")
|
|
127
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
128
|
+
print(
|
|
129
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
130
|
+
)
|
|
131
|
+
print()
|
|
132
|
+
|
|
133
|
+
# -------------------------------------------------------------------------
|
|
134
|
+
# 1. Load model and tokenizer
|
|
135
|
+
# -------------------------------------------------------------------------
|
|
136
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
137
|
+
args.model,
|
|
138
|
+
trust_remote_code=args.trust_remote_code,
|
|
139
|
+
token=args.hf_token,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
model = (
|
|
143
|
+
AutoModelForCausalLM.from_pretrained(
|
|
144
|
+
args.model,
|
|
145
|
+
torch_dtype=dtype,
|
|
146
|
+
trust_remote_code=args.trust_remote_code,
|
|
147
|
+
token=args.hf_token,
|
|
148
|
+
)
|
|
149
|
+
.to(device)
|
|
150
|
+
.eval()
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
model.config.use_cache = args.use_cache
|
|
154
|
+
|
|
155
|
+
if args.mode == "fp":
|
|
156
|
+
fp_model = model
|
|
157
|
+
else:
|
|
158
|
+
# INT8 PTQ path
|
|
159
|
+
uint8_model = model
|
|
160
|
+
|
|
161
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
162
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
163
|
+
|
|
164
|
+
# ---------------------------------------------------------------------
|
|
165
|
+
# 2. Wrap every Transformer layer with PTQWrapper
|
|
166
|
+
# ---------------------------------------------------------------------
|
|
167
|
+
qcfg = PTQConfig() # all-uint8 defaults
|
|
168
|
+
prepare(uint8_model, qcfg)
|
|
169
|
+
|
|
170
|
+
# ---------------------------------------------------------------------
|
|
171
|
+
# 3. Single-pass activation calibration
|
|
172
|
+
# ---------------------------------------------------------------------
|
|
173
|
+
print("Calibrating UINT-8 observers …")
|
|
174
|
+
calib_txt = " ".join(
|
|
175
|
+
load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)["text"]
|
|
176
|
+
)[:CALIB_TOKENS]
|
|
177
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
178
|
+
|
|
179
|
+
# Run inference to collect ranges
|
|
180
|
+
iterator = range(0, ids.size(1) - 1, args.stride)
|
|
181
|
+
if not args.no_tqdm:
|
|
182
|
+
iterator = tqdm.tqdm(iterator, desc="Calibration")
|
|
183
|
+
with torch.no_grad():
|
|
184
|
+
for i in iterator:
|
|
185
|
+
uint8_model(ids[:, i : i + args.stride])
|
|
186
|
+
|
|
187
|
+
# Freeze (scale, zero-point)
|
|
188
|
+
convert(uint8_model)
|
|
189
|
+
|
|
190
|
+
# -------------------------------------------------------------------------
|
|
191
|
+
# 4. Evaluate perplexity
|
|
192
|
+
# -------------------------------------------------------------------------
|
|
193
|
+
print("\nCalculating perplexities …")
|
|
194
|
+
test_ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
195
|
+
enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
|
|
196
|
+
|
|
197
|
+
if args.mode == "fp":
|
|
198
|
+
ppl_fp = perplexity(
|
|
199
|
+
fp_model,
|
|
200
|
+
enc,
|
|
201
|
+
args.device,
|
|
202
|
+
stride=args.stride,
|
|
203
|
+
show_progress=not args.no_tqdm,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
ppl_int8 = perplexity(
|
|
207
|
+
uint8_model,
|
|
208
|
+
enc,
|
|
209
|
+
args.device,
|
|
210
|
+
stride=args.stride,
|
|
211
|
+
show_progress=not args.no_tqdm,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# -------------------------------------------------------------------------
|
|
215
|
+
# 5. Report
|
|
216
|
+
# -------------------------------------------------------------------------
|
|
217
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
218
|
+
if args.mode == "fp":
|
|
219
|
+
print(f"│ FP : {ppl_fp:8.2f}")
|
|
220
|
+
else:
|
|
221
|
+
print(f"│ UINT-8 : {ppl_int8:8.2f}")
|
|
222
|
+
print("└───────────────────────────────────────────")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
if __name__ == "__main__":
|
|
226
|
+
try:
|
|
227
|
+
main()
|
|
228
|
+
except Exception as e:
|
|
229
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
230
|
+
sys.exit(1)
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# ============================================================================
|
|
16
|
+
# LAYER-WISE DIFF DEBUGGING PIPELINE
|
|
17
|
+
# ----------------------------------------------------------------------------
|
|
18
|
+
# A quantization debugging pipeline that identifies accuracy regressions
|
|
19
|
+
# by comparing UINT vs FP outputs at each layer.
|
|
20
|
+
#
|
|
21
|
+
# 1. Load a full-precision (FP) LLaMA-3-1B model.
|
|
22
|
+
# 2. Wrap each Transformer block with PTQWrapper (activations → fake-quant).
|
|
23
|
+
# 3. Capture reference FP layer outputs before quantization.
|
|
24
|
+
# 4. Calibrate UINT-8 activation observers in a single pass.
|
|
25
|
+
# 5. Freeze quantization parameters (scale, zero-point).
|
|
26
|
+
# 6. Re-run inference and compare UINT-8 vs FP outputs per layer.
|
|
27
|
+
# 7. Report where quantization hurts the most.
|
|
28
|
+
#
|
|
29
|
+
# Use this pipeline to trace precision loss layer by layer, and pinpoint
|
|
30
|
+
# problematic modules during post-training quantization.
|
|
31
|
+
# ============================================================================
|
|
32
|
+
|
|
33
|
+
import argparse
|
|
34
|
+
import sys
|
|
35
|
+
|
|
36
|
+
import torch
|
|
37
|
+
import tqdm
|
|
38
|
+
from datasets import load_dataset
|
|
39
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
40
|
+
|
|
41
|
+
from tico.quantization import convert, prepare
|
|
42
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
43
|
+
from tico.quantization.wrapq.utils.introspection import (
|
|
44
|
+
build_fqn_map,
|
|
45
|
+
compare_layer_outputs,
|
|
46
|
+
save_fp_outputs,
|
|
47
|
+
)
|
|
48
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
49
|
+
|
|
50
|
+
# Token-budget presets for activation calibration
|
|
51
|
+
TOKENS: dict[str, int] = {
|
|
52
|
+
# Smoke test (<1 min turnaround on CPU/GPU)
|
|
53
|
+
"debug": 2_000, # ≈16 × 128-seq batches
|
|
54
|
+
# Good default for 1-7B models (≲3 % ppl delta)
|
|
55
|
+
"baseline": 50_000,
|
|
56
|
+
# Production / 4-bit observer smoothing
|
|
57
|
+
"production": 200_000,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
DTYPE_MAP = {
|
|
61
|
+
"float32": torch.float32,
|
|
62
|
+
"bfloat16": torch.bfloat16,
|
|
63
|
+
"float16": torch.float16,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Hardcoded dataset settings
|
|
67
|
+
DATASET_NAME = "wikitext"
|
|
68
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
69
|
+
TRAIN_SPLIT = "train"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def main():
|
|
73
|
+
parser = argparse.ArgumentParser(
|
|
74
|
+
description="Layer-wise diff debugging pipeline for PTQ"
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
78
|
+
)
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--device",
|
|
81
|
+
type=str,
|
|
82
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
83
|
+
help="Device to run on (cuda|cpu|mps).",
|
|
84
|
+
)
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"--dtype",
|
|
87
|
+
choices=list(DTYPE_MAP.keys()),
|
|
88
|
+
default="float32",
|
|
89
|
+
help=f"Model dtype for load.",
|
|
90
|
+
)
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--stride",
|
|
93
|
+
type=int,
|
|
94
|
+
default=512,
|
|
95
|
+
help="Sliding-window stride used during calibration.",
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--calib-preset",
|
|
99
|
+
choices=list(TOKENS.keys()),
|
|
100
|
+
default="debug",
|
|
101
|
+
help="Calibration token budget preset.",
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--trust-remote-code",
|
|
106
|
+
action="store_true",
|
|
107
|
+
help="Enable only if you trust the model repo code.",
|
|
108
|
+
)
|
|
109
|
+
parser.add_argument(
|
|
110
|
+
"--hf-token",
|
|
111
|
+
type=str,
|
|
112
|
+
default=None,
|
|
113
|
+
help="Optional HF token for gated/private repos.",
|
|
114
|
+
)
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--use-cache",
|
|
117
|
+
dest="use_cache",
|
|
118
|
+
action="store_true",
|
|
119
|
+
default=False,
|
|
120
|
+
help="Use model KV cache if enabled (off by default).",
|
|
121
|
+
)
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
args = parser.parse_args()
|
|
127
|
+
|
|
128
|
+
# Basic setup
|
|
129
|
+
torch.manual_seed(args.seed)
|
|
130
|
+
device = torch.device(args.device)
|
|
131
|
+
dtype = DTYPE_MAP[args.dtype] # noqa: E999 (kept readable)
|
|
132
|
+
|
|
133
|
+
print("=== Config ===")
|
|
134
|
+
print(f"Model : {args.model}")
|
|
135
|
+
print(f"Device : {device.type}")
|
|
136
|
+
print(f"DType : {args.dtype}")
|
|
137
|
+
print(f"Stride : {args.stride}")
|
|
138
|
+
print(
|
|
139
|
+
f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
|
|
140
|
+
)
|
|
141
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
142
|
+
print()
|
|
143
|
+
|
|
144
|
+
# -------------------------------------------------------------------------
|
|
145
|
+
# 1. Load the FP backbone and tokenizer
|
|
146
|
+
# -------------------------------------------------------------------------
|
|
147
|
+
print("Loading FP model …")
|
|
148
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
149
|
+
args.model,
|
|
150
|
+
trust_remote_code=args.trust_remote_code,
|
|
151
|
+
token=args.hf_token,
|
|
152
|
+
)
|
|
153
|
+
model = (
|
|
154
|
+
AutoModelForCausalLM.from_pretrained(
|
|
155
|
+
args.model,
|
|
156
|
+
torch_dtype=dtype,
|
|
157
|
+
trust_remote_code=args.trust_remote_code,
|
|
158
|
+
token=args.hf_token,
|
|
159
|
+
)
|
|
160
|
+
.to(device)
|
|
161
|
+
.eval()
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Disable KV cache to force full forward passes for introspection
|
|
165
|
+
model.config.use_cache = args.use_cache
|
|
166
|
+
|
|
167
|
+
# Build module -> FQN map before wrapping
|
|
168
|
+
m_to_fqn = build_fqn_map(model)
|
|
169
|
+
|
|
170
|
+
# Prepare calibration inputs (HF Wikitext-2 train split)
|
|
171
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
172
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
173
|
+
# Use Wikitext-2 train split for calibration.
|
|
174
|
+
dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
|
|
175
|
+
|
|
176
|
+
# -------------------------------------------------------------------------
|
|
177
|
+
# 2. Wrap every layer with PTQWrapper (UINT-8 activations)
|
|
178
|
+
# -------------------------------------------------------------------------
|
|
179
|
+
print("Wrapping layers with PTQWrapper …")
|
|
180
|
+
qcfg = PTQConfig() # default: per-tensor UINT8
|
|
181
|
+
prepare(model, qcfg)
|
|
182
|
+
|
|
183
|
+
# -------------------------------------------------------------------------
|
|
184
|
+
# 3. Activation calibration plus FP-vs-UINT8 diffing
|
|
185
|
+
# -------------------------------------------------------------------------
|
|
186
|
+
print("Calibrating UINT-8 observers …")
|
|
187
|
+
calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
|
|
188
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
189
|
+
|
|
190
|
+
# Save reference FP activations before observers clamp/quantize
|
|
191
|
+
save_handles, act_cache = save_fp_outputs(model)
|
|
192
|
+
|
|
193
|
+
iterator = range(0, ids.size(1) - 1, args.stride)
|
|
194
|
+
if not args.no_tqdm:
|
|
195
|
+
iterator = tqdm.tqdm(iterator, desc="Act-Calibration")
|
|
196
|
+
with torch.no_grad():
|
|
197
|
+
for i in iterator:
|
|
198
|
+
inputs = ids[:, i : i + args.stride]
|
|
199
|
+
model(inputs) # observers collect act. ranges
|
|
200
|
+
|
|
201
|
+
# Remove save hooks now that FP activations are cached
|
|
202
|
+
for h in save_handles:
|
|
203
|
+
h.remove()
|
|
204
|
+
|
|
205
|
+
# Freeze (scale, zero-point) after calibration
|
|
206
|
+
convert(model)
|
|
207
|
+
|
|
208
|
+
# Register diff hooks and measure per-layer deltas
|
|
209
|
+
cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
|
|
210
|
+
# Use same inputs for comparison.
|
|
211
|
+
with torch.no_grad():
|
|
212
|
+
model(inputs)
|
|
213
|
+
|
|
214
|
+
assert isinstance(cmp_handles, list)
|
|
215
|
+
for h in cmp_handles:
|
|
216
|
+
h.remove()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
if __name__ == "__main__":
|
|
220
|
+
try:
|
|
221
|
+
main()
|
|
222
|
+
except Exception as e:
|
|
223
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
224
|
+
sys.exit(1)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# POST-TRAINING QUANTIZATION EXAMPLE — Simple Linear Model
|
|
17
|
+
# -----------------------------------------------------------------------------
|
|
18
|
+
# This demo shows a minimal PTQ flow for a toy model:
|
|
19
|
+
# 1. Define a simple model with a single Linear layer.
|
|
20
|
+
# 2. Replace the FP32 Linear with a QuantLinear wrapper.
|
|
21
|
+
# 3. Run a short calibration pass to collect activation statistics.
|
|
22
|
+
# 4. Freeze scales / zero-points and switch to INT-simulation mode.
|
|
23
|
+
# 5. Compare INT vs FP32 outputs with a mean-absolute-diff check.
|
|
24
|
+
# 6. Export the quantized model to a Circle format.
|
|
25
|
+
# =============================================================================
|
|
26
|
+
|
|
27
|
+
import pathlib
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
import torch.nn as nn
|
|
31
|
+
|
|
32
|
+
from tico.quantization import convert, prepare
|
|
33
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
34
|
+
from tico.quantization.evaluation.metric import compute_peir
|
|
35
|
+
from tico.quantization.evaluation.utils import plot_two_outputs
|
|
36
|
+
from tico.quantization.wrapq.mode import Mode
|
|
37
|
+
from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear
|
|
38
|
+
from tico.utils.utils import SuppressWarning
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# -------------------------------------------------------------------------
|
|
42
|
+
# 0. Define a toy model (1 Linear layer only)
|
|
43
|
+
# -------------------------------------------------------------------------
|
|
44
|
+
class TinyLinearModel(nn.Module):
|
|
45
|
+
"""A minimal model: single Linear layer."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, in_features=16, out_features=8):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.fc = nn.Linear(in_features, out_features, bias=False)
|
|
50
|
+
|
|
51
|
+
def forward(self, x):
|
|
52
|
+
return self.fc(x)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# Instantiate FP32 model
|
|
56
|
+
model = TinyLinearModel()
|
|
57
|
+
model.eval()
|
|
58
|
+
|
|
59
|
+
# Keep FP32 reference for diff check
|
|
60
|
+
fp32_layer = model.fc
|
|
61
|
+
|
|
62
|
+
# -------------------------------------------------------------------------
|
|
63
|
+
# 1. Replace the Linear with QuantLinear wrapper
|
|
64
|
+
# -------------------------------------------------------------------------
|
|
65
|
+
model.fc = prepare(fp32_layer, PTQConfig()) # type: ignore[assignment]
|
|
66
|
+
qlayer = model.fc # alias for brevity
|
|
67
|
+
|
|
68
|
+
# -------------------------------------------------------------------------
|
|
69
|
+
# 2. Single-pass calibration (collect activation ranges)
|
|
70
|
+
# -------------------------------------------------------------------------
|
|
71
|
+
assert isinstance(qlayer.wrapped, QuantLinear)
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
for _ in range(16): # small toy batch
|
|
74
|
+
x = torch.randn(4, 16) # (batch=4, features=16)
|
|
75
|
+
_ = model(x)
|
|
76
|
+
|
|
77
|
+
convert(qlayer)
|
|
78
|
+
|
|
79
|
+
assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now."
|
|
80
|
+
|
|
81
|
+
# -------------------------------------------------------------------------
|
|
82
|
+
# 3. Quick INT-sim vs FP32 sanity check
|
|
83
|
+
# -------------------------------------------------------------------------
|
|
84
|
+
x = torch.randn(2, 16)
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
int8_out = model(x)
|
|
87
|
+
fp32_out = fp32_layer(x)
|
|
88
|
+
|
|
89
|
+
print("┌───────────── Quantization Error Summary ─────────────")
|
|
90
|
+
print(f"│ Mean |diff|: {(int8_out - fp32_out).abs().mean().item():.6f}")
|
|
91
|
+
print(f"│ PEIR : {compute_peir(fp32_out, int8_out) * 100:.6f} %")
|
|
92
|
+
print("└──────────────────────────────────────────────────────")
|
|
93
|
+
print(plot_two_outputs(fp32_out, int8_out))
|
|
94
|
+
|
|
95
|
+
# -------------------------------------------------------------------------
|
|
96
|
+
# 4. Export the calibrated model to Circle
|
|
97
|
+
# -------------------------------------------------------------------------
|
|
98
|
+
import tico
|
|
99
|
+
|
|
100
|
+
save_path = pathlib.Path("tiny_linear.q.circle")
|
|
101
|
+
example_input = torch.randn(1, 16)
|
|
102
|
+
|
|
103
|
+
with SuppressWarning(UserWarning, ".*"):
|
|
104
|
+
cm = tico.convert(model, (example_input,)) # forward(x) only
|
|
105
|
+
cm.save(save_path)
|
|
106
|
+
|
|
107
|
+
print(f"Quantized Circle model saved to {save_path.resolve()}")
|