tico 0.1.0.dev250917__py3-none-any.whl → 0.1.0.dev250918__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +3 -0
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +190 -69
- tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py +494 -0
- tico/experimental/quantization/ptq/wrappers/registry.py +1 -0
- tico/passes/convert_matmul_to_linear.py +200 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +6 -1
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/RECORD +16 -14
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250918.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
from torch import nn, Tensor
|
|
25
|
+
|
|
26
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
27
|
+
from tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha import (
|
|
28
|
+
QuantFairseqMultiheadAttention,
|
|
29
|
+
)
|
|
30
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
|
31
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
32
|
+
QuantModuleBase,
|
|
33
|
+
)
|
|
34
|
+
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@try_register("fairseq.modules.transformer_layer.TransformerDecoderLayerBase")
|
|
38
|
+
class QuantFairseqDecoderLayer(QuantModuleBase):
|
|
39
|
+
"""
|
|
40
|
+
Quant-aware drop-in replacement for Fairseq TransformerDecoderLayerBase.
|
|
41
|
+
|
|
42
|
+
Design (inference-only):
|
|
43
|
+
- Keep LayerNorms and scalar head/residual scalers in FP.
|
|
44
|
+
- PTQ-wrap: self_attn, (optional) encoder_attn, fc1, fc2.
|
|
45
|
+
- Preserve Fairseq tensor contracts and incremental state handling.
|
|
46
|
+
- Remove training-time behaviors: dropout, activation-dropout, quant-noise, onnx_trace.
|
|
47
|
+
|
|
48
|
+
I/O:
|
|
49
|
+
- Input/Output use Fairseq shapes: [T, B, C].
|
|
50
|
+
- Forward returns: (x, attn, None) to match the original call sites in decoder.
|
|
51
|
+
* `attn` is from encoder-attention when requested (alignment).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
fp_layer: nn.Module,
|
|
57
|
+
*,
|
|
58
|
+
qcfg: Optional[QuantConfig] = None,
|
|
59
|
+
fp_name: Optional[str] = None,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
62
|
+
|
|
63
|
+
# --- read-only metadata copied from FP layer -----------------------
|
|
64
|
+
assert hasattr(fp_layer, "embed_dim")
|
|
65
|
+
assert hasattr(fp_layer, "normalize_before")
|
|
66
|
+
self.embed_dim: int = int(fp_layer.embed_dim) # type: ignore[arg-type]
|
|
67
|
+
self.normalize_before: bool = bool(fp_layer.normalize_before)
|
|
68
|
+
|
|
69
|
+
# Cross-self attention flag (when True, key/value can include encoder_out)
|
|
70
|
+
self.cross_self_attention: bool = bool(
|
|
71
|
+
getattr(fp_layer, "cross_self_attention", False)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Generate prefix
|
|
75
|
+
def _safe_prefix(name: Optional[str]) -> str:
|
|
76
|
+
# Avoid "None.*" strings causing collisions
|
|
77
|
+
return (
|
|
78
|
+
name
|
|
79
|
+
if (name is not None and name != "None" and name != "")
|
|
80
|
+
else f"{self.__class__.__name__}_{id(self)}"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
prefix = _safe_prefix(fp_name)
|
|
84
|
+
# Self-attn (PTQ) ---------------------------------------------------
|
|
85
|
+
# Use our MHA wrapper with identical API to the FP module.
|
|
86
|
+
attn_cfg = qcfg.child("self_attn") if qcfg else None
|
|
87
|
+
assert hasattr(fp_layer, "self_attn") and isinstance(
|
|
88
|
+
fp_layer.self_attn, nn.Module
|
|
89
|
+
)
|
|
90
|
+
self.self_attn = QuantFairseqMultiheadAttention(
|
|
91
|
+
fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{prefix}.self_attn"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Optional attention LayerNorm applied to self-attn output (scale_attn)
|
|
95
|
+
# Kept in FP; reuse original instance for weight parity.
|
|
96
|
+
self.attn_ln = getattr(fp_layer, "attn_ln", None)
|
|
97
|
+
|
|
98
|
+
# Optional per-head scaling after self-attn output (scale_heads)
|
|
99
|
+
# Keep exact Parameter reference if present (shape: [num_heads])
|
|
100
|
+
self.c_attn = getattr(fp_layer, "c_attn", None)
|
|
101
|
+
|
|
102
|
+
# Cache head meta for c_attn path
|
|
103
|
+
self.nh = int(getattr(self.self_attn, "num_heads"))
|
|
104
|
+
self.head_dim = int(getattr(self.self_attn, "head_dim"))
|
|
105
|
+
|
|
106
|
+
# Encoder-attn (PTQ) ------------------------------------------------
|
|
107
|
+
# Only present if the original layer was constructed with encoder_attn.
|
|
108
|
+
enc_attn_mod = getattr(fp_layer, "encoder_attn", None)
|
|
109
|
+
assert enc_attn_mod is not None
|
|
110
|
+
enc_cfg = qcfg.child("encoder_attn") if qcfg else None
|
|
111
|
+
self.encoder_attn = QuantFairseqMultiheadAttention(
|
|
112
|
+
enc_attn_mod, qcfg=enc_cfg, fp_name=f"{prefix}.encoder_attn"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Feed-forward (PTQ) ------------------------------------------------
|
|
116
|
+
fc1_cfg = qcfg.child("fc1") if qcfg else None
|
|
117
|
+
fc2_cfg = qcfg.child("fc2") if qcfg else None
|
|
118
|
+
assert hasattr(fp_layer, "fc1") and isinstance(fp_layer.fc1, nn.Module)
|
|
119
|
+
assert hasattr(fp_layer, "fc2") and isinstance(fp_layer.fc2, nn.Module)
|
|
120
|
+
self.fc1 = PTQWrapper(fp_layer.fc1, qcfg=fc1_cfg, fp_name=f"{fp_name}.fc1")
|
|
121
|
+
self.fc2 = PTQWrapper(fp_layer.fc2, qcfg=fc2_cfg, fp_name=f"{fp_name}.fc2")
|
|
122
|
+
|
|
123
|
+
# LayerNorms
|
|
124
|
+
enc_attn_ln_cfg = qcfg.child("encoder_attn_layer_norm") if qcfg else None
|
|
125
|
+
attn_ln_cfg = qcfg.child("self_attn_layer_norm") if qcfg else None
|
|
126
|
+
final_ln_cfg = qcfg.child("final_layer_norm") if qcfg else None
|
|
127
|
+
assert hasattr(fp_layer, "encoder_attn_layer_norm") and isinstance(
|
|
128
|
+
fp_layer.encoder_attn_layer_norm, nn.Module
|
|
129
|
+
)
|
|
130
|
+
assert hasattr(fp_layer, "self_attn_layer_norm") and isinstance(
|
|
131
|
+
fp_layer.self_attn_layer_norm, nn.Module
|
|
132
|
+
)
|
|
133
|
+
assert hasattr(fp_layer, "final_layer_norm") and isinstance(
|
|
134
|
+
fp_layer.final_layer_norm, nn.Module
|
|
135
|
+
)
|
|
136
|
+
self.encoder_attn_layer_norm = PTQWrapper(
|
|
137
|
+
fp_layer.encoder_attn_layer_norm,
|
|
138
|
+
qcfg=enc_attn_ln_cfg,
|
|
139
|
+
fp_name=f"{fp_name}.encoder_attn_layer_norm",
|
|
140
|
+
)
|
|
141
|
+
self.self_attn_layer_norm = PTQWrapper(
|
|
142
|
+
fp_layer.self_attn_layer_norm,
|
|
143
|
+
qcfg=attn_ln_cfg,
|
|
144
|
+
fp_name=f"{fp_name}.self_attn_layer_norm",
|
|
145
|
+
)
|
|
146
|
+
self.final_layer_norm = PTQWrapper(
|
|
147
|
+
fp_layer.final_layer_norm,
|
|
148
|
+
qcfg=final_ln_cfg,
|
|
149
|
+
fp_name=f"{fp_name}.final_layer_norm",
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Optional FFN intermediate LayerNorm (scale_fc), FP
|
|
153
|
+
self.ffn_layernorm = getattr(fp_layer, "ffn_layernorm", None)
|
|
154
|
+
|
|
155
|
+
# Optional residual scaling (scale_resids), keep Parameter reference
|
|
156
|
+
self.w_resid = getattr(fp_layer, "w_resid", None)
|
|
157
|
+
|
|
158
|
+
# Activation function
|
|
159
|
+
self.activation_fn = fp_layer.activation_fn # type: ignore[operator]
|
|
160
|
+
self.obs_activation_fn = self._make_obs("activation_fn")
|
|
161
|
+
|
|
162
|
+
# Alignment flag used by Fairseq (kept for API parity)
|
|
163
|
+
self.need_attn: bool = bool(getattr(fp_layer, "need_attn", True))
|
|
164
|
+
|
|
165
|
+
# No dropout / activation-dropout in inference wrapper
|
|
166
|
+
# (intentionally omitted)
|
|
167
|
+
|
|
168
|
+
# --- observers for external/self-attn KV cache inputs --------------
|
|
169
|
+
self.obs_prev_self_k_in = self._make_obs("prev_self_k_in")
|
|
170
|
+
self.obs_prev_self_v_in = self._make_obs("prev_self_v_in")
|
|
171
|
+
|
|
172
|
+
# ----------------------------------------------------------------------
|
|
173
|
+
def _maybe_apply_head_scale(self, x: Tensor) -> Tensor:
|
|
174
|
+
"""
|
|
175
|
+
Optional per-head scaling (scale_heads) after self-attention.
|
|
176
|
+
x: [T, B, C]
|
|
177
|
+
"""
|
|
178
|
+
if self.c_attn is None:
|
|
179
|
+
return x
|
|
180
|
+
T, B, _ = x.shape
|
|
181
|
+
x = x.view(T, B, self.nh, self.head_dim) # [T,B,H,Dh]
|
|
182
|
+
# einsum over head dim: scales each head independently
|
|
183
|
+
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) # [T,B,H,Dh]
|
|
184
|
+
return x.reshape(T, B, self.nh * self.head_dim) # [T,B,C]
|
|
185
|
+
|
|
186
|
+
# ----------------------------------------------------------------------
|
|
187
|
+
def forward(
|
|
188
|
+
self,
|
|
189
|
+
x: Tensor, # [T,B,C]
|
|
190
|
+
encoder_out: Optional[Tensor] = None, # [S,B,Ce] or None
|
|
191
|
+
encoder_padding_mask: Optional[Tensor] = None, # [B,S] bool or additive float
|
|
192
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
193
|
+
prev_self_attn_state: Optional[List[Tensor]] = None,
|
|
194
|
+
prev_attn_state: Optional[List[Tensor]] = None,
|
|
195
|
+
self_attn_mask: Optional[Tensor] = None, # [T,T] or [B,T,T] or None
|
|
196
|
+
self_attn_padding_mask: Optional[Tensor] = None, # [B,T] or [B,T,T] or None
|
|
197
|
+
need_attn: bool = False,
|
|
198
|
+
need_head_weights: bool = False,
|
|
199
|
+
) -> Tuple[Tensor, Optional[Tensor], None]:
|
|
200
|
+
"""
|
|
201
|
+
Mirrors the original forward, minus training-only logic.
|
|
202
|
+
Returns:
|
|
203
|
+
x': [T,B,C], attn (from encoder-attn when requested), None
|
|
204
|
+
"""
|
|
205
|
+
if need_head_weights:
|
|
206
|
+
need_attn = True
|
|
207
|
+
|
|
208
|
+
# ---- (1) Self-Attention block ------------------------------------
|
|
209
|
+
residual = x
|
|
210
|
+
if self.normalize_before:
|
|
211
|
+
x = self.self_attn_layer_norm(x)
|
|
212
|
+
|
|
213
|
+
# Load provided cached self-attn state (for incremental decoding)
|
|
214
|
+
if prev_self_attn_state is not None:
|
|
215
|
+
prev_key, prev_value = prev_self_attn_state[:2]
|
|
216
|
+
saved_state: Dict[str, Optional[Tensor]] = {
|
|
217
|
+
"prev_key": prev_key,
|
|
218
|
+
"prev_value": prev_value,
|
|
219
|
+
}
|
|
220
|
+
if len(prev_self_attn_state) >= 3:
|
|
221
|
+
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
|
|
222
|
+
assert incremental_state is not None
|
|
223
|
+
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
|
224
|
+
|
|
225
|
+
# Cross-self-attention: prepend encoder_out to K/V at the first step
|
|
226
|
+
y = x
|
|
227
|
+
if self.cross_self_attention:
|
|
228
|
+
_buf = self.self_attn._get_input_buffer(incremental_state)
|
|
229
|
+
no_cache_yet = not (
|
|
230
|
+
incremental_state is not None
|
|
231
|
+
and _buf is not None
|
|
232
|
+
and "prev_key" in _buf
|
|
233
|
+
)
|
|
234
|
+
if no_cache_yet:
|
|
235
|
+
if self_attn_mask is not None:
|
|
236
|
+
assert encoder_out is not None
|
|
237
|
+
# Grow attn mask to cover encoder timesteps (no autoregressive penalty for them)
|
|
238
|
+
self_attn_mask = torch.cat(
|
|
239
|
+
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask),
|
|
240
|
+
dim=1,
|
|
241
|
+
)
|
|
242
|
+
if self_attn_padding_mask is not None:
|
|
243
|
+
if encoder_padding_mask is None:
|
|
244
|
+
assert encoder_out is not None
|
|
245
|
+
encoder_padding_mask = self_attn_padding_mask.new_zeros(
|
|
246
|
+
encoder_out.size(1), encoder_out.size(0)
|
|
247
|
+
)
|
|
248
|
+
# Concatenate encoder pad-mask in front of target pad-mask
|
|
249
|
+
self_attn_padding_mask = torch.cat(
|
|
250
|
+
(encoder_padding_mask, self_attn_padding_mask), dim=1
|
|
251
|
+
)
|
|
252
|
+
assert encoder_out is not None
|
|
253
|
+
y = torch.cat((encoder_out, x), dim=0) # [S+T, B, C]
|
|
254
|
+
|
|
255
|
+
# Self-attn; Fairseq never consumes self-attn weights for alignment here
|
|
256
|
+
x, _ = self.self_attn(
|
|
257
|
+
query=x,
|
|
258
|
+
key=y,
|
|
259
|
+
value=y,
|
|
260
|
+
key_padding_mask=self_attn_padding_mask,
|
|
261
|
+
incremental_state=incremental_state,
|
|
262
|
+
need_weights=False,
|
|
263
|
+
attn_mask=self_attn_mask,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Optional per-head scaling and attn LayerNorm on self-attn output
|
|
267
|
+
x = self._maybe_apply_head_scale(x)
|
|
268
|
+
if self.attn_ln is not None:
|
|
269
|
+
x = self.attn_ln(x)
|
|
270
|
+
|
|
271
|
+
# Residual + (post-norm if applicable)
|
|
272
|
+
x = residual + x
|
|
273
|
+
if not self.normalize_before:
|
|
274
|
+
x = self.self_attn_layer_norm(x)
|
|
275
|
+
|
|
276
|
+
# ---- (2) Encoder-Decoder Attention block --------------------------
|
|
277
|
+
attn_out: Optional[Tensor] = None
|
|
278
|
+
assert encoder_out is not None
|
|
279
|
+
residual = x
|
|
280
|
+
assert self.encoder_attn_layer_norm is not None
|
|
281
|
+
if self.normalize_before:
|
|
282
|
+
x = self.encoder_attn_layer_norm(x)
|
|
283
|
+
|
|
284
|
+
# Load provided cached cross-attn state
|
|
285
|
+
if prev_attn_state is not None:
|
|
286
|
+
prev_key, prev_value = prev_attn_state[:2]
|
|
287
|
+
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
|
288
|
+
if len(prev_attn_state) >= 3:
|
|
289
|
+
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
|
290
|
+
assert incremental_state is not None
|
|
291
|
+
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
|
292
|
+
|
|
293
|
+
# Cross-attn (static_kv=True to reuse encoder K/V across steps)
|
|
294
|
+
assert self.encoder_attn is not None
|
|
295
|
+
x, attn_out = self.encoder_attn(
|
|
296
|
+
query=x,
|
|
297
|
+
key=encoder_out,
|
|
298
|
+
value=encoder_out,
|
|
299
|
+
key_padding_mask=encoder_padding_mask,
|
|
300
|
+
incremental_state=incremental_state,
|
|
301
|
+
static_kv=True,
|
|
302
|
+
need_weights=need_attn or self.need_attn,
|
|
303
|
+
need_head_weights=need_head_weights,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
x = residual + x
|
|
307
|
+
if not self.normalize_before:
|
|
308
|
+
x = self.encoder_attn_layer_norm(x)
|
|
309
|
+
|
|
310
|
+
# ---- (3) Feed-Forward block --------------------------------------
|
|
311
|
+
residual = x
|
|
312
|
+
if self.normalize_before:
|
|
313
|
+
x = self.final_layer_norm(x)
|
|
314
|
+
|
|
315
|
+
# FFN: fc1 -> activation -> (optional LN) -> fc2
|
|
316
|
+
x = self.fc1(x)
|
|
317
|
+
x = self.activation_fn(x) # type: ignore[operator]
|
|
318
|
+
x = self._fq(x, self.obs_activation_fn)
|
|
319
|
+
if self.ffn_layernorm is not None:
|
|
320
|
+
x = self.ffn_layernorm(x)
|
|
321
|
+
x = self.fc2(x)
|
|
322
|
+
|
|
323
|
+
# Optional residual scaling (scale_resids)
|
|
324
|
+
if self.w_resid is not None:
|
|
325
|
+
residual = torch.mul(self.w_resid, residual)
|
|
326
|
+
|
|
327
|
+
x = residual + x
|
|
328
|
+
if not self.normalize_before:
|
|
329
|
+
x = self.final_layer_norm(x)
|
|
330
|
+
|
|
331
|
+
# Return attn from encoder-attn branch when requested; self-attn weights are not returned.
|
|
332
|
+
return x, attn_out, None
|
|
333
|
+
|
|
334
|
+
def forward_external(
|
|
335
|
+
self,
|
|
336
|
+
x: Tensor, # [1, B, C] (embedded current-step token)
|
|
337
|
+
*,
|
|
338
|
+
encoder_out: Optional[Tensor], # [S, B, Ce]
|
|
339
|
+
encoder_padding_mask: Optional[
|
|
340
|
+
Tensor
|
|
341
|
+
] = None, # [B,S] bool or additive-float or [B,1,S] additive-float
|
|
342
|
+
prev_self_k: Optional[Tensor] = None, # [B, H, Tprev, Dh]
|
|
343
|
+
prev_self_v: Optional[Tensor] = None, # [B, H, Tprev, Dh]
|
|
344
|
+
self_attn_mask: Optional[
|
|
345
|
+
Tensor
|
|
346
|
+
] = None, # [1, 1, S_hist+1] or [B,1,S_hist+1] additive-float
|
|
347
|
+
need_attn: bool = False,
|
|
348
|
+
need_head_weights: bool = False,
|
|
349
|
+
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
|
|
350
|
+
"""
|
|
351
|
+
Export-only single-step:
|
|
352
|
+
Returns (x_out[1,B,C], attn_from_cross, new_self_k[B,H,1,Dh], new_self_v[B,H,1,Dh]).
|
|
353
|
+
"""
|
|
354
|
+
if need_head_weights:
|
|
355
|
+
need_attn = True
|
|
356
|
+
|
|
357
|
+
assert x.dim() == 3 and x.size(0) == 1, "x must be [1,B,C]"
|
|
358
|
+
B = x.size(1)
|
|
359
|
+
|
|
360
|
+
# ---- Self-Attention (uses MHA return_new_kv) ----------------------
|
|
361
|
+
x_tbc = x
|
|
362
|
+
if self.normalize_before:
|
|
363
|
+
x_tbc = self.self_attn_layer_norm(x_tbc)
|
|
364
|
+
|
|
365
|
+
# Provide prev KV via incremental_state so wrapper appends internally
|
|
366
|
+
incr: Dict[str, Dict[str, Optional[Tensor]]] = {}
|
|
367
|
+
if prev_self_k is not None and prev_self_v is not None:
|
|
368
|
+
# Attach observers to incoming caches
|
|
369
|
+
prev_self_k = self._fq(prev_self_k, self.obs_prev_self_k_in)
|
|
370
|
+
prev_self_v = self._fq(prev_self_v, self.obs_prev_self_v_in)
|
|
371
|
+
assert isinstance(prev_self_k, Tensor) and isinstance(prev_self_v, Tensor)
|
|
372
|
+
saved = {
|
|
373
|
+
"prev_key": prev_self_k.detach(),
|
|
374
|
+
"prev_value": prev_self_v.detach(),
|
|
375
|
+
}
|
|
376
|
+
self.self_attn._set_input_buffer(incr, saved) # type: ignore[arg-type]
|
|
377
|
+
|
|
378
|
+
# Normalize self-attn additive mask to shapes wrapper accepts: [T,S] or [B,T,S]
|
|
379
|
+
attn_mask_for_wrapper = None
|
|
380
|
+
if self_attn_mask is not None:
|
|
381
|
+
if (
|
|
382
|
+
self_attn_mask.dim() == 3
|
|
383
|
+
and self_attn_mask.size(0) == B
|
|
384
|
+
and self_attn_mask.size(1) == 1
|
|
385
|
+
):
|
|
386
|
+
attn_mask_for_wrapper = self_attn_mask # [B,1,S]
|
|
387
|
+
elif (
|
|
388
|
+
self_attn_mask.dim() == 3
|
|
389
|
+
and self_attn_mask.size(0) == 1
|
|
390
|
+
and self_attn_mask.size(1) == 1
|
|
391
|
+
):
|
|
392
|
+
attn_mask_for_wrapper = self_attn_mask[0] # -> [1,S]
|
|
393
|
+
elif self_attn_mask.dim() == 2 and self_attn_mask.size(0) == 1:
|
|
394
|
+
attn_mask_for_wrapper = self_attn_mask # [1,S]
|
|
395
|
+
else:
|
|
396
|
+
raise RuntimeError(
|
|
397
|
+
"self_attn_mask must be [1,S] or [B,1,S] additive-float."
|
|
398
|
+
)
|
|
399
|
+
attn_mask_for_wrapper = attn_mask_for_wrapper.to(
|
|
400
|
+
dtype=x_tbc.dtype, device=x_tbc.device
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
x_sa, _, new_k_bh, new_v_bh = self.self_attn(
|
|
404
|
+
query=x_tbc,
|
|
405
|
+
key=x_tbc,
|
|
406
|
+
value=x_tbc,
|
|
407
|
+
key_padding_mask=None,
|
|
408
|
+
incremental_state=incr,
|
|
409
|
+
need_weights=False,
|
|
410
|
+
attn_mask=attn_mask_for_wrapper,
|
|
411
|
+
return_new_kv=True, # <<< NEW: ask wrapper to return this step's K/V
|
|
412
|
+
) # x_sa: [1,B,C]; new_k_bh/new_v_bh: [B*H, Tnew, Dh]
|
|
413
|
+
|
|
414
|
+
x_sa = self._maybe_apply_head_scale(x_sa)
|
|
415
|
+
if self.attn_ln is not None:
|
|
416
|
+
x_sa = self.attn_ln(x_sa)
|
|
417
|
+
|
|
418
|
+
x_tbc = x_tbc + x_sa
|
|
419
|
+
if not self.normalize_before:
|
|
420
|
+
x_tbc = self.self_attn_layer_norm(x_tbc)
|
|
421
|
+
|
|
422
|
+
# ---- Encoder-Decoder Attention -----------------------------------
|
|
423
|
+
assert encoder_out is not None, "encoder_out is required in export path"
|
|
424
|
+
residual = x_tbc
|
|
425
|
+
if self.normalize_before:
|
|
426
|
+
assert self.encoder_attn_layer_norm is not None
|
|
427
|
+
x_tbc = self.encoder_attn_layer_norm(x_tbc)
|
|
428
|
+
|
|
429
|
+
enc_kpm = encoder_padding_mask # pass-through; wrapper handles bool/additive
|
|
430
|
+
x_ed, attn_out = self.encoder_attn(
|
|
431
|
+
query=x_tbc,
|
|
432
|
+
key=encoder_out,
|
|
433
|
+
value=encoder_out,
|
|
434
|
+
key_padding_mask=enc_kpm,
|
|
435
|
+
incremental_state=None,
|
|
436
|
+
static_kv=True,
|
|
437
|
+
need_weights=need_attn,
|
|
438
|
+
need_head_weights=need_head_weights,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
x_tbc = residual + x_ed
|
|
442
|
+
if not self.normalize_before:
|
|
443
|
+
assert self.encoder_attn_layer_norm is not None
|
|
444
|
+
x_tbc = self.encoder_attn_layer_norm(x_tbc)
|
|
445
|
+
|
|
446
|
+
# ---- Feed-Forward -------------------------------------------------
|
|
447
|
+
residual = x_tbc
|
|
448
|
+
if self.normalize_before:
|
|
449
|
+
x_tbc = self.final_layer_norm(x_tbc)
|
|
450
|
+
|
|
451
|
+
x_tbc = self.fc1(x_tbc)
|
|
452
|
+
x_tbc = self.activation_fn(x_tbc) # type: ignore[operator]
|
|
453
|
+
x_tbc = self._fq(x_tbc, self.obs_activation_fn)
|
|
454
|
+
if self.ffn_layernorm is not None:
|
|
455
|
+
x_tbc = self.ffn_layernorm(x_tbc)
|
|
456
|
+
x_tbc = self.fc2(x_tbc)
|
|
457
|
+
|
|
458
|
+
if self.w_resid is not None:
|
|
459
|
+
residual = torch.mul(self.w_resid, residual)
|
|
460
|
+
|
|
461
|
+
x_tbc = residual + x_tbc
|
|
462
|
+
if not self.normalize_before:
|
|
463
|
+
x_tbc = self.final_layer_norm(x_tbc)
|
|
464
|
+
|
|
465
|
+
return (
|
|
466
|
+
x_tbc,
|
|
467
|
+
attn_out,
|
|
468
|
+
new_k_bh,
|
|
469
|
+
new_v_bh,
|
|
470
|
+
) # [1,B,C], attn, [B*H, Tnew, Dh], [B*H, Tnew, Dh]
|
|
471
|
+
|
|
472
|
+
def _all_observers(self) -> Iterable:
|
|
473
|
+
"""
|
|
474
|
+
Expose all observers from child PTQ-wrapped modules.
|
|
475
|
+
This layer itself does not add extra per-tensor observers.
|
|
476
|
+
"""
|
|
477
|
+
# local observers
|
|
478
|
+
yield from (
|
|
479
|
+
self.obs_activation_fn,
|
|
480
|
+
self.obs_prev_self_k_in,
|
|
481
|
+
self.obs_prev_self_v_in,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
for m in (
|
|
485
|
+
self.self_attn,
|
|
486
|
+
self.encoder_attn,
|
|
487
|
+
self.fc1,
|
|
488
|
+
self.fc2,
|
|
489
|
+
self.encoder_attn_layer_norm,
|
|
490
|
+
self.self_attn_layer_norm,
|
|
491
|
+
self.final_layer_norm,
|
|
492
|
+
):
|
|
493
|
+
if isinstance(m, QuantModuleBase) and m is not None:
|
|
494
|
+
yield from m._all_observers()
|
|
@@ -33,6 +33,7 @@ _CORE_MODULES = (
|
|
|
33
33
|
"tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer",
|
|
34
34
|
"tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
|
|
35
35
|
# fairseq
|
|
36
|
+
"tico.experimental.quantization.ptq.wrappers.fairseq.quant_decoder_layer",
|
|
36
37
|
"tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder",
|
|
37
38
|
"tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder_layer",
|
|
38
39
|
"tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha",
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
|
|
21
|
+
from torch.export import ExportedProgram
|
|
22
|
+
|
|
23
|
+
from tico.utils import logging
|
|
24
|
+
from tico.utils.graph import create_node
|
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
27
|
+
from tico.utils.validate_args_kwargs import MatmulArgs
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Converter: # type: ignore[empty-body]
|
|
31
|
+
def __init__(self):
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class MatmulToLinearConverter(Converter):
|
|
42
|
+
def __init__(self):
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
46
|
+
graph_module = exported_program.graph_module
|
|
47
|
+
graph = graph_module.graph
|
|
48
|
+
|
|
49
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
50
|
+
|
|
51
|
+
lhs = mm_args.input
|
|
52
|
+
rhs = mm_args.other
|
|
53
|
+
|
|
54
|
+
with graph.inserting_before(node):
|
|
55
|
+
transpose_node = create_node(
|
|
56
|
+
graph,
|
|
57
|
+
torch.ops.aten.permute.default,
|
|
58
|
+
args=(rhs, [1, 0]),
|
|
59
|
+
)
|
|
60
|
+
fc_node = create_node(
|
|
61
|
+
graph,
|
|
62
|
+
torch.ops.aten.linear.default,
|
|
63
|
+
args=(lhs, transpose_node),
|
|
64
|
+
)
|
|
65
|
+
node.replace_all_uses_with(fc_node, propagate_meta=True)
|
|
66
|
+
|
|
67
|
+
return fc_node
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
71
|
+
def __init__(self):
|
|
72
|
+
super().__init__()
|
|
73
|
+
|
|
74
|
+
def match(self, exported_program, node) -> bool:
|
|
75
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
76
|
+
return False
|
|
77
|
+
|
|
78
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
79
|
+
|
|
80
|
+
rhs = mm_args.other
|
|
81
|
+
if isinstance(rhs, torch.fx.Node):
|
|
82
|
+
if is_lifted_tensor_constant(exported_program, rhs):
|
|
83
|
+
return True
|
|
84
|
+
elif is_param(exported_program, rhs):
|
|
85
|
+
return True
|
|
86
|
+
elif is_buffer(exported_program, rhs):
|
|
87
|
+
return True
|
|
88
|
+
else:
|
|
89
|
+
return False
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
93
|
+
return super().convert(exported_program, node)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
97
|
+
def __init__(self):
|
|
98
|
+
super().__init__()
|
|
99
|
+
|
|
100
|
+
def match(self, exported_program, node) -> bool:
|
|
101
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs)
|
|
105
|
+
lhs = mm_args.input
|
|
106
|
+
if isinstance(lhs, torch.fx.Node):
|
|
107
|
+
if is_lifted_tensor_constant(exported_program, lhs):
|
|
108
|
+
return True
|
|
109
|
+
elif is_param(exported_program, lhs):
|
|
110
|
+
return True
|
|
111
|
+
elif is_buffer(exported_program, lhs):
|
|
112
|
+
return True
|
|
113
|
+
else:
|
|
114
|
+
return False
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
118
|
+
return super().convert(exported_program, node)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@trace_graph_diff_on_pass
|
|
122
|
+
class ConvertMatmulToLinear(PassBase):
|
|
123
|
+
"""
|
|
124
|
+
This pass converts matmul to linear selectively
|
|
125
|
+
|
|
126
|
+
How to select between `matmul` and `linear`?
|
|
127
|
+
|
|
128
|
+
* Linear has better quantization accuracy (NPU backend)
|
|
129
|
+
Due to ONE compiler's quantization policy;
|
|
130
|
+
FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
|
|
131
|
+
BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
|
|
132
|
+
|
|
133
|
+
* Matmul to Linear requires Transpose, which may harm latency
|
|
134
|
+
When RHS is constant, addtional transpose can be folded.
|
|
135
|
+
|
|
136
|
+
[RHS non-const case]
|
|
137
|
+
Constant folding cannot be performed.
|
|
138
|
+
|
|
139
|
+
lhs rhs (non-const)
|
|
140
|
+
| |
|
|
141
|
+
| transpose
|
|
142
|
+
| |
|
|
143
|
+
-- linear --
|
|
144
|
+
|
|
|
145
|
+
out
|
|
146
|
+
|
|
147
|
+
[RHS const case]
|
|
148
|
+
Constant folding can be performed to
|
|
149
|
+
|
|
150
|
+
lhs rhs (const) lh rhs (folded const)
|
|
151
|
+
| | | |
|
|
152
|
+
| transpose | |
|
|
153
|
+
| | | |
|
|
154
|
+
-- linear -- --> -- linear --
|
|
155
|
+
| |
|
|
156
|
+
out out
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
|
|
160
|
+
enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
enable_lhs_const: Optional[bool] = False,
|
|
166
|
+
enable_rhs_const: Optional[bool] = True,
|
|
167
|
+
):
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.converters: List[Converter] = []
|
|
170
|
+
if enable_lhs_const:
|
|
171
|
+
self.converters.append(LhsConstMatmulToLinearConverter())
|
|
172
|
+
if enable_rhs_const:
|
|
173
|
+
self.converters.append(RhsConstMatmulToLinearConverter())
|
|
174
|
+
|
|
175
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
176
|
+
logger = logging.getLogger(__name__)
|
|
177
|
+
|
|
178
|
+
graph_module = exported_program.graph_module
|
|
179
|
+
graph = graph_module.graph
|
|
180
|
+
modified = False
|
|
181
|
+
for node in graph.nodes:
|
|
182
|
+
if not node.op == "call_function":
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
for converter in self.converters:
|
|
186
|
+
if not converter.match(exported_program, node):
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
new_node = converter.convert(exported_program, node)
|
|
190
|
+
modified = True
|
|
191
|
+
logger.debug(
|
|
192
|
+
f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
|
|
193
|
+
)
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
graph.eliminate_dead_code()
|
|
197
|
+
graph.lint()
|
|
198
|
+
graph_module.recompile()
|
|
199
|
+
|
|
200
|
+
return PassResult(modified)
|