tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from pathlib import Path
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from cffi import FFI
|
20
|
+
|
21
|
+
|
22
|
+
class Interpreter:
|
23
|
+
"""
|
24
|
+
Python wrapper for C++ luci-interperter class in ONE using CFFI.
|
25
|
+
|
26
|
+
This class provides a Python interface to the underlying C++ luci-interpreter class in ONE,
|
27
|
+
preserving the original C++ API. Each method corresponds to a method in the C++ class,
|
28
|
+
with additional error handling implemented to ensure that C++ exceptions are captured and
|
29
|
+
translated into Python errors.
|
30
|
+
|
31
|
+
Note that each method includes `check_for_errors` at the end of the body to catch any C++
|
32
|
+
exceptions and translate them into Python exceptions. This ensures that errors in the C++
|
33
|
+
library do not cause undefined behavior in Python.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, circle_binary: bytes):
|
37
|
+
self.ffi = FFI()
|
38
|
+
self.ffi.cdef(
|
39
|
+
"""
|
40
|
+
typedef struct InterpreterWrapper InterpreterWrapper;
|
41
|
+
|
42
|
+
const char *get_last_error(void);
|
43
|
+
void clear_last_error(void);
|
44
|
+
InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size);
|
45
|
+
void Interpreter_delete(InterpreterWrapper *intp);
|
46
|
+
void Interpreter_interpret(InterpreterWrapper *intp);
|
47
|
+
void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size);
|
48
|
+
void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size);
|
49
|
+
"""
|
50
|
+
)
|
51
|
+
# TODO Check if one-compiler version is compatible. Whether it has .so file or not for CFFI.
|
52
|
+
intp_lib_path = Path("/usr/share/one/lib/libcircle_interpreter_cffi.so")
|
53
|
+
if not intp_lib_path.is_file():
|
54
|
+
raise RuntimeError("Please install one-compiler for circle inference.")
|
55
|
+
self.C = self.ffi.dlopen(str(intp_lib_path))
|
56
|
+
|
57
|
+
# Initialize interpreter
|
58
|
+
self.intp = self.C.Interpreter_new(circle_binary, len(circle_binary))
|
59
|
+
self.check_for_errors()
|
60
|
+
|
61
|
+
def delete(self):
|
62
|
+
self.C.Interpreter_delete(self.intp)
|
63
|
+
self.check_for_errors()
|
64
|
+
|
65
|
+
def interpret(self):
|
66
|
+
self.C.Interpreter_interpret(self.intp)
|
67
|
+
self.check_for_errors()
|
68
|
+
|
69
|
+
def writeInputTensor(self, input_idx: int, input_data: torch.Tensor):
|
70
|
+
input_as_numpy = input_data.numpy()
|
71
|
+
# cffi.from_buffer() only accepts C-contiguous array.
|
72
|
+
input_as_numpy = np.ascontiguousarray(input_as_numpy)
|
73
|
+
c_input = self.ffi.from_buffer(input_as_numpy)
|
74
|
+
self.C.Interpreter_writeInputTensor(
|
75
|
+
self.intp, input_idx, c_input, input_data.nbytes
|
76
|
+
)
|
77
|
+
self.check_for_errors()
|
78
|
+
|
79
|
+
def readOutputTensor(self, output_idx: int, output: np.ndarray):
|
80
|
+
c_output = self.ffi.from_buffer(output)
|
81
|
+
self.C.Interpreter_readOutputTensor(
|
82
|
+
self.intp, output_idx, c_output, output.nbytes
|
83
|
+
)
|
84
|
+
self.check_for_errors()
|
85
|
+
|
86
|
+
def check_for_errors(self):
|
87
|
+
error_message = self.ffi.string(self.C.get_last_error()).decode("utf-8")
|
88
|
+
if error_message:
|
89
|
+
self.C.clear_last_error()
|
90
|
+
raise RuntimeError(f"C++ Exception: {error_message}")
|
91
|
+
|
92
|
+
def __del__(self):
|
93
|
+
self.delete()
|
tico/passes/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,185 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Tuple, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.passes import PassBase, PassResult
|
25
|
+
from tico.utils.trace_decorators import (
|
26
|
+
trace_const_diff_on_pass,
|
27
|
+
trace_graph_diff_on_pass,
|
28
|
+
)
|
29
|
+
from tico.utils.utils import set_new_meta_val
|
30
|
+
|
31
|
+
|
32
|
+
dtype_ranking = {
|
33
|
+
torch.int32: 0,
|
34
|
+
torch.int64: 1,
|
35
|
+
torch.float32: 2,
|
36
|
+
}
|
37
|
+
|
38
|
+
|
39
|
+
def sort_by_dtype(
|
40
|
+
result_true: torch.fx.Node, result_false: torch.fx.Node
|
41
|
+
) -> Tuple[torch.fx.Node, torch.fx.Node]:
|
42
|
+
true_dtype = extract_torch_dtype(result_true)
|
43
|
+
false_dtype = extract_torch_dtype(result_false)
|
44
|
+
if dtype_ranking[true_dtype] > dtype_ranking[false_dtype]:
|
45
|
+
return result_true, result_false
|
46
|
+
if dtype_ranking[true_dtype] < dtype_ranking[false_dtype]:
|
47
|
+
return result_false, result_true
|
48
|
+
assert False, "There is no case that the dtype_ranking of the nodes are the same"
|
49
|
+
|
50
|
+
|
51
|
+
def check_if_covered_by_float(tensor: torch.Tensor) -> bool:
|
52
|
+
# About the min/max range, please refer to https://en.wikipedia.org/wiki/Single-precision_floating-point_format#Precision_limitations_on_integer_values
|
53
|
+
if tensor.min() < -(2**24) or tensor.max() > 2**24:
|
54
|
+
return False
|
55
|
+
return True
|
56
|
+
|
57
|
+
|
58
|
+
@trace_graph_diff_on_pass
|
59
|
+
@trace_const_diff_on_pass
|
60
|
+
class CastATenWhereArgType(PassBase):
|
61
|
+
"""
|
62
|
+
This pass casts the data type of `aten.where.self` operation's argument.
|
63
|
+
|
64
|
+
This pass is applied when the data type of `aten.where.self` operation's argument is different.
|
65
|
+
If the data type of arguments, which are denoted `result_true` and `result_false` in below graph are identical, this pass is not applied.
|
66
|
+
|
67
|
+
In addition, this pass casts the data type as the direction that avoids data loss.
|
68
|
+
For example, if the data type of `result_true` is `float32` and the data type of `result_false` is `int32`,
|
69
|
+
then the data type of `result_false` will be casted to `float32`.
|
70
|
+
Moreover, in this case, it should be checked whether the contents of `result_false` are within the range of `float32`.
|
71
|
+
If so, the data type of `result_true` will be casted to `float32`.
|
72
|
+
If not, RuntimeError will be raised.
|
73
|
+
|
74
|
+
After this pass, the arguments of `aten.where.self` should have same data type.
|
75
|
+
|
76
|
+
The graph before this pass and the graph after this pass are shown below.
|
77
|
+
NOTE Below example denotes the case when the `result_false` was casted.
|
78
|
+
|
79
|
+
(before)
|
80
|
+
|
81
|
+
[condition] [result_true] [result_false]
|
82
|
+
| | |
|
83
|
+
| | |
|
84
|
+
+---------------+----------------+
|
85
|
+
|
|
86
|
+
|
|
87
|
+
[where]
|
88
|
+
|
|
89
|
+
|
|
90
|
+
[output]
|
91
|
+
|
92
|
+
(after)
|
93
|
+
|
94
|
+
[result_false]
|
95
|
+
[condition] [result_true] |
|
96
|
+
| | [cast]
|
97
|
+
| | |
|
98
|
+
+---------------+----------------+
|
99
|
+
|
|
100
|
+
|
|
101
|
+
[where]
|
102
|
+
|
|
103
|
+
|
|
104
|
+
[output]
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(self):
|
108
|
+
super().__init__()
|
109
|
+
|
110
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
111
|
+
logger = logging.getLogger(__name__)
|
112
|
+
graph_module = exported_program.graph_module
|
113
|
+
graph = graph_module.graph
|
114
|
+
modified = False
|
115
|
+
|
116
|
+
for node in graph.nodes:
|
117
|
+
if node.op == "call_function" and node.target == torch.ops.aten.where.self:
|
118
|
+
|
119
|
+
assert len(node.args) == 3
|
120
|
+
(
|
121
|
+
_,
|
122
|
+
result_true,
|
123
|
+
result_false,
|
124
|
+
) = node.args # first argument is not used
|
125
|
+
|
126
|
+
ep = exported_program
|
127
|
+
|
128
|
+
if not (
|
129
|
+
result_true.name in ep.graph_signature.inputs_to_buffers
|
130
|
+
and result_false.name in ep.graph_signature.inputs_to_buffers
|
131
|
+
):
|
132
|
+
continue
|
133
|
+
|
134
|
+
# Check if they have different data types
|
135
|
+
true_dtype = extract_torch_dtype(result_true)
|
136
|
+
false_dtype = extract_torch_dtype(result_false)
|
137
|
+
if true_dtype == false_dtype:
|
138
|
+
continue
|
139
|
+
|
140
|
+
node_to_dtype = {result_true: true_dtype, result_false: false_dtype}
|
141
|
+
|
142
|
+
not_to_cast, to_cast = sort_by_dtype(result_true, result_false)
|
143
|
+
|
144
|
+
buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
|
145
|
+
buf_name = ep.graph_signature.inputs_to_buffers[to_cast.name]
|
146
|
+
buf_data = buf_name_to_data[buf_name]
|
147
|
+
|
148
|
+
assert isinstance(buf_data, torch.Tensor)
|
149
|
+
|
150
|
+
dtype_to_cast = node_to_dtype[not_to_cast]
|
151
|
+
|
152
|
+
if dtype_to_cast == torch.float32:
|
153
|
+
if not check_if_covered_by_float(buf_data):
|
154
|
+
raise RuntimeError(
|
155
|
+
f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
|
156
|
+
)
|
157
|
+
with graph_module.graph.inserting_after(to_cast):
|
158
|
+
cast = graph_module.graph.call_function(
|
159
|
+
torch.ops.aten._to_copy.default,
|
160
|
+
args=(to_cast,),
|
161
|
+
kwargs={"dtype": dtype_to_cast},
|
162
|
+
)
|
163
|
+
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
164
|
+
set_new_meta_val(cast)
|
165
|
+
node.update_arg(node.args.index(to_cast), cast)
|
166
|
+
|
167
|
+
# check if type promotion is valid.
|
168
|
+
node_dtype_ori = extract_torch_dtype(node)
|
169
|
+
set_new_meta_val(node)
|
170
|
+
node_dtype = extract_torch_dtype(node)
|
171
|
+
assert (
|
172
|
+
node_dtype == node_dtype_ori
|
173
|
+
), f"Type casting doesn't change node's dtype."
|
174
|
+
|
175
|
+
logger.debug(
|
176
|
+
f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
|
177
|
+
)
|
178
|
+
|
179
|
+
modified = True
|
180
|
+
|
181
|
+
graph.eliminate_dead_code()
|
182
|
+
graph.lint()
|
183
|
+
graph_module.recompile()
|
184
|
+
|
185
|
+
return PassResult(modified)
|
@@ -0,0 +1,186 @@
|
|
1
|
+
# Portions of this file are adapted from code originally authored by
|
2
|
+
# Meta Platforms, Inc. and affiliates, licensed under the BSD-style
|
3
|
+
# license found in the LICENSE file in the root directory of their source tree.
|
4
|
+
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import torch.fx
|
23
|
+
import torch
|
24
|
+
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
|
25
|
+
from torch.export import ExportedProgram
|
26
|
+
|
27
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
28
|
+
from tico.utils import logging
|
29
|
+
from tico.utils.passes import PassBase, PassResult
|
30
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
31
|
+
from tico.utils.utils import set_new_meta_val
|
32
|
+
|
33
|
+
|
34
|
+
ops_to_promote = {
|
35
|
+
torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
36
|
+
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
37
|
+
torch.ops.aten.eq.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
38
|
+
torch.ops.aten.eq.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
39
|
+
torch.ops.aten.ge.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
40
|
+
torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
41
|
+
torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
42
|
+
torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
43
|
+
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
44
|
+
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
45
|
+
torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
46
|
+
torch.ops.aten.ne.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
47
|
+
torch.ops.aten.pow.Tensor_Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
48
|
+
torch.ops.aten.sub.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
49
|
+
}
|
50
|
+
|
51
|
+
|
52
|
+
def has_same_dtype(lhs, rhs):
|
53
|
+
if isinstance(lhs, torch.fx.Node):
|
54
|
+
lhs_dtype = lhs.meta["val"].dtype
|
55
|
+
elif isinstance(lhs, torch.Tensor):
|
56
|
+
lhs_dtype = lhs.dtype
|
57
|
+
else:
|
58
|
+
lhs_dtype = torch.tensor(lhs).dtype
|
59
|
+
if isinstance(rhs, torch.fx.Node):
|
60
|
+
rhs_dtype = rhs.meta["val"].dtype
|
61
|
+
elif isinstance(rhs, torch.Tensor):
|
62
|
+
rhs_dtype = rhs.dtype
|
63
|
+
else:
|
64
|
+
rhs_dtype = torch.tensor(rhs).dtype
|
65
|
+
|
66
|
+
if lhs_dtype == rhs_dtype:
|
67
|
+
return True
|
68
|
+
return False
|
69
|
+
|
70
|
+
|
71
|
+
def to_numeric_type(torch_dtype: torch.dtype):
|
72
|
+
dmap = {
|
73
|
+
torch.float32: float,
|
74
|
+
torch.float: float,
|
75
|
+
torch.int64: int,
|
76
|
+
torch.bool: bool,
|
77
|
+
}
|
78
|
+
|
79
|
+
if torch_dtype not in dmap:
|
80
|
+
return None
|
81
|
+
|
82
|
+
return dmap[torch_dtype]
|
83
|
+
|
84
|
+
|
85
|
+
@trace_graph_diff_on_pass
|
86
|
+
class CastMixedTypeArgs(PassBase):
|
87
|
+
def __init__(self, preserve_ep_invariant=True):
|
88
|
+
super().__init__()
|
89
|
+
self.preserve_ep_invariant = preserve_ep_invariant
|
90
|
+
|
91
|
+
# TODO Folding float and int values before this pass
|
92
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
93
|
+
logger = logging.getLogger(__name__)
|
94
|
+
|
95
|
+
graph_module = exported_program.graph_module
|
96
|
+
graph = graph_module.graph
|
97
|
+
modified = False
|
98
|
+
for node in graph.nodes:
|
99
|
+
if not node.op == "call_function":
|
100
|
+
continue
|
101
|
+
|
102
|
+
if node.target not in ops_to_promote:
|
103
|
+
continue
|
104
|
+
|
105
|
+
assert len(node.args) == 2
|
106
|
+
lhs, rhs = node.args
|
107
|
+
assert isinstance(lhs, (torch.fx.Node, torch.Tensor, float, int)), type(lhs)
|
108
|
+
assert isinstance(rhs, (torch.fx.Node, torch.Tensor, float, int)), type(rhs)
|
109
|
+
if has_same_dtype(lhs, rhs):
|
110
|
+
continue
|
111
|
+
|
112
|
+
lhs_val = (
|
113
|
+
lhs.meta["val"] if isinstance(lhs, torch.fx.Node) else torch.tensor(lhs)
|
114
|
+
)
|
115
|
+
rhs_val = (
|
116
|
+
rhs.meta["val"] if isinstance(rhs, torch.fx.Node) else torch.tensor(rhs)
|
117
|
+
)
|
118
|
+
type_to_promote: torch.dtype = elementwise_dtypes(
|
119
|
+
lhs_val, rhs_val, type_promotion_kind=ops_to_promote[node.target]
|
120
|
+
)[1]
|
121
|
+
arg_to_promote = None
|
122
|
+
if lhs_val.dtype == type_to_promote:
|
123
|
+
ori_type = rhs_val.dtype
|
124
|
+
arg_to_promote = rhs
|
125
|
+
if rhs_val.dtype == type_to_promote:
|
126
|
+
ori_type = lhs_val.dtype
|
127
|
+
arg_to_promote = lhs
|
128
|
+
assert arg_to_promote != None
|
129
|
+
|
130
|
+
if isinstance(arg_to_promote, torch.fx.Node):
|
131
|
+
with graph.inserting_after(arg_to_promote):
|
132
|
+
to_copy = graph.call_function(
|
133
|
+
torch.ops.aten._to_copy.default,
|
134
|
+
(arg_to_promote,),
|
135
|
+
{"dtype": type_to_promote},
|
136
|
+
)
|
137
|
+
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
138
|
+
set_new_meta_val(to_copy)
|
139
|
+
node.update_arg(node.args.index(arg_to_promote), to_copy)
|
140
|
+
|
141
|
+
modified = True
|
142
|
+
logger.debug(
|
143
|
+
f"{arg_to_promote.name}'s dtype was casted from {ori_type} to {type_to_promote}"
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
index_to_promote = node.args.index(arg_to_promote)
|
147
|
+
if isinstance(arg_to_promote, torch.Tensor):
|
148
|
+
arg_to_promote = arg_to_promote.to(type_to_promote)
|
149
|
+
else:
|
150
|
+
# numerical types
|
151
|
+
numeric_type = to_numeric_type(type_to_promote)
|
152
|
+
if numeric_type is not None:
|
153
|
+
arg_to_promote = numeric_type(arg_to_promote)
|
154
|
+
else:
|
155
|
+
if self.preserve_ep_invariant:
|
156
|
+
# ExportedProgram (EP) requires to add a placeholder when
|
157
|
+
# a tensor is created, which complicates EP structure but
|
158
|
+
# not necessary for circle serialization. We skip this case if
|
159
|
+
# preserve_ep_invariant = True.
|
160
|
+
continue
|
161
|
+
else:
|
162
|
+
# Create tensor without placeholder
|
163
|
+
# NOTE This breaks EP invariant
|
164
|
+
arg_to_promote = torch.tensor(arg_to_promote).to(
|
165
|
+
type_to_promote
|
166
|
+
)
|
167
|
+
node.update_arg(index_to_promote, arg_to_promote)
|
168
|
+
|
169
|
+
modified = True
|
170
|
+
logger.debug(
|
171
|
+
f"{arg_to_promote}'s dtype was casted from {ori_type} to {type_to_promote}"
|
172
|
+
)
|
173
|
+
|
174
|
+
# check if type promotion is valid.
|
175
|
+
node_dtype_ori = extract_torch_dtype(node)
|
176
|
+
set_new_meta_val(node)
|
177
|
+
node_dtype = extract_torch_dtype(node)
|
178
|
+
assert (
|
179
|
+
node_dtype == node_dtype_ori
|
180
|
+
), f"Type casting doesn't change node's dtype."
|
181
|
+
|
182
|
+
graph.eliminate_dead_code()
|
183
|
+
graph.lint()
|
184
|
+
graph_module.recompile()
|
185
|
+
|
186
|
+
return PassResult(modified)
|