tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,93 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+ from cffi import FFI
20
+
21
+
22
+ class Interpreter:
23
+ """
24
+ Python wrapper for C++ luci-interperter class in ONE using CFFI.
25
+
26
+ This class provides a Python interface to the underlying C++ luci-interpreter class in ONE,
27
+ preserving the original C++ API. Each method corresponds to a method in the C++ class,
28
+ with additional error handling implemented to ensure that C++ exceptions are captured and
29
+ translated into Python errors.
30
+
31
+ Note that each method includes `check_for_errors` at the end of the body to catch any C++
32
+ exceptions and translate them into Python exceptions. This ensures that errors in the C++
33
+ library do not cause undefined behavior in Python.
34
+ """
35
+
36
+ def __init__(self, circle_binary: bytes):
37
+ self.ffi = FFI()
38
+ self.ffi.cdef(
39
+ """
40
+ typedef struct InterpreterWrapper InterpreterWrapper;
41
+
42
+ const char *get_last_error(void);
43
+ void clear_last_error(void);
44
+ InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size);
45
+ void Interpreter_delete(InterpreterWrapper *intp);
46
+ void Interpreter_interpret(InterpreterWrapper *intp);
47
+ void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size);
48
+ void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size);
49
+ """
50
+ )
51
+ # TODO Check if one-compiler version is compatible. Whether it has .so file or not for CFFI.
52
+ intp_lib_path = Path("/usr/share/one/lib/libcircle_interpreter_cffi.so")
53
+ if not intp_lib_path.is_file():
54
+ raise RuntimeError("Please install one-compiler for circle inference.")
55
+ self.C = self.ffi.dlopen(str(intp_lib_path))
56
+
57
+ # Initialize interpreter
58
+ self.intp = self.C.Interpreter_new(circle_binary, len(circle_binary))
59
+ self.check_for_errors()
60
+
61
+ def delete(self):
62
+ self.C.Interpreter_delete(self.intp)
63
+ self.check_for_errors()
64
+
65
+ def interpret(self):
66
+ self.C.Interpreter_interpret(self.intp)
67
+ self.check_for_errors()
68
+
69
+ def writeInputTensor(self, input_idx: int, input_data: torch.Tensor):
70
+ input_as_numpy = input_data.numpy()
71
+ # cffi.from_buffer() only accepts C-contiguous array.
72
+ input_as_numpy = np.ascontiguousarray(input_as_numpy)
73
+ c_input = self.ffi.from_buffer(input_as_numpy)
74
+ self.C.Interpreter_writeInputTensor(
75
+ self.intp, input_idx, c_input, input_data.nbytes
76
+ )
77
+ self.check_for_errors()
78
+
79
+ def readOutputTensor(self, output_idx: int, output: np.ndarray):
80
+ c_output = self.ffi.from_buffer(output)
81
+ self.C.Interpreter_readOutputTensor(
82
+ self.intp, output_idx, c_output, output.nbytes
83
+ )
84
+ self.check_for_errors()
85
+
86
+ def check_for_errors(self):
87
+ error_message = self.ffi.string(self.C.get_last_error()).decode("utf-8")
88
+ if error_message:
89
+ self.C.clear_last_error()
90
+ raise RuntimeError(f"C++ Exception: {error_message}")
91
+
92
+ def __del__(self):
93
+ self.delete()
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,185 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_torch_dtype
23
+ from tico.utils import logging
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import (
26
+ trace_const_diff_on_pass,
27
+ trace_graph_diff_on_pass,
28
+ )
29
+ from tico.utils.utils import set_new_meta_val
30
+
31
+
32
+ dtype_ranking = {
33
+ torch.int32: 0,
34
+ torch.int64: 1,
35
+ torch.float32: 2,
36
+ }
37
+
38
+
39
+ def sort_by_dtype(
40
+ result_true: torch.fx.Node, result_false: torch.fx.Node
41
+ ) -> Tuple[torch.fx.Node, torch.fx.Node]:
42
+ true_dtype = extract_torch_dtype(result_true)
43
+ false_dtype = extract_torch_dtype(result_false)
44
+ if dtype_ranking[true_dtype] > dtype_ranking[false_dtype]:
45
+ return result_true, result_false
46
+ if dtype_ranking[true_dtype] < dtype_ranking[false_dtype]:
47
+ return result_false, result_true
48
+ assert False, "There is no case that the dtype_ranking of the nodes are the same"
49
+
50
+
51
+ def check_if_covered_by_float(tensor: torch.Tensor) -> bool:
52
+ # About the min/max range, please refer to https://en.wikipedia.org/wiki/Single-precision_floating-point_format#Precision_limitations_on_integer_values
53
+ if tensor.min() < -(2**24) or tensor.max() > 2**24:
54
+ return False
55
+ return True
56
+
57
+
58
+ @trace_graph_diff_on_pass
59
+ @trace_const_diff_on_pass
60
+ class CastATenWhereArgType(PassBase):
61
+ """
62
+ This pass casts the data type of `aten.where.self` operation's argument.
63
+
64
+ This pass is applied when the data type of `aten.where.self` operation's argument is different.
65
+ If the data type of arguments, which are denoted `result_true` and `result_false` in below graph are identical, this pass is not applied.
66
+
67
+ In addition, this pass casts the data type as the direction that avoids data loss.
68
+ For example, if the data type of `result_true` is `float32` and the data type of `result_false` is `int32`,
69
+ then the data type of `result_false` will be casted to `float32`.
70
+ Moreover, in this case, it should be checked whether the contents of `result_false` are within the range of `float32`.
71
+ If so, the data type of `result_true` will be casted to `float32`.
72
+ If not, RuntimeError will be raised.
73
+
74
+ After this pass, the arguments of `aten.where.self` should have same data type.
75
+
76
+ The graph before this pass and the graph after this pass are shown below.
77
+ NOTE Below example denotes the case when the `result_false` was casted.
78
+
79
+ (before)
80
+
81
+ [condition] [result_true] [result_false]
82
+ | | |
83
+ | | |
84
+ +---------------+----------------+
85
+ |
86
+ |
87
+ [where]
88
+ |
89
+ |
90
+ [output]
91
+
92
+ (after)
93
+
94
+ [result_false]
95
+ [condition] [result_true] |
96
+ | | [cast]
97
+ | | |
98
+ +---------------+----------------+
99
+ |
100
+ |
101
+ [where]
102
+ |
103
+ |
104
+ [output]
105
+ """
106
+
107
+ def __init__(self):
108
+ super().__init__()
109
+
110
+ def call(self, exported_program: ExportedProgram) -> PassResult:
111
+ logger = logging.getLogger(__name__)
112
+ graph_module = exported_program.graph_module
113
+ graph = graph_module.graph
114
+ modified = False
115
+
116
+ for node in graph.nodes:
117
+ if node.op == "call_function" and node.target == torch.ops.aten.where.self:
118
+
119
+ assert len(node.args) == 3
120
+ (
121
+ _,
122
+ result_true,
123
+ result_false,
124
+ ) = node.args # first argument is not used
125
+
126
+ ep = exported_program
127
+
128
+ if not (
129
+ result_true.name in ep.graph_signature.inputs_to_buffers
130
+ and result_false.name in ep.graph_signature.inputs_to_buffers
131
+ ):
132
+ continue
133
+
134
+ # Check if they have different data types
135
+ true_dtype = extract_torch_dtype(result_true)
136
+ false_dtype = extract_torch_dtype(result_false)
137
+ if true_dtype == false_dtype:
138
+ continue
139
+
140
+ node_to_dtype = {result_true: true_dtype, result_false: false_dtype}
141
+
142
+ not_to_cast, to_cast = sort_by_dtype(result_true, result_false)
143
+
144
+ buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
145
+ buf_name = ep.graph_signature.inputs_to_buffers[to_cast.name]
146
+ buf_data = buf_name_to_data[buf_name]
147
+
148
+ assert isinstance(buf_data, torch.Tensor)
149
+
150
+ dtype_to_cast = node_to_dtype[not_to_cast]
151
+
152
+ if dtype_to_cast == torch.float32:
153
+ if not check_if_covered_by_float(buf_data):
154
+ raise RuntimeError(
155
+ f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
156
+ )
157
+ with graph_module.graph.inserting_after(to_cast):
158
+ cast = graph_module.graph.call_function(
159
+ torch.ops.aten._to_copy.default,
160
+ args=(to_cast,),
161
+ kwargs={"dtype": dtype_to_cast},
162
+ )
163
+ # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
164
+ set_new_meta_val(cast)
165
+ node.update_arg(node.args.index(to_cast), cast)
166
+
167
+ # check if type promotion is valid.
168
+ node_dtype_ori = extract_torch_dtype(node)
169
+ set_new_meta_val(node)
170
+ node_dtype = extract_torch_dtype(node)
171
+ assert (
172
+ node_dtype == node_dtype_ori
173
+ ), f"Type casting doesn't change node's dtype."
174
+
175
+ logger.debug(
176
+ f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
177
+ )
178
+
179
+ modified = True
180
+
181
+ graph.eliminate_dead_code()
182
+ graph.lint()
183
+ graph_module.recompile()
184
+
185
+ return PassResult(modified)
@@ -0,0 +1,186 @@
1
+ # Portions of this file are adapted from code originally authored by
2
+ # Meta Platforms, Inc. and affiliates, licensed under the BSD-style
3
+ # license found in the LICENSE file in the root directory of their source tree.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ import torch.fx
23
+ import torch
24
+ from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
25
+ from torch.export import ExportedProgram
26
+
27
+ from tico.serialize.circle_mapping import extract_torch_dtype
28
+ from tico.utils import logging
29
+ from tico.utils.passes import PassBase, PassResult
30
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
31
+ from tico.utils.utils import set_new_meta_val
32
+
33
+
34
+ ops_to_promote = {
35
+ torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
36
+ torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
37
+ torch.ops.aten.eq.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
38
+ torch.ops.aten.eq.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
39
+ torch.ops.aten.ge.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
40
+ torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
41
+ torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
42
+ torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
43
+ torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
44
+ torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
45
+ torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
46
+ torch.ops.aten.ne.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
47
+ torch.ops.aten.pow.Tensor_Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
48
+ torch.ops.aten.sub.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
49
+ }
50
+
51
+
52
+ def has_same_dtype(lhs, rhs):
53
+ if isinstance(lhs, torch.fx.Node):
54
+ lhs_dtype = lhs.meta["val"].dtype
55
+ elif isinstance(lhs, torch.Tensor):
56
+ lhs_dtype = lhs.dtype
57
+ else:
58
+ lhs_dtype = torch.tensor(lhs).dtype
59
+ if isinstance(rhs, torch.fx.Node):
60
+ rhs_dtype = rhs.meta["val"].dtype
61
+ elif isinstance(rhs, torch.Tensor):
62
+ rhs_dtype = rhs.dtype
63
+ else:
64
+ rhs_dtype = torch.tensor(rhs).dtype
65
+
66
+ if lhs_dtype == rhs_dtype:
67
+ return True
68
+ return False
69
+
70
+
71
+ def to_numeric_type(torch_dtype: torch.dtype):
72
+ dmap = {
73
+ torch.float32: float,
74
+ torch.float: float,
75
+ torch.int64: int,
76
+ torch.bool: bool,
77
+ }
78
+
79
+ if torch_dtype not in dmap:
80
+ return None
81
+
82
+ return dmap[torch_dtype]
83
+
84
+
85
+ @trace_graph_diff_on_pass
86
+ class CastMixedTypeArgs(PassBase):
87
+ def __init__(self, preserve_ep_invariant=True):
88
+ super().__init__()
89
+ self.preserve_ep_invariant = preserve_ep_invariant
90
+
91
+ # TODO Folding float and int values before this pass
92
+ def call(self, exported_program: ExportedProgram) -> PassResult:
93
+ logger = logging.getLogger(__name__)
94
+
95
+ graph_module = exported_program.graph_module
96
+ graph = graph_module.graph
97
+ modified = False
98
+ for node in graph.nodes:
99
+ if not node.op == "call_function":
100
+ continue
101
+
102
+ if node.target not in ops_to_promote:
103
+ continue
104
+
105
+ assert len(node.args) == 2
106
+ lhs, rhs = node.args
107
+ assert isinstance(lhs, (torch.fx.Node, torch.Tensor, float, int)), type(lhs)
108
+ assert isinstance(rhs, (torch.fx.Node, torch.Tensor, float, int)), type(rhs)
109
+ if has_same_dtype(lhs, rhs):
110
+ continue
111
+
112
+ lhs_val = (
113
+ lhs.meta["val"] if isinstance(lhs, torch.fx.Node) else torch.tensor(lhs)
114
+ )
115
+ rhs_val = (
116
+ rhs.meta["val"] if isinstance(rhs, torch.fx.Node) else torch.tensor(rhs)
117
+ )
118
+ type_to_promote: torch.dtype = elementwise_dtypes(
119
+ lhs_val, rhs_val, type_promotion_kind=ops_to_promote[node.target]
120
+ )[1]
121
+ arg_to_promote = None
122
+ if lhs_val.dtype == type_to_promote:
123
+ ori_type = rhs_val.dtype
124
+ arg_to_promote = rhs
125
+ if rhs_val.dtype == type_to_promote:
126
+ ori_type = lhs_val.dtype
127
+ arg_to_promote = lhs
128
+ assert arg_to_promote != None
129
+
130
+ if isinstance(arg_to_promote, torch.fx.Node):
131
+ with graph.inserting_after(arg_to_promote):
132
+ to_copy = graph.call_function(
133
+ torch.ops.aten._to_copy.default,
134
+ (arg_to_promote,),
135
+ {"dtype": type_to_promote},
136
+ )
137
+ # set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
138
+ set_new_meta_val(to_copy)
139
+ node.update_arg(node.args.index(arg_to_promote), to_copy)
140
+
141
+ modified = True
142
+ logger.debug(
143
+ f"{arg_to_promote.name}'s dtype was casted from {ori_type} to {type_to_promote}"
144
+ )
145
+ else:
146
+ index_to_promote = node.args.index(arg_to_promote)
147
+ if isinstance(arg_to_promote, torch.Tensor):
148
+ arg_to_promote = arg_to_promote.to(type_to_promote)
149
+ else:
150
+ # numerical types
151
+ numeric_type = to_numeric_type(type_to_promote)
152
+ if numeric_type is not None:
153
+ arg_to_promote = numeric_type(arg_to_promote)
154
+ else:
155
+ if self.preserve_ep_invariant:
156
+ # ExportedProgram (EP) requires to add a placeholder when
157
+ # a tensor is created, which complicates EP structure but
158
+ # not necessary for circle serialization. We skip this case if
159
+ # preserve_ep_invariant = True.
160
+ continue
161
+ else:
162
+ # Create tensor without placeholder
163
+ # NOTE This breaks EP invariant
164
+ arg_to_promote = torch.tensor(arg_to_promote).to(
165
+ type_to_promote
166
+ )
167
+ node.update_arg(index_to_promote, arg_to_promote)
168
+
169
+ modified = True
170
+ logger.debug(
171
+ f"{arg_to_promote}'s dtype was casted from {ori_type} to {type_to_promote}"
172
+ )
173
+
174
+ # check if type promotion is valid.
175
+ node_dtype_ori = extract_torch_dtype(node)
176
+ set_new_meta_val(node)
177
+ node_dtype = extract_torch_dtype(node)
178
+ assert (
179
+ node_dtype == node_dtype_ori
180
+ ), f"Type casting doesn't change node's dtype."
181
+
182
+ graph.eliminate_dead_code()
183
+ graph.lint()
184
+ graph_module.recompile()
185
+
186
+ return PassResult(modified)