tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_assert_nodes.py +3 -1
  10. tico/passes/remove_redundant_expand.py +3 -1
  11. tico/quantization/__init__.py +6 -0
  12. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  29. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  31. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  32. tico/quantization/config/base.py +26 -0
  33. tico/quantization/config/gptq.py +29 -0
  34. tico/quantization/config/pt2e.py +25 -0
  35. tico/quantization/config/ptq.py +119 -0
  36. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  37. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
  38. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  39. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  40. tico/quantization/evaluation/metric.py +146 -0
  41. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  42. tico/quantization/passes/__init__.py +1 -0
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/__init__.py +1 -0
  47. tico/quantization/wrapq/dtypes.py +70 -0
  48. tico/quantization/wrapq/examples/__init__.py +1 -0
  49. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  50. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  51. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  52. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  53. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  54. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  55. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  56. tico/quantization/wrapq/mode.py +32 -0
  57. tico/quantization/wrapq/observers/__init__.py +1 -0
  58. tico/quantization/wrapq/observers/affine_base.py +128 -0
  59. tico/quantization/wrapq/observers/base.py +98 -0
  60. tico/quantization/wrapq/observers/ema.py +62 -0
  61. tico/quantization/wrapq/observers/identity.py +74 -0
  62. tico/quantization/wrapq/observers/minmax.py +39 -0
  63. tico/quantization/wrapq/observers/mx.py +60 -0
  64. tico/quantization/wrapq/qscheme.py +40 -0
  65. tico/quantization/wrapq/quantizer.py +179 -0
  66. tico/quantization/wrapq/utils/__init__.py +1 -0
  67. tico/quantization/wrapq/utils/introspection.py +167 -0
  68. tico/quantization/wrapq/utils/metrics.py +124 -0
  69. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  70. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  71. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  72. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  73. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  74. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  75. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  76. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  77. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  78. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  79. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  80. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  81. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  82. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  83. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  84. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  85. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  86. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  87. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  88. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  89. tico/quantization/wrapq/wrappers/registry.py +125 -0
  90. tico/serialize/circle_serializer.py +11 -4
  91. tico/serialize/operators/adapters/__init__.py +1 -0
  92. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  93. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  94. tico/serialize/operators/op_le.py +54 -0
  95. tico/serialize/operators/op_mm.py +15 -132
  96. tico/serialize/operators/op_rmsnorm.py +65 -0
  97. tico/utils/convert.py +20 -15
  98. tico/utils/dtype.py +22 -0
  99. tico/utils/register_custom_op.py +29 -4
  100. tico/utils/signature.py +247 -0
  101. tico/utils/utils.py +50 -53
  102. tico/utils/validate_args_kwargs.py +37 -0
  103. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  104. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/RECORD +130 -73
  105. tico/experimental/quantization/__init__.py +0 -6
  106. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  107. tico/experimental/quantization/evaluation/metric.py +0 -109
  108. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  109. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  111. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  112. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  113. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  114. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  115. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  116. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  117. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  118. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  119. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  120. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  121. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  122. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  123. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  124. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  125. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  126. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  127. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  128. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  129. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  130. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List, Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ def compute_max_abs_diff(base: torch.Tensor, target: torch.Tensor) -> float:
