tico 0.1.0.dev250918__py3-none-any.whl → 0.1.0.dev250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +1 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +2 -2
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +1 -1
- tico/experimental/quantization/config/__init__.py +1 -0
- tico/experimental/quantization/config/base.py +26 -0
- tico/experimental/quantization/config/gptq.py +29 -0
- tico/experimental/quantization/config/pt2e.py +25 -0
- tico/experimental/quantization/{config.py → config/smoothquant.py} +1 -35
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +1 -1
- tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder.py +431 -0
- tico/experimental/quantization/public_interface.py +1 -1
- tico/experimental/quantization/quantizer.py +1 -1
- tico/passes/convert_matmul_to_linear.py +119 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/utils/convert.py +3 -0
- {tico-0.1.0.dev250918.dist-info → tico-0.1.0.dev250922.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250918.dist-info → tico-0.1.0.dev250922.dist-info}/RECORD +22 -17
- {tico-0.1.0.dev250918.dist-info → tico-0.1.0.dev250922.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250918.dist-info → tico-0.1.0.dev250922.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250918.dist-info → tico-0.1.0.dev250922.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250918.dist-info → tico-0.1.0.dev250922.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -23,6 +23,7 @@ class CompileConfigV1(CompileConfigBase):
|
|
|
23
23
|
remove_constant_input: bool = False
|
|
24
24
|
convert_lhs_const_mm_to_fc: bool = False
|
|
25
25
|
convert_rhs_const_mm_to_fc: bool = True
|
|
26
|
+
convert_single_batch_lhs_const_bmm_to_fc: bool = False
|
|
26
27
|
|
|
27
28
|
def get(self, name: str):
|
|
28
29
|
return super().get(name)
|
|
@@ -25,7 +25,7 @@ from tico.experimental.quantization.algorithm.gptq.utils import (
|
|
|
25
25
|
gather_single_batch_from_dict,
|
|
26
26
|
gather_single_batch_from_list,
|
|
27
27
|
)
|
|
28
|
-
from tico.experimental.quantization.config import
|
|
28
|
+
from tico.experimental.quantization.config.gptq import GPTQConfig
|
|
29
29
|
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
30
30
|
|
|
31
31
|
|
|
@@ -44,7 +44,7 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
44
44
|
3) convert(model) to consume the collected data and apply GPTQ.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
def __init__(self, config:
|
|
47
|
+
def __init__(self, config: GPTQConfig):
|
|
48
48
|
super().__init__(config)
|
|
49
49
|
|
|
50
50
|
# cache_args[i] -> list of the i-th positional argument for each batch
|
|
@@ -23,7 +23,7 @@ from tico.experimental.quantization.algorithm.smoothquant.observer import (
|
|
|
23
23
|
from tico.experimental.quantization.algorithm.smoothquant.smooth_quant import (
|
|
24
24
|
apply_smoothing,
|
|
25
25
|
)
|
|
26
|
-
from tico.experimental.quantization.config import SmoothQuantConfig
|
|
26
|
+
from tico.experimental.quantization.config.smoothquant import SmoothQuantConfig
|
|
27
27
|
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
28
28
|
|
|
29
29
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseConfig(ABC):
|
|
19
|
+
"""
|
|
20
|
+
Base configuration class for quantization.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def name(self) -> str:
|
|
26
|
+
pass
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from tico.experimental.quantization.config.base import BaseConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GPTQConfig(BaseConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration for GPTQ.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, verbose: bool = False, show_progress: bool = True):
|
|
24
|
+
self.verbose = verbose
|
|
25
|
+
self.show_progress = show_progress
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return "gptq"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from tico.experimental.quantization.config.base import BaseConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PT2EConfig(BaseConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration for pytorch 2.0 export quantization.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def name(self) -> str:
|
|
25
|
+
return "pt2e"
|
|
@@ -12,43 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from abc import ABC, abstractmethod
|
|
16
15
|
from typing import Dict, Literal, Optional
|
|
17
16
|
|
|
18
|
-
|
|
19
|
-
class BaseConfig(ABC):
|
|
20
|
-
"""
|
|
21
|
-
Base configuration class for quantization.
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
@property
|
|
25
|
-
@abstractmethod
|
|
26
|
-
def name(self) -> str:
|
|
27
|
-
pass
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class PT2EConfig(BaseConfig):
|
|
31
|
-
"""
|
|
32
|
-
Configuration for pytorch 2.0 export quantization.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def name(self) -> str:
|
|
37
|
-
return "pt2e"
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class GPTQConfig(BaseConfig):
|
|
41
|
-
"""
|
|
42
|
-
Configuration for GPTQ.
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
def __init__(self, verbose: bool = False, show_progress: bool = True):
|
|
46
|
-
self.verbose = verbose
|
|
47
|
-
self.show_progress = show_progress
|
|
48
|
-
|
|
49
|
-
@property
|
|
50
|
-
def name(self) -> str:
|
|
51
|
-
return "gptq"
|
|
17
|
+
from tico.experimental.quantization.config.base import BaseConfig
|
|
52
18
|
|
|
53
19
|
|
|
54
20
|
class SmoothQuantConfig(BaseConfig):
|
|
@@ -34,7 +34,7 @@ from datasets import load_dataset
|
|
|
34
34
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
35
35
|
|
|
36
36
|
from tico.experimental.quantization import convert, prepare
|
|
37
|
-
from tico.experimental.quantization.config import GPTQConfig
|
|
37
|
+
from tico.experimental.quantization.config.gptq import GPTQConfig
|
|
38
38
|
from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
|
|
39
39
|
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
40
40
|
from tico.experimental.quantization.ptq.utils.introspection import build_fqn_map
|
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
import math
|
|
22
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn.functional as F
|
|
26
|
+
from torch import nn, Tensor
|
|
27
|
+
|
|
28
|
+
from tico.experimental.quantization.ptq.quant_config import QuantConfig
|
|
29
|
+
from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
|
|
30
|
+
from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
|
|
31
|
+
QuantModuleBase,
|
|
32
|
+
)
|
|
33
|
+
from tico.experimental.quantization.ptq.wrappers.registry import try_register
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@try_register("fairseq.models.transformer.TransformerDecoderBase")
|
|
37
|
+
class QuantFairseqDecoder(QuantModuleBase):
|
|
38
|
+
"""
|
|
39
|
+
Quant-aware drop-in replacement for Fairseq TransformerDecoderBase.
|
|
40
|
+
|
|
41
|
+
Design (inference-only):
|
|
42
|
+
- Keep embeddings, positional embeddings, LayerNorms, output_projection in FP.
|
|
43
|
+
- PTQ-wrap all TransformerDecoderLayerBase items via PTQWrapper (uses QuantFairseqDecoderLayer).
|
|
44
|
+
- Drop training-only logic (dropout, activation-dropout, quant-noise, checkpoint wrappers).
|
|
45
|
+
- Preserve Fairseq forward/extract_features contract, shapes, and incremental decoding behavior.
|
|
46
|
+
|
|
47
|
+
I/O:
|
|
48
|
+
- Forward(prev_output_tokens, encoder_out, incremental_state, ...) -> (logits, extra) like the original.
|
|
49
|
+
- `features_only=True` returns features without output projection.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
fp_decoder: nn.Module,
|
|
55
|
+
*,
|
|
56
|
+
qcfg: Optional[QuantConfig] = None,
|
|
57
|
+
fp_name: Optional[str] = None,
|
|
58
|
+
):
|
|
59
|
+
super().__init__(qcfg, fp_name=fp_name)
|
|
60
|
+
|
|
61
|
+
# ---- carry config/meta (read-only views) --------------------------
|
|
62
|
+
assert hasattr(fp_decoder, "cfg")
|
|
63
|
+
self.cfg = fp_decoder.cfg
|
|
64
|
+
self.share_input_output_embed: bool = bool(
|
|
65
|
+
getattr(fp_decoder, "share_input_output_embed", False)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Version buffer (parity with original)
|
|
69
|
+
version = getattr(fp_decoder, "version", None)
|
|
70
|
+
if isinstance(version, torch.Tensor):
|
|
71
|
+
self.register_buffer("version", version.clone(), persistent=False)
|
|
72
|
+
else:
|
|
73
|
+
self.register_buffer("version", torch.tensor([3.0]), persistent=False)
|
|
74
|
+
|
|
75
|
+
# Embeddings / positional encodings (FP; reuse modules)
|
|
76
|
+
assert hasattr(fp_decoder, "embed_tokens") and isinstance(
|
|
77
|
+
fp_decoder.embed_tokens, nn.Module
|
|
78
|
+
)
|
|
79
|
+
self.embed_tokens = fp_decoder.embed_tokens # (B,T)->(B,T,C)
|
|
80
|
+
|
|
81
|
+
self.padding_idx: int = int(fp_decoder.padding_idx) # type: ignore[arg-type]
|
|
82
|
+
self.max_target_positions: int = int(fp_decoder.max_target_positions) # type: ignore[arg-type]
|
|
83
|
+
|
|
84
|
+
self.embed_positions = getattr(fp_decoder, "embed_positions", None)
|
|
85
|
+
self.layernorm_embedding = getattr(fp_decoder, "layernorm_embedding", None)
|
|
86
|
+
|
|
87
|
+
# Dimensions / projections (reuse)
|
|
88
|
+
self.embed_dim: int = int(getattr(fp_decoder, "embed_dim"))
|
|
89
|
+
self.output_embed_dim: int = int(getattr(fp_decoder, "output_embed_dim"))
|
|
90
|
+
self.project_in_dim = getattr(fp_decoder, "project_in_dim", None)
|
|
91
|
+
self.project_out_dim = getattr(fp_decoder, "project_out_dim", None)
|
|
92
|
+
|
|
93
|
+
# Scale factor (sqrt(embed_dim) unless disabled)
|
|
94
|
+
no_scale = bool(getattr(self.cfg, "no_scale_embedding", False))
|
|
95
|
+
self.embed_scale: float = 1.0 if no_scale else math.sqrt(self.embed_dim)
|
|
96
|
+
|
|
97
|
+
# Final decoder LayerNorm (may be None depending on cfg)
|
|
98
|
+
self.layer_norm = getattr(fp_decoder, "layer_norm", None)
|
|
99
|
+
|
|
100
|
+
# Output projection / adaptive softmax (reuse FP modules)
|
|
101
|
+
self.adaptive_softmax = getattr(fp_decoder, "adaptive_softmax", None)
|
|
102
|
+
self.output_projection = getattr(fp_decoder, "output_projection", None)
|
|
103
|
+
|
|
104
|
+
# ---- wrap decoder layers ------------------------------------------
|
|
105
|
+
assert hasattr(fp_decoder, "layers")
|
|
106
|
+
fp_layers = list(fp_decoder.layers) # type: ignore[arg-type]
|
|
107
|
+
self.layers = nn.ModuleList()
|
|
108
|
+
|
|
109
|
+
# Safe prefix to avoid None-based name collisions in KV cache keys
|
|
110
|
+
def _safe_prefix(name: Optional[str]) -> str:
|
|
111
|
+
return (
|
|
112
|
+
name
|
|
113
|
+
if (name is not None and name != "" and name != "None")
|
|
114
|
+
else f"{self.__class__.__name__}_{id(self)}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
prefix = _safe_prefix(fp_name)
|
|
118
|
+
|
|
119
|
+
# Prepare child QuantConfig namespaces: layers/<idx>
|
|
120
|
+
layers_qcfg = qcfg.child("layers") if qcfg else None
|
|
121
|
+
for i, layer in enumerate(fp_layers):
|
|
122
|
+
child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
|
|
123
|
+
# Not every item is necessarily a TransformerDecoderLayerBase (e.g., BaseLayer).
|
|
124
|
+
# If there's no registered wrapper for a layer type, keep it FP.
|
|
125
|
+
try:
|
|
126
|
+
wrapped = PTQWrapper(
|
|
127
|
+
layer, qcfg=child_cfg, fp_name=f"{prefix}.layers.{i}"
|
|
128
|
+
)
|
|
129
|
+
except NotImplementedError:
|
|
130
|
+
wrapped = layer # keep as-is (FP)
|
|
131
|
+
self.layers.append(wrapped)
|
|
132
|
+
self.num_layers = len(self.layers)
|
|
133
|
+
|
|
134
|
+
# choose a generous upper-bound; you can wire this from cfg if you like
|
|
135
|
+
self.mask_fill_value: float = -120.0
|
|
136
|
+
max_tgt = int(getattr(self.cfg, "max_target_positions", 2048)) # fallback: 2048
|
|
137
|
+
|
|
138
|
+
mask = torch.full((1, 1, max_tgt, max_tgt), float(self.mask_fill_value))
|
|
139
|
+
mask.triu_(1) # upper triangle set to fill_value; diagonal/lower are zeros
|
|
140
|
+
self.register_buffer("causal_mask_template", mask, persistent=False)
|
|
141
|
+
|
|
142
|
+
def forward(
|
|
143
|
+
self,
|
|
144
|
+
prev_output_tokens: Tensor, # [B, T]
|
|
145
|
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
|
146
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
147
|
+
features_only: bool = False,
|
|
148
|
+
full_context_alignment: bool = False,
|
|
149
|
+
alignment_layer: Optional[int] = None,
|
|
150
|
+
alignment_heads: Optional[int] = None,
|
|
151
|
+
src_lengths: Optional[Any] = None,
|
|
152
|
+
return_all_hiddens: bool = False,
|
|
153
|
+
):
|
|
154
|
+
"""
|
|
155
|
+
Match the original API.
|
|
156
|
+
Returns:
|
|
157
|
+
(logits_or_features, extra_dict)
|
|
158
|
+
"""
|
|
159
|
+
x, extra = self.extract_features_scriptable(
|
|
160
|
+
prev_output_tokens=prev_output_tokens,
|
|
161
|
+
encoder_out=encoder_out,
|
|
162
|
+
incremental_state=incremental_state,
|
|
163
|
+
full_context_alignment=full_context_alignment,
|
|
164
|
+
alignment_layer=alignment_layer,
|
|
165
|
+
alignment_heads=alignment_heads,
|
|
166
|
+
)
|
|
167
|
+
if not features_only:
|
|
168
|
+
x = self.output_layer(x)
|
|
169
|
+
return x, extra
|
|
170
|
+
|
|
171
|
+
def extract_features_scriptable(
|
|
172
|
+
self,
|
|
173
|
+
prev_output_tokens: Tensor, # [B,T]
|
|
174
|
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
|
175
|
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
176
|
+
full_context_alignment: bool = False,
|
|
177
|
+
alignment_layer: Optional[int] = None,
|
|
178
|
+
alignment_heads: Optional[int] = None,
|
|
179
|
+
) -> Tuple[Tensor, Dict[str, List[Optional[Tensor]]]]:
|
|
180
|
+
"""
|
|
181
|
+
Feature path that mirrors Fairseq's implementation (minus training-only code).
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
x: [B, T, C]
|
|
185
|
+
extra: {"attn": [attn or None], "inner_states": [T x B x C tensors]}
|
|
186
|
+
"""
|
|
187
|
+
B, T = prev_output_tokens.size()
|
|
188
|
+
if alignment_layer is None:
|
|
189
|
+
alignment_layer = self.num_layers - 1
|
|
190
|
+
|
|
191
|
+
# Unpack encoder outputs in Fairseq dict format
|
|
192
|
+
enc: Optional[Tensor] = None
|
|
193
|
+
padding_mask: Optional[Tensor] = None
|
|
194
|
+
if encoder_out is not None and len(encoder_out.get("encoder_out", [])) > 0:
|
|
195
|
+
enc = encoder_out["encoder_out"][0] # [S,B,Ce]
|
|
196
|
+
if (
|
|
197
|
+
encoder_out is not None
|
|
198
|
+
and len(encoder_out.get("encoder_padding_mask", [])) > 0
|
|
199
|
+
):
|
|
200
|
+
padding_mask = encoder_out["encoder_padding_mask"][0] # [B,S] (bool)
|
|
201
|
+
|
|
202
|
+
# Positional embeddings (support incremental decoding)
|
|
203
|
+
positions = None
|
|
204
|
+
if self.embed_positions is not None:
|
|
205
|
+
positions = self.embed_positions(
|
|
206
|
+
prev_output_tokens, incremental_state=incremental_state
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# In incremental mode, only the last step is consumed
|
|
210
|
+
if incremental_state is not None:
|
|
211
|
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
|
212
|
+
if positions is not None:
|
|
213
|
+
positions = positions[:, -1:]
|
|
214
|
+
|
|
215
|
+
# Prevent view quirks (TorchScript parity in original)
|
|
216
|
+
prev_output_tokens = prev_output_tokens.contiguous()
|
|
217
|
+
|
|
218
|
+
# Token embeddings (+ optional proj-in), + positions, + optional LN
|
|
219
|
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens) # [B,T,C]
|
|
220
|
+
if self.project_in_dim is not None:
|
|
221
|
+
x = self.project_in_dim(x)
|
|
222
|
+
if positions is not None:
|
|
223
|
+
x = x + positions
|
|
224
|
+
if self.layernorm_embedding is not None:
|
|
225
|
+
x = self.layernorm_embedding(x)
|
|
226
|
+
|
|
227
|
+
# No dropout / quant_noise (inference-only)
|
|
228
|
+
|
|
229
|
+
# B x T x C -> T x B x C
|
|
230
|
+
x = x.transpose(0, 1)
|
|
231
|
+
|
|
232
|
+
# Build self-attn masks
|
|
233
|
+
self_attn_padding_mask: Optional[Tensor] = None
|
|
234
|
+
if (
|
|
235
|
+
getattr(self.cfg, "cross_self_attention", False)
|
|
236
|
+
or prev_output_tokens.eq(self.padding_idx).any()
|
|
237
|
+
):
|
|
238
|
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # [B,T]
|
|
239
|
+
|
|
240
|
+
attn: Optional[Tensor] = None
|
|
241
|
+
inner_states: List[Optional[Tensor]] = [x]
|
|
242
|
+
|
|
243
|
+
for idx, layer in enumerate(self.layers):
|
|
244
|
+
# Causal mask unless full-context alignment or incremental decoding
|
|
245
|
+
if incremental_state is None and not full_context_alignment:
|
|
246
|
+
Tq = x.size(0)
|
|
247
|
+
self_attn_mask = self.buffered_future_mask(
|
|
248
|
+
Tq, Tq, x=x
|
|
249
|
+
) # [Tq,Tq] additive float
|
|
250
|
+
else:
|
|
251
|
+
self_attn_mask = None
|
|
252
|
+
|
|
253
|
+
x, layer_attn, _ = layer(
|
|
254
|
+
x,
|
|
255
|
+
enc,
|
|
256
|
+
padding_mask,
|
|
257
|
+
incremental_state,
|
|
258
|
+
self_attn_mask=self_attn_mask,
|
|
259
|
+
self_attn_padding_mask=self_attn_padding_mask,
|
|
260
|
+
need_attn=bool(idx == alignment_layer),
|
|
261
|
+
need_head_weights=bool(idx == alignment_layer),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
inner_states.append(x)
|
|
265
|
+
if layer_attn is not None and idx == alignment_layer:
|
|
266
|
+
attn = layer_attn.float().to(x)
|
|
267
|
+
|
|
268
|
+
# Average heads if needed
|
|
269
|
+
if attn is not None and alignment_heads is not None:
|
|
270
|
+
attn = attn[:alignment_heads]
|
|
271
|
+
if attn is not None:
|
|
272
|
+
attn = attn.mean(dim=0) # [B,T,S]
|
|
273
|
+
|
|
274
|
+
# Optional final layer norm
|
|
275
|
+
if self.layer_norm is not None:
|
|
276
|
+
x = self.layer_norm(x)
|
|
277
|
+
|
|
278
|
+
# T x B x C -> B x T x C
|
|
279
|
+
x = x.transpose(0, 1)
|
|
280
|
+
|
|
281
|
+
# Optional proj-out
|
|
282
|
+
if self.project_out_dim is not None:
|
|
283
|
+
assert self.project_out_dim is not None
|
|
284
|
+
x = self.project_out_dim(x)
|
|
285
|
+
|
|
286
|
+
return x, {"attn": [attn], "inner_states": inner_states}
|
|
287
|
+
|
|
288
|
+
def output_layer(self, features: Tensor) -> Tensor:
|
|
289
|
+
"""Project features to vocabulary size (or return features with adaptive softmax)."""
|
|
290
|
+
if self.adaptive_softmax is None:
|
|
291
|
+
assert self.output_projection is not None
|
|
292
|
+
return self.output_projection(features) # type: ignore[operator]
|
|
293
|
+
else:
|
|
294
|
+
return features
|
|
295
|
+
|
|
296
|
+
def buffered_future_mask(
|
|
297
|
+
self, Tq: int, Ts: int, *, x: torch.Tensor
|
|
298
|
+
) -> torch.Tensor:
|
|
299
|
+
"""
|
|
300
|
+
Return additive float mask [Tq, Ts]: zeros on allowed, large-neg on disallowed.
|
|
301
|
+
Uses the prebuilt template; will re-build if you exceed template size.
|
|
302
|
+
"""
|
|
303
|
+
assert isinstance(self.causal_mask_template, torch.Tensor)
|
|
304
|
+
Mmax = self.causal_mask_template.size(-1)
|
|
305
|
+
assert Tq <= Mmax and Ts <= Mmax
|
|
306
|
+
cm = self.causal_mask_template[..., :Tq, :Ts].to(device=x.device, dtype=x.dtype)
|
|
307
|
+
return cm.squeeze(0).squeeze(0) # [Tq, Ts]
|
|
308
|
+
|
|
309
|
+
def max_positions(self) -> int:
|
|
310
|
+
"""Maximum output length supported by the decoder (same policy as the original)."""
|
|
311
|
+
if self.embed_positions is None:
|
|
312
|
+
return self.max_target_positions
|
|
313
|
+
return min(self.max_target_positions, self.embed_positions.max_positions)
|
|
314
|
+
|
|
315
|
+
def get_normalized_probs(
|
|
316
|
+
self,
|
|
317
|
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
|
318
|
+
log_probs: bool,
|
|
319
|
+
sample: Optional[Dict[str, Tensor]] = None,
|
|
320
|
+
):
|
|
321
|
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
322
|
+
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
|
323
|
+
|
|
324
|
+
def get_normalized_probs_scriptable(
|
|
325
|
+
self,
|
|
326
|
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
|
327
|
+
log_probs: bool,
|
|
328
|
+
sample: Optional[Dict[str, Tensor]] = None,
|
|
329
|
+
):
|
|
330
|
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
331
|
+
|
|
332
|
+
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
|
333
|
+
if sample is not None:
|
|
334
|
+
assert "target" in sample
|
|
335
|
+
target = sample["target"]
|
|
336
|
+
else:
|
|
337
|
+
target = None
|
|
338
|
+
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
|
|
339
|
+
return out.exp_() if not log_probs else out
|
|
340
|
+
|
|
341
|
+
logits = net_output[0]
|
|
342
|
+
if log_probs:
|
|
343
|
+
return F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
|
344
|
+
else:
|
|
345
|
+
return F.softmax(logits, dim=-1, dtype=torch.float32)
|
|
346
|
+
|
|
347
|
+
def reorder_incremental_state_scripting(
|
|
348
|
+
self,
|
|
349
|
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
|
350
|
+
new_order: Tensor,
|
|
351
|
+
):
|
|
352
|
+
"""Main entry point for reordering the incremental state.
|
|
353
|
+
|
|
354
|
+
Due to limitations in TorchScript, we call this function in
|
|
355
|
+
:class:`fairseq.sequence_generator.SequenceGenerator` instead of
|
|
356
|
+
calling :func:`reorder_incremental_state` directly.
|
|
357
|
+
"""
|
|
358
|
+
for module in self.modules():
|
|
359
|
+
if hasattr(module, "reorder_incremental_state"):
|
|
360
|
+
result = module.reorder_incremental_state(incremental_state, new_order) # type: ignore[operator]
|
|
361
|
+
if result is not None:
|
|
362
|
+
incremental_state = result
|
|
363
|
+
|
|
364
|
+
def forward_external_step(
|
|
365
|
+
self,
|
|
366
|
+
prev_output_x: Tensor, # [1, B, C]
|
|
367
|
+
*,
|
|
368
|
+
encoder_out_x: Tensor, # [S, B, Ce]
|
|
369
|
+
encoder_padding_mask: Optional[
|
|
370
|
+
Tensor
|
|
371
|
+
] = None, # [B,S] or [B,1,S] additive-float
|
|
372
|
+
self_attn_mask: Optional[
|
|
373
|
+
Tensor
|
|
374
|
+
] = None, # [1,S_hist+1] or [B,1,S_hist+1] additive-float
|
|
375
|
+
prev_self_k_list: Optional[
|
|
376
|
+
List[Tensor]
|
|
377
|
+
] = None, # length=L; each [B,H,Tprev,Dh]
|
|
378
|
+
prev_self_v_list: Optional[
|
|
379
|
+
List[Tensor]
|
|
380
|
+
] = None, # length=L; each [B,H,Tprev,Dh]
|
|
381
|
+
need_attn: bool = False,
|
|
382
|
+
need_head_weights: bool = False,
|
|
383
|
+
) -> Tuple[Tensor, List[Tensor], List[Tensor]]:
|
|
384
|
+
"""
|
|
385
|
+
Export-only single-step decoder.
|
|
386
|
+
Returns:
|
|
387
|
+
- x_out: [1, B, C]
|
|
388
|
+
- new_self_k_list/new_self_v_list: lists of length L; each [B*H, Tnew, Dh]
|
|
389
|
+
"""
|
|
390
|
+
assert (
|
|
391
|
+
prev_output_x.dim() == 3 and prev_output_x.size(0) == 1
|
|
392
|
+
), "prev_output_x must be [1,B,C]"
|
|
393
|
+
L = self.num_layers
|
|
394
|
+
if prev_self_k_list is None:
|
|
395
|
+
prev_self_k_list = [None] * L # type: ignore[list-item]
|
|
396
|
+
if prev_self_v_list is None:
|
|
397
|
+
prev_self_v_list = [None] * L # type: ignore[list-item]
|
|
398
|
+
assert len(prev_self_k_list) == L and len(prev_self_v_list) == L
|
|
399
|
+
|
|
400
|
+
assert encoder_out_x.dim() == 3, "encoder_out_x must be [S,B,C]"
|
|
401
|
+
x = prev_output_x # [1,B,C]
|
|
402
|
+
enc = encoder_out_x
|
|
403
|
+
|
|
404
|
+
new_k_list: List[Tensor] = []
|
|
405
|
+
new_v_list: List[Tensor] = []
|
|
406
|
+
|
|
407
|
+
for li, layer in enumerate(self.layers):
|
|
408
|
+
assert isinstance(layer, PTQWrapper)
|
|
409
|
+
x, _, k_new, v_new = layer.wrapped.forward_external( # type: ignore[attr-defined, operator]
|
|
410
|
+
x,
|
|
411
|
+
encoder_out=enc,
|
|
412
|
+
encoder_padding_mask=encoder_padding_mask,
|
|
413
|
+
prev_self_k=prev_self_k_list[li],
|
|
414
|
+
prev_self_v=prev_self_v_list[li],
|
|
415
|
+
self_attn_mask=self_attn_mask,
|
|
416
|
+
need_attn=need_attn and (li == L - 1),
|
|
417
|
+
need_head_weights=need_head_weights and (li == L - 1),
|
|
418
|
+
)
|
|
419
|
+
new_k_list.append(k_new) # [B*H, Tnew, Dh]
|
|
420
|
+
new_v_list.append(v_new) # [B*H, Tnew, Dh]
|
|
421
|
+
|
|
422
|
+
if self.layer_norm is not None:
|
|
423
|
+
x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1)
|
|
424
|
+
|
|
425
|
+
return x, new_k_list, new_v_list # [1,B,C], lists of [B*H, Tnew, Dh]
|
|
426
|
+
|
|
427
|
+
def _all_observers(self) -> Iterable:
|
|
428
|
+
"""Yield all observers from wrapped decoder layers (if any)."""
|
|
429
|
+
for m in self.layers:
|
|
430
|
+
if isinstance(m, QuantModuleBase):
|
|
431
|
+
yield from m._all_observers()
|
|
@@ -22,7 +22,7 @@ from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantize
|
|
|
22
22
|
from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
|
|
23
23
|
SmoothQuantQuantizer,
|
|
24
24
|
)
|
|
25
|
-
from tico.experimental.quantization.config import BaseConfig
|
|
25
|
+
from tico.experimental.quantization.config.base import BaseConfig
|
|
26
26
|
from tico.experimental.quantization.quantizer import BaseQuantizer
|
|
27
27
|
|
|
28
28
|
|
|
@@ -20,11 +20,13 @@ import torch
|
|
|
20
20
|
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
|
|
21
21
|
from torch.export import ExportedProgram
|
|
22
22
|
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
|
|
23
25
|
from tico.utils import logging
|
|
24
26
|
from tico.utils.graph import create_node
|
|
25
27
|
from tico.utils.passes import PassBase, PassResult
|
|
26
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
27
|
-
from tico.utils.validate_args_kwargs import MatmulArgs
|
|
29
|
+
from tico.utils.validate_args_kwargs import BmmArgs, MatmulArgs
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class Converter: # type: ignore[empty-body]
|
|
@@ -57,14 +59,14 @@ class MatmulToLinearConverter(Converter):
|
|
|
57
59
|
torch.ops.aten.permute.default,
|
|
58
60
|
args=(rhs, [1, 0]),
|
|
59
61
|
)
|
|
60
|
-
|
|
62
|
+
linear_node = create_node(
|
|
61
63
|
graph,
|
|
62
64
|
torch.ops.aten.linear.default,
|
|
63
65
|
args=(lhs, transpose_node),
|
|
64
66
|
)
|
|
65
|
-
node.replace_all_uses_with(
|
|
67
|
+
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
|
66
68
|
|
|
67
|
-
return
|
|
69
|
+
return linear_node
|
|
68
70
|
|
|
69
71
|
|
|
70
72
|
class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
@@ -110,18 +112,125 @@ class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
|
110
112
|
return True
|
|
111
113
|
elif is_buffer(exported_program, lhs):
|
|
112
114
|
return True
|
|
113
|
-
else:
|
|
114
|
-
return False
|
|
115
115
|
return False
|
|
116
116
|
|
|
117
117
|
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
118
118
|
return super().convert(exported_program, node)
|
|
119
119
|
|
|
120
120
|
|
|
121
|
+
class SingleBatchLhsConstBmmToLinearConverter(Converter):
|
|
122
|
+
"""
|
|
123
|
+
Convert `single-batched & lhs-const BatchMatMul` to `linear` operation.
|
|
124
|
+
|
|
125
|
+
[1] exchange lhs and rhs
|
|
126
|
+
[2] transpose rhs
|
|
127
|
+
[3] transpose output
|
|
128
|
+
|
|
129
|
+
**Before**
|
|
130
|
+
|
|
131
|
+
lhs[1,a,b](const) rhs[1,b,c]
|
|
132
|
+
| |
|
|
133
|
+
| |
|
|
134
|
+
---------bmm---------
|
|
135
|
+
|
|
|
136
|
+
output[1,a,c]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
**After**
|
|
140
|
+
|
|
141
|
+
rhs[1,b,c]
|
|
142
|
+
|
|
|
143
|
+
tr lhs'[a,b](const-folded)
|
|
144
|
+
|[1,c,b] |
|
|
145
|
+
| |
|
|
146
|
+
---------fc--------
|
|
147
|
+
|[1,c,a]
|
|
148
|
+
tr
|
|
149
|
+
|
|
|
150
|
+
output[1,a,c]
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self):
|
|
155
|
+
super().__init__()
|
|
156
|
+
|
|
157
|
+
def match(self, exported_program, node) -> bool:
|
|
158
|
+
if not node.target == torch.ops.aten.bmm.default:
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
|
162
|
+
lhs = bmm_args.input
|
|
163
|
+
rhs = bmm_args.mat2
|
|
164
|
+
|
|
165
|
+
# [1] Single-batch
|
|
166
|
+
lhs_shape = extract_shape(lhs)
|
|
167
|
+
rhs_shape = extract_shape(rhs)
|
|
168
|
+
|
|
169
|
+
assert len(lhs_shape) == len(
|
|
170
|
+
rhs_shape
|
|
171
|
+
), f"Bmm input's ranks must be the same but got {lhs_shape} and {rhs_shape}"
|
|
172
|
+
|
|
173
|
+
if not (lhs_shape[0] == rhs_shape[0] == 1):
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
# [2] Lhs is constant
|
|
177
|
+
if not isinstance(lhs, torch.fx.Node):
|
|
178
|
+
return False
|
|
179
|
+
if not (
|
|
180
|
+
is_lifted_tensor_constant(exported_program, lhs)
|
|
181
|
+
or is_param(exported_program, lhs)
|
|
182
|
+
or is_buffer(exported_program, lhs)
|
|
183
|
+
):
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
189
|
+
graph_module = exported_program.graph_module
|
|
190
|
+
graph = graph_module.graph
|
|
191
|
+
|
|
192
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
193
|
+
|
|
194
|
+
lhs = bmm_args.input # const
|
|
195
|
+
rhs = bmm_args.mat2 # non-const
|
|
196
|
+
lhs_shape = extract_shape(lhs)
|
|
197
|
+
rhs_shape = extract_shape(rhs)
|
|
198
|
+
assert rhs_shape[0] == 1
|
|
199
|
+
assert lhs_shape[0] == 1
|
|
200
|
+
|
|
201
|
+
with graph.inserting_before(node):
|
|
202
|
+
rhs_tr = create_node(
|
|
203
|
+
graph,
|
|
204
|
+
torch.ops.aten.permute.default,
|
|
205
|
+
args=(rhs, [0, 2, 1]),
|
|
206
|
+
)
|
|
207
|
+
lhs_reshape = create_node(
|
|
208
|
+
graph,
|
|
209
|
+
torch.ops.aten.view.default,
|
|
210
|
+
args=(lhs, list(lhs_shape[1:])),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
linear_node = create_node(
|
|
214
|
+
graph,
|
|
215
|
+
torch.ops.aten.linear.default,
|
|
216
|
+
args=(rhs_tr, lhs_reshape),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
tr_linear_node = create_node(
|
|
220
|
+
graph,
|
|
221
|
+
torch.ops.aten.permute.default,
|
|
222
|
+
args=(linear_node, [0, 2, 1]),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
node.replace_all_uses_with(tr_linear_node, propagate_meta=False)
|
|
226
|
+
|
|
227
|
+
return tr_linear_node
|
|
228
|
+
|
|
229
|
+
|
|
121
230
|
@trace_graph_diff_on_pass
|
|
122
231
|
class ConvertMatmulToLinear(PassBase):
|
|
123
232
|
"""
|
|
124
|
-
This pass converts matmul to linear selectively
|
|
233
|
+
This pass converts matmul(partially includes single-batch bmm) to linear selectively
|
|
125
234
|
|
|
126
235
|
How to select between `matmul` and `linear`?
|
|
127
236
|
|
|
@@ -164,6 +273,7 @@ class ConvertMatmulToLinear(PassBase):
|
|
|
164
273
|
self,
|
|
165
274
|
enable_lhs_const: Optional[bool] = False,
|
|
166
275
|
enable_rhs_const: Optional[bool] = True,
|
|
276
|
+
enable_single_batch_lhs_const_bmm: Optional[bool] = False,
|
|
167
277
|
):
|
|
168
278
|
super().__init__()
|
|
169
279
|
self.converters: List[Converter] = []
|
|
@@ -171,6 +281,8 @@ class ConvertMatmulToLinear(PassBase):
|
|
|
171
281
|
self.converters.append(LhsConstMatmulToLinearConverter())
|
|
172
282
|
if enable_rhs_const:
|
|
173
283
|
self.converters.append(RhsConstMatmulToLinearConverter())
|
|
284
|
+
if enable_single_batch_lhs_const_bmm:
|
|
285
|
+
self.converters.append(SingleBatchLhsConstBmmToLinearConverter())
|
|
174
286
|
|
|
175
287
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
176
288
|
logger = logging.getLogger(__name__)
|
|
@@ -28,6 +28,42 @@ from tico.utils.errors import InvalidArgumentError
|
|
|
28
28
|
from tico.utils.validate_args_kwargs import ConstantPadNdArgs
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
def convert_to_circle_padding(pad, input_shape_len):
|
|
32
|
+
MAX_RANK = 4
|
|
33
|
+
|
|
34
|
+
if not (1 <= input_shape_len <= MAX_RANK):
|
|
35
|
+
raise InvalidArgumentError(
|
|
36
|
+
f"Input rank must be between 1 and {MAX_RANK}, got {input_shape_len}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if len(pad) % 2 != 0 or len(pad) < 2 or len(pad) > 8:
|
|
40
|
+
raise InvalidArgumentError(
|
|
41
|
+
f"Pad length must be an even number between 2 and 8, got {len(pad)}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if len(pad) == 2:
|
|
45
|
+
padding = [[pad[0], pad[1]]]
|
|
46
|
+
elif len(pad) == 4:
|
|
47
|
+
padding = [[pad[2], pad[3]], [pad[0], pad[1]]]
|
|
48
|
+
elif len(pad) == 6:
|
|
49
|
+
padding = [[pad[4], pad[5]], [pad[2], pad[3]], [pad[0], pad[1]]]
|
|
50
|
+
elif len(pad) == 8:
|
|
51
|
+
padding = [
|
|
52
|
+
[pad[6], pad[7]],
|
|
53
|
+
[pad[4], pad[5]],
|
|
54
|
+
[pad[2], pad[3]],
|
|
55
|
+
[pad[0], pad[1]],
|
|
56
|
+
]
|
|
57
|
+
else:
|
|
58
|
+
assert False, "Cannot reach here"
|
|
59
|
+
|
|
60
|
+
# Fill [0, 0] padding for the rest of dimension
|
|
61
|
+
while len(padding) < input_shape_len:
|
|
62
|
+
padding.insert(0, [0, 0])
|
|
63
|
+
|
|
64
|
+
return padding
|
|
65
|
+
|
|
66
|
+
|
|
31
67
|
@register_node_visitor
|
|
32
68
|
class ConstantPadNdVisitor(NodeVisitor):
|
|
33
69
|
target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
|
|
@@ -45,19 +81,13 @@ class ConstantPadNdVisitor(NodeVisitor):
|
|
|
45
81
|
val = args.value
|
|
46
82
|
|
|
47
83
|
if val != 0:
|
|
48
|
-
raise InvalidArgumentError("Only support 0 value padding.")
|
|
84
|
+
raise InvalidArgumentError(f"Only support 0 value padding. pad:{pad}")
|
|
49
85
|
|
|
50
86
|
input_shape_len = len(extract_shape(input_))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
padding_size = [[0, 0], [0, 0]] + padding_size
|
|
56
|
-
else:
|
|
57
|
-
raise InvalidArgumentError("Only support 3D/4D inputs.")
|
|
58
|
-
|
|
59
|
-
paddings = torch.tensor(padding_size, dtype=torch.int32)
|
|
60
|
-
inputs = [input_, paddings]
|
|
87
|
+
|
|
88
|
+
padding = convert_to_circle_padding(pad, input_shape_len)
|
|
89
|
+
|
|
90
|
+
inputs = [input_, torch.tensor(padding, dtype=torch.int32)]
|
|
61
91
|
outputs = [node]
|
|
62
92
|
|
|
63
93
|
op_index = get_op_index(
|
tico/utils/convert.py
CHANGED
|
@@ -253,6 +253,9 @@ def convert_exported_module_to_circle(
|
|
|
253
253
|
ConvertMatmulToLinear(
|
|
254
254
|
enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
|
|
255
255
|
enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
|
|
256
|
+
enable_single_batch_lhs_const_bmm=config.get(
|
|
257
|
+
"convert_single_batch_lhs_const_bmm_to_fc"
|
|
258
|
+
),
|
|
256
259
|
),
|
|
257
260
|
LowerToResizeNearestNeighbor(),
|
|
258
261
|
LegalizePreDefinedLayoutOperators(),
|
|
@@ -1,19 +1,18 @@
|
|
|
1
|
-
tico/__init__.py,sha256=
|
|
1
|
+
tico/__init__.py,sha256=aXzPnAgp_3hFd-ia92oDhfjfZ1NABlYkgUlEbFs5Pb0,1883
|
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
|
5
5
|
tico/config/factory.py,sha256=il0zqB6Lm5NX2LnG-TUhmiP9vVeZ_3TucJMorVZIodY,1324
|
|
6
|
-
tico/config/v1.py,sha256=
|
|
6
|
+
tico/config/v1.py,sha256=lEyKemeKGrJ0bA5w-LPkMWVlnAiJRDm9mM48TJle-e4,1296
|
|
7
7
|
tico/experimental/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
8
8
|
tico/experimental/quantization/__init__.py,sha256=IaJPZegVJp0P3luutBo907Kp5sOJensE1Mm-XBG_jBs,122
|
|
9
|
-
tico/experimental/quantization/
|
|
10
|
-
tico/experimental/quantization/
|
|
11
|
-
tico/experimental/quantization/quantizer.py,sha256=_2pDtWFKDCuKfYF2bptOwIYsa0VFNFM1ZNgi8_OGvHM,2365
|
|
9
|
+
tico/experimental/quantization/public_interface.py,sha256=y-iwaeuedBvHwTh5hflQg4u2ZCdqf46IlTl9ntHq8pU,4425
|
|
10
|
+
tico/experimental/quantization/quantizer.py,sha256=pDTQGzR-BcQJeGZ7O4cXRQdCme4q_POpxHetwnv0bYg,2370
|
|
12
11
|
tico/experimental/quantization/algorithm/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
13
12
|
tico/experimental/quantization/algorithm/gptq/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
14
13
|
tico/experimental/quantization/algorithm/gptq/gptq.py,sha256=Qn9b_2ki7B64DcVEY25NMkww3PdZ5EqYQQXfYhNDQ6I,5555
|
|
15
14
|
tico/experimental/quantization/algorithm/gptq/quant.py,sha256=Rl4wAOCmlE0U09BtNCDbccaSNohRHCNLwFi3zCqZfNo,5127
|
|
16
|
-
tico/experimental/quantization/algorithm/gptq/quantizer.py,sha256=
|
|
15
|
+
tico/experimental/quantization/algorithm/gptq/quantizer.py,sha256=ZKeQQWm6eMUyRgntQxVR-QVjxJOc2pW4Dc_mrEPZA64,11686
|
|
17
16
|
tico/experimental/quantization/algorithm/gptq/utils.py,sha256=leGKayf-xbSjVwwAGTA5RsxUKrhDiklOQdlsLifjdrs,1811
|
|
18
17
|
tico/experimental/quantization/algorithm/pt2e/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
19
18
|
tico/experimental/quantization/algorithm/pt2e/quantizer.py,sha256=mdTvsG87bo8fu0GaWqSM8iBCs-4f4EfUlVtk-Ko6M34,2546
|
|
@@ -38,8 +37,13 @@ tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py,sha256=
|
|
|
38
37
|
tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py,sha256=Idtoya2RcGKlgUJgC9WqNz0jH3gf6ViuPmsD9ySHbls,2253
|
|
39
38
|
tico/experimental/quantization/algorithm/smoothquant/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
40
39
|
tico/experimental/quantization/algorithm/smoothquant/observer.py,sha256=OWBKQ3ox6PqeqgevxOjpXvb7uApoqE4YbUBelGhVSN8,3435
|
|
41
|
-
tico/experimental/quantization/algorithm/smoothquant/quantizer.py,sha256=
|
|
40
|
+
tico/experimental/quantization/algorithm/smoothquant/quantizer.py,sha256=14-QrKAW-Rw6pIbbNaD5eORcH2fqi40-TNFGaWVakIg,3649
|
|
42
41
|
tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py,sha256=fxCy4m-BsSjraciSVPFlPhgsOT46RjrOgczQGb7B9TA,11561
|
|
42
|
+
tico/experimental/quantization/config/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
43
|
+
tico/experimental/quantization/config/base.py,sha256=xg_HCDSuMgYvMd6ENZe4Sm2SYJgMaCBj4cmqaz_lhAs,816
|
|
44
|
+
tico/experimental/quantization/config/gptq.py,sha256=IUIEz5bLhsTXqoBCE1rfPec99zsRjwgpDbPW5YJqOPg,973
|
|
45
|
+
tico/experimental/quantization/config/pt2e.py,sha256=9HCrraTGGZeKEN9puKV-ODi7ncV2Wjc3oe_JCO1D_Rs,850
|
|
46
|
+
tico/experimental/quantization/config/smoothquant.py,sha256=fcyhu3YlOTM7fDW9lGTXh-uJOUD6CeykZj7AMCNVbak,1415
|
|
43
47
|
tico/experimental/quantization/evaluation/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
44
48
|
tico/experimental/quantization/evaluation/backend.py,sha256=CZL9rZOA0t8cH7PHp6u9l7dGqWNvTj9bKOvwo0PVul0,692
|
|
45
49
|
tico/experimental/quantization/evaluation/evaluate.py,sha256=kfa_GvFaX6DoSTAmuCImMJqF2jgqtnor5UpC7wVmGPI,7877
|
|
@@ -68,7 +72,7 @@ tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYga
|
|
|
68
72
|
tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
|
|
69
73
|
tico/experimental/quantization/ptq/examples/quantize_llama_decoder_layer.py,sha256=mBWrjkyEovYQsPC4Rrsri6Pm1rlFmDb3NiP0DQQhFyM,5751
|
|
70
74
|
tico/experimental/quantization/ptq/examples/quantize_llama_mlp.py,sha256=N1qZQgt1S-xZrdv-PW7OfXEcv0gsO2q9faOF4aD-zKo,4147
|
|
71
|
-
tico/experimental/quantization/ptq/examples/quantize_with_gptq.py,sha256=
|
|
75
|
+
tico/experimental/quantization/ptq/examples/quantize_with_gptq.py,sha256=y-SK56j4wL-9j-0jtuOqQUq4CElZtGOETp-Tg4XivUI,10438
|
|
72
76
|
tico/experimental/quantization/ptq/observers/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
73
77
|
tico/experimental/quantization/ptq/observers/affine_base.py,sha256=e2Eba64nrxKQyE4F_WJ7WTSsk3xe6bkdGUKaoLFWGFw,4638
|
|
74
78
|
tico/experimental/quantization/ptq/observers/base.py,sha256=Wons1MzpqK1mfcy-ppl-B2Dum0edXg2dWW2Lw3V18tw,3280
|
|
@@ -86,6 +90,7 @@ tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfv
|
|
|
86
90
|
tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
|
|
87
91
|
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=OVO5nev6J8Br9zsIX-Ut7ZgWzA9f_jk0Np9bGioXgQM,5171
|
|
88
92
|
tico/experimental/quantization/ptq/wrappers/fairseq/__init__.py,sha256=Mc8FLd9DusyB_IT1vk1OYrRkngOYnYd05IvtA9ORVQc,160
|
|
93
|
+
tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder.py,sha256=CILYvxPhW7xLkroWW_hunQBGAYGexLqnPnO5xmMnK-E,17877
|
|
89
94
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py,sha256=JT79shxOhDtRFgm8jrrN6HKvyVotiytLjMjAxX-Cztg,20416
|
|
90
95
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_encoder.py,sha256=r9DPUAbL2KRJ8zpMJ39Y9n6Oe79nte-mFcdjG2qEP-w,13809
|
|
91
96
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_encoder_layer.py,sha256=aGr80Ku75j2H-UZ0elEa0mOQEyaAs2YJ4WJCN0lonn0,6412
|
|
@@ -108,7 +113,7 @@ tico/passes/cast_mixed_type_args.py,sha256=Wd3sCDKJZwdb8GiMWKljm8X5CLFRd8eCz-dmW
|
|
|
108
113
|
tico/passes/const_prop_pass.py,sha256=hDxGgJNiRjsgOArdaoeAOcOOA-nKBvA1W1zcMZQA5yg,11531
|
|
109
114
|
tico/passes/convert_conv1d_to_conv2d.py,sha256=ktS3h158y9rg1sQiW8BZZbflV_dk_UdjBPQnuiOKyzg,5303
|
|
110
115
|
tico/passes/convert_layout_op_to_reshape.py,sha256=sCAFjkmVtiKjvDQSAgnjNBHl3_hWXJZElGDXQiTH-7s,2963
|
|
111
|
-
tico/passes/convert_matmul_to_linear.py,sha256=
|
|
116
|
+
tico/passes/convert_matmul_to_linear.py,sha256=WATtsHk_GzsU0HYovc3UMyEj8ApF2qLbInAsNlQj0nE,9759
|
|
112
117
|
tico/passes/convert_repeat_to_expand_copy.py,sha256=JbtFTmWyfJS2SSd_higP1IEhQeh7wHdN5dmTbbiFVCs,3237
|
|
113
118
|
tico/passes/convert_to_relu6.py,sha256=9B6OLyF72tMvD-ugV7aBx6l1szwERufNBUaX34pkZ4c,6445
|
|
114
119
|
tico/passes/decompose_addmm.py,sha256=KjnpZjSuA0uvNmKaTN_EMwobcOi3CAB81buORzTDxro,3979
|
|
@@ -158,7 +163,7 @@ tico/serialize/operators/op_bmm.py,sha256=AELjHC9ISFPIzEEl5Kr1s4GSNLZElwZmVZJWkE
|
|
|
158
163
|
tico/serialize/operators/op_cat.py,sha256=XDYOh0XAyrM0TlxVm6Sa0OFFGrKk7aSDcGXC-hYX4gs,2204
|
|
159
164
|
tico/serialize/operators/op_clamp.py,sha256=RRQVrzayDfN3PioCVJqa_yYOtcYwb5HHwkMe4E_YPmE,4408
|
|
160
165
|
tico/serialize/operators/op_clone.py,sha256=vzDYJ8TS3tc2BAyd_z8nt5VqT1inpymSseMEhd9dva0,2394
|
|
161
|
-
tico/serialize/operators/op_constant_pad_nd.py,sha256=
|
|
166
|
+
tico/serialize/operators/op_constant_pad_nd.py,sha256=nGWqYWNbj2E9ChQuoHsN-d8AO7UyVexnPil7qTqWZp8,3444
|
|
162
167
|
tico/serialize/operators/op_conv2d.py,sha256=1_vouWXaF51gDLYg8z5Zlup0Tecq_ggAzvguiHzFffw,6828
|
|
163
168
|
tico/serialize/operators/op_copy.py,sha256=boXHfl0bcvdBVl0tpzPMA_KBonh80vVqv61N3H5-PRU,6941
|
|
164
169
|
tico/serialize/operators/op_cos.py,sha256=N12bNyuTQIxRnD0eHRPdFVzRQPMy1NFM4iM8oQ4lYzw,2034
|
|
@@ -230,7 +235,7 @@ tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT
|
|
|
230
235
|
tico/serialize/operators/adapters/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
231
236
|
tico/serialize/operators/adapters/llama_rmsnorm.py,sha256=6t3dhfNpR03eIjsmhymF2JKd6lCf7PvInqMf77c_BOE,1139
|
|
232
237
|
tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
233
|
-
tico/utils/convert.py,sha256=
|
|
238
|
+
tico/utils/convert.py,sha256=XbogVXO-QS0UTFNvEDyADvhCp87kTUpGAUalN8I8eRQ,13645
|
|
234
239
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
|
235
240
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
|
236
241
|
tico/utils/dtype.py,sha256=L5Qb7qgbt0eQ5frUTvHYrRtTJb1dg4-JNEopcxCNg1U,1389
|
|
@@ -254,9 +259,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
|
254
259
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
|
255
260
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
|
256
261
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
|
257
|
-
tico-0.1.0.
|
|
258
|
-
tico-0.1.0.
|
|
259
|
-
tico-0.1.0.
|
|
260
|
-
tico-0.1.0.
|
|
261
|
-
tico-0.1.0.
|
|
262
|
-
tico-0.1.0.
|
|
262
|
+
tico-0.1.0.dev250922.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
|
263
|
+
tico-0.1.0.dev250922.dist-info/METADATA,sha256=2JnBgGh089dLyvlk3CyDQyTraHh_vDRRcPZla7pmuus,8450
|
|
264
|
+
tico-0.1.0.dev250922.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
|
265
|
+
tico-0.1.0.dev250922.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
|
266
|
+
tico-0.1.0.dev250922.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
|
267
|
+
tico-0.1.0.dev250922.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|