tico 0.1.0.dev250911__py3-none-any.whl → 0.1.0.dev250914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/experimental/quantization/ptq/examples/compare_ppl.py +199 -81
- tico/experimental/quantization/ptq/wrappers/fairseq/quant_encoder.py +333 -0
- tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py +53 -14
- tico/experimental/quantization/ptq/wrappers/llama/quant_decoder_layer.py +14 -2
- tico/experimental/quantization/ptq/wrappers/registry.py +1 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
- tico/passes/remove_redundant_expand.py +3 -1
- {tico-0.1.0.dev250911.dist-info → tico-0.1.0.dev250914.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250911.dist-info → tico-0.1.0.dev250914.dist-info}/RECORD +14 -13
- {tico-0.1.0.dev250911.dist-info → tico-0.1.0.dev250914.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250911.dist-info → tico-0.1.0.dev250914.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250911.dist-info → tico-0.1.0.dev250914.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250911.dist-info → tico-0.1.0.dev250914.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
|
@@ -20,6 +20,10 @@
|
|
|
20
20
|
# • Full post-training UINT-8 flow (wrap → calibrate → eval).
|
|
21
21
|
# =============================================================================
|
|
22
22
|
|
|
23
|
+
import argparse
|
|
24
|
+
import sys
|
|
25
|
+
from typing import Optional
|
|
26
|
+
|
|
23
27
|
import torch
|
|
24
28
|
import tqdm
|
|
25
29
|
from datasets import load_dataset
|
|
@@ -29,14 +33,6 @@ from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
|
29
33
|
from tico.experimental.quantization.ptq.utils.metrics import perplexity
|
|
30
34
|
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
|
31
35
|
|
|
32
|
-
# -------------------------------------------------------------------------
|
|
33
|
-
# 0. Global configuration
|
|
34
|
-
# -------------------------------------------------------------------------
|
|
35
|
-
MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
|
|
36
|
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
37
|
-
STRIDE = 512 # sliding-window stride for perplexity
|
|
38
|
-
RUN_FP = True # set False → run UINT-8 path
|
|
39
|
-
|
|
40
36
|
# Token-budget presets for activation calibration
|
|
41
37
|
TOKENS: dict[str, int] = {
|
|
42
38
|
# Smoke test (<1 min turnaround on CPU/GPU)
|
|
@@ -46,76 +42,198 @@ TOKENS: dict[str, int] = {
|
|
|
46
42
|
# Production / 4-bit observer smoothing
|
|
47
43
|
"production": 200_000,
|
|
48
44
|
}
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
45
|
+
|
|
46
|
+
DTYPE_MAP = {
|
|
47
|
+
"float32": torch.float32,
|
|
48
|
+
"bfloat16": torch.bfloat16,
|
|
49
|
+
"float16": torch.float16,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
# Hardcoded dataset settings
|
|
53
|
+
DATASET_NAME = "wikitext"
|
|
54
|
+
DATASET_CONFIG = "wikitext-2-raw-v1"
|
|
55
|
+
TRAIN_SPLIT = "train"
|
|
56
|
+
TEST_SPLIT = "test"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def main():
|
|
60
|
+
parser = argparse.ArgumentParser(description="Quick PTQ example (FP or UINT8)")
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"--mode",
|
|
63
|
+
choices=["fp", "uint8"],
|
|
64
|
+
default="fp",
|
|
65
|
+
help="Choose FP baseline only or full UINT8 PTQ path.",
|
|
66
|
+
)
|
|
67
|
+
parser.add_argument(
|
|
68
|
+
"--model", type=str, required=True, help="HF repo name or local path."
|
|
69
|
+
)
|
|
70
|
+
parser.add_argument(
|
|
71
|
+
"--device",
|
|
72
|
+
type=str,
|
|
73
|
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
74
|
+
help="Device to run on (cuda|cpu).",
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--dtype",
|
|
78
|
+
choices=list(DTYPE_MAP.keys()),
|
|
79
|
+
default="float32",
|
|
80
|
+
help="Model dtype for load (float32|bfloat16|float16).",
|
|
81
|
+
)
|
|
82
|
+
parser.add_argument(
|
|
83
|
+
"--stride", type=int, default=512, help="Sliding-window stride for perplexity."
|
|
84
|
+
)
|
|
85
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--trust-remote-code",
|
|
88
|
+
action="store_true",
|
|
89
|
+
help="Enable only if you trust the model repo code.",
|
|
90
|
+
)
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--hf-token",
|
|
93
|
+
type=str,
|
|
94
|
+
default=None,
|
|
95
|
+
help="Optional HF token for gated/private models.",
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--use-cache",
|
|
99
|
+
dest="use_cache",
|
|
100
|
+
action="store_true",
|
|
101
|
+
default=False,
|
|
102
|
+
help="Use model KV cache if enabled (off by default).",
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--no-tqdm", action="store_true", help="Disable tqdm progress bars."
|
|
106
|
+
)
|
|
107
|
+
# 2) calib-preset default = debug
|
|
108
|
+
parser.add_argument(
|
|
109
|
+
"--calib-preset",
|
|
110
|
+
choices=list(TOKENS.keys()),
|
|
111
|
+
default="debug",
|
|
112
|
+
help="Calibration token budget preset.",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
args = parser.parse_args()
|
|
116
|
+
|
|
117
|
+
# Basic setup
|
|
118
|
+
torch.manual_seed(args.seed)
|
|
119
|
+
device = torch.device(args.device)
|
|
120
|
+
dtype = DTYPE_MAP[args.dtype]
|
|
121
|
+
|
|
122
|
+
print("=== Config ===")
|
|
123
|
+
print(f"Mode : {args.mode}")
|
|
124
|
+
print(f"Model : {args.model}")
|
|
125
|
+
print(f"Device : {device.type}")
|
|
126
|
+
print(f"DType : {args.dtype}")
|
|
127
|
+
print(f"Stride : {args.stride}")
|
|
128
|
+
print(f"Use HF cache? : {args.use_cache}")
|
|
129
|
+
print(f"Calib preset : {args.calib_preset}")
|
|
130
|
+
print()
|
|
131
|
+
|
|
132
|
+
# -------------------------------------------------------------------------
|
|
133
|
+
# 1. Load model and tokenizer
|
|
134
|
+
# -------------------------------------------------------------------------
|
|
135
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
136
|
+
args.model,
|
|
137
|
+
trust_remote_code=args.trust_remote_code,
|
|
138
|
+
token=args.hf_token,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
model = (
|
|
142
|
+
AutoModelForCausalLM.from_pretrained(
|
|
143
|
+
args.model,
|
|
144
|
+
torch_dtype=dtype,
|
|
145
|
+
trust_remote_code=args.trust_remote_code,
|
|
146
|
+
token=args.hf_token,
|
|
147
|
+
)
|
|
148
|
+
.to(device)
|
|
149
|
+
.eval()
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
model.config.use_cache = args.use_cache
|
|
153
|
+
|
|
154
|
+
if args.mode == "fp":
|
|
155
|
+
fp_model = model
|
|
156
|
+
else:
|
|
157
|
+
# INT8 PTQ path
|
|
158
|
+
uint8_model = model
|
|
159
|
+
|
|
160
|
+
CALIB_TOKENS = TOKENS[args.calib_preset]
|
|
161
|
+
print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
|
|
162
|
+
|
|
163
|
+
# ---------------------------------------------------------------------
|
|
164
|
+
# 2. Wrap every Transformer layer with PTQWrapper
|
|
165
|
+
# ---------------------------------------------------------------------
|
|
166
|
+
qcfg = QuantConfig() # all-uint8 defaults
|
|
167
|
+
|
|
168
|
+
wrapped_layers = torch.nn.ModuleList()
|
|
169
|
+
for idx, layer in enumerate(uint8_model.model.layers):
|
|
170
|
+
layer_cfg = qcfg.child(f"layer{idx}")
|
|
171
|
+
wrapped_layers.append(PTQWrapper(layer, qcfg=layer_cfg))
|
|
172
|
+
uint8_model.model.layers = wrapped_layers
|
|
173
|
+
|
|
174
|
+
# ---------------------------------------------------------------------
|
|
175
|
+
# 3. Single-pass activation calibration
|
|
176
|
+
# ---------------------------------------------------------------------
|
|
177
|
+
print("Calibrating UINT-8 observers …")
|
|
178
|
+
calib_txt = " ".join(
|
|
179
|
+
load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)["text"]
|
|
180
|
+
)[:CALIB_TOKENS]
|
|
181
|
+
ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
|
|
182
|
+
|
|
183
|
+
# (a) switch every QuantModuleBase to CALIB mode
|
|
184
|
+
for l in uint8_model.model.layers:
|
|
185
|
+
l.enable_calibration()
|
|
186
|
+
|
|
187
|
+
# (b) run inference to collect ranges
|
|
188
|
+
iterator = range(0, ids.size(1) - 1, args.stride)
|
|
189
|
+
if not args.no_tqdm:
|
|
190
|
+
iterator = tqdm.tqdm(iterator, desc="Calibration")
|
|
191
|
+
with torch.no_grad():
|
|
192
|
+
for i in iterator:
|
|
193
|
+
uint8_model(ids[:, i : i + args.stride])
|
|
194
|
+
|
|
195
|
+
# (c) freeze (scale, zero-point)
|
|
196
|
+
for l in uint8_model.model.layers:
|
|
197
|
+
l.freeze_qparams()
|
|
198
|
+
|
|
199
|
+
# -------------------------------------------------------------------------
|
|
200
|
+
# 4. Evaluate perplexity
|
|
201
|
+
# -------------------------------------------------------------------------
|
|
202
|
+
print("\nCalculating perplexities …")
|
|
203
|
+
test_ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
|
|
204
|
+
enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
|
|
205
|
+
|
|
206
|
+
if args.mode == "fp":
|
|
207
|
+
ppl_fp = perplexity(
|
|
208
|
+
fp_model,
|
|
209
|
+
enc,
|
|
210
|
+
args.device,
|
|
211
|
+
stride=args.stride,
|
|
212
|
+
show_progress=not args.no_tqdm,
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
ppl_int8 = perplexity(
|
|
216
|
+
uint8_model,
|
|
217
|
+
enc,
|
|
218
|
+
args.device,
|
|
219
|
+
stride=args.stride,
|
|
220
|
+
show_progress=not args.no_tqdm,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# -------------------------------------------------------------------------
|
|
224
|
+
# 5. Report
|
|
225
|
+
# -------------------------------------------------------------------------
|
|
226
|
+
print("\n┌── Wikitext-2 test perplexity ─────────────")
|
|
227
|
+
if args.mode == "fp":
|
|
228
|
+
print(f"│ FP : {ppl_fp:8.2f}")
|
|
229
|
+
else:
|
|
230
|
+
print(f"│ UINT-8 : {ppl_int8:8.2f}")
|
|
231
|
+
print("└───────────────────────────────────────────")
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
if __name__ == "__main__":
|
|
235
|
+
try:
|
|
236
|
+
main()
|
|
237
|
+
except Exception as e:
|
|
238
|
+
print(f"\n[Error] {e}", file=sys.stderr)
|
|
239
|
+
sys.exit(1)
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
import math
|
|
22
|
+
from typing import Dict, List, Literal, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
26
|
+
from torch import Tensor
|
|
27
|
+
|
|
28
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
29
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
|
30
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
31
|
+
QuantModuleBase,
|
|
32
|
+
)
|
|
33
|
+
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@try_register("fairseq.models.transformer.TransformerEncoderBase")
|
|
37
|
+
class QuantFairseqEncoder(QuantModuleBase):
|
|
38
|
+
"""
|
|
39
|
+
Quant-aware drop-in replacement for Fairseq TransformerEncoderBase.
|
|
40
|
+
|
|
41
|
+
Key design choices:
|
|
42
|
+
- Keep embeddings and LayerNorms in FP.
|
|
43
|
+
- Remove training-time logic (dropout, activation-dropout, quant_noise).
|
|
44
|
+
- Attention masks are handled statically inside the layer wrapper; this
|
|
45
|
+
encoder only does the original padding zero-out before the stack.
|
|
46
|
+
|
|
47
|
+
I/O contracts:
|
|
48
|
+
- Forward signature and returned dictionary are identical to the original
|
|
49
|
+
when `use_external_inputs=False`.
|
|
50
|
+
- When `use_external_inputs=True`, forward returns a single Tensor (T,B,C)
|
|
51
|
+
and completely skips embedding/positional/LN/mask-creation paths.
|
|
52
|
+
- Tensor shapes follow Fairseq convention.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
fp_encoder: nn.Module,
|
|
58
|
+
*,
|
|
59
|
+
qcfg: Optional[QuantConfig] = None,
|
|
60
|
+
fp_name: Optional[str] = None,
|
|
61
|
+
use_external_inputs: bool = False, # export-mode flag
|
|
62
|
+
return_type: Literal["tensor", "dict"] = "dict",
|
|
63
|
+
):
|
|
64
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
65
|
+
self.use_external_inputs = use_external_inputs
|
|
66
|
+
self.return_type: Literal["tensor", "dict"] = return_type
|
|
67
|
+
|
|
68
|
+
# --- carry basic config / metadata (read-only copies) ---------------
|
|
69
|
+
assert hasattr(fp_encoder, "cfg")
|
|
70
|
+
self.cfg = fp_encoder.cfg
|
|
71
|
+
self.return_fc: bool = bool(getattr(fp_encoder, "return_fc", False))
|
|
72
|
+
|
|
73
|
+
# Embedding stack ----------------------------------------------------
|
|
74
|
+
assert hasattr(fp_encoder, "embed_tokens") and isinstance(
|
|
75
|
+
fp_encoder.embed_tokens, nn.Module
|
|
76
|
+
)
|
|
77
|
+
self.embed_tokens = fp_encoder.embed_tokens # keep FP embeddings
|
|
78
|
+
|
|
79
|
+
assert hasattr(fp_encoder, "padding_idx")
|
|
80
|
+
self.padding_idx: int = int(fp_encoder.padding_idx) # type: ignore[arg-type]
|
|
81
|
+
|
|
82
|
+
# scale = sqrt(embed_dim) unless disabled
|
|
83
|
+
embed_dim = int(self.embed_tokens.embedding_dim) # type: ignore[arg-type]
|
|
84
|
+
no_scale = bool(getattr(self.cfg, "no_scale_embedding", False))
|
|
85
|
+
self.embed_scale: float = 1.0 if no_scale else math.sqrt(embed_dim)
|
|
86
|
+
|
|
87
|
+
# Positional embeddings (keep as-is; no FQ)
|
|
88
|
+
self.embed_positions = getattr(fp_encoder, "embed_positions", None)
|
|
89
|
+
# Optional embedding LayerNorm
|
|
90
|
+
self.layernorm_embedding = getattr(fp_encoder, "layernorm_embedding", None)
|
|
91
|
+
|
|
92
|
+
# Final encoder LayerNorm (pre-norm stacks may set this to None)
|
|
93
|
+
self.layer_norm = getattr(fp_encoder, "layer_norm", None)
|
|
94
|
+
|
|
95
|
+
# Max positions (reuse for API parity)
|
|
96
|
+
self.max_source_positions: int = int(fp_encoder.max_source_positions) # type: ignore[arg-type]
|
|
97
|
+
|
|
98
|
+
# --- wrap encoder layers with PTQWrapper ----------------------------
|
|
99
|
+
assert hasattr(fp_encoder, "layers")
|
|
100
|
+
fp_layers = list(fp_encoder.layers) # type: ignore[arg-type]
|
|
101
|
+
self.layers = nn.ModuleList()
|
|
102
|
+
|
|
103
|
+
# Prepare child QuantConfig namespaces: layers/<idx>
|
|
104
|
+
layers_qcfg = qcfg.child("layers") if qcfg else None
|
|
105
|
+
for i, layer in enumerate(fp_layers):
|
|
106
|
+
child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
|
|
107
|
+
self.layers.append(
|
|
108
|
+
PTQWrapper(layer, qcfg=child_cfg, fp_name=f"{fp_name}.layers.{i}")
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Version buffer (keep for state_dict parity)
|
|
112
|
+
version = getattr(fp_encoder, "version", None)
|
|
113
|
+
if isinstance(version, torch.Tensor):
|
|
114
|
+
self.register_buffer("version", version.clone(), persistent=False)
|
|
115
|
+
else:
|
|
116
|
+
self.register_buffer("version", torch.tensor([3.0]), persistent=False)
|
|
117
|
+
|
|
118
|
+
# ----------------------------------------------------------------------
|
|
119
|
+
def forward_embedding(
|
|
120
|
+
self, src_tokens: Tensor, token_embedding: Optional[Tensor] = None
|
|
121
|
+
) -> Tuple[Tensor, Tensor]:
|
|
122
|
+
"""
|
|
123
|
+
Embed tokens and add positional embeddings. Dropout/quant_noise are removed.
|
|
124
|
+
Returns:
|
|
125
|
+
x (B, T, C), embed (B, T, C) # embed is the token-only embedding
|
|
126
|
+
"""
|
|
127
|
+
if token_embedding is None:
|
|
128
|
+
token_embedding = self.embed_tokens(src_tokens)
|
|
129
|
+
embed = token_embedding # token-only
|
|
130
|
+
|
|
131
|
+
x = self.embed_scale * token_embedding
|
|
132
|
+
if self.embed_positions is not None:
|
|
133
|
+
x = x + self.embed_positions(src_tokens)
|
|
134
|
+
if self.layernorm_embedding is not None:
|
|
135
|
+
x = self.layernorm_embedding(x)
|
|
136
|
+
# No dropout, no quant_noise here (inference-only)
|
|
137
|
+
return x, embed
|
|
138
|
+
|
|
139
|
+
# ----------------------------------------------------------------------
|
|
140
|
+
def forward(
|
|
141
|
+
self,
|
|
142
|
+
src_tokens: Tensor,
|
|
143
|
+
src_lengths: Optional[Tensor] = None,
|
|
144
|
+
return_all_hiddens: bool = False,
|
|
145
|
+
token_embeddings: Optional[Tensor] = None,
|
|
146
|
+
*,
|
|
147
|
+
# External-inputs branch (used for export)
|
|
148
|
+
encoder_padding_mask: Optional[Tensor] = None, # B x T (bool)
|
|
149
|
+
) -> Tensor | Dict[str, List[Optional[Tensor]]]:
|
|
150
|
+
"""
|
|
151
|
+
If `self.use_external_inputs` is True:
|
|
152
|
+
- Use only x_external and encoder_padding_mask.
|
|
153
|
+
- Return a single Tensor (T, B, C) for export friendliness.
|
|
154
|
+
|
|
155
|
+
Otherwise (False):
|
|
156
|
+
- Behave like the original Fairseq encoder forward and return dict-of-lists.
|
|
157
|
+
"""
|
|
158
|
+
if self.use_external_inputs:
|
|
159
|
+
# ----- External-input mode: completely skip embedding/positional/LN/mask creation -----
|
|
160
|
+
x_external = src_tokens # T x B x C (already embedded + transposed)
|
|
161
|
+
|
|
162
|
+
encoder_states: List[Tensor] = []
|
|
163
|
+
if return_all_hiddens:
|
|
164
|
+
encoder_states.append(x_external)
|
|
165
|
+
|
|
166
|
+
for layer in self.layers:
|
|
167
|
+
out = layer(x_external, encoder_padding_mask=encoder_padding_mask)
|
|
168
|
+
x_external = (
|
|
169
|
+
out[0] if (isinstance(out, tuple) and len(out) == 2) else out
|
|
170
|
+
)
|
|
171
|
+
if return_all_hiddens:
|
|
172
|
+
encoder_states.append(x_external)
|
|
173
|
+
|
|
174
|
+
if self.layer_norm is not None:
|
|
175
|
+
x_external = self.layer_norm(x_external)
|
|
176
|
+
|
|
177
|
+
if self.return_type == "dict":
|
|
178
|
+
return {
|
|
179
|
+
"encoder_out": [x_external],
|
|
180
|
+
"encoder_padding_mask": [encoder_padding_mask],
|
|
181
|
+
"encoder_states": encoder_states, # type: ignore[dict-item]
|
|
182
|
+
}
|
|
183
|
+
else:
|
|
184
|
+
# For export, returning a single Tensor is simpler and more portable.
|
|
185
|
+
return x_external
|
|
186
|
+
|
|
187
|
+
# ----- Original path (training/eval compatibility) ------------------
|
|
188
|
+
|
|
189
|
+
# Compute padding mask [B, T] (bool). We keep the original "has_pads" logic.
|
|
190
|
+
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
|
191
|
+
has_pads: Tensor = (
|
|
192
|
+
torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
|
|
193
|
+
)
|
|
194
|
+
if torch.jit.is_scripting():
|
|
195
|
+
has_pads = torch.tensor(1) if has_pads else torch.tensor(0)
|
|
196
|
+
|
|
197
|
+
# Embedding path (B,T,C). No dropout/quant_noise.
|
|
198
|
+
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
|
199
|
+
|
|
200
|
+
# Zero out padded timesteps prior to the stack (same as original)
|
|
201
|
+
x = x * (
|
|
202
|
+
1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# B x T x C -> T x B x C
|
|
206
|
+
x = x.transpose(0, 1)
|
|
207
|
+
|
|
208
|
+
encoder_states: List[Tensor] = [] # type: ignore[no-redef]
|
|
209
|
+
fc_results: List[Optional[Tensor]] = []
|
|
210
|
+
|
|
211
|
+
if return_all_hiddens:
|
|
212
|
+
encoder_states.append(x)
|
|
213
|
+
|
|
214
|
+
# Encoder layers (each item is PTQ-wrapped and uses static additive masks internally)
|
|
215
|
+
for layer in self.layers:
|
|
216
|
+
out = layer(
|
|
217
|
+
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
|
|
218
|
+
)
|
|
219
|
+
if isinstance(out, tuple) and len(out) == 2:
|
|
220
|
+
x, fc_res = out
|
|
221
|
+
else:
|
|
222
|
+
x = out
|
|
223
|
+
fc_res = None
|
|
224
|
+
|
|
225
|
+
if return_all_hiddens and not torch.jit.is_scripting():
|
|
226
|
+
encoder_states.append(x)
|
|
227
|
+
fc_results.append(fc_res)
|
|
228
|
+
|
|
229
|
+
if self.layer_norm is not None:
|
|
230
|
+
x = self.layer_norm(x)
|
|
231
|
+
|
|
232
|
+
# src_lengths (B, 1) int32, identical to original
|
|
233
|
+
src_lengths_out = (
|
|
234
|
+
src_tokens.ne(self.padding_idx)
|
|
235
|
+
.sum(dim=1, dtype=torch.int32)
|
|
236
|
+
.reshape(-1, 1)
|
|
237
|
+
.contiguous()
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return {
|
|
241
|
+
"encoder_out": [x], # T x B x C
|
|
242
|
+
"encoder_padding_mask": [encoder_padding_mask], # B x T
|
|
243
|
+
"encoder_embedding": [encoder_embedding], # B x T x C
|
|
244
|
+
"encoder_states": encoder_states, # type: ignore[dict-item] # List[T x B x C]
|
|
245
|
+
"fc_results": fc_results, # type: ignore[dict-item] # List[T x B x C]
|
|
246
|
+
"src_tokens": [],
|
|
247
|
+
"src_lengths": [src_lengths_out],
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
def forward_torchscript(self, net_input: Dict[str, Tensor]):
|
|
251
|
+
"""A TorchScript-compatible version of forward.
|
|
252
|
+
|
|
253
|
+
Encoders which use additional arguments may want to override
|
|
254
|
+
this method for TorchScript compatibility.
|
|
255
|
+
"""
|
|
256
|
+
if "encoder_padding_mask" in net_input:
|
|
257
|
+
return self.forward(
|
|
258
|
+
src_tokens=net_input["src_tokens"],
|
|
259
|
+
src_lengths=net_input["src_lengths"],
|
|
260
|
+
encoder_padding_mask=net_input["encoder_padding_mask"],
|
|
261
|
+
)
|
|
262
|
+
else:
|
|
263
|
+
return self.forward(
|
|
264
|
+
src_tokens=net_input["src_tokens"],
|
|
265
|
+
src_lengths=net_input["src_lengths"],
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# ----------------------------------------------------------------------
|
|
269
|
+
@torch.jit.export
|
|
270
|
+
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
|
271
|
+
"""
|
|
272
|
+
Match original API: reorder the batched dimension (B) according to new_order.
|
|
273
|
+
"""
|
|
274
|
+
reordered = dict() # type: ignore[var-annotated]
|
|
275
|
+
if len(encoder_out["encoder_out"]) == 0:
|
|
276
|
+
new_encoder_out = []
|
|
277
|
+
else:
|
|
278
|
+
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
|
|
279
|
+
reordered["encoder_out"] = new_encoder_out
|
|
280
|
+
keys = [
|
|
281
|
+
"encoder_padding_mask",
|
|
282
|
+
"encoder_embedding",
|
|
283
|
+
"src_tokens",
|
|
284
|
+
"src_lengths",
|
|
285
|
+
]
|
|
286
|
+
for k in keys:
|
|
287
|
+
if k not in encoder_out:
|
|
288
|
+
continue
|
|
289
|
+
if len(encoder_out[k]) == 0:
|
|
290
|
+
reordered[k] = []
|
|
291
|
+
else:
|
|
292
|
+
reordered[k] = [encoder_out[k][0].index_select(0, new_order)]
|
|
293
|
+
|
|
294
|
+
if "encoder_states" in encoder_out:
|
|
295
|
+
encoder_states = encoder_out["encoder_states"]
|
|
296
|
+
if len(encoder_states) > 0:
|
|
297
|
+
for idx, state in enumerate(encoder_states):
|
|
298
|
+
encoder_states[idx] = state.index_select(1, new_order)
|
|
299
|
+
reordered["encoder_states"] = encoder_states
|
|
300
|
+
|
|
301
|
+
return reordered
|
|
302
|
+
|
|
303
|
+
@torch.jit.export
|
|
304
|
+
def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
|
305
|
+
"""Dummy re-order for beamable enc-dec attention (API parity)."""
|
|
306
|
+
return encoder_out
|
|
307
|
+
|
|
308
|
+
def max_positions(self) -> int:
|
|
309
|
+
"""Maximum input length supported by the encoder (same policy as the original)."""
|
|
310
|
+
if self.embed_positions is None:
|
|
311
|
+
return self.max_source_positions
|
|
312
|
+
return min(self.max_source_positions, self.embed_positions.max_positions)
|
|
313
|
+
|
|
314
|
+
def upgrade_state_dict_named(self, state_dict, name):
|
|
315
|
+
"""
|
|
316
|
+
Forward-compat mapping for older checkpoints (mirror original behavior for LNs).
|
|
317
|
+
The actual remapping of per-layer norms is delegated to the wrapped layers.
|
|
318
|
+
"""
|
|
319
|
+
for i, layer in enumerate(self.layers):
|
|
320
|
+
if hasattr(layer, "upgrade_state_dict_named"):
|
|
321
|
+
layer.upgrade_state_dict_named(state_dict, f"{name}.layers.{i}")
|
|
322
|
+
|
|
323
|
+
version_key = f"{name}.version"
|
|
324
|
+
v = state_dict.get(version_key, torch.Tensor([1]))
|
|
325
|
+
if float(v[0].item()) < 2:
|
|
326
|
+
self.layer_norm = None
|
|
327
|
+
state_dict[version_key] = torch.Tensor([1])
|
|
328
|
+
return state_dict
|
|
329
|
+
|
|
330
|
+
def _all_observers(self):
|
|
331
|
+
for m in self.layers:
|
|
332
|
+
if isinstance(m, QuantModuleBase):
|
|
333
|
+
yield from m._all_observers()
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Optional
|
|
15
|
+
from typing import Optional, Tuple
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
@@ -131,28 +131,38 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
131
131
|
x2n = self._fq(-x2, o_neg)
|
|
132
132
|
return self._fq(torch.cat((x2n, x1), -1), o_cat)
|
|
133
133
|
|
|
134
|
+
@staticmethod
|
|
135
|
+
def _concat_kv(
|
|
136
|
+
past: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
137
|
+
k_new: torch.Tensor,
|
|
138
|
+
v_new: torch.Tensor,
|
|
139
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
140
|
+
"""Concat along sequence dim (dim=2): (B, n_kv, S, H)."""
|
|
141
|
+
if past is None:
|
|
142
|
+
return k_new, v_new
|
|
143
|
+
past_k, past_v = past
|
|
144
|
+
k = torch.cat([past_k, k_new], dim=2)
|
|
145
|
+
v = torch.cat([past_v, v_new], dim=2)
|
|
146
|
+
return k, v
|
|
147
|
+
|
|
134
148
|
def forward(
|
|
135
149
|
self,
|
|
136
150
|
hidden_states: torch.Tensor,
|
|
137
151
|
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
138
152
|
attention_mask: Optional[torch.Tensor] = None,
|
|
139
|
-
past_key_value=None, #
|
|
153
|
+
past_key_value=None, # tuple(k, v) or HF Cache-like object
|
|
154
|
+
use_cache: Optional[bool] = False,
|
|
140
155
|
cache_position: Optional[torch.LongTensor] = None,
|
|
141
156
|
**kwargs,
|
|
142
157
|
):
|
|
143
|
-
if past_key_value is not None:
|
|
144
|
-
raise NotImplementedError(
|
|
145
|
-
"QuantLlamaAttention does not support KV cache yet."
|
|
146
|
-
)
|
|
147
|
-
|
|
148
158
|
hidden = self._fq(hidden_states, self.obs_hidden)
|
|
149
159
|
B, S, _ = hidden.shape
|
|
150
160
|
H = self.hdim
|
|
151
161
|
|
|
152
162
|
# projections
|
|
153
|
-
q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
|
154
|
-
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
|
155
|
-
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2)
|
|
163
|
+
q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H)
|
|
164
|
+
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
|
|
165
|
+
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
|
|
156
166
|
|
|
157
167
|
# rope tables
|
|
158
168
|
cos, sin = position_embeddings
|
|
@@ -176,14 +186,37 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
176
186
|
k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
|
|
177
187
|
k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
|
|
178
188
|
|
|
189
|
+
# --- build/update KV for attention & present_key_value -------------
|
|
190
|
+
present_key_value: Tuple[torch.Tensor, torch.Tensor]
|
|
191
|
+
|
|
192
|
+
# HF Cache path (if available)
|
|
193
|
+
if use_cache and hasattr(past_key_value, "update"):
|
|
194
|
+
# Many HF Cache impls use update(k, v) and return (k_total, v_total)
|
|
195
|
+
try:
|
|
196
|
+
k_total, v_total = past_key_value.update(k_rot, v)
|
|
197
|
+
present_key_value = (k_total, v_total)
|
|
198
|
+
k_for_attn, v_for_attn = k_total, v_total
|
|
199
|
+
except Exception:
|
|
200
|
+
# Fallback to tuple concat if Cache signature mismatches
|
|
201
|
+
k_for_attn, v_for_attn = self._concat_kv(
|
|
202
|
+
getattr(past_key_value, "kv", None), k_rot, v
|
|
203
|
+
)
|
|
204
|
+
present_key_value = (k_for_attn, v_for_attn)
|
|
205
|
+
else:
|
|
206
|
+
# Tuple or None path
|
|
207
|
+
pkv_tuple = past_key_value if isinstance(past_key_value, tuple) else None
|
|
208
|
+
k_for_attn, v_for_attn = self._concat_kv(pkv_tuple, k_rot, v)
|
|
209
|
+
present_key_value = (k_for_attn, v_for_attn)
|
|
210
|
+
|
|
179
211
|
# logits
|
|
180
|
-
k_rep =
|
|
212
|
+
k_rep = k_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
|
|
181
213
|
logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
|
|
182
214
|
scale = self._fq(self.scale_t, self.obs_scale)
|
|
183
215
|
logits = self._fq(logits_raw * scale, self.obs_logits)
|
|
184
216
|
|
|
185
217
|
if attention_mask is None or attention_mask.dtype == torch.bool:
|
|
186
|
-
_, _, q_len,
|
|
218
|
+
_, _, q_len, _ = logits.shape
|
|
219
|
+
k_len = k_for_attn.size(2)
|
|
187
220
|
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
188
221
|
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
|
|
189
222
|
hidden_states.device
|
|
@@ -196,7 +229,7 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
196
229
|
attn_weights = self._fq(attn_weights, self.obs_softmax)
|
|
197
230
|
|
|
198
231
|
# attn out
|
|
199
|
-
v_rep =
|
|
232
|
+
v_rep = v_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
|
|
200
233
|
attn_out = (
|
|
201
234
|
self._fq(attn_weights @ v_rep, self.obs_attn_out)
|
|
202
235
|
.transpose(1, 2)
|
|
@@ -204,7 +237,13 @@ class QuantLlamaAttention(QuantModuleBase):
|
|
|
204
237
|
)
|
|
205
238
|
|
|
206
239
|
# final projection
|
|
207
|
-
|
|
240
|
+
out = self.o_proj(attn_out)
|
|
241
|
+
|
|
242
|
+
# return with/without cache
|
|
243
|
+
if use_cache:
|
|
244
|
+
return out, attn_weights, present_key_value
|
|
245
|
+
else:
|
|
246
|
+
return out, attn_weights
|
|
208
247
|
|
|
209
248
|
def _all_observers(self):
|
|
210
249
|
# local first
|
|
@@ -136,7 +136,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
136
136
|
L = hidden_states.size(1)
|
|
137
137
|
attention_mask = self._slice_causal(L, hidden_states.device)
|
|
138
138
|
|
|
139
|
-
|
|
139
|
+
attn_out = self.self_attn(
|
|
140
140
|
hidden_states=hidden_states,
|
|
141
141
|
attention_mask=attention_mask,
|
|
142
142
|
position_ids=position_ids,
|
|
@@ -147,7 +147,13 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
147
147
|
position_embeddings=position_embeddings,
|
|
148
148
|
**kwargs,
|
|
149
149
|
)
|
|
150
|
-
|
|
150
|
+
if use_cache:
|
|
151
|
+
hidden_states_attn, _attn_weights, present_key_value = attn_out
|
|
152
|
+
else:
|
|
153
|
+
hidden_states_attn, _attn_weights = attn_out
|
|
154
|
+
present_key_value = None
|
|
155
|
+
|
|
156
|
+
hidden_states = residual + hidden_states_attn
|
|
151
157
|
|
|
152
158
|
# ─── MLP block ─────────────────────────────────────────────────
|
|
153
159
|
residual = hidden_states
|
|
@@ -155,6 +161,12 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
|
|
|
155
161
|
hidden_states = self.mlp(hidden_states)
|
|
156
162
|
hidden_states = residual + hidden_states
|
|
157
163
|
|
|
164
|
+
# Return type policy:
|
|
165
|
+
# - If use_cache: always return (hidden_states, present_key_value)
|
|
166
|
+
# - Else: return as configured (tuple/tensor) for HF compatibility
|
|
167
|
+
if use_cache:
|
|
168
|
+
return hidden_states, present_key_value
|
|
169
|
+
|
|
158
170
|
if self.return_type == "tuple":
|
|
159
171
|
return (hidden_states,)
|
|
160
172
|
elif self.return_type == "tensor":
|
|
@@ -33,6 +33,7 @@ _CORE_MODULES = (
|
|
|
33
33
|
"tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer",
|
|
34
34
|
"tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
|
|
35
35
|
# fairseq
|
|
36
|
+
"tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder",
|
|
36
37
|
"tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder_layer",
|
|
37
38
|
"tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha",
|
|
38
39
|
# add future core wrappers here
|
|
@@ -245,9 +245,10 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
|
245
245
|
# mask_user(output).args == (dequantize_per_tensor.tensor, mask)
|
|
246
246
|
if mask:
|
|
247
247
|
assert len(mask) == 1
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
248
|
+
if len(mask[0].users) > 0:
|
|
249
|
+
mask_user = list(mask[0].users.keys())[0]
|
|
250
|
+
assert len(mask_user.args) == 1
|
|
251
|
+
mask_user.args = ((mask_user.args[0][0],),)
|
|
251
252
|
modified = True
|
|
252
253
|
if (
|
|
253
254
|
node.target
|
|
@@ -46,7 +46,9 @@ class RemoveRedundantExpand(PassBase):
|
|
|
46
46
|
input, size = args.input, args.size
|
|
47
47
|
|
|
48
48
|
input_shape = extract_shape(input)
|
|
49
|
-
|
|
49
|
+
output_shape = extract_shape(node)
|
|
50
|
+
|
|
51
|
+
if input_shape != output_shape:
|
|
50
52
|
continue
|
|
51
53
|
|
|
52
54
|
node.replace_all_uses_with(input, propagate_meta=False)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
tico/__init__.py,sha256=
|
|
1
|
+
tico/__init__.py,sha256=Lmo72Xd9sheKIW4XhH6oc5SheplvnXak_Zbh0EQZsrI,1883
|
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
|
@@ -62,7 +62,7 @@ tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAd
|
|
|
62
62
|
tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
|
|
63
63
|
tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
|
|
64
64
|
tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
65
|
-
tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=
|
|
65
|
+
tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=QWUuO50lITnooYqEe57VV6mvIHKWZMB_TOGvtZ8C8qQ,8238
|
|
66
66
|
tico/experimental/quantization/ptq/examples/debug_quant_outputs.py,sha256=astXzx-maq1W4gKvX2QaGmD2Tpmjunv4JqDYVk9eZRQ,5177
|
|
67
67
|
tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
|
|
68
68
|
tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
|
|
@@ -84,13 +84,14 @@ tico/experimental/quantization/ptq/wrappers/__init__.py,sha256=IO6FP_xYbGy0dW0HL
|
|
|
84
84
|
tico/experimental/quantization/ptq/wrappers/ptq_wrapper.py,sha256=F9sK_DiRaXiGNHULcwIbs5EUtHz6ZJ7N4r5CWTTfhsM,2442
|
|
85
85
|
tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfvto6zKrBOKL4gmxfFFc31jHzyQV_zfps-iQM,3604
|
|
86
86
|
tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
|
|
87
|
-
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=
|
|
87
|
+
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=GlVBPWPAnLRqTtemu_YOEX9WisF1eN6Mud7y1zzvpW0,5092
|
|
88
88
|
tico/experimental/quantization/ptq/wrappers/fairseq/__init__.py,sha256=Mc8FLd9DusyB_IT1vk1OYrRkngOYnYd05IvtA9ORVQc,160
|
|
89
|
+
tico/experimental/quantization/ptq/wrappers/fairseq/quant_encoder.py,sha256=r9DPUAbL2KRJ8zpMJ39Y9n6Oe79nte-mFcdjG2qEP-w,13809
|
|
89
90
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_encoder_layer.py,sha256=aGr80Ku75j2H-UZ0elEa0mOQEyaAs2YJ4WJCN0lonn0,6412
|
|
90
91
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_mha.py,sha256=HsigmOLeacLXc46QNeFqwQ0DwKQhNrtWTKEtLJoqXoc,15562
|
|
91
92
|
tico/experimental/quantization/ptq/wrappers/llama/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
92
|
-
tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py,sha256
|
|
93
|
-
tico/experimental/quantization/ptq/wrappers/llama/quant_decoder_layer.py,sha256=
|
|
93
|
+
tico/experimental/quantization/ptq/wrappers/llama/quant_attn.py,sha256=futw-XhAhErdaK2cZY8T3_xCxZbsj-l1dbsSbeunE_4,10403
|
|
94
|
+
tico/experimental/quantization/ptq/wrappers/llama/quant_decoder_layer.py,sha256=ZImtfT2pyYyGJa0QCcHgCVootiWeflpRvLa4LisjZSY,7646
|
|
94
95
|
tico/experimental/quantization/ptq/wrappers/llama/quant_mlp.py,sha256=uZMnrX66oZwxhKhcNbLXXeri-WxxRBiZnr15aBXJMm0,3562
|
|
95
96
|
tico/experimental/quantization/ptq/wrappers/nn/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
96
97
|
tico/experimental/quantization/ptq/wrappers/nn/quant_layernorm.py,sha256=G5Sgt-tXnzh0Rxyk-2honmZIfEQOZlRfOsoDBdSGmA4,6887
|
|
@@ -111,7 +112,7 @@ tico/passes/convert_to_relu6.py,sha256=1BJpUwUb6Zli_1y3eyJQo7dg9B1xvZ7sYjMbvEQsF
|
|
|
111
112
|
tico/passes/decompose_addmm.py,sha256=KjnpZjSuA0uvNmKaTN_EMwobcOi3CAB81buORzTDxro,3979
|
|
112
113
|
tico/passes/decompose_batch_norm.py,sha256=06LAxhSmpTxFZJmUelwB3I_GipNWrLoM7PfM6ZkxOZY,6512
|
|
113
114
|
tico/passes/decompose_fake_quantize.py,sha256=736srs8SM8K_mLR0WG10LVMMLRkYkBM9OF0k1GCkAW0,5218
|
|
114
|
-
tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=
|
|
115
|
+
tico/passes/decompose_fake_quantize_tensor_qparams.py,sha256=CalubQ1OYC2l59_TNPOcAnl4VxvameYWIQcy57Z6yjI,13985
|
|
115
116
|
tico/passes/decompose_group_norm.py,sha256=6BqvYtMTPzeIgp8cPA8OFMwEBvb7odcg04IUgwtp7NQ,10120
|
|
116
117
|
tico/passes/decompose_grouped_conv2d.py,sha256=n2qv320akL1ju33ucZ6lU1cKEAaj0NI8YZ5CrUnkRLM,8512
|
|
117
118
|
tico/passes/decompose_slice_scatter.py,sha256=xqMHKhW2595YoAeubKZ4jRhYW4TQ09EXPgLNgODqXG8,5653
|
|
@@ -128,7 +129,7 @@ tico/passes/merge_consecutive_cat.py,sha256=ayZNLDA1DFM7Fxxi2Dmk1CujkgUuaVCH1rhQ
|
|
|
128
129
|
tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
|
|
129
130
|
tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
|
|
130
131
|
tico/passes/remove_redundant_assert_nodes.py,sha256=rYbTCyuNIXIC-2NreHKBVCuaSUkEQvB_iSRzb26P_EA,1821
|
|
131
|
-
tico/passes/remove_redundant_expand.py,sha256=
|
|
132
|
+
tico/passes/remove_redundant_expand.py,sha256=8yhlMnbog-T9gIK6LKIU0tu0__gfhZzO36g_fJIVVP4,2162
|
|
132
133
|
tico/passes/remove_redundant_permute.py,sha256=98UsaZzFZdQzEEAR1pIzRisAf6hgfXLa88aayjalt3E,4292
|
|
133
134
|
tico/passes/remove_redundant_reshape.py,sha256=aeep6LDvY58GEuOrWckkEXnJa6wkkbiJ9FrimT9F3-s,16384
|
|
134
135
|
tico/passes/remove_redundant_slice.py,sha256=Iv7TbB39fktNb4eq0VdyZnwxL_VsKLJ90diMmaf3kZk,2087
|
|
@@ -251,9 +252,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
|
251
252
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
|
252
253
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
|
253
254
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
|
254
|
-
tico-0.1.0.
|
|
255
|
-
tico-0.1.0.
|
|
256
|
-
tico-0.1.0.
|
|
257
|
-
tico-0.1.0.
|
|
258
|
-
tico-0.1.0.
|
|
259
|
-
tico-0.1.0.
|
|
255
|
+
tico-0.1.0.dev250914.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
|
256
|
+
tico-0.1.0.dev250914.dist-info/METADATA,sha256=qW47MJq3y-q2MtV7kSDUrT8dkZtBWScPMBwZgvMR6tg,8450
|
|
257
|
+
tico-0.1.0.dev250914.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
|
258
|
+
tico-0.1.0.dev250914.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
|
259
|
+
tico-0.1.0.dev250914.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
|
260
|
+
tico-0.1.0.dev250914.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|