tico 0.1.0.dev250916__py3-none-any.whl → 0.1.0.dev250918__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

@@ -0,0 +1,494 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ from typing import Dict, Iterable, List, Optional, Tuple
22
+
23
+ import torch
24
+ from torch import nn, Tensor
25
+
26
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
27
+ from tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha import (
28
+ QuantFairseqMultiheadAttention,
29
+ )
30
+ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
31
+ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
32
+ QuantModuleBase,
33
+ )
34
+ from tico.experimental.quantization.ptq.wrappers.registry import try_register
35
+
36
+
37
+ @try_register("fairseq.modules.transformer_layer.TransformerDecoderLayerBase")
38
+ class QuantFairseqDecoderLayer(QuantModuleBase):
39
+ """
40
+ Quant-aware drop-in replacement for Fairseq TransformerDecoderLayerBase.
41
+
42
+ Design (inference-only):
43
+ - Keep LayerNorms and scalar head/residual scalers in FP.
44
+ - PTQ-wrap: self_attn, (optional) encoder_attn, fc1, fc2.
45
+ - Preserve Fairseq tensor contracts and incremental state handling.
46
+ - Remove training-time behaviors: dropout, activation-dropout, quant-noise, onnx_trace.
47
+
48
+ I/O:
49
+ - Input/Output use Fairseq shapes: [T, B, C].
50
+ - Forward returns: (x, attn, None) to match the original call sites in decoder.
51
+ * `attn` is from encoder-attention when requested (alignment).
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ fp_layer: nn.Module,
57
+ *,
58
+ qcfg: Optional[QuantConfig] = None,
59
+ fp_name: Optional[str] = None,
60
+ ):
61
+ super().__init__(qcfg, fp_name=fp_name)
62
+
63
+ # --- read-only metadata copied from FP layer -----------------------
64
+ assert hasattr(fp_layer, "embed_dim")
65
+ assert hasattr(fp_layer, "normalize_before")
66
+ self.embed_dim: int = int(fp_layer.embed_dim) # type: ignore[arg-type]
67
+ self.normalize_before: bool = bool(fp_layer.normalize_before)
68
+
69
+ # Cross-self attention flag (when True, key/value can include encoder_out)
70
+ self.cross_self_attention: bool = bool(
71
+ getattr(fp_layer, "cross_self_attention", False)
72
+ )
73
+
74
+ # Generate prefix
75
+ def _safe_prefix(name: Optional[str]) -> str:
76
+ # Avoid "None.*" strings causing collisions
77
+ return (
78
+ name
79
+ if (name is not None and name != "None" and name != "")
80
+ else f"{self.__class__.__name__}_{id(self)}"
81
+ )
82
+
83
+ prefix = _safe_prefix(fp_name)
84
+ # Self-attn (PTQ) ---------------------------------------------------
85
+ # Use our MHA wrapper with identical API to the FP module.
86
+ attn_cfg = qcfg.child("self_attn") if qcfg else None
87
+ assert hasattr(fp_layer, "self_attn") and isinstance(
88
+ fp_layer.self_attn, nn.Module
89
+ )
90
+ self.self_attn = QuantFairseqMultiheadAttention(
91
+ fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{prefix}.self_attn"
92
+ )
93
+
94
+ # Optional attention LayerNorm applied to self-attn output (scale_attn)
95
+ # Kept in FP; reuse original instance for weight parity.
96
+ self.attn_ln = getattr(fp_layer, "attn_ln", None)
97
+
98
+ # Optional per-head scaling after self-attn output (scale_heads)
99
+ # Keep exact Parameter reference if present (shape: [num_heads])
100
+ self.c_attn = getattr(fp_layer, "c_attn", None)
101
+
102
+ # Cache head meta for c_attn path
103
+ self.nh = int(getattr(self.self_attn, "num_heads"))
104
+ self.head_dim = int(getattr(self.self_attn, "head_dim"))
105
+
106
+ # Encoder-attn (PTQ) ------------------------------------------------
107
+ # Only present if the original layer was constructed with encoder_attn.
108
+ enc_attn_mod = getattr(fp_layer, "encoder_attn", None)
109
+ assert enc_attn_mod is not None
110
+ enc_cfg = qcfg.child("encoder_attn") if qcfg else None
111
+ self.encoder_attn = QuantFairseqMultiheadAttention(
112
+ enc_attn_mod, qcfg=enc_cfg, fp_name=f"{prefix}.encoder_attn"
113
+ )
114
+
115
+ # Feed-forward (PTQ) ------------------------------------------------
116
+ fc1_cfg = qcfg.child("fc1") if qcfg else None
117
+ fc2_cfg = qcfg.child("fc2") if qcfg else None
118
+ assert hasattr(fp_layer, "fc1") and isinstance(fp_layer.fc1, nn.Module)
119
+ assert hasattr(fp_layer, "fc2") and isinstance(fp_layer.fc2, nn.Module)
120
+ self.fc1 = PTQWrapper(fp_layer.fc1, qcfg=fc1_cfg, fp_name=f"{fp_name}.fc1")
121
+ self.fc2 = PTQWrapper(fp_layer.fc2, qcfg=fc2_cfg, fp_name=f"{fp_name}.fc2")
122
+
123
+ # LayerNorms
124
+ enc_attn_ln_cfg = qcfg.child("encoder_attn_layer_norm") if qcfg else None
125
+ attn_ln_cfg = qcfg.child("self_attn_layer_norm") if qcfg else None
126
+ final_ln_cfg = qcfg.child("final_layer_norm") if qcfg else None
127
+ assert hasattr(fp_layer, "encoder_attn_layer_norm") and isinstance(
128
+ fp_layer.encoder_attn_layer_norm, nn.Module
129
+ )
130
+ assert hasattr(fp_layer, "self_attn_layer_norm") and isinstance(
131
+ fp_layer.self_attn_layer_norm, nn.Module
132
+ )
133
+ assert hasattr(fp_layer, "final_layer_norm") and isinstance(
134
+ fp_layer.final_layer_norm, nn.Module
135
+ )
136
+ self.encoder_attn_layer_norm = PTQWrapper(
137
+ fp_layer.encoder_attn_layer_norm,
138
+ qcfg=enc_attn_ln_cfg,
139
+ fp_name=f"{fp_name}.encoder_attn_layer_norm",
140
+ )
141
+ self.self_attn_layer_norm = PTQWrapper(
142
+ fp_layer.self_attn_layer_norm,
143
+ qcfg=attn_ln_cfg,
144
+ fp_name=f"{fp_name}.self_attn_layer_norm",
145
+ )
146
+ self.final_layer_norm = PTQWrapper(
147
+ fp_layer.final_layer_norm,
148
+ qcfg=final_ln_cfg,
149
+ fp_name=f"{fp_name}.final_layer_norm",
150
+ )
151
+
152
+ # Optional FFN intermediate LayerNorm (scale_fc), FP
153
+ self.ffn_layernorm = getattr(fp_layer, "ffn_layernorm", None)
154
+
155
+ # Optional residual scaling (scale_resids), keep Parameter reference
156
+ self.w_resid = getattr(fp_layer, "w_resid", None)
157
+
158
+ # Activation function
159
+ self.activation_fn = fp_layer.activation_fn # type: ignore[operator]
160
+ self.obs_activation_fn = self._make_obs("activation_fn")
161
+
162
+ # Alignment flag used by Fairseq (kept for API parity)
163
+ self.need_attn: bool = bool(getattr(fp_layer, "need_attn", True))
164
+
165
+ # No dropout / activation-dropout in inference wrapper
166
+ # (intentionally omitted)
167
+
168
+ # --- observers for external/self-attn KV cache inputs --------------
169
+ self.obs_prev_self_k_in = self._make_obs("prev_self_k_in")
170
+ self.obs_prev_self_v_in = self._make_obs("prev_self_v_in")
171
+
172
+ # ----------------------------------------------------------------------
173
+ def _maybe_apply_head_scale(self, x: Tensor) -> Tensor:
174
+ """
175
+ Optional per-head scaling (scale_heads) after self-attention.
176
+ x: [T, B, C]
177
+ """
178
+ if self.c_attn is None:
179
+ return x
180
+ T, B, _ = x.shape
181
+ x = x.view(T, B, self.nh, self.head_dim) # [T,B,H,Dh]
182
+ # einsum over head dim: scales each head independently
183
+ x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) # [T,B,H,Dh]
184
+ return x.reshape(T, B, self.nh * self.head_dim) # [T,B,C]
185
+
186
+ # ----------------------------------------------------------------------
187
+ def forward(
188
+ self,
189
+ x: Tensor, # [T,B,C]
190
+ encoder_out: Optional[Tensor] = None, # [S,B,Ce] or None
191
+ encoder_padding_mask: Optional[Tensor] = None, # [B,S] bool or additive float
192
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
193
+ prev_self_attn_state: Optional[List[Tensor]] = None,
194
+ prev_attn_state: Optional[List[Tensor]] = None,
195
+ self_attn_mask: Optional[Tensor] = None, # [T,T] or [B,T,T] or None
196
+ self_attn_padding_mask: Optional[Tensor] = None, # [B,T] or [B,T,T] or None
197
+ need_attn: bool = False,
198
+ need_head_weights: bool = False,
199
+ ) -> Tuple[Tensor, Optional[Tensor], None]:
200
+ """
201
+ Mirrors the original forward, minus training-only logic.
202
+ Returns:
203
+ x': [T,B,C], attn (from encoder-attn when requested), None
204
+ """
205
+ if need_head_weights:
206
+ need_attn = True
207
+
208
+ # ---- (1) Self-Attention block ------------------------------------
209
+ residual = x
210
+ if self.normalize_before:
211
+ x = self.self_attn_layer_norm(x)
212
+
213
+ # Load provided cached self-attn state (for incremental decoding)
214
+ if prev_self_attn_state is not None:
215
+ prev_key, prev_value = prev_self_attn_state[:2]
216
+ saved_state: Dict[str, Optional[Tensor]] = {
217
+ "prev_key": prev_key,
218
+ "prev_value": prev_value,
219
+ }
220
+ if len(prev_self_attn_state) >= 3:
221
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
222
+ assert incremental_state is not None
223
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
224
+
225
+ # Cross-self-attention: prepend encoder_out to K/V at the first step
226
+ y = x
227
+ if self.cross_self_attention:
228
+ _buf = self.self_attn._get_input_buffer(incremental_state)
229
+ no_cache_yet = not (
230
+ incremental_state is not None
231
+ and _buf is not None
232
+ and "prev_key" in _buf
233
+ )
234
+ if no_cache_yet:
235
+ if self_attn_mask is not None:
236
+ assert encoder_out is not None
237
+ # Grow attn mask to cover encoder timesteps (no autoregressive penalty for them)
238
+ self_attn_mask = torch.cat(
239
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask),
240
+ dim=1,
241
+ )
242
+ if self_attn_padding_mask is not None:
243
+ if encoder_padding_mask is None:
244
+ assert encoder_out is not None
245
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
246
+ encoder_out.size(1), encoder_out.size(0)
247
+ )
248
+ # Concatenate encoder pad-mask in front of target pad-mask
249
+ self_attn_padding_mask = torch.cat(
250
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
251
+ )
252
+ assert encoder_out is not None
253
+ y = torch.cat((encoder_out, x), dim=0) # [S+T, B, C]
254
+
255
+ # Self-attn; Fairseq never consumes self-attn weights for alignment here
256
+ x, _ = self.self_attn(
257
+ query=x,
258
+ key=y,
259
+ value=y,
260
+ key_padding_mask=self_attn_padding_mask,
261
+ incremental_state=incremental_state,
262
+ need_weights=False,
263
+ attn_mask=self_attn_mask,
264
+ )
265
+
266
+ # Optional per-head scaling and attn LayerNorm on self-attn output
267
+ x = self._maybe_apply_head_scale(x)
268
+ if self.attn_ln is not None:
269
+ x = self.attn_ln(x)
270
+
271
+ # Residual + (post-norm if applicable)
272
+ x = residual + x
273
+ if not self.normalize_before:
274
+ x = self.self_attn_layer_norm(x)
275
+
276
+ # ---- (2) Encoder-Decoder Attention block --------------------------
277
+ attn_out: Optional[Tensor] = None
278
+ assert encoder_out is not None
279
+ residual = x
280
+ assert self.encoder_attn_layer_norm is not None
281
+ if self.normalize_before:
282
+ x = self.encoder_attn_layer_norm(x)
283
+
284
+ # Load provided cached cross-attn state
285
+ if prev_attn_state is not None:
286
+ prev_key, prev_value = prev_attn_state[:2]
287
+ saved_state = {"prev_key": prev_key, "prev_value": prev_value}
288
+ if len(prev_attn_state) >= 3:
289
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
290
+ assert incremental_state is not None
291
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
292
+
293
+ # Cross-attn (static_kv=True to reuse encoder K/V across steps)
294
+ assert self.encoder_attn is not None
295
+ x, attn_out = self.encoder_attn(
296
+ query=x,
297
+ key=encoder_out,
298
+ value=encoder_out,
299
+ key_padding_mask=encoder_padding_mask,
300
+ incremental_state=incremental_state,
301
+ static_kv=True,
302
+ need_weights=need_attn or self.need_attn,
303
+ need_head_weights=need_head_weights,
304
+ )
305
+
306
+ x = residual + x
307
+ if not self.normalize_before:
308
+ x = self.encoder_attn_layer_norm(x)
309
+
310
+ # ---- (3) Feed-Forward block --------------------------------------
311
+ residual = x
312
+ if self.normalize_before:
313
+ x = self.final_layer_norm(x)
314
+
315
+ # FFN: fc1 -> activation -> (optional LN) -> fc2
316
+ x = self.fc1(x)
317
+ x = self.activation_fn(x) # type: ignore[operator]
318
+ x = self._fq(x, self.obs_activation_fn)
319
+ if self.ffn_layernorm is not None:
320
+ x = self.ffn_layernorm(x)
321
+ x = self.fc2(x)
322
+
323
+ # Optional residual scaling (scale_resids)
324
+ if self.w_resid is not None:
325
+ residual = torch.mul(self.w_resid, residual)
326
+
327
+ x = residual + x
328
+ if not self.normalize_before:
329
+ x = self.final_layer_norm(x)
330
+
331
+ # Return attn from encoder-attn branch when requested; self-attn weights are not returned.
332
+ return x, attn_out, None
333
+
334
+ def forward_external(
335
+ self,
336
+ x: Tensor, # [1, B, C] (embedded current-step token)
337
+ *,
338
+ encoder_out: Optional[Tensor], # [S, B, Ce]
339
+ encoder_padding_mask: Optional[
340
+ Tensor
341
+ ] = None, # [B,S] bool or additive-float or [B,1,S] additive-float
342
+ prev_self_k: Optional[Tensor] = None, # [B, H, Tprev, Dh]
343
+ prev_self_v: Optional[Tensor] = None, # [B, H, Tprev, Dh]
344
+ self_attn_mask: Optional[
345
+ Tensor
346
+ ] = None, # [1, 1, S_hist+1] or [B,1,S_hist+1] additive-float
347
+ need_attn: bool = False,
348
+ need_head_weights: bool = False,
349
+ ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
350
+ """
351
+ Export-only single-step:
352
+ Returns (x_out[1,B,C], attn_from_cross, new_self_k[B,H,1,Dh], new_self_v[B,H,1,Dh]).
353
+ """
354
+ if need_head_weights:
355
+ need_attn = True
356
+
357
+ assert x.dim() == 3 and x.size(0) == 1, "x must be [1,B,C]"
358
+ B = x.size(1)
359
+
360
+ # ---- Self-Attention (uses MHA return_new_kv) ----------------------
361
+ x_tbc = x
362
+ if self.normalize_before:
363
+ x_tbc = self.self_attn_layer_norm(x_tbc)
364
+
365
+ # Provide prev KV via incremental_state so wrapper appends internally
366
+ incr: Dict[str, Dict[str, Optional[Tensor]]] = {}
367
+ if prev_self_k is not None and prev_self_v is not None:
368
+ # Attach observers to incoming caches
369
+ prev_self_k = self._fq(prev_self_k, self.obs_prev_self_k_in)
370
+ prev_self_v = self._fq(prev_self_v, self.obs_prev_self_v_in)
371
+ assert isinstance(prev_self_k, Tensor) and isinstance(prev_self_v, Tensor)
372
+ saved = {
373
+ "prev_key": prev_self_k.detach(),
374
+ "prev_value": prev_self_v.detach(),
375
+ }
376
+ self.self_attn._set_input_buffer(incr, saved) # type: ignore[arg-type]
377
+
378
+ # Normalize self-attn additive mask to shapes wrapper accepts: [T,S] or [B,T,S]
379
+ attn_mask_for_wrapper = None
380
+ if self_attn_mask is not None:
381
+ if (
382
+ self_attn_mask.dim() == 3
383
+ and self_attn_mask.size(0) == B
384
+ and self_attn_mask.size(1) == 1
385
+ ):
386
+ attn_mask_for_wrapper = self_attn_mask # [B,1,S]
387
+ elif (
388
+ self_attn_mask.dim() == 3
389
+ and self_attn_mask.size(0) == 1
390
+ and self_attn_mask.size(1) == 1
391
+ ):
392
+ attn_mask_for_wrapper = self_attn_mask[0] # -> [1,S]
393
+ elif self_attn_mask.dim() == 2 and self_attn_mask.size(0) == 1:
394
+ attn_mask_for_wrapper = self_attn_mask # [1,S]
395
+ else:
396
+ raise RuntimeError(
397
+ "self_attn_mask must be [1,S] or [B,1,S] additive-float."
398
+ )
399
+ attn_mask_for_wrapper = attn_mask_for_wrapper.to(
400
+ dtype=x_tbc.dtype, device=x_tbc.device
401
+ )
402
+
403
+ x_sa, _, new_k_bh, new_v_bh = self.self_attn(
404
+ query=x_tbc,
405
+ key=x_tbc,
406
+ value=x_tbc,
407
+ key_padding_mask=None,
408
+ incremental_state=incr,
409
+ need_weights=False,
410
+ attn_mask=attn_mask_for_wrapper,
411
+ return_new_kv=True, # <<< NEW: ask wrapper to return this step's K/V
412
+ ) # x_sa: [1,B,C]; new_k_bh/new_v_bh: [B*H, Tnew, Dh]
413
+
414
+ x_sa = self._maybe_apply_head_scale(x_sa)
415
+ if self.attn_ln is not None:
416
+ x_sa = self.attn_ln(x_sa)
417
+
418
+ x_tbc = x_tbc + x_sa
419
+ if not self.normalize_before:
420
+ x_tbc = self.self_attn_layer_norm(x_tbc)
421
+
422
+ # ---- Encoder-Decoder Attention -----------------------------------
423
+ assert encoder_out is not None, "encoder_out is required in export path"
424
+ residual = x_tbc
425
+ if self.normalize_before:
426
+ assert self.encoder_attn_layer_norm is not None
427
+ x_tbc = self.encoder_attn_layer_norm(x_tbc)
428
+
429
+ enc_kpm = encoder_padding_mask # pass-through; wrapper handles bool/additive
430
+ x_ed, attn_out = self.encoder_attn(
431
+ query=x_tbc,
432
+ key=encoder_out,
433
+ value=encoder_out,
434
+ key_padding_mask=enc_kpm,
435
+ incremental_state=None,
436
+ static_kv=True,
437
+ need_weights=need_attn,
438
+ need_head_weights=need_head_weights,
439
+ )
440
+
441
+ x_tbc = residual + x_ed
442
+ if not self.normalize_before:
443
+ assert self.encoder_attn_layer_norm is not None
444
+ x_tbc = self.encoder_attn_layer_norm(x_tbc)
445
+
446
+ # ---- Feed-Forward -------------------------------------------------
447
+ residual = x_tbc
448
+ if self.normalize_before:
449
+ x_tbc = self.final_layer_norm(x_tbc)
450
+
451
+ x_tbc = self.fc1(x_tbc)
452
+ x_tbc = self.activation_fn(x_tbc) # type: ignore[operator]
453
+ x_tbc = self._fq(x_tbc, self.obs_activation_fn)
454
+ if self.ffn_layernorm is not None:
455
+ x_tbc = self.ffn_layernorm(x_tbc)
456
+ x_tbc = self.fc2(x_tbc)
457
+
458
+ if self.w_resid is not None:
459
+ residual = torch.mul(self.w_resid, residual)
460
+
461
+ x_tbc = residual + x_tbc
462
+ if not self.normalize_before:
463
+ x_tbc = self.final_layer_norm(x_tbc)
464
+
465
+ return (
466
+ x_tbc,
467
+ attn_out,
468
+ new_k_bh,
469
+ new_v_bh,
470
+ ) # [1,B,C], attn, [B*H, Tnew, Dh], [B*H, Tnew, Dh]
471
+
472
+ def _all_observers(self) -> Iterable:
473
+ """
474
+ Expose all observers from child PTQ-wrapped modules.
475
+ This layer itself does not add extra per-tensor observers.
476
+ """
477
+ # local observers
478
+ yield from (
479
+ self.obs_activation_fn,
480
+ self.obs_prev_self_k_in,
481
+ self.obs_prev_self_v_in,
482
+ )
483
+
484
+ for m in (
485
+ self.self_attn,
486
+ self.encoder_attn,
487
+ self.fc1,
488
+ self.fc2,
489
+ self.encoder_attn_layer_norm,
490
+ self.self_attn_layer_norm,
491
+ self.final_layer_norm,
492
+ ):
493
+ if isinstance(m, QuantModuleBase) and m is not None:
494
+ yield from m._all_observers()
@@ -33,6 +33,7 @@ _CORE_MODULES = (
33
33
  "tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer",
34
34
  "tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
35
35
  # fairseq
36
+ "tico.experimental.quantization.ptq.wrappers.fairseq.quant_decoder_layer",
36
37
  "tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder",
37
38
  "tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder_layer",
38
39
  "tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha",
@@ -0,0 +1,200 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.utils import logging
24
+ from tico.utils.graph import create_node
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.validate_args_kwargs import MatmulArgs
28
+
29
+
30
+ class Converter: # type: ignore[empty-body]
31
+ def __init__(self):
32
+ super().__init__()
33
+
34
+ def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
35
+ return False
36
+
37
+ def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
38
+ pass
39
+
40
+
41
+ class MatmulToLinearConverter(Converter):
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ def convert(self, exported_program, node) -> torch.fx.Node:
46
+ graph_module = exported_program.graph_module
47
+ graph = graph_module.graph
48
+
49
+ mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
50
+
51
+ lhs = mm_args.input
52
+ rhs = mm_args.other
53
+
54
+ with graph.inserting_before(node):
55
+ transpose_node = create_node(
56
+ graph,
57
+ torch.ops.aten.permute.default,
58
+ args=(rhs, [1, 0]),
59
+ )
60
+ fc_node = create_node(
61
+ graph,
62
+ torch.ops.aten.linear.default,
63
+ args=(lhs, transpose_node),
64
+ )
65
+ node.replace_all_uses_with(fc_node, propagate_meta=True)
66
+
67
+ return fc_node
68
+
69
+
70
+ class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
71
+ def __init__(self):
72
+ super().__init__()
73
+
74
+ def match(self, exported_program, node) -> bool:
75
+ if not node.target == torch.ops.aten.mm.default:
76
+ return False
77
+
78
+ mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
79
+
80
+ rhs = mm_args.other
81
+ if isinstance(rhs, torch.fx.Node):
82
+ if is_lifted_tensor_constant(exported_program, rhs):
83
+ return True
84
+ elif is_param(exported_program, rhs):
85
+ return True
86
+ elif is_buffer(exported_program, rhs):
87
+ return True
88
+ else:
89
+ return False
90
+ return False
91
+
92
+ def convert(self, exported_program, node) -> torch.fx.Node:
93
+ return super().convert(exported_program, node)
94
+
95
+
96
+ class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
97
+ def __init__(self):
98
+ super().__init__()
99
+
100
+ def match(self, exported_program, node) -> bool:
101
+ if not node.target == torch.ops.aten.mm.default:
102
+ return False
103
+
104
+ mm_args = MatmulArgs(*node.args, **node.kwargs)
105
+ lhs = mm_args.input
106
+ if isinstance(lhs, torch.fx.Node):
107
+ if is_lifted_tensor_constant(exported_program, lhs):
108
+ return True
109
+ elif is_param(exported_program, lhs):
110
+ return True
111
+ elif is_buffer(exported_program, lhs):
112
+ return True
113
+ else:
114
+ return False
115
+ return False
116
+
117
+ def convert(self, exported_program, node) -> torch.fx.Node:
118
+ return super().convert(exported_program, node)
119
+
120
+
121
+ @trace_graph_diff_on_pass
122
+ class ConvertMatmulToLinear(PassBase):
123
+ """
124
+ This pass converts matmul to linear selectively
125
+
126
+ How to select between `matmul` and `linear`?
127
+
128
+ * Linear has better quantization accuracy (NPU backend)
129
+ Due to ONE compiler's quantization policy;
130
+ FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
131
+ BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
132
+
133
+ * Matmul to Linear requires Transpose, which may harm latency
134
+ When RHS is constant, addtional transpose can be folded.
135
+
136
+ [RHS non-const case]
137
+ Constant folding cannot be performed.
138
+
139
+ lhs rhs (non-const)
140
+ | |
141
+ | transpose
142
+ | |
143
+ -- linear --
144
+ |
145
+ out
146
+
147
+ [RHS const case]
148
+ Constant folding can be performed to
149
+
150
+ lhs rhs (const) lh rhs (folded const)
151
+ | | | |
152
+ | transpose | |
153
+ | | | |
154
+ -- linear -- --> -- linear --
155
+ | |
156
+ out out
157
+
158
+
159
+ enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
160
+ enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ enable_lhs_const: Optional[bool] = False,
166
+ enable_rhs_const: Optional[bool] = True,
167
+ ):
168
+ super().__init__()
169
+ self.converters: List[Converter] = []
170
+ if enable_lhs_const:
171
+ self.converters.append(LhsConstMatmulToLinearConverter())
172
+ if enable_rhs_const:
173
+ self.converters.append(RhsConstMatmulToLinearConverter())
174
+
175
+ def call(self, exported_program: ExportedProgram) -> PassResult:
176
+ logger = logging.getLogger(__name__)
177
+
178
+ graph_module = exported_program.graph_module
179
+ graph = graph_module.graph
180
+ modified = False
181
+ for node in graph.nodes:
182
+ if not node.op == "call_function":
183
+ continue
184
+
185
+ for converter in self.converters:
186
+ if not converter.match(exported_program, node):
187
+ continue
188
+
189
+ new_node = converter.convert(exported_program, node)
190
+ modified = True
191
+ logger.debug(
192
+ f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
193
+ )
194
+ continue
195
+
196
+ graph.eliminate_dead_code()
197
+ graph.lint()
198
+ graph_module.recompile()
199
+
200
+ return PassResult(modified)