tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.errors import NotYetSupportedError
28
+ from tico.utils.validate_args_kwargs import SubTensorArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class SubVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.sub.Tensor]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ op_index = get_op_index(
43
+ circle.BuiltinOperator.BuiltinOperator.SUB, self._op_codes
44
+ )
45
+
46
+ args = SubTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type, call-arg]
47
+
48
+ input = args.input
49
+ other = args.other
50
+ alpha = args.alpha
51
+
52
+ if alpha is not None:
53
+ raise NotYetSupportedError(
54
+ "'alpha' of aten::sub.Tensor is not supported yet"
55
+ )
56
+
57
+ inputs = [input, other]
58
+ outputs = [node]
59
+
60
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
61
+
62
+ # Op-specific option
63
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.SubOptions
64
+ option = circle.SubOptions.SubOptionsT()
65
+ option.fusedActivationFunction = (
66
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
67
+ )
68
+ option.potScaleInt16 = False
69
+ operator.builtinOptions = option
70
+
71
+ return operator
@@ -0,0 +1,63 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import SumDimIntListArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class SumVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.sum.dim_IntList]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ args = SumDimIntListArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
43
+ input = args.input
44
+ dim = args.dim
45
+ keepdim = args.keepdim
46
+
47
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
48
+
49
+ inputs = [input, dim_i32]
50
+ outputs = [node]
51
+
52
+ op_index = get_op_index(
53
+ circle.BuiltinOperator.BuiltinOperator.SUM, self._op_codes
54
+ )
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+ operator.builtinOptionsType = (
57
+ circle.BuiltinOptions.BuiltinOptions.ReducerOptions
58
+ )
59
+ option = circle.ReducerOptions.ReducerOptionsT()
60
+ option.keepDims = keepdim
61
+ operator.builtinOptions = option
62
+
63
+ return operator
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import TanhArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class TanhVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.tanh.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = TanhArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+
46
+ inputs = [input]
47
+ outputs = [node]
48
+
49
+ op_index = get_op_index(
50
+ circle.BuiltinOperator.BuiltinOperator.TANH, self._op_codes
51
+ )
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator
@@ -0,0 +1,105 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_mapping import (
24
+ extract_circle_dtype,
25
+ extract_torch_dtype,
26
+ to_circle_dtype,
27
+ )
28
+ from tico.serialize.operators.hashable_opcode import OpCode
29
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
30
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
31
+ from tico.utils.errors import NotYetSupportedError
32
+ from tico.utils.validate_args_kwargs import ToCopyArgs
33
+
34
+
35
+ @register_node_visitor
36
+ class ToCopyVisitor(NodeVisitor):
37
+ target: List[torch._ops.OpOverload] = [torch.ops.aten._to_copy.default]
38
+
39
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
40
+ super().__init__(op_codes, graph)
41
+
42
+ def define_cast_node(
43
+ self,
44
+ inputs: List[torch.fx.Node],
45
+ outputs: List[torch.fx.Node],
46
+ in_type: int,
47
+ out_type: int,
48
+ ):
49
+ op_index = get_op_index(
50
+ circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
51
+ )
52
+
53
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
54
+
55
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CastOptions
56
+ option = circle.CastOptions.CastOptionsT()
57
+ option.inDataType = in_type
58
+ option.outDataType = out_type
59
+ operator.builtinOptions = option
60
+
61
+ return operator
62
+
63
+ def define_node(
64
+ self,
65
+ node: torch.fx.Node,
66
+ ) -> circle.Operator.OperatorT:
67
+ supported_kwargs = ["dtype", "device", "layout"]
68
+ if not all(k in supported_kwargs for k in node.kwargs):
69
+ unsupported_node_kargs = list(node.kwargs.keys())
70
+ for supported_key in supported_kwargs:
71
+ if supported_key in node.kwargs:
72
+ unsupported_node_kargs.remove(supported_key)
73
+ raise NotYetSupportedError(
74
+ f"Support only {supported_kwargs} kwargs now. Do not support {unsupported_node_kargs}"
75
+ )
76
+
77
+ args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type, call-arg]
78
+ input = args.input
79
+ dtype = args.dtype
80
+
81
+ input_meta = input.meta["val"]
82
+ # https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout
83
+ # layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors)
84
+ if "layout" in input.kwargs and input.kwargs["layout"] != input_meta:
85
+ raise NotYetSupportedError(
86
+ f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {node.kwargs['layout']})."
87
+ )
88
+
89
+ if dtype is not None:
90
+ target_type = node.kwargs["dtype"]
91
+ else:
92
+ # device and layout are meaningless
93
+ target_type = extract_torch_dtype(node)
94
+ assert isinstance(target_type, torch.dtype), type(target_type)
95
+
96
+ # define cast node
97
+ in_type: int = extract_circle_dtype(input)
98
+ out_type: int = to_circle_dtype(target_type)
99
+ inputs = [input]
100
+ outputs = [node]
101
+ operator = self.define_cast_node(inputs, outputs, in_type, out_type)
102
+
103
+ # TODO Support layout conversion
104
+
105
+ return operator
@@ -0,0 +1,66 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import UnSqueezeArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class UnsqueezeVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [
34
+ torch.ops.aten.unsqueeze.default,
35
+ torch.ops.aten.unsqueeze_copy.default,
36
+ ]
37
+
38
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
+ super().__init__(op_codes, graph)
40
+
41
+ def define_node(
42
+ self,
43
+ node: torch.fx.Node,
44
+ ) -> circle.Operator.OperatorT:
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.EXPAND_DIMS, self._op_codes
47
+ )
48
+
49
+ args = UnSqueezeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
50
+ input = args.input
51
+ dim = args.dim
52
+
53
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
54
+ inputs = [input, dim_i32]
55
+ outputs = [node]
56
+
57
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
58
+
59
+ operator.builtinOptionsType = (
60
+ circle.BuiltinOptions.BuiltinOptions.ExpandDimsOptions
61
+ )
62
+ option = circle.ExpandDimsOptions.ExpandDimsOptionsT()
63
+
64
+ operator.builtinOptions = option
65
+
66
+ return operator
@@ -0,0 +1,74 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph, is_const
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import ViewArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class ViewVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [
34
+ torch.ops.aten.view,
35
+ torch.ops.aten.view.default,
36
+ torch.ops.aten.view_copy.default,
37
+ ]
38
+
39
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
40
+ super().__init__(op_codes, graph)
41
+
42
+ def define_node(
43
+ self,
44
+ node: torch.fx.Node,
45
+ ) -> circle.Operator.OperatorT:
46
+ op_index = get_op_index(
47
+ circle.BuiltinOperator.BuiltinOperator.RESHAPE,
48
+ self._op_codes,
49
+ )
50
+ args = ViewArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
51
+ input = args.input
52
+ size = args.size
53
+
54
+ assert is_const(size), type(size)
55
+
56
+ if isinstance(size, int):
57
+ raise Exception("scalar size conversion is not supported yet.")
58
+
59
+ size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
60
+ inputs = [input, size_i32]
61
+ outputs = [node]
62
+
63
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
64
+
65
+ # Op-specific option
66
+ operator.builtinOptionsType = (
67
+ circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
68
+ )
69
+ option = circle.ReshapeOptions.ReshapeOptionsT()
70
+ option.newShape = size_i32
71
+
72
+ operator.builtinOptions = option
73
+
74
+ return operator
@@ -0,0 +1,82 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import extract_torch_dtype
25
+
26
+ from tico.serialize.operators.hashable_opcode import OpCode
27
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
28
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
29
+ from tico.utils.validate_args_kwargs import WhereSelfArgs
30
+
31
+
32
+ @register_node_visitor
33
+ class WhereVisitor(NodeVisitor):
34
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.where.self]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ op_index = get_op_index(
44
+ circle.BuiltinOperator.BuiltinOperator.SELECT_V2,
45
+ self._op_codes,
46
+ )
47
+
48
+ args = WhereSelfArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
49
+ condition = args.condition
50
+ input = args.input
51
+ other = args.other
52
+
53
+ result_true_dtype = (
54
+ extract_torch_dtype(input)
55
+ if isinstance(input, torch.fx.node.Node)
56
+ else input.dtype # type: ignore[union-attr]
57
+ )
58
+ result_false_dtype = (
59
+ extract_torch_dtype(other)
60
+ if isinstance(other, torch.fx.node.Node)
61
+ else other.dtype # type: ignore[union-attr]
62
+ )
63
+
64
+ if result_true_dtype != result_false_dtype:
65
+ raise RuntimeError(
66
+ f"Data type of arguments are not matched. result_true: {result_true_dtype}, result_false: {result_false_dtype}"
67
+ )
68
+
69
+ inputs = [condition, input, other]
70
+ outputs = [node]
71
+
72
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
73
+
74
+ # Op-specific option
75
+ operator.builtinOptionsType = (
76
+ circle.BuiltinOptions.BuiltinOptions.SelectV2Options
77
+ )
78
+ option = circle.SelectV2Options.SelectV2OptionsT()
79
+
80
+ operator.builtinOptions = option
81
+
82
+ return operator
@@ -0,0 +1,51 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List
16
+
17
+ from circle_schema import circle
18
+
19
+ from tico.serialize.operators.hashable_opcode import OpCode
20
+
21
+
22
+ def create_builtin_opcode(opcode: int) -> OpCode:
23
+ op_code = OpCode()
24
+ # deprecatedBuiltinCode is int8, so its maximum value is 127
25
+ # (127 is reserved as a placeholder for greater opcodes)
26
+ # opcode greater than 127 is saved in builtinCode
27
+ op_code.deprecatedBuiltinCode = min(127, opcode)
28
+ op_code.builtinCode = opcode
29
+ op_code.version = 1
30
+ return op_code
31
+
32
+
33
+ def get_op_index(opcode: int, opcode_map: Dict[OpCode, int]) -> int:
34
+ op_code = create_builtin_opcode(opcode)
35
+ if op_code not in opcode_map:
36
+ op_index = len(opcode_map)
37
+ opcode_map[op_code] = op_index
38
+ else:
39
+ op_index = opcode_map[op_code]
40
+ return op_index
41
+
42
+
43
+ # TODO Move this to CircleSubGraph
44
+ def create_builtin_operator(
45
+ graph, op_index: int, inputs: List, outputs: List
46
+ ) -> circle.Operator.OperatorT:
47
+ operator = circle.Operator.OperatorT()
48
+ operator.opcodeIndex = op_index
49
+ operator.inputs = [graph.get_tid(input) for input in inputs]
50
+ operator.outputs = [graph.get_tid(output) for output in outputs]
51
+ return operator
tico/serialize/pack.py ADDED
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+
17
+
18
+ def pack_buffer(flat_data: np.ndarray, dtype: str) -> np.ndarray:
19
+ assert len(flat_data.shape) == 1
20
+
21
+ if dtype == "uint4":
22
+ if flat_data.dtype != np.uint8:
23
+ raise RuntimeError("uint4 data should be saved in uint8.")
24
+
25
+ numel = flat_data.shape[0]
26
+ packed = np.zeros((numel + 1) // 2, dtype=np.uint8)
27
+ for i in range(numel):
28
+ assert flat_data[i] >= 0 and flat_data[i] <= 15
29
+ if i % 2 == 0:
30
+ packed[i // 2] = flat_data[i]
31
+ else:
32
+ packed[i // 2] |= flat_data[i] << 4
33
+ return packed
34
+ else:
35
+ raise NotImplementedError(f"NYI dtype: {dtype}")
@@ -0,0 +1,42 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This is a key for torch.fx.Node's meta dict to save QuantParam
17
+
18
+ QuantParam can be retrieved as node.meta[QPARAM_KEY]
19
+ """
20
+ QPARAM_KEY = "_quantization_parameters_"
21
+
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional
24
+
25
+ import torch
26
+
27
+
28
+ @dataclass
29
+ class QuantParam:
30
+ scale: Optional[List[float]] = None
31
+ zero_point: Optional[List[int]] = None
32
+ quantized_dimension: Optional[int] = None
33
+ min: Optional[List[float]] = None
34
+ max: Optional[List[float]] = None
35
+ # NOTE We define dtype as a string to easily extend new dtypes (ex: uint4)
36
+ dtype: str = ""
37
+
38
+
39
+ def to_qparam_dtype(dtype: torch.dtype) -> str:
40
+ str_type = str(dtype)
41
+ assert str_type.startswith("torch.")
42
+ return str_type[6:]
tico/utils/__init__.py ADDED
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE