tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
16
+
17
+
18
+ class BACKEND(Enum):
19
+ CIRCLE = 1
20
+ TRIV24 = 2
@@ -0,0 +1,223 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from circle_schema import circle
21
+ from torch.utils import _pytree as pytree
22
+
23
+ from tico.experimental.quantization.evaluation.backend import BACKEND
24
+ from tico.experimental.quantization.evaluation.executor.backend_executor import (
25
+ BackendExecutor,
26
+ )
27
+ from tico.experimental.quantization.evaluation.executor.circle_executor import (
28
+ CircleExecutor,
29
+ )
30
+ from tico.experimental.quantization.evaluation.executor.triv24_executor import (
31
+ Triv24Executor,
32
+ )
33
+ from tico.experimental.quantization.evaluation.metric import MetricCalculator
34
+ from tico.experimental.quantization.evaluation.utils import (
35
+ ensure_list,
36
+ find_invalid_types,
37
+ get_graph_input_output,
38
+ plot_two_outputs,
39
+ )
40
+ from tico.utils.model import CircleModel
41
+
42
+ InputDataType = Union[
43
+ None,
44
+ np.ndarray,
45
+ torch.Tensor,
46
+ List[np.ndarray],
47
+ List[torch.Tensor],
48
+ Tuple[np.ndarray],
49
+ Tuple[torch.Tensor],
50
+ ]
51
+
52
+ BACKEND_TO_EXECUTOR: Dict[BACKEND, type[BackendExecutor]] = {
53
+ BACKEND.CIRCLE: CircleExecutor,
54
+ BACKEND.TRIV24: Triv24Executor,
55
+ }
56
+
57
+
58
+ def _validate_input_data(
59
+ input_data: InputDataType, circle_inputs: List[circle.Tensor.Tensor]
60
+ ) -> None:
61
+ """
62
+ Validate whether the given input data matches the shape of a list of circle input tensors.
63
+
64
+ Parameters
65
+ -----------
66
+ input_data
67
+ The input data to be checked.
68
+ circle_inputs
69
+ A list of circle.Tensor.Tensor to validate against.
70
+ """
71
+ if input_data is None:
72
+ return
73
+ assert isinstance(input_data, list)
74
+
75
+ if len(input_data) != len(circle_inputs):
76
+ raise RuntimeError(
77
+ f"Mismatch between the length of input data and circle model: input_data({len(input_data)}) != circle_model({len(circle_inputs)})"
78
+ )
79
+ if invalid_type := find_invalid_types(input_data, [torch.Tensor, np.ndarray]):
80
+ raise RuntimeError(
81
+ f"Only support tuple of torch.Tensor or numpy.ndarray for input data. Invalid types: {invalid_type}"
82
+ )
83
+
84
+
85
+ def _convert_to_torch_tensor(input_data: InputDataType) -> List[torch.Tensor]:
86
+ """
87
+ Convert the input data into a list of torch.Tensor.
88
+
89
+ This function performs the following tasks:
90
+ - Checks if `input_data` is a numpy array and converts it to a torch.Tensor.
91
+ - If it is already a torch.Tensor, it is returned as is.
92
+
93
+ Parameters
94
+ -----------
95
+ input_data
96
+ The input data to be converted.
97
+ """
98
+ assert isinstance(input_data, list)
99
+
100
+ # Cast to torch.Tensor to make the logic simpler
101
+ for i, data in enumerate(input_data):
102
+ if isinstance(data, np.ndarray):
103
+ input_data[i] = torch.Tensor(data) # type: ignore[call-overload]
104
+
105
+ assert all(isinstance(input, torch.Tensor) for input in input_data)
106
+
107
+ return input_data # type: ignore[return-value]
108
+
109
+
110
+ def evaluate(
111
+ torch_module: torch.nn.Module,
112
+ circle_model: CircleModel,
113
+ backend: BACKEND,
114
+ input_data: InputDataType = None,
115
+ *,
116
+ mode="plot",
117
+ metrics: List[str] = ["peir"],
118
+ custom_metrics: Dict[str, Callable] = dict(),
119
+ ) -> Optional[Dict[str, Any]]:
120
+ """
121
+ Evaluate and compare a Pytorch module with a quantized circle model on a specific
122
+ backend using a given metrics.
123
+
124
+ It compiles the circle model using specified backend, run both the Pytorch module
125
+ and the compiled model on the same input, and computes the comparison score based
126
+ on the provided metrics.
127
+
128
+ Parameters
129
+ -----------
130
+ torch_module
131
+ A callable Pytorch module.
132
+ circle_module
133
+ The Circle model to be compiled and evaluated.
134
+ backend
135
+ The backend used to compile and execute the Circle model.
136
+ input_data
137
+ The input data to be used for evaluation. Should be compatible with both models.
138
+ If None, random data will be generated.
139
+ mode
140
+ The mode of operation. Options are:
141
+ - "plot": Plot the results (default)
142
+ - "return": Return the results.
143
+ metrics
144
+ A list of metric names for comparison.
145
+ custom_metrics
146
+ A dictionary of metric names and corresponding callable functions for comparison.
147
+ Example: {'mse': mean_squared_error, 'cosine_similarity': cosine_similarity_fn}
148
+ # TODO Support options for backend optimizations.
149
+
150
+ Returns
151
+ --------
152
+ dict or None
153
+ If `mode` is "plot", plot the results and returns None.
154
+ If `mode` is "return", returns a dictionary containing:
155
+ - "peir": The computed PEIR value.
156
+ - "<metric_name>": The computed value of the additional metric (if provided).
157
+ """
158
+ # Check if arguments are allowed types.
159
+ if not isinstance(torch_module, torch.nn.Module):
160
+ raise RuntimeError(
161
+ f"Only support torch.nn.Module. Given module type: {type(torch_module)}"
162
+ )
163
+ if not isinstance(circle_model, CircleModel):
164
+ raise RuntimeError(
165
+ f"Only support CircleModel. Given module type: {type(circle_model)}"
166
+ )
167
+ if not isinstance(backend, BACKEND):
168
+ raise RuntimeError(
169
+ f"Invalid backend. Please use tico.quantization.evaluate.BACKEND enum class"
170
+ )
171
+ # Make it a list for simpler logic.
172
+ if input_data is not None:
173
+ input_data = ensure_list(input_data)
174
+
175
+ circle_inputs, _ = get_graph_input_output(circle_model)
176
+ _validate_input_data(input_data, circle_inputs)
177
+
178
+ if input_data:
179
+ input_data = _convert_to_torch_tensor(input_data)
180
+ else:
181
+ # Make random inputs
182
+ circle_input_shapes_np = [t.ShapeAsNumpy() for t in circle_inputs]
183
+ input_data = [torch.randn(*shape) for shape in circle_input_shapes_np]
184
+
185
+ assert isinstance(input_data, list)
186
+ assert all(isinstance(data, torch.Tensor) for data in input_data)
187
+
188
+ # Compile circle model and run inference.
189
+ executor: BackendExecutor = BACKEND_TO_EXECUTOR[backend]()
190
+ executor.compile(circle_model)
191
+ circle_output = executor.run_inference(input_data)
192
+ circle_output = [
193
+ torch.from_numpy(out) for out in circle_output if isinstance(out, np.ndarray)
194
+ ]
195
+
196
+ # Run torch model.
197
+ with torch.no_grad():
198
+ torch_output = torch_module(*input_data)
199
+ torch_output, _ = pytree.tree_flatten(torch_output)
200
+ if isinstance(torch_output, torch.Tensor):
201
+ torch_output = (torch_output,)
202
+ if len(torch_output) != len(circle_output):
203
+ raise RuntimeError(
204
+ f"Mismatch between the length of torch output and circle output: torch_output({len(torch_output)}) != circle_output({len(circle_output)})"
205
+ )
206
+
207
+ # Computes the comparison score based on the provided metrics.
208
+ metric_calculator = MetricCalculator(metrics, custom_metrics)
209
+ results: Dict[str, Any] = metric_calculator.compute(torch_output, circle_output)
210
+
211
+ if mode == "return":
212
+ return results
213
+ elif mode == "plot":
214
+ for idx, (t_out, c_out) in enumerate(zip(torch_output, circle_output)):
215
+ print(f"OUTPUT [{idx}]")
216
+ fig = plot_two_outputs(t_out, c_out)
217
+ print(fig)
218
+ for metric_name, values in results.items():
219
+ print(f"{metric_name}: {values[idx]}")
220
+ else:
221
+ raise RuntimeError("Invalid mode.")
222
+
223
+ return None
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from typing import Any
17
+
18
+ from tico.utils.model import CircleModel
19
+
20
+
21
+ class BackendExecutor(abc.ABC):
22
+ """
23
+ Abstract base class for executing a circle model on a specific backend.
24
+ """
25
+
26
+ @abc.abstractmethod
27
+ def compile(self, circle_model: CircleModel) -> None:
28
+ """
29
+ Compile the circle model for this backend, if needed.
30
+
31
+ Parameters
32
+ -----------
33
+ circle_model
34
+ The circle model to be compiled.
35
+ """
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def run_inference(self, input_data: Any) -> Any:
40
+ """
41
+ Run inference using the compiled (or directly usable) model
42
+ on the given input data.
43
+
44
+ Parameters
45
+ -----------
46
+ input_data
47
+ The input data to be fed to the model.
48
+
49
+ Returns
50
+ --------
51
+ Any
52
+ The model's inference output.
53
+ """
54
+ pass
@@ -0,0 +1,75 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from pathlib import Path
17
+ from typing import List
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from tico.experimental.quantization.evaluation.executor.backend_executor import (
23
+ BackendExecutor,
24
+ )
25
+ from tico.utils.model import CircleModel
26
+ from tico.utils.utils import run_bash_cmd
27
+
28
+
29
+ class CircleExecutor(BackendExecutor):
30
+ """
31
+ A class for running inference on fake-quantized circle models.
32
+
33
+ Instead of leveraging the actual backend for quantized circle execution,
34
+ it applies fake quantization to the models and performs inference.
35
+ """
36
+
37
+ def __init__(self):
38
+ self.compiler_path = Path("/usr/share/one/bin/onecc")
39
+ self.interpreter_path = None # Use circle-interpreter
40
+ self.fq_circle_path = None
41
+
42
+ # Check if the toolchain is installed.
43
+ if not self.compiler_path.is_file():
44
+ raise RuntimeError(
45
+ "Not found one-compiler. Please install the one-compiler package first."
46
+ )
47
+
48
+ self.temp_dir = tempfile.TemporaryDirectory()
49
+
50
+ def compile(self, circle_model: CircleModel) -> None:
51
+ assert isinstance(circle_model, CircleModel)
52
+ circle_path = Path(self.temp_dir.name) / "quantized.circle"
53
+ circle_model.save(str(circle_path))
54
+ self.fq_circle_path = Path(self.temp_dir.name) / "fake_quantized.circle"
55
+ args = []
56
+ args += ["quantize"]
57
+ args += ["--fake_quantize"]
58
+ args += ["-i", str(circle_path)]
59
+ args += ["-o", str(self.fq_circle_path)]
60
+ cmd = [str(self.compiler_path)] + args
61
+ run_bash_cmd(cmd)
62
+
63
+ def run_inference(self, input_data: List[torch.Tensor]) -> List[np.ndarray]:
64
+ if not self.fq_circle_path:
65
+ raise RuntimeError("You must compile the model before running inference.")
66
+
67
+ fq_circle = CircleModel.load(self.fq_circle_path)
68
+ assert isinstance(fq_circle, CircleModel)
69
+ out = fq_circle(*input_data)
70
+ if not isinstance(out, list):
71
+ out = [out]
72
+ return out
73
+
74
+ def __del__(self):
75
+ self.temp_dir.cleanup()
@@ -0,0 +1,128 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ from pathlib import Path
17
+ from typing import List
18
+
19
+ import numpy as np
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.experimental.quantization.evaluation.executor.backend_executor import (
24
+ BackendExecutor,
25
+ )
26
+ from tico.experimental.quantization.evaluation.utils import (
27
+ dequantize,
28
+ get_graph_input_output,
29
+ quantize,
30
+ )
31
+ from tico.serialize.circle_mapping import np_dtype_from_circle_dtype
32
+ from tico.utils.model import CircleModel
33
+ from tico.utils.utils import run_bash_cmd
34
+
35
+
36
+ class Triv24Executor(BackendExecutor):
37
+ """
38
+ Implementation for a TRIV24 backend.
39
+ """
40
+
41
+ def __init__(self):
42
+ self.compiler_path = Path("/usr/share/one/backends/triv24/bin/triv24-compile")
43
+ self.interpreter_path = Path(
44
+ "/usr/share/one/backends/triv24/bin/triv24-ssinfer"
45
+ )
46
+ self.circle_model = None
47
+ self.tvn_path = None
48
+
49
+ # Check if triv24 toolchain is installed.
50
+ if not self.compiler_path.is_file() or not self.interpreter_path.is_file():
51
+ raise RuntimeError(
52
+ "Not found triv24 toolchain. Please install the toolchain package first."
53
+ )
54
+
55
+ self.temp_dir = tempfile.TemporaryDirectory()
56
+
57
+ def compile(self, circle_model: CircleModel) -> None:
58
+ assert isinstance(circle_model, CircleModel)
59
+ self.circle_model = circle_model
60
+ circle_path = Path(self.temp_dir.name) / "quantized.circle"
61
+ circle_model.save(str(circle_path))
62
+ self.tvn_path = Path(self.temp_dir.name) / "compiled.tvn"
63
+ args = []
64
+ args += ["-o", str(self.tvn_path)]
65
+ args += [str(circle_path)]
66
+ cmd = [str(self.compiler_path)] + args
67
+ run_bash_cmd(cmd)
68
+
69
+ def run_inference(self, input_data: List[torch.Tensor]) -> List[np.ndarray]:
70
+ if not self.tvn_path:
71
+ raise RuntimeError("You must compile the model before running inference.")
72
+
73
+ assert isinstance(self.circle_model, CircleModel)
74
+ circle_inputs, circle_outputs = get_graph_input_output(self.circle_model)
75
+ # Get input/output of scale/zp from quantized circle.
76
+ # Note that qparams may be None because some of them are not quantized like indices, argmax.
77
+ input_qparam: List[circle.QuantizationParameters.QuantizationParameters] = []
78
+ for inp in circle_inputs:
79
+ input_qparam.append(inp.Quantization())
80
+ output_qparam: List[circle.QuantizationParameters.QuantizationParameters] = []
81
+ for out in circle_outputs:
82
+ output_qparam.append(out.Quantization())
83
+
84
+ # Create input files for inference.
85
+ for in_idx, data in enumerate(input_data):
86
+ in_data_path = Path(self.temp_dir.name) / f"input.{in_idx}.tv2b"
87
+ assert isinstance(data, torch.Tensor)
88
+ if input_qparam[in_idx] is None:
89
+ np_data = data.numpy()
90
+ else:
91
+ np_data = quantize(
92
+ data.numpy(),
93
+ input_qparam[in_idx].ScaleAsNumpy()[0],
94
+ input_qparam[in_idx].ZeroPointAsNumpy()[0],
95
+ np_dtype_from_circle_dtype(circle_inputs[in_idx].Type()),
96
+ )
97
+ np_data.tofile(in_data_path)
98
+ args = []
99
+ args += ["--loadable", str(self.tvn_path)]
100
+ args += ["--input-spec", f"tv2b:{str(self.temp_dir.name)}/input"]
101
+ args += ["--dump-output-as-tv2b", f"{str(self.temp_dir.name)}/output"]
102
+ cmd = [str(self.interpreter_path)] + args
103
+
104
+ # Run inference
105
+ run_bash_cmd(cmd)
106
+
107
+ # Load outputs from file
108
+ circle_dequantized_output = []
109
+ circle_output_shapes_np = [t.ShapeAsNumpy() for t in circle_outputs]
110
+ for out_idx, out_shape in enumerate(circle_output_shapes_np):
111
+ out_data_path = Path(self.temp_dir.name) / f"output.{out_idx}.tv2b"
112
+ circle_output = np.fromfile(out_data_path, np.uint8).reshape(out_shape)
113
+ if output_qparam[out_idx] is None:
114
+ circle_dequantized_output.append(circle_output)
115
+ else:
116
+ circle_dequantized_output.append(
117
+ dequantize(
118
+ circle_output,
119
+ output_qparam[out_idx].ScaleAsNumpy()[0],
120
+ output_qparam[out_idx].ZeroPointAsNumpy()[0],
121
+ np_dtype_from_circle_dtype(circle_outputs[out_idx].Type()),
122
+ )
123
+ )
124
+
125
+ return circle_dequantized_output
126
+
127
+ def __del__(self):
128
+ self.temp_dir.cleanup()
@@ -0,0 +1,109 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ def compute_peir(base: torch.Tensor, target: torch.Tensor):
22
+ """
23
+ Calculate the Peak Error to Interval Ratio (PEIR) between two tensors.
24
+
25
+ This function computes the PEIR between two tensors using the formula:
26
+ PEIR = max(abs(tensor1 - tensor2)) / (max(tensor1) - min(tensor2))
27
+ """
28
+ assert base.shape == target.shape, f"shape mismatch: {base.shape} != {target.shape}"
29
+ base_tensor = base.numpy()
30
+ target_tensor = target.numpy()
31
+ assert (
32
+ base_tensor.dtype == np.float32 and target_tensor.dtype == np.float32
33
+ ), f"dtype should be float32: base({base_tensor.dtype}), target({target_tensor.dtype})"
34
+
35
+ base_tensor = base_tensor.reshape(-1)
36
+ target_tensor = target_tensor.reshape(-1)
37
+
38
+ assert (
39
+ base_tensor.shape == target_tensor.shape
40
+ ), f"Shape mismatch: {base_tensor.shape} != {target_tensor.shape}"
41
+
42
+ peak_error = np.max(np.absolute(target_tensor - base_tensor))
43
+ interval = np.max(base_tensor) - np.min(base_tensor)
44
+ peir = peak_error / interval # pylint: disable=invalid-name
45
+
46
+ min_value = min([base_tensor.min(), target_tensor.min()])
47
+ max_value = max([base_tensor.max(), target_tensor.max()])
48
+
49
+ interval = max_value - min_value
50
+ interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
51
+
52
+ return peir
53
+
54
+
55
+ class MetricCalculator:
56
+ """
57
+ Compute metrics including both built-in and custom metrics.
58
+
59
+ metrics
60
+ A list of metric names for comparison.
61
+ custom_metrics
62
+ A dictionary of metric names and corresponding callable functions for comparison.
63
+ Example: {'mse': mean_squared_error, 'cosine_similarity': cosine_similarity_fn}
64
+ """
65
+
66
+ builtin_metrics = {
67
+ "peir": compute_peir,
68
+ }
69
+
70
+ def __init__(
71
+ self,
72
+ metrics: List[str] = list(),
73
+ custom_metrics: Dict[str, Callable] = dict(),
74
+ ):
75
+ self.metrics: Dict[str, Callable] = dict()
76
+
77
+ for m in metrics:
78
+ if m in self.builtin_metrics:
79
+ self.metrics[m] = self.builtin_metrics[m]
80
+ else:
81
+ raise RuntimeError(f"Invalid metric: {m}")
82
+
83
+ duplicates = set(self.metrics).intersection(custom_metrics.keys())
84
+ if len(duplicates) != 0:
85
+ raise RuntimeError(f"There are duplicate metrics: {duplicates}")
86
+
87
+ self.metrics = self.metrics | custom_metrics
88
+
89
+ def compute(
90
+ self, output1: List[torch.Tensor], output2: List[torch.Tensor]
91
+ ) -> Dict[str, List[Any]]:
92
+ """
93
+ Compute both built-in metrics (if provided) and custom metrics.
94
+
95
+ Returns
96
+ --------
97
+ Dict[str, Any]
98
+ A dictionary with metric names and their computed values.
99
+ """
100
+ results: Dict[str, List[Any]] = dict()
101
+
102
+ # Compute built-in metrics
103
+ if self.metrics is not None:
104
+ for m in self.metrics:
105
+ results[m] = list()
106
+ for out1, out2 in zip(output1, output2):
107
+ results[m].append(self.builtin_metrics[m](out1, out2))
108
+
109
+ return results