tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
|
|
22
|
+
from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
|
23
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
24
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
25
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@try_register("transformers.models.llama.modeling_llama.LlamaDecoderLayer")
|
|
29
|
+
class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
30
|
+
"""
|
|
31
|
+
Quant-aware drop-in replacement for HF `LlamaDecoderLayer`.
|
|
32
|
+
Signature and return-value are identical to the original.
|
|
33
|
+
|
|
34
|
+
▸ Attention & MLP blocks are replaced by their quantized counterparts
|
|
35
|
+
▸ LayerNorms remain FP32 (no fake-quant)
|
|
36
|
+
▸ A "static" causal mask is pre-built in `__init__` to avoid
|
|
37
|
+
dynamic boolean-to-float casts inside `forward`.
|
|
38
|
+
|
|
39
|
+
Notes on the causal mask
|
|
40
|
+
------------------------
|
|
41
|
+
Building a boolean mask "inside" `forward` would introduce
|
|
42
|
+
non-deterministic dynamic ops that an integer-only accelerator cannot
|
|
43
|
+
fuse easily. Therefore we:
|
|
44
|
+
|
|
45
|
+
1. Pre-compute a full upper-triangular mask of size
|
|
46
|
+
`[1, 1, max_seq, max_seq]` in `__init__`.
|
|
47
|
+
2. In `forward`, if the caller passes `attention_mask=None`, we
|
|
48
|
+
slice the pre-computed template to the current sequence length.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
fp_layer: nn.Module,
|
|
54
|
+
*,
|
|
55
|
+
qcfg: Optional[PTQConfig] = None,
|
|
56
|
+
fp_name: Optional[str] = None,
|
|
57
|
+
return_type: Optional[str] = None,
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Q) Why do we need `return_type`?
|
|
61
|
+
A) Different versions of `transformers` wrap the decoder output in
|
|
62
|
+
different containers: a plain Tensor or a tuple.
|
|
63
|
+
"""
|
|
64
|
+
self.return_type = return_type
|
|
65
|
+
if self.return_type is None:
|
|
66
|
+
import transformers
|
|
67
|
+
|
|
68
|
+
v = tuple(map(int, transformers.__version__.split(".")[:2]))
|
|
69
|
+
self.return_type = "tensor" if v >= (4, 54) else "tuple"
|
|
70
|
+
assert self.return_type is not None
|
|
71
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
72
|
+
|
|
73
|
+
# Child QuantConfigs -------------------------------------------------
|
|
74
|
+
attn_cfg = qcfg.child("self_attn") if qcfg else None
|
|
75
|
+
mlp_cfg = qcfg.child("mlp") if qcfg else None
|
|
76
|
+
|
|
77
|
+
# Quantized sub-modules ---------------------------------------------
|
|
78
|
+
assert hasattr(fp_layer, "self_attn") and isinstance(
|
|
79
|
+
fp_layer.self_attn, torch.nn.Module
|
|
80
|
+
)
|
|
81
|
+
assert hasattr(fp_layer, "mlp") and isinstance(fp_layer.mlp, torch.nn.Module)
|
|
82
|
+
self.self_attn = PTQWrapper(
|
|
83
|
+
fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{fp_name}.self_attn"
|
|
84
|
+
)
|
|
85
|
+
self.mlp = PTQWrapper(fp_layer.mlp, qcfg=mlp_cfg, fp_name=f"{fp_name}.mlp")
|
|
86
|
+
|
|
87
|
+
# LayerNorms remain FP (copied from fp_layer to keep weights)
|
|
88
|
+
assert hasattr(fp_layer, "input_layernorm") and isinstance(
|
|
89
|
+
fp_layer.input_layernorm, torch.nn.Module
|
|
90
|
+
)
|
|
91
|
+
assert hasattr(fp_layer, "post_attention_layernorm") and isinstance(
|
|
92
|
+
fp_layer.post_attention_layernorm, torch.nn.Module
|
|
93
|
+
)
|
|
94
|
+
self.input_layernorm = fp_layer.input_layernorm
|
|
95
|
+
self.post_attention_layernorm = fp_layer.post_attention_layernorm
|
|
96
|
+
|
|
97
|
+
# Static causal mask template ---------------------------------------
|
|
98
|
+
assert hasattr(fp_layer.self_attn, "config") and hasattr(
|
|
99
|
+
fp_layer.self_attn.config, "max_position_embeddings"
|
|
100
|
+
)
|
|
101
|
+
assert isinstance(fp_layer.self_attn.config.max_position_embeddings, int)
|
|
102
|
+
max_seq = fp_layer.self_attn.config.max_position_embeddings
|
|
103
|
+
mask = torch.full((1, 1, max_seq, max_seq), float("-120"))
|
|
104
|
+
mask.triu_(1)
|
|
105
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
|
106
|
+
|
|
107
|
+
def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
108
|
+
"""Return `[1,1,L,L]` causal mask slice on *device*."""
|
|
109
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
110
|
+
return self.causal_mask_template[..., :seq_len, :seq_len].to(device)
|
|
111
|
+
|
|
112
|
+
def forward(
|
|
113
|
+
self,
|
|
114
|
+
hidden_states: torch.Tensor,
|
|
115
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
116
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
117
|
+
past_key_value: Optional["Cache"] = None, # type: ignore[name-defined]
|
|
118
|
+
output_attentions: Optional[bool] = False,
|
|
119
|
+
use_cache: Optional[bool] = False,
|
|
120
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
121
|
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
122
|
+
**kwargs,
|
|
123
|
+
) -> Tuple[torch.Tensor] | torch.Tensor:
|
|
124
|
+
if output_attentions:
|
|
125
|
+
raise NotImplementedError(
|
|
126
|
+
"QuantLlamaDecoderLayer does not support output attention yet."
|
|
127
|
+
)
|
|
128
|
+
residual = hidden_states
|
|
129
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
130
|
+
|
|
131
|
+
if attention_mask is None or attention_mask.dtype == torch.bool:
|
|
132
|
+
L = hidden_states.size(1)
|
|
133
|
+
attention_mask = self._slice_causal(L, hidden_states.device)
|
|
134
|
+
|
|
135
|
+
attn_out = self.self_attn(
|
|
136
|
+
hidden_states=hidden_states,
|
|
137
|
+
attention_mask=attention_mask,
|
|
138
|
+
position_ids=position_ids,
|
|
139
|
+
past_key_value=past_key_value,
|
|
140
|
+
output_attentions=output_attentions,
|
|
141
|
+
use_cache=use_cache,
|
|
142
|
+
cache_position=cache_position,
|
|
143
|
+
position_embeddings=position_embeddings,
|
|
144
|
+
**kwargs,
|
|
145
|
+
)
|
|
146
|
+
if use_cache:
|
|
147
|
+
hidden_states_attn, _attn_weights, present_key_value = attn_out
|
|
148
|
+
else:
|
|
149
|
+
hidden_states_attn, _attn_weights = attn_out
|
|
150
|
+
present_key_value = None
|
|
151
|
+
|
|
152
|
+
hidden_states = residual + hidden_states_attn
|
|
153
|
+
|
|
154
|
+
# ─── MLP block ─────────────────────────────────────────────────
|
|
155
|
+
residual = hidden_states
|
|
156
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
157
|
+
hidden_states = self.mlp(hidden_states)
|
|
158
|
+
hidden_states = residual + hidden_states
|
|
159
|
+
|
|
160
|
+
# Return type policy:
|
|
161
|
+
# - If use_cache: always return (hidden_states, present_key_value)
|
|
162
|
+
# - Else: return as configured (tuple/tensor) for HF compatibility
|
|
163
|
+
if use_cache:
|
|
164
|
+
return hidden_states, present_key_value # type: ignore[return-value]
|
|
165
|
+
|
|
166
|
+
if self.return_type == "tuple":
|
|
167
|
+
return (hidden_states,)
|
|
168
|
+
elif self.return_type == "tensor":
|
|
169
|
+
return hidden_states
|
|
170
|
+
else:
|
|
171
|
+
raise RuntimeError("Invalid return type.")
|
|
172
|
+
|
|
173
|
+
# No local observers; just recurse into children
|
|
174
|
+
def _all_observers(self):
|
|
175
|
+
yield from self.self_attn._all_observers()
|
|
176
|
+
yield from self.mlp._all_observers()
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
22
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
23
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@try_register("transformers.models.llama.modeling_llama.LlamaMLP")
|
|
27
|
+
class QuantLlamaMLP(QuantModuleBase):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
mlp_fp: nn.Module,
|
|
31
|
+
*,
|
|
32
|
+
qcfg: Optional[PTQConfig] = None,
|
|
33
|
+
fp_name: Optional[str] = None,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
36
|
+
|
|
37
|
+
# ----- child configs (hierarchical override) -------------------
|
|
38
|
+
gate_cfg = qcfg.child("gate_proj") if qcfg else None
|
|
39
|
+
up_cfg = qcfg.child("up_proj") if qcfg else None
|
|
40
|
+
down_cfg = qcfg.child("down_proj") if qcfg else None
|
|
41
|
+
act_cfg = qcfg.child("act_fn") if qcfg else None
|
|
42
|
+
|
|
43
|
+
# ----- wrap three Linear layers -------------------------------
|
|
44
|
+
assert hasattr(mlp_fp, "gate_proj") and isinstance(
|
|
45
|
+
mlp_fp.gate_proj, torch.nn.Module
|
|
46
|
+
)
|
|
47
|
+
assert hasattr(mlp_fp, "up_proj") and isinstance(
|
|
48
|
+
mlp_fp.up_proj, torch.nn.Module
|
|
49
|
+
)
|
|
50
|
+
assert hasattr(mlp_fp, "down_proj") and isinstance(
|
|
51
|
+
mlp_fp.down_proj, torch.nn.Module
|
|
52
|
+
)
|
|
53
|
+
self.gate_proj = PTQWrapper(
|
|
54
|
+
mlp_fp.gate_proj, qcfg=gate_cfg, fp_name=f"{fp_name}.gate_proj"
|
|
55
|
+
)
|
|
56
|
+
self.up_proj = PTQWrapper(
|
|
57
|
+
mlp_fp.up_proj, qcfg=up_cfg, fp_name=f"{fp_name}.up_proj"
|
|
58
|
+
)
|
|
59
|
+
self.down_proj = PTQWrapper(
|
|
60
|
+
mlp_fp.down_proj, qcfg=down_cfg, fp_name=f"{fp_name}.down_proj"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# ----- activation ---------------------------------------------
|
|
64
|
+
assert hasattr(mlp_fp, "act_fn") and isinstance(mlp_fp.act_fn, torch.nn.Module)
|
|
65
|
+
self.act_fn = PTQWrapper(
|
|
66
|
+
mlp_fp.act_fn, qcfg=act_cfg, fp_name=f"{fp_name}.act_fn"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# ----- local observers ----------------------------------------
|
|
70
|
+
self.act_in_obs = self._make_obs("act_in")
|
|
71
|
+
self.mul_obs = self._make_obs("mul")
|
|
72
|
+
|
|
73
|
+
def forward(self, x: torch.Tensor):
|
|
74
|
+
# 1) quantize input once
|
|
75
|
+
x_q = self._fq(x, self.act_in_obs)
|
|
76
|
+
|
|
77
|
+
# 2) parallel projections
|
|
78
|
+
g = self.gate_proj(x_q)
|
|
79
|
+
u = self.up_proj(x_q)
|
|
80
|
+
|
|
81
|
+
# 3) activation on gate
|
|
82
|
+
a = self.act_fn(g)
|
|
83
|
+
|
|
84
|
+
# 4) element-wise product
|
|
85
|
+
h = self._fq(a * u, self.mul_obs)
|
|
86
|
+
|
|
87
|
+
# 5) final projection
|
|
88
|
+
return self.down_proj(h)
|
|
89
|
+
|
|
90
|
+
def _all_observers(self):
|
|
91
|
+
# local first
|
|
92
|
+
yield self.act_in_obs
|
|
93
|
+
yield self.mul_obs
|
|
94
|
+
# recurse into children that are QuantModuleBase
|
|
95
|
+
for m in (self.gate_proj, self.up_proj, self.down_proj, self.act_fn):
|
|
96
|
+
yield from m._all_observers()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Iterable, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
|
|
22
|
+
from tico.quantization.wrapq.mode import Mode
|
|
23
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
24
|
+
from tico.quantization.wrapq.wrappers.registry import register
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@register(nn.LayerNorm)
|
|
28
|
+
class QuantLayerNorm(QuantModuleBase):
|
|
29
|
+
"""
|
|
30
|
+
QuantLayerNorm — drop-in replacement for nn.LayerNorm that quantizes
|
|
31
|
+
the elementary steps:
|
|
32
|
+
1) μ = mean(x, dims) (mean)
|
|
33
|
+
2) c = x - μ (sub)
|
|
34
|
+
3) s = c * c (square)
|
|
35
|
+
4) v = mean(s, dims) (variance)
|
|
36
|
+
5) e = v + eps (add-eps)
|
|
37
|
+
6) r = rsqrt(e) (rsqrt)
|
|
38
|
+
7) n = c * r (normalize)
|
|
39
|
+
8) y = (n * γ) + β (affine), with:
|
|
40
|
+
• affine_mul : n * γ
|
|
41
|
+
• affine_add : (n * γ) + β
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
fp: nn.LayerNorm,
|
|
47
|
+
*,
|
|
48
|
+
qcfg: Optional[PTQConfig] = None,
|
|
49
|
+
fp_name: Optional[str] = None
|
|
50
|
+
):
|
|
51
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
52
|
+
self.module = fp
|
|
53
|
+
self.eps = torch.tensor(self.module.eps)
|
|
54
|
+
# Number of trailing dims participating in normalization
|
|
55
|
+
# (PyTorch stores normalized_shape as a tuple even if an int was passed)
|
|
56
|
+
self._norm_ndim: int = len(fp.normalized_shape) # safe for int→tuple
|
|
57
|
+
|
|
58
|
+
# Activation / intermediate observers
|
|
59
|
+
self.act_in_obs = self._make_obs("act_in")
|
|
60
|
+
self.mean_obs = self._make_obs("mean")
|
|
61
|
+
self.centered_obs = self._make_obs("centered")
|
|
62
|
+
self.square_obs = self._make_obs("square")
|
|
63
|
+
self.var_obs = self._make_obs("var")
|
|
64
|
+
self.eps_obs = self._make_obs("eps")
|
|
65
|
+
self.add_eps_obs = self._make_obs("add_eps")
|
|
66
|
+
self.inv_std_obs = self._make_obs("inv_std")
|
|
67
|
+
self.norm_obs = self._make_obs("norm")
|
|
68
|
+
self.act_out_obs = self._make_obs("act_out")
|
|
69
|
+
|
|
70
|
+
# Optional affine parameter observers (γ, β)
|
|
71
|
+
self.weight_obs = None
|
|
72
|
+
self.bias_obs = None
|
|
73
|
+
self.affine_mul_obs = None
|
|
74
|
+
self.affine_add_obs = None
|
|
75
|
+
if self.module.elementwise_affine:
|
|
76
|
+
if self.module.weight is not None:
|
|
77
|
+
self.weight_obs = self._make_obs("weight")
|
|
78
|
+
if self.module.bias is not None:
|
|
79
|
+
self.bias_obs = self._make_obs("bias")
|
|
80
|
+
# Per-op observers for (n * w) and (+ b)
|
|
81
|
+
self.affine_mul_obs = self._make_obs("affine_mul")
|
|
82
|
+
self.affine_add_obs = self._make_obs("affine_add")
|
|
83
|
+
|
|
84
|
+
def enable_calibration(self) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Switch to CALIB mode and collect *fixed* ranges for affine params
|
|
87
|
+
immediately, since they do not change across inputs.
|
|
88
|
+
"""
|
|
89
|
+
super().enable_calibration()
|
|
90
|
+
if self.module.elementwise_affine:
|
|
91
|
+
if self.weight_obs is not None and self.module.weight is not None:
|
|
92
|
+
self.weight_obs.collect(self.module.weight)
|
|
93
|
+
if self.bias_obs is not None and self.module.bias is not None:
|
|
94
|
+
self.bias_obs.collect(self.module.bias)
|
|
95
|
+
|
|
96
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
97
|
+
# Determine reduction dims (last self._norm_ndim axes)
|
|
98
|
+
# Example: if x.ndim=4 and norm_ndim=2 → dims=(2,3)
|
|
99
|
+
dims = tuple(range(x.dim() - self._norm_ndim, x.dim()))
|
|
100
|
+
|
|
101
|
+
# 0) input
|
|
102
|
+
x_q = self._fq(x, self.act_in_obs)
|
|
103
|
+
|
|
104
|
+
# 1) mean
|
|
105
|
+
mu = x_q.mean(dim=dims, keepdim=True)
|
|
106
|
+
mu_q = self._fq(mu, self.mean_obs)
|
|
107
|
+
|
|
108
|
+
# 2) center
|
|
109
|
+
c = x_q - mu_q
|
|
110
|
+
c_q = self._fq(c, self.centered_obs)
|
|
111
|
+
|
|
112
|
+
# 3) square (elementwise mul)
|
|
113
|
+
s = c_q * c_q
|
|
114
|
+
s_q = self._fq(s, self.square_obs)
|
|
115
|
+
|
|
116
|
+
# 4) variance (via squared mean)
|
|
117
|
+
v = s_q.mean(dim=dims, keepdim=True)
|
|
118
|
+
v_q = self._fq(v, self.var_obs)
|
|
119
|
+
|
|
120
|
+
# 5) add eps
|
|
121
|
+
eps_q = self._fq(self.eps, self.eps_obs)
|
|
122
|
+
e = v_q + eps_q
|
|
123
|
+
e_q = self._fq(e, self.add_eps_obs)
|
|
124
|
+
|
|
125
|
+
# 6) inverse std
|
|
126
|
+
r = torch.rsqrt(e_q)
|
|
127
|
+
r_q = self._fq(r, self.inv_std_obs)
|
|
128
|
+
|
|
129
|
+
# 7) normalize
|
|
130
|
+
n = c_q * r_q
|
|
131
|
+
n_q = self._fq(n, self.norm_obs)
|
|
132
|
+
|
|
133
|
+
# 8) optional affine
|
|
134
|
+
if self.module.elementwise_affine:
|
|
135
|
+
w = self.module.weight
|
|
136
|
+
b = self.module.bias
|
|
137
|
+
if self._mode is Mode.QUANT:
|
|
138
|
+
if self.weight_obs is not None and w is not None:
|
|
139
|
+
w = self.weight_obs.fake_quant(w) # type: ignore[assignment]
|
|
140
|
+
if self.bias_obs is not None and b is not None:
|
|
141
|
+
b = self.bias_obs.fake_quant(b) # type: ignore[assignment]
|
|
142
|
+
y = n_q
|
|
143
|
+
# 8a) n * w (fake-quant the result of the mul)
|
|
144
|
+
if w is not None:
|
|
145
|
+
y = y * w
|
|
146
|
+
if self.affine_mul_obs is not None:
|
|
147
|
+
y = self._fq(y, self.affine_mul_obs)
|
|
148
|
+
|
|
149
|
+
# 8b) (+ b) (fake-quant the result of the add)
|
|
150
|
+
if b is not None:
|
|
151
|
+
y = y + b
|
|
152
|
+
if self.affine_add_obs is not None:
|
|
153
|
+
y = self._fq(y, self.affine_add_obs)
|
|
154
|
+
else:
|
|
155
|
+
y = n_q
|
|
156
|
+
|
|
157
|
+
# 9) output activation
|
|
158
|
+
return self._fq(y, self.act_out_obs)
|
|
159
|
+
|
|
160
|
+
def _all_observers(self) -> Iterable:
|
|
161
|
+
obs: Tuple = (
|
|
162
|
+
self.act_in_obs,
|
|
163
|
+
self.mean_obs,
|
|
164
|
+
self.centered_obs,
|
|
165
|
+
self.square_obs,
|
|
166
|
+
self.var_obs,
|
|
167
|
+
self.eps_obs,
|
|
168
|
+
self.add_eps_obs,
|
|
169
|
+
self.inv_std_obs,
|
|
170
|
+
self.norm_obs,
|
|
171
|
+
self.act_out_obs,
|
|
172
|
+
)
|
|
173
|
+
# Insert affine param observers if present
|
|
174
|
+
if self.module.elementwise_affine:
|
|
175
|
+
if self.weight_obs is not None:
|
|
176
|
+
obs = (self.weight_obs,) + obs
|
|
177
|
+
if self.bias_obs is not None:
|
|
178
|
+
obs = obs + (self.bias_obs,)
|
|
179
|
+
if self.affine_mul_obs is not None:
|
|
180
|
+
obs = obs + (self.affine_mul_obs,)
|
|
181
|
+
if self.affine_add_obs is not None:
|
|
182
|
+
obs = obs + (self.affine_add_obs,)
|
|
183
|
+
return obs
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
|
|
22
|
+
from tico.quantization.wrapq.mode import Mode
|
|
23
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
24
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
25
|
+
from tico.quantization.wrapq.wrappers.registry import register
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@register(nn.Linear)
|
|
29
|
+
class QuantLinear(QuantModuleBase):
|
|
30
|
+
"""Per-channel weight fake-quant, eager-output activation fake-quant."""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
fp: nn.Linear,
|
|
35
|
+
*,
|
|
36
|
+
qcfg: Optional[PTQConfig] = None,
|
|
37
|
+
fp_name: Optional[str] = None
|
|
38
|
+
):
|
|
39
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
40
|
+
self.weight_obs = self._make_obs(
|
|
41
|
+
"weight", qscheme=QScheme.PER_CHANNEL_ASYMM, channel_axis=0
|
|
42
|
+
)
|
|
43
|
+
self.act_in_obs = self._make_obs("act_in")
|
|
44
|
+
self.act_out_obs = self._make_obs("act_out")
|
|
45
|
+
self.module = fp
|
|
46
|
+
|
|
47
|
+
def enable_calibration(self) -> None:
|
|
48
|
+
super().enable_calibration()
|
|
49
|
+
# immediately capture the fixed weight range
|
|
50
|
+
self.weight_obs.collect(self.module.weight)
|
|
51
|
+
|
|
52
|
+
def forward(self, x):
|
|
53
|
+
x_q = self._fq(x, self.act_in_obs)
|
|
54
|
+
|
|
55
|
+
w = self.module.weight
|
|
56
|
+
if self._mode is Mode.QUANT:
|
|
57
|
+
w = self.weight_obs.fake_quant(w)
|
|
58
|
+
b = self.module.bias
|
|
59
|
+
|
|
60
|
+
out = F.linear(x_q, w, b)
|
|
61
|
+
|
|
62
|
+
return self._fq(out, self.act_out_obs)
|
|
63
|
+
|
|
64
|
+
def _all_observers(self):
|
|
65
|
+
return (self.weight_obs, self.act_in_obs, self.act_out_obs)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
22
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@try_register("torch.nn.SiLU", "transformers.activations.SiLUActivation")
|
|
26
|
+
class QuantSiLU(QuantModuleBase):
|
|
27
|
+
"""
|
|
28
|
+
QuantSiLU — drop-in quantized implementation of the SiLU operation.
|
|
29
|
+
|
|
30
|
+
This module quantizes both intermediate tensors:
|
|
31
|
+
• s = sigmoid(x) (logistic)
|
|
32
|
+
• y = x * s (mul)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
fp: nn.SiLU,
|
|
38
|
+
*,
|
|
39
|
+
qcfg: Optional[PTQConfig] = None,
|
|
40
|
+
fp_name: Optional[str] = None
|
|
41
|
+
):
|
|
42
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
43
|
+
self.act_in_obs = self._make_obs("act_in")
|
|
44
|
+
self.sig_obs = self._make_obs("sigmoid")
|
|
45
|
+
self.mul_obs = self._make_obs("mul")
|
|
46
|
+
self.module = fp
|
|
47
|
+
|
|
48
|
+
def forward(self, x: torch.Tensor):
|
|
49
|
+
x_q = self._fq(x, self.act_in_obs)
|
|
50
|
+
|
|
51
|
+
s = torch.sigmoid(x_q)
|
|
52
|
+
s = self._fq(s, self.sig_obs)
|
|
53
|
+
|
|
54
|
+
y = x * s
|
|
55
|
+
y = self._fq(y, self.mul_obs)
|
|
56
|
+
|
|
57
|
+
return y
|
|
58
|
+
|
|
59
|
+
def _all_observers(self):
|
|
60
|
+
return (self.act_in_obs, self.sig_obs, self.mul_obs)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
20
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
21
|
+
from tico.quantization.wrapq.wrappers.registry import lookup
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PTQWrapper(QuantModuleBase):
|
|
25
|
+
"""
|
|
26
|
+
Adapter that turns a fp module into its quantized counterpart.
|
|
27
|
+
|
|
28
|
+
It is itself a QuantModuleBase so composite wrappers can treat
|
|
29
|
+
it exactly like any other quant module.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
module: torch.nn.Module,
|
|
35
|
+
qcfg: Optional[PTQConfig] = None,
|
|
36
|
+
*,
|
|
37
|
+
fp_name: Optional[str] = None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__(qcfg)
|
|
40
|
+
wrapped_cls = lookup(type(module))
|
|
41
|
+
if wrapped_cls is None:
|
|
42
|
+
raise NotImplementedError(f"No quant wrapper for {type(module).__name__}")
|
|
43
|
+
self.wrapped: QuantModuleBase = wrapped_cls(module, qcfg=qcfg, fp_name=fp_name) # type: ignore[arg-type, misc]
|
|
44
|
+
|
|
45
|
+
def forward(self, *args, **kwargs):
|
|
46
|
+
return self.wrapped(*args, **kwargs)
|
|
47
|
+
|
|
48
|
+
def _all_observers(self):
|
|
49
|
+
"""
|
|
50
|
+
PTQWrapper itself owns NO observers (transparent node).
|
|
51
|
+
Returning an empty iterator prevents double-processing when parents
|
|
52
|
+
traverse the tree and then recurse into `self.wrapped`.
|
|
53
|
+
"""
|
|
54
|
+
return () # no local observers
|
|
55
|
+
|
|
56
|
+
def named_observers(self):
|
|
57
|
+
"""
|
|
58
|
+
Proxy to the wrapped module so debugging tools can still enumerate observers.
|
|
59
|
+
"""
|
|
60
|
+
yield from self.wrapped.named_observers()
|
|
61
|
+
|
|
62
|
+
def get_observer(self, name: str):
|
|
63
|
+
"""
|
|
64
|
+
Proxy to the wrapped module for direct lookup by name.
|
|
65
|
+
"""
|
|
66
|
+
return self.wrapped.get_observer(name)
|
|
67
|
+
|
|
68
|
+
def extra_repr(self) -> str:
|
|
69
|
+
return self.wrapped.extra_repr()
|