tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,240 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import operator
16
+ from typing import Dict
17
+
18
+ import flatbuffers
19
+ import torch
20
+ from circle_schema import circle
21
+ from torch.export.exported_program import (
22
+ ConstantArgument,
23
+ ExportedProgram,
24
+ InputKind,
25
+ TensorArgument,
26
+ )
27
+
28
+ from tico.serialize.circle_mapping import to_circle_dtype
29
+ from tico.serialize.operators import *
30
+ from tico.serialize.circle_graph import CircleModel, CircleSubgraph
31
+ from tico.serialize.operators.hashable_opcode import OpCode
32
+ from tico.serialize.operators.node_visitor import get_node_visitors
33
+ from tico.utils import logging
34
+ from tico.utils.serialize import finalise_tensor_names
35
+
36
+
37
+ multiple_output_ops = [
38
+ torch.ops.aten.split_with_sizes.default,
39
+ torch.ops.aten.max.dim,
40
+ ]
41
+
42
+ # Build circle model from ExportedProgram
43
+ # Return raw bytes of circle model
44
+ def build_circle(edge_program: ExportedProgram) -> bytes:
45
+ logger = logging.getLogger(__name__)
46
+
47
+ builder = flatbuffers.Builder()
48
+
49
+ # Init Model
50
+ model = CircleModel()
51
+
52
+ # Add empty buffer at the front (convention)
53
+ model.add_buffer(circle.Buffer.BufferT())
54
+
55
+ # Create an empty subgraph (assume a single subgraph)
56
+ graph = CircleSubgraph(model)
57
+
58
+ # Export tensors
59
+ logger.debug("---------------Export tensors--------------")
60
+ buf_name_to_data = {name: buf for name, buf in edge_program.named_buffers()}
61
+ for node in edge_program.graph.nodes:
62
+ if node.op == "call_function":
63
+ if node.target in multiple_output_ops:
64
+ continue
65
+ node_val = node.meta["val"]
66
+ if node_val.layout != torch.strided:
67
+ raise RuntimeError(
68
+ f"Only support dense tensors (node layout: {node_val.layout})"
69
+ )
70
+ graph.add_tensor_from_node(node)
71
+ logger.debug(f"call_function: {node.name} tensor exported.")
72
+
73
+ # placeholder: function input (including parameters, buffers, constant tensors)
74
+ elif node.op == "placeholder":
75
+ # placeholder invariants
76
+ assert node.args is None or len(node.args) == 0 # Not support default param
77
+
78
+ # parameters
79
+ if node.name in edge_program.graph_signature.inputs_to_parameters:
80
+ param_name = edge_program.graph_signature.inputs_to_parameters[
81
+ node.name
82
+ ]
83
+ param_data = edge_program.state_dict[param_name]
84
+
85
+ assert isinstance(
86
+ param_data, torch.Tensor
87
+ ), "Expect parameters to be a tensor"
88
+ param_value = param_data.cpu().detach().numpy()
89
+
90
+ graph.add_tensor_from_node(node, param_value)
91
+ logger.debug(f"placeholder(param): {node.name} tensor exported.")
92
+ elif node.name in edge_program.graph_signature.inputs_to_buffers:
93
+ buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
94
+ assert buffer_name in buf_name_to_data
95
+ buffer_data = buf_name_to_data[buffer_name]
96
+ assert isinstance(
97
+ buffer_data, torch.Tensor
98
+ ), "Expect buffers to be a tensor"
99
+ buffer_value = buffer_data.cpu().detach().numpy()
100
+
101
+ graph.add_tensor_from_node(node, buffer_value)
102
+ logger.debug(f"placeholder(buffer): {node.name} tensor exported.")
103
+ elif (
104
+ node.name
105
+ in edge_program.graph_signature.inputs_to_lifted_tensor_constants
106
+ ):
107
+ ctensor_name = (
108
+ edge_program.graph_signature.inputs_to_lifted_tensor_constants[
109
+ node.name
110
+ ]
111
+ )
112
+ ctensor_data = edge_program.constants[ctensor_name]
113
+
114
+ assert isinstance(
115
+ ctensor_data, torch.Tensor
116
+ ), "Expect constant tensor to be a tensor"
117
+ ctensor_value = ctensor_data.cpu().detach().numpy()
118
+
119
+ graph.add_tensor_from_node(node, ctensor_value)
120
+ logger.debug(
121
+ f"placeholder(constant tensor): {node.name} tensor exported."
122
+ )
123
+ else:
124
+ user_inputs = [
125
+ specs
126
+ for specs in edge_program.graph_signature.input_specs
127
+ if specs.kind == InputKind.USER_INPUT
128
+ ]
129
+ constant_inputs = [
130
+ specs
131
+ for specs in user_inputs
132
+ if isinstance(specs.arg, ConstantArgument)
133
+ ]
134
+ name_to_value = {
135
+ specs.arg.name: specs.arg.value for specs in constant_inputs
136
+ }
137
+ # NoneType ConstantArgument is ignored.
138
+ if node.name in name_to_value and name_to_value[node.name] == None:
139
+ continue
140
+ graph.add_tensor_from_node(node)
141
+ logger.debug(f"placeholder: {node.name} tensor exported.")
142
+
143
+ # get_attr: retrieve parameter
144
+ elif node.op == "get_attr":
145
+ # node.name: Place where fetched attribute is saved
146
+ # node.target: Attribute in the module
147
+ attr_tensor = getattr(node.graph.owning_module, node.target)
148
+ assert isinstance(attr_tensor, torch.Tensor)
149
+
150
+ graph.add_tensor_from_scratch(
151
+ prefix=node.name,
152
+ shape=list(attr_tensor.shape),
153
+ dtype=to_circle_dtype(attr_tensor.dtype),
154
+ source_node=node,
155
+ )
156
+
157
+ logger.debug(f"get_attr: {node.name} tensor exported.")
158
+
159
+ # output: function output
160
+ elif node.op == "output":
161
+ # output node itself does not need a buffer
162
+ # argument of output node is assumed to be exported beforehand
163
+ for output in node.args[0]:
164
+ if isinstance(output, torch.fx.Node):
165
+ assert graph.has_tensor(output.name)
166
+ continue
167
+
168
+ # call_method: call method
169
+ elif node.op == "call_method":
170
+ raise AssertionError("Not yet implemented")
171
+
172
+ # call_module: call 'forward' of module
173
+ elif node.op == "call_module":
174
+ raise AssertionError("Not yet implemented")
175
+
176
+ else:
177
+ # Add more if fx.Node is extended
178
+ raise AssertionError(f"Unknown fx.Node op {node.op}")
179
+
180
+ # Register inputs
181
+ logger.debug("---------------Register inputs--------------")
182
+ for in_spec in edge_program.graph_signature.input_specs:
183
+ if in_spec.kind != InputKind.USER_INPUT:
184
+ continue
185
+ # NoneType ConstantArgument is ignored.
186
+ if isinstance(in_spec.arg, ConstantArgument) and in_spec.arg.value == None:
187
+ continue
188
+ arg_name = in_spec.arg.name
189
+ graph.add_input(arg_name)
190
+ logger.debug(f"Registered input: {arg_name}")
191
+
192
+ # Register outputs
193
+ logger.debug("---------------Register outputs--------------")
194
+ for user_output in edge_program.graph_signature.user_outputs:
195
+ if user_output == None:
196
+ logger.debug(f"Ignore 'None' output")
197
+ continue
198
+
199
+ graph.add_output(user_output)
200
+ logger.debug(f"Registered output: {user_output}")
201
+
202
+ # Export operators
203
+ logger.debug("---------------Export operators--------------")
204
+ op_codes: Dict[OpCode, int] = {}
205
+ visitors = get_node_visitors(op_codes, graph)
206
+ for node in edge_program.graph.nodes:
207
+ if node.op != "call_function":
208
+ continue
209
+
210
+ opcode = node.target
211
+ if opcode == operator.getitem:
212
+ continue
213
+ if opcode not in visitors:
214
+ raise RuntimeError(f"{opcode} is not yet supported")
215
+ circle_op = visitors[opcode].define_node(node)
216
+
217
+ if circle_op:
218
+ graph.add_operator(circle_op)
219
+ logger.debug(f"call_function: {node.name} ({opcode}) Op exported.")
220
+
221
+ # Register subgraph
222
+ finalise_tensor_names(graph)
223
+ model.subgraphs.append(graph)
224
+
225
+ # Encode operator codes
226
+ model.operatorCodes = [
227
+ code for code, _ in sorted(op_codes.items(), key=lambda x: x[1])
228
+ ]
229
+
230
+ # Description
231
+ model.description = "circle"
232
+
233
+ # Set version
234
+ model.version = 0
235
+
236
+ # Finish model
237
+ builder.Finish(model.Pack(builder), "CIR0".encode("utf8"))
238
+ buf = builder.Output()
239
+
240
+ return bytes(buf)
@@ -0,0 +1,28 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ from os.path import basename, dirname, isfile, join
17
+
18
+ from tico.utils.register_custom_op import RegisterOps
19
+
20
+
21
+ # Register custom ops to torch namespace
22
+ RegisterOps()
23
+
24
+ # Load all modules in the current directory
25
+ modules = glob.glob(join(dirname(__file__), "*.py"))
26
+ __all__ = [
27
+ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
28
+ ]
@@ -0,0 +1,43 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from circle_schema import circle
16
+
17
+
18
+ class OpCode(circle.OperatorCode.OperatorCodeT):
19
+ """
20
+ Wrapper class for operator code in circle schema
21
+ This implements __eq__ and __hash__ for use with dict()
22
+ """
23
+
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def __eq__(self, other):
28
+ if self.version != other.version:
29
+ return False
30
+
31
+ if self.builtinCode == circle.BuiltinOperator.BuiltinOperator.CUSTOM:
32
+ return self.customCode == other.customCode
33
+
34
+ return self.builtinCode == other.builtinCode
35
+
36
+ def __hash__(self):
37
+ val = (
38
+ self.deprecatedBuiltinCode,
39
+ self.customCode,
40
+ self.version,
41
+ self.builtinCode,
42
+ )
43
+ return hash(val)
@@ -0,0 +1,80 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Type, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from circle_schema import circle
21
+
22
+ from tico.serialize.circle_graph import CircleSubgraph
23
+ from tico.serialize.operators.hashable_opcode import OpCode
24
+
25
+
26
+ class NodeVisitor:
27
+ """
28
+ Node visitor for lowering edge IR to circle
29
+ """
30
+
31
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
32
+ # For setting opcode index in circle model
33
+ # This is updated during serialization
34
+ self._op_codes = op_codes
35
+ self.graph = graph
36
+
37
+ # Define circle model operator
38
+ def define_node(
39
+ self,
40
+ node: torch.fx.node.Node,
41
+ ) -> circle.Operator.OperatorT:
42
+ raise NotImplementedError("NodeVisitor must be extended.")
43
+
44
+
45
+ # container for all node visitors
46
+ _node_visitor_dict: Dict[torch._ops.OpOverload, Type[NodeVisitor]] = {}
47
+
48
+
49
+ # Decorator for each visitor
50
+ def register_node_visitor(visitor):
51
+ for target in visitor.target:
52
+ _node_visitor_dict[target] = visitor
53
+ return visitor
54
+
55
+
56
+ def get_node_visitor(target: torch._ops.OpOverload) -> Type[NodeVisitor]:
57
+ """
58
+ Get a single node visitor (for unittest purpose)
59
+ """
60
+ _visitor = _node_visitor_dict.get(target, None)
61
+
62
+ if not _visitor:
63
+ raise LookupError(f"NodeVisitor for {target} is not registered")
64
+
65
+ return _visitor
66
+
67
+
68
+ # Get all node visitors
69
+ def get_node_visitors(
70
+ op_codes: Dict[OpCode, int], graph: CircleSubgraph
71
+ ) -> Dict[torch._ops.OpOverload, NodeVisitor]:
72
+ node_visitors = {}
73
+ for target, visitor in _node_visitor_dict.items():
74
+ node_visitors[target] = visitor(op_codes, graph)
75
+
76
+ return node_visitors
77
+
78
+
79
+ def get_support_targets():
80
+ return _node_visitor_dict.keys()
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import AbsArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class AbsVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [torch.ops.aten.abs.default]
33
+
34
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
35
+ super().__init__(op_codes, graph)
36
+
37
+ def define_node(
38
+ self,
39
+ node: torch.fx.Node,
40
+ ) -> circle.Operator.OperatorT:
41
+ op_index = get_op_index(
42
+ circle.BuiltinOperator.BuiltinOperator.ABS, self._op_codes
43
+ )
44
+
45
+ args = AbsArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
46
+ input = args.input
47
+
48
+ inputs = [input]
49
+ outputs = [node]
50
+
51
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
52
+
53
+ return operator
@@ -0,0 +1,69 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import AddTensorArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class AddVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.add.Tensor,
34
+ torch.ops.aten.add.Scalar,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(
41
+ self,
42
+ node: torch.fx.Node,
43
+ ) -> circle.Operator.OperatorT:
44
+ args = AddTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
45
+ input = args.input
46
+ other = args.other
47
+
48
+ inputs = [input, other]
49
+ outputs = [node]
50
+
51
+ op_index = get_op_index(
52
+ circle.BuiltinOperator.BuiltinOperator.ADD, self._op_codes
53
+ )
54
+
55
+ inputs = [input, other]
56
+ outputs = [node]
57
+
58
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
59
+
60
+ # Op-specific option
61
+ operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.AddOptions
62
+ option = circle.AddOptions.AddOptionsT()
63
+ option.fusedActivationFunction = (
64
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
65
+ )
66
+ option.potScaleInt16 = False
67
+ operator.builtinOptions = option
68
+
69
+ return operator
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.operators.hashable_opcode import OpCode
24
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
25
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
26
+ from tico.utils.validate_args_kwargs import AliasCopyArgs
27
+
28
+
29
+ @register_node_visitor
30
+ class AliasCopyVisitor(NodeVisitor):
31
+ target: List[torch._ops.OpOverload] = [
32
+ torch.ops.aten.alias.default,
33
+ torch.ops.aten.alias_copy.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = AliasCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+
46
+ op_index = get_op_index(
47
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
48
+ )
49
+
50
+ permute = torch.IntTensor(list(range(len(input.meta["val"].shape))))
51
+
52
+ inputs = [input, permute]
53
+ outputs = [node]
54
+
55
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
+
57
+ # Op-specific option
58
+ operator.builtinOptionsType = (
59
+ circle.BuiltinOptions.BuiltinOptions.TransposeOptions
60
+ )
61
+ option = circle.TransposeOptions.TransposeOptionsT()
62
+ operator.builtinOptions = option
63
+
64
+ return operator