tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
tico/utils/convert.py ADDED
@@ -0,0 +1,292 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import operator
16
+ import os
17
+ from typing import Any, Dict, Optional, Tuple
18
+
19
+ import torch
20
+ from torch.export import export, ExportedProgram
21
+
22
+ from tico.config import CompileConfigBase, get_default_config
23
+ from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
24
+ from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
25
+ InsertQuantizeOnDtypeMismatch,
26
+ )
27
+ from tico.experimental.quantization.passes.propagate_qparam_backward import (
28
+ PropagateQParamBackward,
29
+ )
30
+ from tico.experimental.quantization.passes.propagate_qparam_forward import (
31
+ PropagateQParamForward,
32
+ )
33
+ from tico.experimental.quantization.passes.remove_weight_dequant_op import (
34
+ RemoveWeightDequantOp,
35
+ )
36
+ from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
37
+ from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
38
+ from tico.passes.const_prop_pass import ConstPropPass
39
+ from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
40
+ from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
41
+ from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
42
+ from tico.passes.convert_to_relu6 import ConvertToReLU6
43
+ from tico.passes.decompose_addmm import DecomposeAddmm
44
+ from tico.passes.decompose_batch_norm import DecomposeBatchNorm
45
+ from tico.passes.decompose_fake_quantize import DecomposeFakeQuantize
46
+ from tico.passes.decompose_fake_quantize_tensor_qparams import (
47
+ DecomposeFakeQuantizeTensorQParams,
48
+ )
49
+ from tico.passes.decompose_group_norm import DecomposeGroupNorm
50
+ from tico.passes.decompose_grouped_conv2d import DecomposeGroupedConv2d
51
+ from tico.passes.decompose_slice_scatter import DecomposeSliceScatter
52
+ from tico.passes.extract_dtype_kwargs import ExtractDtypeKwargsPass
53
+ from tico.passes.fill_meta_val import FillMetaVal
54
+ from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean
55
+ from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
56
+ from tico.passes.legalize_predefined_layout_operators import (
57
+ LegalizePreDefinedLayoutOperators,
58
+ )
59
+ from tico.passes.lower_pow2_to_mul import LowerPow2ToMul
60
+ from tico.passes.lower_to_resize_nearest_neighbor import LowerToResizeNearestNeighbor
61
+ from tico.passes.lower_to_slice import LowerToSlice
62
+ from tico.passes.merge_consecutive_cat import MergeConsecutiveCat
63
+ from tico.passes.remove_nop import RemoveNop
64
+ from tico.passes.remove_redundant_assert_nodes import RemoveRedundantAssertionNodes
65
+ from tico.passes.remove_redundant_expand import RemoveRedundantExpand
66
+ from tico.passes.remove_redundant_permute import passes as RemoveRedundantPermutePasses
67
+ from tico.passes.remove_redundant_reshape import passes as RemoveRedundantViewPasses
68
+ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
69
+ from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
70
+ from tico.passes.restore_linear import RestoreLinear
71
+ from tico.passes.segment_index_select import SegmentIndexSelectConst
72
+ from tico.serialize.circle_serializer import build_circle
73
+ from tico.serialize.operators.node_visitor import get_support_targets
74
+ from tico.utils import logging
75
+ from tico.utils.errors import NotYetSupportedError
76
+ from tico.utils.model import CircleModel
77
+ from tico.utils.passes import PassManager
78
+ from tico.utils.trace_decorators import (
79
+ trace_const_diff_on_func,
80
+ trace_graph_diff_on_func,
81
+ )
82
+ from tico.utils.utils import has_quantization_ops, SuppressWarning
83
+
84
+
85
+ @trace_const_diff_on_func
86
+ @trace_graph_diff_on_func
87
+ def traced_run_decompositions(exported_program: ExportedProgram):
88
+ """
89
+ Let's preserve convolution operators.
90
+ `run_decompositions()` converts all Conv-related Ops to generic `aten.convolution`.
91
+ But, we should re-convert them to specific circle ops such as CircleConv2D, TransposeConv, etc.
92
+ Therefore, we do not decompose Conv-related Ops and convert them directly to circle ops.
93
+ """
94
+
95
+ def run_decompositions_v25(ep: ExportedProgram):
96
+ _preserve_ops = (
97
+ torch.ops.aten.conv2d.default,
98
+ torch.ops.aten.conv2d.padding,
99
+ torch.ops.aten.conv1d.default,
100
+ torch.ops.aten.conv1d.padding,
101
+ torch.ops.aten.instance_norm.default,
102
+ torch.ops.aten._safe_softmax.default,
103
+ torch.ops.aten.relu6.default, # Do not decompose to hardtanh
104
+ torch.ops.aten.linear.default,
105
+ )
106
+ ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
107
+
108
+ return ep
109
+
110
+ def run_decompositions(ep: ExportedProgram):
111
+ _decomp_table = torch.export.default_decompositions() # type: ignore[attr-defined]
112
+ _preserve_ops = (
113
+ torch.ops.aten.conv2d.default,
114
+ torch.ops.aten.conv2d.padding,
115
+ torch.ops.aten.conv1d.default,
116
+ torch.ops.aten.conv1d.padding,
117
+ torch.ops.aten.instance_norm.default,
118
+ torch.ops.aten._safe_softmax.default,
119
+ torch.ops.aten.relu6.default, # Do not decompose to hardtanh
120
+ torch.ops.aten.prelu.default,
121
+ torch.ops.aten.linear.default,
122
+ )
123
+ for op in _preserve_ops:
124
+ if op in _decomp_table:
125
+ del _decomp_table[op]
126
+
127
+ ep = ep.run_decompositions(decomp_table=_decomp_table)
128
+ return ep
129
+
130
+ if torch.__version__.startswith("2.5"):
131
+ return run_decompositions_v25(exported_program)
132
+ elif (
133
+ torch.__version__.startswith("2.6")
134
+ or torch.__version__.startswith("2.7")
135
+ or torch.__version__.startswith("2.8")
136
+ ):
137
+ return run_decompositions(exported_program)
138
+ else:
139
+ raise RuntimeError(f"Unsupported PyTorch version: {torch.__version__}")
140
+
141
+
142
+ def check_unsupported_target(exported_program: ExportedProgram):
143
+ logger = logging.getLogger(__name__)
144
+
145
+ supported_target = list(get_support_targets())
146
+ # Ignore `getitem` since it is no-op for multiple outputs.
147
+ supported_target.append(operator.getitem)
148
+ unsupported = []
149
+ for n in exported_program.graph.nodes:
150
+ if n.op != "call_function":
151
+ continue
152
+ if not n.target in supported_target:
153
+ unsupported.append(n)
154
+
155
+ if unsupported:
156
+ for node in unsupported:
157
+ logger.error(
158
+ f"NOT SUPPORTED OPERATOR\n\t(op) {node.target.__name__}\n\t(trace) {node.meta.get('stack_trace')}"
159
+ )
160
+ raise NotYetSupportedError("NOT SUPPORTED OPERATOR IN GRAPH MODULE")
161
+
162
+
163
+ def convert_exported_module_to_circle(
164
+ exported_program: ExportedProgram,
165
+ config: CompileConfigBase = get_default_config(),
166
+ ) -> bytes:
167
+ logger = logging.getLogger(__name__)
168
+ logger.debug("Input ExportedProgram (must be core aten)")
169
+ logger.debug(exported_program)
170
+
171
+ # PRE-EDGE PASSES
172
+ #
173
+ # Here are the passes that run before to_edge() conversion.
174
+ # Let's decompose nodes that are not Aten Canonical, which can't be converted to the edge IR.
175
+ decompose_quantize_op = PassManager(
176
+ passes=[
177
+ DecomposeFakeQuantize(),
178
+ DecomposeFakeQuantizeTensorQParams(),
179
+ ]
180
+ )
181
+ decompose_quantize_op.run(exported_program)
182
+
183
+ # This pass should be run before 'RestoreLinear' and after 'decompose_quantize_op'.
184
+ # TODO run pass regardless of the orders.
185
+ with SuppressWarning(UserWarning, ".*quantize_per_tensor"), SuppressWarning(
186
+ UserWarning,
187
+ ".*TF32 acceleration on top of oneDNN is available for Intel GPUs.*",
188
+ ):
189
+ # Warning details:
190
+ # ...site-packages/torch/_subclasses/functional_tensor.py:364
191
+ # UserWarning: At pre-dispatch tracing, we assume that any custom op marked with
192
+ # CompositeImplicitAutograd and have functional schema are safe to not decompose.
193
+ exported_program = traced_run_decompositions(exported_program)
194
+
195
+ # TODO Distinguish legalize and optimize
196
+ circle_legalize = PassManager(
197
+ passes=[
198
+ FillMetaVal(),
199
+ ExtractDtypeKwargsPass(),
200
+ RemoveNop(),
201
+ ConvertLayoutOpToReshape(),
202
+ RestoreLinear(),
203
+ ConvertToReLU6(),
204
+ DecomposeAddmm(),
205
+ DecomposeSliceScatter(),
206
+ DecomposeGroupNorm(),
207
+ DecomposeBatchNorm(),
208
+ DecomposeGroupedConv2d(),
209
+ CastATenWhereArgType(),
210
+ ConvertRepeatToExpandCopy(),
211
+ *RemoveRedundantPermutePasses(),
212
+ RemoveRedundantAssertionNodes(),
213
+ RemoveRedundantExpand(),
214
+ RemoveRedundantSlice(),
215
+ FuseRedundantReshapeToMean(),
216
+ *RemoveRedundantViewPasses(),
217
+ RemoveRedundantToCopy(),
218
+ MergeConsecutiveCat(),
219
+ CastMixedTypeArgs(preserve_ep_invariant=True),
220
+ ConstPropPass(),
221
+ SegmentIndexSelectConst(),
222
+ LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
223
+ LowerToResizeNearestNeighbor(),
224
+ LegalizePreDefinedLayoutOperators(),
225
+ LowerPow2ToMul(),
226
+ ConvertConv1dToConv2d(),
227
+ LowerToSlice(),
228
+ ]
229
+ )
230
+ circle_legalize.run(exported_program)
231
+
232
+ # After this stage, ExportedProgram invariant is broken, i.e.,
233
+ # graph can have a constant torch.tensor not lifted to a placeholder
234
+ circle_legalize = PassManager(
235
+ passes=[
236
+ FillMetaVal(),
237
+ CastMixedTypeArgs(preserve_ep_invariant=False),
238
+ ]
239
+ )
240
+ circle_legalize.run(exported_program)
241
+
242
+ # TODO Give an option to enable quantiztion to user
243
+ enable_quantization = has_quantization_ops(exported_program.graph)
244
+ if enable_quantization:
245
+ quantize_graph = PassManager(
246
+ passes=[
247
+ FoldQuantOps(),
248
+ RemoveWeightDequantOp(),
249
+ PropagateQParamForward(),
250
+ PropagateQParamBackward(),
251
+ InsertQuantizeOnDtypeMismatch(),
252
+ ]
253
+ )
254
+ quantize_graph.run(exported_program)
255
+
256
+ check_unsupported_target(exported_program)
257
+ circle_program = build_circle(exported_program)
258
+
259
+ return circle_program
260
+
261
+
262
+ def convert(
263
+ mod: torch.nn.Module,
264
+ args: Tuple[Any, ...],
265
+ kwargs: Optional[Dict[str, Any]] = None,
266
+ strict: bool = True,
267
+ config: CompileConfigBase = get_default_config(),
268
+ ) -> CircleModel:
269
+ with torch.no_grad():
270
+ exported_program = export(mod, args, kwargs, strict=strict)
271
+
272
+ circle_binary = convert_exported_module_to_circle(exported_program, config=config)
273
+
274
+ return CircleModel(circle_binary)
275
+
276
+
277
+ def convert_from_exported_program(
278
+ exported_program: ExportedProgram,
279
+ config: CompileConfigBase = get_default_config(),
280
+ ) -> CircleModel:
281
+ circle_binary = convert_exported_module_to_circle(exported_program, config=config)
282
+
283
+ return CircleModel(circle_binary)
284
+
285
+
286
+ def convert_from_pt2(
287
+ pt2_path: str | os.PathLike, config: CompileConfigBase = get_default_config()
288
+ ) -> CircleModel:
289
+ exported_program = torch.export.load(pt2_path)
290
+ circle_binary = convert_exported_module_to_circle(exported_program, config=config)
291
+
292
+ return CircleModel(circle_binary)
tico/utils/define.py ADDED
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List
16
+
17
+ from circle_schema import circle
18
+
19
+ from tico.serialize.circle_graph import CircleSubgraph
20
+ from tico.serialize.operators.hashable_opcode import OpCode
21
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
22
+
23
+
24
+ def define_pad_node(
25
+ graph: CircleSubgraph, op_codes: Dict[OpCode, int], inputs: List, outputs: List
26
+ ) -> circle.Operator.OperatorT:
27
+ def set_pad_option(operator: circle.Operator.OperatorT):
28
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PadOptions
29
+ option = circle.PadOptions.PadOptionsT()
30
+ operator.builtinOptions = option
31
+
32
+ pad_op_index = get_op_index(circle.BuiltinOperator.BuiltinOperator.PAD, op_codes)
33
+ operator = create_builtin_operator(graph, pad_op_index, inputs, outputs)
34
+ set_pad_option(operator)
35
+ return operator
@@ -0,0 +1,181 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from copy import deepcopy
16
+ from difflib import ndiff
17
+ from functools import reduce
18
+ from logging import DEBUG
19
+ from typing import Optional, TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ import torch.fx
23
+ import torch
24
+
25
+ from tico.utils.logging import getLogger, LOG_LEVEL
26
+
27
+
28
+ def strdiff(a: str, b: str):
29
+ """
30
+ Get difference in two strings as if linux `diff` command does
31
+ """
32
+ assert isinstance(a, str), f"{a} must be str, type: {type(a)}"
33
+ assert isinstance(b, str), f"{b} must be str, type: {type(b)}"
34
+
35
+ changed = []
36
+ for line in ndiff(a.splitlines(keepends=True), b.splitlines(keepends=True)):
37
+ if line.startswith(("-", "+")):
38
+ changed.append(line)
39
+ return "".join(changed)
40
+
41
+
42
+ def disable_when(predicate):
43
+ """
44
+ Disable function only if predicate is true
45
+ """
46
+
47
+ def _inner_disable_when(func):
48
+ if predicate:
49
+
50
+ def nop(*args, **kwargs):
51
+ pass
52
+
53
+ return nop
54
+ else:
55
+ return func
56
+
57
+ return _inner_disable_when
58
+
59
+
60
+ LOGGER_THRESHOLD = DEBUG
61
+ graph_captured: Optional[str | torch.fx.Graph] = None
62
+ const_size_captured: Optional[int] = None
63
+
64
+
65
+ def get_const_size(ep: torch.export.ExportedProgram) -> int:
66
+ """
67
+ Return const tensor's size in **byte**
68
+ """
69
+
70
+ def const_size(items):
71
+ const_sum = 0
72
+ for _, tensor in items:
73
+ if len(tensor.size()) == 0:
74
+ # scalar tensor
75
+ const_sum += tensor.dtype.itemsize
76
+ else:
77
+ const_sum += (
78
+ reduce(lambda x, y: x * y, list(tensor.size()))
79
+ * tensor.dtype.itemsize
80
+ )
81
+ return const_sum
82
+
83
+ constant_tensor_sum = 0
84
+
85
+ constant_tensor_sum += const_size(ep.state_dict.items())
86
+ constant_tensor_sum += const_size(ep.constants.items())
87
+
88
+ return constant_tensor_sum
89
+
90
+
91
+ @disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
92
+ def capture_const(ep: torch.export.ExportedProgram):
93
+ assert isinstance(ep, torch.export.ExportedProgram)
94
+
95
+ global const_size_captured
96
+ const_size_captured = get_const_size(ep)
97
+
98
+
99
+ @disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
100
+ def log_const(ep: torch.export.ExportedProgram, title: str, recapture: bool):
101
+ assert isinstance(ep, torch.export.ExportedProgram)
102
+
103
+ global const_size_captured
104
+ assert const_size_captured is not None
105
+ const_size = get_const_size(ep)
106
+ const_size_diff = const_size - const_size_captured
107
+
108
+ # print differences
109
+ logger = getLogger(__name__)
110
+ prefix = f"[{title}]" if title else ""
111
+ if const_size_diff > 0:
112
+ const_size_inc_dec = "has changed (increased)"
113
+ elif const_size_diff == 0:
114
+ const_size_inc_dec = "has unchanged"
115
+ else:
116
+ const_size_inc_dec = "has changed (decreased)"
117
+
118
+ percentage_avg_str = ""
119
+ if const_size + const_size_captured == 0:
120
+ percentage_avg_str = "N/A"
121
+ else:
122
+ percentage_avg = (
123
+ float(const_size_diff) / float(const_size + const_size_captured) * 100
124
+ )
125
+ if percentage_avg > 0:
126
+ percentage_avg_str = f"+{percentage_avg:.2f}%"
127
+ else:
128
+ percentage_avg_str = f"{percentage_avg:.2f}%"
129
+
130
+ if const_size_diff:
131
+ logger.debug(
132
+ f"{prefix} Total const size {const_size_inc_dec} by {const_size_diff} Bytes"
133
+ )
134
+ logger.debug(f"{const_size_captured}B -> {const_size}B ({percentage_avg_str})")
135
+
136
+ if recapture:
137
+ const_size_captured = const_size
138
+
139
+
140
+ @disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
141
+ def capture(graph: torch.fx.Graph):
142
+ """
143
+ Capture the start-point graph for graph-diff.
144
+ String diff lines will be printed to debug logger if enabled.
145
+
146
+ Args:
147
+ graph (torch.fx.Graph): graph to captureString diff lines
148
+ """
149
+ assert isinstance(graph, torch.fx.Graph)
150
+ global graph_captured
151
+ graph_captured = str(graph)
152
+
153
+
154
+ @disable_when(LOG_LEVEL > DEBUG)
155
+ def log(graph: torch.fx.Graph, title: str, recapture: bool):
156
+ """
157
+ Capture the end-point graph for graph-diff.
158
+ String diff lines will be printed to debug logger if enabled.
159
+
160
+ Args:
161
+ graph (torch.fx.Graph): graph to capture
162
+ title (str): Title in log
163
+ recapture (bool): recapture the graph
164
+ """
165
+ assert isinstance(graph, torch.fx.Graph)
166
+ global graph_captured
167
+
168
+ logger = getLogger(__name__)
169
+ diff = strdiff(f"{graph_captured}\n", f"{graph}\n")
170
+ prefix = f"[{title}]" if title else ""
171
+ if len(diff) > 0:
172
+ logger.debug(f"{prefix} Graph is changed.")
173
+ logger.debug(f"\n{diff}")
174
+
175
+ if recapture:
176
+ graph_captured = deepcopy(graph)
177
+ else:
178
+ graph_captured = None # reset
179
+
180
+
181
+ # TODO diff graph signature
tico/utils/errors.py ADDED
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ class CircleExirError(Exception):
17
+ """Base class for custom exceptions in project"""
18
+
19
+ pass
20
+
21
+
22
+ class NotYetSupportedError(CircleExirError):
23
+ """
24
+ Not yet supported feature or functionality
25
+ """
26
+
27
+ pass
28
+
29
+
30
+ class InvalidArgumentError(CircleExirError):
31
+ """
32
+ Invalid argument, which is never allowed
33
+ """
34
+
35
+ pass