tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+
20
+ from tico.experimental.quantization.config import BaseConfig
21
+
22
+
23
+ class BaseQuantizer(ABC):
24
+ """
25
+ Abstract base class for quantizers that apply a quantization algorithm to a target model.
26
+ """
27
+
28
+ def __init__(self, config: BaseConfig):
29
+ """
30
+ Initialize the quantizer with the given configuration.
31
+
32
+ Parameters:
33
+ config (BaseConfig): Quantization configuration parameters.
34
+ """
35
+ self.config = config
36
+
37
+ @abstractmethod
38
+ def prepare(
39
+ self,
40
+ model: torch.nn.Module,
41
+ args: Optional[Any] = None,
42
+ kwargs: Optional[Dict[str, Any]] = None,
43
+ ):
44
+ """
45
+ Prepare the given model for quantization based on the provided algorithm-specific
46
+ configuration. This involves setting up necessary observers or hooks, and may
47
+ optionally use example inputs—particularly useful for activation quantization.
48
+
49
+ Parameters:
50
+ model: The target PyTorch model.
51
+ args (Any, optional): Positional example inputs required for activation quantization.
52
+ kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
53
+
54
+ Returns:
55
+ The prepared model.
56
+ """
57
+ pass
58
+
59
+ @abstractmethod
60
+ def convert(self, model):
61
+ """
62
+ Convert the prepared (or calibrated) model into its quantized form. This function leverages
63
+ the statistics collected during calibration to perform the quantization transformation.
64
+
65
+ Parameters:
66
+ model: The prepared PyTorch model.
67
+
68
+ Returns:
69
+ The quantized model.
70
+ """
71
+ pass
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,116 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ import torch
19
+ from circle_schema import circle
20
+
21
+ from tico.interpreter.interpreter import Interpreter
22
+ from tico.serialize.circle_mapping import np_dtype_from_circle_dtype, to_circle_dtype
23
+
24
+
25
+ def preprocess_inputs(inputs: Any):
26
+ """
27
+ Preprocess user inputs for circle inference.
28
+
29
+ 1. None inputs are ignored.
30
+ 2. A list/tuple input is flatten when a torch module is exported.
31
+ e.g. inputs = (torch.Tensor, [2,3,4]) -> inputs = (torch.Tensor, 2, 3, 4)
32
+ """
33
+ l = []
34
+ for value in inputs:
35
+ if value == None:
36
+ continue
37
+ if isinstance(value, (tuple, list)):
38
+ for val in value:
39
+ l.append(val)
40
+ else:
41
+ l.append(value)
42
+ # Check if it is a list of a list.
43
+ if any(isinstance(item, (tuple, list)) for item in l):
44
+ l = preprocess_inputs(l)
45
+ return tuple(l)
46
+
47
+
48
+ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any:
49
+ # When converting a model, it is assumed that the order of keyword arguments is maintained.
50
+ user_inputs = args + tuple(kwargs.values())
51
+ user_inputs = preprocess_inputs(user_inputs)
52
+ # Cast them to torch.Tensor to make it simple.
53
+ user_inputs = tuple(
54
+ torch.tensor(user_input) if type(user_input) != torch.Tensor else user_input
55
+ for user_input in user_inputs
56
+ )
57
+
58
+ # Get input spec from circle binary.
59
+ model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
60
+ assert model.SubgraphsLength() == 1
61
+ graph = model.Subgraphs(0)
62
+ model_input_tensors = [
63
+ graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())
64
+ ]
65
+ model_input_shapes_np = [t.ShapeAsNumpy() for t in model_input_tensors]
66
+ model_input_types_cm = [t.Type() for t in model_input_tensors]
67
+
68
+ # Check if given inputs' dtype and shape from users match the inputs' from model binary.
69
+ if len(model_input_shapes_np) != len(user_inputs):
70
+ raise RuntimeError(
71
+ f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})"
72
+ )
73
+ for input_idx, user_input in enumerate(user_inputs):
74
+ # Shape
75
+ if list(user_input.shape) != list(model_input_shapes_np[input_idx]):
76
+ raise RuntimeError(
77
+ f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})"
78
+ )
79
+ # Data type
80
+ user_input_type_cm = to_circle_dtype(user_input.dtype)
81
+ if user_input_type_cm != model_input_types_cm[input_idx]:
82
+ raise RuntimeError(
83
+ f"Mismatch input {input_idx} data type : input({user_input_type_cm}) != circle model({model_input_types_cm[input_idx]})"
84
+ )
85
+
86
+ # Initialize interpreter
87
+ intp = Interpreter(circle_binary)
88
+
89
+ # Set input
90
+ for input_idx, user_input in enumerate(user_inputs):
91
+ intp.writeInputTensor(input_idx, user_input)
92
+
93
+ # Interpret
94
+ intp.interpret()
95
+
96
+ # Retrieve outputs' dtype and shape from circle model
97
+ model_output_tensors = [
98
+ graph.Tensors(graph.Outputs(o)) for o in range(graph.OutputsLength())
99
+ ]
100
+ model_output_shapes_np = [t.ShapeAsNumpy() for t in model_output_tensors]
101
+ model_output_types_cm = [t.Type() for t in model_output_tensors]
102
+
103
+ output = []
104
+ # Get output
105
+ for output_idx in range(len(model_output_tensors)):
106
+ result: np.ndarray = np.empty(
107
+ model_output_shapes_np[output_idx],
108
+ dtype=np_dtype_from_circle_dtype(model_output_types_cm[output_idx]),
109
+ )
110
+ intp.readOutputTensor(output_idx, result)
111
+ output.append(result)
112
+
113
+ if len(output) == 1:
114
+ return output[0]
115
+ else:
116
+ return output
@@ -0,0 +1,93 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+ from cffi import FFI
20
+
21
+
22
+ class Interpreter:
23
+ """
24
+ Python wrapper for C++ luci-interperter class in ONE using CFFI.
25
+
26
+ This class provides a Python interface to the underlying C++ luci-interpreter class in ONE,
27
+ preserving the original C++ API. Each method corresponds to a method in the C++ class,
28
+ with additional error handling implemented to ensure that C++ exceptions are captured and
29
+ translated into Python errors.
30
+
31
+ Note that each method includes `check_for_errors` at the end of the body to catch any C++
32
+ exceptions and translate them into Python exceptions. This ensures that errors in the C++
33
+ library do not cause undefined behavior in Python.
34
+ """
35
+
36
+ def __init__(self, circle_binary: bytes):
37
+ self.ffi = FFI()
38
+ self.ffi.cdef(
39
+ """
40
+ typedef struct InterpreterWrapper InterpreterWrapper;
41
+
42
+ const char *get_last_error(void);
43
+ void clear_last_error(void);
44
+ InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size);
45
+ void Interpreter_delete(InterpreterWrapper *intp);
46
+ void Interpreter_interpret(InterpreterWrapper *intp);
47
+ void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size);
48
+ void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size);
49
+ """
50
+ )
51
+ # TODO Check if one-compiler version is compatible. Whether it has .so file or not for CFFI.
52
+ intp_lib_path = Path("/usr/share/one/lib/libcircle_interpreter_cffi.so")
53
+ if not intp_lib_path.is_file():
54
+ raise RuntimeError("Please install one-compiler for circle inference.")
55
+ self.C = self.ffi.dlopen(str(intp_lib_path))
56
+
57
+ # Initialize interpreter
58
+ self.intp = self.C.Interpreter_new(circle_binary, len(circle_binary))
59
+ self.check_for_errors()
60
+
61
+ def delete(self):
62
+ self.C.Interpreter_delete(self.intp)
63
+ self.check_for_errors()
64
+
65
+ def interpret(self):
66
+ self.C.Interpreter_interpret(self.intp)
67
+ self.check_for_errors()
68
+
69
+ def writeInputTensor(self, input_idx: int, input_data: torch.Tensor):
70
+ input_as_numpy = input_data.numpy()
71
+ # cffi.from_buffer() only accepts C-contiguous array.
72
+ input_as_numpy = np.ascontiguousarray(input_as_numpy)
73
+ c_input = self.ffi.from_buffer(input_as_numpy)
74
+ self.C.Interpreter_writeInputTensor(
75
+ self.intp, input_idx, c_input, input_data.nbytes
76
+ )
77
+ self.check_for_errors()
78
+
79
+ def readOutputTensor(self, output_idx: int, output: np.ndarray):
80
+ c_output = self.ffi.from_buffer(output)
81
+ self.C.Interpreter_readOutputTensor(
82
+ self.intp, output_idx, c_output, output.nbytes
83
+ )
84
+ self.check_for_errors()
85
+
86
+ def check_for_errors(self):
87
+ error_message = self.ffi.string(self.C.get_last_error()).decode("utf-8")
88
+ if error_message:
89
+ self.C.clear_last_error()
90
+ raise RuntimeError(f"C++ Exception: {error_message}")
91
+
92
+ def __del__(self):
93
+ self.delete()
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,191 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_torch_dtype
23
+ from tico.utils import logging
24
+ from tico.utils.graph import create_node
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import (
27
+ trace_const_diff_on_pass,
28
+ trace_graph_diff_on_pass,
29
+ )
30
+ from tico.utils.utils import is_target_node, set_new_meta_val
31
+ from tico.utils.validate_args_kwargs import WhereSelfArgs
32
+
33
+
34
+ dtype_ranking = {
35
+ torch.int32: 0,
36
+ torch.int64: 1,
37
+ torch.float32: 2,
38
+ }
39
+
40
+
41
+ def sort_by_dtype(
42
+ result_true: torch.fx.Node, result_false: torch.fx.Node
43
+ ) -> Tuple[torch.fx.Node, torch.fx.Node]:
44
+ true_dtype = extract_torch_dtype(result_true)
45
+ false_dtype = extract_torch_dtype(result_false)
46
+ if dtype_ranking[true_dtype] > dtype_ranking[false_dtype]:
47
+ return result_true, result_false
48
+ if dtype_ranking[true_dtype] < dtype_ranking[false_dtype]:
49
+ return result_false, result_true
50
+ assert False, "There is no case that the dtype_ranking of the nodes are the same"
51
+
52
+
53
+ def check_if_covered_by_float(tensor: torch.Tensor) -> bool:
54
+ # About the min/max range, please refer to https://en.wikipedia.org/wiki/Single-precision_floating-point_format#Precision_limitations_on_integer_values
55
+ if tensor.min() < -(2**24) or tensor.max() > 2**24:
56
+ return False
57
+ return True
58
+
59
+
60
+ @trace_graph_diff_on_pass
61
+ @trace_const_diff_on_pass
62
+ class CastATenWhereArgType(PassBase):
63
+ """
64
+ This pass casts the data type of `aten.where.self` operation's argument.
65
+
66
+ This pass is applied when the data type of `aten.where.self` operation's argument is different.
67
+ If the data type of arguments, which are denoted `result_true` and `result_false` in below graph are identical, this pass is not applied.
68
+
69
+ In addition, this pass casts the data type as the direction that avoids data loss.
70
+ For example, if the data type of `result_true` is `float32` and the data type of `result_false` is `int32`,
71
+ then the data type of `result_false` will be casted to `float32`.
72
+ Moreover, in this case, it should be checked whether the contents of `result_false` are within the range of `float32`.
73
+ If so, the data type of `result_true` will be casted to `float32`.
74
+ If not, RuntimeError will be raised.
75
+
76
+ After this pass, the arguments of `aten.where.self` should have same data type.
77
+
78
+ The graph before this pass and the graph after this pass are shown below.
79
+ NOTE Below example denotes the case when the `result_false` was casted.
80
+
81
+ (before)
82
+
83
+ [condition] [result_true] [result_false]
84
+ | | |
85
+ | | |
86
+ +---------------+----------------+
87
+ |
88
+ |
89
+ [where]
90
+ |
91
+ |
92
+ [output]
93
+
94
+ (after)
95
+
96
+ [result_false]
97
+ [condition] [result_true] |
98
+ | | [cast]
99
+ | | |
100
+ +---------------+----------------+
101
+ |
102
+ |
103
+ [where]
104
+ |
105
+ |
106
+ [output]
107
+ """
108
+
109
+ def __init__(self):
110
+ super().__init__()
111
+
112
+ def call(self, exported_program: ExportedProgram) -> PassResult:
113
+ logger = logging.getLogger(__name__)
114
+ graph_module = exported_program.graph_module
115
+ graph = graph_module.graph
116
+ modified = False
117
+
118
+ for node in graph.nodes:
119
+ if not is_target_node(node, torch.ops.aten.where.self):
120
+ continue
121
+
122
+ where_args = WhereSelfArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
123
+ result_true, result_false = where_args.input, where_args.other
124
+ if not isinstance(result_true, torch.fx.Node) or not isinstance(
125
+ result_false, torch.fx.Node
126
+ ):
127
+ continue
128
+
129
+ ep = exported_program
130
+ assert isinstance(result_true, torch.fx.Node)
131
+ assert isinstance(result_false, torch.fx.Node)
132
+ if not (
133
+ result_true.name in ep.graph_signature.inputs_to_buffers
134
+ and result_false.name in ep.graph_signature.inputs_to_buffers
135
+ ):
136
+ continue
137
+
138
+ # Check if they have different data types
139
+ true_dtype = extract_torch_dtype(result_true)
140
+ false_dtype = extract_torch_dtype(result_false)
141
+ if true_dtype == false_dtype:
142
+ continue
143
+
144
+ node_to_dtype = {result_true: true_dtype, result_false: false_dtype}
145
+
146
+ not_to_cast, to_cast = sort_by_dtype(result_true, result_false)
147
+
148
+ buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
149
+ buf_name = ep.graph_signature.inputs_to_buffers[to_cast.name]
150
+ buf_data = buf_name_to_data[buf_name]
151
+
152
+ assert isinstance(buf_data, torch.Tensor)
153
+
154
+ dtype_to_cast = node_to_dtype[not_to_cast]
155
+
156
+ if dtype_to_cast == torch.float32:
157
+ if not check_if_covered_by_float(buf_data):
158
+ raise RuntimeError(
159
+ f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
160
+ )
161
+ with graph_module.graph.inserting_after(to_cast):
162
+ cast = create_node(
163
+ graph,
164
+ torch.ops.aten._to_copy.default,
165
+ args=(to_cast,),
166
+ kwargs={"dtype": dtype_to_cast},
167
+ origin=to_cast,
168
+ )
169
+ # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
170
+ set_new_meta_val(cast)
171
+ node.update_arg(node.args.index(to_cast), cast)
172
+
173
+ # check if type promotion is valid.
174
+ node_dtype_ori = extract_torch_dtype(node)
175
+ set_new_meta_val(node)
176
+ node_dtype = extract_torch_dtype(node)
177
+ assert (
178
+ node_dtype == node_dtype_ori
179
+ ), f"Type casting doesn't change node's dtype."
180
+
181
+ logger.debug(
182
+ f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
183
+ )
184
+
185
+ modified = True
186
+
187
+ graph.eliminate_dead_code()
188
+ graph.lint()
189
+ graph_module.recompile()
190
+
191
+ return PassResult(modified)
@@ -0,0 +1,187 @@
1
+ # Portions of this file are adapted from code originally authored by
2
+ # Meta Platforms, Inc. and affiliates, licensed under the BSD-style
3
+ # license found in the LICENSE file in the root directory of their source tree.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ import torch.fx
23
+ import torch
24
+ from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
25
+ from torch.export import ExportedProgram
26
+
27
+ from tico.serialize.circle_mapping import extract_torch_dtype
28
+ from tico.utils import logging
29
+ from tico.utils.graph import create_node
30
+ from tico.utils.passes import PassBase, PassResult
31
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
32
+ from tico.utils.utils import is_target_node, set_new_meta_val
33
+
34
+
35
+ ops_to_promote = {
36
+ torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
37
+ torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
38
+ torch.ops.aten.eq.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
39
+ torch.ops.aten.eq.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
40
+ torch.ops.aten.ge.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
41
+ torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
42
+ torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
43
+ torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
44
+ torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
45
+ torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
46
+ torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
47
+ torch.ops.aten.ne.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
48
+ torch.ops.aten.pow.Tensor_Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
49
+ torch.ops.aten.sub.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
50
+ }
51
+
52
+
53
+ def has_same_dtype(lhs, rhs):
54
+ if isinstance(lhs, torch.fx.Node):
55
+ lhs_dtype = lhs.meta["val"].dtype
56
+ elif isinstance(lhs, torch.Tensor):
57
+ lhs_dtype = lhs.dtype
58
+ else:
59
+ lhs_dtype = torch.tensor(lhs).dtype
60
+ if isinstance(rhs, torch.fx.Node):
61
+ rhs_dtype = rhs.meta["val"].dtype
62
+ elif isinstance(rhs, torch.Tensor):
63
+ rhs_dtype = rhs.dtype
64
+ else:
65
+ rhs_dtype = torch.tensor(rhs).dtype
66
+
67
+ if lhs_dtype == rhs_dtype:
68
+ return True
69
+ return False
70
+
71
+
72
+ def to_numeric_type(torch_dtype: torch.dtype):
73
+ dmap = {
74
+ torch.float32: float,
75
+ torch.float: float,
76
+ torch.int64: int,
77
+ torch.bool: bool,
78
+ }
79
+
80
+ if torch_dtype not in dmap:
81
+ return None
82
+
83
+ return dmap[torch_dtype]
84
+
85
+
86
+ @trace_graph_diff_on_pass
87
+ class CastMixedTypeArgs(PassBase):
88
+ def __init__(self, preserve_ep_invariant=True):
89
+ super().__init__()
90
+ self.preserve_ep_invariant = preserve_ep_invariant
91
+
92
+ # TODO Folding float and int values before this pass
93
+ def call(self, exported_program: ExportedProgram) -> PassResult:
94
+ logger = logging.getLogger(__name__)
95
+
96
+ graph_module = exported_program.graph_module
97
+ graph = graph_module.graph
98
+ modified = False
99
+ for node in graph.nodes:
100
+ if not is_target_node(node, list(ops_to_promote.keys())):
101
+ continue
102
+
103
+ assert len(node.args) == 2
104
+ lhs, rhs = node.args
105
+ assert isinstance(lhs, (torch.fx.Node, torch.Tensor, float, int)), type(lhs)
106
+ assert isinstance(rhs, (torch.fx.Node, torch.Tensor, float, int)), type(rhs)
107
+ if has_same_dtype(lhs, rhs):
108
+ continue
109
+
110
+ lhs_val = (
111
+ lhs.meta["val"] if isinstance(lhs, torch.fx.Node) else torch.tensor(lhs)
112
+ )
113
+ rhs_val = (
114
+ rhs.meta["val"] if isinstance(rhs, torch.fx.Node) else torch.tensor(rhs)
115
+ )
116
+ type_to_promote: torch.dtype = elementwise_dtypes(
117
+ lhs_val, rhs_val, type_promotion_kind=ops_to_promote[node.target]
118
+ )[1]
119
+ arg_to_promote = None
120
+ ori_type = None
121
+ if lhs_val.dtype == type_to_promote:
122
+ ori_type = rhs_val.dtype
123
+ arg_to_promote = rhs
124
+ if rhs_val.dtype == type_to_promote:
125
+ ori_type = lhs_val.dtype
126
+ arg_to_promote = lhs
127
+ assert arg_to_promote != None
128
+
129
+ if isinstance(arg_to_promote, torch.fx.Node):
130
+ with graph.inserting_after(arg_to_promote):
131
+ to_copy = create_node(
132
+ graph,
133
+ torch.ops.aten._to_copy.default,
134
+ (arg_to_promote,),
135
+ {"dtype": type_to_promote},
136
+ origin=arg_to_promote,
137
+ )
138
+ # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
139
+ set_new_meta_val(to_copy)
140
+ node.update_arg(node.args.index(arg_to_promote), to_copy)
141
+
142
+ modified = True
143
+ logger.debug(
144
+ f"{arg_to_promote.name}'s dtype was casted from {ori_type} to {type_to_promote}"
145
+ )
146
+ else:
147
+ index_to_promote = node.args.index(arg_to_promote)
148
+ if isinstance(arg_to_promote, torch.Tensor):
149
+ arg_to_promote = arg_to_promote.to(type_to_promote)
150
+ else:
151
+ # numerical types
152
+ numeric_type = to_numeric_type(type_to_promote)
153
+ if numeric_type is not None:
154
+ arg_to_promote = numeric_type(arg_to_promote)
155
+ else:
156
+ if self.preserve_ep_invariant:
157
+ # ExportedProgram (EP) requires to add a placeholder when
158
+ # a tensor is created, which complicates EP structure but
159
+ # not necessary for circle serialization. We skip this case if
160
+ # preserve_ep_invariant = True.
161
+ continue
162
+ else:
163
+ # Create tensor without placeholder
164
+ # NOTE This breaks EP invariant
165
+ arg_to_promote = torch.tensor(arg_to_promote).to(
166
+ type_to_promote
167
+ )
168
+ node.update_arg(index_to_promote, arg_to_promote)
169
+
170
+ modified = True
171
+ logger.debug(
172
+ f"{arg_to_promote}'s dtype was casted from {ori_type} to {type_to_promote}"
173
+ )
174
+
175
+ # check if type promotion is valid.
176
+ node_dtype_ori = extract_torch_dtype(node)
177
+ set_new_meta_val(node)
178
+ node_dtype = extract_torch_dtype(node)
179
+ assert (
180
+ node_dtype == node_dtype_ori
181
+ ), f"Type casting doesn't change node's dtype."
182
+
183
+ graph.eliminate_dead_code()
184
+ graph.lint()
185
+ graph_module.recompile()
186
+
187
+ return PassResult(modified)