tico 0.1.0.dev250826__py3-none-any.whl → 0.1.0.dev250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250826"
32
+ __version__ = "0.1.0.dev250828"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,103 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pathlib
16
+
17
+ import torch
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+
20
+ from tico.experimental.quantization.evaluation.metric import compute_peir
21
+ from tico.experimental.quantization.evaluation.utils import plot_two_outputs
22
+
23
+ from tico.experimental.quantization.ptq.mode import Mode
24
+ from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
25
+ QuantLlamaAttention,
26
+ )
27
+ from tico.utils.utils import SuppressWarning
28
+
29
+ name = "Maykeye/TinyLLama-v0"
30
+ model = AutoModelForCausalLM.from_pretrained(name)
31
+ tokenizer = AutoTokenizer.from_pretrained(name)
32
+
33
+ # -------------------------------------------------------------------------
34
+ # 1. Replace layer-0’s MLP with QuantLlamaMLP
35
+ # -------------------------------------------------------------------------
36
+ orig_attn = model.model.layers[0].self_attn
37
+ model.model.layers[0].self_attn = QuantLlamaAttention(
38
+ orig_attn
39
+ ) # PTQWrapper(orig_attn) is also fine
40
+ model.eval()
41
+
42
+ attn_q = model.model.layers[0].self_attn # quant wrapper
43
+ rotary = model.model.rotary_emb
44
+
45
+ # -------------------------------------------------------------------------
46
+ # 2. Single-pass calibration
47
+ # -------------------------------------------------------------------------
48
+ PROMPTS = [
49
+ "The quick brown fox jumps over the lazy dog.",
50
+ "In 2025, AI systems accelerated hardware-software co-design at scale.",
51
+ "양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.",
52
+ "今日はいい天気ですね。ところでRoPE角度は長さに依存します。",
53
+ "def quicksort(arr):\n if len(arr) <= 1: return arr\n ...",
54
+ "Prices rose 3.14% — see Figure 2; emails: foo@bar.com!",
55
+ ]
56
+
57
+ with torch.no_grad():
58
+ attn_q.enable_calibration()
59
+ for prompt in PROMPTS:
60
+ ids = tokenizer(prompt, return_tensors="pt")
61
+ embeds = model.model.embed_tokens(ids["input_ids"])
62
+ cos_sin = rotary(embeds, ids["input_ids"])
63
+ S = cos_sin[0].shape[1]
64
+ float_mask = torch.zeros(1, 1, S, S)
65
+ _ = attn_q(embeds, cos_sin) # observers collect
66
+ attn_q.freeze_qparams()
67
+
68
+ assert attn_q._mode is Mode.QUANT, "Quantization mode should be active now."
69
+
70
+ # -------------------------------------------------------------------------
71
+ # 3. Quick diff check (INT-sim vs FP32)
72
+ # -------------------------------------------------------------------------
73
+ ids = tokenizer("check", return_tensors="pt")
74
+ emb = model.model.embed_tokens(ids["input_ids"])
75
+ pos = rotary(emb, ids["input_ids"])
76
+ S = pos[0].shape[1]
77
+ float_mask = torch.zeros(1, 1, S, S)
78
+ with torch.no_grad():
79
+ int8 = attn_q(emb, pos)[0]
80
+ fp32 = orig_attn(emb, position_embeddings=pos, attention_mask=None)[0]
81
+
82
+ print("┌───────────── Quantization Error Summary ─────────────")
83
+ print(f"│ Mean |diff|: {(int8 - fp32).abs().mean().item():.6f}")
84
+ print(f"│ PEIR : {compute_peir(fp32, int8) * 100:.6f} %")
85
+ print("└──────────────────────────────────────────────────────")
86
+ print(plot_two_outputs(fp32, int8))
87
+
88
+ # -------------------------------------------------------------------------
89
+ # 4. Export the quantized block
90
+ # -------------------------------------------------------------------------
91
+ import tico
92
+
93
+ save_path = pathlib.Path("attn.q.circle")
94
+ B, S, D = 1, 4, model.config.hidden_size
95
+ example = torch.randn(B, S, D)
96
+ example_pos = rotary(example, torch.arange(S)[None, :])
97
+ float_mask = torch.zeros(1, 1, S, S)
98
+
99
+ with SuppressWarning(UserWarning, ".*"):
100
+ cm = tico.convert(attn_q, (example, example_pos))
101
+ cm.save(save_path)
102
+
103
+ print(f"Quantized Circle model saved to {save_path.resolve()}")
@@ -0,0 +1,95 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pathlib
16
+
17
+ import torch
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+
20
+ import tico
21
+ from tico.experimental.quantization.evaluation.metric import compute_peir
22
+ from tico.experimental.quantization.evaluation.utils import plot_two_outputs
23
+ from tico.experimental.quantization.ptq.dtypes import INT16
24
+ from tico.experimental.quantization.ptq.mode import Mode
25
+ from tico.experimental.quantization.ptq.qscheme import QScheme
26
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
27
+ from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
28
+ from tico.utils.utils import SuppressWarning
29
+
30
+ name = "Maykeye/TinyLLama-v0"
31
+ model = AutoModelForCausalLM.from_pretrained(name)
32
+ tokenizer = AutoTokenizer.from_pretrained(name)
33
+ model.eval()
34
+
35
+ # -------------------------------------------------------------------------
36
+ # 1. Replace layer-0’s MLP with QuantLlamaMLP
37
+ # -------------------------------------------------------------------------
38
+ fp32_mlp = model.model.layers[0].mlp
39
+ model.model.layers[0].mlp = QuantLlamaMLP(
40
+ fp32_mlp,
41
+ qcfg=QuantConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM),
42
+ ) # PTQWrapper(fp32_mlp) is also fine
43
+ model.eval()
44
+
45
+ mlp_q = model.model.layers[0].mlp
46
+
47
+ # -------------------------------------------------------------------------
48
+ # 2. Single-pass calibration
49
+ # -------------------------------------------------------------------------
50
+ PROMPTS = [
51
+ "The quick brown fox jumps over the lazy dog.",
52
+ "In 2025, AI systems accelerated hardware-software co-design at scale.",
53
+ "양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.",
54
+ "今日はいい天気ですね。ところでRoPE角度は長さに依存します。",
55
+ "def quicksort(arr):\n if len(arr) <= 1: return arr\n ...",
56
+ "Prices rose 3.14% — see Figure 2; emails: foo@bar.com!",
57
+ ]
58
+
59
+ with torch.no_grad():
60
+ mlp_q.enable_calibration()
61
+ for prompt in PROMPTS:
62
+ enc = tokenizer(prompt, return_tensors="pt")
63
+ emb = model.model.embed_tokens(enc["input_ids"])
64
+ _ = mlp_q(emb)
65
+
66
+ mlp_q.freeze_qparams()
67
+
68
+ assert mlp_q._mode is Mode.QUANT, "Quantization mode should be active now."
69
+
70
+ # -------------------------------------------------------------------------
71
+ # 3. Quick diff check (INT-sim vs FP32)
72
+ # -------------------------------------------------------------------------
73
+ with torch.no_grad():
74
+ ids = tokenizer("quant all tensors!", return_tensors="pt")
75
+ emb = model.model.embed_tokens(ids["input_ids"])
76
+ int16 = mlp_q(emb) # INT-sim
77
+ fp32 = fp32_mlp(emb) # baseline reference
78
+
79
+ print("┌───────────── Quantization Error Summary ─────────────")
80
+ print(f"│ Mean |diff|: {(int16 - fp32).abs().mean().item():.6f}")
81
+ print(f"│ PEIR : {compute_peir(fp32, int16) * 100:.6f} %")
82
+ print("└──────────────────────────────────────────────────────")
83
+ print(plot_two_outputs(fp32, int16))
84
+
85
+ # -------------------------------------------------------------------------
86
+ # 4. Export the quantized block
87
+ # -------------------------------------------------------------------------
88
+ save_path = pathlib.Path("mlp.q.circle")
89
+ example_in = (torch.randn(1, 1, model.config.hidden_size),)
90
+
91
+ with SuppressWarning(UserWarning, ".*"):
92
+ cm = tico.convert(mlp_q, example_in)
93
+ cm.save(save_path)
94
+
95
+ print(f"Quantized Circle model saved to {save_path.resolve()}")
@@ -0,0 +1,6 @@
1
+ from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
2
+ QuantLlamaAttention,
3
+ )
4
+ from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
5
+
6
+ __all__ = ["QuantLlamaAttention", "QuantLlamaMLP"]
@@ -0,0 +1,236 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
+ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
+ QuantModuleBase,
24
+ )
25
+ from tico.experimental.quantization.ptq.wrappers.registry import try_register
26
+
27
+
28
+ @try_register("transformers.models.llama.modeling_llama.LlamaAttention")
29
+ class QuantLlamaAttention(QuantModuleBase):
30
+ def __init__(
31
+ self,
32
+ fp_attn: nn.Module,
33
+ *,
34
+ qcfg: Optional[QuantConfig] = None,
35
+ fp_name: Optional[str] = None,
36
+ ):
37
+ super().__init__(qcfg, fp_name=fp_name)
38
+
39
+ cfg = fp_attn.config
40
+ assert hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads")
41
+ assert hasattr(cfg, "num_key_value_heads")
42
+ assert isinstance(cfg.hidden_size, int) and isinstance(
43
+ cfg.num_attention_heads, int
44
+ )
45
+ assert isinstance(cfg.num_key_value_heads, int)
46
+ self.hdim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
47
+ self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads
48
+
49
+ # constant scale (1/√d)
50
+ self.scale_t = torch.tensor(self.hdim**-0.5)
51
+ self.obs_scale = self._make_obs("scale")
52
+
53
+ # ---- wrap q k v o projections via PTQWrapper ---------------
54
+ q_cfg = qcfg.child("q_proj") if qcfg else None
55
+ k_cfg = qcfg.child("k_proj") if qcfg else None
56
+ v_cfg = qcfg.child("v_proj") if qcfg else None
57
+ o_cfg = qcfg.child("o_proj") if qcfg else None
58
+ assert hasattr(fp_attn, "q_proj") and isinstance(
59
+ fp_attn.q_proj, torch.nn.Module
60
+ )
61
+ assert hasattr(fp_attn, "k_proj") and isinstance(
62
+ fp_attn.k_proj, torch.nn.Module
63
+ )
64
+ assert hasattr(fp_attn, "v_proj") and isinstance(
65
+ fp_attn.v_proj, torch.nn.Module
66
+ )
67
+ assert hasattr(fp_attn, "o_proj") and isinstance(
68
+ fp_attn.o_proj, torch.nn.Module
69
+ )
70
+ self.q_proj = PTQWrapper(
71
+ fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj"
72
+ )
73
+ self.k_proj = PTQWrapper(
74
+ fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj"
75
+ )
76
+ self.v_proj = PTQWrapper(
77
+ fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj"
78
+ )
79
+ self.o_proj = PTQWrapper(
80
+ fp_attn.o_proj, qcfg=o_cfg, fp_name=f"{fp_name}.o_proj"
81
+ )
82
+
83
+ # ---- create arithmetic observers ---------------------------
84
+ mk = self._make_obs
85
+ self.obs_hidden = mk("hidden")
86
+
87
+ self.obs_cos = mk("cos")
88
+ self.obs_sin = mk("sin")
89
+
90
+ self.obs_causal_mask = mk("causal_mask")
91
+
92
+ # rotate-half sub-steps
93
+ self.obs_q_x1 = mk("q_x1")
94
+ self.obs_q_x2 = mk("q_x2")
95
+ self.obs_q_neg = mk("q_neg")
96
+ self.obs_q_cat = mk("q_cat")
97
+ self.obs_k_x1 = mk("k_x1")
98
+ self.obs_k_x2 = mk("k_x2")
99
+ self.obs_k_neg = mk("k_neg")
100
+ self.obs_k_cat = mk("k_cat")
101
+
102
+ # q / k paths
103
+ self.obs_q_cos = mk("q_cos")
104
+ self.obs_q_sin = mk("q_sin")
105
+ self.obs_q_rot = mk("q_rot")
106
+ self.obs_k_cos = mk("k_cos")
107
+ self.obs_k_sin = mk("k_sin")
108
+ self.obs_k_rot = mk("k_rot")
109
+
110
+ # logits / softmax / out
111
+ self.obs_logits_raw = mk("logits_raw")
112
+ self.obs_logits = mk("logits")
113
+ self.obs_mask_add = mk("mask_add")
114
+ self.obs_softmax = mk("softmax")
115
+ self.obs_attn_out = mk("attn_out")
116
+
117
+ # Static causal mask template
118
+ assert hasattr(cfg, "max_position_embeddings")
119
+ max_seq = cfg.max_position_embeddings
120
+ mask = torch.full((1, 1, max_seq, max_seq), float("-120")) # type: ignore[arg-type]
121
+ mask.triu_(1)
122
+ self.register_buffer("causal_mask_template", mask, persistent=False)
123
+
124
+ def _rot(self, t, o_x1, o_x2, o_neg, o_cat):
125
+ x1, x2 = torch.chunk(t, 2, dim=-1)
126
+ x1 = self._fq(x1, o_x1)
127
+ x2 = self._fq(x2, o_x2)
128
+ x2n = self._fq(-x2, o_neg)
129
+ return self._fq(torch.cat((x2n, x1), -1), o_cat)
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: torch.Tensor,
134
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ past_key_value=None, # not supported yet
137
+ cache_position: Optional[torch.LongTensor] = None,
138
+ **kwargs,
139
+ ):
140
+ if past_key_value is not None:
141
+ raise NotImplementedError(
142
+ "QuantLlamaAttention does not support KV cache yet."
143
+ )
144
+
145
+ hidden = self._fq(hidden_states, self.obs_hidden)
146
+ B, S, _ = hidden.shape
147
+ H = self.hdim
148
+
149
+ # projections
150
+ q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2)
151
+ k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2)
152
+ v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2)
153
+
154
+ # rope tables
155
+ cos, sin = position_embeddings
156
+ cos = self._fq(cos, self.obs_cos)
157
+ sin = self._fq(sin, self.obs_sin)
158
+ cos_u, sin_u = cos.unsqueeze(1), sin.unsqueeze(1)
159
+
160
+ # q_rot
161
+ q_half = self._rot(
162
+ q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat
163
+ )
164
+ q_cos = self._fq(q * cos_u, self.obs_q_cos)
165
+ q_sin = self._fq(q_half * sin_u, self.obs_q_sin)
166
+ q_rot = self._fq(q_cos + q_sin, self.obs_q_rot)
167
+
168
+ # k_rot
169
+ k_half = self._rot(
170
+ k, self.obs_k_x1, self.obs_k_x2, self.obs_k_neg, self.obs_k_cat
171
+ )
172
+ k_cos = self._fq(k * cos_u, self.obs_k_cos)
173
+ k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
174
+ k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
175
+
176
+ # logits
177
+ k_rep = k_rot.repeat_interleave(self.kv_rep, dim=1)
178
+ logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
179
+ scale = self._fq(self.scale_t, self.obs_scale)
180
+ logits = self._fq(logits_raw * scale, self.obs_logits)
181
+
182
+ if attention_mask is None or attention_mask.dtype == torch.bool:
183
+ _, _, q_len, k_len = logits.shape
184
+ assert isinstance(self.causal_mask_template, torch.Tensor)
185
+ attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
186
+ hidden_states.device
187
+ )
188
+ attention_mask = self._fq(attention_mask, self.obs_causal_mask)
189
+ logits = self._fq(logits + attention_mask, self.obs_mask_add)
190
+
191
+ # softmax
192
+ attn_weights = torch.softmax(logits, -1, dtype=torch.float32).to(q.dtype)
193
+ attn_weights = self._fq(attn_weights, self.obs_softmax)
194
+
195
+ # attn out
196
+ v_rep = v.repeat_interleave(self.kv_rep, dim=1)
197
+ attn_out = (
198
+ self._fq(attn_weights @ v_rep, self.obs_attn_out)
199
+ .transpose(1, 2)
200
+ .reshape(B, S, -1)
201
+ )
202
+
203
+ # final projection
204
+ return self.o_proj(attn_out), attn_weights
205
+
206
+ def _all_observers(self):
207
+ # local first
208
+ yield from (
209
+ self.obs_hidden,
210
+ self.obs_scale,
211
+ self.obs_cos,
212
+ self.obs_sin,
213
+ self.obs_causal_mask,
214
+ self.obs_q_x1,
215
+ self.obs_q_x2,
216
+ self.obs_q_neg,
217
+ self.obs_q_cat,
218
+ self.obs_k_x1,
219
+ self.obs_k_x2,
220
+ self.obs_k_neg,
221
+ self.obs_k_cat,
222
+ self.obs_q_cos,
223
+ self.obs_q_sin,
224
+ self.obs_q_rot,
225
+ self.obs_k_cos,
226
+ self.obs_k_sin,
227
+ self.obs_k_rot,
228
+ self.obs_logits_raw,
229
+ self.obs_logits,
230
+ self.obs_mask_add,
231
+ self.obs_softmax,
232
+ self.obs_attn_out,
233
+ )
234
+ # recurse into children that are QuantModuleBase
235
+ for m in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
236
+ yield from m._all_observers()
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
+ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
+ QuantModuleBase,
24
+ )
25
+ from tico.experimental.quantization.ptq.wrappers.registry import try_register
26
+
27
+
28
+ @try_register("transformers.models.llama.modeling_llama.LlamaMLP")
29
+ class QuantLlamaMLP(QuantModuleBase):
30
+ def __init__(
31
+ self,
32
+ mlp_fp: nn.Module,
33
+ *,
34
+ qcfg: Optional[QuantConfig] = None,
35
+ fp_name: Optional[str] = None,
36
+ ):
37
+ super().__init__(qcfg, fp_name=fp_name)
38
+
39
+ # ----- child configs (hierarchical override) -------------------
40
+ gate_cfg = qcfg.child("gate_proj") if qcfg else None
41
+ up_cfg = qcfg.child("up_proj") if qcfg else None
42
+ down_cfg = qcfg.child("down_proj") if qcfg else None
43
+ act_cfg = qcfg.child("act_fn") if qcfg else None
44
+
45
+ # ----- wrap three Linear layers -------------------------------
46
+ assert hasattr(mlp_fp, "gate_proj") and isinstance(
47
+ mlp_fp.gate_proj, torch.nn.Module
48
+ )
49
+ assert hasattr(mlp_fp, "up_proj") and isinstance(
50
+ mlp_fp.up_proj, torch.nn.Module
51
+ )
52
+ assert hasattr(mlp_fp, "down_proj") and isinstance(
53
+ mlp_fp.down_proj, torch.nn.Module
54
+ )
55
+ self.gate_proj = PTQWrapper(
56
+ mlp_fp.gate_proj, qcfg=gate_cfg, fp_name=f"{fp_name}.gate_proj"
57
+ )
58
+ self.up_proj = PTQWrapper(
59
+ mlp_fp.up_proj, qcfg=up_cfg, fp_name=f"{fp_name}.up_proj"
60
+ )
61
+ self.down_proj = PTQWrapper(
62
+ mlp_fp.down_proj, qcfg=down_cfg, fp_name=f"{fp_name}.down_proj"
63
+ )
64
+
65
+ # ----- activation ---------------------------------------------
66
+ assert hasattr(mlp_fp, "act_fn") and isinstance(mlp_fp.act_fn, torch.nn.Module)
67
+ self.act_fn = PTQWrapper(
68
+ mlp_fp.act_fn, qcfg=act_cfg, fp_name=f"{fp_name}.act_fn"
69
+ )
70
+
71
+ # ----- local observers ----------------------------------------
72
+ self.act_in_obs = self._make_obs("act_in")
73
+ self.mul_obs = self._make_obs("mul")
74
+
75
+ def forward(self, x: torch.Tensor):
76
+ # 1) quantize input once
77
+ x_q = self._fq(x, self.act_in_obs)
78
+
79
+ # 2) parallel projections
80
+ g = self.gate_proj(x_q)
81
+ u = self.up_proj(x_q)
82
+
83
+ # 3) activation on gate
84
+ a = self.act_fn(g)
85
+
86
+ # 4) element-wise product
87
+ h = self._fq(a * u, self.mul_obs)
88
+
89
+ # 5) final projection
90
+ return self.down_proj(h)
91
+
92
+ def _all_observers(self):
93
+ # local first
94
+ yield self.act_in_obs
95
+ yield self.mul_obs
96
+ # recurse into children that are QuantModuleBase
97
+ for m in (self.gate_proj, self.up_proj, self.down_proj, self.act_fn):
98
+ yield from m._all_observers()
@@ -28,6 +28,8 @@ _CORE_MODULES = (
28
28
  "tico.experimental.quantization.ptq.wrappers.nn.quant_layernorm",
29
29
  "tico.experimental.quantization.ptq.wrappers.nn.quant_linear",
30
30
  "tico.experimental.quantization.ptq.wrappers.nn.quant_silu",
31
+ "tico.experimental.quantization.ptq.wrappers.llama.quant_attn",
32
+ "tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
31
33
  # add future core wrappers here
32
34
  )
