tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/fpi_gptq.py +29 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +20 -15
- tico/utils/register_custom_op.py +6 -4
- tico/utils/signature.py +7 -8
- tico/utils/validate_args_kwargs.py +12 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
- /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
- /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
|
@@ -17,9 +17,9 @@ from typing import Optional, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
20
|
+
from tico.quantization.wrapq.dtypes import DType, UINT8
|
|
21
|
+
from tico.quantization.wrapq.observers.base import ObserverBase
|
|
22
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class AffineObserverBase(ObserverBase):
|
|
@@ -17,8 +17,8 @@ from typing import Optional, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
20
|
+
from tico.quantization.wrapq.dtypes import DType, UINT8
|
|
21
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class ObserverBase(ABC):
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
-
from tico.
|
|
18
|
-
from tico.
|
|
17
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
18
|
+
from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class EMAObserver(AffineObserverBase):
|
|
@@ -24,7 +24,7 @@ performing any statistics gathering or fake-quantization.
|
|
|
24
24
|
"""
|
|
25
25
|
import torch
|
|
26
26
|
|
|
27
|
-
from tico.
|
|
27
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class IdentityObserver(AffineObserverBase):
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
-
from tico.
|
|
18
|
-
from tico.
|
|
17
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
18
|
+
from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class MinMaxObserver(AffineObserverBase):
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
22
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
23
|
+
|
|
24
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
25
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@register_quantizer(PTQConfig)
|
|
29
|
+
class PTQQuantizer(BaseQuantizer):
|
|
30
|
+
"""
|
|
31
|
+
Post-Training Quantization (PTQ) quantizer integrated with the public interface.
|
|
32
|
+
|
|
33
|
+
Features
|
|
34
|
+
--------
|
|
35
|
+
• Automatically wraps quantizable modules using PTQWrapper.
|
|
36
|
+
• Supports leaf-level (single-module) quantization (e.g., prepare(model.fc, PTQConfig())).
|
|
37
|
+
• Enforces strict wrapping if `strict_wrap=True`: raises NotImplementedError if
|
|
38
|
+
no quantizable module was found at any boundary.
|
|
39
|
+
• If `strict_wrap=False`, unquantizable modules are silently skipped.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, config: PTQConfig):
|
|
43
|
+
super().__init__(config)
|
|
44
|
+
self.qcfg: PTQConfig = config
|
|
45
|
+
self.strict_wrap: bool = bool(getattr(config, "strict_wrap", True))
|
|
46
|
+
|
|
47
|
+
@torch.no_grad()
|
|
48
|
+
def prepare(
|
|
49
|
+
self,
|
|
50
|
+
model: torch.nn.Module,
|
|
51
|
+
args: Optional[Any] = None,
|
|
52
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
53
|
+
):
|
|
54
|
+
# Wrap the tree (or single module) according to strictness policy
|
|
55
|
+
model = self._wrap_supported(model, self.qcfg)
|
|
56
|
+
|
|
57
|
+
# Switch all quant modules into calibration mode
|
|
58
|
+
if isinstance(model, QuantModuleBase):
|
|
59
|
+
model.enable_calibration()
|
|
60
|
+
for m in model.modules():
|
|
61
|
+
if isinstance(m, QuantModuleBase):
|
|
62
|
+
m.enable_calibration()
|
|
63
|
+
return model
|
|
64
|
+
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def convert(self, model):
|
|
67
|
+
# Freeze qparams across the tree (QUANT mode)
|
|
68
|
+
if isinstance(model, QuantModuleBase):
|
|
69
|
+
model.freeze_qparams()
|
|
70
|
+
for m in model.modules():
|
|
71
|
+
if isinstance(m, QuantModuleBase):
|
|
72
|
+
m.freeze_qparams()
|
|
73
|
+
return model
|
|
74
|
+
|
|
75
|
+
def _wrap_supported(
|
|
76
|
+
self,
|
|
77
|
+
root: nn.Module,
|
|
78
|
+
qcfg: PTQConfig,
|
|
79
|
+
) -> nn.Module:
|
|
80
|
+
"""
|
|
81
|
+
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
|
|
82
|
+
"""
|
|
83
|
+
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
|
|
84
|
+
|
|
85
|
+
# Case A: HuggingFace-style transformers: model.model.layers
|
|
86
|
+
lm = getattr(root, "model", None)
|
|
87
|
+
layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
|
|
88
|
+
if isinstance(layers, nn.ModuleList):
|
|
89
|
+
new_list = nn.ModuleList()
|
|
90
|
+
for idx, layer in enumerate(layers):
|
|
91
|
+
child_scope = f"layer{idx}"
|
|
92
|
+
child_cfg = qcfg.child(child_scope)
|
|
93
|
+
|
|
94
|
+
# Enforce strictness at the child boundary
|
|
95
|
+
wrapped = self._try_wrap(
|
|
96
|
+
layer,
|
|
97
|
+
child_cfg,
|
|
98
|
+
fp_name=child_scope,
|
|
99
|
+
raise_on_fail=self.strict_wrap,
|
|
100
|
+
)
|
|
101
|
+
new_list.append(wrapped)
|
|
102
|
+
lm.layers = new_list # type: ignore[union-attr]
|
|
103
|
+
return root
|
|
104
|
+
|
|
105
|
+
# Case B: Containers
|
|
106
|
+
if isinstance(root, (nn.Sequential, nn.ModuleList)):
|
|
107
|
+
for i, child in enumerate(list(root)):
|
|
108
|
+
name = str(i)
|
|
109
|
+
child_cfg = qcfg.child(name)
|
|
110
|
+
|
|
111
|
+
wrapped = self._try_wrap(
|
|
112
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
113
|
+
)
|
|
114
|
+
if wrapped is child:
|
|
115
|
+
assert not self.strict_wrap
|
|
116
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
117
|
+
root[i] = wrapped # type: ignore[index]
|
|
118
|
+
|
|
119
|
+
if isinstance(root, nn.ModuleDict):
|
|
120
|
+
for k, child in list(root.items()):
|
|
121
|
+
name = k
|
|
122
|
+
child_cfg = qcfg.child(name)
|
|
123
|
+
|
|
124
|
+
wrapped = self._try_wrap(
|
|
125
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
126
|
+
)
|
|
127
|
+
if wrapped is child:
|
|
128
|
+
assert not self.strict_wrap
|
|
129
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
130
|
+
root[k] = wrapped # type: ignore[index]
|
|
131
|
+
|
|
132
|
+
# Case C: Leaf node
|
|
133
|
+
root_name = getattr(root, "_get_name", lambda: None)()
|
|
134
|
+
wrapped = self._try_wrap(
|
|
135
|
+
root, qcfg, fp_name=root_name, raise_on_fail=self.strict_wrap
|
|
136
|
+
)
|
|
137
|
+
if wrapped is not root:
|
|
138
|
+
return wrapped
|
|
139
|
+
|
|
140
|
+
assert not self.strict_wrap
|
|
141
|
+
# Case D: Named children
|
|
142
|
+
for name, child in list(root.named_children()):
|
|
143
|
+
child_cfg = qcfg.child(name)
|
|
144
|
+
|
|
145
|
+
wrapped = self._try_wrap(
|
|
146
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
147
|
+
)
|
|
148
|
+
if wrapped is child:
|
|
149
|
+
assert not self.strict_wrap
|
|
150
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
151
|
+
setattr(root, name, wrapped)
|
|
152
|
+
|
|
153
|
+
return root
|
|
154
|
+
|
|
155
|
+
def _try_wrap(
|
|
156
|
+
self,
|
|
157
|
+
module: nn.Module,
|
|
158
|
+
qcfg_for_child: PTQConfig,
|
|
159
|
+
*,
|
|
160
|
+
fp_name: Optional[str],
|
|
161
|
+
raise_on_fail: bool,
|
|
162
|
+
) -> nn.Module:
|
|
163
|
+
"""
|
|
164
|
+
Attempt to wrap a boundary with PTQWrapper.
|
|
165
|
+
|
|
166
|
+
Behavior:
|
|
167
|
+
• If PTQWrapper succeeds: return wrapped module.
|
|
168
|
+
• If PTQWrapper raises NotImplementedError:
|
|
169
|
+
- raise_on_fail=True -> re-raise (strict)
|
|
170
|
+
- raise_on_fail=False -> return original module (permissive)
|
|
171
|
+
"""
|
|
172
|
+
try:
|
|
173
|
+
return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
|
|
174
|
+
except NotImplementedError as e:
|
|
175
|
+
if raise_on_fail:
|
|
176
|
+
raise NotImplementedError(
|
|
177
|
+
f"PTQQuantizer: no quantization wrapper for {type(module).__name__}"
|
|
178
|
+
) from e
|
|
179
|
+
return module
|
|
@@ -16,11 +16,9 @@ from typing import Callable, Dict, List, Optional, Tuple
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from tico.
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
QuantModuleBase,
|
|
23
|
-
)
|
|
19
|
+
from tico.quantization.evaluation.metric import MetricCalculator
|
|
20
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
21
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
24
22
|
|
|
25
23
|
|
|
26
24
|
def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
|
|
@@ -98,7 +98,8 @@ def perplexity(
|
|
|
98
98
|
|
|
99
99
|
input_ids = input_ids_full[:, begin:end]
|
|
100
100
|
target_ids = input_ids.clone()
|
|
101
|
-
|
|
101
|
+
# mask previously-scored tokens
|
|
102
|
+
target_ids[:, :-trg_len] = ignore_index # type: ignore[assignment]
|
|
102
103
|
|
|
103
104
|
with torch.no_grad():
|
|
104
105
|
outputs = model(input_ids, labels=target_ids)
|
|
@@ -106,7 +107,7 @@ def perplexity(
|
|
|
106
107
|
neg_log_likelihood = outputs.loss
|
|
107
108
|
|
|
108
109
|
# exact number of labels that contributed to loss
|
|
109
|
-
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item()
|
|
110
|
+
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item() # type: ignore[attr-defined]
|
|
110
111
|
nll_sum += neg_log_likelihood * loss_tokens
|
|
111
112
|
n_tokens += int(loss_tokens)
|
|
112
113
|
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
Q) Why the name "SingleStep"?
|
|
23
|
+
|
|
24
|
+
Fairseq's decoder already advances one token at a time during generation,
|
|
25
|
+
but the default path is "stateful" and "shape-polymorphic": it owns and
|
|
26
|
+
mutates K/V caches internally, prefix lengths and triangular masks grow with
|
|
27
|
+
the step, and beam reordering updates hidden module state. That's friendly
|
|
28
|
+
for eager execution, but hostile to `torch.export` and many accelerator
|
|
29
|
+
backends.
|
|
30
|
+
|
|
31
|
+
This export wrapper makes the per-token call truly "single-step" in the
|
|
32
|
+
export sense: "stateless" and "fixed-shape" so every invocation has the
|
|
33
|
+
exact same graph.
|
|
34
|
+
|
|
35
|
+
Key invariants
|
|
36
|
+
--------------
|
|
37
|
+
• "Stateless": K/V caches come in as explicit inputs and go out as outputs.
|
|
38
|
+
The module does not store or mutate hidden state.
|
|
39
|
+
• "Static shapes": Query is always [B, 1, C]; encoder features and masks
|
|
40
|
+
have fixed, predeclared sizes; K/V slots use fixed capacity (unused tail
|
|
41
|
+
is simply masked/ignored).
|
|
42
|
+
• "External control": Step indexing, cache slot management (append/roll),
|
|
43
|
+
and beam reordering are handled outside the module.
|
|
44
|
+
• "Prebuilt additive masks": Self-attention masks are provided by the
|
|
45
|
+
caller (0 for valid, large negative sentinel, e.g. -120, for masked),
|
|
46
|
+
avoiding data-dependent control flow.
|
|
47
|
+
|
|
48
|
+
In short: still step-wise like fairseq, but restructured for export—no
|
|
49
|
+
internal state, no data-dependent shapes, no dynamic control flow.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
from typing import List, Tuple
|
|
53
|
+
|
|
54
|
+
import torch
|
|
55
|
+
import torch.nn as nn
|
|
56
|
+
|
|
57
|
+
import tico
|
|
58
|
+
|
|
59
|
+
# ----- 1) Export wrapper module -------------------------------------------
|
|
60
|
+
class DecoderExportSingleStep(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Export-only single-step decoder module.
|
|
63
|
+
|
|
64
|
+
Inputs (example shapes; B=1, H=8, Dh=64, C=512, S=64, Tprev=63):
|
|
65
|
+
- prev_x: [B, 1, C] embedded decoder input for the current step
|
|
66
|
+
- enc_x: [S, B, C] encoder hidden states (fixed-length export input)
|
|
67
|
+
- enc_pad_additive: [B, 1, S] additive float key_padding_mask for enc-dec attn (0 for keep, -120 for pad)
|
|
68
|
+
- self_attn_mask: [B, 1, S] additive float mask for decoder self-attn at this step; pass zeros if unused
|
|
69
|
+
- prev_self_k_0..L-1: [B, H, Tprev, Dh] cached self-attn K per layer
|
|
70
|
+
- prev_self_v_0..L-1: [B, H, Tprev, Dh] cached self-attn V per layer
|
|
71
|
+
|
|
72
|
+
Outputs:
|
|
73
|
+
- x_out: [B, 1, C] new decoder features at the current step
|
|
74
|
+
- new_k_0..L-1: [H, B, Dh] per-layer new K (single-timestep; time dim squeezed)
|
|
75
|
+
- new_v_0..L-1: [H, B, Dh] per-layer new V (single-timestep; time dim squeezed)
|
|
76
|
+
|
|
77
|
+
Notes:
|
|
78
|
+
• We keep masks/additive semantics externally to avoid any mask-building inside the graph.
|
|
79
|
+
• We reshape the new K/V from [B,H,1,Dh] -> [H,B,Dh] to match the requested output spec (8,1,64).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, decoder: nn.Module):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.decoder = decoder
|
|
85
|
+
# Cache common meta for assertions
|
|
86
|
+
self.num_layers = len(getattr(decoder, "layers"))
|
|
87
|
+
# Infer heads/head_dim from the wrapped self_attn of layer 0
|
|
88
|
+
any_layer = getattr(decoder.layers[0], "wrapped", decoder.layers[0]) # type: ignore[index]
|
|
89
|
+
mha = getattr(any_layer, "self_attn", None)
|
|
90
|
+
assert mha is not None, "Decoder layer must expose self_attn"
|
|
91
|
+
self.num_heads = int(mha.num_heads)
|
|
92
|
+
self.head_dim = int(mha.head_dim)
|
|
93
|
+
# Embed dim (C)
|
|
94
|
+
self.embed_dim = int(getattr(decoder, "embed_dim"))
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
prev_x: torch.Tensor, # [B,1,C]
|
|
99
|
+
enc_x: torch.Tensor, # [S,B,C]
|
|
100
|
+
enc_pad_additive: torch.Tensor, # [B,1,S]
|
|
101
|
+
*kv_args: torch.Tensor, # prev_k_0..L-1, prev_v_0..L-1 (total 2L tensors)
|
|
102
|
+
self_attn_mask: torch.Tensor, # [B,1,S] (or zeros)
|
|
103
|
+
):
|
|
104
|
+
L = self.num_layers
|
|
105
|
+
H = self.num_heads
|
|
106
|
+
Dh = self.head_dim
|
|
107
|
+
B, one, C = prev_x.shape
|
|
108
|
+
S, B2, C2 = enc_x.shape
|
|
109
|
+
assert (
|
|
110
|
+
one == 1 and C == self.embed_dim and B == B2 and C2 == C
|
|
111
|
+
), "Shape mismatch in prev_x/enc_x"
|
|
112
|
+
assert len(kv_args) == 2 * L, f"Expected {2*L} KV tensors, got {len(kv_args)}"
|
|
113
|
+
|
|
114
|
+
# Unpack previous self-attn caches
|
|
115
|
+
prev_k_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
116
|
+
prev_v_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
117
|
+
for i in range(L):
|
|
118
|
+
prev_k_list.append(kv_args[2 * i])
|
|
119
|
+
prev_v_list.append(kv_args[2 * i + 1])
|
|
120
|
+
for i in range(L):
|
|
121
|
+
assert (
|
|
122
|
+
prev_k_list[i].dim() == 4 and prev_v_list[i].dim() == 4
|
|
123
|
+
), "KV must be [B,H,Tprev,Dh]"
|
|
124
|
+
assert (
|
|
125
|
+
prev_k_list[i].shape[0] == B
|
|
126
|
+
and prev_k_list[i].shape[1] == H
|
|
127
|
+
and prev_k_list[i].shape[3] == Dh
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Call decoder's external single-step path
|
|
131
|
+
# Returns:
|
|
132
|
+
# x_step: [B,1,C]
|
|
133
|
+
# newk/newv: lists of length L, each [B*H,1,Dh]
|
|
134
|
+
x_step, newk_list, newv_list = self.decoder.forward_external_step( # type: ignore[operator]
|
|
135
|
+
prev_output_x=prev_x,
|
|
136
|
+
encoder_out_x=enc_x,
|
|
137
|
+
encoder_padding_mask=enc_pad_additive,
|
|
138
|
+
self_attn_mask=self_attn_mask,
|
|
139
|
+
prev_self_k_list=prev_k_list,
|
|
140
|
+
prev_self_v_list=prev_v_list,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
out_tensors: List[torch.Tensor] = [
|
|
144
|
+
x_step
|
|
145
|
+
] # first output is the new decoder features
|
|
146
|
+
for i in range(L):
|
|
147
|
+
nk = newk_list[i] # [B*H, Tnew, Dh]
|
|
148
|
+
nv = newv_list[i] # [B*H, Tnew, Dh]
|
|
149
|
+
out_tensors.append(nk)
|
|
150
|
+
out_tensors.append(nv)
|
|
151
|
+
|
|
152
|
+
# Return tuple: (x_step, new_k_0, new_v_0, new_k_1, new_v_1, ..., new_k_{L-1}, new_v_{L-1})
|
|
153
|
+
return tuple(out_tensors)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# ----- 2) Example inputs (B=1, S=64, H=8, Dh=64, C=512, L=4) ---------------
|
|
157
|
+
def make_example_inputs(*, L=4, B=1, S=64, H=8, Dh=64, C=512, Tprev=63, device="cpu"):
|
|
158
|
+
"""
|
|
159
|
+
Build example tensors that match the export I/O spec.
|
|
160
|
+
Shapes follow the request:
|
|
161
|
+
prev_x: [1,1,512]
|
|
162
|
+
enc_x: [64,1,512]
|
|
163
|
+
enc_pad_additive: [1,1,64] (additive float; zeros -> keep)
|
|
164
|
+
prev_k_i / prev_v_i (for i in 0..L-1): [1,8,63,64]
|
|
165
|
+
self_attn_mask: [1,1,64] (additive float; zeros -> keep)
|
|
166
|
+
"""
|
|
167
|
+
g = torch.Generator(device=device).manual_seed(0)
|
|
168
|
+
|
|
169
|
+
prev_x = torch.randn(B, 1, C, device=device, dtype=torch.float32, generator=g)
|
|
170
|
+
enc_x = torch.randn(S, B, C, device=device, dtype=torch.float32, generator=g)
|
|
171
|
+
|
|
172
|
+
# Additive masks (0 for allowed, -120 for masked)
|
|
173
|
+
enc_pad_additive = torch.full((B, 1, S), float(-120), device=device)
|
|
174
|
+
self_attn_mask = torch.full((B, 1, S), float(-120), device=device)
|
|
175
|
+
enc_pad_additive[0, :27] = 0 # 27 is a random example.
|
|
176
|
+
self_attn_mask[0, :27] = 0 # 27 is a random example.
|
|
177
|
+
|
|
178
|
+
# Previous self-attn caches for each layer
|
|
179
|
+
prev_k_list = []
|
|
180
|
+
prev_v_list = []
|
|
181
|
+
for _ in range(L):
|
|
182
|
+
prev_k = torch.randn(
|
|
183
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
184
|
+
)
|
|
185
|
+
prev_v = torch.randn(
|
|
186
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
187
|
+
)
|
|
188
|
+
prev_k_list.append(prev_k)
|
|
189
|
+
prev_v_list.append(prev_v)
|
|
190
|
+
|
|
191
|
+
# Pack inputs as the export function will expect:
|
|
192
|
+
# (prev_x, enc_x, enc_pad_additive, self_attn_mask, prev_k_0..L-1, prev_v_0..L-1)
|
|
193
|
+
example_args: Tuple[torch.Tensor, ...] = (
|
|
194
|
+
prev_x,
|
|
195
|
+
enc_x,
|
|
196
|
+
enc_pad_additive,
|
|
197
|
+
*prev_k_list,
|
|
198
|
+
*prev_v_list,
|
|
199
|
+
)
|
|
200
|
+
example_kwargs = {"self_attn_mask": self_attn_mask}
|
|
201
|
+
return example_args, example_kwargs
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# ----- 3) Export driver -----------------------------------------------------
|
|
205
|
+
def export_decoder_single_step(translator, *, save_path="decoder_step_export.circle"):
|
|
206
|
+
"""
|
|
207
|
+
Wrap the QuantFairseqDecoder into the export-friendly single-step module
|
|
208
|
+
and export with torch.export.export using example inputs.
|
|
209
|
+
"""
|
|
210
|
+
# Grab the wrapped decoder
|
|
211
|
+
dec = translator.models[
|
|
212
|
+
0
|
|
213
|
+
].decoder # assumed QuantFairseqDecoder with forward_external_step
|
|
214
|
+
# Build export wrapper
|
|
215
|
+
wrapper = DecoderExportSingleStep(decoder=dec).eval()
|
|
216
|
+
|
|
217
|
+
# Example inputs (L inferred from wrapper/decoder)
|
|
218
|
+
L = wrapper.num_layers
|
|
219
|
+
H = wrapper.num_heads
|
|
220
|
+
Dh = wrapper.head_dim
|
|
221
|
+
C = wrapper.embed_dim
|
|
222
|
+
example_inputs, example_kwargs = make_example_inputs(L=L, H=H, Dh=Dh, C=C)
|
|
223
|
+
|
|
224
|
+
# Export circle (no dynamism assumed; shapes are fixed for export)
|
|
225
|
+
cm = tico.convert(
|
|
226
|
+
wrapper,
|
|
227
|
+
args=example_inputs,
|
|
228
|
+
kwargs=example_kwargs,
|
|
229
|
+
strict=True, # fail if something cannot be captured
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Save .pte
|
|
233
|
+
cm.save(save_path)
|
|
234
|
+
print(f"Saved decoder single-step export to: {save_path}")
|