tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,141 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import copy
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+
24
+ from tico.serialize.quant_param import QPARAM_KEY
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import (
29
+ CatArgs,
30
+ NegArgs,
31
+ PermuteArgs,
32
+ ReshapeArgs,
33
+ SliceArgs,
34
+ )
35
+
36
+
37
+ @trace_graph_diff_on_pass
38
+ class PropagateQParamForward(PassBase):
39
+ """
40
+ A pass propagates quantization parameters through operations that do not alter them.
41
+
42
+ This pass identifies and propagates quantization parameters through operations that
43
+ do not change their values, such as `permute`, `reshape`, `transpose`, `view` and
44
+ similar tensor transformations.
45
+
46
+ By ensuring that quantization parameters remain consistent across such operations,
47
+ this pass helps maintain correctness in quantization-aware representations.
48
+ """
49
+
50
+ def __init__(self):
51
+ super().__init__()
52
+
53
+ def call(self, exported_program: ExportedProgram) -> PassResult:
54
+ logger = logging.getLogger(__name__)
55
+
56
+ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
57
+ if QPARAM_KEY not in src.meta:
58
+ return
59
+
60
+ if (
61
+ QPARAM_KEY in dst.meta
62
+ and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
63
+ ):
64
+ return
65
+
66
+ dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
67
+
68
+ logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
69
+
70
+ graph_module = exported_program.graph_module
71
+ graph: torch.fx.Graph = graph_module.graph
72
+ for node in graph.nodes:
73
+ if node.op != "call_function":
74
+ continue
75
+ if node.target == torch.ops.aten.permute.default:
76
+ permute_args = PermuteArgs(*node.args, **node.kwargs)
77
+ _propagate_qparam_if_possible(permute_args.input, node)
78
+ elif node.target == torch.ops.aten.reshape.default:
79
+ reshape_args = ReshapeArgs(*node.args, **node.kwargs)
80
+ _propagate_qparam_if_possible(reshape_args.input, node)
81
+ elif node.target == torch.ops.aten.slice.Tensor:
82
+ slice_args = SliceArgs(*node.args, **node.kwargs)
83
+ _propagate_qparam_if_possible(slice_args.input, node)
84
+ elif node.target == torch.ops.aten.neg.default:
85
+ neg_args = NegArgs(*node.args, **node.kwargs)
86
+
87
+ if QPARAM_KEY not in neg_args.input.meta:
88
+ continue
89
+ # Only support int16 for now
90
+ if neg_args.input.meta[QPARAM_KEY].dtype != "int16":
91
+ continue
92
+
93
+ _propagate_qparam_if_possible(neg_args.input, node)
94
+
95
+ elif node.target == torch.ops.aten.cat.default:
96
+ concat_args = CatArgs(*node.args, **node.kwargs)
97
+ concat_inputs = concat_args.tensors
98
+
99
+ cond = True
100
+ for concat_input in concat_inputs:
101
+ # Check all inputs have qparam
102
+ if QPARAM_KEY not in concat_input.meta:
103
+ cond = False
104
+ break
105
+
106
+ # Only support int16 for now
107
+ if concat_input.meta[QPARAM_KEY].dtype != "int16":
108
+ cond = False
109
+ break
110
+
111
+ if concat_input.meta[QPARAM_KEY].scale is None:
112
+ cond = False
113
+ break
114
+
115
+ if len(concat_input.meta[QPARAM_KEY].scale) != 1:
116
+ cond = False
117
+ break
118
+
119
+ if not cond:
120
+ continue
121
+
122
+ # Find max scale node
123
+ max_scale = 0.0
124
+ max_scale_node = None
125
+ for concat_input in concat_inputs:
126
+ scale = concat_input.meta[QPARAM_KEY].scale[0]
127
+ if max_scale < scale:
128
+ max_scale = scale
129
+ max_scale_node = concat_input
130
+
131
+ assert max_scale_node is not None
132
+ _propagate_qparam_if_possible(max_scale_node, node)
133
+
134
+ # TODO Support more ops.
135
+
136
+ graph.eliminate_dead_code()
137
+ graph.lint()
138
+ graph_module.recompile()
139
+
140
+ # Run only once.
141
+ return PassResult(False)
@@ -0,0 +1,168 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch._subclasses.fake_tensor import FakeTensor
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
24
+ from tico.utils import logging
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.validate_args_kwargs import (
28
+ DequantizePerChannelArgs,
29
+ DequantizePerTensorArgs,
30
+ )
31
+
32
+
33
+ def get_constant(exported_program: ExportedProgram, node: torch.fx.Node):
34
+ assert isinstance(node, torch.fx.Node)
35
+ if node.name in exported_program.constants:
36
+ return exported_program.constants[node.name]
37
+ elif node.name in exported_program.graph_signature.inputs_to_buffers:
38
+ buffer_name = exported_program.graph_signature.inputs_to_buffers[node.name]
39
+ named_buffer = dict(exported_program.named_buffers())
40
+ return named_buffer[buffer_name]
41
+ else:
42
+ raise RuntimeError("NYI constant")
43
+
44
+
45
+ class ValRange:
46
+ def __init__(self, val: Union[torch.Tensor, List[int]]):
47
+ if isinstance(val, torch.Tensor):
48
+ self.max = torch.max(val).item()
49
+ self.min = torch.min(val).item()
50
+ elif type(val) == list:
51
+ self.max = max(val)
52
+ self.min = min(val)
53
+ else:
54
+ raise RuntimeError("Wrong dtype (val)")
55
+
56
+ def within(self, min_val, max_val):
57
+ return self.min >= min_val and self.max <= max_val
58
+
59
+
60
+ # Infer dtype using weight, zero point, and dtype
61
+ def infer_dtype(weight: torch.Tensor, zerop: List[int], dtype: torch.dtype) -> str:
62
+ weight_val = ValRange(weight)
63
+ zp_val = ValRange(zerop)
64
+
65
+ if weight_val.within(0, 15) and zp_val.within(0, 15) and dtype == torch.uint8:
66
+ return "uint4"
67
+ else:
68
+ return to_qparam_dtype(dtype)
69
+
70
+
71
+ @trace_graph_diff_on_pass
72
+ class RemoveWeightDequantOp(PassBase):
73
+ """
74
+ This pass identifies and removes any remaining Dequantize ops associated with
75
+ quantized weights.
76
+
77
+ Since weights already quantized earlier (and possibly kept in float by
78
+ attaching a DQ), the final stage of the quantization pipeline typically
79
+ does not require those DQ ops anymore.
80
+
81
+ NOTE Removing 'DQ' causes a sementic change: f32 -> quantized
82
+
83
+ [BEFORE]
84
+ W (quantized) - Dequantize (float)
85
+
86
+ [AFTER]
87
+ W (quantized)
88
+ """
89
+
90
+ def __init__(self):
91
+ super().__init__()
92
+
93
+ def call(self, exported_program: ExportedProgram) -> PassResult:
94
+ logger = logging.getLogger(__name__)
95
+
96
+ graph_module = exported_program.graph_module
97
+ graph: torch.fx.Graph = graph_module.graph
98
+ for dq in graph.nodes:
99
+ if not dq.op == "call_function":
100
+ continue
101
+
102
+ if dq.target not in [
103
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
104
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
105
+ ]:
106
+ continue
107
+ dq_args: Optional[DequantizePerChannelArgs | DequantizePerTensorArgs] = None
108
+
109
+ if (
110
+ dq.target
111
+ == torch.ops.quantized_decomposed.dequantize_per_channel.default
112
+ ):
113
+ dq_args = DequantizePerChannelArgs(*dq.args, *dq.kwargs)
114
+ elif (
115
+ dq.target
116
+ == torch.ops.quantized_decomposed.dequantize_per_tensor.default
117
+ ):
118
+ dq_args = DequantizePerTensorArgs(*dq.args, *dq.kwargs)
119
+ else:
120
+ raise RuntimeError(f"Invalid DQ target: {dq.target}")
121
+
122
+ q_weight = dq_args.input
123
+ # All weights are placehoders.
124
+ if q_weight.op != "placeholder":
125
+ continue
126
+ # Check if DQ already has quant param because DQ can be shared.
127
+ if QPARAM_KEY in q_weight.meta:
128
+ continue
129
+
130
+ q_weight_meta = q_weight.meta["val"]
131
+ assert isinstance(q_weight_meta, FakeTensor)
132
+ # Weight should have quantized values.
133
+ assert q_weight_meta.dtype != torch.float
134
+
135
+ q_weight_val = get_constant(exported_program, q_weight)
136
+ assert isinstance(q_weight_val, torch.Tensor)
137
+
138
+ quant_param = QuantParam()
139
+ if isinstance(dq_args, DequantizePerChannelArgs):
140
+ scales = get_constant(exported_program, dq_args.scales)
141
+ zero_ps = get_constant(exported_program, dq_args.zero_points)
142
+ quant_param.scale = scales.tolist()
143
+ quant_param.zero_point = zero_ps.tolist()
144
+ assert quant_param.zero_point is not None # To avoid mypy error
145
+ quant_param.quantized_dimension = dq_args.axis
146
+ quant_param.dtype = infer_dtype(
147
+ q_weight_val, quant_param.zero_point, q_weight_meta.dtype
148
+ )
149
+ elif isinstance(dq_args, DequantizePerTensorArgs):
150
+ quant_param.scale = [dq_args.scale]
151
+ quant_param.zero_point = [dq_args.zero_point]
152
+ assert quant_param.zero_point is not None # To avoid mypy error
153
+ quant_param.dtype = infer_dtype(
154
+ q_weight_val, quant_param.zero_point, q_weight_meta.dtype
155
+ )
156
+ else:
157
+ raise RuntimeError(f"Invalid DQ target: {dq.target}")
158
+
159
+ q_weight.meta[QPARAM_KEY] = quant_param
160
+ dq.replace_all_uses_with(q_weight, propagate_meta=False)
161
+ logger.debug(f"{dq.name} is removed.")
162
+
163
+ graph.eliminate_dead_code()
164
+ graph.lint()
165
+ graph_module.recompile()
166
+
167
+ # Run only once.
168
+ return PassResult(False)
@@ -0,0 +1,108 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ from typing import Any, Dict, Optional, Type
17
+
18
+ import torch
19
+
20
+ from tico.experimental.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
+ from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
+ from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
23
+ SmoothQuantQuantizer,
24
+ )
25
+ from tico.experimental.quantization.config import BaseConfig
26
+ from tico.experimental.quantization.quantizer import BaseQuantizer
27
+
28
+
29
+ config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
30
+ "pt2e": PT2EQuantizer,
31
+ "gptq": GPTQQuantizer,
32
+ "smooth_quant": SmoothQuantQuantizer,
33
+ }
34
+
35
+ QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
36
+
37
+
38
+ def prepare(
39
+ model: torch.nn.Module,
40
+ quant_config: BaseConfig,
41
+ args: Optional[Any] = None,
42
+ kwargs: Optional[Dict[str, Any]] = None,
43
+ inplace: Optional[bool] = False,
44
+ ):
45
+ """
46
+ Prepare the model for quantization using the provided configuration.
47
+
48
+ Determines the appropriate quantizer based on the type of `quant_config` and
49
+ prepares the model accordingly.
50
+
51
+ Parameters:
52
+ model: The PyTorch model to be quantized.
53
+ quant_config (BaseConfig): The quantization configuration.
54
+ args (Any, optional): Positional example inputs required for activation quantization.
55
+ kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
56
+ inplace (bool, optional): If true, the model will be modified in place;
57
+ otherwise, a new prepared model is returned.
58
+
59
+ Returns:
60
+ The model prepared for quantization.
61
+ """
62
+ if quant_config.name == "pt2e" and inplace:
63
+ raise RuntimeError(
64
+ "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
65
+ )
66
+
67
+ model = model if inplace else copy.deepcopy(model)
68
+
69
+ quantizer = config_to_quantizer[quant_config.name](quant_config)
70
+ model = quantizer.prepare(model, args, kwargs)
71
+ setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
72
+
73
+ return model
74
+
75
+
76
+ def convert(model, inplace: Optional[bool] = False):
77
+ """
78
+ Convert the prepared model to a quantized model using the provided configuration.
79
+
80
+ Determines the appropriate quantizer based on the type of quant_config and
81
+ converts the model accordingly.
82
+
83
+ Parameters:
84
+ model: The prepared PyTorch model.
85
+ inplace (bool, optional): If true, the model will be modified in place;
86
+ otherwise, a new prepared model is returned.
87
+
88
+ Returns:
89
+ The quantized model.
90
+ """
91
+ # Get quantizer first before calling deepcopy that does not copy attributes properly.
92
+ if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
93
+ quantizer = getattr(model, QUANTIZER_ATTRIBUTE_NAME)
94
+ delattr(model, QUANTIZER_ATTRIBUTE_NAME)
95
+ else:
96
+ raise RuntimeError("Call prepare() function first.")
97
+
98
+ if isinstance(quantizer, PT2EQuantizer) and inplace:
99
+ raise RuntimeError(
100
+ "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
101
+ )
102
+
103
+ model = model if inplace else copy.deepcopy(model)
104
+
105
+ assert isinstance(quantizer, BaseQuantizer)
106
+ model = quantizer.convert(model)
107
+
108
+ return model
@@ -0,0 +1,71 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+
20
+ from tico.experimental.quantization.config import BaseConfig
21
+
22
+
23
+ class BaseQuantizer(ABC):
24
+ """
25
+ Abstract base class for quantizers that apply a quantization algorithm to a target model.
26
+ """
27
+
28
+ def __init__(self, config: BaseConfig):
29
+ """
30
+ Initialize the quantizer with the given configuration.
31
+
32
+ Parameters:
33
+ config (BaseConfig): Quantization configuration parameters.
34
+ """
35
+ self.config = config
36
+
37
+ @abstractmethod
38
+ def prepare(
39
+ self,
40
+ model: torch.nn.Module,
41
+ args: Optional[Any] = None,
42
+ kwargs: Optional[Dict[str, Any]] = None,
43
+ ):
44
+ """
45
+ Prepare the given model for quantization based on the provided algorithm-specific
46
+ configuration. This involves setting up necessary observers or hooks, and may
47
+ optionally use example inputs—particularly useful for activation quantization.
48
+
49
+ Parameters:
50
+ model: The target PyTorch model.
51
+ args (Any, optional): Positional example inputs required for activation quantization.
52
+ kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
53
+
54
+ Returns:
55
+ The prepared model.
56
+ """
57
+ pass
58
+
59
+ @abstractmethod
60
+ def convert(self, model):
61
+ """
62
+ Convert the prepared (or calibrated) model into its quantized form. This function leverages
63
+ the statistics collected during calibration to perform the quantization transformation.
64
+
65
+ Parameters:
66
+ model: The prepared PyTorch model.
67
+
68
+ Returns:
69
+ The quantized model.
70
+ """
71
+ pass
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,116 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ import torch
19
+ from circle_schema import circle
20
+
21
+ from tico.interpreter.interpreter import Interpreter
22
+ from tico.serialize.circle_mapping import np_dtype_from_circle_dtype, to_circle_dtype
23
+
24
+
25
+ def preprocess_inputs(inputs: Any):
26
+ """
27
+ Preprocess user inputs for circle inference.
28
+
29
+ 1. None inputs are ignored.
30
+ 2. A list/tuple input is flatten when a torch module is exported.
31
+ e.g. inputs = (torch.Tensor, [2,3,4]) -> inputs = (torch.Tensor, 2, 3, 4)
32
+ """
33
+ l = []
34
+ for value in inputs:
35
+ if value == None:
36
+ continue
37
+ if isinstance(value, (tuple, list)):
38
+ for val in value:
39
+ l.append(val)
40
+ else:
41
+ l.append(value)
42
+ # Check if it is a list of a list.
43
+ if any(isinstance(item, (tuple, list)) for item in l):
44
+ l = preprocess_inputs(l)
45
+ return tuple(l)
46
+
47
+
48
+ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any:
49
+ # When converting a model, it is assumed that the order of keyword arguments is maintained.
50
+ user_inputs = args + tuple(kwargs.values())
51
+ user_inputs = preprocess_inputs(user_inputs)
52
+ # Cast them to torch.Tensor to make it simple.
53
+ user_inputs = tuple(
54
+ torch.tensor(user_input) if type(user_input) != torch.Tensor else user_input
55
+ for user_input in user_inputs
56
+ )
57
+
58
+ # Get input spec from circle binary.
59
+ model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
60
+ assert model.SubgraphsLength() == 1
61
+ graph = model.Subgraphs(0)
62
+ model_input_tensors = [
63
+ graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())
64
+ ]
65
+ model_input_shapes_np = [t.ShapeAsNumpy() for t in model_input_tensors]
66
+ model_input_types_cm = [t.Type() for t in model_input_tensors]
67
+
68
+ # Check if given inputs' dtype and shape from users match the inputs' from model binary.
69
+ if len(model_input_shapes_np) != len(user_inputs):
70
+ raise RuntimeError(
71
+ f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})"
72
+ )
73
+ for input_idx, user_input in enumerate(user_inputs):
74
+ # Shape
75
+ if list(user_input.shape) != list(model_input_shapes_np[input_idx]):
76
+ raise RuntimeError(
77
+ f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})"
78
+ )
79
+ # Data type
80
+ user_input_type_cm = to_circle_dtype(user_input.dtype)
81
+ if user_input_type_cm != model_input_types_cm[input_idx]:
82
+ raise RuntimeError(
83
+ f"Mismatch input {input_idx} data type : input({user_input_type_cm}) != circle model({model_input_types_cm[input_idx]})"
84
+ )
85
+
86
+ # Initialize interpreter
87
+ intp = Interpreter(circle_binary)
88
+
89
+ # Set input
90
+ for input_idx, user_input in enumerate(user_inputs):
91
+ intp.writeInputTensor(input_idx, user_input)
92
+
93
+ # Interpret
94
+ intp.interpret()
95
+
96
+ # Retrieve outputs' dtype and shape from circle model
97
+ model_output_tensors = [
98
+ graph.Tensors(graph.Outputs(o)) for o in range(graph.OutputsLength())
99
+ ]
100
+ model_output_shapes_np = [t.ShapeAsNumpy() for t in model_output_tensors]
101
+ model_output_types_cm = [t.Type() for t in model_output_tensors]
102
+
103
+ output = []
104
+ # Get output
105
+ for output_idx in range(len(model_output_tensors)):
106
+ result: np.ndarray = np.empty(
107
+ model_output_shapes_np[output_idx],
108
+ dtype=np_dtype_from_circle_dtype(model_output_types_cm[output_idx]),
109
+ )
110
+ intp.readOutputTensor(output_idx, result)
111
+ output.append(result)
112
+
113
+ if len(output) == 1:
114
+ return output[0]
115
+ else:
116
+ return output