33
35
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250826
3
+ Version: 0.1.0.dev250828
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=M4dQ4CTD_7xsO5DjUx76t4A5o1Q_2NqGlMe0fjkGDxQ,1883
1
+ tico/__init__.py,sha256=Ur6T0ZsgBvl70ek5stJ0_uxhK9xNmHado-clOup1OEM,1883
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
@@ -63,6 +63,8 @@ tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXO
63
63
  tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
64
64
  tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
65
65
  tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
66
+ tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
67
+ tico/experimental/quantization/ptq/examples/quantize_llama_mlp.py,sha256=N1qZQgt1S-xZrdv-PW7OfXEcv0gsO2q9faOF4aD-zKo,4147
66
68
  tico/experimental/quantization/ptq/observers/__init__.py,sha256=WF2MvL9M_jl-B1FqcY9zic34NOCRp17HkRYv-TMxMr4,613
67
69
  tico/experimental/quantization/ptq/observers/affine_base.py,sha256=e2Eba64nrxKQyE4F_WJ7WTSsk3xe6bkdGUKaoLFWGFw,4638
68
70
  tico/experimental/quantization/ptq/observers/base.py,sha256=Wons1MzpqK1mfcy-ppl-B2Dum0edXg2dWW2Lw3V18tw,3280
@@ -76,7 +78,10 @@ tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TIm
76
78
  tico/experimental/quantization/ptq/wrappers/ptq_wrapper.py,sha256=F9sK_DiRaXiGNHULcwIbs5EUtHz6ZJ7N4r5CWTTfhsM,2442
