tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/fpi_gptq.py +29 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +20 -15
- tico/utils/register_custom_op.py +6 -4
- tico/utils/signature.py +7 -8
- tico/utils/validate_args_kwargs.py +12 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
- /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
- /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
from typing import Dict, Optional, Tuple, Union
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
import torch.nn.functional as F
|
|
26
|
+
|
|
27
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
28
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
29
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
30
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@try_register("fairseq.modules.multihead_attention.MultiheadAttention")
|
|
34
|
+
class QuantFairseqMultiheadAttention(QuantModuleBase):
|
|
35
|
+
"""
|
|
36
|
+
Quant-aware drop-in for Fairseq MultiheadAttention.
|
|
37
|
+
|
|
38
|
+
- No xFormers / no torch F.multi_head_attention_forward fast-path.
|
|
39
|
+
- Self/cross attention + minimal incremental KV cache.
|
|
40
|
+
- Causal mask is pre-built statically; `key_padding_mask` is additive float.
|
|
41
|
+
- I/O shape: [T, B, C]
|
|
42
|
+
|
|
43
|
+
Runtime optimization flags
|
|
44
|
+
--------------------------
|
|
45
|
+
use_static_causal : bool
|
|
46
|
+
If True, reuse a precomputed upper-triangular causal mask template
|
|
47
|
+
instead of rebuilding it each forward step. Reduces per-step mask
|
|
48
|
+
construction overhead during incremental decoding.
|
|
49
|
+
|
|
50
|
+
assume_additive_key_padding : bool
|
|
51
|
+
If True, assume the `key_padding_mask` is already an additive float
|
|
52
|
+
tensor (large negative values at padded positions). Skips conversion
|
|
53
|
+
from boolean masks, reducing runtime overhead.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
fp_attn: nn.Module,
|
|
59
|
+
*,
|
|
60
|
+
qcfg: Optional[PTQConfig] = None,
|
|
61
|
+
fp_name: Optional[str] = None,
|
|
62
|
+
max_seq: int = 4096,
|
|
63
|
+
use_static_causal: bool = False,
|
|
64
|
+
mask_fill_value: float = -120.0,
|
|
65
|
+
assume_additive_key_padding: bool = False,
|
|
66
|
+
):
|
|
67
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
68
|
+
|
|
69
|
+
self.use_static_causal = use_static_causal
|
|
70
|
+
self.mask_fill_value = mask_fill_value
|
|
71
|
+
self.assume_additive_key_padding = assume_additive_key_padding
|
|
72
|
+
self.embed_dim: int = int(fp_attn.embed_dim) # type: ignore[arg-type]
|
|
73
|
+
self.num_heads: int = int(fp_attn.num_heads) # type: ignore[arg-type]
|
|
74
|
+
self.head_dim: int = self.embed_dim // self.num_heads
|
|
75
|
+
assert self.head_dim * self.num_heads == self.embed_dim
|
|
76
|
+
|
|
77
|
+
self.self_attention: bool = bool(getattr(fp_attn, "self_attention", False))
|
|
78
|
+
self.encoder_decoder_attention: bool = bool(
|
|
79
|
+
getattr(fp_attn, "encoder_decoder_attention", False)
|
|
80
|
+
)
|
|
81
|
+
assert self.self_attention != self.encoder_decoder_attention
|
|
82
|
+
|
|
83
|
+
# PTQ-wrapped projections
|
|
84
|
+
qc = qcfg.child("q_proj") if qcfg else None
|
|
85
|
+
kc = qcfg.child("k_proj") if qcfg else None
|
|
86
|
+
vc = qcfg.child("v_proj") if qcfg else None
|
|
87
|
+
oc = qcfg.child("out_proj") if qcfg else None
|
|
88
|
+
assert hasattr(fp_attn, "q_proj") and hasattr(fp_attn, "k_proj")
|
|
89
|
+
assert hasattr(fp_attn, "v_proj") and hasattr(fp_attn, "out_proj")
|
|
90
|
+
assert isinstance(fp_attn.q_proj, nn.Module) and isinstance(
|
|
91
|
+
fp_attn.k_proj, nn.Module
|
|
92
|
+
)
|
|
93
|
+
assert isinstance(fp_attn.v_proj, nn.Module) and isinstance(
|
|
94
|
+
fp_attn.out_proj, nn.Module
|
|
95
|
+
)
|
|
96
|
+
self.q_proj = PTQWrapper(fp_attn.q_proj, qcfg=qc, fp_name=f"{fp_name}.q_proj")
|
|
97
|
+
self.k_proj = PTQWrapper(fp_attn.k_proj, qcfg=kc, fp_name=f"{fp_name}.k_proj")
|
|
98
|
+
self.v_proj = PTQWrapper(fp_attn.v_proj, qcfg=vc, fp_name=f"{fp_name}.v_proj")
|
|
99
|
+
self.out_proj = PTQWrapper(
|
|
100
|
+
fp_attn.out_proj, qcfg=oc, fp_name=f"{fp_name}.out_proj"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# scale & static causal mask
|
|
104
|
+
self.register_buffer(
|
|
105
|
+
"scale_const", torch.tensor(self.head_dim**-0.5), persistent=False
|
|
106
|
+
)
|
|
107
|
+
mask = torch.full((1, 1, max_seq, max_seq), float(self.mask_fill_value))
|
|
108
|
+
mask.triu_(1)
|
|
109
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
|
110
|
+
|
|
111
|
+
# observers (no *_proj_out here; PTQWrapper handles module outputs)
|
|
112
|
+
mk = self._make_obs
|
|
113
|
+
self.obs_query_in = mk("query_in")
|
|
114
|
+
self.obs_key_in = mk("key_in")
|
|
115
|
+
self.obs_value_in = mk("value_in")
|
|
116
|
+
self.obs_kpm_in = mk("kpm_in")
|
|
117
|
+
self.obs_causal_mask = mk("causal_mask")
|
|
118
|
+
self.obs_q_fold = mk("q_fold")
|
|
119
|
+
self.obs_k_fold = mk("k_fold")
|
|
120
|
+
self.obs_v_fold = mk("v_fold")
|
|
121
|
+
self.obs_scale = mk("scale")
|
|
122
|
+
self.obs_logits_raw = mk("logits_raw")
|
|
123
|
+
self.obs_logits = mk("logits_scaled")
|
|
124
|
+
self.obs_attn_mask_add = mk("obs_attn_mask_add")
|
|
125
|
+
self.obs_kp_mask_add = mk("obs_kp_mask_add")
|
|
126
|
+
self.obs_softmax = mk("softmax")
|
|
127
|
+
self.obs_attn_out = mk("attn_out")
|
|
128
|
+
|
|
129
|
+
safe_name = (
|
|
130
|
+
fp_name if (fp_name not in (None, "", "None")) else f"QuantFsMHA_{id(self)}"
|
|
131
|
+
)
|
|
132
|
+
assert safe_name is not None
|
|
133
|
+
self._state_key = safe_name + ".attn_state"
|
|
134
|
+
|
|
135
|
+
def _get_input_buffer(
|
|
136
|
+
self,
|
|
137
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
|
|
138
|
+
) -> Optional[Dict[str, Optional[torch.Tensor]]]:
|
|
139
|
+
"""Return saved KV/mask dict or None."""
|
|
140
|
+
if incremental_state is None:
|
|
141
|
+
return None
|
|
142
|
+
return incremental_state.get(self._state_key, None)
|
|
143
|
+
|
|
144
|
+
def _set_input_buffer(
|
|
145
|
+
self,
|
|
146
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
|
|
147
|
+
buffer: Dict[str, Optional[torch.Tensor]],
|
|
148
|
+
):
|
|
149
|
+
"""Store KV/mask dict in incremental_state."""
|
|
150
|
+
if incremental_state is not None:
|
|
151
|
+
incremental_state[self._state_key] = buffer
|
|
152
|
+
return incremental_state
|
|
153
|
+
|
|
154
|
+
# ---- utils ----
|
|
155
|
+
def _fold_heads(self, x: torch.Tensor, B: int) -> torch.Tensor:
|
|
156
|
+
# [T,B,E] -> [B*H, T, Dh]
|
|
157
|
+
T = x.size(0)
|
|
158
|
+
x = x.view(T, B, self.num_heads, self.head_dim).permute(1, 2, 0, 3).contiguous()
|
|
159
|
+
return x.view(B * self.num_heads, T, self.head_dim)
|
|
160
|
+
|
|
161
|
+
def _unfold_heads(self, x: torch.Tensor, B: int, T: int) -> torch.Tensor:
|
|
162
|
+
# [B*H, T, Dh] -> [T,B,E]
|
|
163
|
+
x = x.view(B, self.num_heads, T, self.head_dim).permute(2, 0, 1, 3).contiguous()
|
|
164
|
+
return x.view(T, B, self.embed_dim)
|
|
165
|
+
|
|
166
|
+
def forward(
|
|
167
|
+
self,
|
|
168
|
+
query: torch.Tensor, # [Tq,B,C]
|
|
169
|
+
key: Optional[torch.Tensor],
|
|
170
|
+
value: Optional[torch.Tensor],
|
|
171
|
+
key_padding_mask: Optional[
|
|
172
|
+
torch.Tensor
|
|
173
|
+
] = None, # additive float (e.g. -120 at pads)
|
|
174
|
+
incremental_state: Optional[
|
|
175
|
+
Dict[str, Dict[str, Optional[torch.Tensor]]]
|
|
176
|
+
] = None,
|
|
177
|
+
need_weights: bool = False,
|
|
178
|
+
static_kv: bool = False,
|
|
179
|
+
attn_mask: Optional[torch.Tensor] = None, # if None -> internal causal
|
|
180
|
+
before_softmax: bool = False,
|
|
181
|
+
need_head_weights: bool = False,
|
|
182
|
+
return_new_kv: bool = False,
|
|
183
|
+
) -> Union[
|
|
184
|
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
|
185
|
+
Tuple[
|
|
186
|
+
torch.Tensor,
|
|
187
|
+
Optional[torch.Tensor],
|
|
188
|
+
Optional[torch.Tensor],
|
|
189
|
+
Optional[torch.Tensor],
|
|
190
|
+
],
|
|
191
|
+
]:
|
|
192
|
+
|
|
193
|
+
if need_head_weights:
|
|
194
|
+
need_weights = True
|
|
195
|
+
|
|
196
|
+
Tq, B, _ = query.shape
|
|
197
|
+
if self.self_attention:
|
|
198
|
+
key = query if key is None else key
|
|
199
|
+
value = query if value is None else value
|
|
200
|
+
else:
|
|
201
|
+
assert key is not None and value is not None
|
|
202
|
+
|
|
203
|
+
Tk, Bk, _ = key.shape
|
|
204
|
+
Tv, Bv, _ = value.shape
|
|
205
|
+
assert B == Bk == Bv
|
|
206
|
+
|
|
207
|
+
q = self.q_proj(self._fq(query, self.obs_query_in))
|
|
208
|
+
k = self.k_proj(self._fq(key, self.obs_key_in))
|
|
209
|
+
v = self.v_proj(self._fq(value, self.obs_value_in))
|
|
210
|
+
|
|
211
|
+
state = self._get_input_buffer(incremental_state)
|
|
212
|
+
if incremental_state is not None and state is None:
|
|
213
|
+
state = {}
|
|
214
|
+
|
|
215
|
+
# Capture "new" K/V for this call BEFORE concatenating with cache
|
|
216
|
+
new_k_bh: Optional[torch.Tensor] = None
|
|
217
|
+
new_v_bh: Optional[torch.Tensor] = None
|
|
218
|
+
|
|
219
|
+
# Fold heads
|
|
220
|
+
q = self._fq(self._fold_heads(q, B), self.obs_q_fold)
|
|
221
|
+
if state is not None and "prev_key" in state and static_kv:
|
|
222
|
+
# Cross-attention static_kv path: reuse cached KV; there is no new KV this call.
|
|
223
|
+
k = None
|
|
224
|
+
v = None
|
|
225
|
+
if k is not None:
|
|
226
|
+
k = self._fq(self._fold_heads(k, B), self.obs_k_fold) # [B*H, Tnew, Dh]
|
|
227
|
+
if return_new_kv:
|
|
228
|
+
new_k_bh = k.contiguous()
|
|
229
|
+
if v is not None:
|
|
230
|
+
v = self._fq(self._fold_heads(v, B), self.obs_v_fold) # [B*H, Tnew, Dh]
|
|
231
|
+
if return_new_kv:
|
|
232
|
+
new_v_bh = v.contiguous()
|
|
233
|
+
|
|
234
|
+
# Append/reuse cache
|
|
235
|
+
if state is not None:
|
|
236
|
+
pk = state.get("prev_key")
|
|
237
|
+
pv = state.get("prev_value")
|
|
238
|
+
if pk is not None:
|
|
239
|
+
pk = pk.view(B * self.num_heads, -1, self.head_dim)
|
|
240
|
+
k = pk if static_kv else torch.cat([pk, k], dim=1)
|
|
241
|
+
if pv is not None:
|
|
242
|
+
pv = pv.view(B * self.num_heads, -1, self.head_dim)
|
|
243
|
+
v = pv if static_kv else torch.cat([pv, v], dim=1)
|
|
244
|
+
|
|
245
|
+
assert k is not None and v is not None
|
|
246
|
+
Ts = k.size(1)
|
|
247
|
+
|
|
248
|
+
# Scaled dot-product
|
|
249
|
+
scale = self._fq(self.scale_const, self.obs_scale).to(q.dtype)
|
|
250
|
+
logits_raw = self._fq(
|
|
251
|
+
torch.bmm(q, k.transpose(1, 2)), self.obs_logits_raw
|
|
252
|
+
) # [B*H,Tq,Ts]
|
|
253
|
+
logits = self._fq(logits_raw * scale, self.obs_logits)
|
|
254
|
+
|
|
255
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
256
|
+
# Masks
|
|
257
|
+
device = logits.device
|
|
258
|
+
if attn_mask is None and self.use_static_causal:
|
|
259
|
+
# Incremental decoding aware slicing:
|
|
260
|
+
# align the causal row(s) to the current time indices
|
|
261
|
+
start_q = max(Ts - Tq, 0)
|
|
262
|
+
cm = self.causal_mask_template[..., start_q : start_q + Tq, :Ts].to(
|
|
263
|
+
device=device, dtype=logits.dtype
|
|
264
|
+
)
|
|
265
|
+
attn_mask = cm.squeeze(0).squeeze(0) # [Tq,Ts]
|
|
266
|
+
|
|
267
|
+
if attn_mask is not None:
|
|
268
|
+
# Bool/byte mask -> additive float with large negatives
|
|
269
|
+
if not torch.is_floating_point(attn_mask):
|
|
270
|
+
fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
|
|
271
|
+
attn_mask = torch.where(
|
|
272
|
+
attn_mask.to(torch.bool), fill, fill.new_zeros(())
|
|
273
|
+
)
|
|
274
|
+
attn_mask = self._fq(attn_mask, self.obs_causal_mask)
|
|
275
|
+
assert isinstance(attn_mask, torch.Tensor)
|
|
276
|
+
|
|
277
|
+
if not self.assume_additive_key_padding:
|
|
278
|
+
# attn_mask -> [B*H,Tq,Ts]
|
|
279
|
+
if attn_mask.dim() == 2:
|
|
280
|
+
add_mask = attn_mask.unsqueeze(0).expand(logits.size(0), -1, -1)
|
|
281
|
+
elif attn_mask.dim() == 3:
|
|
282
|
+
add_mask = (
|
|
283
|
+
attn_mask.unsqueeze(1)
|
|
284
|
+
.expand(B, self.num_heads, Tq, Ts)
|
|
285
|
+
.contiguous()
|
|
286
|
+
)
|
|
287
|
+
add_mask = add_mask.view(B * self.num_heads, Tq, Ts)
|
|
288
|
+
else:
|
|
289
|
+
raise RuntimeError("attn_mask must be [T,S] or [B,T,S]")
|
|
290
|
+
else:
|
|
291
|
+
add_mask = attn_mask
|
|
292
|
+
logits = self._fq(logits + add_mask, self.obs_attn_mask_add)
|
|
293
|
+
|
|
294
|
+
if key_padding_mask is not None:
|
|
295
|
+
if not torch.is_floating_point(key_padding_mask):
|
|
296
|
+
fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
|
|
297
|
+
kpm = torch.where(
|
|
298
|
+
key_padding_mask.to(torch.bool), fill, fill.new_zeros(())
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
kpm = key_padding_mask
|
|
302
|
+
kpm = self._fq(kpm, self.obs_kpm_in)
|
|
303
|
+
|
|
304
|
+
if not self.assume_additive_key_padding:
|
|
305
|
+
# key_padding_mask: additive float already
|
|
306
|
+
kpm = kpm.to(dtype=logits.dtype, device=device)
|
|
307
|
+
if kpm.dim() == 2: # [B,S]
|
|
308
|
+
kpm = (
|
|
309
|
+
kpm.view(B, 1, 1, Ts)
|
|
310
|
+
.expand(B, self.num_heads, Tq, Ts)
|
|
311
|
+
.contiguous()
|
|
312
|
+
)
|
|
313
|
+
kpm = kpm.view(B * self.num_heads, Tq, Ts)
|
|
314
|
+
elif kpm.dim() == 3: # [B,T,S]
|
|
315
|
+
kpm = (
|
|
316
|
+
kpm.unsqueeze(1).expand(B, self.num_heads, Tq, Ts).contiguous()
|
|
317
|
+
)
|
|
318
|
+
kpm = kpm.view(B * self.num_heads, Tq, Ts)
|
|
319
|
+
else:
|
|
320
|
+
raise RuntimeError(
|
|
321
|
+
"key_padding_mask must be [B,S] or [B,T,S] (additive)"
|
|
322
|
+
)
|
|
323
|
+
logits = self._fq(logits + kpm, self.obs_kp_mask_add)
|
|
324
|
+
|
|
325
|
+
if before_softmax:
|
|
326
|
+
if return_new_kv:
|
|
327
|
+
return logits, v, new_k_bh, new_v_bh
|
|
328
|
+
return logits, v
|
|
329
|
+
|
|
330
|
+
# Softmax (float32) -> back to q.dtype
|
|
331
|
+
attn_probs = torch.softmax(logits, dim=-1, dtype=torch.float32).to(q.dtype)
|
|
332
|
+
attn_probs = self._fq(attn_probs, self.obs_softmax)
|
|
333
|
+
|
|
334
|
+
# Context + output proj
|
|
335
|
+
ctx = self._fq(torch.bmm(attn_probs, v), self.obs_attn_out) # [B*H,Tq,Dh]
|
|
336
|
+
ctx = self._unfold_heads(ctx, B, Tq) # [Tq,B,E]
|
|
337
|
+
out = self.out_proj(ctx)
|
|
338
|
+
|
|
339
|
+
# Weights (optional)
|
|
340
|
+
attn_weights_out: Optional[torch.Tensor] = None
|
|
341
|
+
if need_weights:
|
|
342
|
+
aw = (
|
|
343
|
+
torch.softmax(logits, dim=-1, dtype=torch.float32)
|
|
344
|
+
.view(B, self.num_heads, Tq, Ts)
|
|
345
|
+
.transpose(1, 0)
|
|
346
|
+
)
|
|
347
|
+
if not need_head_weights:
|
|
348
|
+
aw = aw.mean(dim=1) # [B,Tq,Ts]
|
|
349
|
+
attn_weights_out = aw
|
|
350
|
+
|
|
351
|
+
# Cache write
|
|
352
|
+
if state is not None:
|
|
353
|
+
state["prev_key"] = k.view(B, self.num_heads, -1, self.head_dim).detach()
|
|
354
|
+
state["prev_value"] = v.view(B, self.num_heads, -1, self.head_dim).detach()
|
|
355
|
+
self._set_input_buffer(incremental_state, state)
|
|
356
|
+
|
|
357
|
+
if return_new_kv:
|
|
358
|
+
return out, attn_weights_out, new_k_bh, new_v_bh
|
|
359
|
+
return out, attn_weights_out
|
|
360
|
+
|
|
361
|
+
def _all_observers(self):
|
|
362
|
+
yield from (
|
|
363
|
+
self.obs_query_in,
|
|
364
|
+
self.obs_key_in,
|
|
365
|
+
self.obs_value_in,
|
|
366
|
+
self.obs_kpm_in,
|
|
367
|
+
self.obs_causal_mask,
|
|
368
|
+
self.obs_q_fold,
|
|
369
|
+
self.obs_k_fold,
|
|
370
|
+
self.obs_v_fold,
|
|
371
|
+
self.obs_scale,
|
|
372
|
+
self.obs_logits_raw,
|
|
373
|
+
self.obs_logits,
|
|
374
|
+
self.obs_attn_mask_add,
|
|
375
|
+
self.obs_kp_mask_add,
|
|
376
|
+
self.obs_softmax,
|
|
377
|
+
self.obs_attn_out,
|
|
378
|
+
)
|
|
379
|
+
for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
|
|
380
|
+
if isinstance(m, QuantModuleBase):
|
|
381
|
+
yield from m._all_observers()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -12,17 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Optional
|
|
15
|
+
from typing import Optional, Tuple
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
)
|
|
25
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
22
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
23
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
@try_register(
|
|
@@ -34,7 +32,7 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
34
32
|
self,
|
|
35
33
|
fp_attn: nn.Module,
|
|
36
34
|
*,
|
|
37
|
-
qcfg: Optional[
|
|
35
|
+
qcfg: Optional[PTQConfig] = None,
|
|
38
36
|
fp_name: Optional[str] = None,
|
|
39
37
|
):
|
|
40
38
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -131,28 +129,38 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
131
129
|
x2n = self._fq(-x2, o_neg)
|
|
132
130
|
return self._fq(torch.cat((x2n, x1), -1), o_cat)
|
|
133
131
|
|
|
132
|
+
@staticmethod
|
|
133
|
+
def _concat_kv(
|
|
134
|
+
past: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
135
|
+
k_new: torch.Tensor,
|
|
136
|
+
v_new: torch.Tensor,
|
|
137
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
138
|
+
"""Concat along sequence dim (dim=2): (B, n_kv, S, H)."""
|
|
139
|
+
if past is None:
|
|
140
|
+
return k_new, v_new
|
|
141
|
+
past_k, past_v = past
|
|
142
|
+
k = torch.cat([past_k, k_new], dim=2)
|
|
143
|
+
v = torch.cat([past_v, v_new], dim=2)
|
|
144
|
+
return k, v
|
|
145
|
+
|
|
134
146
|
def forward(
|
|
135
147
|
self,
|
|
136
148
|
hidden_states: torch.Tensor,
|
|
137
149
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
138
150
|
attention_mask: Optional[torch.Tensor] = None,
|
|
139
|
-
past_key_value=None, #
|
|
151
|
+
past_key_value=None, # tuple(k, v) or HF Cache-like object
|
|
152
|
+
use_cache: Optional[bool] = False,
|
|
140
153
|
cache_position: Optional[torch.LongTensor] = None,
|
|
141
154
|
**kwargs,
|
|
142
155
|
):
|
|
143
|
-
if past_key_value is not None:
|
|
144
|
-
raise NotImplementedError(
|
|
145
|
-
"QuantLlamaAttention does not support KV cache yet."
|
|
146
|
-
)
|
|
147
|
-
|
|
148
156
|
hidden = self._fq(hidden_states, self.obs_hidden)
|
|
149
157
|
B, S, _ = hidden.shape
|
|
150
158
|
H = self.hdim
|
|
151
159
|
|
|
152
160
|
# projections
|
|
153
|
-
q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
|
154
|
-
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
|
155
|
-
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
|
161
|
+
q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H)
|
|
162
|
+
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
|
|
163
|
+
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
|
|
156
164
|
|
|
157
165
|
# rope tables
|
|
158
166
|
cos, sin = position_embeddings
|
|
@@ -176,14 +184,37 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
176
184
|
k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
|
|
177
185
|
k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
|
|
178
186
|
|
|
187
|
+
# --- build/update KV for attention & present_key_value -------------
|
|
188
|
+
present_key_value: Tuple[torch.Tensor, torch.Tensor]
|
|
189
|
+
|
|
190
|
+
# HF Cache path (if available)
|
|
191
|
+
if use_cache and hasattr(past_key_value, "update"):
|
|
192
|
+
# Many HF Cache impls use update(k, v) and return (k_total, v_total)
|
|
193
|
+
try:
|
|
194
|
+
k_total, v_total = past_key_value.update(k_rot, v)
|
|
195
|
+
present_key_value = (k_total, v_total)
|
|
196
|
+
k_for_attn, v_for_attn = k_total, v_total
|
|
197
|
+
except Exception:
|
|
198
|
+
# Fallback to tuple concat if Cache signature mismatches
|
|
199
|
+
k_for_attn, v_for_attn = self._concat_kv(
|
|
200
|
+
getattr(past_key_value, "kv", None), k_rot, v
|
|
201
|
+
)
|
|
202
|
+
present_key_value = (k_for_attn, v_for_attn)
|
|
203
|
+
else:
|
|
204
|
+
# Tuple or None path
|
|
205
|
+
pkv_tuple = past_key_value if isinstance(past_key_value, tuple) else None
|
|
206
|
+
k_for_attn, v_for_attn = self._concat_kv(pkv_tuple, k_rot, v)
|
|
207
|
+
present_key_value = (k_for_attn, v_for_attn)
|
|
208
|
+
|
|
179
209
|
# logits
|
|
180
|
-
k_rep =
|
|
210
|
+
k_rep = k_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
|
|
181
211
|
logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
|
|
182
212
|
scale = self._fq(self.scale_t, self.obs_scale)
|
|
183
213
|
logits = self._fq(logits_raw * scale, self.obs_logits)
|
|
184
214
|
|
|
185
215
|
if attention_mask is None or attention_mask.dtype == torch.bool:
|
|
186
|
-
_, _, q_len,
|
|
216
|
+
_, _, q_len, _ = logits.shape
|
|
217
|
+
k_len = k_for_attn.size(2)
|
|
187
218
|
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
188
219
|
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
|
|
189
220
|
hidden_states.device
|
|
@@ -196,7 +227,7 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
196
227
|
attn_weights = self._fq(attn_weights, self.obs_softmax)
|
|
197
228
|
|
|
198
229
|
# attn out
|
|
199
|
-
v_rep =
|
|
230
|
+
v_rep = v_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
|
|
200
231
|
attn_out = (
|
|
201
232
|
self._fq(attn_weights @ v_rep, self.obs_attn_out)
|
|
202
233
|
.transpose(1, 2)
|
|
@@ -204,7 +235,13 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
204
235
|
)
|
|
205
236
|
|
|
206
237
|
# final projection
|
|
207
|
-
|
|
238
|
+
out = self.o_proj(attn_out)
|
|
239
|
+
|
|
240
|
+
# return with/without cache
|
|
241
|
+
if use_cache:
|
|
242
|
+
return out, attn_weights, present_key_value
|
|
243
|
+
else:
|
|
244
|
+
return out, attn_weights
|
|
208
245
|
|
|
209
246
|
def _all_observers(self):
|
|
210
247
|
# local first
|
tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py
RENAMED
|
@@ -17,16 +17,12 @@ from typing import Optional, Tuple
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
from tico.
|
|
25
|
-
from tico.
|
|
26
|
-
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
27
|
-
QuantModuleBase,
|
|
28
|
-
)
|
|
29
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
|
|
22
|
+
from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
|
23
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
24
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
25
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
30
26
|
|
|
31
27
|
|
|
32
28
|
@try_register("transformers.models.llama.modeling_llama.LlamaDecoderLayer")
|
|
@@ -56,7 +52,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
56
52
|
self,
|
|
57
53
|
fp_layer: nn.Module,
|
|
58
54
|
*,
|
|
59
|
-
qcfg: Optional[
|
|
55
|
+
qcfg: Optional[PTQConfig] = None,
|
|
60
56
|
fp_name: Optional[str] = None,
|
|
61
57
|
return_type: Optional[str] = None,
|
|
62
58
|
):
|
|
@@ -136,7 +132,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
136
132
|
L = hidden_states.size(1)
|
|
137
133
|
attention_mask = self._slice_causal(L, hidden_states.device)
|
|
138
134
|
|
|
139
|
-
|
|
135
|
+
attn_out = self.self_attn(
|
|
140
136
|
hidden_states=hidden_states,
|
|
141
137
|
attention_mask=attention_mask,
|
|
142
138
|
position_ids=position_ids,
|
|
@@ -147,7 +143,13 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
147
143
|
position_embeddings=position_embeddings,
|
|
148
144
|
**kwargs,
|
|
149
145
|
)
|
|
150
|
-
|
|
146
|
+
if use_cache:
|
|
147
|
+
hidden_states_attn, _attn_weights, present_key_value = attn_out
|
|
148
|
+
else:
|
|
149
|
+
hidden_states_attn, _attn_weights = attn_out
|
|
150
|
+
present_key_value = None
|
|
151
|
+
|
|
152
|
+
hidden_states = residual + hidden_states_attn
|
|
151
153
|
|
|
152
154
|
# ─── MLP block ─────────────────────────────────────────────────
|
|
153
155
|
residual = hidden_states
|
|
@@ -155,6 +157,12 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
155
157
|
hidden_states = self.mlp(hidden_states)
|
|
156
158
|
hidden_states = residual + hidden_states
|
|
157
159
|
|
|
160
|
+
# Return type policy:
|
|
161
|
+
# - If use_cache: always return (hidden_states, present_key_value)
|
|
162
|
+
# - Else: return as configured (tuple/tensor) for HF compatibility
|
|
163
|
+
if use_cache:
|
|
164
|
+
return hidden_states, present_key_value # type: ignore[return-value]
|
|
165
|
+
|
|
158
166
|
if self.return_type == "tuple":
|
|
159
167
|
return (hidden_states,)
|
|
160
168
|
elif self.return_type == "tensor":
|
|
@@ -17,12 +17,10 @@ from typing import Optional
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
)
|
|
25
|
-
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
22
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
23
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
@try_register("transformers.models.llama.modeling_llama.LlamaMLP")
|
|
@@ -31,7 +29,7 @@ class QuantLlamaMLP(QuantModuleBase):
|
|
|
31
29
|
self,
|
|
32
30
|
mlp_fp: nn.Module,
|
|
33
31
|
*,
|
|
34
|
-
qcfg: Optional[
|
|
32
|
+
qcfg: Optional[PTQConfig] = None,
|
|
35
33
|
fp_name: Optional[str] = None,
|
|
36
34
|
):
|
|
37
35
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -17,12 +17,11 @@ from typing import Iterable, Optional, Tuple
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
|
|
22
|
-
from tico.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from tico.experimental.quantization.ptq.wrappers.registry import register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
|
|
22
|
+
from tico.quantization.wrapq.mode import Mode
|
|
23
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
24
|
+
from tico.quantization.wrapq.wrappers.registry import register
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
@register(nn.LayerNorm)
|
|
@@ -46,7 +45,7 @@ class QuantLayerNorm(QuantModuleBase):
|
|
|
46
45
|
self,
|
|
47
46
|
fp: nn.LayerNorm,
|
|
48
47
|
*,
|
|
49
|
-
qcfg: Optional[
|
|
48
|
+
qcfg: Optional[PTQConfig] = None,
|
|
50
49
|
fp_name: Optional[str] = None
|
|
51
50
|
):
|
|
52
51
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -17,13 +17,12 @@ from typing import Optional
|
|
|
17
17
|
import torch.nn as nn
|
|
18
18
|
import torch.nn.functional as F
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
|
|
22
|
-
from tico.
|
|
23
|
-
from tico.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
from tico.experimental.quantization.ptq.wrappers.registry import register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
|
|
22
|
+
from tico.quantization.wrapq.mode import Mode
|
|
23
|
+
from tico.quantization.wrapq.qscheme import QScheme
|
|
24
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
25
|
+
from tico.quantization.wrapq.wrappers.registry import register
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
@register(nn.Linear)
|
|
@@ -34,7 +33,7 @@ class QuantLinear(QuantModuleBase):
|
|
|
34
33
|
self,
|
|
35
34
|
fp: nn.Linear,
|
|
36
35
|
*,
|
|
37
|
-
qcfg: Optional[
|
|
36
|
+
qcfg: Optional[PTQConfig] = None,
|
|
38
37
|
fp_name: Optional[str] = None
|
|
39
38
|
):
|
|
40
39
|
super().__init__(qcfg, fp_name=fp_name)
|