tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
tico/serialize/pack.py ADDED
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+
17
+
18
+ def pack_buffer(flat_data: np.ndarray, dtype: str) -> np.ndarray:
19
+ assert len(flat_data.shape) == 1
20
+
21
+ if dtype == "uint4":
22
+ if flat_data.dtype != np.uint8:
23
+ raise RuntimeError("uint4 data should be saved in uint8.")
24
+
25
+ numel = flat_data.shape[0]
26
+ packed = np.zeros((numel + 1) // 2, dtype=np.uint8)
27
+ for i in range(numel):
28
+ assert flat_data[i] >= 0 and flat_data[i] <= 15
29
+ if i % 2 == 0:
30
+ packed[i // 2] = flat_data[i]
31
+ else:
32
+ packed[i // 2] |= flat_data[i] << 4
33
+ return packed
34
+ else:
35
+ raise NotImplementedError(f"NYI dtype: {dtype}")
@@ -0,0 +1,42 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This is a key for torch.fx.Node's meta dict to save QuantParam
17
+
18
+ QuantParam can be retrieved as node.meta[QPARAM_KEY]
19
+ """
20
+ QPARAM_KEY = "_quantization_parameters_"
21
+
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional
24
+
25
+ import torch
26
+
27
+
28
+ @dataclass
29
+ class QuantParam:
30
+ scale: Optional[List[float]] = None
31
+ zero_point: Optional[List[int]] = None
32
+ quantized_dimension: Optional[int] = None
33
+ min: Optional[List[float]] = None
34
+ max: Optional[List[float]] = None
35
+ # NOTE We define dtype as a string to easily extend new dtypes (ex: uint4)
36
+ dtype: str = ""
37
+
38
+
39
+ def to_qparam_dtype(dtype: torch.dtype) -> str:
40
+ str_type = str(dtype)
41
+ assert str_type.startswith("torch.")
42
+ return str_type[6:]
tico/utils/__init__.py ADDED
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
tico/utils/convert.py ADDED
@@ -0,0 +1,296 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import operator
16
+ import os
17
+ from typing import Any, Dict, Optional, Tuple
18
+
19
+ import torch
20
+ from torch.export import export, ExportedProgram
21
+
22
+ from tico.config import CompileConfigBase, get_default_config
23
+ from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
24
+ from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
25
+ InsertQuantizeOnDtypeMismatch,
26
+ )
27
+ from tico.experimental.quantization.passes.propagate_qparam_backward import (
28
+ PropagateQParamBackward,
29
+ )
30
+ from tico.experimental.quantization.passes.propagate_qparam_forward import (
31
+ PropagateQParamForward,
32
+ )
33
+ from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
34
+ from tico.experimental.quantization.passes.remove_weight_dequant_op import (
35
+ RemoveWeightDequantOp,
36
+ )
37
+ from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
38
+ from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
39
+ from tico.passes.const_prop_pass import ConstPropPass
40
+ from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
41
+ from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
42
+ from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
43
+ from tico.passes.convert_to_relu6 import ConvertToReLU6
44
+ from tico.passes.decompose_addmm import DecomposeAddmm
45
+ from tico.passes.decompose_batch_norm import DecomposeBatchNorm
46
+ from tico.passes.decompose_fake_quantize import DecomposeFakeQuantize
47
+ from tico.passes.decompose_fake_quantize_tensor_qparams import (
48
+ DecomposeFakeQuantizeTensorQParams,
49
+ )
50
+ from tico.passes.decompose_group_norm import DecomposeGroupNorm
51
+ from tico.passes.decompose_grouped_conv2d import DecomposeGroupedConv2d
52
+ from tico.passes.decompose_slice_scatter import DecomposeSliceScatter
53
+ from tico.passes.extract_dtype_kwargs import ExtractDtypeKwargsPass
54
+ from tico.passes.fill_meta_val import FillMetaVal
55
+ from tico.passes.fuse_leading_unsqueeze_reshape import FuseLeadingUnsqueezeReshape
56
+ from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean
57
+ from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
58
+ from tico.passes.legalize_predefined_layout_operators import (
59
+ LegalizePreDefinedLayoutOperators,
60
+ )
61
+ from tico.passes.lower_pow2_to_mul import LowerPow2ToMul
62
+ from tico.passes.lower_to_resize_nearest_neighbor import LowerToResizeNearestNeighbor
63
+ from tico.passes.lower_to_slice import passes as LowerToSlicePasses
64
+ from tico.passes.merge_consecutive_cat import MergeConsecutiveCat
65
+ from tico.passes.remove_nop import RemoveNop
66
+ from tico.passes.remove_redundant_assert_nodes import RemoveRedundantAssertionNodes
67
+ from tico.passes.remove_redundant_expand import RemoveRedundantExpand
68
+ from tico.passes.remove_redundant_permute import passes as RemoveRedundantPermutePasses
69
+ from tico.passes.remove_redundant_reshape import passes as RemoveRedundantViewPasses
70
+ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
71
+ from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
72
+ from tico.passes.restore_linear import RestoreLinear
73
+ from tico.passes.segment_index_select import SegmentIndexSelectConst
74
+ from tico.serialize.circle_serializer import build_circle
75
+ from tico.serialize.operators.node_visitor import get_support_targets
76
+ from tico.utils import logging
77
+ from tico.utils.errors import NotYetSupportedError
78
+ from tico.utils.model import CircleModel
79
+ from tico.utils.passes import PassManager
80
+ from tico.utils.trace_decorators import (
81
+ trace_const_diff_on_func,
82
+ trace_graph_diff_on_func,
83
+ )
84
+ from tico.utils.utils import has_quantization_ops, SuppressWarning
85
+
86
+
87
+ @trace_const_diff_on_func
88
+ @trace_graph_diff_on_func
89
+ def traced_run_decompositions(exported_program: ExportedProgram):
90
+ """
91
+ Let's preserve convolution operators.
92
+ `run_decompositions()` converts all Conv-related Ops to generic `aten.convolution`.
93
+ But, we should re-convert them to specific circle ops such as CircleConv2D, TransposeConv, etc.
94
+ Therefore, we do not decompose Conv-related Ops and convert them directly to circle ops.
95
+ """
96
+
97
+ def run_decompositions_v25(ep: ExportedProgram):
98
+ _preserve_ops = (
99
+ torch.ops.aten.conv2d.default,
100
+ torch.ops.aten.conv2d.padding,
101
+ torch.ops.aten.conv1d.default,
102
+ torch.ops.aten.conv1d.padding,
103
+ torch.ops.aten.instance_norm.default,
104
+ torch.ops.aten._safe_softmax.default,
105
+ torch.ops.aten.relu6.default, # Do not decompose to hardtanh
106
+ torch.ops.aten.linear.default,
107
+ )
108
+ ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
109
+
110
+ return ep
111
+
112
+ def run_decompositions(ep: ExportedProgram):
113
+ _decomp_table = torch.export.default_decompositions() # type: ignore[attr-defined]
114
+ _preserve_ops = (
115
+ torch.ops.aten.conv2d.default,
116
+ torch.ops.aten.conv2d.padding,
117
+ torch.ops.aten.conv1d.default,
118
+ torch.ops.aten.conv1d.padding,
119
+ torch.ops.aten.instance_norm.default,
120
+ torch.ops.aten._safe_softmax.default,
121
+ torch.ops.aten.relu6.default, # Do not decompose to hardtanh
122
+ torch.ops.aten.prelu.default,
123
+ torch.ops.aten.linear.default,
124
+ )
125
+ for op in _preserve_ops:
126
+ if op in _decomp_table:
127
+ del _decomp_table[op]
128
+
129
+ ep = ep.run_decompositions(decomp_table=_decomp_table)
130
+ return ep
131
+
132
+ if torch.__version__.startswith("2.5"):
133
+ return run_decompositions_v25(exported_program)
134
+ elif (
135
+ torch.__version__.startswith("2.6")
136
+ or torch.__version__.startswith("2.7")
137
+ or torch.__version__.startswith("2.8")
138
+ ):
139
+ return run_decompositions(exported_program)
140
+ else:
141
+ raise RuntimeError(f"Unsupported PyTorch version: {torch.__version__}")
142
+
143
+
144
+ def check_unsupported_target(exported_program: ExportedProgram):
145
+ logger = logging.getLogger(__name__)
146
+
147
+ supported_target = list(get_support_targets())
148
+ # Ignore `getitem` since it is no-op for multiple outputs.
149
+ supported_target.append(operator.getitem)
150
+ unsupported = []
151
+ for n in exported_program.graph.nodes:
152
+ if n.op != "call_function":
153
+ continue
154
+ if not n.target in supported_target:
155
+ unsupported.append(n)
156
+
157
+ if unsupported:
158
+ for node in unsupported:
159
+ logger.error(
160
+ f"NOT SUPPORTED OPERATOR\n\t(op) {node.target.__name__}\n\t(trace) {node.meta.get('stack_trace')}"
161
+ )
162
+ raise NotYetSupportedError("NOT SUPPORTED OPERATOR IN GRAPH MODULE")
163
+
164
+
165
+ def convert_exported_module_to_circle(
166
+ exported_program: ExportedProgram,
167
+ config: CompileConfigBase = get_default_config(),
168
+ ) -> bytes:
169
+ logger = logging.getLogger(__name__)
170
+ logger.debug("Input ExportedProgram (must be core aten)")
171
+ logger.debug(exported_program)
172
+
173
+ # PRE-EDGE PASSES
174
+ #
175
+ # Here are the passes that run before to_edge() conversion.
176
+ # Let's decompose nodes that are not Aten Canonical, which can't be converted to the edge IR.
177
+ decompose_quantize_op = PassManager(
178
+ passes=[
179
+ DecomposeFakeQuantize(),
180
+ DecomposeFakeQuantizeTensorQParams(),
181
+ ]
182
+ )
183
+ decompose_quantize_op.run(exported_program)
184
+
185
+ # This pass should be run before 'RestoreLinear' and after 'decompose_quantize_op'.
186
+ # TODO run pass regardless of the orders.
187
+ with SuppressWarning(UserWarning, ".*quantize_per_tensor"), SuppressWarning(
188
+ UserWarning,
189
+ ".*TF32 acceleration on top of oneDNN is available for Intel GPUs.*",
190
+ ):
191
+ # Warning details:
192
+ # ...site-packages/torch/_subclasses/functional_tensor.py:364
193
+ # UserWarning: At pre-dispatch tracing, we assume that any custom op marked with
194
+ # CompositeImplicitAutograd and have functional schema are safe to not decompose.
195
+ exported_program = traced_run_decompositions(exported_program)
196
+
197
+ # TODO Distinguish legalize and optimize
198
+ circle_legalize = PassManager(
199
+ passes=[
200
+ FillMetaVal(),
201
+ ExtractDtypeKwargsPass(),
202
+ RemoveNop(),
203
+ ConvertLayoutOpToReshape(),
204
+ RestoreLinear(),
205
+ ConvertToReLU6(),
206
+ DecomposeAddmm(),
207
+ DecomposeSliceScatter(),
208
+ DecomposeGroupNorm(),
209
+ DecomposeBatchNorm(),
210
+ DecomposeGroupedConv2d(),
211
+ CastATenWhereArgType(),
212
+ ConvertRepeatToExpandCopy(),
213
+ *RemoveRedundantPermutePasses(),
214
+ RemoveRedundantAssertionNodes(),
215
+ RemoveRedundantExpand(),
216
+ RemoveRedundantSlice(),
217
+ FuseRedundantReshapeToMean(),
218
+ *RemoveRedundantViewPasses(),
219
+ RemoveRedundantToCopy(),
220
+ MergeConsecutiveCat(),
221
+ CastMixedTypeArgs(preserve_ep_invariant=True),
222
+ ConstPropPass(),
223
+ SegmentIndexSelectConst(),
224
+ LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
225
+ LowerToResizeNearestNeighbor(),
226
+ LegalizePreDefinedLayoutOperators(),
227
+ LowerPow2ToMul(),
228
+ ConvertConv1dToConv2d(),
229
+ *LowerToSlicePasses(),
230
+ FuseLeadingUnsqueezeReshape(),
231
+ ]
232
+ )
233
+ circle_legalize.run(exported_program)
234
+
235
+ # After this stage, ExportedProgram invariant is broken, i.e.,
236
+ # graph can have a constant torch.tensor not lifted to a placeholder
237
+ circle_legalize = PassManager(
238
+ passes=[
239
+ FillMetaVal(),
240
+ CastMixedTypeArgs(preserve_ep_invariant=False),
241
+ ]
242
+ )
243
+ circle_legalize.run(exported_program)
244
+
245
+ # TODO Give an option to enable quantiztion to user
246
+ enable_quantization = has_quantization_ops(exported_program.graph)
247
+ if enable_quantization:
248
+ quantize_graph = PassManager(
249
+ passes=[
250
+ FoldQuantOps(),
251
+ RemoveWeightDequantOp(),
252
+ PropagateQParamForward(),
253
+ PropagateQParamBackward(),
254
+ QuantizeBias(),
255
+ InsertQuantizeOnDtypeMismatch(),
256
+ ]
257
+ )
258
+ quantize_graph.run(exported_program)
259
+
260
+ check_unsupported_target(exported_program)
261
+ circle_program = build_circle(exported_program)
262
+
263
+ return circle_program
264
+
265
+
266
+ def convert(
267
+ mod: torch.nn.Module,
268
+ args: Tuple[Any, ...],
269
+ kwargs: Optional[Dict[str, Any]] = None,
270
+ strict: bool = True,
271
+ config: CompileConfigBase = get_default_config(),
272
+ ) -> CircleModel:
273
+ with torch.no_grad():
274
+ exported_program = export(mod, args, kwargs, strict=strict)
275
+
276
+ circle_binary = convert_exported_module_to_circle(exported_program, config=config)
277
+
278
+ return CircleModel(circle_binary)
279
+
280
+
281
+ def convert_from_exported_program(
282
+ exported_program: ExportedProgram,
283
+ config: CompileConfigBase = get_default_config(),
284
+ ) -> CircleModel:
285
+ circle_binary = convert_exported_module_to_circle(exported_program, config=config)
286
+
287
+ return CircleModel(circle_binary)
288
+
289
+
290
+ def convert_from_pt2(
291
+ pt2_path: str | os.PathLike, config: CompileConfigBase = get_default_config()
292
+ ) -> CircleModel:
293
+ exported_program = torch.export.load(pt2_path)
294
+ circle_binary = convert_exported_module_to_circle(exported_program, config=config)
295
+
296
+ return CircleModel(circle_binary)
tico/utils/define.py ADDED
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List
16
+
17
+ from circle_schema import circle
18
+
19
+ from tico.serialize.circle_graph import CircleSubgraph
20
+ from tico.serialize.operators.hashable_opcode import OpCode
21
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
22
+
23
+
24
+ def define_pad_node(
25
+ graph: CircleSubgraph, op_codes: Dict[OpCode, int], inputs: List, outputs: List
26
+ ) -> circle.Operator.OperatorT:
27
+ def set_pad_option(operator: circle.Operator.OperatorT):
28
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PadOptions
29
+ option = circle.PadOptions.PadOptionsT()
30
+ operator.builtinOptions = option
31
+
32
+ pad_op_index = get_op_index(circle.BuiltinOperator.BuiltinOperator.PAD, op_codes)
33
+ operator = create_builtin_operator(graph, pad_op_index, inputs, outputs)
34
+ set_pad_option(operator)
35
+ return operator
@@ -0,0 +1,181 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from copy import deepcopy
16
+ from difflib import ndiff
17
+ from functools import reduce
18
+ from logging import DEBUG
19
+ from typing import Optional, TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ import torch.fx
23
+ import torch
24
+
25
+ from tico.utils.logging import getLogger, LOG_LEVEL
26
+
27
+
28
+ def strdiff(a: str, b: str):
29
+ """
30
+ Get difference in two strings as if linux `diff` command does
31
+ """
32
+ assert isinstance(a, str), f"{a} must be str, type: {type(a)}"
33
+ assert isinstance(b, str), f"{b} must be str, type: {type(b)}"
34
+
35
+ changed = []
36
+ for line in ndiff(a.splitlines(keepends=True), b.splitlines(keepends=True)):
37
+ if line.startswith(("-", "+")):
38
+ changed.append(line)
39
+ return "".join(changed)
40
+
41
+
42
+ def disable_when(predicate):
43
+ """
44
+ Disable function only if predicate is true
45
+ """
46
+
47
+ def _inner_disable_when(func):
48
+ if predicate:
49
+
50
+ def nop(*args, **kwargs):
51
+ pass
52
+
53
+ return nop
54
+ else:
55
+ return func
56
+
57
+ return _inner_disable_when
58
+
59
+
60
+ LOGGER_THRESHOLD = DEBUG
61
+ graph_captured: Optional[str | torch.fx.Graph] = None
62
+ const_size_captured: Optional[int] = None
63
+
64
+
65
+ def get_const_size(ep: torch.export.ExportedProgram) -> int:
66
+ """
67
+ Return const tensor's size in **byte**
68
+ """
69
+
70
+ def const_size(items):
71
+ const_sum = 0
72
+ for _, tensor in items:
73
+ if len(tensor.size()) == 0:
74
+ # scalar tensor
75
+ const_sum += tensor.dtype.itemsize
76
+ else:
77
+ const_sum += (
78
+ reduce(lambda x, y: x * y, list(tensor.size()))
79
+ * tensor.dtype.itemsize
80
+ )
81
+ return const_sum
82
+
83
+ constant_tensor_sum = 0
84
+
85
+ constant_tensor_sum += const_size(ep.state_dict.items())
86
+ constant_tensor_sum += const_size(ep.constants.items())
87
+
88
+ return constant_tensor_sum
89
+
90
+
91
+ @disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
92
+ def capture_const(ep: torch.export.ExportedProgram):
93
+ assert isinstance(ep, torch.export.ExportedProgram)
94
+
95
+ global const_size_captured
96
+ const_size_captured = get_const_size(ep)
97
+
98
+
99
+ @disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
100
+ def log_const(ep: torch.export.ExportedProgram, title: str, recapture: bool):
101
+ assert isinstance(ep, torch.export.ExportedProgram)
102
+
103
+ global const_size_captured
104
+ assert const_size_captured is not None
105
+ const_size = get_const_size(ep)
106
+ const_size_diff = const_size - const_size_captured
107
+
108
+ # print differences
109
+ logger = getLogger(__name__)
110
+ prefix = f"[{title}]" if title else ""
111
+ if const_size_diff > 0:
112
+ const_size_inc_dec = "has changed (increased)"
113
+ elif const_size_diff == 0:
114
+ const_size_inc_dec = "has unchanged"
115
+ else:
116
+ const_size_inc_dec = "has changed (decreased)"
117
+
118
+ percentage_avg_str = ""
119
+ if const_size + const_size_captured == 0:
120
+ percentage_avg_str = "N/A"
121
+ else:
122
+ percentage_avg = (
123
+ float(const_size_diff) / float(const_size + const_size_captured) * 100
124
+ )
125
+ if percentage_avg > 0:
126
+ percentage_avg_str = f"+{percentage_avg:.2f}%"
127
+ else:
128
+ percentage_avg_str = f"{percentage_avg:.2f}%"
129
+
130
+ if const_size_diff:
131
+ logger.debug(
132
+ f"{prefix} Total const size {const_size_inc_dec} by {const_size_diff} Bytes"
133
+ )
134
+ logger.debug(f"{const_size_captured}B -> {const_size}B ({percentage_avg_str})")
135
+
136
+ if recapture:
137
+ const_size_captured = const_size
138
+
139
+
140
+ @disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
141
+ def capture(graph: torch.fx.Graph):
142
+ """
143
+ Capture the start-point graph for graph-diff.
144
+ String diff lines will be printed to debug logger if enabled.
145
+
146
+ Args:
147
+ graph (torch.fx.Graph): graph to captureString diff lines
148
+ """
149
+ assert isinstance(graph, torch.fx.Graph)
150
+ global graph_captured
151
+ graph_captured = str(graph)
152
+
153
+
154
+ @disable_when(LOG_LEVEL > DEBUG)
155
+ def log(graph: torch.fx.Graph, title: str, recapture: bool):
156
+ """
157
+ Capture the end-point graph for graph-diff.
158
+ String diff lines will be printed to debug logger if enabled.
159
+
160
+ Args:
161
+ graph (torch.fx.Graph): graph to capture
162
+ title (str): Title in log
163
+ recapture (bool): recapture the graph
164
+ """
165
+ assert isinstance(graph, torch.fx.Graph)
166
+ global graph_captured
167
+
168
+ logger = getLogger(__name__)
169
+ diff = strdiff(f"{graph_captured}\n", f"{graph}\n")
170
+ prefix = f"[{title}]" if title else ""
171
+ if len(diff) > 0:
172
+ logger.debug(f"{prefix} Graph is changed.")
173
+ logger.debug(f"\n{diff}")
174
+
175
+ if recapture:
176
+ graph_captured = deepcopy(graph)
177
+ else:
178
+ graph_captured = None # reset
179
+
180
+
181
+ # TODO diff graph signature
tico/utils/errors.py ADDED
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ class CircleExirError(Exception):
17
+ """Base class for custom exceptions in project"""
18
+
19
+ pass
20
+
21
+
22
+ class NotYetSupportedError(CircleExirError):
23
+ """
24
+ Not yet supported feature or functionality
25
+ """
26
+
27
+ pass
28
+
29
+
30
+ class InvalidArgumentError(CircleExirError):
31
+ """
32
+ Invalid argument, which is never allowed
33
+ """
34
+
35
+ pass