tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,66 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MeanDimArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class MeanVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.mean.dim]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.MEAN, self._op_codes
43
+ )
44
+
45
+ args = MeanDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+ dim = args.dim
48
+ keep_dims = args.keepdim
49
+
50
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
51
+ inputs = [input, dim_i32]
52
+ outputs = [node]
53
+
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = (
58
+ circle.BuiltinOptions.BuiltinOptions.ReducerOptions
59
+ )
60
+ option = circle.ReducerOptions.ReducerOptionsT()
61
+ if keep_dims:
62
+ option.keepDims = keep_dims
63
+
64
+ operator.builtinOptions = option
65
+
66
+ return operator
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MinimumArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class MinimumVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.minimum.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
40
+ args = MinimumArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
41
+ input = args.input
42
+ other = args.other
43
+
44
+ op_index = get_op_index(
45
+ circle.BuiltinOperator.BuiltinOperator.MINIMUM, self._op_codes
46
+ )
47
+
48
+ inputs = [input, other]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph, is_const
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MatmulArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class MatmulDefaultVisitor(NodeVisitor):
32
+ """
33
+ Convert matmul to equavalent BatchMatMul or FullyConnected with Transpose.
34
+ """
35
+
36
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
37
+
38
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
+ super().__init__(op_codes, graph)
40
+
41
+ # NOTE: Matmul is equivalent to Batch MatMul (batch=1)
42
+ def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
43
+ def set_bmm_option(operator):
44
+ operator.builtinOptionsType = (
45
+ circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
46
+ )
47
+ option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
48
+ option.adjointLhs, option.adjointRhs = False, False
49
+ option.asymmetricQuantizeInputs = False
50
+ operator.builtinOptions = option
51
+
52
+ op_index = get_op_index(
53
+ circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
54
+ )
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+ set_bmm_option(operator)
57
+
58
+ return operator
59
+
60
+ def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
61
+ def set_transpose_option(operator):
62
+ operator.builtinOptionsType = (
63
+ circle.BuiltinOptions.BuiltinOptions.TransposeOptions
64
+ )
65
+ option = circle.TransposeOptions.TransposeOptionsT()
66
+ operator.builtinOptions = option
67
+
68
+ transpose_op_index = get_op_index(
69
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
70
+ )
71
+ operator = create_builtin_operator(
72
+ self.graph, transpose_op_index, inputs, outputs
73
+ )
74
+ set_transpose_option(operator)
75
+ return operator
76
+
77
+ def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
78
+ def set_fc_option(operator):
79
+ operator.builtinOptionsType = (
80
+ circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
81
+ )
82
+ option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
83
+
84
+ option.fusedActivationFunction = (
85
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
86
+ )
87
+ option.weightsFormat = (
88
+ circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
89
+ )
90
+ option.keepNumDims = False
91
+ option.asymmetricQuantizeInputs = False
92
+ option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
93
+
94
+ operator.builtinOptions = option
95
+
96
+ fc_op_index = get_op_index(
97
+ circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
98
+ )
99
+ operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
100
+ set_fc_option(operator)
101
+ return operator
102
+
103
+ """
104
+ Define FullyConnnected with Tranpose operator.
105
+ Note that those sets of operators are equivalent.
106
+ (1) Matmul
107
+ matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
108
+
109
+ (2) Transpose + FullyConneccted
110
+ transpose( rhs[K, W'] ) -> trs_output[W', K]
111
+ fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
112
+ """
113
+
114
+ def define_fc_with_transpose(
115
+ self, node, inputs, outputs
116
+ ) -> circle.Operator.OperatorT:
117
+ lhs, rhs = inputs
118
+
119
+ # get transpose shape
120
+ rhs_tid: int = self.graph.get_tid_registered(rhs)
121
+ rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
122
+ rhs_name: str = rhs.name
123
+ rhs_type: int = rhs_tensor.type
124
+ rhs_shape: List[int] = rhs_tensor.shape
125
+ assert len(rhs_shape) == 2, len(rhs_shape)
126
+ rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
127
+
128
+ # create transpose output tensor
129
+ trs_output = self.graph.add_tensor_from_scratch(
130
+ prefix=f"{rhs_name}_transposed_output",
131
+ shape=rhs_shape_transpose,
132
+ dtype=rhs_type,
133
+ source_node=node,
134
+ )
135
+ trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
136
+ trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
137
+ self.graph.add_operator(trs_operator)
138
+
139
+ # define fc node
140
+ fc_input = lhs
141
+ fc_weight = trs_output
142
+ fc_shape = [fc_weight.shape[0]]
143
+ fc_bias = self.graph.add_const_tensor(
144
+ data=[0.0] * fc_shape[0], source_node=node
145
+ )
146
+
147
+ operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
148
+
149
+ return operator
150
+
151
+ def define_node(
152
+ self, node: torch.fx.Node, prior_latency=True
153
+ ) -> circle.Operator.OperatorT:
154
+ """
155
+ NOTE: Possibility of accuracy-latency trade-off
156
+ From ONE compiler's perspective:
157
+ - BMM uses per-tensor quantization for both rhs and lhs.
158
+ - FC uses per-channel quantization for weight and per-tensor for input.
159
+ Thus, FC is better in terms of accuracy.
160
+ FC necessarily involves an additional transpose operation to be identical with mm.
161
+ If transposed operand is const, it can be optimized by constant folding.
162
+ Thus, convert FC only if tranpose can be folded.
163
+ TODO set prior_latency outside
164
+ """
165
+ args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
166
+ input = args.input
167
+ other = args.other
168
+
169
+ inputs = [input, other]
170
+ outputs = [node]
171
+
172
+ if not is_const(other) and prior_latency:
173
+ operator = self.define_bmm_node(inputs, outputs)
174
+ else:
175
+ operator = self.define_fc_with_transpose(node, inputs, outputs)
176
+
177
+ return operator
@@ -0,0 +1,99 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import MulScalarArgs, MulTensorArgs
28
+
29
+
30
+ class BaseMulVisitor(NodeVisitor):
31
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
32
+ super().__init__(op_codes, graph)
33
+
34
+ def define_node(
35
+ self,
36
+ node: torch.fx.node.Node,
37
+ ) -> circle.Operator.OperatorT:
38
+ op_index = get_op_index(
39
+ circle.BuiltinOperator.BuiltinOperator.MUL, self._op_codes
40
+ )
41
+
42
+ inputs = list(node.args)
43
+ outputs = [node]
44
+
45
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
46
+
47
+ # Op-specific option
48
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.MulOptions
49
+ option = circle.MulOptions.MulOptionsT()
50
+ option.fusedActivationFunction = (
51
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
52
+ )
53
+ operator.builtinOptions = option
54
+
55
+ return operator
56
+
57
+
58
+ @register_node_visitor
59
+ class MulTensorVisitor(BaseMulVisitor):
60
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.mul.Tensor]
61
+
62
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
63
+ super().__init__(op_codes, graph)
64
+
65
+ def define_node(
66
+ self,
67
+ node: torch.fx.Node,
68
+ ) -> circle.Operator.OperatorT:
69
+ args = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
70
+ input = args.input
71
+ other = args.other
72
+
73
+ operator = super().define_node(
74
+ node,
75
+ )
76
+
77
+ return operator
78
+
79
+
80
+ @register_node_visitor
81
+ class MulScalarVisitor(BaseMulVisitor):
82
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.mul.Scalar]
83
+
84
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
85
+ super().__init__(op_codes, graph)
86
+
87
+ def define_node(
88
+ self,
89
+ node: torch.fx.Node,
90
+ ) -> circle.Operator.OperatorT:
91
+ args = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
92
+ input = args.input
93
+ other = args.other
94
+
95
+ operator = super().define_node(
96
+ node,
97
+ )
98
+
99
+ return operator
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import NeTensorArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class NeVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.ne.Scalar,
34
+ torch.ops.aten.ne.Tensor,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.NOT_EQUAL, self._op_codes
43
+ )
44
+
45
+ args = NeTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+ other = args.other
48
+
49
+ inputs = [input, other]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import NegArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class NegVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.neg.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.NEG,
43
+ self._op_codes,
44
+ )
45
+ args = NegArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ # Op-specific option
54
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.NegOptions
55
+ option = circle.NegOptions.NegOptionsT()
56
+
57
+ operator.builtinOptions = option
58
+
59
+ return operator
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.circle_mapping import circle_legalize_dtype_to
25
+ from tico.serialize.operators.hashable_opcode import OpCode
26
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
+ from tico.utils.validate_args_kwargs import PermuteArgs
29
+
30
+
31
+ @register_node_visitor
32
+ class PermuteVisitor(NodeVisitor):
33
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.permute.default]
34
+
35
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
36
+ super().__init__(op_codes, graph)
37
+
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ op_index = get_op_index(
43
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE,
44
+ self._op_codes,
45
+ )
46
+
47
+ args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
48
+ input = args.input
49
+ dims = args.dims
50
+
51
+ dims_i32 = circle_legalize_dtype_to(dims, dtype=torch.int32)
52
+ inputs = [input, dims_i32]
53
+ outputs = [node]
54
+
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+
57
+ # Op-specific option
58
+ operator.builtinOptionsType = (
59
+ circle.BuiltinOptions.BuiltinOptions.TransposeOptions
60
+ )
61
+ option = circle.TransposeOptions.TransposeOptionsT()
62
+
63
+ operator.builtinOptions = option
64
+
65
+ return operator