22
+ """
23
+ Return the *maximum* absolute element-wise difference between two tensors.
24
+ """
25
+ assert base.shape == target.shape, "shape mismatch"
26
+ return (base.detach() - target.detach()).abs().max().item()
27
+
28
+
29
+ def compute_peir(base: torch.Tensor, target: torch.Tensor) -> float:
30
+ """
31
+ Peak-Error-to-Interval Ratio (PEIR).
32
+
33
+ PEIR = max(|base - target|) / (max(base) - min(base))
34
+
35
+ The interval denominator uses the reference (*base*) tensor only — this
36
+ makes PEIR independent of quantisation error in `target`.
37
+ """
38
+ assert base.shape == target.shape, "shape mismatch"
39
+ peak_error = (base.detach() - target.detach()).abs().max().item()
40
+ interval = (base.detach().max() - base.detach().min()).item()
41
+ interval = 1.0 if interval == 0.0 else interval # avoid divide-by-zero
42
+ return peak_error / interval
43
+
44
+
45
+ def mse(base: torch.Tensor, target: torch.Tensor) -> float:
46
+ """
47
+ Mean Squared Error (MSE).
48
+ Penalizes **larger** deviations more heavily than MAE by squaring each
49
+ difference — helpful to expose occasional large spikes.
50
+ Formula
51
+ -------
52
+ MSE = mean((base - target)²)
53
+ Returns
54
+ -------
55
+ float
56
+ Mean squared error. *Lower is better*.
57
+ """
58
+ return torch.mean((base.detach() - target.detach()) ** 2).item()
59
+
60
+
61
+ class MetricCalculator:
62
+ """
63
+ Lightweight registry-and-dispatcher for **pair-wise tensor comparison metrics**.
64
+
65
+ Purpose
66
+ -------
67
+ Consolidate all metrics used to assess the discrepancy between a reference
68
+ (usually FP32) tensor and its quantized counterpart, while letting the caller
69
+ choose *at runtime* which subset to evaluate.
70
+
71
+ Built-in metrics
72
+ ----------------
73
+ Key Description
74
+ -------------------- -------------------------------------------------
75
+ "diff" / "max_abs_diff" Maximum absolute element-wise difference
76
+ "peir" Peak-Error-to-Interval Ratio
77
+
78
+ Usage pattern
79
+ -------------
80
+ >>> calc = MetricCalculator(custom_metrics={'mse': mse_fn})
81
+ >>> stats = calc.compute(fp_outs, q_outs, metrics=['diff', 'mse'])
82
+
83
+ • **Instantiation** registers any extra user metrics
84
+ (signature: ``fn(base: Tensor, target: Tensor) -> float``).
85
+ • **compute(...)** takes two *equal-length* lists of tensors and an optional
86
+ list of metric names.
87
+ — If *metrics* is *None*, every registered metric is evaluated.
88
+ — Returns a dict: ``{metric_name -> [value for each tensor pair]}``.
89
+
90
+ Implementation notes
91
+ --------------------
92
+ * All tensors are detached before calculation to avoid autograd overhead.
93
+ * Registrations are stored in `self.registry` (str → callable).
94
+ * Duplicate metric names between built-ins and custom metrics raise an error
95
+ at construction time to prevent silent shadowing.
96
+ """
97
+
98
+ builtin_metrics: Dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = {
99
+ "diff": compute_max_abs_diff,
100
+ "max_abs_diff": compute_max_abs_diff,
101
+ "peir": compute_peir,
102
+ "mse": mse,
103
+ }
104
+
105
+ def __init__(
106
+ self,
107
+ custom_metrics: Optional[
108
+ Dict[str, Callable[[torch.Tensor, torch.Tensor], float]]
109
+ ] = None,
110
+ ):
111
+ self.registry: Dict[str, Callable] = self.builtin_metrics.copy()
112
+ if custom_metrics:
113
+ dup = self.registry.keys() & custom_metrics.keys()
114
+ if dup:
115
+ raise RuntimeError(f"Duplicate metric names: {dup}")
116
+ assert custom_metrics is not None
117
+ self.registry.update(custom_metrics) # type: ignore[arg-type]
118
+
119
+ # ----------------------------------------------------------------- #
120
+ # Public API #
121
+ # ----------------------------------------------------------------- #
122
+ def compute(
123
+ self,
124
+ base_outputs: List[torch.Tensor],
125
+ target_outputs: List[torch.Tensor],
126
+ metrics: Optional[List[str]] = None,
127
+ ) -> Dict[str, List[Any]]:
128
+ """
129
+ Compute selected metrics for every (base, target) pair.
130
+
131
+ Parameters
132
+ ----------
133
+ metrics
134
+ List of metric names to evaluate **this call**.
135
+ • None → evaluate *all* registered metrics.
136
+ """
137
+ sel = metrics or list(self.registry)
138
+ unknown = set(sel) - self.registry.keys()
139
+ if unknown:
140
+ raise RuntimeError(f"Unknown metric(s): {unknown}")
141
+
142
+ results: Dict[str, List[Any]] = {m: [] for m in sel}
143
+ for base, tgt in zip(base_outputs, target_outputs):
144
+ for m in sel:
145
+ results[m].append(self.registry[m](base, tgt))
146
+ return results
@@ -44,7 +44,7 @@ def quantize(
44
44
  data = np.array(data)
45
45
  # Perfrom quantization
46
46
  if not scale:
47
- logger.warn("WARNING: scale value is 0. 1e-7 will be used instead.")
47
+ logger.warning("WARNING: scale value is 0. 1e-7 will be used instead.")
48
48
  scale = 1e-7
49
49
  rescaled = np.round(data / scale) + zero_point
50
50
  # Clamp the values
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -13,25 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import copy
16
- from typing import Any, Dict, Optional, Type
16
+ from typing import Any, Dict, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
- from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
- from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
23
- SmoothQuantQuantizer,
24
- )
25
- from tico.experimental.quantization.config import BaseConfig
26
- from tico.experimental.quantization.quantizer import BaseQuantizer
20
+ from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
+ from tico.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
+ from tico.quantization.config.base import BaseConfig
23
+ from tico.quantization.quantizer import BaseQuantizer
24
+ from tico.quantization.quantizer_registry import get_quantizer
27
25
 
