tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,319 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ from typing import Any, cast, Dict, final, List, Optional, TYPE_CHECKING, Union
17
+
18
+ if TYPE_CHECKING:
19
+ import torch.fx
20
+ import numpy as np
21
+ import torch
22
+ from circle_schema import circle
23
+ from torch._subclasses.fake_tensor import FakeTensor
24
+
25
+ from tico.serialize.circle_mapping import (
26
+ extract_circle_dtype,
27
+ extract_shape,
28
+ str_to_circle_dtype,
29
+ to_circle_dtype,
30
+ )
31
+ from tico.serialize.pack import pack_buffer
32
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
33
+ from tico.utils.utils import to_circle_qparam
34
+
35
+ """
36
+ Type alias for const
37
+ """
38
+ _PRIMITIVE_TYPES = (
39
+ float,
40
+ int,
41
+ bool,
42
+ str,
43
+ torch.Tensor,
44
+ torch.device,
45
+ torch.dtype,
46
+ torch.layout,
47
+ )
48
+ ConstDataElement = Union[
49
+ int, float, bool, str, torch.Tensor, torch.device, torch.dtype, torch.layout
50
+ ]
51
+ ConstData = Union[ConstDataElement, List[ConstDataElement]]
52
+
53
+
54
+ def is_const(arg) -> bool:
55
+ if isinstance(arg, FakeTensor):
56
+ return False
57
+ if isinstance(arg, _PRIMITIVE_TYPES):
58
+ return True
59
+ if isinstance(arg, (tuple, list)):
60
+ return all(map(is_const, arg))
61
+ if isinstance(arg, dict):
62
+ return all(map(is_const, arg.values()))
63
+ return False
64
+
65
+
66
+ @final
67
+ class CircleModel(circle.Model.ModelT):
68
+ def __init__(self):
69
+ super().__init__()
70
+ self.subgraphs: List[circle.SubGraph.SubGraphT] = []
71
+ self.buffers: List[circle.Buffer.BufferT] = []
72
+
73
+ def add_subgraph(self, graph: circle.SubGraph.SubGraphT) -> None:
74
+ self.subgraphs.append(graph)
75
+
76
+ def add_buffer(self, buffer: circle.Buffer.BufferT) -> int:
77
+ """Return buffer id"""
78
+ self.buffers.append(buffer)
79
+ buf_id = len(self.buffers) - 1 # last index
80
+ return buf_id
81
+
82
+
83
+ @final
84
+ class CircleSubgraph(circle.SubGraph.SubGraphT):
85
+ def __init__(self, model: CircleModel):
86
+ super().__init__()
87
+ self.model: CircleModel = model
88
+ self.name: str = "subgraph"
89
+ self.inputs: List[int] = []
90
+ self.outputs: List[int] = []
91
+ self.tensors: List[circle.Tensor.TensorT] = []
92
+ self.operators: List[circle.Operator.OperatorT] = []
93
+ self.name_to_tid: Dict[str, int] = {}
94
+ # Mapping from Circle tensor names to their originating FX nodes.
95
+ # Used to trace back tensor definitions to their source and finalize
96
+ # human-readable tensor names after serialization.
97
+ self.name_to_node: Dict[str, torch.fx.Node] = {}
98
+ self.counter: defaultdict = defaultdict(int)
99
+
100
+ # Generate a unique name with prefix.
101
+ # Naming rule
102
+ # - If no tensor has the same name with prefix, return prefix
103
+ # - Otherwise, add postfix f"_{idx}" where idx increases by 1 from 0
104
+ # Example
105
+ # If prefix = "add", this function will find a unique name in the following order.
106
+ # "add", "add_0", "add_1", ...
107
+ def _gen_unique_name_with_prefix(self, prefix: str):
108
+ name = prefix
109
+ while self.has_tensor(name):
110
+ index = self.counter[prefix]
111
+ name = f"{prefix}_{index}"
112
+ self.counter[prefix] += 1
113
+
114
+ return name
115
+
116
+ def _add_tensor(self, tensor: circle.Tensor.TensorT) -> None:
117
+ self.tensors.append(tensor)
118
+ assert tensor.name not in self.name_to_tid
119
+ self.name_to_tid[tensor.name] = len(self.tensors) - 1
120
+
121
+ def add_operator(self, op: circle.Operator.OperatorT) -> None:
122
+ self.operators.append(op)
123
+
124
+ def add_input(self, input_name: str) -> None:
125
+ assert input_name in self.name_to_tid, f"{input_name}"
126
+ tid = self.name_to_tid[input_name]
127
+ self.inputs.append(tid)
128
+
129
+ def add_output(self, output: Any) -> None:
130
+ if isinstance(output, str):
131
+ assert output in self.name_to_tid
132
+ output_name = output
133
+ elif isinstance(output, int | float):
134
+ # output is built-in type.
135
+ circle_tensor = self.add_const_tensor(output)
136
+ output_name = circle_tensor.name
137
+ else:
138
+ raise NotImplementedError(f"Unsupported output dtype: {type(output)}")
139
+ tid = self.name_to_tid[output_name]
140
+ self.outputs.append(tid)
141
+
142
+ def has_tensor(self, name: str):
143
+ return name in self.name_to_tid
144
+
145
+ def add_tensor_from_node(
146
+ self, node: torch.fx.Node, data: Optional[np.ndarray] = None
147
+ ) -> None:
148
+ tensor = circle.Tensor.TensorT()
149
+ tensor.name = self._gen_unique_name_with_prefix(node.name)
150
+ assert tensor.name not in self.name_to_node
151
+ self.name_to_node[tensor.name] = node
152
+ assert node.meta.get("val") is not None
153
+ tensor.type = extract_circle_dtype(node)
154
+ tensor.shape = list(extract_shape(node))
155
+ if QPARAM_KEY in node.meta:
156
+ tensor.quantization = to_circle_qparam(node.meta[QPARAM_KEY])
157
+ tensor.type = str_to_circle_dtype(node.meta[QPARAM_KEY].dtype)
158
+
159
+ buffer = circle.Buffer.BufferT()
160
+ if data is not None and isinstance(data, np.ndarray):
161
+ data = data.flatten()
162
+
163
+ if QPARAM_KEY in node.meta:
164
+ if node.meta[QPARAM_KEY].dtype == "uint4":
165
+ data = pack_buffer(data, "uint4")
166
+
167
+ # Packing np.ndarray is faster than packing bytes
168
+ buffer.data = data.view(np.uint8) # type: ignore[assignment]
169
+ else:
170
+ assert data is None
171
+ bid = self.model.add_buffer(buffer)
172
+ tensor.buffer = bid
173
+ self._add_tensor(tensor)
174
+
175
+ def add_const_tensor(
176
+ self, data: ConstData, source_node: Optional[torch.fx.Node] = None
177
+ ) -> circle.Tensor.TensorT:
178
+ assert is_const(data)
179
+ tensor = circle.Tensor.TensorT()
180
+ tensor.name = self._gen_unique_name_with_prefix("const_tensor")
181
+ assert tensor.name not in self.name_to_node
182
+ if source_node is not None:
183
+ self.name_to_node[tensor.name] = source_node
184
+ assert not self.has_tensor(tensor.name)
185
+ torch_t = torch.as_tensor(data=data)
186
+ torch_t_shape = list(torch_t.size())
187
+ tensor.type = to_circle_dtype(torch_dtype=torch_t.dtype)
188
+ tensor.shape = torch_t_shape
189
+
190
+ buffer = circle.Buffer.BufferT()
191
+ buffer.data = torch_t.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
192
+ bid = self.model.add_buffer(buffer)
193
+ tensor.buffer = bid
194
+ self._add_tensor(tensor)
195
+
196
+ return tensor
197
+
198
+ def add_tensor_from_scratch(
199
+ self,
200
+ prefix: str,
201
+ shape: List[int],
202
+ dtype: int,
203
+ qparam: Optional[QuantParam] = None,
204
+ source_node: Optional[torch.fx.Node] = None,
205
+ ) -> circle.Tensor.TensorT:
206
+ """
207
+ Create a new tensor and register it into the Circle subgraph from scratch.
208
+
209
+ This function is used to allocate tensors that are not directly derived from
210
+ values in the FX graph, such as those created by padding or shape-generating
211
+ operators.
212
+
213
+ If a `source_node` is provided, it is used to enrich the tensor's metadata
214
+ (e.g., by associating the tensor with the module hierarchy path stored in
215
+ the node's `nn_module_stack`). This enables better traceability and more
216
+ informative tensor names in the final Circle model.
217
+
218
+ Parameters
219
+ ----------
220
+ prefix : str
221
+ A name prefix used to generate a unique tensor name.
222
+ shape : List[int]
223
+ The shape of the tensor.
224
+ dtype : int
225
+ The Circle-compatible dtype of the tensor. Use `to_circle_dtype()` to convert.
226
+ qparam : Optional[QuantParam]
227
+ Optional quantization parameters to apply to the tensor.
228
+ source_node : Optional[torch.fx.Node]
229
+ If provided, the FX node from which this tensor originates. Used to generate
230
+ a richer name and track module origin.
231
+
232
+ Returns
233
+ -------
234
+ circle.Tensor.TensorT
235
+ The newly created and registered tensor.
236
+ """
237
+ assert isinstance(dtype, int), f"{dtype} must be integer. Use to_circle_dtype."
238
+ tensor = circle.Tensor.TensorT()
239
+ tensor.name = self._gen_unique_name_with_prefix(prefix)
240
+ assert tensor.name not in self.name_to_node
241
+ if source_node is not None:
242
+ self.name_to_node[tensor.name] = source_node
243
+ tensor.shape = shape
244
+ if qparam is not None:
245
+ tensor.quantization = to_circle_qparam(qparam)
246
+ tensor.type = str_to_circle_dtype(qparam.dtype)
247
+ else:
248
+ tensor.type = dtype
249
+
250
+ buffer = circle.Buffer.BufferT()
251
+ bid = self.model.add_buffer(buffer)
252
+ tensor.buffer = bid
253
+ self._add_tensor(tensor)
254
+
255
+ return tensor
256
+
257
+ # Some operators like `full`, `arange_start_step` or `scalar_tensor` needs buffers to be in-place updated.
258
+ # TODO remove this function
259
+ def update_tensor_buffer(
260
+ self, data: ConstData, tensor_name: str = str()
261
+ ) -> circle.Tensor.TensorT:
262
+ assert is_const(data)
263
+ assert self.has_tensor(tensor_name)
264
+ data_tensor = torch.as_tensor(data=data)
265
+ data_shape = list(data_tensor.size())
266
+ op_tensor = self.tensors[self.name_to_tid[tensor_name]]
267
+ assert op_tensor.type == to_circle_dtype(
268
+ data_tensor.dtype
269
+ ), f"{op_tensor.type}, {data_tensor.dtype}"
270
+ assert op_tensor.shape == data_shape
271
+
272
+ buffer = circle.Buffer.BufferT()
273
+ # Packing np.ndarray is faster than packing bytes
274
+ buffer.data = data_tensor.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
275
+ bid = self.model.add_buffer(buffer)
276
+ op_tensor.buffer = bid
277
+
278
+ return op_tensor
279
+
280
+ def get_tid_registered(
281
+ self, node: Union[torch.fx.node.Node, circle.Tensor.TensorT]
282
+ ) -> int:
283
+ assert hasattr(node, "name"), "FIX CALLER UNLESS"
284
+
285
+ tid = self.name_to_tid.get(node.name, None)
286
+
287
+ if tid is None:
288
+ raise KeyError(f"{node}({node.name}) is not registered.")
289
+
290
+ assert tid < len(self.tensors)
291
+
292
+ return tid
293
+
294
+ def get_tensor(self, node: torch.fx.node.Node) -> circle.Tensor.TensorT:
295
+ tid = self.get_tid_registered(node)
296
+
297
+ return self.tensors[tid]
298
+
299
+ def get_buffer(self, node: torch.fx.Node) -> circle.Buffer.BufferT:
300
+ buf_id = self.get_tensor(node).buffer
301
+ return self.model.buffers[buf_id]
302
+
303
+ # TODO Rename, it doesn't only get_tid but also possibly add a new const tensor
304
+ def get_tid(
305
+ self, node: Union[torch.fx.Node, circle.Tensor.TensorT, ConstData]
306
+ ) -> int:
307
+ # return -1 if node is None. This is for generating CircleOutputExclude
308
+ if node == None:
309
+ return -1
310
+
311
+ if hasattr(node, "name") and node.name in self.name_to_tid:
312
+ return self.name_to_tid[node.name]
313
+
314
+ if is_const(node):
315
+ node_name = self.add_const_tensor(cast(ConstData, node)).name
316
+ return self.name_to_tid[node_name]
317
+
318
+ # Unreachable
319
+ raise RuntimeError("fx Node was not converted to tensor.")
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple, TYPE_CHECKING, Union
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import numpy as np
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+
24
+ # Convert torch dtype to circle dtype
25
+ def to_circle_dtype(
26
+ torch_dtype: torch.dtype,
27
+ ) -> int:
28
+ assert isinstance(torch_dtype, torch.dtype)
29
+ dmap = {
30
+ torch.float32: circle.TensorType.TensorType.FLOAT32,
31
+ torch.float: circle.TensorType.TensorType.FLOAT32,
32
+ torch.uint8: circle.TensorType.TensorType.UINT8,
33
+ torch.int8: circle.TensorType.TensorType.INT8,
34
+ torch.int16: circle.TensorType.TensorType.INT16,
35
+ torch.short: circle.TensorType.TensorType.INT16,
36
+ torch.int32: circle.TensorType.TensorType.INT32,
37
+ torch.int: circle.TensorType.TensorType.INT32,
38
+ torch.int64: circle.TensorType.TensorType.INT64,
39
+ torch.bool: circle.TensorType.TensorType.BOOL,
40
+ }
41
+
42
+ if torch_dtype not in dmap:
43
+ raise RuntimeError(f"Unsupported dtype {torch_dtype}")
44
+
45
+ circle_type = dmap[torch_dtype]
46
+ assert circle_type is not None
47
+ return circle_type
48
+
49
+
50
+ # Convert str dtype used in QuantParam to circle dtype
51
+ def str_to_circle_dtype(
52
+ str_dtype: str,
53
+ ) -> int:
54
+ dmap = {
55
+ "float32": circle.TensorType.TensorType.FLOAT32,
56
+ "float": circle.TensorType.TensorType.FLOAT32,
57
+ "uint8": circle.TensorType.TensorType.UINT8,
58
+ "int8": circle.TensorType.TensorType.INT8,
59
+ "int16": circle.TensorType.TensorType.INT16,
60
+ "short": circle.TensorType.TensorType.INT16,
61
+ "int32": circle.TensorType.TensorType.INT32,
62
+ "int": circle.TensorType.TensorType.INT32,
63
+ "int64": circle.TensorType.TensorType.INT64,
64
+ "bool": circle.TensorType.TensorType.BOOL,
65
+ "uint4": circle.TensorType.TensorType.UINT4,
66
+ # TODO Add more dtypes
67
+ }
68
+
69
+ if str_dtype not in dmap:
70
+ raise RuntimeError(f"Unsupported dtype {str_dtype}")
71
+
72
+ circle_type = dmap[str_dtype]
73
+ assert circle_type is not None
74
+ return circle_type
75
+
76
+
77
+ # Convert circle dtype to numpy dtype
78
+ def np_dtype_from_circle_dtype(circle_dtype: int):
79
+ dmap = {
80
+ circle.TensorType.TensorType.FLOAT32: np.float32,
81
+ circle.TensorType.TensorType.UINT8: np.uint8,
82
+ circle.TensorType.TensorType.INT8: np.int8,
83
+ circle.TensorType.TensorType.INT16: np.int16,
84
+ circle.TensorType.TensorType.INT32: np.int32,
85
+ circle.TensorType.TensorType.INT64: np.int64,
86
+ circle.TensorType.TensorType.BOOL: np.bool_,
87
+ }
88
+
89
+ if circle_dtype not in dmap:
90
+ raise RuntimeError(f"Unsupported dtype {circle_dtype}")
91
+
92
+ np_dtype = dmap[circle_dtype]
93
+ assert np_dtype is not None
94
+ return np_dtype
95
+
96
+
97
+ # Return dtype of node
98
+ def extract_torch_dtype(node: torch.fx.Node) -> torch.dtype:
99
+ assert node.meta is not None
100
+ assert node.meta.get("val") is not None
101
+
102
+ val = node.meta.get("val")
103
+ val_dtype = None
104
+ if isinstance(val, torch.Tensor):
105
+ assert isinstance(val.dtype, torch.dtype)
106
+ val_dtype = val.dtype
107
+ else:
108
+ val_dtype = torch.tensor(val).dtype
109
+ return val_dtype
110
+
111
+
112
+ def extract_circle_dtype(node: torch.fx.Node) -> int:
113
+ return to_circle_dtype(extract_torch_dtype(node))
114
+
115
+
116
+ # Return shape of node
117
+ def extract_shape(node: torch.fx.Node) -> torch.Size:
118
+ assert node.meta is not None
119
+ assert node.meta.get("val") is not None
120
+
121
+ val = node.meta.get("val")
122
+ val_shape = None
123
+ if isinstance(val, torch.Tensor):
124
+ val_shape = val.size()
125
+ else:
126
+ val_shape = torch.tensor(val).shape
127
+
128
+ return val_shape
129
+
130
+
131
+ # Return stride of node
132
+ def extract_stride(node: torch.fx.Node) -> Tuple[int, ...]:
133
+ assert node.meta is not None
134
+ assert node.meta.get("val") is not None
135
+
136
+ val = node.meta.get("val")
137
+ val_stride = None
138
+ assert isinstance(val, torch.Tensor)
139
+ val_stride = val.stride()
140
+
141
+ return val_stride
142
+
143
+
144
+ def traverse_elements(iter, container_types=(list, tuple)):
145
+ if isinstance(iter, container_types):
146
+ for e in iter:
147
+ for sub_e in traverse_elements(e, container_types):
148
+ yield sub_e
149
+ else:
150
+ yield iter
151
+
152
+
153
+ def check_if_i32_range(axis: Union[list, int]):
154
+ INT32_MAX = 2**31 - 1
155
+ INT32_MIN = -(2**31)
156
+ values = list(traverse_elements(axis))
157
+ return all(INT32_MIN <= val <= INT32_MAX for val in values)
158
+
159
+
160
+ def circle_legalize_dtype_to(values, *, dtype: torch.dtype):
161
+ """
162
+ Legalize data types from `torch.int64` to `torch.int32`.
163
+
164
+ Pytorch assumes python's built-in integer type is `torch.int64`.
165
+ But, many of the circle infrastructures support only int32 type. E.g. circle-interpreter.
166
+
167
+ So, if constants has values whose range is inside [INT32_MIN <= val <= INT32_MAX], we will legalize the data type to int32.
168
+
169
+ TODO support more types
170
+
171
+ NOTE. This function must be applied only to constant values.
172
+ """
173
+ if dtype != torch.int32:
174
+ raise RuntimeError("Not supported data types.")
175
+ if not check_if_i32_range(values):
176
+ raise RuntimeError("'size' cannot be converted from int64 to int32.")
177
+ return torch.as_tensor(values, dtype=dtype)