tico 0.1.0.dev250924__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +12 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +4 -4
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +6 -10
- tico/quantization/config/fpi_gptq.py +29 -0
- tico/{experimental/quantization → quantization}/config/gptq.py +1 -1
- tico/{experimental/quantization → quantization}/config/pt2e.py +1 -1
- tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
- tico/{experimental/quantization → quantization}/config/smoothquant.py +1 -1
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +1 -3
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +7 -7
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/{experimental/quantization → quantization}/quantizer_registry.py +11 -10
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/compare_ppl.py +8 -19
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/debug_quant_outputs.py +9 -24
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_with_gptq.py +14 -35
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/__init__.py +1 -1
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_decoder.py +6 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_decoder_layer.py +6 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_encoder.py +6 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_encoder_layer.py +6 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_mha.py +5 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +5 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +8 -12
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -16
- tico/utils/convert.py +9 -14
- {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
- {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +113 -108
- tico/experimental/quantization/__init__.py +0 -6
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization/config → quantization/algorithm/smoothquant}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/smooth_quant.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/config/base.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/passes}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- /tico/{experimental/quantization/ptq/examples → quantization/wrapq}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
- /tico/{experimental/quantization/ptq/observers → quantization/wrapq/examples}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
- /tico/{experimental/quantization/ptq/utils → quantization/wrapq/observers}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/utils}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/wrappers}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/decoder_export_single_step.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers/llama}/__init__.py +0 -0
- {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
|
@@ -33,16 +33,14 @@ import tqdm
|
|
|
33
33
|
from datasets import load_dataset
|
|
34
34
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
35
35
|
|
|
36
|
-
from tico.
|
|
37
|
-
from tico.
|
|
38
|
-
from tico.
|
|
39
|
-
from tico.
|
|
40
|
-
from tico.
|
|
41
|
-
from tico.
|
|
42
|
-
from tico.
|
|
43
|
-
from tico.
|
|
44
|
-
QuantModuleBase,
|
|
45
|
-
)
|
|
36
|
+
from tico.quantization import convert, prepare
|
|
37
|
+
from tico.quantization.config.gptq import GPTQConfig
|
|
38
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
39
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
40
|
+
from tico.quantization.wrapq.utils.introspection import build_fqn_map
|
|
41
|
+
from tico.quantization.wrapq.utils.metrics import perplexity
|
|
42
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
43
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
46
44
|
|
|
47
45
|
|
|
48
46
|
# Token-budget presets for activation calibration
|
|
@@ -215,22 +213,8 @@ def main():
|
|
|
215
213
|
# 4. Wrap every layer with PTQWrapper (activation UINT-8)
|
|
216
214
|
# -------------------------------------------------------------------------
|
|
217
215
|
print("Wrapping layers with PTQWrapper …")
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
raise TypeError(f"'model.layers' must be list/ModuleList, got {type(layers)}")
|
|
221
|
-
|
|
222
|
-
qcfg = QuantConfig() # default: per-tensor UINT8
|
|
223
|
-
wrapped = torch.nn.ModuleList()
|
|
224
|
-
for idx, fp_layer in enumerate(layers):
|
|
225
|
-
layer_cfg = qcfg.child(f"layer{idx}")
|
|
226
|
-
wrapped.append(
|
|
227
|
-
PTQWrapper(
|
|
228
|
-
fp_layer,
|
|
229
|
-
qcfg=layer_cfg,
|
|
230
|
-
fp_name=m_to_fqn.get(fp_layer),
|
|
231
|
-
)
|
|
232
|
-
)
|
|
233
|
-
q_m.model.layers = wrapped
|
|
216
|
+
qcfg = PTQConfig() # default: per-tensor UINT8
|
|
217
|
+
prepare(q_m, qcfg)
|
|
234
218
|
|
|
235
219
|
# -------------------------------------------------------------------------
|
|
236
220
|
# 5. Single-pass activation calibration
|
|
@@ -242,11 +226,7 @@ def main():
|
|
|
242
226
|
calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
|
|
243
227
|
train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
244
228
|
|
|
245
|
-
#
|
|
246
|
-
for l in q_m.model.layers:
|
|
247
|
-
l.enable_calibration()
|
|
248
|
-
|
|
249
|
-
# (b) Overwrite weight observers with GPTQ statistics
|
|
229
|
+
# Overwrite weight observers with GPTQ statistics
|
|
250
230
|
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
|
|
251
231
|
inject_gptq_qparams(q_m, q_m.quantizers)
|
|
252
232
|
else:
|
|
@@ -254,7 +234,7 @@ def main():
|
|
|
254
234
|
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
|
|
255
235
|
)
|
|
256
236
|
|
|
257
|
-
#
|
|
237
|
+
# Forward passes to collect activation ranges
|
|
258
238
|
iterator = range(0, train_ids.size(1) - 1, args.stride)
|
|
259
239
|
if not args.no_tqdm:
|
|
260
240
|
iterator = tqdm.tqdm(iterator, desc="Act-calibration")
|
|
@@ -262,9 +242,8 @@ def main():
|
|
|
262
242
|
for i in iterator:
|
|
263
243
|
q_m(train_ids[:, i : i + args.stride])
|
|
264
244
|
|
|
265
|
-
#
|
|
266
|
-
|
|
267
|
-
l.freeze_qparams()
|
|
245
|
+
# Freeze all Q-params (scale, zero-point)
|
|
246
|
+
convert(q_m)
|
|
268
247
|
|
|
269
248
|
# -------------------------------------------------------------------------
|
|
270
249
|
# 6. Evaluate perplexity on Wikitext-2
|
|
@@ -17,9 +17,9 @@ from typing import Optional, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
20
|
+
from tico.quantization.wrapq.dtypes import DType, UINT8
|
|
21
|
+
from tico.quantization.wrapq.observers.base import ObserverBase
|
|
22
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class AffineObserverBase(ObserverBase):
|
|
@@ -17,8 +17,8 @@ from typing import Optional, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
20
|
+
from tico.quantization.wrapq.dtypes import DType, UINT8
|
|
21
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class ObserverBase(ABC):
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
-
from tico.
|
|
18
|
-
from tico.
|
|
17
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
18
|
+
from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class EMAObserver(AffineObserverBase):
|
|
@@ -24,7 +24,7 @@ performing any statistics gathering or fake-quantization.
|
|
|
24
24
|
"""
|
|
25
25
|
import torch
|
|
26
26
|
|
|
27
|
-
from tico.
|
|
27
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class IdentityObserver(AffineObserverBase):
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
-
from tico.
|
|
18
|
-
from tico.
|
|
17
|
+
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
|
|
18
|
+
from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class MinMaxObserver(AffineObserverBase):
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
22
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
23
|
+
|
|
24
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
25
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@register_quantizer(PTQConfig)
|
|
29
|
+
class PTQQuantizer(BaseQuantizer):
|
|
30
|
+
"""
|
|
31
|
+
Post-Training Quantization (PTQ) quantizer integrated with the public interface.
|
|
32
|
+
|
|
33
|
+
Features
|
|
34
|
+
--------
|
|
35
|
+
• Automatically wraps quantizable modules using PTQWrapper.
|
|
36
|
+
• Supports leaf-level (single-module) quantization (e.g., prepare(model.fc, PTQConfig())).
|
|
37
|
+
• Enforces strict wrapping if `strict_wrap=True`: raises NotImplementedError if
|
|
38
|
+
no quantizable module was found at any boundary.
|
|
39
|
+
• If `strict_wrap=False`, unquantizable modules are silently skipped.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, config: PTQConfig):
|
|
43
|
+
super().__init__(config)
|
|
44
|
+
self.qcfg: PTQConfig = config
|
|
45
|
+
self.strict_wrap: bool = bool(getattr(config, "strict_wrap", True))
|
|
46
|
+
|
|
47
|
+
@torch.no_grad()
|
|
48
|
+
def prepare(
|
|
49
|
+
self,
|
|
50
|
+
model: torch.nn.Module,
|
|
51
|
+
args: Optional[Any] = None,
|
|
52
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
53
|
+
):
|
|
54
|
+
# Wrap the tree (or single module) according to strictness policy
|
|
55
|
+
model = self._wrap_supported(model, self.qcfg)
|
|
56
|
+
|
|
57
|
+
# Switch all quant modules into calibration mode
|
|
58
|
+
if isinstance(model, QuantModuleBase):
|
|
59
|
+
model.enable_calibration()
|
|
60
|
+
for m in model.modules():
|
|
61
|
+
if isinstance(m, QuantModuleBase):
|
|
62
|
+
m.enable_calibration()
|
|
63
|
+
return model
|
|
64
|
+
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def convert(self, model):
|
|
67
|
+
# Freeze qparams across the tree (QUANT mode)
|
|
68
|
+
if isinstance(model, QuantModuleBase):
|
|
69
|
+
model.freeze_qparams()
|
|
70
|
+
for m in model.modules():
|
|
71
|
+
if isinstance(m, QuantModuleBase):
|
|
72
|
+
m.freeze_qparams()
|
|
73
|
+
return model
|
|
74
|
+
|
|
75
|
+
def _wrap_supported(
|
|
76
|
+
self,
|
|
77
|
+
root: nn.Module,
|
|
78
|
+
qcfg: PTQConfig,
|
|
79
|
+
) -> nn.Module:
|
|
80
|
+
"""
|
|
81
|
+
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
|
|
82
|
+
"""
|
|
83
|
+
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
|
|
84
|
+
|
|
85
|
+
# Case A: HuggingFace-style transformers: model.model.layers
|
|
86
|
+
lm = getattr(root, "model", None)
|
|
87
|
+
layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
|
|
88
|
+
if isinstance(layers, nn.ModuleList):
|
|
89
|
+
new_list = nn.ModuleList()
|
|
90
|
+
for idx, layer in enumerate(layers):
|
|
91
|
+
child_scope = f"layer{idx}"
|
|
92
|
+
child_cfg = qcfg.child(child_scope)
|
|
93
|
+
|
|
94
|
+
# Enforce strictness at the child boundary
|
|
95
|
+
wrapped = self._try_wrap(
|
|
96
|
+
layer,
|
|
97
|
+
child_cfg,
|
|
98
|
+
fp_name=child_scope,
|
|
99
|
+
raise_on_fail=self.strict_wrap,
|
|
100
|
+
)
|
|
101
|
+
new_list.append(wrapped)
|
|
102
|
+
lm.layers = new_list # type: ignore[union-attr]
|
|
103
|
+
return root
|
|
104
|
+
|
|
105
|
+
# Case B: Containers
|
|
106
|
+
if isinstance(root, (nn.Sequential, nn.ModuleList)):
|
|
107
|
+
for i, child in enumerate(list(root)):
|
|
108
|
+
name = str(i)
|
|
109
|
+
child_cfg = qcfg.child(name)
|
|
110
|
+
|
|
111
|
+
wrapped = self._try_wrap(
|
|
112
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
113
|
+
)
|
|
114
|
+
if wrapped is child:
|
|
115
|
+
assert not self.strict_wrap
|
|
116
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
117
|
+
root[i] = wrapped # type: ignore[index]
|
|
118
|
+
|
|
119
|
+
if isinstance(root, nn.ModuleDict):
|
|
120
|
+
for k, child in list(root.items()):
|
|
121
|
+
name = k
|
|
122
|
+
child_cfg = qcfg.child(name)
|
|
123
|
+
|
|
124
|
+
wrapped = self._try_wrap(
|
|
125
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
126
|
+
)
|
|
127
|
+
if wrapped is child:
|
|
128
|
+
assert not self.strict_wrap
|
|
129
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
130
|
+
root[k] = wrapped # type: ignore[index]
|
|
131
|
+
|
|
132
|
+
# Case C: Leaf node
|
|
133
|
+
root_name = getattr(root, "_get_name", lambda: None)()
|
|
134
|
+
wrapped = self._try_wrap(
|
|
135
|
+
root, qcfg, fp_name=root_name, raise_on_fail=self.strict_wrap
|
|
136
|
+
)
|
|
137
|
+
if wrapped is not root:
|
|
138
|
+
return wrapped
|
|
139
|
+
|
|
140
|
+
assert not self.strict_wrap
|
|
141
|
+
# Case D: Named children
|
|
142
|
+
for name, child in list(root.named_children()):
|
|
143
|
+
child_cfg = qcfg.child(name)
|
|
144
|
+
|
|
145
|
+
wrapped = self._try_wrap(
|
|
146
|
+
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
|
|
147
|
+
)
|
|
148
|
+
if wrapped is child:
|
|
149
|
+
assert not self.strict_wrap
|
|
150
|
+
wrapped = self._wrap_supported(wrapped, child_cfg)
|
|
151
|
+
setattr(root, name, wrapped)
|
|
152
|
+
|
|
153
|
+
return root
|
|
154
|
+
|
|
155
|
+
def _try_wrap(
|
|
156
|
+
self,
|
|
157
|
+
module: nn.Module,
|
|
158
|
+
qcfg_for_child: PTQConfig,
|
|
159
|
+
*,
|
|
160
|
+
fp_name: Optional[str],
|
|
161
|
+
raise_on_fail: bool,
|
|
162
|
+
) -> nn.Module:
|
|
163
|
+
"""
|
|
164
|
+
Attempt to wrap a boundary with PTQWrapper.
|
|
165
|
+
|
|
166
|
+
Behavior:
|
|
167
|
+
• If PTQWrapper succeeds: return wrapped module.
|
|
168
|
+
• If PTQWrapper raises NotImplementedError:
|
|
169
|
+
- raise_on_fail=True -> re-raise (strict)
|
|
170
|
+
- raise_on_fail=False -> return original module (permissive)
|
|
171
|
+
"""
|
|
172
|
+
try:
|
|
173
|
+
return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
|
|
174
|
+
except NotImplementedError as e:
|
|
175
|
+
if raise_on_fail:
|
|
176
|
+
raise NotImplementedError(
|
|
177
|
+
f"PTQQuantizer: no quantization wrapper for {type(module).__name__}"
|
|
178
|
+
) from e
|
|
179
|
+
return module
|
|
@@ -16,11 +16,9 @@ from typing import Callable, Dict, List, Optional, Tuple
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from tico.
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
QuantModuleBase,
|
|
23
|
-
)
|
|
19
|
+
from tico.quantization.evaluation.metric import MetricCalculator
|
|
20
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
21
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
24
22
|
|
|
25
23
|
|
|
26
24
|
def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
|
|
@@ -98,7 +98,8 @@ def perplexity(
|
|
|
98
98
|
|
|
99
99
|
input_ids = input_ids_full[:, begin:end]
|
|
100
100
|
target_ids = input_ids.clone()
|
|
101
|
-
|
|
101
|
+
# mask previously-scored tokens
|
|
102
|
+
target_ids[:, :-trg_len] = ignore_index # type: ignore[assignment]
|
|
102
103
|
|
|
103
104
|
with torch.no_grad():
|
|
104
105
|
outputs = model(input_ids, labels=target_ids)
|
|
@@ -106,7 +107,7 @@ def perplexity(
|
|
|
106
107
|
neg_log_likelihood = outputs.loss
|
|
107
108
|
|
|
108
109
|
# exact number of labels that contributed to loss
|
|
109
|
-
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item()
|
|
110
|
+
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item() # type: ignore[attr-defined]
|
|
110
111
|
nll_sum += neg_log_likelihood * loss_tokens
|
|
111
112
|
n_tokens += int(loss_tokens)
|
|
112
113
|
|
|
@@ -25,12 +25,10 @@ import torch
|
|
|
25
25
|
import torch.nn.functional as F
|
|
26
26
|
from torch import nn, Tensor
|
|
27
27
|
|
|
28
|
-
from tico.
|
|
29
|
-
from tico.
|
|
30
|
-
from tico.
|
|
31
|
-
|
|
32
|
-
)
|
|
33
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
28
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
29
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
30
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
31
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
34
32
|
|
|
35
33
|
|
|
36
34
|
@try_register("fairseq.models.transformer.TransformerDecoderBase")
|
|
@@ -53,7 +51,7 @@ class QuantFairseqDecoder(QuantModuleBase):
|
|
|
53
51
|
self,
|
|
54
52
|
fp_decoder: nn.Module,
|
|
55
53
|
*,
|
|
56
|
-
qcfg: Optional[
|
|
54
|
+
qcfg: Optional[PTQConfig] = None,
|
|
57
55
|
fp_name: Optional[str] = None,
|
|
58
56
|
):
|
|
59
57
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -116,7 +114,7 @@ class QuantFairseqDecoder(QuantModuleBase):
|
|
|
116
114
|
|
|
117
115
|
prefix = _safe_prefix(fp_name)
|
|
118
116
|
|
|
119
|
-
# Prepare child
|
|
117
|
+
# Prepare child PTQConfig namespaces: layers/<idx>
|
|
120
118
|
layers_qcfg = qcfg.child("layers") if qcfg else None
|
|
121
119
|
for i, layer in enumerate(fp_layers):
|
|
122
120
|
child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
|
tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_decoder_layer.py
RENAMED
|
@@ -23,15 +23,13 @@ from typing import Dict, Iterable, List, Optional, Tuple
|
|
|
23
23
|
import torch
|
|
24
24
|
from torch import nn, Tensor
|
|
25
25
|
|
|
26
|
-
from tico.
|
|
27
|
-
from tico.
|
|
26
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
27
|
+
from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
|
|
28
28
|
QuantFairseqMultiheadAttention,
|
|
29
29
|
)
|
|
30
|
-
from tico.
|
|
31
|
-
from tico.
|
|
32
|
-
|
|
33
|
-
)
|
|
34
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
30
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
31
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
32
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
35
33
|
|
|
36
34
|
|
|
37
35
|
@try_register("fairseq.modules.transformer_layer.TransformerDecoderLayerBase")
|
|
@@ -55,7 +53,7 @@ class QuantFairseqDecoderLayer(QuantModuleBase):
|
|
|
55
53
|
self,
|
|
56
54
|
fp_layer: nn.Module,
|
|
57
55
|
*,
|
|
58
|
-
qcfg: Optional[
|
|
56
|
+
qcfg: Optional[PTQConfig] = None,
|
|
59
57
|
fp_name: Optional[str] = None,
|
|
60
58
|
):
|
|
61
59
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -25,12 +25,10 @@ import torch
|
|
|
25
25
|
import torch.nn as nn
|
|
26
26
|
from torch import Tensor
|
|
27
27
|
|
|
28
|
-
from tico.
|
|
29
|
-
from tico.
|
|
30
|
-
from tico.
|
|
31
|
-
|
|
32
|
-
)
|
|
33
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
28
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
29
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
30
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
31
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
34
32
|
|
|
35
33
|
|
|
36
34
|
@try_register("fairseq.models.transformer.TransformerEncoderBase")
|
|
@@ -56,7 +54,7 @@ class QuantFairseqEncoder(QuantModuleBase):
|
|
|
56
54
|
self,
|
|
57
55
|
fp_encoder: nn.Module,
|
|
58
56
|
*,
|
|
59
|
-
qcfg: Optional[
|
|
57
|
+
qcfg: Optional[PTQConfig] = None,
|
|
60
58
|
fp_name: Optional[str] = None,
|
|
61
59
|
use_external_inputs: bool = False, # export-mode flag
|
|
62
60
|
return_type: Literal["tensor", "dict"] = "dict",
|
|
@@ -100,7 +98,7 @@ class QuantFairseqEncoder(QuantModuleBase):
|
|
|
100
98
|
fp_layers = list(fp_encoder.layers) # type: ignore[arg-type]
|
|
101
99
|
self.layers = nn.ModuleList()
|
|
102
100
|
|
|
103
|
-
# Prepare child
|
|
101
|
+
# Prepare child PTQConfig namespaces: layers/<idx>
|
|
104
102
|
layers_qcfg = qcfg.child("layers") if qcfg else None
|
|
105
103
|
for i, layer in enumerate(fp_layers):
|
|
106
104
|
child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
|
tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_encoder_layer.py
RENAMED
|
@@ -23,15 +23,13 @@ from typing import Optional
|
|
|
23
23
|
import torch.nn as nn
|
|
24
24
|
from torch import Tensor
|
|
25
25
|
|
|
26
|
-
from tico.
|
|
27
|
-
from tico.
|
|
26
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
27
|
+
from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
|
|
28
28
|
QuantFairseqMultiheadAttention,
|
|
29
29
|
)
|
|
30
|
-
from tico.
|
|
31
|
-
from tico.
|
|
32
|
-
|
|
33
|
-
)
|
|
34
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
30
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
31
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
32
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
35
33
|
|
|
36
34
|
|
|
37
35
|
@try_register("fairseq.modules.transformer_layer.TransformerEncoderLayerBase")
|
|
@@ -49,7 +47,7 @@ class QuantFairseqEncoderLayer(QuantModuleBase):
|
|
|
49
47
|
self,
|
|
50
48
|
fp_layer: nn.Module,
|
|
51
49
|
*,
|
|
52
|
-
qcfg: Optional[
|
|
50
|
+
qcfg: Optional[PTQConfig] = None,
|
|
53
51
|
fp_name: Optional[str] = None,
|
|
54
52
|
):
|
|
55
53
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -24,12 +24,10 @@ import torch
|
|
|
24
24
|
import torch.nn as nn
|
|
25
25
|
import torch.nn.functional as F
|
|
26
26
|
|
|
27
|
-
from tico.
|
|
28
|
-
from tico.
|
|
29
|
-
from tico.
|
|
30
|
-
|
|
31
|
-
)
|
|
32
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
27
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
28
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
29
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
30
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
33
31
|
|
|
34
32
|
|
|
35
33
|
@try_register("fairseq.modules.multihead_attention.MultiheadAttention")
|
|
@@ -59,7 +57,7 @@ class QuantFairseqMultiheadAttention(QuantModuleBase):
|
|
|
59
57
|
self,
|
|
60
58
|
fp_attn: nn.Module,
|
|
61
59
|
*,
|
|
62
|
-
qcfg: Optional[
|
|
60
|
+
qcfg: Optional[PTQConfig] = None,
|
|
63
61
|
fp_name: Optional[str] = None,
|
|
64
62
|
max_seq: int = 4096,
|
|
65
63
|
use_static_causal: bool = False,
|
|
@@ -17,12 +17,10 @@ from typing import Optional, Tuple
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
)
|
|
25
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
22
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
23
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
@try_register(
|
|
@@ -34,7 +32,7 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
34
32
|
self,
|
|
35
33
|
fp_attn: nn.Module,
|
|
36
34
|
*,
|
|
37
|
-
qcfg: Optional[
|
|
35
|
+
qcfg: Optional[PTQConfig] = None,
|
|
38
36
|
fp_name: Optional[str] = None,
|
|
39
37
|
):
|
|
40
38
|
super().__init__(qcfg, fp_name=fp_name)
|
tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py
RENAMED
|
@@ -17,16 +17,12 @@ from typing import Optional, Tuple
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
from tico.
|
|
25
|
-
from tico.
|
|
26
|
-
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
27
|
-
QuantModuleBase,
|
|
28
|
-
)
|
|
29
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
|
|
22
|
+
from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
|
23
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
24
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
25
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
30
26
|
|
|
31
27
|
|
|
32
28
|
@try_register("transformers.models.llama.modeling_llama.LlamaDecoderLayer")
|
|
@@ -56,7 +52,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
56
52
|
self,
|
|
57
53
|
fp_layer: nn.Module,
|
|
58
54
|
*,
|
|
59
|
-
qcfg: Optional[
|
|
55
|
+
qcfg: Optional[PTQConfig] = None,
|
|
60
56
|
fp_name: Optional[str] = None,
|
|
61
57
|
return_type: Optional[str] = None,
|
|
62
58
|
):
|
|
@@ -165,7 +161,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
165
161
|
# - If use_cache: always return (hidden_states, present_key_value)
|
|
166
162
|
# - Else: return as configured (tuple/tensor) for HF compatibility
|
|
167
163
|
if use_cache:
|
|
168
|
-
return hidden_states, present_key_value
|
|
164
|
+
return hidden_states, present_key_value # type: ignore[return-value]
|
|
169
165
|
|
|
170
166
|
if self.return_type == "tuple":
|
|
171
167
|
return (hidden_states,)
|
|
@@ -17,12 +17,10 @@ from typing import Optional
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
)
|
|
25
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
22
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
23
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
@try_register("transformers.models.llama.modeling_llama.LlamaMLP")
|
|
@@ -31,7 +29,7 @@ class QuantLlamaMLP(QuantModuleBase):
|
|
|
31
29
|
self,
|
|
32
30
|
mlp_fp: nn.Module,
|
|
33
31
|
*,
|
|
34
|
-
qcfg: Optional[
|
|
32
|
+
qcfg: Optional[PTQConfig] = None,
|
|
35
33
|
fp_name: Optional[str] = None,
|
|
36
34
|
):
|
|
37
35
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -17,12 +17,11 @@ from typing import Iterable, Optional, Tuple
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from tico.experimental.quantization.ptq.wrappers.registry import register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
|
|
22
|
+
from tico.quantization.wrapq.mode import Mode
|
|
23
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
24
|
+
from tico.quantization.wrapq.wrappers.registry import register
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
@register(nn.LayerNorm)
|
|
@@ -46,7 +45,7 @@ class QuantLayerNorm(QuantModuleBase):
|
|
|
46
45
|
self,
|
|
47
46
|
fp: nn.LayerNorm,
|
|
48
47
|
*,
|
|
49
|
-
qcfg: Optional[
|
|
48
|
+
qcfg: Optional[PTQConfig] = None,
|
|
50
49
|
fp_name: Optional[str] = None
|
|
51
50
|
):
|
|
52
51
|
super().__init__(qcfg, fp_name=fp_name)
|