28
26
 
29
- config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
30
- "pt2e": PT2EQuantizer,
31
- "gptq": GPTQQuantizer,
32
- "smooth_quant": SmoothQuantQuantizer,
33
- }
34
-
35
27
  QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
36
28
 
37
29
 
@@ -40,7 +32,7 @@ def prepare(
40
32
  quant_config: BaseConfig,
41
33
  args: Optional[Any] = None,
42
34
  kwargs: Optional[Dict[str, Any]] = None,
43
- inplace: Optional[bool] = False,
35
+ inplace: Optional[bool] = True,
44
36
  ):
45
37
  """
46
38
  Prepare the model for quantization using the provided configuration.
@@ -61,21 +53,22 @@ def prepare(
61
53
  """
62
54
  if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
63
55
  raise RuntimeError("prepare() already has been called.")
64
- if quant_config.name == "pt2e" and inplace:
56
+ quantizer = get_quantizer(quant_config)
57
+
58
+ if isinstance(quantizer, PT2EQuantizer) and inplace:
65
59
  raise RuntimeError(
66
60
  "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
67
61
  )
68
62
 
69
63
  model = model if inplace else copy.deepcopy(model)
70
64
 
71
- quantizer = config_to_quantizer[quant_config.name](quant_config)
72
65
  model = quantizer.prepare(model, args, kwargs)
73
66
  setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
74
67
 
75
68
  return model
76
69
 
77
70
 
