tico 0.1.0.dev250806__py3-none-any.whl → 0.1.0.dev250810__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +8 -1
- tico/experimental/quantization/config.py +2 -1
- tico/experimental/quantization/evaluation/metric.py +17 -0
- tico/experimental/quantization/ptq/__init__.py +13 -0
- tico/experimental/quantization/ptq/dtypes.py +70 -0
- tico/experimental/quantization/ptq/mode.py +32 -0
- tico/experimental/quantization/ptq/qscheme.py +40 -0
- tico/utils/dtype.py +22 -0
- tico/utils/signature.py +248 -0
- {tico-0.1.0.dev250806.dist-info → tico-0.1.0.dev250810.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250806.dist-info → tico-0.1.0.dev250810.dist-info}/RECORD +16 -11
- {tico-0.1.0.dev250806.dist-info → tico-0.1.0.dev250810.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250806.dist-info → tico-0.1.0.dev250810.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250806.dist-info → tico-0.1.0.dev250810.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250806.dist-info → tico-0.1.0.dev250810.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -183,7 +183,12 @@ class GPTQQuantizer(BaseQuantizer):
|
|
183
183
|
|
184
184
|
quantizers: Dict[str, Any] = {}
|
185
185
|
for l_idx, layer in enumerate(
|
186
|
-
tqdm(
|
186
|
+
tqdm(
|
187
|
+
target_layers,
|
188
|
+
desc="Quantizing layers",
|
189
|
+
unit="layer",
|
190
|
+
disable=not gptq_conf.show_progress,
|
191
|
+
)
|
187
192
|
):
|
188
193
|
# 1) Identify quantizable submodules within the layer
|
189
194
|
full = find_layers(layer)
|
@@ -218,6 +223,7 @@ class GPTQQuantizer(BaseQuantizer):
|
|
218
223
|
desc=f"[L{l_idx}] collecting",
|
219
224
|
leave=False,
|
220
225
|
unit="batch",
|
226
|
+
disable=not gptq_conf.show_progress,
|
221
227
|
):
|
222
228
|
cache_args_batch = gather_single_batch_from_list(
|
223
229
|
self.cache_args, batch_idx
|
@@ -251,6 +257,7 @@ class GPTQQuantizer(BaseQuantizer):
|
|
251
257
|
desc=f"[L{l_idx}] re-forward",
|
252
258
|
leave=False,
|
253
259
|
unit="batch",
|
260
|
+
disable=not gptq_conf.show_progress,
|
254
261
|
):
|
255
262
|
cache_args_batch = gather_single_batch_from_list(
|
256
263
|
self.cache_args, batch_idx
|
@@ -42,8 +42,9 @@ class GPTQConfig(BaseConfig):
|
|
42
42
|
Configuration for GPTQ.
|
43
43
|
"""
|
44
44
|
|
45
|
-
def __init__(self, verbose: bool = False):
|
45
|
+
def __init__(self, verbose: bool = False, show_progress: bool = True):
|
46
46
|
self.verbose = verbose
|
47
|
+
self.show_progress = show_progress
|
47
48
|
|
48
49
|
@property
|
49
50
|
def name(self) -> str:
|
@@ -42,6 +42,22 @@ def compute_peir(base: torch.Tensor, target: torch.Tensor) -> float:
|
|
42
42
|
return peak_error / interval
|
43
43
|
|
44
44
|
|
45
|
+
def mse(base: torch.Tensor, target: torch.Tensor) -> float:
|
46
|
+
"""
|
47
|
+
Mean Squared Error (MSE).
|
48
|
+
Penalizes **larger** deviations more heavily than MAE by squaring each
|
49
|
+
difference — helpful to expose occasional large spikes.
|
50
|
+
Formula
|
51
|
+
-------
|
52
|
+
MSE = mean((base - target)²)
|
53
|
+
Returns
|
54
|
+
-------
|
55
|
+
float
|
56
|
+
Mean squared error. *Lower is better*.
|
57
|
+
"""
|
58
|
+
return torch.mean((base.detach() - target.detach()) ** 2).item()
|
59
|
+
|
60
|
+
|
45
61
|
class MetricCalculator:
|
46
62
|
"""
|
47
63
|
Lightweight registry-and-dispatcher for **pair-wise tensor comparison metrics**.
|
@@ -83,6 +99,7 @@ class MetricCalculator:
|
|
83
99
|
"diff": compute_max_abs_diff,
|
84
100
|
"max_abs_diff": compute_max_abs_diff,
|
85
101
|
"peir": compute_peir,
|
102
|
+
"mse": mse,
|
86
103
|
}
|
87
104
|
|
88
105
|
def __init__(
|
@@ -0,0 +1,13 @@
|
|
1
|
+
"""
|
2
|
+
Public PTQ API — re-export the most common symbols.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from tico.experimental.quantization.ptq.dtypes import DType
|
6
|
+
from tico.experimental.quantization.ptq.mode import Mode
|
7
|
+
from tico.experimental.quantization.ptq.qscheme import QScheme
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"DType",
|
11
|
+
"Mode",
|
12
|
+
"QScheme",
|
13
|
+
]
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass(frozen=True)
|
19
|
+
class DType:
|
20
|
+
"""
|
21
|
+
Self-contained integer dtypes for quantization.
|
22
|
+
|
23
|
+
A DType is just an immutable value-object with two fields:
|
24
|
+
- bits
|
25
|
+
- signed
|
26
|
+
|
27
|
+
Common presets (INT8, UINT4, ..) are provided as constants for convenience.
|
28
|
+
"""
|
29
|
+
|
30
|
+
bits: int # pylint: disable=used-before-assignment
|
31
|
+
signed: bool = False # False -> unsigned
|
32
|
+
|
33
|
+
@property
|
34
|
+
def qmin(self) -> int:
|
35
|
+
assert self.bits is not None
|
36
|
+
if self.signed:
|
37
|
+
return -(1 << (self.bits - 1))
|
38
|
+
return 0
|
39
|
+
|
40
|
+
@property
|
41
|
+
def qmax(self) -> int:
|
42
|
+
assert self.bits is not None
|
43
|
+
if self.signed:
|
44
|
+
return (1 << (self.bits - 1)) - 1
|
45
|
+
return (1 << self.bits) - 1
|
46
|
+
|
47
|
+
def __str__(self) -> str:
|
48
|
+
prefix = "int" if self.signed else "uint"
|
49
|
+
return f"{prefix}{self.bits}"
|
50
|
+
|
51
|
+
# ────────────────────────────────
|
52
|
+
# Factory helpers
|
53
|
+
# ────────────────────────────────
|
54
|
+
@staticmethod
|
55
|
+
def int(bits: int): # type: ignore[valid-type]
|
56
|
+
return DType(bits, signed=True)
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def uint(bits: int): # type: ignore[valid-type]
|
60
|
+
return DType(bits, signed=False)
|
61
|
+
|
62
|
+
|
63
|
+
# ---------------------------------------------------------------------
|
64
|
+
# Convenient canned versions
|
65
|
+
# ---------------------------------------------------------------------
|
66
|
+
UINT4 = DType.uint(4)
|
67
|
+
INT4 = DType.int(4)
|
68
|
+
INT8 = DType.int(8)
|
69
|
+
UINT8 = DType.uint(8)
|
70
|
+
INT16 = DType.int(16)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import auto, Enum
|
16
|
+
|
17
|
+
|
18
|
+
class Mode(Enum):
|
19
|
+
"""
|
20
|
+
Mode — global FSM for PTQWrapper & Handlers.
|
21
|
+
|
22
|
+
• NO_QUANT : pure pass-through (no stats, no fake-quant)
|
23
|
+
• CALIB : collect observer statistics only
|
24
|
+
• QUANT : use cached (scale, zero-point) → fake-quant enabled
|
25
|
+
"""
|
26
|
+
|
27
|
+
NO_QUANT = auto()
|
28
|
+
CALIB = auto()
|
29
|
+
QUANT = auto()
|
30
|
+
|
31
|
+
def __str__(self) -> str:
|
32
|
+
return self.name.lower()
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import auto, Enum
|
16
|
+
|
17
|
+
|
18
|
+
class QScheme(Enum):
|
19
|
+
# ───── Per-tensor ────────────
|
20
|
+
PER_TENSOR_ASYMM = auto()
|
21
|
+
PER_TENSOR_SYMM = auto()
|
22
|
+
# ───── Per-channel ───────────
|
23
|
+
PER_CHANNEL_ASYMM = auto()
|
24
|
+
PER_CHANNEL_SYMM = auto()
|
25
|
+
|
26
|
+
# helper
|
27
|
+
def is_per_channel(self) -> bool:
|
28
|
+
return self in {
|
29
|
+
QScheme.PER_CHANNEL_ASYMM,
|
30
|
+
QScheme.PER_CHANNEL_SYMM,
|
31
|
+
}
|
32
|
+
|
33
|
+
def is_symmetric(self) -> bool:
|
34
|
+
return self in {
|
35
|
+
QScheme.PER_TENSOR_SYMM,
|
36
|
+
QScheme.PER_CHANNEL_SYMM,
|
37
|
+
}
|
38
|
+
|
39
|
+
def __str__(self) -> str:
|
40
|
+
return self.name.lower()
|
tico/utils/dtype.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import numpy as np
|
2
2
|
import torch
|
3
3
|
|
4
|
+
from circle_schema import circle
|
5
|
+
|
4
6
|
NUMPY_TO_TORCH_DTYPE_DICT = {
|
5
7
|
np.dtype("float32"): torch.float32,
|
6
8
|
np.dtype("float64"): torch.float64,
|
@@ -15,6 +17,26 @@ NUMPY_TO_TORCH_DTYPE_DICT = {
|
|
15
17
|
np.dtype("bool"): torch.bool,
|
16
18
|
}
|
17
19
|
|
20
|
+
CIRCLE_TO_TORCH_DTYPE_DICT = {
|
21
|
+
circle.TensorType.TensorType.FLOAT32: torch.float32,
|
22
|
+
circle.TensorType.TensorType.UINT8: torch.uint8,
|
23
|
+
circle.TensorType.TensorType.INT8: torch.int8,
|
24
|
+
circle.TensorType.TensorType.INT16: torch.int16,
|
25
|
+
circle.TensorType.TensorType.INT32: torch.int32,
|
26
|
+
circle.TensorType.TensorType.INT64: torch.int64,
|
27
|
+
circle.TensorType.TensorType.BOOL: torch.bool,
|
28
|
+
}
|
29
|
+
|
18
30
|
|
19
31
|
def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype:
|
20
32
|
return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype]
|
33
|
+
|
34
|
+
|
35
|
+
def circle_dtype_to_torch_dtype(circle_dtype: int) -> torch.dtype:
|
36
|
+
assert isinstance(circle_dtype, int)
|
37
|
+
if circle_dtype not in CIRCLE_TO_TORCH_DTYPE_DICT:
|
38
|
+
raise RuntimeError(f"Unsupported dtype {circle_dtype}")
|
39
|
+
|
40
|
+
torch_dtype = CIRCLE_TO_TORCH_DTYPE_DICT[circle_dtype]
|
41
|
+
assert torch_dtype is not None
|
42
|
+
return torch_dtype
|
tico/utils/signature.py
ADDED
@@ -0,0 +1,248 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Sequence
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from circle_schema import circle
|
20
|
+
|
21
|
+
from tico.serialize.circle_mapping import to_circle_shape
|
22
|
+
from tico.utils.dtype import circle_dtype_to_torch_dtype
|
23
|
+
from tico.utils.installed_packages import is_dynamic_cache_available
|
24
|
+
|
25
|
+
|
26
|
+
def is_dynamic_cache_instance(value):
|
27
|
+
if is_dynamic_cache_available():
|
28
|
+
from transformers.cache_utils import DynamicCache
|
29
|
+
|
30
|
+
return isinstance(value, DynamicCache)
|
31
|
+
else:
|
32
|
+
return False
|
33
|
+
|
34
|
+
|
35
|
+
def flatten_and_convert_kwargs(kwargs: dict) -> dict[str, torch.Tensor]:
|
36
|
+
result = {} # type: ignore[var-annotated]
|
37
|
+
for k, v in kwargs.items():
|
38
|
+
if v is None:
|
39
|
+
continue
|
40
|
+
elif isinstance(v, (list, tuple)):
|
41
|
+
# 1. handle list
|
42
|
+
def unpack_recursive(name, value, store=None):
|
43
|
+
if store is None:
|
44
|
+
store = {}
|
45
|
+
|
46
|
+
if isinstance(value, (tuple, list)):
|
47
|
+
for i, v in enumerate(value):
|
48
|
+
# recursive call. Append index to name and explore lower level
|
49
|
+
unpack_recursive(f"{name}_{i}", v, store)
|
50
|
+
else:
|
51
|
+
# base type (scalar etc.) directly stored
|
52
|
+
store[name] = value
|
53
|
+
|
54
|
+
return store
|
55
|
+
|
56
|
+
unpack_recursive(k, v, result)
|
57
|
+
elif is_dynamic_cache_instance(v):
|
58
|
+
# 2. handle DynamicCache
|
59
|
+
for idx, cache_val in enumerate(v.key_cache):
|
60
|
+
result[f"{k}_key_cache_{idx}"] = cache_val
|
61
|
+
|
62
|
+
for idx, cache_val in enumerate(v.value_cache):
|
63
|
+
result[f"{k}_value_cache_{idx}"] = cache_val
|
64
|
+
else:
|
65
|
+
result[k] = v
|
66
|
+
|
67
|
+
# 3. Convert to tensors
|
68
|
+
for k, v in result.items():
|
69
|
+
result[k] = v if isinstance(v, torch.Tensor) else torch.tensor(v)
|
70
|
+
|
71
|
+
return result
|
72
|
+
|
73
|
+
|
74
|
+
def flatten_and_convert_args(args: Sequence) -> tuple:
|
75
|
+
result = [] # type: ignore[var-annotated]
|
76
|
+
for item in args:
|
77
|
+
if item is None:
|
78
|
+
continue
|
79
|
+
|
80
|
+
# 1. recursion on list and tuple
|
81
|
+
if isinstance(item, (list, tuple)):
|
82
|
+
result.extend(flatten_and_convert_args(item))
|
83
|
+
continue
|
84
|
+
|
85
|
+
# 2. handle DynamicCache
|
86
|
+
if is_dynamic_cache_available():
|
87
|
+
from transformers.cache_utils import DynamicCache
|
88
|
+
|
89
|
+
if isinstance(item, DynamicCache):
|
90
|
+
# NOTE The tensor order is: key_in → key_out → value_in → value_out
|
91
|
+
#
|
92
|
+
# Refer to https://github.com/huggingface/transformers/blob/3457e8e73e4f5532cc69059682b1ba4484d7e7e8/src/transformers/cache_utils.py#L557
|
93
|
+
# ```
|
94
|
+
# self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
95
|
+
# self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
96
|
+
# ```
|
97
|
+
result.extend(item.key_cache)
|
98
|
+
result.extend(item.value_cache)
|
99
|
+
continue
|
100
|
+
|
101
|
+
# 3. Convert to tensors
|
102
|
+
result.append(item if isinstance(item, torch.Tensor) else torch.tensor(item))
|
103
|
+
|
104
|
+
return tuple(result)
|
105
|
+
|
106
|
+
|
107
|
+
class ModelInputSpec:
|
108
|
+
@classmethod
|
109
|
+
def load(cls, circle_path):
|
110
|
+
def load(circle_path: str) -> bytes:
|
111
|
+
with open(circle_path, "rb") as f:
|
112
|
+
buf = bytes(f.read())
|
113
|
+
return buf
|
114
|
+
|
115
|
+
circle_binary = load(circle_path)
|
116
|
+
return cls(circle_binary)
|
117
|
+
|
118
|
+
def __init__(self, circle_binary):
|
119
|
+
model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
|
120
|
+
assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
|
121
|
+
|
122
|
+
graph = model.Subgraphs(0)
|
123
|
+
tensors = [graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())]
|
124
|
+
|
125
|
+
self.names = [t.Name().decode("utf-8").split("::")[-1] for t in tensors]
|
126
|
+
self.shapes = [t.ShapeAsNumpy() for t in tensors]
|
127
|
+
self.shape_signatures = list(
|
128
|
+
map(
|
129
|
+
lambda x: None if (isinstance(x, int) and x == 0) else x,
|
130
|
+
(t.ShapeSignatureAsNumpy() for t in tensors),
|
131
|
+
)
|
132
|
+
)
|
133
|
+
self.types: list[torch.dtype] = [
|
134
|
+
circle_dtype_to_torch_dtype(t.Type()) for t in tensors
|
135
|
+
]
|
136
|
+
self.name_to_idx = {name: idx for idx, name in enumerate(self.names)}
|
137
|
+
|
138
|
+
def bind(self, args, kwargs, check=True):
|
139
|
+
"""Convert args and kwargs into an ordered list according to model input order"""
|
140
|
+
inputs = []
|
141
|
+
args = flatten_and_convert_args(args)
|
142
|
+
kwargs = flatten_and_convert_kwargs(kwargs)
|
143
|
+
|
144
|
+
# 1. positional arguments
|
145
|
+
for i, val in enumerate(args):
|
146
|
+
if i >= len(self.names):
|
147
|
+
raise ValueError(f"Too many positional arguments ({i+1}).")
|
148
|
+
name = self.names[i]
|
149
|
+
if name in kwargs:
|
150
|
+
raise TypeError(
|
151
|
+
f"Got multiple values for argument '{name}' (positional and keyword)."
|
152
|
+
)
|
153
|
+
inputs.append(val)
|
154
|
+
|
155
|
+
# 2. keyword arguments
|
156
|
+
for idx in range(len(args), len(self.names)):
|
157
|
+
name = self.names[idx]
|
158
|
+
if name not in kwargs:
|
159
|
+
raise ValueError(f"Missing argument for input '{name}'.")
|
160
|
+
inputs.append(kwargs[name])
|
161
|
+
|
162
|
+
if check:
|
163
|
+
self.check_types(inputs)
|
164
|
+
self.check_shapes(inputs)
|
165
|
+
|
166
|
+
return inputs
|
167
|
+
|
168
|
+
def check_types(self, inputs):
|
169
|
+
"""Check the types of input values"""
|
170
|
+
for i, (inp, ref_type) in enumerate(zip(inputs, self.types)):
|
171
|
+
# TODO: Support more data types (np array)
|
172
|
+
assert isinstance(
|
173
|
+
inp, (torch.Tensor | int | float)
|
174
|
+
), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
|
175
|
+
|
176
|
+
if isinstance(inp, torch.Tensor):
|
177
|
+
if inp.dtype != ref_type:
|
178
|
+
raise TypeError(
|
179
|
+
f"Input '{self.names[i]}' type {inp.dtype} != expected {ref_type}"
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
# Scalars (int, float)
|
183
|
+
if ref_type == torch.float32:
|
184
|
+
if not isinstance(inp, (float)):
|
185
|
+
raise TypeError(
|
186
|
+
f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
|
187
|
+
)
|
188
|
+
elif ref_type == torch.int64:
|
189
|
+
if not isinstance(inp, (int)):
|
190
|
+
raise TypeError(
|
191
|
+
f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
print(f"Unexpected ref_type: {ref_type}")
|
195
|
+
|
196
|
+
def check_shapes(self, inputs):
|
197
|
+
"""Check the shapes of input values"""
|
198
|
+
|
199
|
+
def merge(shape, shape_sig):
|
200
|
+
"""
|
201
|
+
Merge shape signature with shape
|
202
|
+
"""
|
203
|
+
from copy import deepcopy
|
204
|
+
|
205
|
+
shape_merged = deepcopy(shape)
|
206
|
+
if shape_sig is not None:
|
207
|
+
for idx, ss in enumerate(shape_sig):
|
208
|
+
if ss == -1:
|
209
|
+
shape_merged[idx] = -1
|
210
|
+
|
211
|
+
return shape_merged
|
212
|
+
|
213
|
+
for i, (inp, ref_shape, ref_shape_sig) in enumerate(
|
214
|
+
zip(inputs, self.shapes, self.shape_signatures)
|
215
|
+
):
|
216
|
+
# TODO: Support more data types (np array)
|
217
|
+
assert isinstance(
|
218
|
+
inp, (torch.Tensor | int | float)
|
219
|
+
), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
|
220
|
+
|
221
|
+
if isinstance(inp, torch.Tensor): # Tensor
|
222
|
+
in_shape, in_shape_sig = to_circle_shape(inp.size())
|
223
|
+
|
224
|
+
if len(in_shape) != len(ref_shape):
|
225
|
+
raise ValueError(
|
226
|
+
f"Input '{self.names[i]}' has invalid rank {len(in_shape)}!= expected {len(ref_shape)}"
|
227
|
+
)
|
228
|
+
|
229
|
+
in_merged_shape = merge(in_shape, in_shape_sig)
|
230
|
+
ref_merged_shape = merge(ref_shape, ref_shape_sig)
|
231
|
+
for in_shp, ref_shp in zip(in_merged_shape, ref_merged_shape):
|
232
|
+
if ref_shp == -1:
|
233
|
+
continue
|
234
|
+
if in_shp == -1:
|
235
|
+
raise ValueError(
|
236
|
+
f"Input '{self.names[i]}' has unknown dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
|
237
|
+
)
|
238
|
+
if in_shp != ref_shp:
|
239
|
+
raise ValueError(
|
240
|
+
f"Input '{self.names[i]}' has wrong dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
|
241
|
+
)
|
242
|
+
elif isinstance(inp, (int, float)): # Scalar
|
243
|
+
if len(ref_shape) > 0:
|
244
|
+
raise ValueError(
|
245
|
+
f"Input '{self.names[i]}' has invalid rank {len(ref_shape)}"
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
print(f"Unexpected input type: {type(inp)}")
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=H67V6wHv1jk6smTy2N_nzF9vFR0U8Rf-Vsc19oWjhpM,1883
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
@@ -6,14 +6,14 @@ tico/config/factory.py,sha256=il0zqB6Lm5NX2LnG-TUhmiP9vVeZ_3TucJMorVZIodY,1324
|
|
6
6
|
tico/config/v1.py,sha256=O1jzpUBDwoWpLohEpI08pJNwVB-yz3ufPrQm2_XWq4Y,1108
|
7
7
|
tico/experimental/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
8
8
|
tico/experimental/quantization/__init__.py,sha256=IaJPZegVJp0P3luutBo907Kp5sOJensE1Mm-XBG_jBs,122
|
9
|
-
tico/experimental/quantization/config.py,sha256=
|
9
|
+
tico/experimental/quantization/config.py,sha256=1bCSAUI043Kbq08j59mb-K1cP2lmBMbekh8p3hNK6b8,1675
|
10
10
|
tico/experimental/quantization/public_interface.py,sha256=4-v9VXsokRG2-UUYYHd_MlbHxChqdGI5iuySyYDY_Pw,4420
|
11
11
|
tico/experimental/quantization/quantizer.py,sha256=_2pDtWFKDCuKfYF2bptOwIYsa0VFNFM1ZNgi8_OGvHM,2365
|
12
12
|
tico/experimental/quantization/algorithm/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
13
13
|
tico/experimental/quantization/algorithm/gptq/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
14
14
|
tico/experimental/quantization/algorithm/gptq/gptq.py,sha256=Qn9b_2ki7B64DcVEY25NMkww3PdZ5EqYQQXfYhNDQ6I,5555
|
15
15
|
tico/experimental/quantization/algorithm/gptq/quant.py,sha256=Rl4wAOCmlE0U09BtNCDbccaSNohRHCNLwFi3zCqZfNo,5127
|
16
|
-
tico/experimental/quantization/algorithm/gptq/quantizer.py,sha256=
|
16
|
+
tico/experimental/quantization/algorithm/gptq/quantizer.py,sha256=_ZnSD_LBag_FVcVEniPKBmw7bNZ2iZLZ8aZnexnCgrs,11693
|
17
17
|
tico/experimental/quantization/algorithm/gptq/utils.py,sha256=leGKayf-xbSjVwwAGTA5RsxUKrhDiklOQdlsLifjdrs,1811
|
18
18
|
tico/experimental/quantization/algorithm/pt2e/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
19
19
|
tico/experimental/quantization/algorithm/pt2e/quantizer.py,sha256=mdTvsG87bo8fu0GaWqSM8iBCs-4f4EfUlVtk-Ko6M34,2546
|
@@ -43,7 +43,7 @@ tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py,sha256=O1h7
|
|
43
43
|
tico/experimental/quantization/evaluation/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
44
44
|
tico/experimental/quantization/evaluation/backend.py,sha256=CZL9rZOA0t8cH7PHp6u9l7dGqWNvTj9bKOvwo0PVul0,692
|
45
45
|
tico/experimental/quantization/evaluation/evaluate.py,sha256=kfa_GvFaX6DoSTAmuCImMJqF2jgqtnor5UpC7wVmGPI,7877
|
46
|
-
tico/experimental/quantization/evaluation/metric.py,sha256=
|
46
|
+
tico/experimental/quantization/evaluation/metric.py,sha256=t9M058dOQ8iy_2PcrbNMAebBNJs8TU8USZw_nbi2iWI,5488
|
47
47
|
tico/experimental/quantization/evaluation/utils.py,sha256=82RG_e5LuKfWo786wEZUVwXY93nNl901n04fB7D0Z6k,5909
|
48
48
|
tico/experimental/quantization/evaluation/executor/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
49
49
|
tico/experimental/quantization/evaluation/executor/backend_executor.py,sha256=3kLu3_rcsreA_NK42yRgRgubPtZmVp7QCRvaqLNw10E,1522
|
@@ -56,6 +56,10 @@ tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0
|
|
56
56
|
tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
|
57
57
|
tico/experimental/quantization/passes/quantize_bias.py,sha256=T7YxJ70N0tSK0FF9VJZA5iP0sHdnnsX9GX4AT4JDFSk,4325
|
58
58
|
tico/experimental/quantization/passes/remove_weight_dequant_op.py,sha256=gI1MtrHazWpdNfys7f1ngTTWplzluF7SA-uX0HMR5Mc,6592
|
59
|
+
tico/experimental/quantization/ptq/__init__.py,sha256=ZoPdEwZ1i1n5pBFChx8GuUrkfRP2vsSoLPNILQjNBaA,298
|
60
|
+
tico/experimental/quantization/ptq/dtypes.py,sha256=xfCBtq6mQmUYRwsoFgII6gvRl1raQi0Inj9pznDuKwQ,2236
|
61
|
+
tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAdh_3PWB-ptPKaI,1052
|
62
|
+
tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
|
59
63
|
tico/interpreter/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
60
64
|
tico/interpreter/infer.py,sha256=1ZFe3DVMR2mlwBosoedqoL0-CGN_01CKLgMgxuw62KA,4861
|
61
65
|
tico/interpreter/interpreter.py,sha256=tGbluCbrehTCqBu8mtGDNzby_ieJ2ry8_RH_eC0CQxk,3828
|
@@ -186,7 +190,7 @@ tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
186
190
|
tico/utils/convert.py,sha256=GgZwZtiqFzTdszfUQO0vcX39lKjs97gYwZ-Tiw_4Bbo,13222
|
187
191
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
188
192
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
189
|
-
tico/utils/dtype.py,sha256=
|
193
|
+
tico/utils/dtype.py,sha256=L5Qb7qgbt0eQ5frUTvHYrRtTJb1dg4-JNEopcxCNg1U,1389
|
190
194
|
tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
|
191
195
|
tico/utils/graph.py,sha256=jD6m58m5JmN9mPfaROA9CW3406iJxmnukke00AuwRqI,9131
|
192
196
|
tico/utils/installed_packages.py,sha256=J0FTwnkCGs0MxRWoCMYAqiwH7Z0GWFDLV--x-IndSp4,1017
|
@@ -198,6 +202,7 @@ tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,52
|
|
198
202
|
tico/utils/record_input.py,sha256=QN-8D71G_WAX3QQQ5CIwbEfFJZTQ3CvL4wCMiVddua4,3894
|
199
203
|
tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
|
200
204
|
tico/utils/serialize.py,sha256=mEuusEzi82WFsz3AkowgWwxSLeo50JDxyOj6yYDQhEI,1914
|
205
|
+
tico/utils/signature.py,sha256=R2GV0alRpXEbZISqPKyxCUWbgDcsrQ2ovbVG3737IzA,9595
|
201
206
|
tico/utils/torch_compat.py,sha256=oc6PztVsXdHcQ3iaVR90wLLxrGaj6zFHWZ8K9rRS6q8,1795
|
202
207
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
203
208
|
tico/utils/utils.py,sha256=A5p3iAAxRGDsZJh4ybp-Qo3MX3vk5RrmSY-R3rXqVeI,12976
|
@@ -206,9 +211,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
206
211
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
207
212
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
208
213
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
209
|
-
tico-0.1.0.
|
210
|
-
tico-0.1.0.
|
211
|
-
tico-0.1.0.
|
212
|
-
tico-0.1.0.
|
213
|
-
tico-0.1.0.
|
214
|
-
tico-0.1.0.
|
214
|
+
tico-0.1.0.dev250810.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
215
|
+
tico-0.1.0.dev250810.dist-info/METADATA,sha256=N1kcg1vk8kn6bLJgRUSZMZalY-mTn46jqLvVj9NGvR4,8450
|
216
|
+
tico-0.1.0.dev250810.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
217
|
+
tico-0.1.0.dev250810.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
218
|
+
tico-0.1.0.dev250810.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
219
|
+
tico-0.1.0.dev250810.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|