tico 0.1.0.dev250826__py3-none-any.whl → 0.1.0.dev250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/ptq/examples/quantize_llama_attn.py +103 -0
- tico/experimental/quantization/ptq/examples/quantize_llama_mlp.py +95 -0
- tico/experimental/quantization/ptq/wrappers/llama/__init__.py +6 -0
- tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py +236 -0
- tico/experimental/quantization/ptq/wrappers/llama/quant_mlp.py +98 -0
- tico/experimental/quantization/ptq/wrappers/registry.py +2 -0
- {tico-0.1.0.dev250826.dist-info → tico-0.1.0.dev250828.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250826.dist-info → tico-0.1.0.dev250828.dist-info}/RECORD +13 -8
- {tico-0.1.0.dev250826.dist-info → tico-0.1.0.dev250828.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250826.dist-info → tico-0.1.0.dev250828.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250826.dist-info → tico-0.1.0.dev250828.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250826.dist-info → tico-0.1.0.dev250828.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import pathlib
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
19
|
+
|
20
|
+
from tico.experimental.quantization.evaluation.metric import compute_peir
|
21
|
+
from tico.experimental.quantization.evaluation.utils import plot_two_outputs
|
22
|
+
|
23
|
+
from tico.experimental.quantization.ptq.mode import Mode
|
24
|
+
from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
|
25
|
+
QuantLlamaAttention,
|
26
|
+
)
|
27
|
+
from tico.utils.utils import SuppressWarning
|
28
|
+
|
29
|
+
name = "Maykeye/TinyLLama-v0"
|
30
|
+
model = AutoModelForCausalLM.from_pretrained(name)
|
31
|
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
32
|
+
|
33
|
+
# -------------------------------------------------------------------------
|
34
|
+
# 1. Replace layer-0’s MLP with QuantLlamaMLP
|
35
|
+
# -------------------------------------------------------------------------
|
36
|
+
orig_attn = model.model.layers[0].self_attn
|
37
|
+
model.model.layers[0].self_attn = QuantLlamaAttention(
|
38
|
+
orig_attn
|
39
|
+
) # PTQWrapper(orig_attn) is also fine
|
40
|
+
model.eval()
|
41
|
+
|
42
|
+
attn_q = model.model.layers[0].self_attn # quant wrapper
|
43
|
+
rotary = model.model.rotary_emb
|
44
|
+
|
45
|
+
# -------------------------------------------------------------------------
|
46
|
+
# 2. Single-pass calibration
|
47
|
+
# -------------------------------------------------------------------------
|
48
|
+
PROMPTS = [
|
49
|
+
"The quick brown fox jumps over the lazy dog.",
|
50
|
+
"In 2025, AI systems accelerated hardware-software co-design at scale.",
|
51
|
+
"양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.",
|
52
|
+
"今日はいい天気ですね。ところでRoPE角度は長さに依存します。",
|
53
|
+
"def quicksort(arr):\n if len(arr) <= 1: return arr\n ...",
|
54
|
+
"Prices rose 3.14% — see Figure 2; emails: foo@bar.com!",
|
55
|
+
]
|
56
|
+
|
57
|
+
with torch.no_grad():
|
58
|
+
attn_q.enable_calibration()
|
59
|
+
for prompt in PROMPTS:
|
60
|
+
ids = tokenizer(prompt, return_tensors="pt")
|
61
|
+
embeds = model.model.embed_tokens(ids["input_ids"])
|
62
|
+
cos_sin = rotary(embeds, ids["input_ids"])
|
63
|
+
S = cos_sin[0].shape[1]
|
64
|
+
float_mask = torch.zeros(1, 1, S, S)
|
65
|
+
_ = attn_q(embeds, cos_sin) # observers collect
|
66
|
+
attn_q.freeze_qparams()
|
67
|
+
|
68
|
+
assert attn_q._mode is Mode.QUANT, "Quantization mode should be active now."
|
69
|
+
|
70
|
+
# -------------------------------------------------------------------------
|
71
|
+
# 3. Quick diff check (INT-sim vs FP32)
|
72
|
+
# -------------------------------------------------------------------------
|
73
|
+
ids = tokenizer("check", return_tensors="pt")
|
74
|
+
emb = model.model.embed_tokens(ids["input_ids"])
|
75
|
+
pos = rotary(emb, ids["input_ids"])
|
76
|
+
S = pos[0].shape[1]
|
77
|
+
float_mask = torch.zeros(1, 1, S, S)
|
78
|
+
with torch.no_grad():
|
79
|
+
int8 = attn_q(emb, pos)[0]
|
80
|
+
fp32 = orig_attn(emb, position_embeddings=pos, attention_mask=None)[0]
|
81
|
+
|
82
|
+
print("┌───────────── Quantization Error Summary ─────────────")
|
83
|
+
print(f"│ Mean |diff|: {(int8 - fp32).abs().mean().item():.6f}")
|
84
|
+
print(f"│ PEIR : {compute_peir(fp32, int8) * 100:.6f} %")
|
85
|
+
print("└──────────────────────────────────────────────────────")
|
86
|
+
print(plot_two_outputs(fp32, int8))
|
87
|
+
|
88
|
+
# -------------------------------------------------------------------------
|
89
|
+
# 4. Export the quantized block
|
90
|
+
# -------------------------------------------------------------------------
|
91
|
+
import tico
|
92
|
+
|
93
|
+
save_path = pathlib.Path("attn.q.circle")
|
94
|
+
B, S, D = 1, 4, model.config.hidden_size
|
95
|
+
example = torch.randn(B, S, D)
|
96
|
+
example_pos = rotary(example, torch.arange(S)[None, :])
|
97
|
+
float_mask = torch.zeros(1, 1, S, S)
|
98
|
+
|
99
|
+
with SuppressWarning(UserWarning, ".*"):
|
100
|
+
cm = tico.convert(attn_q, (example, example_pos))
|
101
|
+
cm.save(save_path)
|
102
|
+
|
103
|
+
print(f"Quantized Circle model saved to {save_path.resolve()}")
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import pathlib
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
19
|
+
|
20
|
+
import tico
|
21
|
+
from tico.experimental.quantization.evaluation.metric import compute_peir
|
22
|
+
from tico.experimental.quantization.evaluation.utils import plot_two_outputs
|
23
|
+
from tico.experimental.quantization.ptq.dtypes import INT16
|
24
|
+
from tico.experimental.quantization.ptq.mode import Mode
|
25
|
+
from tico.experimental.quantization.ptq.qscheme import QScheme
|
26
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
27
|
+
from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
|
28
|
+
from tico.utils.utils import SuppressWarning
|
29
|
+
|
30
|
+
name = "Maykeye/TinyLLama-v0"
|
31
|
+
model = AutoModelForCausalLM.from_pretrained(name)
|
32
|
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
33
|
+
model.eval()
|
34
|
+
|
35
|
+
# -------------------------------------------------------------------------
|
36
|
+
# 1. Replace layer-0’s MLP with QuantLlamaMLP
|
37
|
+
# -------------------------------------------------------------------------
|
38
|
+
fp32_mlp = model.model.layers[0].mlp
|
39
|
+
model.model.layers[0].mlp = QuantLlamaMLP(
|
40
|
+
fp32_mlp,
|
41
|
+
qcfg=QuantConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM),
|
42
|
+
) # PTQWrapper(fp32_mlp) is also fine
|
43
|
+
model.eval()
|
44
|
+
|
45
|
+
mlp_q = model.model.layers[0].mlp
|
46
|
+
|
47
|
+
# -------------------------------------------------------------------------
|
48
|
+
# 2. Single-pass calibration
|
49
|
+
# -------------------------------------------------------------------------
|
50
|
+
PROMPTS = [
|
51
|
+
"The quick brown fox jumps over the lazy dog.",
|
52
|
+
"In 2025, AI systems accelerated hardware-software co-design at scale.",
|
53
|
+
"양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.",
|
54
|
+
"今日はいい天気ですね。ところでRoPE角度は長さに依存します。",
|
55
|
+
"def quicksort(arr):\n if len(arr) <= 1: return arr\n ...",
|
56
|
+
"Prices rose 3.14% — see Figure 2; emails: foo@bar.com!",
|
57
|
+
]
|
58
|
+
|
59
|
+
with torch.no_grad():
|
60
|
+
mlp_q.enable_calibration()
|
61
|
+
for prompt in PROMPTS:
|
62
|
+
enc = tokenizer(prompt, return_tensors="pt")
|
63
|
+
emb = model.model.embed_tokens(enc["input_ids"])
|
64
|
+
_ = mlp_q(emb)
|
65
|
+
|
66
|
+
mlp_q.freeze_qparams()
|
67
|
+
|
68
|
+
assert mlp_q._mode is Mode.QUANT, "Quantization mode should be active now."
|
69
|
+
|
70
|
+
# -------------------------------------------------------------------------
|
71
|
+
# 3. Quick diff check (INT-sim vs FP32)
|
72
|
+
# -------------------------------------------------------------------------
|
73
|
+
with torch.no_grad():
|
74
|
+
ids = tokenizer("quant all tensors!", return_tensors="pt")
|
75
|
+
emb = model.model.embed_tokens(ids["input_ids"])
|
76
|
+
int16 = mlp_q(emb) # INT-sim
|
77
|
+
fp32 = fp32_mlp(emb) # baseline reference
|
78
|
+
|
79
|
+
print("┌───────────── Quantization Error Summary ─────────────")
|
80
|
+
print(f"│ Mean |diff|: {(int16 - fp32).abs().mean().item():.6f}")
|
81
|
+
print(f"│ PEIR : {compute_peir(fp32, int16) * 100:.6f} %")
|
82
|
+
print("└──────────────────────────────────────────────────────")
|
83
|
+
print(plot_two_outputs(fp32, int16))
|
84
|
+
|
85
|
+
# -------------------------------------------------------------------------
|
86
|
+
# 4. Export the quantized block
|
87
|
+
# -------------------------------------------------------------------------
|
88
|
+
save_path = pathlib.Path("mlp.q.circle")
|
89
|
+
example_in = (torch.randn(1, 1, model.config.hidden_size),)
|
90
|
+
|
91
|
+
with SuppressWarning(UserWarning, ".*"):
|
92
|
+
cm = tico.convert(mlp_q, example_in)
|
93
|
+
cm.save(save_path)
|
94
|
+
|
95
|
+
print(f"Quantized Circle model saved to {save_path.resolve()}")
|
@@ -0,0 +1,236 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
21
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
22
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
23
|
+
QuantModuleBase,
|
24
|
+
)
|
25
|
+
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
26
|
+
|
27
|
+
|
28
|
+
@try_register("transformers.models.llama.modeling_llama.LlamaAttention")
|
29
|
+
class QuantLlamaAttention(QuantModuleBase):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
fp_attn: nn.Module,
|
33
|
+
*,
|
34
|
+
qcfg: Optional[QuantConfig] = None,
|
35
|
+
fp_name: Optional[str] = None,
|
36
|
+
):
|
37
|
+
super().__init__(qcfg, fp_name=fp_name)
|
38
|
+
|
39
|
+
cfg = fp_attn.config
|
40
|
+
assert hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads")
|
41
|
+
assert hasattr(cfg, "num_key_value_heads")
|
42
|
+
assert isinstance(cfg.hidden_size, int) and isinstance(
|
43
|
+
cfg.num_attention_heads, int
|
44
|
+
)
|
45
|
+
assert isinstance(cfg.num_key_value_heads, int)
|
46
|
+
self.hdim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
|
47
|
+
self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads
|
48
|
+
|
49
|
+
# constant scale (1/√d)
|
50
|
+
self.scale_t = torch.tensor(self.hdim**-0.5)
|
51
|
+
self.obs_scale = self._make_obs("scale")
|
52
|
+
|
53
|
+
# ---- wrap q k v o projections via PTQWrapper ---------------
|
54
|
+
q_cfg = qcfg.child("q_proj") if qcfg else None
|
55
|
+
k_cfg = qcfg.child("k_proj") if qcfg else None
|
56
|
+
v_cfg = qcfg.child("v_proj") if qcfg else None
|
57
|
+
o_cfg = qcfg.child("o_proj") if qcfg else None
|
58
|
+
assert hasattr(fp_attn, "q_proj") and isinstance(
|
59
|
+
fp_attn.q_proj, torch.nn.Module
|
60
|
+
)
|
61
|
+
assert hasattr(fp_attn, "k_proj") and isinstance(
|
62
|
+
fp_attn.k_proj, torch.nn.Module
|
63
|
+
)
|
64
|
+
assert hasattr(fp_attn, "v_proj") and isinstance(
|
65
|
+
fp_attn.v_proj, torch.nn.Module
|
66
|
+
)
|
67
|
+
assert hasattr(fp_attn, "o_proj") and isinstance(
|
68
|
+
fp_attn.o_proj, torch.nn.Module
|
69
|
+
)
|
70
|
+
self.q_proj = PTQWrapper(
|
71
|
+
fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj"
|
72
|
+
)
|
73
|
+
self.k_proj = PTQWrapper(
|
74
|
+
fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj"
|
75
|
+
)
|
76
|
+
self.v_proj = PTQWrapper(
|
77
|
+
fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj"
|
78
|
+
)
|
79
|
+
self.o_proj = PTQWrapper(
|
80
|
+
fp_attn.o_proj, qcfg=o_cfg, fp_name=f"{fp_name}.o_proj"
|
81
|
+
)
|
82
|
+
|
83
|
+
# ---- create arithmetic observers ---------------------------
|
84
|
+
mk = self._make_obs
|
85
|
+
self.obs_hidden = mk("hidden")
|
86
|
+
|
87
|
+
self.obs_cos = mk("cos")
|
88
|
+
self.obs_sin = mk("sin")
|
89
|
+
|
90
|
+
self.obs_causal_mask = mk("causal_mask")
|
91
|
+
|
92
|
+
# rotate-half sub-steps
|
93
|
+
self.obs_q_x1 = mk("q_x1")
|
94
|
+
self.obs_q_x2 = mk("q_x2")
|
95
|
+
self.obs_q_neg = mk("q_neg")
|
96
|
+
self.obs_q_cat = mk("q_cat")
|
97
|
+
self.obs_k_x1 = mk("k_x1")
|
98
|
+
self.obs_k_x2 = mk("k_x2")
|
99
|
+
self.obs_k_neg = mk("k_neg")
|
100
|
+
self.obs_k_cat = mk("k_cat")
|
101
|
+
|
102
|
+
# q / k paths
|
103
|
+
self.obs_q_cos = mk("q_cos")
|
104
|
+
self.obs_q_sin = mk("q_sin")
|
105
|
+
self.obs_q_rot = mk("q_rot")
|
106
|
+
self.obs_k_cos = mk("k_cos")
|
107
|
+
self.obs_k_sin = mk("k_sin")
|
108
|
+
self.obs_k_rot = mk("k_rot")
|
109
|
+
|
110
|
+
# logits / softmax / out
|
111
|
+
self.obs_logits_raw = mk("logits_raw")
|
112
|
+
self.obs_logits = mk("logits")
|
113
|
+
self.obs_mask_add = mk("mask_add")
|
114
|
+
self.obs_softmax = mk("softmax")
|
115
|
+
self.obs_attn_out = mk("attn_out")
|
116
|
+
|
117
|
+
# Static causal mask template
|
118
|
+
assert hasattr(cfg, "max_position_embeddings")
|
119
|
+
max_seq = cfg.max_position_embeddings
|
120
|
+
mask = torch.full((1, 1, max_seq, max_seq), float("-120")) # type: ignore[arg-type]
|
121
|
+
mask.triu_(1)
|
122
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
123
|
+
|
124
|
+
def _rot(self, t, o_x1, o_x2, o_neg, o_cat):
|
125
|
+
x1, x2 = torch.chunk(t, 2, dim=-1)
|
126
|
+
x1 = self._fq(x1, o_x1)
|
127
|
+
x2 = self._fq(x2, o_x2)
|
128
|
+
x2n = self._fq(-x2, o_neg)
|
129
|
+
return self._fq(torch.cat((x2n, x1), -1), o_cat)
|
130
|
+
|
131
|
+
def forward(
|
132
|
+
self,
|
133
|
+
hidden_states: torch.Tensor,
|
134
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
135
|
+
attention_mask: Optional[torch.Tensor] = None,
|
136
|
+
past_key_value=None, # not supported yet
|
137
|
+
cache_position: Optional[torch.LongTensor] = None,
|
138
|
+
**kwargs,
|
139
|
+
):
|
140
|
+
if past_key_value is not None:
|
141
|
+
raise NotImplementedError(
|
142
|
+
"QuantLlamaAttention does not support KV cache yet."
|
143
|
+
)
|
144
|
+
|
145
|
+
hidden = self._fq(hidden_states, self.obs_hidden)
|
146
|
+
B, S, _ = hidden.shape
|
147
|
+
H = self.hdim
|
148
|
+
|
149
|
+
# projections
|
150
|
+
q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
151
|
+
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
152
|
+
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
153
|
+
|
154
|
+
# rope tables
|
155
|
+
cos, sin = position_embeddings
|
156
|
+
cos = self._fq(cos, self.obs_cos)
|
157
|
+
sin = self._fq(sin, self.obs_sin)
|
158
|
+
cos_u, sin_u = cos.unsqueeze(1), sin.unsqueeze(1)
|
159
|
+
|
160
|
+
# q_rot
|
161
|
+
q_half = self._rot(
|
162
|
+
q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat
|
163
|
+
)
|
164
|
+
q_cos = self._fq(q * cos_u, self.obs_q_cos)
|
165
|
+
q_sin = self._fq(q_half * sin_u, self.obs_q_sin)
|
166
|
+
q_rot = self._fq(q_cos + q_sin, self.obs_q_rot)
|
167
|
+
|
168
|
+
# k_rot
|
169
|
+
k_half = self._rot(
|
170
|
+
k, self.obs_k_x1, self.obs_k_x2, self.obs_k_neg, self.obs_k_cat
|
171
|
+
)
|
172
|
+
k_cos = self._fq(k * cos_u, self.obs_k_cos)
|
173
|
+
k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
|
174
|
+
k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
|
175
|
+
|
176
|
+
# logits
|
177
|
+
k_rep = k_rot.repeat_interleave(self.kv_rep, dim=1)
|
178
|
+
logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
|
179
|
+
scale = self._fq(self.scale_t, self.obs_scale)
|
180
|
+
logits = self._fq(logits_raw * scale, self.obs_logits)
|
181
|
+
|
182
|
+
if attention_mask is None or attention_mask.dtype == torch.bool:
|
183
|
+
_, _, q_len, k_len = logits.shape
|
184
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
185
|
+
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
|
186
|
+
hidden_states.device
|
187
|
+
)
|
188
|
+
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
|
189
|
+
logits = self._fq(logits + attention_mask, self.obs_mask_add)
|
190
|
+
|
191
|
+
# softmax
|
192
|
+
attn_weights = torch.softmax(logits, -1, dtype=torch.float32).to(q.dtype)
|
193
|
+
attn_weights = self._fq(attn_weights, self.obs_softmax)
|
194
|
+
|
195
|
+
# attn out
|
196
|
+
v_rep = v.repeat_interleave(self.kv_rep, dim=1)
|
197
|
+
attn_out = (
|
198
|
+
self._fq(attn_weights @ v_rep, self.obs_attn_out)
|
199
|
+
.transpose(1, 2)
|
200
|
+
.reshape(B, S, -1)
|
201
|
+
)
|
202
|
+
|
203
|
+
# final projection
|
204
|
+
return self.o_proj(attn_out), attn_weights
|
205
|
+
|
206
|
+
def _all_observers(self):
|
207
|
+
# local first
|
208
|
+
yield from (
|
209
|
+
self.obs_hidden,
|
210
|
+
self.obs_scale,
|
211
|
+
self.obs_cos,
|
212
|
+
self.obs_sin,
|
213
|
+
self.obs_causal_mask,
|
214
|
+
self.obs_q_x1,
|
215
|
+
self.obs_q_x2,
|
216
|
+
self.obs_q_neg,
|
217
|
+
self.obs_q_cat,
|
218
|
+
self.obs_k_x1,
|
219
|
+
self.obs_k_x2,
|
220
|
+
self.obs_k_neg,
|
221
|
+
self.obs_k_cat,
|
222
|
+
self.obs_q_cos,
|
223
|
+
self.obs_q_sin,
|
224
|
+
self.obs_q_rot,
|
225
|
+
self.obs_k_cos,
|
226
|
+
self.obs_k_sin,
|
227
|
+
self.obs_k_rot,
|
228
|
+
self.obs_logits_raw,
|
229
|
+
self.obs_logits,
|
230
|
+
self.obs_mask_add,
|
231
|
+
self.obs_softmax,
|
232
|
+
self.obs_attn_out,
|
233
|
+
)
|
234
|
+
# recurse into children that are QuantModuleBase
|
235
|
+
for m in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
|
236
|
+
yield from m._all_observers()
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
21
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
22
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
23
|
+
QuantModuleBase,
|
24
|
+
)
|
25
|
+
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
26
|
+
|
27
|
+
|
28
|
+
@try_register("transformers.models.llama.modeling_llama.LlamaMLP")
|
29
|
+
class QuantLlamaMLP(QuantModuleBase):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
mlp_fp: nn.Module,
|
33
|
+
*,
|
34
|
+
qcfg: Optional[QuantConfig] = None,
|
35
|
+
fp_name: Optional[str] = None,
|
36
|
+
):
|
37
|
+
super().__init__(qcfg, fp_name=fp_name)
|
38
|
+
|
39
|
+
# ----- child configs (hierarchical override) -------------------
|
40
|
+
gate_cfg = qcfg.child("gate_proj") if qcfg else None
|
41
|
+
up_cfg = qcfg.child("up_proj") if qcfg else None
|
42
|
+
down_cfg = qcfg.child("down_proj") if qcfg else None
|
43
|
+
act_cfg = qcfg.child("act_fn") if qcfg else None
|
44
|
+
|
45
|
+
# ----- wrap three Linear layers -------------------------------
|
46
|
+
assert hasattr(mlp_fp, "gate_proj") and isinstance(
|
47
|
+
mlp_fp.gate_proj, torch.nn.Module
|
48
|
+
)
|
49
|
+
assert hasattr(mlp_fp, "up_proj") and isinstance(
|
50
|
+
mlp_fp.up_proj, torch.nn.Module
|
51
|
+
)
|
52
|
+
assert hasattr(mlp_fp, "down_proj") and isinstance(
|
53
|
+
mlp_fp.down_proj, torch.nn.Module
|
54
|
+
)
|
55
|
+
self.gate_proj = PTQWrapper(
|
56
|
+
mlp_fp.gate_proj, qcfg=gate_cfg, fp_name=f"{fp_name}.gate_proj"
|
57
|
+
)
|
58
|
+
self.up_proj = PTQWrapper(
|
59
|
+
mlp_fp.up_proj, qcfg=up_cfg, fp_name=f"{fp_name}.up_proj"
|
60
|
+
)
|
61
|
+
self.down_proj = PTQWrapper(
|
62
|
+
mlp_fp.down_proj, qcfg=down_cfg, fp_name=f"{fp_name}.down_proj"
|
63
|
+
)
|
64
|
+
|
65
|
+
# ----- activation ---------------------------------------------
|
66
|
+
assert hasattr(mlp_fp, "act_fn") and isinstance(mlp_fp.act_fn, torch.nn.Module)
|
67
|
+
self.act_fn = PTQWrapper(
|
68
|
+
mlp_fp.act_fn, qcfg=act_cfg, fp_name=f"{fp_name}.act_fn"
|
69
|
+
)
|
70
|
+
|
71
|
+
# ----- local observers ----------------------------------------
|
72
|
+
self.act_in_obs = self._make_obs("act_in")
|
73
|
+
self.mul_obs = self._make_obs("mul")
|
74
|
+
|
75
|
+
def forward(self, x: torch.Tensor):
|
76
|
+
# 1) quantize input once
|
77
|
+
x_q = self._fq(x, self.act_in_obs)
|
78
|
+
|
79
|
+
# 2) parallel projections
|
80
|
+
g = self.gate_proj(x_q)
|
81
|
+
u = self.up_proj(x_q)
|
82
|
+
|
83
|
+
# 3) activation on gate
|
84
|
+
a = self.act_fn(g)
|
85
|
+
|
86
|
+
# 4) element-wise product
|
87
|
+
h = self._fq(a * u, self.mul_obs)
|
88
|
+
|
89
|
+
# 5) final projection
|
90
|
+
return self.down_proj(h)
|
91
|
+
|
92
|
+
def _all_observers(self):
|
93
|
+
# local first
|
94
|
+
yield self.act_in_obs
|
95
|
+
yield self.mul_obs
|
96
|
+
# recurse into children that are QuantModuleBase
|
97
|
+
for m in (self.gate_proj, self.up_proj, self.down_proj, self.act_fn):
|
98
|
+
yield from m._all_observers()
|
@@ -28,6 +28,8 @@ _CORE_MODULES = (
|
|
28
28
|
"tico.experimental.quantization.ptq.wrappers.nn.quant_layernorm",
|
29
29
|
"tico.experimental.quantization.ptq.wrappers.nn.quant_linear",
|
30
30
|
"tico.experimental.quantization.ptq.wrappers.nn.quant_silu",
|
31
|
+
"tico.experimental.quantization.ptq.wrappers.llama.quant_attn",
|
32
|
+
"tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
|
31
33
|
# add future core wrappers here
|
32
34
|
)
|
33
35
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=Ur6T0ZsgBvl70ek5stJ0_uxhK9xNmHado-clOup1OEM,1883
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
@@ -63,6 +63,8 @@ tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXO
|
|
63
63
|
tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
|
64
64
|
tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
65
65
|
tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
|
66
|
+
tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
|
67
|
+
tico/experimental/quantization/ptq/examples/quantize_llama_mlp.py,sha256=N1qZQgt1S-xZrdv-PW7OfXEcv0gsO2q9faOF4aD-zKo,4147
|
66
68
|
tico/experimental/quantization/ptq/observers/__init__.py,sha256=WF2MvL9M_jl-B1FqcY9zic34NOCRp17HkRYv-TMxMr4,613
|
67
69
|
tico/experimental/quantization/ptq/observers/affine_base.py,sha256=e2Eba64nrxKQyE4F_WJ7WTSsk3xe6bkdGUKaoLFWGFw,4638
|
68
70
|
tico/experimental/quantization/ptq/observers/base.py,sha256=Wons1MzpqK1mfcy-ppl-B2Dum0edXg2dWW2Lw3V18tw,3280
|
@@ -76,7 +78,10 @@ tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TIm
|
|
76
78
|
tico/experimental/quantization/ptq/wrappers/ptq_wrapper.py,sha256=F9sK_DiRaXiGNHULcwIbs5EUtHz6ZJ7N4r5CWTTfhsM,2442
|
77
79
|
tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfvto6zKrBOKL4gmxfFFc31jHzyQV_zfps-iQM,3604
|
78
80
|
tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
|
79
|
-
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=
|
81
|
+
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=TwH-MD-qkTkG6M-f1VqFLmSNcXLNYsh21yjyzCcojJc,4706
|
82
|
+
tico/experimental/quantization/ptq/wrappers/llama/__init__.py,sha256=b360gkQ0RxExmiV-ZaaxwJdMPJ53g6uCRlR2-_dOby0,240
|
83
|
+
tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py,sha256=WIUI6EFMTvvruvqu8pBxWy6qJeDyjkaYbJk1R3pAmwE,8578
|
84
|
+
tico/experimental/quantization/ptq/wrappers/llama/quant_mlp.py,sha256=uZMnrX66oZwxhKhcNbLXXeri-WxxRBiZnr15aBXJMm0,3562
|
80
85
|
tico/experimental/quantization/ptq/wrappers/nn/__init__.py,sha256=I9uTt5HfcRoMEDYHpAeATMv2TbCQiX0ZbfUFMzSJ4Qw,336
|
81
86
|
tico/experimental/quantization/ptq/wrappers/nn/quant_layernorm.py,sha256=G5Sgt-tXnzh0Rxyk-2honmZIfEQOZlRfOsoDBdSGmA4,6887
|
82
87
|
tico/experimental/quantization/ptq/wrappers/nn/quant_linear.py,sha256=xW-VEPB7RJoslS3xLVCdhIuMjppknvpkZleRGK4JFVQ,2240
|
@@ -235,9 +240,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
235
240
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
236
241
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
237
242
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
238
|
-
tico-0.1.0.
|
239
|
-
tico-0.1.0.
|
240
|
-
tico-0.1.0.
|
241
|
-
tico-0.1.0.
|
242
|
-
tico-0.1.0.
|
243
|
-
tico-0.1.0.
|
243
|
+
tico-0.1.0.dev250828.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
244
|
+
tico-0.1.0.dev250828.dist-info/METADATA,sha256=nYk1Pl6H1eZbVSUDMUYLDjb8-AsRW9e9v1EAnIgLM4k,8450
|
245
|
+
tico-0.1.0.dev250828.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
246
|
+
tico-0.1.0.dev250828.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
247
|
+
tico-0.1.0.dev250828.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
248
|
+
tico-0.1.0.dev250828.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|