77
79
  tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfvto6zKrBOKL4gmxfFFc31jHzyQV_zfps-iQM,3604
78
80
  tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
79
- tico/experimental/quantization/ptq/wrappers/registry.py,sha256=562nKSlp9qF-w4-aQeJbx2V_wMGE2FRrjIKUfRwC4Mg,4571
81
+ tico/experimental/quantization/ptq/wrappers/registry.py,sha256=TwH-MD-qkTkG6M-f1VqFLmSNcXLNYsh21yjyzCcojJc,4706
82
+ tico/experimental/quantization/ptq/wrappers/llama/__init__.py,sha256=b360gkQ0RxExmiV-ZaaxwJdMPJ53g6uCRlR2-_dOby0,240
83
+ tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py,sha256=WIUI6EFMTvvruvqu8pBxWy6qJeDyjkaYbJk1R3pAmwE,8578
84
+ tico/experimental/quantization/ptq/wrappers/llama/quant_mlp.py,sha256=uZMnrX66oZwxhKhcNbLXXeri-WxxRBiZnr15aBXJMm0,3562
80
85
  tico/experimental/quantization/ptq/wrappers/nn/__init__.py,sha256=I9uTt5HfcRoMEDYHpAeATMv2TbCQiX0ZbfUFMzSJ4Qw,336
