tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,264 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ from typing import Any, cast, Dict, final, List, Optional, TYPE_CHECKING, Union
17
+
18
+ if TYPE_CHECKING:
19
+ import torch.fx
20
+ import numpy as np
21
+ import torch
22
+ from circle_schema import circle
23
+ from torch._subclasses.fake_tensor import FakeTensor
24
+
25
+ from tico.serialize.circle_mapping import (
26
+ extract_circle_dtype,
27
+ extract_shape,
28
+ str_to_circle_dtype,
29
+ to_circle_dtype,
30
+ )
31
+ from tico.serialize.pack import pack_buffer
32
+ from tico.serialize.quant_param import QPARAM_KEY
33
+ from tico.utils.utils import to_circle_qparam
34
+
35
+ """
36
+ Type alias for const
37
+ """
38
+ _PRIMITIVE_TYPES = (
39
+ float,
40
+ int,
41
+ bool,
42
+ str,
43
+ torch.Tensor,
44
+ torch.device,
45
+ torch.dtype,
46
+ torch.layout,
47
+ )
48
+ ConstDataElement = Union[
49
+ int, float, bool, str, torch.Tensor, torch.device, torch.dtype, torch.layout
50
+ ]
51
+ ConstData = Union[ConstDataElement, List[ConstDataElement]]
52
+
53
+
54
+ def is_const(arg) -> bool:
55
+ if isinstance(arg, FakeTensor):
56
+ return False
57
+ if isinstance(arg, _PRIMITIVE_TYPES):
58
+ return True
59
+ if isinstance(arg, (tuple, list)):
60
+ return all(map(is_const, arg))
61
+ if isinstance(arg, dict):
62
+ return all(map(is_const, arg.values()))
63
+ return False
64
+
65
+
66
+ @final
67
+ class CircleModel(circle.Model.ModelT):
68
+ def __init__(self):
69
+ super().__init__()
70
+ self.subgraphs: List[circle.SubGraph.SubGraphT] = []
71
+ self.buffers: List[circle.Buffer.BufferT] = []
72
+
73
+ def add_subgraph(self, graph: circle.SubGraph.SubGraphT) -> None:
74
+ self.subgraphs.append(graph)
75
+
76
+ def add_buffer(self, buffer: circle.Buffer.BufferT) -> int:
77
+ """Return buffer id"""
78
+ self.buffers.append(buffer)
79
+ buf_id = len(self.buffers) - 1 # last index
80
+ return buf_id
81
+
82
+
83
+ @final
84
+ class CircleSubgraph(circle.SubGraph.SubGraphT):
85
+ def __init__(self, model: CircleModel):
86
+ super().__init__()
87
+ self.model: CircleModel = model
88
+ self.name: str = "subgraph"
89
+ self.inputs: List[int] = []
90
+ self.outputs: List[int] = []
91
+ self.tensors: List[circle.Tensor.TensorT] = []
92
+ self.operators: List[circle.Operator.OperatorT] = []
93
+ self.name_to_tid: Dict[str, int] = {}
94
+ self.counter: defaultdict = defaultdict(int)
95
+
96
+ # Generate a unique name with prefix.
97
+ # Naming rule
98
+ # - If no tensor has the same name with prefix, return prefix
99
+ # - Otherwise, add postfix f"_{idx}" where idx increases by 1 from 0
100
+ # Example
101
+ # If prefix = "add", this function will find a unique name in the following order.
102
+ # "add", "add_0", "add_1", ...
103
+ def _gen_unique_name_with_prefix(self, prefix: str):
104
+ name = prefix
105
+ while self.has_tensor(name):
106
+ index = self.counter[prefix]
107
+ name = f"{prefix}_{index}"
108
+ self.counter[prefix] += 1
109
+
110
+ return name
111
+
112
+ def _add_tensor(self, tensor: circle.Tensor.TensorT) -> None:
113
+ self.tensors.append(tensor)
114
+ self.name_to_tid[tensor.name] = len(self.tensors) - 1
115
+
116
+ def add_operator(self, op: circle.Operator.OperatorT) -> None:
117
+ self.operators.append(op)
118
+
119
+ def add_input(self, input_name: str) -> None:
120
+ assert input_name in self.name_to_tid, f"{input_name}"
121
+ tid = self.name_to_tid[input_name]
122
+ self.inputs.append(tid)
123
+
124
+ def add_output(self, output: Any) -> None:
125
+ if isinstance(output, str):
126
+ assert output in self.name_to_tid
127
+ output_name = output
128
+ elif isinstance(output, int | float):
129
+ # output is built-in type.
130
+ circle_tensor = self.add_const_tensor(output)
131
+ output_name = circle_tensor.name
132
+ else:
133
+ raise NotImplementedError(f"Unsupported output dtype: {type(output)}")
134
+ tid = self.name_to_tid[output_name]
135
+ self.outputs.append(tid)
136
+
137
+ def has_tensor(self, name: str):
138
+ return name in self.name_to_tid
139
+
140
+ def add_tensor_from_node(
141
+ self, node: torch.fx.node.Node, data: Optional[np.ndarray] = None
142
+ ) -> None:
143
+ tensor = circle.Tensor.TensorT()
144
+ tensor.name = self._gen_unique_name_with_prefix(node.name)
145
+ assert node.meta.get("val") is not None
146
+ tensor.type = extract_circle_dtype(node)
147
+ tensor.shape = list(extract_shape(node))
148
+ if QPARAM_KEY in node.meta:
149
+ tensor.quantization = to_circle_qparam(node.meta[QPARAM_KEY])
150
+ tensor.type = str_to_circle_dtype(node.meta[QPARAM_KEY].dtype)
151
+
152
+ buffer = circle.Buffer.BufferT()
153
+ if data is not None and isinstance(data, np.ndarray):
154
+ data = data.flatten()
155
+
156
+ if QPARAM_KEY in node.meta:
157
+ if node.meta[QPARAM_KEY].dtype == "uint4":
158
+ data = pack_buffer(data, "uint4")
159
+
160
+ # Packing np.ndarray is faster than packing bytes
161
+ buffer.data = data.view(np.uint8) # type: ignore[assignment]
162
+ else:
163
+ assert data is None
164
+ bid = self.model.add_buffer(buffer)
165
+ tensor.buffer = bid
166
+ self._add_tensor(tensor)
167
+
168
+ def add_const_tensor(self, data: ConstData) -> circle.Tensor.TensorT:
169
+ assert is_const(data)
170
+ tensor = circle.Tensor.TensorT()
171
+ tensor.name = self._gen_unique_name_with_prefix("const_tensor")
172
+ assert not self.has_tensor(tensor.name)
173
+ torch_t = torch.as_tensor(data=data)
174
+ torch_t_shape = list(torch_t.size())
175
+ tensor.type = to_circle_dtype(torch_dtype=torch_t.dtype)
176
+ tensor.shape = torch_t_shape
177
+
178
+ buffer = circle.Buffer.BufferT()
179
+ buffer.data = torch_t.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
180
+ bid = self.model.add_buffer(buffer)
181
+ tensor.buffer = bid
182
+ self._add_tensor(tensor)
183
+
184
+ return tensor
185
+
186
+ def add_tensor_from_scratch(
187
+ self, prefix: str, shape: List[int], dtype: int
188
+ ) -> circle.Tensor.TensorT:
189
+ assert isinstance(dtype, int), f"{dtype} must be integer. Use to_circle_dtype."
190
+ tensor = circle.Tensor.TensorT()
191
+ tensor.name = self._gen_unique_name_with_prefix(prefix)
192
+ tensor.type = dtype
193
+ tensor.shape = shape
194
+
195
+ buffer = circle.Buffer.BufferT()
196
+ bid = self.model.add_buffer(buffer)
197
+ tensor.buffer = bid
198
+ self._add_tensor(tensor)
199
+
200
+ return tensor
201
+
202
+ # Some operators like `full`, `arange_start_step` or `scalar_tensor` needs buffers to be in-place updated.
203
+ # TODO remove this function
204
+ def update_tensor_buffer(
205
+ self, data: ConstData, tensor_name: str = str()
206
+ ) -> circle.Tensor.TensorT:
207
+ assert is_const(data)
208
+ assert self.has_tensor(tensor_name)
209
+ data_tensor = torch.as_tensor(data=data)
210
+ data_shape = list(data_tensor.size())
211
+ op_tensor = self.tensors[self.name_to_tid[tensor_name]]
212
+ assert op_tensor.type == to_circle_dtype(
213
+ data_tensor.dtype
214
+ ), f"{op_tensor.type}, {data_tensor.dtype}"
215
+ assert op_tensor.shape == data_shape
216
+
217
+ buffer = circle.Buffer.BufferT()
218
+ # Packing np.ndarray is faster than packing bytes
219
+ buffer.data = data_tensor.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
220
+ bid = self.model.add_buffer(buffer)
221
+ op_tensor.buffer = bid
222
+
223
+ return op_tensor
224
+
225
+ def get_tid_registered(
226
+ self, node: Union[torch.fx.node.Node, circle.Tensor.TensorT]
227
+ ) -> int:
228
+ assert hasattr(node, "name"), "FIX CALLER UNLESS"
229
+
230
+ tid = self.name_to_tid.get(node.name, None)
231
+
232
+ if tid is None:
233
+ raise KeyError(f"{node}({node.name}) is not registered.")
234
+
235
+ assert tid < len(self.tensors)
236
+
237
+ return tid
238
+
239
+ def get_tensor(self, node: torch.fx.node.Node) -> circle.Tensor.TensorT:
240
+ tid = self.get_tid_registered(node)
241
+
242
+ return self.tensors[tid]
243
+
244
+ def get_buffer(self, node: torch.fx.Node) -> circle.Buffer.BufferT:
245
+ buf_id = self.get_tensor(node).buffer
246
+ return self.model.buffers[buf_id]
247
+
248
+ # TODO Rename, it doesn't only get_tid but also possibly add a new const tensor
249
+ def get_tid(
250
+ self, node: Union[torch.fx.node.Node, circle.Tensor.TensorT, ConstData]
251
+ ) -> int:
252
+ # return -1 if node is None. This is for generating CircleOutputExclude
253
+ if node == None:
254
+ return -1
255
+
256
+ if hasattr(node, "name") and node.name in self.name_to_tid:
257
+ return self.name_to_tid[node.name]
258
+
259
+ if is_const(node):
260
+ node_name = self.add_const_tensor(cast(ConstData, node)).name
261
+ return self.name_to_tid[node_name]
262
+
263
+ # Unreachable
264
+ raise RuntimeError("fx Node was not converted to tensor.")
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple, TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import numpy as np
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+
24
+ # Convert torch dtype to circle dtype
25
+ def to_circle_dtype(
26
+ torch_dtype: torch.dtype,
27
+ ) -> int:
28
+ assert isinstance(torch_dtype, torch.dtype)
29
+ dmap = {
30
+ torch.float32: circle.TensorType.TensorType.FLOAT32,
31
+ torch.float: circle.TensorType.TensorType.FLOAT32,
32
+ torch.uint8: circle.TensorType.TensorType.UINT8,
33
+ torch.int8: circle.TensorType.TensorType.INT8,
34
+ torch.int16: circle.TensorType.TensorType.INT16,
35
+ torch.short: circle.TensorType.TensorType.INT16,
36
+ torch.int32: circle.TensorType.TensorType.INT32,
37
+ torch.int: circle.TensorType.TensorType.INT32,
38
+ torch.int64: circle.TensorType.TensorType.INT64,
39
+ torch.bool: circle.TensorType.TensorType.BOOL,
40
+ }
41
+
42
+ if torch_dtype not in dmap:
43
+ raise RuntimeError(f"Unsupported dtype {torch_dtype}")
44
+
45
+ circle_type = dmap[torch_dtype]
46
+ assert circle_type is not None
47
+ return circle_type
48
+
49
+
50
+ # Convert str dtype used in QuantParam to circle dtype
51
+ def str_to_circle_dtype(
52
+ str_dtype: str,
53
+ ) -> int:
54
+ dmap = {
55
+ "float32": circle.TensorType.TensorType.FLOAT32,
56
+ "float": circle.TensorType.TensorType.FLOAT32,
57
+ "uint8": circle.TensorType.TensorType.UINT8,
58
+ "int8": circle.TensorType.TensorType.INT8,
59
+ "int16": circle.TensorType.TensorType.INT16,
60
+ "short": circle.TensorType.TensorType.INT16,
61
+ "int32": circle.TensorType.TensorType.INT32,
62
+ "int": circle.TensorType.TensorType.INT32,
63
+ "int64": circle.TensorType.TensorType.INT64,
64
+ "bool": circle.TensorType.TensorType.BOOL,
65
+ "uint4": circle.TensorType.TensorType.UINT4,
66
+ # TODO Add more dtypes
67
+ }
68
+
69
+ if str_dtype not in dmap:
70
+ raise RuntimeError(f"Unsupported dtype {str_dtype}")
71
+
72
+ circle_type = dmap[str_dtype]
73
+ assert circle_type is not None
74
+ return circle_type
75
+
76
+
77
+ # Convert circle dtype to numpy dtype
78
+ def np_dtype_from_circle_dtype(circle_dtype: int):
79
+ dmap = {
80
+ circle.TensorType.TensorType.FLOAT32: np.float32,
81
+ circle.TensorType.TensorType.UINT8: np.uint8,
82
+ circle.TensorType.TensorType.INT8: np.int8,
83
+ circle.TensorType.TensorType.INT16: np.int16,
84
+ circle.TensorType.TensorType.INT32: np.int32,
85
+ circle.TensorType.TensorType.INT64: np.int64,
86
+ circle.TensorType.TensorType.BOOL: np.bool_,
87
+ }
88
+
89
+ if circle_dtype not in dmap:
90
+ raise RuntimeError(f"Unsupported dtype {circle_dtype}")
91
+
92
+ np_dtype = dmap[circle_dtype]
93
+ assert np_dtype is not None
94
+ return np_dtype
95
+
96
+
97
+ # Return dtype of node
98
+ def extract_torch_dtype(node: torch.fx.Node) -> torch.dtype:
99
+ assert node.meta is not None
100
+ assert node.meta.get("val") is not None
101
+
102
+ val = node.meta.get("val")
103
+ val_dtype = None
104
+ if isinstance(val, torch.Tensor):
105
+ assert isinstance(val.dtype, torch.dtype)
106
+ val_dtype = val.dtype
107
+ else:
108
+ val_dtype = torch.tensor(val).dtype
109
+ return val_dtype
110
+
111
+
112
+ def extract_circle_dtype(node: torch.fx.Node) -> int:
113
+ return to_circle_dtype(extract_torch_dtype(node))
114
+
115
+
116
+ # Return shape of node
117
+ def extract_shape(node: torch.fx.Node) -> torch.Size:
118
+ assert node.meta is not None
119
+ assert node.meta.get("val") is not None
120
+
121
+ val = node.meta.get("val")
122
+ val_shape = None
123
+ if isinstance(val, torch.Tensor):
124
+ val_shape = val.size()
125
+ else:
126
+ val_shape = torch.tensor(val).shape
127
+
128
+ return val_shape
129
+
130
+
131
+ # Return stride of node
132
+ def extract_stride(node: torch.fx.Node) -> Tuple[int, ...]:
133
+ assert node.meta is not None
134
+ assert node.meta.get("val") is not None
135
+
136
+ val = node.meta.get("val")
137
+ val_stride = None
138
+ assert isinstance(val, torch.Tensor)
139
+ val_stride = val.stride()
140
+
141
+ return val_stride
142
+
143
+
144
+ def traverse_elements(iter, container_types=(list, tuple)):
145
+ if isinstance(iter, container_types):
146
+ for e in iter:
147
+ for sub_e in traverse_elements(e, container_types):
148
+ yield sub_e
149
+ else:
150
+ yield iter
151
+
152
+
153
+ def check_if_i32_range(axis: Union[list, int]):
154
+ INT32_MAX = 2**31 - 1
155
+ INT32_MIN = -(2**31)
156
+ values = list(traverse_elements(axis))
157
+ return all(INT32_MIN <= val <= INT32_MAX for val in values)
158
+
159
+
160
+ def circle_legalize_dtype_to(values, *, dtype: torch.dtype):
161
+ """
162
+ Legalize data types from `torch.int64` to `torch.int32`.
163
+
164
+ Pytorch assumes python's built-in integer type is `torch.int64`.
165
+ But, many of the circle infrastructures support only int32 type. E.g. circle-interpreter.
166
+
167
+ So, if constants has values whose range is inside [INT32_MIN <= val <= INT32_MAX], we will legalize the data type to int32.
168
+
169
+ TODO support more types
170
+
171
+ NOTE. This function must be applied only to constant values.
172
+ """
173
+ if dtype != torch.int32:
174
+ raise RuntimeError("Not supported data types.")
175
+ if not check_if_i32_range(values):
176
+ raise RuntimeError("'size' cannot be converted from int64 to int32.")
177
+ return torch.as_tensor(values, dtype=dtype)
@@ -0,0 +1,232 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import operator
16
+ from typing import Dict
17
+
18
+ import flatbuffers
19
+ import torch
20
+ from circle_schema import circle
21
+ from torch.export.exported_program import (
22
+ ConstantArgument,
23
+ ExportedProgram,
24
+ InputKind,
25
+ TensorArgument,
26
+ )
27
+
28
+ from tico.serialize.circle_mapping import to_circle_dtype
29
+ from tico.serialize.operators import *
30
+ from tico.serialize.circle_graph import CircleModel, CircleSubgraph
31
+ from tico.serialize.operators.hashable_opcode import OpCode
32
+ from tico.serialize.operators.node_visitor import get_node_visitors
33
+ from tico.utils import logging
34
+
35
+
36
+ multiple_output_ops = [
37
+ torch.ops.aten.split_with_sizes.default,
38
+ ]
39
+
40
+ # Build circle model from ExportedProgram
41
+ # Return raw bytes of circle model
42
+ def build_circle(edge_program: ExportedProgram) -> bytes:
43
+ logger = logging.getLogger(__name__)
44
+
45
+ builder = flatbuffers.Builder()
46
+
47
+ # Init Model
48
+ model = CircleModel()
49
+
50
+ # Add empty buffer at the front (convention)
51
+ model.add_buffer(circle.Buffer.BufferT())
52
+
53
+ # Create an empty subgraph (assume a single subgraph)
54
+ graph = CircleSubgraph(model)
55
+
56
+ # Export tensors
57
+ logger.debug("---------------Export tensors--------------")
58
+ buf_name_to_data = {name: buf for name, buf in edge_program.named_buffers()}
59
+ for node in edge_program.graph.nodes:
60
+ if node.op == "call_function":
61
+ if node.target in multiple_output_ops:
62
+ continue
63
+ node_val = node.meta["val"]
64
+ if node_val.layout != torch.strided:
65
+ raise RuntimeError(
66
+ f"Only support dense tensors (node layout: {node_val.layout})"
67
+ )
68
+ graph.add_tensor_from_node(node)
69
+ logger.debug(f"call_function: {node.name} tensor exported.")
70
+
71
+ # placeholder: function input (including parameters, buffers, constant tensors)
72
+ elif node.op == "placeholder":
73
+ # placeholder invariants
74
+ assert node.args is None or len(node.args) == 0 # Not support default param
75
+
76
+ # parameters
77
+ if node.name in edge_program.graph_signature.inputs_to_parameters:
78
+ param_name = edge_program.graph_signature.inputs_to_parameters[
79
+ node.name
80
+ ]
81
+ param_data = edge_program.state_dict[param_name]
82
+
83
+ assert isinstance(
84
+ param_data, torch.Tensor
85
+ ), "Expect parameters to be a tensor"
86
+ param_value = param_data.cpu().detach().numpy()
87
+
88
+ graph.add_tensor_from_node(node, param_value)
89
+ logger.debug(f"placeholder(param): {node.name} tensor exported.")
90
+ elif node.name in edge_program.graph_signature.inputs_to_buffers:
91
+ buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
92
+ assert buffer_name in buf_name_to_data
93
+ buffer_data = buf_name_to_data[buffer_name]
94
+ assert isinstance(
95
+ buffer_data, torch.Tensor
96
+ ), "Expect buffers to be a tensor"
97
+ buffer_value = buffer_data.cpu().detach().numpy()
98
+
99
+ graph.add_tensor_from_node(node, buffer_value)
100
+ logger.debug(f"placeholder(buffer): {node.name} tensor exported.")
101
+ elif (
102
+ node.name
103
+ in edge_program.graph_signature.inputs_to_lifted_tensor_constants
104
+ ):
105
+ ctensor_name = (
106
+ edge_program.graph_signature.inputs_to_lifted_tensor_constants[
107
+ node.name
108
+ ]
109
+ )
110
+ ctensor_data = edge_program.constants[ctensor_name]
111
+
112
+ assert isinstance(
113
+ ctensor_data, torch.Tensor
114
+ ), "Expect constant tensor to be a tensor"
115
+ ctensor_value = ctensor_data.cpu().detach().numpy()
116
+
117
+ graph.add_tensor_from_node(node, ctensor_value)
118
+ logger.debug(
119
+ f"placeholder(constant tensor): {node.name} tensor exported."
120
+ )
121
+ else:
122
+ user_inputs = [
123
+ specs
124
+ for specs in edge_program.graph_signature.input_specs
125
+ if specs.kind == InputKind.USER_INPUT
126
+ ]
127
+ constant_inputs = [
128
+ specs
129
+ for specs in user_inputs
130
+ if isinstance(specs.arg, ConstantArgument)
131
+ ]
132
+ name_to_value = {
133
+ specs.arg.name: specs.arg.value for specs in constant_inputs
134
+ }
135
+ # NoneType ConstantArgument is ignored.
136
+ if node.name in name_to_value and name_to_value[node.name] == None:
137
+ continue
138
+ graph.add_tensor_from_node(node)
139
+ logger.debug(f"placeholder: {node.name} tensor exported.")
140
+
141
+ # get_attr: retrieve parameter
142
+ elif node.op == "get_attr":
143
+ # node.name: Place where fetched attribute is saved
144
+ # node.target: Attribute in the module
145
+ attr_tensor = getattr(node.graph.owning_module, node.target)
146
+ assert isinstance(attr_tensor, torch.Tensor)
147
+
148
+ graph.add_tensor_from_scratch(
149
+ prefix=node.name,
150
+ shape=list(attr_tensor.shape),
151
+ dtype=to_circle_dtype(attr_tensor.dtype),
152
+ )
153
+
154
+ logger.debug(f"get_attr: {node.name} tensor exported.")
155
+
156
+ # output: function output
157
+ elif node.op == "output":
158
+ # output node itself does not need a buffer
159
+ # argument of output node is assumed to be exported beforehand
160
+ for output in node.args[0]:
161
+ if isinstance(output, torch.fx.Node):
162
+ assert graph.has_tensor(output.name)
163
+ continue
164
+
165
+ # call_method: call method
166
+ elif node.op == "call_method":
167
+ raise AssertionError("Not yet implemented")
168
+
169
+ # call_module: call 'forward' of module
170
+ elif node.op == "call_module":
171
+ raise AssertionError("Not yet implemented")
172
+
173
+ else:
174
+ # Add more if fx.Node is extended
175
+ raise AssertionError(f"Unknown fx.Node op {node.op}")
176
+
177
+ # Register inputs
178
+ logger.debug("---------------Register inputs--------------")
179
+ for in_spec in edge_program.graph_signature.input_specs:
180
+ if in_spec.kind != InputKind.USER_INPUT:
181
+ continue
182
+ # NoneType ConstantArgument is ignored.
183
+ if isinstance(in_spec.arg, ConstantArgument) and in_spec.arg.value == None:
184
+ continue
185
+ arg_name = in_spec.arg.name
186
+ graph.add_input(arg_name)
187
+ logger.debug(f"Registered input: {arg_name}")
188
+
189
+ # Register outputs
190
+ logger.debug("---------------Register outputs--------------")
191
+ for user_output in edge_program.graph_signature.user_outputs:
192
+ graph.add_output(user_output)
193
+ logger.debug(f"Registered output: {user_output}")
194
+
195
+ # Export operators
196
+ logger.debug("---------------Export operators--------------")
197
+ op_codes: Dict[OpCode, int] = {}
198
+ visitors = get_node_visitors(op_codes, graph)
199
+ for node in edge_program.graph.nodes:
200
+ if node.op != "call_function":
201
+ continue
202
+
203
+ opcode = node.target
204
+ if opcode == operator.getitem:
205
+ continue
206
+ if opcode not in visitors:
207
+ raise RuntimeError(f"{opcode} is not yet supported")
208
+ circle_op = visitors[opcode].define_node(node)
209
+
210
+ if circle_op:
211
+ graph.add_operator(circle_op)
212
+ logger.debug(f"call_function: {node.name} ({opcode}) Op exported.")
213
+
214
+ # Register subgraph
215
+ model.subgraphs.append(graph)
216
+
217
+ # Encode operator codes
218
+ model.operatorCodes = [
219
+ code for code, _ in sorted(op_codes.items(), key=lambda x: x[1])
220
+ ]
221
+
222
+ # Description
223
+ model.description = "circle"
224
+
225
+ # Set version
226
+ model.version = 0
227
+
228
+ # Finish model
229
+ builder.Finish(model.Pack(builder), "CIR0".encode("utf8"))
230
+ buf = builder.Output()
231
+
232
+ return bytes(buf)
@@ -0,0 +1,28 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ from os.path import basename, dirname, isfile, join
17
+
18
+ from tico.utils.register_custom_op import RegisterOps
19
+
20
+
21
+ # Register custom ops to torch namespace
22
+ RegisterOps()
23
+
24
+ # Load all modules in the current directory
25
+ modules = glob.glob(join(dirname(__file__), "*.py"))
26
+ __all__ = [
27
+ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
28
+ ]