tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,78 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+
19
+ from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
20
+
21
+ from tico.experimental.quantization.algorithm.pt2e.annotation.annotator import (
22
+ get_asymmetric_quantization_config,
23
+ PT2EAnnotator,
24
+ )
25
+ from tico.experimental.quantization.quantizer import BaseQuantizer
26
+
27
+
28
+ class PT2EQuantizer(BaseQuantizer):
29
+ """
30
+ Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
31
+ """
32
+
33
+ def prepare(
34
+ self,
35
+ model: torch.nn.Module,
36
+ args: Optional[Any] = None,
37
+ kwargs: Optional[Dict[str, Any]] = None,
38
+ ):
39
+ """
40
+ Prepare the model for pt2e quantization.
41
+
42
+ Registers activation observers using the provided example inputs.
43
+
44
+ Parameters:
45
+ model: The target PyTorch model.
46
+ args: Positional example inputs required for capturing graph.
47
+ kwargs: Keyword example inputs required for capturing graph.
48
+
49
+ Returns:
50
+ The model prepared for pt2e quantization.
51
+ """
52
+ # Program capture
53
+ assert isinstance(args, tuple)
54
+ model = torch.export.export_for_training(
55
+ model, args=args, kwargs=kwargs
56
+ ).module()
57
+ quantizer = PT2EAnnotator()
58
+ quantizer = quantizer.set_global(get_asymmetric_quantization_config())
59
+
60
+ # Register observers in each nodes
61
+ assert isinstance(model, torch.fx.GraphModule)
62
+ model = prepare_pt2e(model, quantizer)
63
+
64
+ return model
65
+
66
+ def convert(self, model: torch.fx.GraphModule):
67
+ """
68
+ Convert the prepared model to its pt2e quantized version.
69
+
70
+ Applies the pt2e quantization on activations based on the collected statistics.
71
+
72
+ Parameters:
73
+ model: The prepared PyTorch model.
74
+
75
+ Returns:
76
+ The quantized model.
77
+ """
78
+ return convert_pt2e(model)
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,58 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
17
+
18
+
19
+ def convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
20
+ """
21
+ Convert scalar values in the graph to `get_attr` nodes.
22
+
23
+ This function identifies scalar constants in the graph and transforms them
24
+ into `get_attr` nodes to ensure compatibility with quantization workflows.
25
+ """
26
+ for n in model.graph.nodes:
27
+ if n.op != "call_function" or n.target not in [
28
+ # The operators that have scalar parameters.
29
+ torch.ops.aten.add.Tensor,
30
+ ]:
31
+ continue
32
+ args = list(n.args)
33
+ new_args = []
34
+ for arg in args:
35
+ if isinstance(arg, torch.fx.Node):
36
+ new_args.append(arg)
37
+ continue
38
+
39
+ assert isinstance(arg, float)
40
+ prefix = "_tensor_constant_"
41
+ get_new_attr_name = get_new_attr_name_with_prefix(prefix)
42
+ tensor_constant_name = get_new_attr_name(model)
43
+ float_tensor = torch.tensor(float(arg))
44
+ model.register_buffer(tensor_constant_name, float_tensor)
45
+
46
+ fake_mode = n.meta["val"].fake_mode
47
+ with model.graph.inserting_before(n):
48
+ get_attr_node = model.graph.create_node(
49
+ "get_attr", tensor_constant_name, (), {}
50
+ )
51
+ get_attr_node.meta["val"] = fake_mode.from_tensor(
52
+ float_tensor, static_shapes=True
53
+ )
54
+ new_args.append(get_attr_node)
55
+ n.args = tuple(new_args)
56
+ model.recompile()
57
+
58
+ return model
@@ -0,0 +1,138 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, List, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.ao.quantization.quantizer import QuantizationSpec
21
+ from torch.ao.quantization.quantizer.utils import _get_module_name_filter
22
+ from torch.utils import _pytree as pytree
23
+
24
+ from tico.experimental.quantization.algorithm.pt2e.annotation.config import (
25
+ QuantizationConfig,
26
+ )
27
+
28
+
29
+ def get_module_type_filter(tp: Callable):
30
+ """
31
+ Get the module_type_filter function for a given module type.
32
+
33
+ The filter accepts a node and checks if the node comes from a module
34
+ that has certain module type.
35
+
36
+ For example:
37
+ node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
38
+
39
+
40
+ >> module_type_filter = get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
41
+ >> print(module_type_filter(node))
42
+ True # the node is from the submodule `Sub`
43
+ """
44
+
45
+ tp_str = tp.__module__ + "." + tp.__qualname__
46
+
47
+ def module_type_filter(n: torch.fx.Node) -> bool:
48
+ # example: {
49
+ # 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
50
+ # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
51
+ # }
52
+ nn_module_stack = n.meta.get("nn_module_stack", {})
53
+ types = []
54
+ for _, t in nn_module_stack.values():
55
+ # export() returns str, but older APIs (e.g. capture_pre_autograd_graph)
56
+ # return type. Handle both cases.
57
+ if isinstance(t, type):
58
+ t = t.__module__ + "." + t.__qualname__
59
+ types.append(t)
60
+ return tp_str in types
61
+
62
+ return module_type_filter
63
+
64
+
65
+ def get_not_module_type_or_name_filter(
66
+ tp_list: List[Callable], module_name_list: List[str]
67
+ ) -> Callable[[torch.fx.Node], bool]:
68
+ module_type_filters = [get_module_type_filter(tp) for tp in tp_list]
69
+ module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
70
+
71
+ def not_module_type_or_name_filter(n: torch.fx.Node) -> bool:
72
+ return not any(f(n) for f in module_type_filters + module_name_list_filters)
73
+
74
+ return not_module_type_or_name_filter
75
+
76
+
77
+ def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
78
+ if quantization_config is None:
79
+ return None
80
+ if quantization_config.input_activation is None:
81
+ return None
82
+ quantization_spec: QuantizationSpec = quantization_config.input_activation
83
+ assert quantization_spec.qscheme in [
84
+ torch.per_tensor_affine,
85
+ ]
86
+ return quantization_spec
87
+
88
+
89
+ def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
90
+ if quantization_config is None:
91
+ return None
92
+ if quantization_config.output_activation is None:
93
+ return None
94
+ quantization_spec: QuantizationSpec = quantization_config.output_activation
95
+ assert quantization_spec.qscheme in [
96
+ torch.per_tensor_affine,
97
+ ]
98
+ return quantization_spec
99
+
100
+
101
+ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
102
+ if quantization_config is None:
103
+ return None
104
+ if quantization_config.weight is None:
105
+ return None
106
+ quantization_spec: QuantizationSpec = quantization_config.weight
107
+ if quantization_spec.qscheme not in [
108
+ torch.per_tensor_affine,
109
+ torch.per_channel_affine,
110
+ ]:
111
+ raise ValueError(
112
+ f"Unsupported quantization_spec {quantization_spec} for weight"
113
+ )
114
+ return quantization_spec
115
+
116
+
117
+ def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
118
+ if quantization_config is None:
119
+ return None
120
+ if quantization_config.bias is None:
121
+ return None
122
+ quantization_spec: QuantizationSpec = quantization_config.bias
123
+ return quantization_spec
124
+
125
+
126
+ def is_annotated(nodes: List[torch.fx.Node] | torch.fx.Node):
127
+ """
128
+ Check if any of the node in the given list is annotated.
129
+ """
130
+ annotated = False
131
+ if isinstance(nodes, torch.fx.Node):
132
+ nodes = [nodes]
133
+ for node in nodes:
134
+ annotated = annotated or (
135
+ "quantization_annotation" in node.meta
136
+ and node.meta["quantization_annotation"]._annotated
137
+ )
138
+ return annotated
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,78 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from typing import Any, Dict, List
17
+
18
+ import torch
19
+
20
+
21
+ class ChannelwiseMaxActsObserver:
22
+ """
23
+ Observer to calcuate channelwise maximum activation
24
+ """
25
+
26
+ def __init__(self, model):
27
+ """
28
+ model
29
+ A torch module whose activations are to be analyzed.
30
+ hooks
31
+ A list to store the hooks which are registered to collect activation statistics.
32
+ max_acts
33
+ A dictionary to store the maximum activation values
34
+ """
35
+ self.model = model
36
+ self.hooks: List[Any] = []
37
+ self.max_acts: Dict[str, torch.Tensor] = {}
38
+
39
+ def attach(self):
40
+ """
41
+ Attach hooks to compute the maximum activation values per channel by running the given model
42
+ on a dataset.
43
+
44
+ WHAT IT DOES:
45
+ Set hooks to collect activation values at the per-channel level.
46
+ For each channel, it will calculate the maximum observed activation across
47
+ all processed samples.
48
+ """
49
+ self.model.eval()
50
+
51
+ def stat_tensor(name, tensor: torch.Tensor):
52
+ hidden_dim = tensor.shape[-1]
53
+ tensor = tensor.view(-1, hidden_dim).abs().detach()
54
+ coming_max = torch.max(tensor, dim=0)[0]
55
+ if name in self.max_acts:
56
+ self.max_acts[name] = torch.max(self.max_acts[name], coming_max)
57
+ else:
58
+ self.max_acts[name] = coming_max
59
+
60
+ def stat_input_hook(m, input, name):
61
+ if isinstance(input, tuple):
62
+ input = input[0]
63
+ stat_tensor(name, input)
64
+
65
+ for name, m in self.model.named_modules():
66
+ if isinstance(m, torch.nn.Linear):
67
+ self.hooks.append(
68
+ m.register_forward_pre_hook(
69
+ functools.partial(stat_input_hook, name=name)
70
+ )
71
+ )
72
+
73
+ def remove(self):
74
+ for hook in self.hooks:
75
+ hook.remove()
76
+
77
+ def get_max_acts(self):
78
+ return self.max_acts
@@ -0,0 +1,81 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+
19
+ from tico.experimental.quantization.algorithm.smoothquant.observer import (
20
+ ChannelwiseMaxActsObserver,
21
+ )
22
+
23
+ from tico.experimental.quantization.algorithm.smoothquant.smooth_quant import (
24
+ apply_smoothing,
25
+ )
26
+ from tico.experimental.quantization.config import SmoothQuantConfig
27
+ from tico.experimental.quantization.quantizer import BaseQuantizer
28
+
29
+
30
+ class SmoothQuantQuantizer(BaseQuantizer):
31
+ """
32
+ Quantizer for applying the SmoothQuant algorithm
33
+ """
34
+
35
+ def __init__(self, config: SmoothQuantConfig):
36
+ super().__init__(config)
37
+
38
+ self.alpha = config.alpha
39
+ self.custom_alpha_map = config.custom_alpha_map
40
+ self.observer: Optional[ChannelwiseMaxActsObserver] = None
41
+
42
+ @torch.no_grad()
43
+ def prepare(
44
+ self,
45
+ model: torch.nn.Module,
46
+ args: Optional[Any] = None,
47
+ kwargs: Optional[Dict[str, Any]] = None,
48
+ ):
49
+ """
50
+ Parameters:
51
+ model: The target PyTorch model.
52
+ args: Positional example inputs required for capturing graph.
53
+ kwargs: Keyword example inputs required for capturing graph.
54
+
55
+ Returns:
56
+ The model prepared for SmoothQuant quantization.
57
+ """
58
+ self.observer = ChannelwiseMaxActsObserver(model)
59
+ self.observer.attach()
60
+
61
+ return model
62
+
63
+ @torch.no_grad()
64
+ def convert(self, model):
65
+ """
66
+ Convert the prepared model to its SmoothQuant quantized version.
67
+ Applies the SmoothQuant quantization on weights based on the collected statistics.
68
+
69
+ Parameters:
70
+ model: The prepared PyTorch model.
71
+
72
+ Returns:
73
+ The quantized model.
74
+ """
75
+ if self.observer is not None:
76
+ self.observer.remove()
77
+ apply_smoothing(
78
+ model, self.observer.get_max_acts(), self.alpha, self.custom_alpha_map
79
+ )
80
+
81
+ return model
@@ -0,0 +1,164 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional
16
+
17
+ import torch
18
+
19
+
20
+ @torch.no_grad()
21
+ def smooth_weights(
22
+ front_module: torch.nn.Module,
23
+ back_modules: torch.nn.Module | List[torch.nn.Module],
24
+ activation_max: torch.Tensor,
25
+ alpha: float,
26
+ ):
27
+ """
28
+ Applies SmoothQuant-style smoothing to the weights and biases of two
29
+ connected modules using activation maximum values.
30
+
31
+ NOTE All modules **MUST** have `weight` and optionally `bias` attributes.
32
+
33
+ Parameters
34
+ -----------
35
+ front_module
36
+ The front module whose weights and biases will be adjusted.
37
+ back_modules
38
+ A list of back modules whose weights and biases will be adjusted.
39
+ activation_max
40
+ A tensor of channel-wise maximum activation values for the front module.
41
+ alpha
42
+ The smoothing factor that determines the scaling for weight adjustments.
43
+
44
+ Raises
45
+ -------
46
+ AttributeError
47
+ If `front_module` or any module in `back_modules` does not have `weight` attributes.
48
+ ValueError
49
+ If the shape of tensors in `activation_max` does not match the number of channels
50
+ in `front_module`'s weight.
51
+ NoteImplementedError
52
+ If `front_module` or any module in `back_modules` is of an unsupported type.
53
+ """
54
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
55
+
56
+ if not isinstance(back_modules, list):
57
+ back_modules = [back_modules]
58
+
59
+ # Check attributes
60
+ if not hasattr(front_module, "weight"):
61
+ raise AttributeError(
62
+ f"The front module '{type(front_module).__name__}' does not have a 'weight' attribute."
63
+ )
64
+ for back_m in back_modules:
65
+ if not hasattr(back_m, "weight"):
66
+ raise AttributeError(
67
+ f"The front module '{type(back_m).__name__}' does not have a 'weight' attribute."
68
+ )
69
+ # Check shapes
70
+ if isinstance(front_module, LlamaRMSNorm):
71
+ front_numel = front_module.weight.numel()
72
+ else:
73
+ raise NotImplementedError(
74
+ f"Unsupported module type: {type(front_module).__name__}"
75
+ )
76
+ for back_m in back_modules:
77
+ if isinstance(back_m, torch.nn.Linear):
78
+ back_numel = back_m.in_features
79
+ else:
80
+ raise NotImplementedError(
81
+ f"Unsupported module type: {type(front_module).__name__}"
82
+ )
83
+
84
+ if front_numel != back_numel or back_numel != activation_max.numel():
85
+ raise ValueError(
86
+ f"Shape mismatch: front_numel({front_numel}), back_numel({back_numel}), activation_max_numel({activation_max.numel()})"
87
+ )
88
+
89
+ # Compute scales
90
+ device, dtype = back_modules[0].weight.device, back_modules[0].weight.dtype
91
+ activation_max = activation_max.to(device=device, dtype=dtype) # type: ignore[arg-type]
92
+ weight_scales = torch.cat(
93
+ [back_m.weight.abs().max(dim=0, keepdim=True)[0] for back_m in back_modules], # type: ignore[operator]
94
+ dim=0,
95
+ )
96
+ weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
97
+ scales = (
98
+ (activation_max.pow(alpha) / weight_scales.pow(1 - alpha))
99
+ .clamp(min=1e-5)
100
+ .to(device) # type: ignore[arg-type]
101
+ .to(dtype) # type: ignore[arg-type]
102
+ )
103
+
104
+ # Smooth
105
+ front_module.weight.div_(scales)
106
+ if hasattr(front_module, "bias"):
107
+ front_module.bias.div_(scales)
108
+
109
+ for back_m in back_modules:
110
+ back_m.weight.mul_(scales.view(1, -1)) # type: ignore[operator]
111
+
112
+
113
+ @torch.no_grad()
114
+ def apply_smoothing(
115
+ model: torch.nn.Module,
116
+ activation_max: Dict[str, torch.Tensor],
117
+ alpha: float = 0.5,
118
+ custom_alpha_map: Optional[Dict[str, float]] = None,
119
+ ):
120
+ """
121
+ Applies SmoothQuant-style smoothing to the model's weights using activation maximum values.
122
+
123
+ Parameters
124
+ -----------
125
+ model
126
+ A torch module whose weights will be smoothed.
127
+ activation_max
128
+ The channel-wise maximum activation values for the model.
129
+ alpha
130
+ The default smoothing factor to apply across all modules.
131
+ custom_alpha_map
132
+ A dictionary mapping layer/module names to custom alpha values.
133
+ Layers specified in this dictionary will use the corresponding alpha
134
+ value instead of the default.
135
+ """
136
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
137
+
138
+ for name, module in model.named_modules():
139
+ alpha_to_apply = alpha
140
+ if custom_alpha_map and name in custom_alpha_map:
141
+ alpha_to_apply = custom_alpha_map[name]
142
+ if alpha_to_apply > 1.0:
143
+ raise RuntimeError(
144
+ f"Alpha value cannot exceed 1.0. Given alpha: {alpha_to_apply}"
145
+ )
146
+ # SmoothQuant is applied before capturing the graph. Therefore, it needs to know
147
+ # specific module information.
148
+ # TODO Suport more modules.
149
+ if isinstance(module, LlamaDecoderLayer):
150
+ attn_ln = module.input_layernorm
151
+ qkv = [
152
+ module.self_attn.q_proj,
153
+ module.self_attn.k_proj,
154
+ module.self_attn.v_proj,
155
+ ]
156
+
157
+ qkv_input_scales = activation_max[name + ".self_attn.q_proj"]
158
+ smooth_weights(attn_ln, qkv, qkv_input_scales, alpha_to_apply)
159
+
160
+ ffn_ln = module.post_attention_layernorm
161
+ fcs = [module.mlp.gate_proj, module.mlp.up_proj]
162
+ fcs_input_scales = activation_max[name + ".mlp.gate_proj"]
163
+
164
+ smooth_weights(ffn_ln, fcs, fcs_input_scales, alpha_to_apply)
@@ -0,0 +1,68 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Dict, Optional
17
+
18
+
19
+ class BaseConfig(ABC):
20
+ """
21
+ Base configuration class for quantization.
22
+ """
23
+
24
+ @property
25
+ @abstractmethod
26
+ def name(self) -> str:
27
+ pass
28
+
29
+
30
+ class PT2EConfig(BaseConfig):
31
+ """
32
+ Configuration for pytorch 2.0 export quantization.
33
+ """
34
+
35
+ @property
36
+ def name(self) -> str:
37
+ return "pt2e"
38
+
39
+
40
+ class GPTQConfig(BaseConfig):
41
+ """
42
+ Configuration for GPTQ.
43
+ """
44
+
45
+ def __init__(self, verbose: bool = False):
46
+ self.verbose = verbose
47
+
48
+ @property
49
+ def name(self) -> str:
50
+ return "gptq"
51
+
52
+
53
+ class SmoothQuantConfig(BaseConfig):
54
+ """
55
+ Configuration for smooth quant.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ alpha: float = 0.5,
61
+ custom_alpha_map: Optional[Dict[str, float]] = None,
62
+ ):
63
+ self.alpha = alpha
64
+ self.custom_alpha_map = custom_alpha_map
65
+
66
+ @property
67
+ def name(self) -> str:
68
+ return "smooth_quant"
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE