tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
tico/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import warnings
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from packaging.version import Version
|
19
|
+
|
20
|
+
from tico.config import CompileConfigV1, get_default_config
|
21
|
+
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
|
+
|
23
|
+
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
+
__version__ = "0.1.0"
|
25
|
+
|
26
|
+
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
|
+
SECURE_TORCH_VERSION = "2.6.0"
|
28
|
+
|
29
|
+
if Version(torch.__version__) < Version(MINIMUM_SUPPORTED_VERSION):
|
30
|
+
warnings.warn(
|
31
|
+
f"TICO officially supports torch>={MINIMUM_SUPPORTED_VERSION}. "
|
32
|
+
f"You are using a lower version of torch ({torch.__version__}). "
|
33
|
+
f"We highly recommend to upgrade torch>={MINIMUM_SUPPORTED_VERSION} to avoid unexpected behaviors."
|
34
|
+
)
|
35
|
+
|
36
|
+
if Version(torch.__version__) < Version(SECURE_TORCH_VERSION):
|
37
|
+
warnings.warn(
|
38
|
+
f"Detected PyTorch version {torch.__version__}, which may include known security vulnerabilities. "
|
39
|
+
f"We recommend upgrading to {SECURE_TORCH_VERSION} or later for better security.\n"
|
40
|
+
"Upgrade command: pip install --upgrade torch\n"
|
41
|
+
"For more details, see: https://pytorch.org/security"
|
42
|
+
)
|
tico/config/__init__.py
ADDED
tico/config/base.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class CompileConfigBase:
|
20
|
+
def get(self, name: str):
|
21
|
+
return getattr(self, name) if hasattr(self, name) else None
|
22
|
+
|
23
|
+
def set(self, name: str, enabled: bool):
|
24
|
+
setattr(self, name, enabled)
|
25
|
+
|
26
|
+
def to_dict(self):
|
27
|
+
return {key: value for key, value in self.__dict__.items()}
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def from_dict(cls, config_dict: dict):
|
31
|
+
config = cls()
|
32
|
+
for key in config_dict:
|
33
|
+
if key in config.to_dict():
|
34
|
+
assert type(config.get(key)) == bool
|
35
|
+
config.set(key, config_dict[key])
|
36
|
+
|
37
|
+
return config
|
tico/config/factory.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Type
|
16
|
+
|
17
|
+
from tico.config.base import CompileConfigBase
|
18
|
+
from tico.config.v1 import CompileConfigV1
|
19
|
+
|
20
|
+
|
21
|
+
class CompileConfigFactory:
|
22
|
+
_config_classes = {
|
23
|
+
"1.0": CompileConfigV1,
|
24
|
+
# '2.0': CompileConfigV2,
|
25
|
+
}
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def get_config(cls, version: str) -> Type[CompileConfigBase]:
|
29
|
+
if version not in cls._config_classes:
|
30
|
+
raise ValueError(f"Unsupported version: {version}")
|
31
|
+
|
32
|
+
return cls._config_classes[version]
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def create(cls, version: str):
|
36
|
+
config_class = cls.get_config(version)
|
37
|
+
return config_class()
|
38
|
+
|
39
|
+
|
40
|
+
def get_default_config(version: str = "1.0"):
|
41
|
+
return CompileConfigFactory.create(version)
|
tico/config/v1.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
|
17
|
+
from tico.config.base import CompileConfigBase
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class CompileConfigV1(CompileConfigBase):
|
22
|
+
legalize_causal_mask_value: bool = False
|
23
|
+
|
24
|
+
def get(self, name: str):
|
25
|
+
return super().get(name)
|
26
|
+
|
27
|
+
def set(self, name: str, enabled: bool):
|
28
|
+
super().set(name, enabled)
|
29
|
+
|
30
|
+
def to_dict(self):
|
31
|
+
return super().to_dict()
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def from_dict(cls, config_dict: dict):
|
35
|
+
return super().from_dict(config_dict)
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1 @@
|
|
1
|
+
from tico.experimental.quantization.public_interface import convert, prepare
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,172 @@
|
|
1
|
+
# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
|
2
|
+
# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
|
3
|
+
# Apache License 2.0.
|
4
|
+
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
|
19
|
+
# https://github.com/IST-DASLab/gptq/blob/2d65066/gptq.py
|
20
|
+
|
21
|
+
import math
|
22
|
+
import time
|
23
|
+
from typing import Optional
|
24
|
+
|
25
|
+
import torch
|
26
|
+
import torch.nn as nn
|
27
|
+
|
28
|
+
from tico.experimental.quantization.algorithm.gptq.quant import quantize, Quantizer
|
29
|
+
|
30
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
31
|
+
torch.backends.cudnn.allow_tf32 = False
|
32
|
+
|
33
|
+
|
34
|
+
class GPTQ:
|
35
|
+
def __init__(self, layer):
|
36
|
+
self.layer = layer
|
37
|
+
self.dev = self.layer.weight.device
|
38
|
+
W = layer.weight.data.clone()
|
39
|
+
self.rows = W.shape[0]
|
40
|
+
self.columns = W.shape[1]
|
41
|
+
self.H: Optional[torch.Tensor] = torch.zeros(
|
42
|
+
(self.columns, self.columns), device=self.dev
|
43
|
+
)
|
44
|
+
self.nsamples = 0
|
45
|
+
self.quantizer: Quantizer = Quantizer()
|
46
|
+
|
47
|
+
def add_batch(self, inp, out):
|
48
|
+
if len(inp.shape) == 2:
|
49
|
+
inp = inp.unsqueeze(0)
|
50
|
+
tmp = inp.shape[0]
|
51
|
+
if isinstance(self.layer, nn.Linear):
|
52
|
+
if len(inp.shape) == 3:
|
53
|
+
inp = inp.reshape((-1, inp.shape[-1]))
|
54
|
+
inp = inp.t()
|
55
|
+
self.H *= self.nsamples / (self.nsamples + tmp)
|
56
|
+
self.nsamples += tmp
|
57
|
+
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
58
|
+
self.H += inp.matmul(inp.t())
|
59
|
+
|
60
|
+
def fasterquant(
|
61
|
+
self,
|
62
|
+
blocksize=128,
|
63
|
+
percdamp=0.01,
|
64
|
+
groupsize=-1,
|
65
|
+
actorder=False,
|
66
|
+
static_groups=False,
|
67
|
+
verbose=False,
|
68
|
+
):
|
69
|
+
W = self.layer.weight.data.clone()
|
70
|
+
W = W.float()
|
71
|
+
tick = time.time()
|
72
|
+
if not self.quantizer.ready():
|
73
|
+
self.quantizer.find_params(W, weight=True)
|
74
|
+
|
75
|
+
H = self.H
|
76
|
+
del self.H
|
77
|
+
assert isinstance(H, torch.Tensor)
|
78
|
+
dead = torch.diag(H) == 0
|
79
|
+
H[dead, dead] = 1
|
80
|
+
W[:, dead] = 0
|
81
|
+
|
82
|
+
if static_groups:
|
83
|
+
import copy
|
84
|
+
|
85
|
+
groups = []
|
86
|
+
for i in range(0, self.columns, groupsize):
|
87
|
+
quantizer = copy.deepcopy(self.quantizer)
|
88
|
+
quantizer.find_params(W[:, i : (i + groupsize)], weight=True)
|
89
|
+
groups.append(quantizer)
|
90
|
+
|
91
|
+
if actorder:
|
92
|
+
perm = torch.argsort(torch.diag(H), descending=True)
|
93
|
+
W = W[:, perm]
|
94
|
+
H = H[perm][:, perm]
|
95
|
+
invperm = torch.argsort(perm)
|
96
|
+
|
97
|
+
Losses = torch.zeros_like(W)
|
98
|
+
Q = torch.zeros_like(W)
|
99
|
+
|
100
|
+
damp = percdamp * torch.mean(torch.diag(H))
|
101
|
+
diag = torch.arange(self.columns, device=self.dev)
|
102
|
+
H[diag, diag] += damp
|
103
|
+
H = torch.linalg.cholesky(H)
|
104
|
+
assert isinstance(H, torch.Tensor)
|
105
|
+
H = torch.cholesky_inverse(H)
|
106
|
+
H = torch.linalg.cholesky(H, upper=True)
|
107
|
+
Hinv = H
|
108
|
+
|
109
|
+
assert isinstance(Hinv, torch.Tensor)
|
110
|
+
for i1 in range(0, self.columns, blocksize):
|
111
|
+
i2 = min(i1 + blocksize, self.columns)
|
112
|
+
count = i2 - i1
|
113
|
+
|
114
|
+
W1 = W[:, i1:i2].clone()
|
115
|
+
Q1 = torch.zeros_like(W1)
|
116
|
+
Err1 = torch.zeros_like(W1)
|
117
|
+
Losses1 = torch.zeros_like(W1)
|
118
|
+
Hinv1 = Hinv[i1:i2, i1:i2]
|
119
|
+
|
120
|
+
for i in range(count):
|
121
|
+
w = W1[:, i]
|
122
|
+
d = Hinv1[i, i]
|
123
|
+
|
124
|
+
if groupsize != -1:
|
125
|
+
if not static_groups:
|
126
|
+
if (i1 + i) % groupsize == 0:
|
127
|
+
self.quantizer.find_params(
|
128
|
+
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
|
129
|
+
)
|
130
|
+
else:
|
131
|
+
idx: torch.Tensor | int = i1 + i
|
132
|
+
if actorder:
|
133
|
+
idx = perm[idx]
|
134
|
+
self.quantizer = groups[idx // groupsize]
|
135
|
+
|
136
|
+
q = quantize(
|
137
|
+
w.unsqueeze(1),
|
138
|
+
self.quantizer.scale,
|
139
|
+
self.quantizer.zero,
|
140
|
+
self.quantizer.maxq,
|
141
|
+
).flatten()
|
142
|
+
Q1[:, i] = q
|
143
|
+
Losses1[:, i] = (w - q) ** 2 / d**2
|
144
|
+
|
145
|
+
err1 = (w - q) / d
|
146
|
+
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
147
|
+
Err1[:, i] = err1
|
148
|
+
|
149
|
+
Q[:, i1:i2] = Q1
|
150
|
+
Losses[:, i1:i2] = Losses1 / 2
|
151
|
+
|
152
|
+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
153
|
+
|
154
|
+
if torch.cuda.is_available():
|
155
|
+
torch.cuda.synchronize()
|
156
|
+
if verbose:
|
157
|
+
print("time %.2f" % (time.time() - tick))
|
158
|
+
print("error", torch.sum(Losses).item())
|
159
|
+
|
160
|
+
if actorder:
|
161
|
+
Q = Q[:, invperm]
|
162
|
+
|
163
|
+
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(
|
164
|
+
self.layer.weight.data.dtype
|
165
|
+
)
|
166
|
+
|
167
|
+
def free(self):
|
168
|
+
self.H = None
|
169
|
+
self.Losses = None
|
170
|
+
self.Trace = None
|
171
|
+
if torch.cuda.is_available():
|
172
|
+
torch.cuda.empty_cache()
|
@@ -0,0 +1,153 @@
|
|
1
|
+
# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
|
2
|
+
# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
|
3
|
+
# Apache License 2.0.
|
4
|
+
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
|
19
|
+
# https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py
|
20
|
+
|
21
|
+
import torch
|
22
|
+
import torch.nn as nn
|
23
|
+
|
24
|
+
|
25
|
+
def quantize(x, scale, zero, maxq):
|
26
|
+
if maxq < 0:
|
27
|
+
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
|
28
|
+
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
29
|
+
return scale * (q - zero)
|
30
|
+
|
31
|
+
|
32
|
+
class Quantizer(nn.Module):
|
33
|
+
def __init__(self, shape=1):
|
34
|
+
super(Quantizer, self).__init__()
|
35
|
+
self.register_buffer("maxq", torch.tensor(0))
|
36
|
+
self.register_buffer("scale", torch.zeros(shape))
|
37
|
+
self.register_buffer("zero", torch.zeros(shape))
|
38
|
+
|
39
|
+
def configure(
|
40
|
+
self,
|
41
|
+
bits,
|
42
|
+
perchannel=False,
|
43
|
+
sym=True,
|
44
|
+
mse=False,
|
45
|
+
norm=2.4,
|
46
|
+
grid=100,
|
47
|
+
maxshrink=0.8,
|
48
|
+
trits=False,
|
49
|
+
):
|
50
|
+
self.maxq = torch.tensor(2**bits - 1)
|
51
|
+
self.perchannel = perchannel
|
52
|
+
self.sym = sym
|
53
|
+
self.mse = mse
|
54
|
+
self.norm = norm
|
55
|
+
self.grid = grid
|
56
|
+
self.maxshrink = maxshrink
|
57
|
+
if trits:
|
58
|
+
self.maxq = torch.tensor(-1)
|
59
|
+
|
60
|
+
def find_params(self, x, weight=False):
|
61
|
+
dev = x.device
|
62
|
+
self.maxq = self.maxq.to(dev)
|
63
|
+
|
64
|
+
shape = x.shape
|
65
|
+
if self.perchannel:
|
66
|
+
if weight:
|
67
|
+
x = x.flatten(1)
|
68
|
+
else:
|
69
|
+
if len(shape) == 4:
|
70
|
+
x = x.permute([1, 0, 2, 3])
|
71
|
+
x = x.flatten(1)
|
72
|
+
if len(shape) == 3:
|
73
|
+
x = x.reshape((-1, shape[-1])).t()
|
74
|
+
if len(shape) == 2:
|
75
|
+
x = x.t()
|
76
|
+
else:
|
77
|
+
x = x.flatten().unsqueeze(0)
|
78
|
+
|
79
|
+
tmp = torch.zeros(x.shape[0], device=dev)
|
80
|
+
xmin = torch.minimum(x.min(1)[0], tmp)
|
81
|
+
xmax = torch.maximum(x.max(1)[0], tmp)
|
82
|
+
|
83
|
+
if self.sym:
|
84
|
+
xmax = torch.maximum(torch.abs(xmin), xmax)
|
85
|
+
tmp = xmin < 0
|
86
|
+
if torch.any(tmp):
|
87
|
+
xmin[tmp] = -xmax[tmp]
|
88
|
+
tmp = (xmin == 0) & (xmax == 0)
|
89
|
+
xmin[tmp] = -1
|
90
|
+
xmax[tmp] = +1
|
91
|
+
|
92
|
+
if self.maxq < 0:
|
93
|
+
self.scale = xmax
|
94
|
+
self.zero = xmin
|
95
|
+
else:
|
96
|
+
self.scale = (xmax - xmin) / self.maxq
|
97
|
+
if self.sym:
|
98
|
+
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type]
|
99
|
+
else:
|
100
|
+
self.zero = torch.round(-xmin / self.scale)
|
101
|
+
|
102
|
+
if self.mse:
|
103
|
+
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
104
|
+
for i in range(int(self.maxshrink * self.grid)):
|
105
|
+
p = 1 - i / self.grid
|
106
|
+
xmin1 = p * xmin
|
107
|
+
xmax1 = p * xmax
|
108
|
+
scale1 = (xmax1 - xmin1) / self.maxq
|
109
|
+
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
110
|
+
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
111
|
+
q -= x
|
112
|
+
q.abs_()
|
113
|
+
q.pow_(self.norm)
|
114
|
+
err = torch.sum(q, 1)
|
115
|
+
tmp = err < best
|
116
|
+
if torch.any(tmp):
|
117
|
+
best[tmp] = err[tmp]
|
118
|
+
self.scale[tmp] = scale1[tmp]
|
119
|
+
self.zero[tmp] = zero1[tmp]
|
120
|
+
if not self.perchannel:
|
121
|
+
if weight:
|
122
|
+
tmp = shape[0]
|
123
|
+
else:
|
124
|
+
tmp = shape[1] if len(shape) != 3 else shape[2]
|
125
|
+
assert isinstance(tmp, int)
|
126
|
+
self.scale = self.scale.repeat(tmp)
|
127
|
+
self.zero = self.zero.repeat(tmp)
|
128
|
+
|
129
|
+
if weight:
|
130
|
+
shape = [-1] + [1] * (len(shape) - 1)
|
131
|
+
self.scale = self.scale.reshape(shape)
|
132
|
+
self.zero = self.zero.reshape(shape)
|
133
|
+
return
|
134
|
+
if len(shape) == 4:
|
135
|
+
self.scale = self.scale.reshape((1, -1, 1, 1))
|
136
|
+
self.zero = self.zero.reshape((1, -1, 1, 1))
|
137
|
+
if len(shape) == 3:
|
138
|
+
self.scale = self.scale.reshape((1, 1, -1))
|
139
|
+
self.zero = self.zero.reshape((1, 1, -1))
|
140
|
+
if len(shape) == 2:
|
141
|
+
self.scale = self.scale.unsqueeze(0)
|
142
|
+
self.zero = self.zero.unsqueeze(0)
|
143
|
+
|
144
|
+
def quantize(self, x):
|
145
|
+
if self.ready():
|
146
|
+
return quantize(x, self.scale, self.zero, self.maxq)
|
147
|
+
return x
|
148
|
+
|
149
|
+
def enabled(self):
|
150
|
+
return self.maxq > 0
|
151
|
+
|
152
|
+
def ready(self):
|
153
|
+
return torch.all(self.scale != 0)
|