tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/fpi_gptq.py +29 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +20 -15
- tico/utils/register_custom_op.py +6 -4
- tico/utils/signature.py +7 -8
- tico/utils/validate_args_kwargs.py +12 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
- /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
- /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
import math
|
|
22
|
+
from typing import Dict, List, Literal, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
26
|
+
from torch import Tensor
|
|
27
|
+
|
|
28
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
29
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
30
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
31
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@try_register("fairseq.models.transformer.TransformerEncoderBase")
|
|
35
|
+
class QuantFairseqEncoder(QuantModuleBase):
|
|
36
|
+
"""
|
|
37
|
+
Quant-aware drop-in replacement for Fairseq TransformerEncoderBase.
|
|
38
|
+
|
|
39
|
+
Key design choices:
|
|
40
|
+
- Keep embeddings and LayerNorms in FP.
|
|
41
|
+
- Remove training-time logic (dropout, activation-dropout, quant_noise).
|
|
42
|
+
- Attention masks are handled statically inside the layer wrapper; this
|
|
43
|
+
encoder only does the original padding zero-out before the stack.
|
|
44
|
+
|
|
45
|
+
I/O contracts:
|
|
46
|
+
- Forward signature and returned dictionary are identical to the original
|
|
47
|
+
when `use_external_inputs=False`.
|
|
48
|
+
- When `use_external_inputs=True`, forward returns a single Tensor (T,B,C)
|
|
49
|
+
and completely skips embedding/positional/LN/mask-creation paths.
|
|
50
|
+
- Tensor shapes follow Fairseq convention.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
fp_encoder: nn.Module,
|
|
56
|
+
*,
|
|
57
|
+
qcfg: Optional[PTQConfig] = None,
|
|
58
|
+
fp_name: Optional[str] = None,
|
|
59
|
+
use_external_inputs: bool = False, # export-mode flag
|
|
60
|
+
return_type: Literal["tensor", "dict"] = "dict",
|
|
61
|
+
):
|
|
62
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
63
|
+
self.use_external_inputs = use_external_inputs
|
|
64
|
+
self.return_type: Literal["tensor", "dict"] = return_type
|
|
65
|
+
|
|
66
|
+
# --- carry basic config / metadata (read-only copies) ---------------
|
|
67
|
+
assert hasattr(fp_encoder, "cfg")
|
|
68
|
+
self.cfg = fp_encoder.cfg
|
|
69
|
+
self.return_fc: bool = bool(getattr(fp_encoder, "return_fc", False))
|
|
70
|
+
|
|
71
|
+
# Embedding stack ----------------------------------------------------
|
|
72
|
+
assert hasattr(fp_encoder, "embed_tokens") and isinstance(
|
|
73
|
+
fp_encoder.embed_tokens, nn.Module
|
|
74
|
+
)
|
|
75
|
+
self.embed_tokens = fp_encoder.embed_tokens # keep FP embeddings
|
|
76
|
+
|
|
77
|
+
assert hasattr(fp_encoder, "padding_idx")
|
|
78
|
+
self.padding_idx: int = int(fp_encoder.padding_idx) # type: ignore[arg-type]
|
|
79
|
+
|
|
80
|
+
# scale = sqrt(embed_dim) unless disabled
|
|
81
|
+
embed_dim = int(self.embed_tokens.embedding_dim) # type: ignore[arg-type]
|
|
82
|
+
no_scale = bool(getattr(self.cfg, "no_scale_embedding", False))
|
|
83
|
+
self.embed_scale: float = 1.0 if no_scale else math.sqrt(embed_dim)
|
|
84
|
+
|
|
85
|
+
# Positional embeddings (keep as-is; no FQ)
|
|
86
|
+
self.embed_positions = getattr(fp_encoder, "embed_positions", None)
|
|
87
|
+
# Optional embedding LayerNorm
|
|
88
|
+
self.layernorm_embedding = getattr(fp_encoder, "layernorm_embedding", None)
|
|
89
|
+
|
|
90
|
+
# Final encoder LayerNorm (pre-norm stacks may set this to None)
|
|
91
|
+
self.layer_norm = getattr(fp_encoder, "layer_norm", None)
|
|
92
|
+
|
|
93
|
+
# Max positions (reuse for API parity)
|
|
94
|
+
self.max_source_positions: int = int(fp_encoder.max_source_positions) # type: ignore[arg-type]
|
|
95
|
+
|
|
96
|
+
# --- wrap encoder layers with PTQWrapper ----------------------------
|
|
97
|
+
assert hasattr(fp_encoder, "layers")
|
|
98
|
+
fp_layers = list(fp_encoder.layers) # type: ignore[arg-type]
|
|
99
|
+
self.layers = nn.ModuleList()
|
|
100
|
+
|
|
101
|
+
# Prepare child PTQConfig namespaces: layers/<idx>
|
|
102
|
+
layers_qcfg = qcfg.child("layers") if qcfg else None
|
|
103
|
+
for i, layer in enumerate(fp_layers):
|
|
104
|
+
child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
|
|
105
|
+
self.layers.append(
|
|
106
|
+
PTQWrapper(layer, qcfg=child_cfg, fp_name=f"{fp_name}.layers.{i}")
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Version buffer (keep for state_dict parity)
|
|
110
|
+
version = getattr(fp_encoder, "version", None)
|
|
111
|
+
if isinstance(version, torch.Tensor):
|
|
112
|
+
self.register_buffer("version", version.clone(), persistent=False)
|
|
113
|
+
else:
|
|
114
|
+
self.register_buffer("version", torch.tensor([3.0]), persistent=False)
|
|
115
|
+
|
|
116
|
+
# ----------------------------------------------------------------------
|
|
117
|
+
def forward_embedding(
|
|
118
|
+
self, src_tokens: Tensor, token_embedding: Optional[Tensor] = None
|
|
119
|
+
) -> Tuple[Tensor, Tensor]:
|
|
120
|
+
"""
|
|
121
|
+
Embed tokens and add positional embeddings. Dropout/quant_noise are removed.
|
|
122
|
+
Returns:
|
|
123
|
+
x (B, T, C), embed (B, T, C) # embed is the token-only embedding
|
|
124
|
+
"""
|
|
125
|
+
if token_embedding is None:
|
|
126
|
+
token_embedding = self.embed_tokens(src_tokens)
|
|
127
|
+
embed = token_embedding # token-only
|
|
128
|
+
|
|
129
|
+
x = self.embed_scale * token_embedding
|
|
130
|
+
if self.embed_positions is not None:
|
|
131
|
+
x = x + self.embed_positions(src_tokens)
|
|
132
|
+
if self.layernorm_embedding is not None:
|
|
133
|
+
x = self.layernorm_embedding(x)
|
|
134
|
+
# No dropout, no quant_noise here (inference-only)
|
|
135
|
+
return x, embed
|
|
136
|
+
|
|
137
|
+
# ----------------------------------------------------------------------
|
|
138
|
+
def forward(
|
|
139
|
+
self,
|
|
140
|
+
src_tokens: Tensor,
|
|
141
|
+
src_lengths: Optional[Tensor] = None,
|
|
142
|
+
return_all_hiddens: bool = False,
|
|
143
|
+
token_embeddings: Optional[Tensor] = None,
|
|
144
|
+
*,
|
|
145
|
+
# External-inputs branch (used for export)
|
|
146
|
+
encoder_padding_mask: Optional[Tensor] = None, # B x T (bool)
|
|
147
|
+
) -> Tensor | Dict[str, List[Optional[Tensor]]]:
|
|
148
|
+
"""
|
|
149
|
+
If `self.use_external_inputs` is True:
|
|
150
|
+
- Use only x_external and encoder_padding_mask.
|
|
151
|
+
- Return a single Tensor (T, B, C) for export friendliness.
|
|
152
|
+
|
|
153
|
+
Otherwise (False):
|
|
154
|
+
- Behave like the original Fairseq encoder forward and return dict-of-lists.
|
|
155
|
+
"""
|
|
156
|
+
if self.use_external_inputs:
|
|
157
|
+
# ----- External-input mode: completely skip embedding/positional/LN/mask creation -----
|
|
158
|
+
x_external = src_tokens # T x B x C (already embedded + transposed)
|
|
159
|
+
|
|
160
|
+
encoder_states: List[Tensor] = []
|
|
161
|
+
if return_all_hiddens:
|
|
162
|
+
encoder_states.append(x_external)
|
|
163
|
+
|
|
164
|
+
for layer in self.layers:
|
|
165
|
+
out = layer(x_external, encoder_padding_mask=encoder_padding_mask)
|
|
166
|
+
x_external = (
|
|
167
|
+
out[0] if (isinstance(out, tuple) and len(out) == 2) else out
|
|
168
|
+
)
|
|
169
|
+
if return_all_hiddens:
|
|
170
|
+
encoder_states.append(x_external)
|
|
171
|
+
|
|
172
|
+
if self.layer_norm is not None:
|
|
173
|
+
x_external = self.layer_norm(x_external)
|
|
174
|
+
|
|
175
|
+
if self.return_type == "dict":
|
|
176
|
+
return {
|
|
177
|
+
"encoder_out": [x_external],
|
|
178
|
+
"encoder_padding_mask": [encoder_padding_mask],
|
|
179
|
+
"encoder_states": encoder_states, # type: ignore[dict-item]
|
|
180
|
+
}
|
|
181
|
+
else:
|
|
182
|
+
# For export, returning a single Tensor is simpler and more portable.
|
|
183
|
+
return x_external
|
|
184
|
+
|
|
185
|
+
# ----- Original path (training/eval compatibility) ------------------
|
|
186
|
+
|
|
187
|
+
# Compute padding mask [B, T] (bool). We keep the original "has_pads" logic.
|
|
188
|
+
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
|
189
|
+
has_pads: Tensor = (
|
|
190
|
+
torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
|
|
191
|
+
)
|
|
192
|
+
if torch.jit.is_scripting():
|
|
193
|
+
has_pads = torch.tensor(1) if has_pads else torch.tensor(0)
|
|
194
|
+
|
|
195
|
+
# Embedding path (B,T,C). No dropout/quant_noise.
|
|
196
|
+
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
|
197
|
+
|
|
198
|
+
# Zero out padded timesteps prior to the stack (same as original)
|
|
199
|
+
x = x * (
|
|
200
|
+
1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# B x T x C -> T x B x C
|
|
204
|
+
x = x.transpose(0, 1)
|
|
205
|
+
|
|
206
|
+
encoder_states: List[Tensor] = [] # type: ignore[no-redef]
|
|
207
|
+
fc_results: List[Optional[Tensor]] = []
|
|
208
|
+
|
|
209
|
+
if return_all_hiddens:
|
|
210
|
+
encoder_states.append(x)
|
|
211
|
+
|
|
212
|
+
# Encoder layers (each item is PTQ-wrapped and uses static additive masks internally)
|
|
213
|
+
for layer in self.layers:
|
|
214
|
+
out = layer(
|
|
215
|
+
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
|
|
216
|
+
)
|
|
217
|
+
if isinstance(out, tuple) and len(out) == 2:
|
|
218
|
+
x, fc_res = out
|
|
219
|
+
else:
|
|
220
|
+
x = out
|
|
221
|
+
fc_res = None
|
|
222
|
+
|
|
223
|
+
if return_all_hiddens and not torch.jit.is_scripting():
|
|
224
|
+
encoder_states.append(x)
|
|
225
|
+
fc_results.append(fc_res)
|
|
226
|
+
|
|
227
|
+
if self.layer_norm is not None:
|
|
228
|
+
x = self.layer_norm(x)
|
|
229
|
+
|
|
230
|
+
# src_lengths (B, 1) int32, identical to original
|
|
231
|
+
src_lengths_out = (
|
|
232
|
+
src_tokens.ne(self.padding_idx)
|
|
233
|
+
.sum(dim=1, dtype=torch.int32)
|
|
234
|
+
.reshape(-1, 1)
|
|
235
|
+
.contiguous()
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
"encoder_out": [x], # T x B x C
|
|
240
|
+
"encoder_padding_mask": [encoder_padding_mask], # B x T
|
|
241
|
+
"encoder_embedding": [encoder_embedding], # B x T x C
|
|
242
|
+
"encoder_states": encoder_states, # type: ignore[dict-item] # List[T x B x C]
|
|
243
|
+
"fc_results": fc_results, # type: ignore[dict-item] # List[T x B x C]
|
|
244
|
+
"src_tokens": [],
|
|
245
|
+
"src_lengths": [src_lengths_out],
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
def forward_torchscript(self, net_input: Dict[str, Tensor]):
|
|
249
|
+
"""A TorchScript-compatible version of forward.
|
|
250
|
+
|
|
251
|
+
Encoders which use additional arguments may want to override
|
|
252
|
+
this method for TorchScript compatibility.
|
|
253
|
+
"""
|
|
254
|
+
if "encoder_padding_mask" in net_input:
|
|
255
|
+
return self.forward(
|
|
256
|
+
src_tokens=net_input["src_tokens"],
|
|
257
|
+
src_lengths=net_input["src_lengths"],
|
|
258
|
+
encoder_padding_mask=net_input["encoder_padding_mask"],
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
return self.forward(
|
|
262
|
+
src_tokens=net_input["src_tokens"],
|
|
263
|
+
src_lengths=net_input["src_lengths"],
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# ----------------------------------------------------------------------
|
|
267
|
+
@torch.jit.export
|
|
268
|
+
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
|
269
|
+
"""
|
|
270
|
+
Match original API: reorder the batched dimension (B) according to new_order.
|
|
271
|
+
"""
|
|
272
|
+
reordered = dict() # type: ignore[var-annotated]
|
|
273
|
+
if len(encoder_out["encoder_out"]) == 0:
|
|
274
|
+
new_encoder_out = []
|
|
275
|
+
else:
|
|
276
|
+
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
|
|
277
|
+
reordered["encoder_out"] = new_encoder_out
|
|
278
|
+
keys = [
|
|
279
|
+
"encoder_padding_mask",
|
|
280
|
+
"encoder_embedding",
|
|
281
|
+
"src_tokens",
|
|
282
|
+
"src_lengths",
|
|
283
|
+
]
|
|
284
|
+
for k in keys:
|
|
285
|
+
if k not in encoder_out:
|
|
286
|
+
continue
|
|
287
|
+
if len(encoder_out[k]) == 0:
|
|
288
|
+
reordered[k] = []
|
|
289
|
+
else:
|
|
290
|
+
reordered[k] = [encoder_out[k][0].index_select(0, new_order)]
|
|
291
|
+
|
|
292
|
+
if "encoder_states" in encoder_out:
|
|
293
|
+
encoder_states = encoder_out["encoder_states"]
|
|
294
|
+
if len(encoder_states) > 0:
|
|
295
|
+
for idx, state in enumerate(encoder_states):
|
|
296
|
+
encoder_states[idx] = state.index_select(1, new_order)
|
|
297
|
+
reordered["encoder_states"] = encoder_states
|
|
298
|
+
|
|
299
|
+
return reordered
|
|
300
|
+
|
|
301
|
+
@torch.jit.export
|
|
302
|
+
def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
|
303
|
+
"""Dummy re-order for beamable enc-dec attention (API parity)."""
|
|
304
|
+
return encoder_out
|
|
305
|
+
|
|
306
|
+
def max_positions(self) -> int:
|
|
307
|
+
"""Maximum input length supported by the encoder (same policy as the original)."""
|
|
308
|
+
if self.embed_positions is None:
|
|
309
|
+
return self.max_source_positions
|
|
310
|
+
return min(self.max_source_positions, self.embed_positions.max_positions)
|
|
311
|
+
|
|
312
|
+
def upgrade_state_dict_named(self, state_dict, name):
|
|
313
|
+
"""
|
|
314
|
+
Forward-compat mapping for older checkpoints (mirror original behavior for LNs).
|
|
315
|
+
The actual remapping of per-layer norms is delegated to the wrapped layers.
|
|
316
|
+
"""
|
|
317
|
+
for i, layer in enumerate(self.layers):
|
|
318
|
+
if hasattr(layer, "upgrade_state_dict_named"):
|
|
319
|
+
layer.upgrade_state_dict_named(state_dict, f"{name}.layers.{i}")
|
|
320
|
+
|
|
321
|
+
version_key = f"{name}.version"
|
|
322
|
+
v = state_dict.get(version_key, torch.Tensor([1]))
|
|
323
|
+
if float(v[0].item()) < 2:
|
|
324
|
+
self.layer_norm = None
|
|
325
|
+
state_dict[version_key] = torch.Tensor([1])
|
|
326
|
+
return state_dict
|
|
327
|
+
|
|
328
|
+
def _all_observers(self):
|
|
329
|
+
for m in self.layers:
|
|
330
|
+
if isinstance(m, QuantModuleBase):
|
|
331
|
+
yield from m._all_observers()
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
from typing import Optional
|
|
22
|
+
|
|
23
|
+
import torch.nn as nn
|
|
24
|
+
from torch import Tensor
|
|
25
|
+
|
|
26
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
27
|
+
from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
|
|
28
|
+
QuantFairseqMultiheadAttention,
|
|
29
|
+
)
|
|
30
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
31
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
32
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@try_register("fairseq.modules.transformer_layer.TransformerEncoderLayerBase")
|
|
36
|
+
class QuantFairseqEncoderLayer(QuantModuleBase):
|
|
37
|
+
"""
|
|
38
|
+
Quant-aware drop-in replacement for Fairseq TransformerEncoderLayerBase.
|
|
39
|
+
|
|
40
|
+
Design notes (inference-friendly):
|
|
41
|
+
- All training-time logic (dropout, activation-dropout) is removed.
|
|
42
|
+
- I/O shape follows Fairseq convention: [T, B, C].
|
|
43
|
+
- `return_fc` behavior is preserved (returns (x, fc_result) if enabled).
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
fp_layer: nn.Module,
|
|
49
|
+
*,
|
|
50
|
+
qcfg: Optional[PTQConfig] = None,
|
|
51
|
+
fp_name: Optional[str] = None,
|
|
52
|
+
):
|
|
53
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
54
|
+
|
|
55
|
+
# --- copy meta / config flags from FP layer (read-only) -------------
|
|
56
|
+
assert hasattr(fp_layer, "embed_dim")
|
|
57
|
+
assert hasattr(fp_layer, "normalize_before")
|
|
58
|
+
self.embed_dim: int = int(fp_layer.embed_dim) # type: ignore[arg-type]
|
|
59
|
+
self.normalize_before: bool = bool(fp_layer.normalize_before)
|
|
60
|
+
self.return_fc: bool = bool(getattr(fp_layer, "return_fc", False))
|
|
61
|
+
|
|
62
|
+
# --- PTQ-wrapped submodules ----------------------------------------
|
|
63
|
+
attn_cfg = qcfg.child("self_attn") if qcfg else None
|
|
64
|
+
fc1_cfg = qcfg.child("fc1") if qcfg else None
|
|
65
|
+
fc2_cfg = qcfg.child("fc2") if qcfg else None
|
|
66
|
+
attn_ln_cfg = qcfg.child("self_attn_layer_norm") if qcfg else None
|
|
67
|
+
final_ln_cfg = qcfg.child("final_layer_norm") if qcfg else None
|
|
68
|
+
|
|
69
|
+
assert hasattr(fp_layer, "self_attn") and isinstance(
|
|
70
|
+
fp_layer.self_attn, nn.Module
|
|
71
|
+
)
|
|
72
|
+
assert hasattr(fp_layer, "fc1") and isinstance(fp_layer.fc1, nn.Module)
|
|
73
|
+
assert hasattr(fp_layer, "fc2") and isinstance(fp_layer.fc2, nn.Module)
|
|
74
|
+
|
|
75
|
+
self.self_attn = QuantFairseqMultiheadAttention(
|
|
76
|
+
fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{fp_name}.self_attn"
|
|
77
|
+
)
|
|
78
|
+
self.fc1 = PTQWrapper(fp_layer.fc1, qcfg=fc1_cfg, fp_name=f"{fp_name}.fc1")
|
|
79
|
+
self.fc2 = PTQWrapper(fp_layer.fc2, qcfg=fc2_cfg, fp_name=f"{fp_name}.fc2")
|
|
80
|
+
|
|
81
|
+
# LayerNorms
|
|
82
|
+
assert hasattr(fp_layer, "self_attn_layer_norm") and isinstance(
|
|
83
|
+
fp_layer.self_attn_layer_norm, nn.Module
|
|
84
|
+
)
|
|
85
|
+
assert hasattr(fp_layer, "final_layer_norm") and isinstance(
|
|
86
|
+
fp_layer.final_layer_norm, nn.Module
|
|
87
|
+
)
|
|
88
|
+
self.self_attn_layer_norm = PTQWrapper(
|
|
89
|
+
fp_layer.self_attn_layer_norm,
|
|
90
|
+
qcfg=attn_ln_cfg,
|
|
91
|
+
fp_name=f"{fp_name}.self_attn_layer_norm",
|
|
92
|
+
)
|
|
93
|
+
self.final_layer_norm = PTQWrapper(
|
|
94
|
+
fp_layer.final_layer_norm,
|
|
95
|
+
qcfg=final_ln_cfg,
|
|
96
|
+
fp_name=f"{fp_name}.final_layer_norm",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Activation function
|
|
100
|
+
self.activation_fn = fp_layer.activation_fn # type: ignore[operator] # e.g., GELU/ReLU
|
|
101
|
+
self.obs_activation_fn = self._make_obs("activation_fn")
|
|
102
|
+
|
|
103
|
+
# ----------------------------------------------------------------------
|
|
104
|
+
def forward(
|
|
105
|
+
self,
|
|
106
|
+
x: Tensor, # [T,B,C]
|
|
107
|
+
encoder_padding_mask: Optional[Tensor],
|
|
108
|
+
attn_mask: Optional[Tensor] = None, # [T,S] boolean/byte or additive float
|
|
109
|
+
):
|
|
110
|
+
"""
|
|
111
|
+
Returns:
|
|
112
|
+
x' of shape [T, B, C] (or (x', fc_result) when return_fc=True)
|
|
113
|
+
"""
|
|
114
|
+
# ---- Self-Attention block (pre-/post-norm kept as in FP layer) ----
|
|
115
|
+
residual = x
|
|
116
|
+
if self.normalize_before:
|
|
117
|
+
x = self.self_attn_layer_norm(x)
|
|
118
|
+
|
|
119
|
+
# Fairseq MHA expects [T,B,C]; our wrapped module keeps the same API
|
|
120
|
+
attn_out, _ = self.self_attn(
|
|
121
|
+
query=x,
|
|
122
|
+
key=x,
|
|
123
|
+
value=x,
|
|
124
|
+
key_padding_mask=encoder_padding_mask, # additive float [B,S] or None
|
|
125
|
+
need_weights=False,
|
|
126
|
+
attn_mask=attn_mask, # additive float [T,S] or None
|
|
127
|
+
)
|
|
128
|
+
x = residual + attn_out
|
|
129
|
+
|
|
130
|
+
if not self.normalize_before:
|
|
131
|
+
x = self.self_attn_layer_norm(x)
|
|
132
|
+
|
|
133
|
+
# ---- FFN block (no dropout/activation-dropout) --------------------
|
|
134
|
+
residual = x
|
|
135
|
+
if self.normalize_before:
|
|
136
|
+
x = self.final_layer_norm(x)
|
|
137
|
+
|
|
138
|
+
x = self.fc1(x) # Linear
|
|
139
|
+
x = self.activation_fn(x) # type: ignore[operator]
|
|
140
|
+
x = self._fq(x, self.obs_activation_fn)
|
|
141
|
+
x = self.fc2(x) # Linear
|
|
142
|
+
|
|
143
|
+
fc_result = x # keep before residual for optional return
|
|
144
|
+
|
|
145
|
+
x = residual + x
|
|
146
|
+
if not self.normalize_before:
|
|
147
|
+
x = self.final_layer_norm(x)
|
|
148
|
+
|
|
149
|
+
if self.return_fc:
|
|
150
|
+
return x, fc_result
|
|
151
|
+
return x
|
|
152
|
+
|
|
153
|
+
def _all_observers(self):
|
|
154
|
+
yield from (self.obs_activation_fn,)
|
|
155
|
+
for m in (
|
|
156
|
+
self.self_attn,
|
|
157
|
+
self.fc1,
|
|
158
|
+
self.fc2,
|
|
159
|
+
self.self_attn_layer_norm,
|
|
160
|
+
self.final_layer_norm,
|
|
161
|
+
):
|
|
162
|
+
if isinstance(m, QuantModuleBase):
|
|
163
|
+
yield from m._all_observers()
|