tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
18
|
+
from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MinMaxObserver(AffineObserverBase):
|
|
22
|
+
"""Plain min/max range tracker."""
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def _update_stats(self, x: torch.Tensor) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Update running min/max with the incoming batch.
|
|
28
|
+
|
|
29
|
+
Per-tensor: use global min/max.
|
|
30
|
+
Per-channel: reduce all axes except the channel axis.
|
|
31
|
+
"""
|
|
32
|
+
if self.channel_axis is None:
|
|
33
|
+
curr_min, curr_max = x.min(), x.max()
|
|
34
|
+
else:
|
|
35
|
+
curr_min, curr_max = channelwise_minmax(x, self.channel_axis)
|
|
36
|
+
|
|
37
|
+
# Broadcasting handles scalar-vs-vector cases
|
|
38
|
+
self.min_val = torch.minimum(self.min_val, curr_min)
|
|
39
|
+
self.max_val = torch.maximum(self.max_val, curr_max)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from tico.quantization.wrapq.observers.base import ObserverBase
|
|
18
|
+
from tico.utils.mx.mx_ops import quantize_mx
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MXObserver(ObserverBase):
|
|
22
|
+
"""MX (micro-scaling) observer: no min/max, no affine qparams."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
*,
|
|
27
|
+
name: str,
|
|
28
|
+
elem_format: str = "int8",
|
|
29
|
+
axis: int = 0,
|
|
30
|
+
shared_exp_method: str = "max",
|
|
31
|
+
round: str = "nearest",
|
|
32
|
+
**base_kwargs,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(name=name, **base_kwargs)
|
|
35
|
+
self.elem_format = elem_format
|
|
36
|
+
self.axis = axis
|
|
37
|
+
self.shared_exp_method = shared_exp_method
|
|
38
|
+
self.round = round
|
|
39
|
+
|
|
40
|
+
def reset(self) -> None:
|
|
41
|
+
# No state to reset
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
@torch.no_grad()
|
|
45
|
+
def _update_stats(self, x: torch.Tensor) -> None:
|
|
46
|
+
# No stats required
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
def compute_qparams(self):
|
|
50
|
+
# MX path does not produce affine qparams; keep interface contract.
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
def fake_quant(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
return quantize_mx(
|
|
55
|
+
x,
|
|
56
|
+
elem_format=self.elem_format,
|
|
57
|
+
axis=self.axis,
|
|
58
|
+
shared_exp_method=self.shared_exp_method,
|
|
59
|
+
round=self.round,
|
|
60
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from enum import auto, Enum
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class QScheme(Enum):
|
|
19
|
+
# ───── Per-tensor ────────────
|
|
20
|
+
PER_TENSOR_ASYMM = auto()
|
|
21
|
+
PER_TENSOR_SYMM = auto()
|
|
22
|
+
# ───── Per-channel ───────────
|
|
23
|
+
PER_CHANNEL_ASYMM = auto()
|
|
24
|
+
PER_CHANNEL_SYMM = auto()
|
|
25
|
+
|
|
26
|
+
# helper
|
|
27
|
+
def is_per_channel(self) -> bool:
|
|
28
|
+
return self in {
|
|
29
|
+
QScheme.PER_CHANNEL_ASYMM,
|
|
30
|
+
QScheme.PER_CHANNEL_SYMM,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
def is_symmetric(self) -> bool:
|
|
34
|
+
return self in {
|
|
35
|
+
QScheme.PER_TENSOR_SYMM,
|
|
36
|
+
QScheme.PER_CHANNEL_SYMM,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
def __str__(self) -> str:
|
|
40
|
+
return self.name.lower()
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
22
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
23
|
+
|
|
24
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
25
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@register_quantizer(PTQConfig)
|
|
29
|
+
class PTQQuantizer(BaseQuantizer):
|
|
30
|
+
"""
|
|
31
|
+
Post-Training Quantization (PTQ) quantizer integrated with the public interface.
|
|
32
|
+
|
|
33
|
+
Features
|
|
34
|
+
--------
|
|
35
|
+
• Automatically wraps quantizable modules using PTQWrapper.
|
|
36
|
+
• Supports leaf-level (single-module) quantization (e.g., prepare(model.fc, PTQConfig())).
|
|
37
|
+
• Enforces strict wrapping if `strict_wrap=True`: raises NotImplementedError if
|
|
38
|
+
no quantizable module was found at any boundary.
|
|
39
|
+
• If `strict_wrap=False`, unquantizable modules are silently skipped.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, config: PTQConfig):
|
|
43
|
+
super().__init__(config)
|
|
44
|
+
self.qcfg: PTQConfig = config
|
|
45
|
+
self.strict_wrap: bool = bool(getattr(config, "strict_wrap", True))
|
|
46
|
+
|
|
47
|
+
@torch.no_grad()
|
|
48
|
+
def prepare(
|
|
49
|
+
self,
|
|
50
|
+
model: torch.nn.Module,
|
|
51
|
+
args: Optional[Any] = None,
|
|
52
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
53
|
+
):
|
|
54
|
+
# Wrap the tree (or single module) according to strictness policy
|
|
55
|
+
model = self._wrap_supported(model, self.qcfg)
|
|
56
|
+
|
|
57
|
+
# Switch all quant modules into calibration mode
|
|
58
|
+
if isinstance(model, QuantModuleBase):
|
|
59
|
+
model.enable_calibration()
|
|
60
|
+
for m in model.modules():
|
|
61
|
+
if isinstance(m, QuantModuleBase):
|
|
62
|
+
m.enable_calibration()
|
|
63
|
+
return model
|
|
64
|
+
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def convert(self, model):
|
|
67
|
+
# Freeze qparams across the tree (QUANT mode)
|
|
68
|
+
if isinstance(model, QuantModuleBase):
|
|
69
|
+
model.freeze_qparams()
|
|
70
|
+
for m in model.modules():
|
|
71
|
+
if isinstance(m, QuantModuleBase):
|
|
72
|
+
m.freeze_qparams()
|
|
73
|
+
return model
|
|
74
|
+
|
|
75
|
+
def _wrap_supported(
|
|
76
|
+
self,
|
|
77
|
+
root: nn.Module,
|
|
78
|
+
qcfg: PTQConfig,
|
|
79
|
+
) -> nn.Module:
|
|
80
|
+
"""
|
|
81
|
+
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
|
|
82
|
+
"""
|
|
83
|
+
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
|
|
84
|
+
|
|
85
|
+
# Case A: HuggingFace-style transformers: model.model.layers
|
|
86
|
+
lm = getattr(root, "model", None)
|
|
87
|
+
layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
|
|
88
|
+
if isinstance(layers, nn.ModuleList):
|
|
89
|
+
new_list = nn.ModuleList()
|
|
90
|
+
for idx, layer in enumerate(layers):
|
|
91
|
+
child_scope = f"layer{idx}"
|
|
92
|
+
child_cfg = qcfg.child(child_scope)
|
|
93
|
+
|
|
94
|
+
# Enforce strictness at the child boundary
|
|
95
|
+
wrapped = self._try_wrap(
|
|
96
|
+
layer,
|
|
97
|
+
child_cfg,
|
|
98
|
+
fp_name=child_scope,
|
|
99
|
+
raise_on_fail=self.strict_wrap,
|
|
100
|
+
)
|
|
101
|
+
new_list.append(wrapped)
|
|
102
|
+
lm.layers = new_list # type: ignore[union-attr]
|
|
103
|
+
return root
|
|
104
|
+
|
|
105
|
+
# Case B: Containers
|
|
106
|
+
if isinstance(root, (nn.Sequential, nn.ModuleList)):
|
|
107
|
+
for i, child in enumerate(list(root)):
|
|
108
|
+
name = str(i)
|
|
109
|
+
child_cfg = qcfg.child(name)
|
|
110
|
+
|
|
111
|
+
wrapped = self._try_wrap(
|
|
112
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
113
|
+
)
|
|
114
|
+
if wrapped is child:
|
|
115
|
+
assert not self.strict_wrap
|
|
116
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
117
|
+
root[i] = wrapped # type: ignore[index]
|
|
118
|
+
|
|
119
|
+
if isinstance(root, nn.ModuleDict):
|
|
120
|
+
for k, child in list(root.items()):
|
|
121
|
+
name = k
|
|
122
|
+
child_cfg = qcfg.child(name)
|
|
123
|
+
|
|
124
|
+
wrapped = self._try_wrap(
|
|
125
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
126
|
+
)
|
|
127
|
+
if wrapped is child:
|
|
128
|
+
assert not self.strict_wrap
|
|
129
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
130
|
+
root[k] = wrapped # type: ignore[index]
|
|
131
|
+
|
|
132
|
+
# Case C: Leaf node
|
|
133
|
+
root_name = getattr(root, "_get_name", lambda: None)()
|
|
134
|
+
wrapped = self._try_wrap(
|
|
135
|
+
root, qcfg, fp_name=root_name, raise_on_fail=self.strict_wrap
|
|
136
|
+
)
|
|
137
|
+
if wrapped is not root:
|
|
138
|
+
return wrapped
|
|
139
|
+
|
|
140
|
+
assert not self.strict_wrap
|
|
141
|
+
# Case D: Named children
|
|
142
|
+
for name, child in list(root.named_children()):
|
|
143
|
+
child_cfg = qcfg.child(name)
|
|
144
|
+
|
|
145
|
+
wrapped = self._try_wrap(
|
|
146
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
147
|
+
)
|
|
148
|
+
if wrapped is child:
|
|
149
|
+
assert not self.strict_wrap
|
|
150
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
151
|
+
setattr(root, name, wrapped)
|
|
152
|
+
|
|
153
|
+
return root
|
|
154
|
+
|
|
155
|
+
def _try_wrap(
|
|
156
|
+
self,
|
|
157
|
+
module: nn.Module,
|
|
158
|
+
qcfg_for_child: PTQConfig,
|
|
159
|
+
*,
|
|
160
|
+
fp_name: Optional[str],
|
|
161
|
+
raise_on_fail: bool,
|
|
162
|
+
) -> nn.Module:
|
|
163
|
+
"""
|
|
164
|
+
Attempt to wrap a boundary with PTQWrapper.
|
|
165
|
+
|
|
166
|
+
Behavior:
|
|
167
|
+
• If PTQWrapper succeeds: return wrapped module.
|
|
168
|
+
• If PTQWrapper raises NotImplementedError:
|
|
169
|
+
- raise_on_fail=True -> re-raise (strict)
|
|
170
|
+
- raise_on_fail=False -> return original module (permissive)
|
|
171
|
+
"""
|
|
172
|
+
try:
|
|
173
|
+
return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
|
|
174
|
+
except NotImplementedError as e:
|
|
175
|
+
if raise_on_fail:
|
|
176
|
+
raise NotImplementedError(
|
|
177
|
+
f"PTQQuantizer: no quantization wrapper for {type(module).__name__}"
|
|
178
|
+
) from e
|
|
179
|
+
return module
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from tico.quantization.evaluation.metric import MetricCalculator
|
|
20
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
21
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
|
|
25
|
+
"""
|
|
26
|
+
Return {module_object: full_qualified_name} without touching the modules.
|
|
27
|
+
"""
|
|
28
|
+
return {m: n for n, m in root.named_modules()}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def save_fp_outputs(
|
|
32
|
+
model: torch.nn.Module,
|
|
33
|
+
) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[str, torch.Tensor]]:
|
|
34
|
+
"""
|
|
35
|
+
Register forward-hooks on every `QuantModuleBase` wrapper itself (not the
|
|
36
|
+
wrapped `module`) and cache its output while the wrapper runs in CALIB mode.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
model : torch.nn.Module
|
|
41
|
+
The model whose wrappers are already switched to CALIB mode
|
|
42
|
+
(`enable_calibration()` has been called).
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
handles : list[RemovableHandle]
|
|
47
|
+
Hook handles; call `.remove()` on each one to detach the hooks.
|
|
48
|
+
cache : dict[str, torch.Tensor]
|
|
49
|
+
Mapping "wrapper-name → cached FP32 activation" captured from the first
|
|
50
|
+
forward pass. Keys default to `wrapper.fp_name`; if that attribute is
|
|
51
|
+
`None`, the `id(wrapper)` string is used instead.
|
|
52
|
+
"""
|
|
53
|
+
cache: Dict[str, torch.Tensor] = {}
|
|
54
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
55
|
+
|
|
56
|
+
def _save(name: str):
|
|
57
|
+
def hook(_, __, out: torch.Tensor | Tuple):
|
|
58
|
+
if isinstance(out, tuple):
|
|
59
|
+
out = out[0]
|
|
60
|
+
assert isinstance(out, torch.Tensor)
|
|
61
|
+
cache[name] = out.detach()
|
|
62
|
+
|
|
63
|
+
return hook
|
|
64
|
+
|
|
65
|
+
for m in model.modules():
|
|
66
|
+
if isinstance(m, QuantModuleBase):
|
|
67
|
+
name = m.fp_name or str(id(m))
|
|
68
|
+
handles.append(m.register_forward_hook(_save(name)))
|
|
69
|
+
|
|
70
|
+
return handles, cache
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def compare_layer_outputs(
|
|
74
|
+
model: torch.nn.Module,
|
|
75
|
+
cache: Dict[str, torch.Tensor],
|
|
76
|
+
*,
|
|
77
|
+
metrics: Optional[List[str]] = None,
|
|
78
|
+
custom_metrics: Optional[Dict[str, Callable]] = None,
|
|
79
|
+
rtol: float = 1e-3,
|
|
80
|
+
atol: float = 1e-3,
|
|
81
|
+
collect: bool = False,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Register forward-hooks on every `QuantModuleBase` wrapper to compare its
|
|
85
|
+
QUANT-mode output to the FP32 reference saved by `save_fp_outputs()`.
|
|
86
|
+
|
|
87
|
+
Each hook prints a per-layer diff report:
|
|
88
|
+
|
|
89
|
+
✓ layer_name max=1.23e-02 mean=8.45e-04 (within tolerance)
|
|
90
|
+
⚠️ layer_name max=3.07e+00 mean=5.12e-01 (exceeds tolerance)
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
model : torch.nn.Module
|
|
95
|
+
The model whose wrappers are now in QUANT mode
|
|
96
|
+
(`freeze_qparams()` has been called).
|
|
97
|
+
cache : dict[str, torch.Tensor]
|
|
98
|
+
The reference activations captured during CALIB mode.
|
|
99
|
+
metrics
|
|
100
|
+
Metrics to compute. Defaults to `["diff"]`. Add `peir` to print PEIR.
|
|
101
|
+
custom_metrics
|
|
102
|
+
Optional user metric functions. Same signature as built-ins.
|
|
103
|
+
rtol, atol : float, optional
|
|
104
|
+
Relative / absolute tolerances used to flag large deviations
|
|
105
|
+
(similar to `torch.allclose` semantics).
|
|
106
|
+
collect : bool, optional
|
|
107
|
+
• False (default) → print one-line report per layer, return `None`
|
|
108
|
+
• True → suppress printing, return a nested dict
|
|
109
|
+
{layer_name -> {metric -> value}}
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
handles
|
|
114
|
+
Hook handles; call `.remove()` once diffing is complete.
|
|
115
|
+
results
|
|
116
|
+
Only if *collect* is True.
|
|
117
|
+
"""
|
|
118
|
+
metrics = metrics or ["diff"]
|
|
119
|
+
calc = MetricCalculator(custom_metrics)
|
|
120
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
121
|
+
results: Dict[
|
|
122
|
+
str, Dict[str, float]
|
|
123
|
+
] = {} # Dict[layer_name, Dict[metric_name, value]]
|
|
124
|
+
|
|
125
|
+
def _cmp(name: str):
|
|
126
|
+
ref = cache.get(name)
|
|
127
|
+
|
|
128
|
+
def hook(_, __, out):
|
|
129
|
+
if ref is None:
|
|
130
|
+
if not collect:
|
|
131
|
+
print(f"[{name}] no cached reference")
|
|
132
|
+
return
|
|
133
|
+
if isinstance(out, tuple):
|
|
134
|
+
out = out[0]
|
|
135
|
+
assert isinstance(out, torch.Tensor)
|
|
136
|
+
|
|
137
|
+
# Compute all requested metrics
|
|
138
|
+
res = calc.compute([ref], [out], metrics) # lists with length-1 tensors
|
|
139
|
+
res = {k: v[0] for k, v in res.items()} # flatten
|
|
140
|
+
|
|
141
|
+
if collect:
|
|
142
|
+
results[name] = res # type: ignore[assignment]
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
# Pretty print ------------------------------------------------ #
|
|
146
|
+
diff_val = res.get("diff") or res.get("max_abs_diff")
|
|
147
|
+
thresh = atol + rtol * ref.abs().max().item()
|
|
148
|
+
flag = "⚠️" if (diff_val is not None and diff_val > thresh) else "✓" # type: ignore[operator]
|
|
149
|
+
|
|
150
|
+
pieces = [f"{flag} {name:45s}"]
|
|
151
|
+
for key, val in res.items():
|
|
152
|
+
pieces.append(f"{key}={val:<7.4}")
|
|
153
|
+
print(" ".join(pieces))
|
|
154
|
+
|
|
155
|
+
return hook
|
|
156
|
+
|
|
157
|
+
for m in model.modules():
|
|
158
|
+
if isinstance(m, PTQWrapper):
|
|
159
|
+
# skip the internal fp module inside the wrapper
|
|
160
|
+
continue
|
|
161
|
+
if isinstance(m, QuantModuleBase):
|
|
162
|
+
lname = m.fp_name or str(id(m))
|
|
163
|
+
handles.append(m.register_forward_hook(_cmp(lname)))
|
|
164
|
+
|
|
165
|
+
if collect:
|
|
166
|
+
return handles, results
|
|
167
|
+
return handles
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import tqdm
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def perplexity(
|
|
22
|
+
model: torch.nn.Module,
|
|
23
|
+
encodings: torch.Tensor,
|
|
24
|
+
device: torch.device | str,
|
|
25
|
+
*,
|
|
26
|
+
max_length: Optional[int] = None,
|
|
27
|
+
stride: int = 512,
|
|
28
|
+
ignore_index: int | None = -100,
|
|
29
|
+
show_progress: bool = True,
|
|
30
|
+
) -> float:
|
|
31
|
+
"""
|
|
32
|
+
Compute perplexity (PPL) using a "strided sliding-window"
|
|
33
|
+
evaluation strategy.
|
|
34
|
+
|
|
35
|
+
The function:
|
|
36
|
+
1. Splits the token sequence into overlapping windows of length
|
|
37
|
+
`max_length` (model context size).
|
|
38
|
+
2. Masks tokens that were already scored in previous windows
|
|
39
|
+
(`labels == -100`), so each token's negative log-likelihood (NLL)
|
|
40
|
+
is counted EXACTLY once.
|
|
41
|
+
3. Aggregates token-wise NLL to return corpus-level PPL.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
model : torch.nn.Module
|
|
46
|
+
Causal LM loaded in evaluation mode (`model.eval()`).
|
|
47
|
+
encodings : torch.Tensor | transformers.BatchEncoding
|
|
48
|
+
Tokenised corpus. If a `BatchEncoding` is passed, its
|
|
49
|
+
`.input_ids` field is used. Shape must be `(1, seq_len)`.
|
|
50
|
+
device : torch.device | str
|
|
51
|
+
CUDA or CPU device on which to run evaluation.
|
|
52
|
+
max_length : int, optional
|
|
53
|
+
Context window size. Defaults to `model.config.max_position_embeddings`.
|
|
54
|
+
stride : int, default = 512
|
|
55
|
+
Step size by which the sliding window advances. Must satisfy
|
|
56
|
+
`1 ≤ stride ≤ max_length`.
|
|
57
|
+
ignore_index : int, default = -100
|
|
58
|
+
Label value to ignore in loss computation. This should match
|
|
59
|
+
the `ignore_index` used by the model's internal
|
|
60
|
+
`CrossEntropyLoss`. For Hugging Face causal LMs, the
|
|
61
|
+
convention is `-100`.
|
|
62
|
+
show_progress : bool, default = True
|
|
63
|
+
If True, displays a tqdm progess bar while evaluating.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
float
|
|
68
|
+
Corpus-level perplexity.
|
|
69
|
+
"""
|
|
70
|
+
# -------- input preparation -------- #
|
|
71
|
+
try:
|
|
72
|
+
# transformers.BatchEncoding has `input_ids`
|
|
73
|
+
input_ids_full = encodings.input_ids # type: ignore[attr-defined]
|
|
74
|
+
except AttributeError: # already a tensor
|
|
75
|
+
input_ids_full = encodings
|
|
76
|
+
assert isinstance(input_ids_full, torch.Tensor)
|
|
77
|
+
input_ids_full = input_ids_full.to(device)
|
|
78
|
+
|
|
79
|
+
if max_length is None:
|
|
80
|
+
assert hasattr(model, "config")
|
|
81
|
+
assert hasattr(model.config, "max_position_embeddings")
|
|
82
|
+
assert isinstance(model.config.max_position_embeddings, int)
|
|
83
|
+
max_length = model.config.max_position_embeddings
|
|
84
|
+
assert max_length is not None
|
|
85
|
+
assert (
|
|
86
|
+
1 <= stride <= max_length
|
|
87
|
+
), f"stride ({stride}) must be in [1, max_length ({max_length})]"
|
|
88
|
+
|
|
89
|
+
seq_len = input_ids_full.size(1)
|
|
90
|
+
nll_sum = 0.0
|
|
91
|
+
n_tokens = 0
|
|
92
|
+
prev_end = 0
|
|
93
|
+
|
|
94
|
+
# -------- main loop -------- #
|
|
95
|
+
for begin in tqdm.trange(0, seq_len, stride, desc="PPL", disable=not show_progress):
|
|
96
|
+
end = min(begin + max_length, seq_len)
|
|
97
|
+
trg_len = end - prev_end # fresh tokens in this window
|
|
98
|
+
|
|
99
|
+
input_ids = input_ids_full[:, begin:end]
|
|
100
|
+
target_ids = input_ids.clone()
|
|
101
|
+
# mask previously-scored tokens
|
|
102
|
+
target_ids[:, :-trg_len] = ignore_index # type: ignore[assignment]
|
|
103
|
+
|
|
104
|
+
with torch.no_grad():
|
|
105
|
+
outputs = model(input_ids, labels=target_ids)
|
|
106
|
+
# loss is already averaged over non-masked labels
|
|
107
|
+
neg_log_likelihood = outputs.loss
|
|
108
|
+
|
|
109
|
+
# exact number of labels that contributed to loss
|
|
110
|
+
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item() # type: ignore[attr-defined]
|
|
111
|
+
nll_sum += neg_log_likelihood * loss_tokens
|
|
112
|
+
n_tokens += int(loss_tokens)
|
|
113
|
+
|
|
114
|
+
prev_end = end
|
|
115
|
+
if end == seq_len:
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
avg_nll: float | torch.Tensor = nll_sum / n_tokens
|
|
119
|
+
if not isinstance(avg_nll, torch.Tensor):
|
|
120
|
+
avg_nll = torch.tensor(avg_nll)
|
|
121
|
+
assert isinstance(avg_nll, torch.Tensor)
|
|
122
|
+
ppl = torch.exp(avg_nll)
|
|
123
|
+
|
|
124
|
+
return ppl.item()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def channelwise_minmax(x: torch.Tensor, channel_axis: int):
|
|
19
|
+
"""
|
|
20
|
+
Compute per-channel (min, max) by reducing all axes except `channel_axis`.
|
|
21
|
+
"""
|
|
22
|
+
channel_axis = channel_axis % x.ndim # handle negative indices safely
|
|
23
|
+
dims = tuple(d for d in range(x.ndim) if d != channel_axis)
|
|
24
|
+
|
|
25
|
+
return x.amin(dim=dims), x.amax(dim=dims)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|