78
- def convert(model, inplace: Optional[bool] = False):
71
+ def convert(model, inplace: Optional[bool] = True):
79
72
  """
80
73
  Convert the prepared model to a quantized model using the provided configuration.
81
74
 
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.config import BaseConfig
20
+ from tico.quantization.config.base import BaseConfig
21
21
 
22
22
 
23
23
  class BaseQuantizer(ABC):
@@ -0,0 +1,73 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ from typing import Dict, Optional, Type, TypeVar
17
+
18
+ from tico.quantization.config.base import BaseConfig
19
+ from tico.quantization.quantizer import BaseQuantizer
20
+
21
+ TQ = TypeVar("TQ", bound=BaseQuantizer)
22
+
23
+ # Mapping: Config type -> Quantizer type
24
+ _REGISTRY: Dict[Type[BaseConfig], Type[BaseQuantizer]] = {}
25
+
26
+
27
+ def register_quantizer(config_cls: Type[BaseConfig]):
28
+ """
29
+ Decorator to register a quantizer for a given config class.
30
+ Usage:
31
+ @register_quantizer(GPTQConfig)
32
+ class GPTQQuantizer(BaseQuantizer): ...
33
+ """
34
+
35
+ def wrapper(quantizer_cls: Type[TQ]) -> Type[TQ]:
36
+ _REGISTRY[config_cls] = quantizer_cls
37
+ return quantizer_cls
38
+
39
+ return wrapper
40
+
41
+
42
+ def _lookup(cfg: BaseConfig) -> Optional[Type[BaseQuantizer]]:
43
+ """Return a quantizer class only if the exact config type is registered."""
44
+ return _REGISTRY.get(type(cfg))
45
+
46
+
47
+ def get_quantizer(cfg: BaseConfig) -> BaseQuantizer:
48
+ """Factory to return a quantizer instance for the given config."""
49
+ qcls = _lookup(cfg)
50
+ if qcls is not None:
51
+ return qcls(cfg)
52
+
53
+ # Lazy import by naming convention
54
+ name = getattr(cfg, "name", None)
55
+ if name:
56
+ if name == "ptq":
57
+ importlib.import_module(f"tico.quantization.wrapq.quantizer")
58
+ else:
59
+ try:
60
+ importlib.import_module(f"tico.quantization.algorithm.{name}.quantizer")
61
+ except Exception as e:
62
+ raise RuntimeError(
63
+ f"Failed to import quantizer module for config name='{name}': {e}"
64
+ )
65
+
66
+ qcls = _lookup(cfg)
67
+ if qcls is not None:
68
+ return qcls(cfg)
69
+
70
+ raise RuntimeError(
71
+ f"No quantizer registered for config type {type(cfg).__name__} "
72
+ f"(name='{getattr(cfg,'name',None)}')."
73
+ )
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class DType:
20
+ """
21
+ Self-contained integer dtypes for quantization.
22
+
23
+ A DType is just an immutable value-object with two fields:
24
+ - bits
25
+ - signed
26
+
27
+ Common presets (INT8, UINT4, ..) are provided as constants for convenience.
28
+ """
29
+
30
+ bits: int # pylint: disable=used-before-assignment
31
+ signed: bool = False # False -> unsigned
32
+
33
+ @property
34
+ def qmin(self) -> int:
35
+ assert self.bits is not None
36
+ if self.signed:
37
+ return -(1 << (self.bits - 1))
38
+ return 0
39
+
40
+ @property
41
+ def qmax(self) -> int:
42
+ assert self.bits is not None
43
+ if self.signed:
44
+ return (1 << (self.bits - 1)) - 1
45
+ return (1 << self.bits) - 1
46
+
47
+ def __str__(self) -> str:
48
+ prefix = "int" if self.signed else "uint"
49
+ return f"{prefix}{self.bits}"
50
+
51
+ # ────────────────────────────────
52
+ # Factory helpers
53
+ # ────────────────────────────────
54
+ @staticmethod
55
+ def int(bits: int): # type: ignore[valid-type]
56
+ return DType(bits, signed=True)
57
+
58
+ @staticmethod
59
+ def uint(bits: int): # type: ignore[valid-type]
60
+ return DType(bits, signed=False)
61
+
62
+
63
+ # ---------------------------------------------------------------------
64
+ # Convenient canned versions
65
+ # ---------------------------------------------------------------------
66
+ UINT4 = DType.uint(4)
67
+ INT4 = DType.int(4)
68
+ INT8 = DType.int(8)
69
+ UINT8 = DType.uint(8)
70
+ INT16 = DType.int(16)
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,230 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # =============================================================================
16
+ # QUICK PTQ WORKFLOW (OPTIONAL FP32 BASELINE)
17
+ # -----------------------------------------------------------------------------
18
+ # Toggle RUN_FP to choose between:
19
+ # • FP32 perplexity measurement only, OR
20
+ # • Full post-training UINT-8 flow (wrap → calibrate → eval).
21
+ # =============================================================================
22
+
23
+ import argparse
24
+ import sys
25
+
26
+ import torch
27
+ import tqdm
28
+ from datasets import load_dataset
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer
30
+
31
+ from tico.quantization import convert, prepare
32
+ from tico.quantization.config.ptq import PTQConfig
33
+ from tico.quantization.wrapq.utils.metrics import perplexity
34
+
35
+ # Token-budget presets for activation calibration
36
+ TOKENS: dict[str, int] = {
37
+ # Smoke test (<1 min turnaround on CPU/GPU)
38
+ "debug": 2_000, # ≈16 × 128-seq batches
39
+ # Good default for 1-7B models (≲3 % ppl delta)
40
+ "baseline": 50_000,
41
+ # Production / 4-bit observer smoothing
42
+ "production": 200_000,
43
+ }
44
+
45
+ DTYPE_MAP = {
46
+ "float32": torch.float32,
47
+ "bfloat16": torch.bfloat16,
48
+ "float16": torch.float16,
49
+ }
50
+
51
+ # Hardcoded dataset settings
52
+ DATASET_NAME = "wikitext"
53
+ DATASET_CONFIG = "wikitext-2-raw-v1"
54
+ TRAIN_SPLIT = "train"
55
+ TEST_SPLIT = "test"
56
+
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser(description="Quick PTQ example (FP or UINT8)")
60
+ parser.add_argument(
61
+ "--mode",
62
+ choices=["fp", "uint8"],
63
+ default="fp",
64
+ help="Choose FP baseline only or full UINT8 PTQ path.",
65
+ )
66
+ parser.add_argument(
67
+ "--model", type=str, required=True, help="HF repo name or local path."
68
+ )
69
+ parser.add_argument(
70
+ "--device",
71
+ type=str,
72
+ default="cuda" if torch.cuda.is_available() else "cpu",
73
+ help="Device to run on (cuda|cpu).",
74
+ )
75
+ parser.add_argument(
76
+ "--dtype",
77
+ choices=list(DTYPE_MAP.keys()),
78
+ default="float32",
79
+ help=f"Model dtype for load.",
80
+ )
81
+ parser.add_argument(
82
+ "--stride", type=int, default=512, help="Sliding-window stride for perplexity."
83
+ )
84
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
85
+ parser.add_argument(
86
+ "--trust-remote-code",
87
+ action="store_true",
88
+ help="Enable only if you trust the model repo code.",
89
+ )
90
+ parser.add_argument(
91
+ "--hf-token",
92
+ type=str,
93
+ default=None,
94
+ help="Optional HF token for gated/private models.",
95
+ )
96
+ parser.add_argument(
97
+ "--use-cache",
98
+ dest="use_cache",
99
+ action="store_true",
100
+ default=False,
101
+ help="Use model KV cache if enabled (off by default).",
102
+ )
103
+ parser.add_argument(
104
+ "--no-tqdm", action="store_true", help="Disable tqdm progress bars."
105
+ )
106
+ # 2) calib-preset default = debug
107
+ parser.add_argument(
108
+ "--calib-preset",
109
+ choices=list(TOKENS.keys()),
110
+ default="debug",
111
+ help="Calibration token budget preset.",
112
+ )
113
+
114
+ args = parser.parse_args()
115
+
116
+ # Basic setup
117
+ torch.manual_seed(args.seed)
118
+ device = torch.device(args.device)
119
+ dtype = DTYPE_MAP[args.dtype]
120
+
121
+ print("=== Config ===")
122
+ print(f"Mode : {args.mode}")
123
+ print(f"Model : {args.model}")
124
+ print(f"Device : {device.type}")
125
+ print(f"DType : {args.dtype}")
126
+ print(f"Stride : {args.stride}")
127
+ print(f"Use HF cache? : {args.use_cache}")
128
+ print(
129
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
130
+ )
131
+ print()
132
+
133
+ # -------------------------------------------------------------------------
134
+ # 1. Load model and tokenizer
135
+ # -------------------------------------------------------------------------
136
+ tokenizer = AutoTokenizer.from_pretrained(
137
+ args.model,
138
+ trust_remote_code=args.trust_remote_code,
139
+ token=args.hf_token,
140
+ )
141
+
142
+ model = (
143
+ AutoModelForCausalLM.from_pretrained(
144
+ args.model,
145
+ torch_dtype=dtype,
146
+ trust_remote_code=args.trust_remote_code,
147
+ token=args.hf_token,
148
+ )
149
+ .to(device)
150
+ .eval()
151
+ )
152
+
153
+ model.config.use_cache = args.use_cache
154
+
155
+ if args.mode == "fp":
156
+ fp_model = model
157
+ else:
158
+ # INT8 PTQ path
159
+ uint8_model = model
160
+
161
+ CALIB_TOKENS = TOKENS[args.calib_preset]
162
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
163
+
164
+ # ---------------------------------------------------------------------
165
+ # 2. Wrap every Transformer layer with PTQWrapper
166
+ # ---------------------------------------------------------------------
167
+ qcfg = PTQConfig() # all-uint8 defaults
168
+ prepare(uint8_model, qcfg)
169
+
170
+ # ---------------------------------------------------------------------
171
+ # 3. Single-pass activation calibration
172
+ # ---------------------------------------------------------------------
173
+ print("Calibrating UINT-8 observers …")
174
+ calib_txt = " ".join(
175
+ load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)["text"]
176
+ )[:CALIB_TOKENS]
177
+ ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
178
+
179
+ # Run inference to collect ranges
180
+ iterator = range(0, ids.size(1) - 1, args.stride)
181
+ if not args.no_tqdm:
182
+ iterator = tqdm.tqdm(iterator, desc="Calibration")
183
+ with torch.no_grad():
184
+ for i in iterator:
185
+ uint8_model(ids[:, i : i + args.stride])
186
+
187
+ # Freeze (scale, zero-point)
188
+ convert(uint8_model)
189
+
190
+ # -------------------------------------------------------------------------
191
+ # 4. Evaluate perplexity
192
+ # -------------------------------------------------------------------------
193
+ print("\nCalculating perplexities …")
194
+ test_ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
195
+ enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
196
+
197
+ if args.mode == "fp":
198
+ ppl_fp = perplexity(
199
+ fp_model,
200
+ enc,
201
+ args.device,
202
+ stride=args.stride,
203
+ show_progress=not args.no_tqdm,
204
+ )
205
+ else:
206
+ ppl_int8 = perplexity(
207
+ uint8_model,
208
+ enc,
209
+ args.device,
210
+ stride=args.stride,
211
+ show_progress=not args.no_tqdm,
212
+ )
213
+
214
+ # -------------------------------------------------------------------------
215
+ # 5. Report
216
+ # -------------------------------------------------------------------------
217
+ print("\n┌── Wikitext-2 test perplexity ─────────────")
218
+ if args.mode == "fp":
219
+ print(f"│ FP : {ppl_fp:8.2f}")
220
+ else:
221
+ print(f"│ UINT-8 : {ppl_int8:8.2f}")
222
+ print("└───────────────────────────────────────────")
223
+
224
+
225
+ if __name__ == "__main__":
226
+ try:
227
+ main()
228
+ except Exception as e:
229
+ print(f"\n[Error] {e}", file=sys.stderr)
230
+ sys.exit(1)