81
86
  tico/experimental/quantization/ptq/wrappers/nn/quant_layernorm.py,sha256=G5Sgt-tXnzh0Rxyk-2honmZIfEQOZlRfOsoDBdSGmA4,6887
82
87
  tico/experimental/quantization/ptq/wrappers/nn/quant_linear.py,sha256=xW-VEPB7RJoslS3xLVCdhIuMjppknvpkZleRGK4JFVQ,2240
@@ -235,9 +240,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
235
240
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
236
241
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
237
242
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
238
- tico-0.1.0.dev250826.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
239
- tico-0.1.0.dev250826.dist-info/METADATA,sha256=QhtUiHj_YT4ZxsClOx4OaP24kuaLUsLT83x3yl1gRDY,8450
240
- tico-0.1.0.dev250826.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
241
- tico-0.1.0.dev250826.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
242
- tico-0.1.0.dev250826.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
243
- tico-0.1.0.dev250826.dist-info/RECORD,,
243
+ tico-0.1.0.dev250828.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
244
+ tico-0.1.0.dev250828.dist-info/METADATA,sha256=nYk1Pl6H1eZbVSUDMUYLDjb8-AsRW9e9v1EAnIgLM4k,8450
245
+ tico-0.1.0.dev250828.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
246
+ tico-0.1.0.dev250828.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
247
+ tico-0.1.0.dev250828.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
248
+ tico-0.1.0.dev250828.dist-info/RECORD,,