tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
Q) Why the name "SingleStep"?
|
|
23
|
+
|
|
24
|
+
Fairseq's decoder already advances one token at a time during generation,
|
|
25
|
+
but the default path is "stateful" and "shape-polymorphic": it owns and
|
|
26
|
+
mutates K/V caches internally, prefix lengths and triangular masks grow with
|
|
27
|
+
the step, and beam reordering updates hidden module state. That's friendly
|
|
28
|
+
for eager execution, but hostile to `torch.export` and many accelerator
|
|
29
|
+
backends.
|
|
30
|
+
|
|
31
|
+
This export wrapper makes the per-token call truly "single-step" in the
|
|
32
|
+
export sense: "stateless" and "fixed-shape" so every invocation has the
|
|
33
|
+
exact same graph.
|
|
34
|
+
|
|
35
|
+
Key invariants
|
|
36
|
+
--------------
|
|
37
|
+
• "Stateless": K/V caches come in as explicit inputs and go out as outputs.
|
|
38
|
+
The module does not store or mutate hidden state.
|
|
39
|
+
• "Static shapes": Query is always [B, 1, C]; encoder features and masks
|
|
40
|
+
have fixed, predeclared sizes; K/V slots use fixed capacity (unused tail
|
|
41
|
+
is simply masked/ignored).
|
|
42
|
+
• "External control": Step indexing, cache slot management (append/roll),
|
|
43
|
+
and beam reordering are handled outside the module.
|
|
44
|
+
• "Prebuilt additive masks": Self-attention masks are provided by the
|
|
45
|
+
caller (0 for valid, large negative sentinel, e.g. -120, for masked),
|
|
46
|
+
avoiding data-dependent control flow.
|
|
47
|
+
|
|
48
|
+
In short: still step-wise like fairseq, but restructured for export—no
|
|
49
|
+
internal state, no data-dependent shapes, no dynamic control flow.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
from typing import List, Tuple
|
|
53
|
+
|
|
54
|
+
import torch
|
|
55
|
+
import torch.nn as nn
|
|
56
|
+
|
|
57
|
+
import tico
|
|
58
|
+
|
|
59
|
+
# ----- 1) Export wrapper module -------------------------------------------
|
|
60
|
+
class DecoderExportSingleStep(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Export-only single-step decoder module.
|
|
63
|
+
|
|
64
|
+
Inputs (example shapes; B=1, H=8, Dh=64, C=512, S=64, Tprev=63):
|
|
65
|
+
- prev_x: [B, 1, C] embedded decoder input for the current step
|
|
66
|
+
- enc_x: [S, B, C] encoder hidden states (fixed-length export input)
|
|
67
|
+
- enc_pad_additive: [B, 1, S] additive float key_padding_mask for enc-dec attn (0 for keep, -120 for pad)
|
|
68
|
+
- self_attn_mask: [B, 1, S] additive float mask for decoder self-attn at this step; pass zeros if unused
|
|
69
|
+
- prev_self_k_0..L-1: [B, H, Tprev, Dh] cached self-attn K per layer
|
|
70
|
+
- prev_self_v_0..L-1: [B, H, Tprev, Dh] cached self-attn V per layer
|
|
71
|
+
|
|
72
|
+
Outputs:
|
|
73
|
+
- x_out: [B, 1, C] new decoder features at the current step
|
|
74
|
+
- new_k_0..L-1: [H, B, Dh] per-layer new K (single-timestep; time dim squeezed)
|
|
75
|
+
- new_v_0..L-1: [H, B, Dh] per-layer new V (single-timestep; time dim squeezed)
|
|
76
|
+
|
|
77
|
+
Notes:
|
|
78
|
+
• We keep masks/additive semantics externally to avoid any mask-building inside the graph.
|
|
79
|
+
• We reshape the new K/V from [B,H,1,Dh] -> [H,B,Dh] to match the requested output spec (8,1,64).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, decoder: nn.Module):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.decoder = decoder
|
|
85
|
+
# Cache common meta for assertions
|
|
86
|
+
self.num_layers = len(getattr(decoder, "layers"))
|
|
87
|
+
# Infer heads/head_dim from the wrapped self_attn of layer 0
|
|
88
|
+
any_layer = getattr(decoder.layers[0], "wrapped", decoder.layers[0]) # type: ignore[index]
|
|
89
|
+
mha = getattr(any_layer, "self_attn", None)
|
|
90
|
+
assert mha is not None, "Decoder layer must expose self_attn"
|
|
91
|
+
self.num_heads = int(mha.num_heads)
|
|
92
|
+
self.head_dim = int(mha.head_dim)
|
|
93
|
+
# Embed dim (C)
|
|
94
|
+
self.embed_dim = int(getattr(decoder, "embed_dim"))
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
prev_x: torch.Tensor, # [B,1,C]
|
|
99
|
+
enc_x: torch.Tensor, # [S,B,C]
|
|
100
|
+
enc_pad_additive: torch.Tensor, # [B,1,S]
|
|
101
|
+
*kv_args: torch.Tensor, # prev_k_0..L-1, prev_v_0..L-1 (total 2L tensors)
|
|
102
|
+
self_attn_mask: torch.Tensor, # [B,1,S] (or zeros)
|
|
103
|
+
):
|
|
104
|
+
L = self.num_layers
|
|
105
|
+
H = self.num_heads
|
|
106
|
+
Dh = self.head_dim
|
|
107
|
+
B, one, C = prev_x.shape
|
|
108
|
+
S, B2, C2 = enc_x.shape
|
|
109
|
+
assert (
|
|
110
|
+
one == 1 and C == self.embed_dim and B == B2 and C2 == C
|
|
111
|
+
), "Shape mismatch in prev_x/enc_x"
|
|
112
|
+
assert len(kv_args) == 2 * L, f"Expected {2*L} KV tensors, got {len(kv_args)}"
|
|
113
|
+
|
|
114
|
+
# Unpack previous self-attn caches
|
|
115
|
+
prev_k_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
116
|
+
prev_v_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
117
|
+
for i in range(L):
|
|
118
|
+
prev_k_list.append(kv_args[2 * i])
|
|
119
|
+
prev_v_list.append(kv_args[2 * i + 1])
|
|
120
|
+
for i in range(L):
|
|
121
|
+
assert (
|
|
122
|
+
prev_k_list[i].dim() == 4 and prev_v_list[i].dim() == 4
|
|
123
|
+
), "KV must be [B,H,Tprev,Dh]"
|
|
124
|
+
assert (
|
|
125
|
+
prev_k_list[i].shape[0] == B
|
|
126
|
+
and prev_k_list[i].shape[1] == H
|
|
127
|
+
and prev_k_list[i].shape[3] == Dh
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Call decoder's external single-step path
|
|
131
|
+
# Returns:
|
|
132
|
+
# x_step: [B,1,C]
|
|
133
|
+
# newk/newv: lists of length L, each [B*H,1,Dh]
|
|
134
|
+
x_step, newk_list, newv_list = self.decoder.forward_external_step( # type: ignore[operator]
|
|
135
|
+
prev_output_x=prev_x,
|
|
136
|
+
encoder_out_x=enc_x,
|
|
137
|
+
encoder_padding_mask=enc_pad_additive,
|
|
138
|
+
self_attn_mask=self_attn_mask,
|
|
139
|
+
prev_self_k_list=prev_k_list,
|
|
140
|
+
prev_self_v_list=prev_v_list,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
out_tensors: List[torch.Tensor] = [
|
|
144
|
+
x_step
|
|
145
|
+
] # first output is the new decoder features
|
|
146
|
+
for i in range(L):
|
|
147
|
+
nk = newk_list[i] # [B*H, Tnew, Dh]
|
|
148
|
+
nv = newv_list[i] # [B*H, Tnew, Dh]
|
|
149
|
+
out_tensors.append(nk)
|
|
150
|
+
out_tensors.append(nv)
|
|
151
|
+
|
|
152
|
+
# Return tuple: (x_step, new_k_0, new_v_0, new_k_1, new_v_1, ..., new_k_{L-1}, new_v_{L-1})
|
|
153
|
+
return tuple(out_tensors)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# ----- 2) Example inputs (B=1, S=64, H=8, Dh=64, C=512, L=4) ---------------
|
|
157
|
+
def make_example_inputs(*, L=4, B=1, S=64, H=8, Dh=64, C=512, Tprev=63, device="cpu"):
|
|
158
|
+
"""
|
|
159
|
+
Build example tensors that match the export I/O spec.
|
|
160
|
+
Shapes follow the request:
|
|
161
|
+
prev_x: [1,1,512]
|
|
162
|
+
enc_x: [64,1,512]
|
|
163
|
+
enc_pad_additive: [1,1,64] (additive float; zeros -> keep)
|
|
164
|
+
prev_k_i / prev_v_i (for i in 0..L-1): [1,8,63,64]
|
|
165
|
+
self_attn_mask: [1,1,64] (additive float; zeros -> keep)
|
|
166
|
+
"""
|
|
167
|
+
g = torch.Generator(device=device).manual_seed(0)
|
|
168
|
+
|
|
169
|
+
prev_x = torch.randn(B, 1, C, device=device, dtype=torch.float32, generator=g)
|
|
170
|
+
enc_x = torch.randn(S, B, C, device=device, dtype=torch.float32, generator=g)
|
|
171
|
+
|
|
172
|
+
# Additive masks (0 for allowed, -120 for masked)
|
|
173
|
+
enc_pad_additive = torch.full((B, 1, S), float(-120), device=device)
|
|
174
|
+
self_attn_mask = torch.full((B, 1, S), float(-120), device=device)
|
|
175
|
+
enc_pad_additive[0, :27] = 0 # 27 is a random example.
|
|
176
|
+
self_attn_mask[0, :27] = 0 # 27 is a random example.
|
|
177
|
+
|
|
178
|
+
# Previous self-attn caches for each layer
|
|
179
|
+
prev_k_list = []
|
|
180
|
+
prev_v_list = []
|
|
181
|
+
for _ in range(L):
|
|
182
|
+
prev_k = torch.randn(
|
|
183
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
184
|
+
)
|
|
185
|
+
prev_v = torch.randn(
|
|
186
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
187
|
+
)
|
|
188
|
+
prev_k_list.append(prev_k)
|
|
189
|
+
prev_v_list.append(prev_v)
|
|
190
|
+
|
|
191
|
+
# Pack inputs as the export function will expect:
|
|
192
|
+
# (prev_x, enc_x, enc_pad_additive, self_attn_mask, prev_k_0..L-1, prev_v_0..L-1)
|
|
193
|
+
example_args: Tuple[torch.Tensor, ...] = (
|
|
194
|
+
prev_x,
|
|
195
|
+
enc_x,
|
|
196
|
+
enc_pad_additive,
|
|
197
|
+
*prev_k_list,
|
|
198
|
+
*prev_v_list,
|
|
199
|
+
)
|
|
200
|
+
example_kwargs = {"self_attn_mask": self_attn_mask}
|
|
201
|
+
return example_args, example_kwargs
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# ----- 3) Export driver -----------------------------------------------------
|
|
205
|
+
def export_decoder_single_step(translator, *, save_path="decoder_step_export.circle"):
|
|
206
|
+
"""
|
|
207
|
+
Wrap the QuantFairseqDecoder into the export-friendly single-step module
|
|
208
|
+
and export with torch.export.export using example inputs.
|
|
209
|
+
"""
|
|
210
|
+
# Grab the wrapped decoder
|
|
211
|
+
dec = translator.models[
|
|
212
|
+
0
|
|
213
|
+
].decoder # assumed QuantFairseqDecoder with forward_external_step
|
|
214
|
+
# Build export wrapper
|
|
215
|
+
wrapper = DecoderExportSingleStep(decoder=dec).eval()
|
|
216
|
+
|
|
217
|
+
# Example inputs (L inferred from wrapper/decoder)
|
|
218
|
+
L = wrapper.num_layers
|
|
219
|
+
H = wrapper.num_heads
|
|
220
|
+
Dh = wrapper.head_dim
|
|
221
|
+
C = wrapper.embed_dim
|
|
222
|
+
example_inputs, example_kwargs = make_example_inputs(L=L, H=H, Dh=Dh, C=C)
|
|
223
|
+
|
|
224
|
+
# Export circle (no dynamism assumed; shapes are fixed for export)
|
|
225
|
+
cm = tico.convert(
|
|
226
|
+
wrapper,
|
|
227
|
+
args=example_inputs,
|
|
228
|
+
kwargs=example_kwargs,
|
|
229
|
+
strict=True, # fail if something cannot be captured
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Save .pte
|
|
233
|
+
cm.save(save_path)
|
|
234
|
+
print(f"Saved decoder single-step export to: {save_path}")
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
import math
|
|
22
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn.functional as F
|
|
26
|
+
from torch import nn, Tensor
|
|
27
|
+
|
|
28
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
29
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
30
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
31
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@try_register("fairseq.models.transformer.TransformerDecoderBase")
|
|
35
|
+
class QuantFairseqDecoder(QuantModuleBase):
|
|
36
|
+
"""
|
|
37
|
+
Quant-aware drop-in replacement for Fairseq TransformerDecoderBase.
|
|
38
|
+
|
|
39
|
+
Design (inference-only):
|
|
40
|
+
- Keep embeddings, positional embeddings, LayerNorms, output_projection in FP.
|
|
41
|
+
- PTQ-wrap all TransformerDecoderLayerBase items via PTQWrapper (uses QuantFairseqDecoderLayer).
|
|
42
|
+
- Drop training-only logic (dropout, activation-dropout, quant-noise, checkpoint wrappers).
|
|
43
|
+
- Preserve Fairseq forward/extract_features contract, shapes, and incremental decoding behavior.
|
|
44
|
+
|
|
45
|
+
I/O:
|
|
46
|
+
- Forward(prev_output_tokens, encoder_out, incremental_state, ...) -> (logits, extra) like the original.
|
|
47
|
+
- `features_only=True` returns features without output projection.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
fp_decoder: nn.Module,
|
|
53
|
+
*,
|
|
54
|
+
qcfg: Optional[PTQConfig] = None,
|
|
55
|
+
fp_name: Optional[str] = None,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
58
|
+
|
|
59
|
+
# ---- carry config/meta (read-only views) --------------------------
|
|
60
|
+
assert hasattr(fp_decoder, "cfg")
|
|
61
|
+
self.cfg = fp_decoder.cfg
|
|
62
|
+
self.share_input_output_embed: bool = bool(
|
|
63
|
+
getattr(fp_decoder, "share_input_output_embed", False)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Version buffer (parity with original)
|
|
67
|
+
version = getattr(fp_decoder, "version", None)
|
|
68
|
+
if isinstance(version, torch.Tensor):
|
|
69
|
+
self.register_buffer("version", version.clone(), persistent=False)
|
|
70
|
+
else:
|
|
71
|
+
self.register_buffer("version", torch.tensor([3.0]), persistent=False)
|
|
72
|
+
|
|
73
|
+
# Embeddings / positional encodings (FP; reuse modules)
|
|
74
|
+
assert hasattr(fp_decoder, "embed_tokens") and isinstance(
|
|
75
|
+
fp_decoder.embed_tokens, nn.Module
|
|
76
|
+
)
|
|
77
|
+
self.embed_tokens = fp_decoder.embed_tokens # (B,T)->(B,T,C)
|
|
78
|
+
|
|
79
|
+
self.padding_idx: int = int(fp_decoder.padding_idx) # type: ignore[arg-type]
|
|
80
|
+
self.max_target_positions: int = int(fp_decoder.max_target_positions) # type: ignore[arg-type]
|
|
81
|
+
|
|
82
|
+
self.embed_positions = getattr(fp_decoder, "embed_positions", None)
|
|
83
|
+
self.layernorm_embedding = getattr(fp_decoder, "layernorm_embedding", None)
|
|
84
|
+
|
|
85
|
+
# Dimensions / projections (reuse)
|
|
86
|
+
self.embed_dim: int = int(getattr(fp_decoder, "embed_dim"))
|
|
87
|
+
self.output_embed_dim: int = int(getattr(fp_decoder, "output_embed_dim"))
|
|
88
|
+
self.project_in_dim = getattr(fp_decoder, "project_in_dim", None)
|
|
89
|
+
self.project_out_dim = getattr(fp_decoder, "project_out_dim", None)
|
|
90
|
+
|
|
91
|
+
# Scale factor (sqrt(embed_dim) unless disabled)
|
|
92
|
+
no_scale = bool(getattr(self.cfg, "no_scale_embedding", False))
|
|
93
|
+
self.embed_scale: float = 1.0 if no_scale else math.sqrt(self.embed_dim)
|
|
94
|
+
|
|
95
|
+
# Final decoder LayerNorm (may be None depending on cfg)
|
|
96
|
+
self.layer_norm = getattr(fp_decoder, "layer_norm", None)
|
|
97
|
+
|
|
98
|
+
# Output projection / adaptive softmax (reuse FP modules)
|
|
99
|
+
self.adaptive_softmax = getattr(fp_decoder, "adaptive_softmax", None)
|
|
100
|
+
self.output_projection = getattr(fp_decoder, "output_projection", None)
|
|
101
|
+
|
|
102
|
+
# ---- wrap decoder layers ------------------------------------------
|
|
103
|
+
assert hasattr(fp_decoder, "layers")
|
|
104
|
+
fp_layers = list(fp_decoder.layers) # type: ignore[arg-type]
|
|
105
|
+
self.layers = nn.ModuleList()
|
|
106
|
+
|
|
107
|
+
# Safe prefix to avoid None-based name collisions in KV cache keys
|
|
108
|
+
def _safe_prefix(name: Optional[str]) -> str:
|
|
109
|
+
return (
|
|
110
|
+
name
|
|
111
|
+
if (name is not None and name != "" and name != "None")
|
|
112
|
+
else f"{self.__class__.__name__}_{id(self)}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
prefix = _safe_prefix(fp_name)
|
|
116
|
+
|
|
117
|
+
# Prepare child PTQConfig namespaces: layers/<idx>
|
|
118
|
+
layers_qcfg = qcfg.child("layers") if qcfg else None
|
|
119
|
+
for i, layer in enumerate(fp_layers):
|
|
120
|
+
child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
|
|
121
|
+
# Not every item is necessarily a TransformerDecoderLayerBase (e.g., BaseLayer).
|
|
122
|
+
# If there's no registered wrapper for a layer type, keep it FP.
|
|
123
|
+
try:
|
|
124
|
+
wrapped = PTQWrapper(
|
|
125
|
+
layer, qcfg=child_cfg, fp_name=f"{prefix}.layers.{i}"
|
|
126
|
+
)
|
|
127
|
+
except NotImplementedError:
|
|
128
|
+
wrapped = layer # keep as-is (FP)
|
|
129
|
+
self.layers.append(wrapped)
|
|
130
|
+
self.num_layers = len(self.layers)
|
|
131
|
+
|
|
132
|
+
# choose a generous upper-bound; you can wire this from cfg if you like
|
|
133
|
+
self.mask_fill_value: float = -120.0
|
|
134
|
+
max_tgt = int(getattr(self.cfg, "max_target_positions", 2048)) # fallback: 2048
|
|
135
|
+
|
|
136
|
+
mask = torch.full((1, 1, max_tgt, max_tgt), float(self.mask_fill_value))
|
|
137
|
+
mask.triu_(1) # upper triangle set to fill_value; diagonal/lower are zeros
|
|
138
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
|
139
|
+
|
|
140
|
+
def forward(
|
|
141
|
+
self,
|
|
142
|
+
prev_output_tokens: Tensor, # [B, T]
|
|
143
|
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
|
144
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
145
|
+
features_only: bool = False,
|
|
146
|
+
full_context_alignment: bool = False,
|
|
147
|
+
alignment_layer: Optional[int] = None,
|
|
148
|
+
alignment_heads: Optional[int] = None,
|
|
149
|
+
src_lengths: Optional[Any] = None,
|
|
150
|
+
return_all_hiddens: bool = False,
|
|
151
|
+
):
|
|
152
|
+
"""
|
|
153
|
+
Match the original API.
|
|
154
|
+
Returns:
|
|
155
|
+
(logits_or_features, extra_dict)
|
|
156
|
+
"""
|
|
157
|
+
x, extra = self.extract_features_scriptable(
|
|
158
|
+
prev_output_tokens=prev_output_tokens,
|
|
159
|
+
encoder_out=encoder_out,
|
|
160
|
+
incremental_state=incremental_state,
|
|
161
|
+
full_context_alignment=full_context_alignment,
|
|
162
|
+
alignment_layer=alignment_layer,
|
|
163
|
+
alignment_heads=alignment_heads,
|
|
164
|
+
)
|
|
165
|
+
if not features_only:
|
|
166
|
+
x = self.output_layer(x)
|
|
167
|
+
return x, extra
|
|
168
|
+
|
|
169
|
+
def extract_features_scriptable(
|
|
170
|
+
self,
|
|
171
|
+
prev_output_tokens: Tensor, # [B,T]
|
|
172
|
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
|
173
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
174
|
+
full_context_alignment: bool = False,
|
|
175
|
+
alignment_layer: Optional[int] = None,
|
|
176
|
+
alignment_heads: Optional[int] = None,
|
|
177
|
+
) -> Tuple[Tensor, Dict[str, List[Optional[Tensor]]]]:
|
|
178
|
+
"""
|
|
179
|
+
Feature path that mirrors Fairseq's implementation (minus training-only code).
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
x: [B, T, C]
|
|
183
|
+
extra: {"attn": [attn or None], "inner_states": [T x B x C tensors]}
|
|
184
|
+
"""
|
|
185
|
+
B, T = prev_output_tokens.size()
|
|
186
|
+
if alignment_layer is None:
|
|
187
|
+
alignment_layer = self.num_layers - 1
|
|
188
|
+
|
|
189
|
+
# Unpack encoder outputs in Fairseq dict format
|
|
190
|
+
enc: Optional[Tensor] = None
|
|
191
|
+
padding_mask: Optional[Tensor] = None
|
|
192
|
+
if encoder_out is not None and len(encoder_out.get("encoder_out", [])) > 0:
|
|
193
|
+
enc = encoder_out["encoder_out"][0] # [S,B,Ce]
|
|
194
|
+
if (
|
|
195
|
+
encoder_out is not None
|
|
196
|
+
and len(encoder_out.get("encoder_padding_mask", [])) > 0
|
|
197
|
+
):
|
|
198
|
+
padding_mask = encoder_out["encoder_padding_mask"][0] # [B,S] (bool)
|
|
199
|
+
|
|
200
|
+
# Positional embeddings (support incremental decoding)
|
|
201
|
+
positions = None
|
|
202
|
+
if self.embed_positions is not None:
|
|
203
|
+
positions = self.embed_positions(
|
|
204
|
+
prev_output_tokens, incremental_state=incremental_state
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# In incremental mode, only the last step is consumed
|
|
208
|
+
if incremental_state is not None:
|
|
209
|
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
|
210
|
+
if positions is not None:
|
|
211
|
+
positions = positions[:, -1:]
|
|
212
|
+
|
|
213
|
+
# Prevent view quirks (TorchScript parity in original)
|
|
214
|
+
prev_output_tokens = prev_output_tokens.contiguous()
|
|
215
|
+
|
|
216
|
+
# Token embeddings (+ optional proj-in), + positions, + optional LN
|
|
217
|
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens) # [B,T,C]
|
|
218
|
+
if self.project_in_dim is not None:
|
|
219
|
+
x = self.project_in_dim(x)
|
|
220
|
+
if positions is not None:
|
|
221
|
+
x = x + positions
|
|
222
|
+
if self.layernorm_embedding is not None:
|
|
223
|
+
x = self.layernorm_embedding(x)
|
|
224
|
+
|
|
225
|
+
# No dropout / quant_noise (inference-only)
|
|
226
|
+
|
|
227
|
+
# B x T x C -> T x B x C
|
|
228
|
+
x = x.transpose(0, 1)
|
|
229
|
+
|
|
230
|
+
# Build self-attn masks
|
|
231
|
+
self_attn_padding_mask: Optional[Tensor] = None
|
|
232
|
+
if (
|
|
233
|
+
getattr(self.cfg, "cross_self_attention", False)
|
|
234
|
+
or prev_output_tokens.eq(self.padding_idx).any()
|
|
235
|
+
):
|
|
236
|
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # [B,T]
|
|
237
|
+
|
|
238
|
+
attn: Optional[Tensor] = None
|
|
239
|
+
inner_states: List[Optional[Tensor]] = [x]
|
|
240
|
+
|
|
241
|
+
for idx, layer in enumerate(self.layers):
|
|
242
|
+
# Causal mask unless full-context alignment or incremental decoding
|
|
243
|
+
if incremental_state is None and not full_context_alignment:
|
|
244
|
+
Tq = x.size(0)
|
|
245
|
+
self_attn_mask = self.buffered_future_mask(
|
|
246
|
+
Tq, Tq, x=x
|
|
247
|
+
) # [Tq,Tq] additive float
|
|
248
|
+
else:
|
|
249
|
+
self_attn_mask = None
|
|
250
|
+
|
|
251
|
+
x, layer_attn, _ = layer(
|
|
252
|
+
x,
|
|
253
|
+
enc,
|
|
254
|
+
padding_mask,
|
|
255
|
+
incremental_state,
|
|
256
|
+
self_attn_mask=self_attn_mask,
|
|
257
|
+
self_attn_padding_mask=self_attn_padding_mask,
|
|
258
|
+
need_attn=bool(idx == alignment_layer),
|
|
259
|
+
need_head_weights=bool(idx == alignment_layer),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
inner_states.append(x)
|
|
263
|
+
if layer_attn is not None and idx == alignment_layer:
|
|
264
|
+
attn = layer_attn.float().to(x)
|
|
265
|
+
|
|
266
|
+
# Average heads if needed
|
|
267
|
+
if attn is not None and alignment_heads is not None:
|
|
268
|
+
attn = attn[:alignment_heads]
|
|
269
|
+
if attn is not None:
|
|
270
|
+
attn = attn.mean(dim=0) # [B,T,S]
|
|
271
|
+
|
|
272
|
+
# Optional final layer norm
|
|
273
|
+
if self.layer_norm is not None:
|
|
274
|
+
x = self.layer_norm(x)
|
|
275
|
+
|
|
276
|
+
# T x B x C -> B x T x C
|
|
277
|
+
x = x.transpose(0, 1)
|
|
278
|
+
|
|
279
|
+
# Optional proj-out
|
|
280
|
+
if self.project_out_dim is not None:
|
|
281
|
+
assert self.project_out_dim is not None
|
|
282
|
+
x = self.project_out_dim(x)
|
|
283
|
+
|
|
284
|
+
return x, {"attn": [attn], "inner_states": inner_states}
|
|
285
|
+
|
|
286
|
+
def output_layer(self, features: Tensor) -> Tensor:
|
|
287
|
+
"""Project features to vocabulary size (or return features with adaptive softmax)."""
|
|
288
|
+
if self.adaptive_softmax is None:
|
|
289
|
+
assert self.output_projection is not None
|
|
290
|
+
return self.output_projection(features) # type: ignore[operator]
|
|
291
|
+
else:
|
|
292
|
+
return features
|
|
293
|
+
|
|
294
|
+
def buffered_future_mask(
|
|
295
|
+
self, Tq: int, Ts: int, *, x: torch.Tensor
|
|
296
|
+
) -> torch.Tensor:
|
|
297
|
+
"""
|
|
298
|
+
Return additive float mask [Tq, Ts]: zeros on allowed, large-neg on disallowed.
|
|
299
|
+
Uses the prebuilt template; will re-build if you exceed template size.
|
|
300
|
+
"""
|
|
301
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
302
|
+
Mmax = self.causal_mask_template.size(-1)
|
|
303
|
+
assert Tq <= Mmax and Ts <= Mmax
|
|
304
|
+
cm = self.causal_mask_template[..., :Tq, :Ts].to(device=x.device, dtype=x.dtype)
|
|
305
|
+
return cm.squeeze(0).squeeze(0) # [Tq, Ts]
|
|
306
|
+
|
|
307
|
+
def max_positions(self) -> int:
|
|
308
|
+
"""Maximum output length supported by the decoder (same policy as the original)."""
|
|
309
|
+
if self.embed_positions is None:
|
|
310
|
+
return self.max_target_positions
|
|
311
|
+
return min(self.max_target_positions, self.embed_positions.max_positions)
|
|
312
|
+
|
|
313
|
+
def get_normalized_probs(
|
|
314
|
+
self,
|
|
315
|
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
|
316
|
+
log_probs: bool,
|
|
317
|
+
sample: Optional[Dict[str, Tensor]] = None,
|
|
318
|
+
):
|
|
319
|
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
320
|
+
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
|
321
|
+
|
|
322
|
+
def get_normalized_probs_scriptable(
|
|
323
|
+
self,
|
|
324
|
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
|
325
|
+
log_probs: bool,
|
|
326
|
+
sample: Optional[Dict[str, Tensor]] = None,
|
|
327
|
+
):
|
|
328
|
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
329
|
+
|
|
330
|
+
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
|
331
|
+
if sample is not None:
|
|
332
|
+
assert "target" in sample
|
|
333
|
+
target = sample["target"]
|
|
334
|
+
else:
|
|
335
|
+
target = None
|
|
336
|
+
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
|
|
337
|
+
return out.exp_() if not log_probs else out
|
|
338
|
+
|
|
339
|
+
logits = net_output[0]
|
|
340
|
+
if log_probs:
|
|
341
|
+
return F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
|
342
|
+
else:
|
|
343
|
+
return F.softmax(logits, dim=-1, dtype=torch.float32)
|
|
344
|
+
|
|
345
|
+
def reorder_incremental_state_scripting(
|
|
346
|
+
self,
|
|
347
|
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
|
348
|
+
new_order: Tensor,
|
|
349
|
+
):
|
|
350
|
+
"""Main entry point for reordering the incremental state.
|
|
351
|
+
|
|
352
|
+
Due to limitations in TorchScript, we call this function in
|
|
353
|
+
:class:`fairseq.sequence_generator.SequenceGenerator` instead of
|
|
354
|
+
calling :func:`reorder_incremental_state` directly.
|
|
355
|
+
"""
|
|
356
|
+
for module in self.modules():
|
|
357
|
+
if hasattr(module, "reorder_incremental_state"):
|
|
358
|
+
result = module.reorder_incremental_state(incremental_state, new_order) # type: ignore[operator]
|
|
359
|
+
if result is not None:
|
|
360
|
+
incremental_state = result
|
|
361
|
+
|
|
362
|
+
def forward_external_step(
|
|
363
|
+
self,
|
|
364
|
+
prev_output_x: Tensor, # [1, B, C]
|
|
365
|
+
*,
|
|
366
|
+
encoder_out_x: Tensor, # [S, B, Ce]
|
|
367
|
+
encoder_padding_mask: Optional[
|
|
368
|
+
Tensor
|
|
369
|
+
] = None, # [B,S] or [B,1,S] additive-float
|
|
370
|
+
self_attn_mask: Optional[
|
|
371
|
+
Tensor
|
|
372
|
+
] = None, # [1,S_hist+1] or [B,1,S_hist+1] additive-float
|
|
373
|
+
prev_self_k_list: Optional[
|
|
374
|
+
List[Tensor]
|
|
375
|
+
] = None, # length=L; each [B,H,Tprev,Dh]
|
|
376
|
+
prev_self_v_list: Optional[
|
|
377
|
+
List[Tensor]
|
|
378
|
+
] = None, # length=L; each [B,H,Tprev,Dh]
|
|
379
|
+
need_attn: bool = False,
|
|
380
|
+
need_head_weights: bool = False,
|
|
381
|
+
) -> Tuple[Tensor, List[Tensor], List[Tensor]]:
|
|
382
|
+
"""
|
|
383
|
+
Export-only single-step decoder.
|
|
384
|
+
Returns:
|
|
385
|
+
- x_out: [1, B, C]
|
|
386
|
+
- new_self_k_list/new_self_v_list: lists of length L; each [B*H, Tnew, Dh]
|
|
387
|
+
"""
|
|
388
|
+
assert (
|
|
389
|
+
prev_output_x.dim() == 3 and prev_output_x.size(0) == 1
|
|
390
|
+
), "prev_output_x must be [1,B,C]"
|
|
391
|
+
L = self.num_layers
|
|
392
|
+
if prev_self_k_list is None:
|
|
393
|
+
prev_self_k_list = [None] * L # type: ignore[list-item]
|
|
394
|
+
if prev_self_v_list is None:
|
|
395
|
+
prev_self_v_list = [None] * L # type: ignore[list-item]
|
|
396
|
+
assert len(prev_self_k_list) == L and len(prev_self_v_list) == L
|
|
397
|
+
|
|
398
|
+
assert encoder_out_x.dim() == 3, "encoder_out_x must be [S,B,C]"
|
|
399
|
+
x = prev_output_x # [1,B,C]
|
|
400
|
+
enc = encoder_out_x
|
|
401
|
+
|
|
402
|
+
new_k_list: List[Tensor] = []
|
|
403
|
+
new_v_list: List[Tensor] = []
|
|
404
|
+
|
|
405
|
+
for li, layer in enumerate(self.layers):
|
|
406
|
+
assert isinstance(layer, PTQWrapper)
|
|
407
|
+
x, _, k_new, v_new = layer.wrapped.forward_external( # type: ignore[attr-defined, operator]
|
|
408
|
+
x,
|
|
409
|
+
encoder_out=enc,
|
|
410
|
+
encoder_padding_mask=encoder_padding_mask,
|
|
411
|
+
prev_self_k=prev_self_k_list[li],
|
|
412
|
+
prev_self_v=prev_self_v_list[li],
|
|
413
|
+
self_attn_mask=self_attn_mask,
|
|
414
|
+
need_attn=need_attn and (li == L - 1),
|
|
415
|
+
need_head_weights=need_head_weights and (li == L - 1),
|
|
416
|
+
)
|
|
417
|
+
new_k_list.append(k_new) # [B*H, Tnew, Dh]
|
|
418
|
+
new_v_list.append(v_new) # [B*H, Tnew, Dh]
|
|
419
|
+
|
|
420
|
+
if self.layer_norm is not None:
|
|
421
|
+
x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1)
|
|
422
|
+
|
|
423
|
+
return x, new_k_list, new_v_list # [1,B,C], lists of [B*H, Tnew, Dh]
|
|
424
|
+
|
|
425
|
+
def _all_observers(self) -> Iterable:
|
|
426
|
+
"""Yield all observers from wrapped decoder layers (if any)."""
|
|
427
|
+
for m in self.layers:
|
|
428
|
+
if isinstance(m, QuantModuleBase):
|
|
429
|
+
yield from m._all_observers()
|