tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
from typing import Dict, Optional, Tuple, Union
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
import torch.nn.functional as F
|
|
26
|
+
|
|
27
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
28
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
29
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
30
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@try_register("fairseq.modules.multihead_attention.MultiheadAttention")
|
|
34
|
+
class QuantFairseqMultiheadAttention(QuantModuleBase):
|
|
35
|
+
"""
|
|
36
|
+
Quant-aware drop-in for Fairseq MultiheadAttention.
|
|
37
|
+
|
|
38
|
+
- No xFormers / no torch F.multi_head_attention_forward fast-path.
|
|
39
|
+
- Self/cross attention + minimal incremental KV cache.
|
|
40
|
+
- Causal mask is pre-built statically; `key_padding_mask` is additive float.
|
|
41
|
+
- I/O shape: [T, B, C]
|
|
42
|
+
|
|
43
|
+
Runtime optimization flags
|
|
44
|
+
--------------------------
|
|
45
|
+
use_static_causal : bool
|
|
46
|
+
If True, reuse a precomputed upper-triangular causal mask template
|
|
47
|
+
instead of rebuilding it each forward step. Reduces per-step mask
|
|
48
|
+
construction overhead during incremental decoding.
|
|
49
|
+
|
|
50
|
+
assume_additive_key_padding : bool
|
|
51
|
+
If True, assume the `key_padding_mask` is already an additive float
|
|
52
|
+
tensor (large negative values at padded positions). Skips conversion
|
|
53
|
+
from boolean masks, reducing runtime overhead.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
fp_attn: nn.Module,
|
|
59
|
+
*,
|
|
60
|
+
qcfg: Optional[PTQConfig] = None,
|
|
61
|
+
fp_name: Optional[str] = None,
|
|
62
|
+
max_seq: int = 4096,
|
|
63
|
+
use_static_causal: bool = False,
|
|
64
|
+
mask_fill_value: float = -120.0,
|
|
65
|
+
assume_additive_key_padding: bool = False,
|
|
66
|
+
):
|
|
67
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
68
|
+
|
|
69
|
+
self.use_static_causal = use_static_causal
|
|
70
|
+
self.mask_fill_value = mask_fill_value
|
|
71
|
+
self.assume_additive_key_padding = assume_additive_key_padding
|
|
72
|
+
self.embed_dim: int = int(fp_attn.embed_dim) # type: ignore[arg-type]
|
|
73
|
+
self.num_heads: int = int(fp_attn.num_heads) # type: ignore[arg-type]
|
|
74
|
+
self.head_dim: int = self.embed_dim // self.num_heads
|
|
75
|
+
assert self.head_dim * self.num_heads == self.embed_dim
|
|
76
|
+
|
|
77
|
+
self.self_attention: bool = bool(getattr(fp_attn, "self_attention", False))
|
|
78
|
+
self.encoder_decoder_attention: bool = bool(
|
|
79
|
+
getattr(fp_attn, "encoder_decoder_attention", False)
|
|
80
|
+
)
|
|
81
|
+
assert self.self_attention != self.encoder_decoder_attention
|
|
82
|
+
|
|
83
|
+
# PTQ-wrapped projections
|
|
84
|
+
qc = qcfg.child("q_proj") if qcfg else None
|
|
85
|
+
kc = qcfg.child("k_proj") if qcfg else None
|
|
86
|
+
vc = qcfg.child("v_proj") if qcfg else None
|
|
87
|
+
oc = qcfg.child("out_proj") if qcfg else None
|
|
88
|
+
assert hasattr(fp_attn, "q_proj") and hasattr(fp_attn, "k_proj")
|
|
89
|
+
assert hasattr(fp_attn, "v_proj") and hasattr(fp_attn, "out_proj")
|
|
90
|
+
assert isinstance(fp_attn.q_proj, nn.Module) and isinstance(
|
|
91
|
+
fp_attn.k_proj, nn.Module
|
|
92
|
+
)
|
|
93
|
+
assert isinstance(fp_attn.v_proj, nn.Module) and isinstance(
|
|
94
|
+
fp_attn.out_proj, nn.Module
|
|
95
|
+
)
|
|
96
|
+
self.q_proj = PTQWrapper(fp_attn.q_proj, qcfg=qc, fp_name=f"{fp_name}.q_proj")
|
|
97
|
+
self.k_proj = PTQWrapper(fp_attn.k_proj, qcfg=kc, fp_name=f"{fp_name}.k_proj")
|
|
98
|
+
self.v_proj = PTQWrapper(fp_attn.v_proj, qcfg=vc, fp_name=f"{fp_name}.v_proj")
|
|
99
|
+
self.out_proj = PTQWrapper(
|
|
100
|
+
fp_attn.out_proj, qcfg=oc, fp_name=f"{fp_name}.out_proj"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# scale & static causal mask
|
|
104
|
+
self.register_buffer(
|
|
105
|
+
"scale_const", torch.tensor(self.head_dim**-0.5), persistent=False
|
|
106
|
+
)
|
|
107
|
+
mask = torch.full((1, 1, max_seq, max_seq), float(self.mask_fill_value))
|
|
108
|
+
mask.triu_(1)
|
|
109
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
|
110
|
+
|
|
111
|
+
# observers (no *_proj_out here; PTQWrapper handles module outputs)
|
|
112
|
+
mk = self._make_obs
|
|
113
|
+
self.obs_query_in = mk("query_in")
|
|
114
|
+
self.obs_key_in = mk("key_in")
|
|
115
|
+
self.obs_value_in = mk("value_in")
|
|
116
|
+
self.obs_kpm_in = mk("kpm_in")
|
|
117
|
+
self.obs_causal_mask = mk("causal_mask")
|
|
118
|
+
self.obs_q_fold = mk("q_fold")
|
|
119
|
+
self.obs_k_fold = mk("k_fold")
|
|
120
|
+
self.obs_v_fold = mk("v_fold")
|
|
121
|
+
self.obs_scale = mk("scale")
|
|
122
|
+
self.obs_logits_raw = mk("logits_raw")
|
|
123
|
+
self.obs_logits = mk("logits_scaled")
|
|
124
|
+
self.obs_attn_mask_add = mk("obs_attn_mask_add")
|
|
125
|
+
self.obs_kp_mask_add = mk("obs_kp_mask_add")
|
|
126
|
+
self.obs_softmax = mk("softmax")
|
|
127
|
+
self.obs_attn_out = mk("attn_out")
|
|
128
|
+
|
|
129
|
+
safe_name = (
|
|
130
|
+
fp_name if (fp_name not in (None, "", "None")) else f"QuantFsMHA_{id(self)}"
|
|
131
|
+
)
|
|
132
|
+
assert safe_name is not None
|
|
133
|
+
self._state_key = safe_name + ".attn_state"
|
|
134
|
+
|
|
135
|
+
def _get_input_buffer(
|
|
136
|
+
self,
|
|
137
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
|
|
138
|
+
) -> Optional[Dict[str, Optional[torch.Tensor]]]:
|
|
139
|
+
"""Return saved KV/mask dict or None."""
|
|
140
|
+
if incremental_state is None:
|
|
141
|
+
return None
|
|
142
|
+
return incremental_state.get(self._state_key, None)
|
|
143
|
+
|
|
144
|
+
def _set_input_buffer(
|
|
145
|
+
self,
|
|
146
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
|
|
147
|
+
buffer: Dict[str, Optional[torch.Tensor]],
|
|
148
|
+
):
|
|
149
|
+
"""Store KV/mask dict in incremental_state."""
|
|
150
|
+
if incremental_state is not None:
|
|
151
|
+
incremental_state[self._state_key] = buffer
|
|
152
|
+
return incremental_state
|
|
153
|
+
|
|
154
|
+
# ---- utils ----
|
|
155
|
+
def _fold_heads(self, x: torch.Tensor, B: int) -> torch.Tensor:
|
|
156
|
+
# [T,B,E] -> [B*H, T, Dh]
|
|
157
|
+
T = x.size(0)
|
|
158
|
+
x = x.view(T, B, self.num_heads, self.head_dim).permute(1, 2, 0, 3).contiguous()
|
|
159
|
+
return x.view(B * self.num_heads, T, self.head_dim)
|
|
160
|
+
|
|
161
|
+
def _unfold_heads(self, x: torch.Tensor, B: int, T: int) -> torch.Tensor:
|
|
162
|
+
# [B*H, T, Dh] -> [T,B,E]
|
|
163
|
+
x = x.view(B, self.num_heads, T, self.head_dim).permute(2, 0, 1, 3).contiguous()
|
|
164
|
+
return x.view(T, B, self.embed_dim)
|
|
165
|
+
|
|
166
|
+
def forward(
|
|
167
|
+
self,
|
|
168
|
+
query: torch.Tensor, # [Tq,B,C]
|
|
169
|
+
key: Optional[torch.Tensor],
|
|
170
|
+
value: Optional[torch.Tensor],
|
|
171
|
+
key_padding_mask: Optional[
|
|
172
|
+
torch.Tensor
|
|
173
|
+
] = None, # additive float (e.g. -120 at pads)
|
|
174
|
+
incremental_state: Optional[
|
|
175
|
+
Dict[str, Dict[str, Optional[torch.Tensor]]]
|
|
176
|
+
] = None,
|
|
177
|
+
need_weights: bool = False,
|
|
178
|
+
static_kv: bool = False,
|
|
179
|
+
attn_mask: Optional[torch.Tensor] = None, # if None -> internal causal
|
|
180
|
+
before_softmax: bool = False,
|
|
181
|
+
need_head_weights: bool = False,
|
|
182
|
+
return_new_kv: bool = False,
|
|
183
|
+
) -> Union[
|
|
184
|
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
|
185
|
+
Tuple[
|
|
186
|
+
torch.Tensor,
|
|
187
|
+
Optional[torch.Tensor],
|
|
188
|
+
Optional[torch.Tensor],
|
|
189
|
+
Optional[torch.Tensor],
|
|
190
|
+
],
|
|
191
|
+
]:
|
|
192
|
+
|
|
193
|
+
if need_head_weights:
|
|
194
|
+
need_weights = True
|
|
195
|
+
|
|
196
|
+
Tq, B, _ = query.shape
|
|
197
|
+
if self.self_attention:
|
|
198
|
+
key = query if key is None else key
|
|
199
|
+
value = query if value is None else value
|
|
200
|
+
else:
|
|
201
|
+
assert key is not None and value is not None
|
|
202
|
+
|
|
203
|
+
Tk, Bk, _ = key.shape
|
|
204
|
+
Tv, Bv, _ = value.shape
|
|
205
|
+
assert B == Bk == Bv
|
|
206
|
+
|
|
207
|
+
q = self.q_proj(self._fq(query, self.obs_query_in))
|
|
208
|
+
k = self.k_proj(self._fq(key, self.obs_key_in))
|
|
209
|
+
v = self.v_proj(self._fq(value, self.obs_value_in))
|
|
210
|
+
|
|
211
|
+
state = self._get_input_buffer(incremental_state)
|
|
212
|
+
if incremental_state is not None and state is None:
|
|
213
|
+
state = {}
|
|
214
|
+
|
|
215
|
+
# Capture "new" K/V for this call BEFORE concatenating with cache
|
|
216
|
+
new_k_bh: Optional[torch.Tensor] = None
|
|
217
|
+
new_v_bh: Optional[torch.Tensor] = None
|
|
218
|
+
|
|
219
|
+
# Fold heads
|
|
220
|
+
q = self._fq(self._fold_heads(q, B), self.obs_q_fold)
|
|
221
|
+
if state is not None and "prev_key" in state and static_kv:
|
|
222
|
+
# Cross-attention static_kv path: reuse cached KV; there is no new KV this call.
|
|
223
|
+
k = None
|
|
224
|
+
v = None
|
|
225
|
+
if k is not None:
|
|
226
|
+
k = self._fq(self._fold_heads(k, B), self.obs_k_fold) # [B*H, Tnew, Dh]
|
|
227
|
+
if return_new_kv:
|
|
228
|
+
new_k_bh = k.contiguous()
|
|
229
|
+
if v is not None:
|
|
230
|
+
v = self._fq(self._fold_heads(v, B), self.obs_v_fold) # [B*H, Tnew, Dh]
|
|
231
|
+
if return_new_kv:
|
|
232
|
+
new_v_bh = v.contiguous()
|
|
233
|
+
|
|
234
|
+
# Append/reuse cache
|
|
235
|
+
if state is not None:
|
|
236
|
+
pk = state.get("prev_key")
|
|
237
|
+
pv = state.get("prev_value")
|
|
238
|
+
if pk is not None:
|
|
239
|
+
pk = pk.view(B * self.num_heads, -1, self.head_dim)
|
|
240
|
+
k = pk if static_kv else torch.cat([pk, k], dim=1)
|
|
241
|
+
if pv is not None:
|
|
242
|
+
pv = pv.view(B * self.num_heads, -1, self.head_dim)
|
|
243
|
+
v = pv if static_kv else torch.cat([pv, v], dim=1)
|
|
244
|
+
|
|
245
|
+
assert k is not None and v is not None
|
|
246
|
+
Ts = k.size(1)
|
|
247
|
+
|
|
248
|
+
# Scaled dot-product
|
|
249
|
+
scale = self._fq(self.scale_const, self.obs_scale).to(q.dtype)
|
|
250
|
+
logits_raw = self._fq(
|
|
251
|
+
torch.bmm(q, k.transpose(1, 2)), self.obs_logits_raw
|
|
252
|
+
) # [B*H,Tq,Ts]
|
|
253
|
+
logits = self._fq(logits_raw * scale, self.obs_logits)
|
|
254
|
+
|
|
255
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
256
|
+
# Masks
|
|
257
|
+
device = logits.device
|
|
258
|
+
if attn_mask is None and self.use_static_causal:
|
|
259
|
+
# Incremental decoding aware slicing:
|
|
260
|
+
# align the causal row(s) to the current time indices
|
|
261
|
+
start_q = max(Ts - Tq, 0)
|
|
262
|
+
cm = self.causal_mask_template[..., start_q : start_q + Tq, :Ts].to(
|
|
263
|
+
device=device, dtype=logits.dtype
|
|
264
|
+
)
|
|
265
|
+
attn_mask = cm.squeeze(0).squeeze(0) # [Tq,Ts]
|
|
266
|
+
|
|
267
|
+
if attn_mask is not None:
|
|
268
|
+
# Bool/byte mask -> additive float with large negatives
|
|
269
|
+
if not torch.is_floating_point(attn_mask):
|
|
270
|
+
fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
|
|
271
|
+
attn_mask = torch.where(
|
|
272
|
+
attn_mask.to(torch.bool), fill, fill.new_zeros(())
|
|
273
|
+
)
|
|
274
|
+
attn_mask = self._fq(attn_mask, self.obs_causal_mask)
|
|
275
|
+
assert isinstance(attn_mask, torch.Tensor)
|
|
276
|
+
|
|
277
|
+
if not self.assume_additive_key_padding:
|
|
278
|
+
# attn_mask -> [B*H,Tq,Ts]
|
|
279
|
+
if attn_mask.dim() == 2:
|
|
280
|
+
add_mask = attn_mask.unsqueeze(0).expand(logits.size(0), -1, -1)
|
|
281
|
+
elif attn_mask.dim() == 3:
|
|
282
|
+
add_mask = (
|
|
283
|
+
attn_mask.unsqueeze(1)
|
|
284
|
+
.expand(B, self.num_heads, Tq, Ts)
|
|
285
|
+
.contiguous()
|
|
286
|
+
)
|
|
287
|
+
add_mask = add_mask.view(B * self.num_heads, Tq, Ts)
|
|
288
|
+
else:
|
|
289
|
+
raise RuntimeError("attn_mask must be [T,S] or [B,T,S]")
|
|
290
|
+
else:
|
|
291
|
+
add_mask = attn_mask
|
|
292
|
+
logits = self._fq(logits + add_mask, self.obs_attn_mask_add)
|
|
293
|
+
|
|
294
|
+
if key_padding_mask is not None:
|
|
295
|
+
if not torch.is_floating_point(key_padding_mask):
|
|
296
|
+
fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
|
|
297
|
+
kpm = torch.where(
|
|
298
|
+
key_padding_mask.to(torch.bool), fill, fill.new_zeros(())
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
kpm = key_padding_mask
|
|
302
|
+
kpm = self._fq(kpm, self.obs_kpm_in)
|
|
303
|
+
|
|
304
|
+
if not self.assume_additive_key_padding:
|
|
305
|
+
# key_padding_mask: additive float already
|
|
306
|
+
kpm = kpm.to(dtype=logits.dtype, device=device)
|
|
307
|
+
if kpm.dim() == 2: # [B,S]
|
|
308
|
+
kpm = (
|
|
309
|
+
kpm.view(B, 1, 1, Ts)
|
|
310
|
+
.expand(B, self.num_heads, Tq, Ts)
|
|
311
|
+
.contiguous()
|
|
312
|
+
)
|
|
313
|
+
kpm = kpm.view(B * self.num_heads, Tq, Ts)
|
|
314
|
+
elif kpm.dim() == 3: # [B,T,S]
|
|
315
|
+
kpm = (
|
|
316
|
+
kpm.unsqueeze(1).expand(B, self.num_heads, Tq, Ts).contiguous()
|
|
317
|
+
)
|
|
318
|
+
kpm = kpm.view(B * self.num_heads, Tq, Ts)
|
|
319
|
+
else:
|
|
320
|
+
raise RuntimeError(
|
|
321
|
+
"key_padding_mask must be [B,S] or [B,T,S] (additive)"
|
|
322
|
+
)
|
|
323
|
+
logits = self._fq(logits + kpm, self.obs_kp_mask_add)
|
|
324
|
+
|
|
325
|
+
if before_softmax:
|
|
326
|
+
if return_new_kv:
|
|
327
|
+
return logits, v, new_k_bh, new_v_bh
|
|
328
|
+
return logits, v
|
|
329
|
+
|
|
330
|
+
# Softmax (float32) -> back to q.dtype
|
|
331
|
+
attn_probs = torch.softmax(logits, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
332
|
+
attn_probs = self._fq(attn_probs, self.obs_softmax)
|
|
333
|
+
|
|
334
|
+
# Context + output proj
|
|
335
|
+
ctx = self._fq(torch.bmm(attn_probs, v), self.obs_attn_out) # [B*H,Tq,Dh]
|
|
336
|
+
ctx = self._unfold_heads(ctx, B, Tq) # [Tq,B,E]
|
|
337
|
+
out = self.out_proj(ctx)
|
|
338
|
+
|
|
339
|
+
# Weights (optional)
|
|
340
|
+
attn_weights_out: Optional[torch.Tensor] = None
|
|
341
|
+
if need_weights:
|
|
342
|
+
aw = (
|
|
343
|
+
torch.softmax(logits, dim=-1, dtype=torch.float32)
|
|
344
|
+
.view(B, self.num_heads, Tq, Ts)
|
|
345
|
+
.transpose(1, 0)
|
|
346
|
+
)
|
|
347
|
+
if not need_head_weights:
|
|
348
|
+
aw = aw.mean(dim=1) # [B,Tq,Ts]
|
|
349
|
+
attn_weights_out = aw
|
|
350
|
+
|
|
351
|
+
# Cache write
|
|
352
|
+
if state is not None:
|
|
353
|
+
state["prev_key"] = k.view(B, self.num_heads, -1, self.head_dim).detach()
|
|
354
|
+
state["prev_value"] = v.view(B, self.num_heads, -1, self.head_dim).detach()
|
|
355
|
+
self._set_input_buffer(incremental_state, state)
|
|
356
|
+
|
|
357
|
+
if return_new_kv:
|
|
358
|
+
return out, attn_weights_out, new_k_bh, new_v_bh
|
|
359
|
+
return out, attn_weights_out
|
|
360
|
+
|
|
361
|
+
def _all_observers(self):
|
|
362
|
+
yield from (
|
|
363
|
+
self.obs_query_in,
|
|
364
|
+
self.obs_key_in,
|
|
365
|
+
self.obs_value_in,
|
|
366
|
+
self.obs_kpm_in,
|
|
367
|
+
self.obs_causal_mask,
|
|
368
|
+
self.obs_q_fold,
|
|
369
|
+
self.obs_k_fold,
|
|
370
|
+
self.obs_v_fold,
|
|
371
|
+
self.obs_scale,
|
|
372
|
+
self.obs_logits_raw,
|
|
373
|
+
self.obs_logits,
|
|
374
|
+
self.obs_attn_mask_add,
|
|
375
|
+
self.obs_kp_mask_add,
|
|
376
|
+
self.obs_softmax,
|
|
377
|
+
self.obs_attn_out,
|
|
378
|
+
)
|
|
379
|
+
for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
|
|
380
|
+
if isinstance(m, QuantModuleBase):
|
|
381
|
+
yield from m._all_observers()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
22
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
23
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@try_register(
|
|
27
|
+
"transformers.models.llama.modeling_llama.LlamaAttention",
|
|
28
|
+
"transformers.models.llama.modeling_llama.LlamaSdpaAttention",
|
|
29
|
+
)
|
|
30
|
+
class QuantLlamaAttention(QuantModuleBase):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
fp_attn: nn.Module,
|
|
34
|
+
*,
|
|
35
|
+
qcfg: Optional[PTQConfig] = None,
|
|
36
|
+
fp_name: Optional[str] = None,
|
|
37
|
+
):
|
|
38
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
39
|
+
|
|
40
|
+
cfg = fp_attn.config
|
|
41
|
+
assert hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads")
|
|
42
|
+
assert hasattr(cfg, "num_key_value_heads")
|
|
43
|
+
assert isinstance(cfg.hidden_size, int) and isinstance(
|
|
44
|
+
cfg.num_attention_heads, int
|
|
45
|
+
)
|
|
46
|
+
assert isinstance(cfg.num_key_value_heads, int)
|
|
47
|
+
self.hdim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
|
|
48
|
+
self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads
|
|
49
|
+
|
|
50
|
+
# constant scale (1/√d)
|
|
51
|
+
self.scale_t = torch.tensor(self.hdim**-0.5)
|
|
52
|
+
self.obs_scale = self._make_obs("scale")
|
|
53
|
+
|
|
54
|
+
# ---- wrap q k v o projections via PTQWrapper ---------------
|
|
55
|
+
q_cfg = qcfg.child("q_proj") if qcfg else None
|
|
56
|
+
k_cfg = qcfg.child("k_proj") if qcfg else None
|
|
57
|
+
v_cfg = qcfg.child("v_proj") if qcfg else None
|
|
58
|
+
o_cfg = qcfg.child("o_proj") if qcfg else None
|
|
59
|
+
assert hasattr(fp_attn, "q_proj") and isinstance(
|
|
60
|
+
fp_attn.q_proj, torch.nn.Module
|
|
61
|
+
)
|
|
62
|
+
assert hasattr(fp_attn, "k_proj") and isinstance(
|
|
63
|
+
fp_attn.k_proj, torch.nn.Module
|
|
64
|
+
)
|
|
65
|
+
assert hasattr(fp_attn, "v_proj") and isinstance(
|
|
66
|
+
fp_attn.v_proj, torch.nn.Module
|
|
67
|
+
)
|
|
68
|
+
assert hasattr(fp_attn, "o_proj") and isinstance(
|
|
69
|
+
fp_attn.o_proj, torch.nn.Module
|
|
70
|
+
)
|
|
71
|
+
self.q_proj = PTQWrapper(
|
|
72
|
+
fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj"
|
|
73
|
+
)
|
|
74
|
+
self.k_proj = PTQWrapper(
|
|
75
|
+
fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj"
|
|
76
|
+
)
|
|
77
|
+
self.v_proj = PTQWrapper(
|
|
78
|
+
fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj"
|
|
79
|
+
)
|
|
80
|
+
self.o_proj = PTQWrapper(
|
|
81
|
+
fp_attn.o_proj, qcfg=o_cfg, fp_name=f"{fp_name}.o_proj"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# ---- create arithmetic observers ---------------------------
|
|
85
|
+
mk = self._make_obs
|
|
86
|
+
self.obs_hidden = mk("hidden")
|
|
87
|
+
|
|
88
|
+
self.obs_cos = mk("cos")
|
|
89
|
+
self.obs_sin = mk("sin")
|
|
90
|
+
|
|
91
|
+
self.obs_causal_mask = mk("causal_mask")
|
|
92
|
+
|
|
93
|
+
# rotate-half sub-steps
|
|
94
|
+
self.obs_q_x1 = mk("q_x1")
|
|
95
|
+
self.obs_q_x2 = mk("q_x2")
|
|
96
|
+
self.obs_q_neg = mk("q_neg")
|
|
97
|
+
self.obs_q_cat = mk("q_cat")
|
|
98
|
+
self.obs_k_x1 = mk("k_x1")
|
|
99
|
+
self.obs_k_x2 = mk("k_x2")
|
|
100
|
+
self.obs_k_neg = mk("k_neg")
|
|
101
|
+
self.obs_k_cat = mk("k_cat")
|
|
102
|
+
|
|
103
|
+
# q / k paths
|
|
104
|
+
self.obs_q_cos = mk("q_cos")
|
|
105
|
+
self.obs_q_sin = mk("q_sin")
|
|
106
|
+
self.obs_q_rot = mk("q_rot")
|
|
107
|
+
self.obs_k_cos = mk("k_cos")
|
|
108
|
+
self.obs_k_sin = mk("k_sin")
|
|
109
|
+
self.obs_k_rot = mk("k_rot")
|
|
110
|
+
|
|
111
|
+
# logits / softmax / out
|
|
112
|
+
self.obs_logits_raw = mk("logits_raw")
|
|
113
|
+
self.obs_logits = mk("logits")
|
|
114
|
+
self.obs_mask_add = mk("mask_add")
|
|
115
|
+
self.obs_softmax = mk("softmax")
|
|
116
|
+
self.obs_attn_out = mk("attn_out")
|
|
117
|
+
|
|
118
|
+
# Static causal mask template
|
|
119
|
+
assert hasattr(cfg, "max_position_embeddings")
|
|
120
|
+
max_seq = cfg.max_position_embeddings
|
|
121
|
+
mask = torch.full((1, 1, max_seq, max_seq), float("-120")) # type: ignore[arg-type]
|
|
122
|
+
mask.triu_(1)
|
|
123
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
|
124
|
+
|
|
125
|
+
def _rot(self, t, o_x1, o_x2, o_neg, o_cat):
|
|
126
|
+
x1, x2 = torch.chunk(t, 2, dim=-1)
|
|
127
|
+
x1 = self._fq(x1, o_x1)
|
|
128
|
+
x2 = self._fq(x2, o_x2)
|
|
129
|
+
x2n = self._fq(-x2, o_neg)
|
|
130
|
+
return self._fq(torch.cat((x2n, x1), -1), o_cat)
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _concat_kv(
|
|
134
|
+
past: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
135
|
+
k_new: torch.Tensor,
|
|
136
|
+
v_new: torch.Tensor,
|
|
137
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
138
|
+
"""Concat along sequence dim (dim=2): (B, n_kv, S, H)."""
|
|
139
|
+
if past is None:
|
|
140
|
+
return k_new, v_new
|
|
141
|
+
past_k, past_v = past
|
|
142
|
+
k = torch.cat([past_k, k_new], dim=2)
|
|
143
|
+
v = torch.cat([past_v, v_new], dim=2)
|
|
144
|
+
return k, v
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
hidden_states: torch.Tensor,
|
|
149
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
150
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
151
|
+
past_key_value=None, # tuple(k, v) or HF Cache-like object
|
|
152
|
+
use_cache: Optional[bool] = False,
|
|
153
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
154
|
+
**kwargs,
|
|
155
|
+
):
|
|
156
|
+
hidden = self._fq(hidden_states, self.obs_hidden)
|
|
157
|
+
B, S, _ = hidden.shape
|
|
158
|
+
H = self.hdim
|
|
159
|
+
|
|
160
|
+
# projections
|
|
161
|
+
q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H)
|
|
162
|
+
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
|
|
163
|
+
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
|
|
164
|
+
|
|
165
|
+
# rope tables
|
|
166
|
+
cos, sin = position_embeddings
|
|
167
|
+
cos = self._fq(cos, self.obs_cos)
|
|
168
|
+
sin = self._fq(sin, self.obs_sin)
|
|
169
|
+
cos_u, sin_u = cos.unsqueeze(1), sin.unsqueeze(1)
|
|
170
|
+
|
|
171
|
+
# q_rot
|
|
172
|
+
q_half = self._rot(
|
|
173
|
+
q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat
|
|
174
|
+
)
|
|
175
|
+
q_cos = self._fq(q * cos_u, self.obs_q_cos)
|
|
176
|
+
q_sin = self._fq(q_half * sin_u, self.obs_q_sin)
|
|
177
|
+
q_rot = self._fq(q_cos + q_sin, self.obs_q_rot)
|
|
178
|
+
|
|
179
|
+
# k_rot
|
|
180
|
+
k_half = self._rot(
|
|
181
|
+
k, self.obs_k_x1, self.obs_k_x2, self.obs_k_neg, self.obs_k_cat
|
|
182
|
+
)
|
|
183
|
+
k_cos = self._fq(k * cos_u, self.obs_k_cos)
|
|
184
|
+
k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
|
|
185
|
+
k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
|
|
186
|
+
|
|
187
|
+
# --- build/update KV for attention & present_key_value -------------
|
|
188
|
+
present_key_value: Tuple[torch.Tensor, torch.Tensor]
|
|
189
|
+
|
|
190
|
+
# HF Cache path (if available)
|
|
191
|
+
if use_cache and hasattr(past_key_value, "update"):
|
|
192
|
+
# Many HF Cache impls use update(k, v) and return (k_total, v_total)
|
|
193
|
+
try:
|
|
194
|
+
k_total, v_total = past_key_value.update(k_rot, v)
|
|
195
|
+
present_key_value = (k_total, v_total)
|
|
196
|
+
k_for_attn, v_for_attn = k_total, v_total
|
|
197
|
+
except Exception:
|
|
198
|
+
# Fallback to tuple concat if Cache signature mismatches
|
|
199
|
+
k_for_attn, v_for_attn = self._concat_kv(
|
|
200
|
+
getattr(past_key_value, "kv", None), k_rot, v
|
|
201
|
+
)
|
|
202
|
+
present_key_value = (k_for_attn, v_for_attn)
|
|
203
|
+
else:
|
|
204
|
+
# Tuple or None path
|
|
205
|
+
pkv_tuple = past_key_value if isinstance(past_key_value, tuple) else None
|
|
206
|
+
k_for_attn, v_for_attn = self._concat_kv(pkv_tuple, k_rot, v)
|
|
207
|
+
present_key_value = (k_for_attn, v_for_attn)
|
|
208
|
+
|
|
209
|
+
# logits
|
|
210
|
+
k_rep = k_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
|
|
211
|
+
logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
|
|
212
|
+
scale = self._fq(self.scale_t, self.obs_scale)
|
|
213
|
+
logits = self._fq(logits_raw * scale, self.obs_logits)
|
|
214
|
+
|
|
215
|
+
if attention_mask is None or attention_mask.dtype == torch.bool:
|
|
216
|
+
_, _, q_len, _ = logits.shape
|
|
217
|
+
k_len = k_for_attn.size(2)
|
|
218
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
219
|
+
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
|
|
220
|
+
hidden_states.device
|
|
221
|
+
)
|
|
222
|
+
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
|
|
223
|
+
logits = self._fq(logits + attention_mask, self.obs_mask_add)
|
|
224
|
+
|
|
225
|
+
# softmax
|
|
226
|
+
attn_weights = torch.softmax(logits, -1, dtype=torch.float32).to(q.dtype)
|
|
227
|
+
attn_weights = self._fq(attn_weights, self.obs_softmax)
|
|
228
|
+
|
|
229
|
+
# attn out
|
|
230
|
+
v_rep = v_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
|
|
231
|
+
attn_out = (
|
|
232
|
+
self._fq(attn_weights @ v_rep, self.obs_attn_out)
|
|
233
|
+
.transpose(1, 2)
|
|
234
|
+
.reshape(B, S, -1)
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# final projection
|
|
238
|
+
out = self.o_proj(attn_out)
|
|
239
|
+
|
|
240
|
+
# return with/without cache
|
|
241
|
+
if use_cache:
|
|
242
|
+
return out, attn_weights, present_key_value
|
|
243
|
+
else:
|
|
244
|
+
return out, attn_weights
|
|
245
|
+
|
|
246
|
+
def _all_observers(self):
|
|
247
|
+
# local first
|
|
248
|
+
yield from (
|
|
249
|
+
self.obs_hidden,
|
|
250
|
+
self.obs_scale,
|
|
251
|
+
self.obs_cos,
|
|
252
|
+
self.obs_sin,
|
|
253
|
+
self.obs_causal_mask,
|
|
254
|
+
self.obs_q_x1,
|
|
255
|
+
self.obs_q_x2,
|
|
256
|
+
self.obs_q_neg,
|
|
257
|
+
self.obs_q_cat,
|
|
258
|
+
self.obs_k_x1,
|
|
259
|
+
self.obs_k_x2,
|
|
260
|
+
self.obs_k_neg,
|
|
261
|
+
self.obs_k_cat,
|
|
262
|
+
self.obs_q_cos,
|
|
263
|
+
self.obs_q_sin,
|
|
264
|
+
self.obs_q_rot,
|
|
265
|
+
self.obs_k_cos,
|
|
266
|
+
self.obs_k_sin,
|
|
267
|
+
self.obs_k_rot,
|
|
268
|
+
self.obs_logits_raw,
|
|
269
|
+
self.obs_logits,
|
|
270
|
+
self.obs_mask_add,
|
|
271
|
+
self.obs_softmax,
|
|
272
|
+
self.obs_attn_out,
|
|
273
|
+
)
|
|
274
|
+
# recurse into children that are QuantModuleBase
|
|
275
|
+
for m in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
|
|
276
|
+
yield from m._all_observers()
|