tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
tico/__init__.py ADDED
@@ -0,0 +1,42 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+
17
+ import torch
18
+ from packaging.version import Version
19
+
20
+ from tico.config import CompileConfigV1, get_default_config
21
+ from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
+
23
+ # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
+ __version__ = "0.1.0"
25
+
26
+ MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
+ SECURE_TORCH_VERSION = "2.6.0"
28
+
29
+ if Version(torch.__version__) < Version(MINIMUM_SUPPORTED_VERSION):
30
+ warnings.warn(
31
+ f"TICO officially supports torch>={MINIMUM_SUPPORTED_VERSION}. "
32
+ f"You are using a lower version of torch ({torch.__version__}). "
33
+ f"We highly recommend to upgrade torch>={MINIMUM_SUPPORTED_VERSION} to avoid unexpected behaviors."
34
+ )
35
+
36
+ if Version(torch.__version__) < Version(SECURE_TORCH_VERSION):
37
+ warnings.warn(
38
+ f"Detected PyTorch version {torch.__version__}, which may include known security vulnerabilities. "
39
+ f"We recommend upgrading to {SECURE_TORCH_VERSION} or later for better security.\n"
40
+ "Upgrade command: pip install --upgrade torch\n"
41
+ "For more details, see: https://pytorch.org/security"
42
+ )
@@ -0,0 +1,4 @@
1
+ from tico.config.base import CompileConfigBase
2
+ from tico.config.factory import get_default_config
3
+
4
+ from tico.config.v1 import CompileConfigV1
tico/config/base.py ADDED
@@ -0,0 +1,37 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+
18
+ @dataclass
19
+ class CompileConfigBase:
20
+ def get(self, name: str):
21
+ return getattr(self, name) if hasattr(self, name) else None
22
+
23
+ def set(self, name: str, enabled: bool):
24
+ setattr(self, name, enabled)
25
+
26
+ def to_dict(self):
27
+ return {key: value for key, value in self.__dict__.items()}
28
+
29
+ @classmethod
30
+ def from_dict(cls, config_dict: dict):
31
+ config = cls()
32
+ for key in config_dict:
33
+ if key in config.to_dict():
34
+ assert type(config.get(key)) == bool
35
+ config.set(key, config_dict[key])
36
+
37
+ return config
tico/config/factory.py ADDED
@@ -0,0 +1,41 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Type
16
+
17
+ from tico.config.base import CompileConfigBase
18
+ from tico.config.v1 import CompileConfigV1
19
+
20
+
21
+ class CompileConfigFactory:
22
+ _config_classes = {
23
+ "1.0": CompileConfigV1,
24
+ # '2.0': CompileConfigV2,
25
+ }
26
+
27
+ @classmethod
28
+ def get_config(cls, version: str) -> Type[CompileConfigBase]:
29
+ if version not in cls._config_classes:
30
+ raise ValueError(f"Unsupported version: {version}")
31
+
32
+ return cls._config_classes[version]
33
+
34
+ @classmethod
35
+ def create(cls, version: str):
36
+ config_class = cls.get_config(version)
37
+ return config_class()
38
+
39
+
40
+ def get_default_config(version: str = "1.0"):
41
+ return CompileConfigFactory.create(version)
tico/config/v1.py ADDED
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+ from tico.config.base import CompileConfigBase
18
+
19
+
20
+ @dataclass
21
+ class CompileConfigV1(CompileConfigBase):
22
+ legalize_causal_mask_value: bool = False
23
+
24
+ def get(self, name: str):
25
+ return super().get(name)
26
+
27
+ def set(self, name: str, enabled: bool):
28
+ super().set(name, enabled)
29
+
30
+ def to_dict(self):
31
+ return super().to_dict()
32
+
33
+ @classmethod
34
+ def from_dict(cls, config_dict: dict):
35
+ return super().from_dict(config_dict)
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1 @@
1
+ from tico.experimental.quantization.public_interface import convert, prepare
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,172 @@
1
+ # Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
2
+ # Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
3
+ # Apache License 2.0.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # https://github.com/IST-DASLab/gptq/blob/2d65066/gptq.py
20
+
21
+ import math
22
+ import time
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+
28
+ from tico.experimental.quantization.algorithm.gptq.quant import quantize, Quantizer
29
+
30
+ torch.backends.cuda.matmul.allow_tf32 = False
31
+ torch.backends.cudnn.allow_tf32 = False
32
+
33
+
34
+ class GPTQ:
35
+ def __init__(self, layer):
36
+ self.layer = layer
37
+ self.dev = self.layer.weight.device
38
+ W = layer.weight.data.clone()
39
+ self.rows = W.shape[0]
40
+ self.columns = W.shape[1]
41
+ self.H: Optional[torch.Tensor] = torch.zeros(
42
+ (self.columns, self.columns), device=self.dev
43
+ )
44
+ self.nsamples = 0
45
+ self.quantizer: Quantizer = Quantizer()
46
+
47
+ def add_batch(self, inp, out):
48
+ if len(inp.shape) == 2:
49
+ inp = inp.unsqueeze(0)
50
+ tmp = inp.shape[0]
51
+ if isinstance(self.layer, nn.Linear):
52
+ if len(inp.shape) == 3:
53
+ inp = inp.reshape((-1, inp.shape[-1]))
54
+ inp = inp.t()
55
+ self.H *= self.nsamples / (self.nsamples + tmp)
56
+ self.nsamples += tmp
57
+ inp = math.sqrt(2 / self.nsamples) * inp.float()
58
+ self.H += inp.matmul(inp.t())
59
+
60
+ def fasterquant(
61
+ self,
62
+ blocksize=128,
63
+ percdamp=0.01,
64
+ groupsize=-1,
65
+ actorder=False,
66
+ static_groups=False,
67
+ verbose=False,
68
+ ):
69
+ W = self.layer.weight.data.clone()
70
+ W = W.float()
71
+ tick = time.time()
72
+ if not self.quantizer.ready():
73
+ self.quantizer.find_params(W, weight=True)
74
+
75
+ H = self.H
76
+ del self.H
77
+ assert isinstance(H, torch.Tensor)
78
+ dead = torch.diag(H) == 0
79
+ H[dead, dead] = 1
80
+ W[:, dead] = 0
81
+
82
+ if static_groups:
83
+ import copy
84
+
85
+ groups = []
86
+ for i in range(0, self.columns, groupsize):
87
+ quantizer = copy.deepcopy(self.quantizer)
88
+ quantizer.find_params(W[:, i : (i + groupsize)], weight=True)
89
+ groups.append(quantizer)
90
+
91
+ if actorder:
92
+ perm = torch.argsort(torch.diag(H), descending=True)
93
+ W = W[:, perm]
94
+ H = H[perm][:, perm]
95
+ invperm = torch.argsort(perm)
96
+
97
+ Losses = torch.zeros_like(W)
98
+ Q = torch.zeros_like(W)
99
+
100
+ damp = percdamp * torch.mean(torch.diag(H))
101
+ diag = torch.arange(self.columns, device=self.dev)
102
+ H[diag, diag] += damp
103
+ H = torch.linalg.cholesky(H)
104
+ assert isinstance(H, torch.Tensor)
105
+ H = torch.cholesky_inverse(H)
106
+ H = torch.linalg.cholesky(H, upper=True)
107
+ Hinv = H
108
+
109
+ assert isinstance(Hinv, torch.Tensor)
110
+ for i1 in range(0, self.columns, blocksize):
111
+ i2 = min(i1 + blocksize, self.columns)
112
+ count = i2 - i1
113
+
114
+ W1 = W[:, i1:i2].clone()
115
+ Q1 = torch.zeros_like(W1)
116
+ Err1 = torch.zeros_like(W1)
117
+ Losses1 = torch.zeros_like(W1)
118
+ Hinv1 = Hinv[i1:i2, i1:i2]
119
+
120
+ for i in range(count):
121
+ w = W1[:, i]
122
+ d = Hinv1[i, i]
123
+
124
+ if groupsize != -1:
125
+ if not static_groups:
126
+ if (i1 + i) % groupsize == 0:
127
+ self.quantizer.find_params(
128
+ W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
129
+ )
130
+ else:
131
+ idx: torch.Tensor | int = i1 + i
132
+ if actorder:
133
+ idx = perm[idx]
134
+ self.quantizer = groups[idx // groupsize]
135
+
136
+ q = quantize(
137
+ w.unsqueeze(1),
138
+ self.quantizer.scale,
139
+ self.quantizer.zero,
140
+ self.quantizer.maxq,
141
+ ).flatten()
142
+ Q1[:, i] = q
143
+ Losses1[:, i] = (w - q) ** 2 / d**2
144
+
145
+ err1 = (w - q) / d
146
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
147
+ Err1[:, i] = err1
148
+
149
+ Q[:, i1:i2] = Q1
150
+ Losses[:, i1:i2] = Losses1 / 2
151
+
152
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
153
+
154
+ if torch.cuda.is_available():
155
+ torch.cuda.synchronize()
156
+ if verbose:
157
+ print("time %.2f" % (time.time() - tick))
158
+ print("error", torch.sum(Losses).item())
159
+
160
+ if actorder:
161
+ Q = Q[:, invperm]
162
+
163
+ self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(
164
+ self.layer.weight.data.dtype
165
+ )
166
+
167
+ def free(self):
168
+ self.H = None
169
+ self.Losses = None
170
+ self.Trace = None
171
+ if torch.cuda.is_available():
172
+ torch.cuda.empty_cache()
@@ -0,0 +1,153 @@
1
+ # Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
2
+ # Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
3
+ # Apache License 2.0.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+
25
+ def quantize(x, scale, zero, maxq):
26
+ if maxq < 0:
27
+ return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
28
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
29
+ return scale * (q - zero)
30
+
31
+
32
+ class Quantizer(nn.Module):
33
+ def __init__(self, shape=1):
34
+ super(Quantizer, self).__init__()
35
+ self.register_buffer("maxq", torch.tensor(0))
36
+ self.register_buffer("scale", torch.zeros(shape))
37
+ self.register_buffer("zero", torch.zeros(shape))
38
+
39
+ def configure(
40
+ self,
41
+ bits,
42
+ perchannel=False,
43
+ sym=True,
44
+ mse=False,
45
+ norm=2.4,
46
+ grid=100,
47
+ maxshrink=0.8,
48
+ trits=False,
49
+ ):
50
+ self.maxq = torch.tensor(2**bits - 1)
51
+ self.perchannel = perchannel
52
+ self.sym = sym
53
+ self.mse = mse
54
+ self.norm = norm
55
+ self.grid = grid
56
+ self.maxshrink = maxshrink
57
+ if trits:
58
+ self.maxq = torch.tensor(-1)
59
+
60
+ def find_params(self, x, weight=False):
61
+ dev = x.device
62
+ self.maxq = self.maxq.to(dev)
63
+
64
+ shape = x.shape
65
+ if self.perchannel:
66
+ if weight:
67
+ x = x.flatten(1)
68
+ else:
69
+ if len(shape) == 4:
70
+ x = x.permute([1, 0, 2, 3])
71
+ x = x.flatten(1)
72
+ if len(shape) == 3:
73
+ x = x.reshape((-1, shape[-1])).t()
74
+ if len(shape) == 2:
75
+ x = x.t()
76
+ else:
77
+ x = x.flatten().unsqueeze(0)
78
+
79
+ tmp = torch.zeros(x.shape[0], device=dev)
80
+ xmin = torch.minimum(x.min(1)[0], tmp)
81
+ xmax = torch.maximum(x.max(1)[0], tmp)
82
+
83
+ if self.sym:
84
+ xmax = torch.maximum(torch.abs(xmin), xmax)
85
+ tmp = xmin < 0
86
+ if torch.any(tmp):
87
+ xmin[tmp] = -xmax[tmp]
88
+ tmp = (xmin == 0) & (xmax == 0)
89
+ xmin[tmp] = -1
90
+ xmax[tmp] = +1
91
+
92
+ if self.maxq < 0:
93
+ self.scale = xmax
94
+ self.zero = xmin
95
+ else:
96
+ self.scale = (xmax - xmin) / self.maxq
97
+ if self.sym:
98
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type]
99
+ else:
100
+ self.zero = torch.round(-xmin / self.scale)
101
+
102
+ if self.mse:
103
+ best = torch.full([x.shape[0]], float("inf"), device=dev)
104
+ for i in range(int(self.maxshrink * self.grid)):
105
+ p = 1 - i / self.grid
106
+ xmin1 = p * xmin
107
+ xmax1 = p * xmax
108
+ scale1 = (xmax1 - xmin1) / self.maxq
109
+ zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
110
+ q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
111
+ q -= x
112
+ q.abs_()
113
+ q.pow_(self.norm)
114
+ err = torch.sum(q, 1)
115
+ tmp = err < best
116
+ if torch.any(tmp):
117
+ best[tmp] = err[tmp]
118
+ self.scale[tmp] = scale1[tmp]
119
+ self.zero[tmp] = zero1[tmp]
120
+ if not self.perchannel:
121
+ if weight:
122
+ tmp = shape[0]
123
+ else:
124
+ tmp = shape[1] if len(shape) != 3 else shape[2]
125
+ assert isinstance(tmp, int)
126
+ self.scale = self.scale.repeat(tmp)
127
+ self.zero = self.zero.repeat(tmp)
128
+
129
+ if weight:
130
+ shape = [-1] + [1] * (len(shape) - 1)
131
+ self.scale = self.scale.reshape(shape)
132
+ self.zero = self.zero.reshape(shape)
133
+ return
134
+ if len(shape) == 4:
135
+ self.scale = self.scale.reshape((1, -1, 1, 1))
136
+ self.zero = self.zero.reshape((1, -1, 1, 1))
137
+ if len(shape) == 3:
138
+ self.scale = self.scale.reshape((1, 1, -1))
139
+ self.zero = self.zero.reshape((1, 1, -1))
140
+ if len(shape) == 2:
141
+ self.scale = self.scale.unsqueeze(0)
142
+ self.zero = self.zero.unsqueeze(0)
143
+
144
+ def quantize(self, x):
145
+ if self.ready():
146
+ return quantize(x, self.scale, self.zero, self.maxq)
147
+ return x
148
+
149
+ def enabled(self):
150
+ return self.maxq > 0
151
+
152
+ def ready(self):
153
+ return torch.all(self.scale != 0)