tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Any, Dict, Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from tico.experimental.quantization.config import BaseConfig
|
21
|
+
|
22
|
+
|
23
|
+
class BaseQuantizer(ABC):
|
24
|
+
"""
|
25
|
+
Abstract base class for quantizers that apply a quantization algorithm to a target model.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, config: BaseConfig):
|
29
|
+
"""
|
30
|
+
Initialize the quantizer with the given configuration.
|
31
|
+
|
32
|
+
Parameters:
|
33
|
+
config (BaseConfig): Quantization configuration parameters.
|
34
|
+
"""
|
35
|
+
self.config = config
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def prepare(
|
39
|
+
self,
|
40
|
+
model: torch.nn.Module,
|
41
|
+
args: Optional[Any] = None,
|
42
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
43
|
+
):
|
44
|
+
"""
|
45
|
+
Prepare the given model for quantization based on the provided algorithm-specific
|
46
|
+
configuration. This involves setting up necessary observers or hooks, and may
|
47
|
+
optionally use example inputs—particularly useful for activation quantization.
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
model: The target PyTorch model.
|
51
|
+
args (Any, optional): Positional example inputs required for activation quantization.
|
52
|
+
kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The prepared model.
|
56
|
+
"""
|
57
|
+
pass
|
58
|
+
|
59
|
+
@abstractmethod
|
60
|
+
def convert(self, model):
|
61
|
+
"""
|
62
|
+
Convert the prepared (or calibrated) model into its quantized form. This function leverages
|
63
|
+
the statistics collected during calibration to perform the quantization transformation.
|
64
|
+
|
65
|
+
Parameters:
|
66
|
+
model: The prepared PyTorch model.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
The quantized model.
|
70
|
+
"""
|
71
|
+
pass
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from circle_schema import circle
|
20
|
+
|
21
|
+
from tico.interpreter.interpreter import Interpreter
|
22
|
+
from tico.serialize.circle_mapping import np_dtype_from_circle_dtype, to_circle_dtype
|
23
|
+
|
24
|
+
|
25
|
+
def preprocess_inputs(inputs: Any):
|
26
|
+
"""
|
27
|
+
Preprocess user inputs for circle inference.
|
28
|
+
|
29
|
+
1. None inputs are ignored.
|
30
|
+
2. A list/tuple input is flatten when a torch module is exported.
|
31
|
+
e.g. inputs = (torch.Tensor, [2,3,4]) -> inputs = (torch.Tensor, 2, 3, 4)
|
32
|
+
"""
|
33
|
+
l = []
|
34
|
+
for value in inputs:
|
35
|
+
if value == None:
|
36
|
+
continue
|
37
|
+
if isinstance(value, (tuple, list)):
|
38
|
+
for val in value:
|
39
|
+
l.append(val)
|
40
|
+
else:
|
41
|
+
l.append(value)
|
42
|
+
# Check if it is a list of a list.
|
43
|
+
if any(isinstance(item, (tuple, list)) for item in l):
|
44
|
+
l = preprocess_inputs(l)
|
45
|
+
return tuple(l)
|
46
|
+
|
47
|
+
|
48
|
+
def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any:
|
49
|
+
# When converting a model, it is assumed that the order of keyword arguments is maintained.
|
50
|
+
user_inputs = args + tuple(kwargs.values())
|
51
|
+
user_inputs = preprocess_inputs(user_inputs)
|
52
|
+
# Cast them to torch.Tensor to make it simple.
|
53
|
+
user_inputs = tuple(
|
54
|
+
torch.tensor(user_input) if type(user_input) != torch.Tensor else user_input
|
55
|
+
for user_input in user_inputs
|
56
|
+
)
|
57
|
+
|
58
|
+
# Get input spec from circle binary.
|
59
|
+
model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
|
60
|
+
assert model.SubgraphsLength() == 1
|
61
|
+
graph = model.Subgraphs(0)
|
62
|
+
model_input_tensors = [
|
63
|
+
graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())
|
64
|
+
]
|
65
|
+
model_input_shapes_np = [t.ShapeAsNumpy() for t in model_input_tensors]
|
66
|
+
model_input_types_cm = [t.Type() for t in model_input_tensors]
|
67
|
+
|
68
|
+
# Check if given inputs' dtype and shape from users match the inputs' from model binary.
|
69
|
+
if len(model_input_shapes_np) != len(user_inputs):
|
70
|
+
raise RuntimeError(
|
71
|
+
f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})"
|
72
|
+
)
|
73
|
+
for input_idx, user_input in enumerate(user_inputs):
|
74
|
+
# Shape
|
75
|
+
if list(user_input.shape) != list(model_input_shapes_np[input_idx]):
|
76
|
+
raise RuntimeError(
|
77
|
+
f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})"
|
78
|
+
)
|
79
|
+
# Data type
|
80
|
+
user_input_type_cm = to_circle_dtype(user_input.dtype)
|
81
|
+
if user_input_type_cm != model_input_types_cm[input_idx]:
|
82
|
+
raise RuntimeError(
|
83
|
+
f"Mismatch input {input_idx} data type : input({user_input_type_cm}) != circle model({model_input_types_cm[input_idx]})"
|
84
|
+
)
|
85
|
+
|
86
|
+
# Initialize interpreter
|
87
|
+
intp = Interpreter(circle_binary)
|
88
|
+
|
89
|
+
# Set input
|
90
|
+
for input_idx, user_input in enumerate(user_inputs):
|
91
|
+
intp.writeInputTensor(input_idx, user_input)
|
92
|
+
|
93
|
+
# Interpret
|
94
|
+
intp.interpret()
|
95
|
+
|
96
|
+
# Retrieve outputs' dtype and shape from circle model
|
97
|
+
model_output_tensors = [
|
98
|
+
graph.Tensors(graph.Outputs(o)) for o in range(graph.OutputsLength())
|
99
|
+
]
|
100
|
+
model_output_shapes_np = [t.ShapeAsNumpy() for t in model_output_tensors]
|
101
|
+
model_output_types_cm = [t.Type() for t in model_output_tensors]
|
102
|
+
|
103
|
+
output = []
|
104
|
+
# Get output
|
105
|
+
for output_idx in range(len(model_output_tensors)):
|
106
|
+
result: np.ndarray = np.empty(
|
107
|
+
model_output_shapes_np[output_idx],
|
108
|
+
dtype=np_dtype_from_circle_dtype(model_output_types_cm[output_idx]),
|
109
|
+
)
|
110
|
+
intp.readOutputTensor(output_idx, result)
|
111
|
+
output.append(result)
|
112
|
+
|
113
|
+
if len(output) == 1:
|
114
|
+
return output[0]
|
115
|
+
else:
|
116
|
+
return output
|
@@ -0,0 +1,93 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from pathlib import Path
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from cffi import FFI
|
20
|
+
|
21
|
+
|
22
|
+
class Interpreter:
|
23
|
+
"""
|
24
|
+
Python wrapper for C++ luci-interperter class in ONE using CFFI.
|
25
|
+
|
26
|
+
This class provides a Python interface to the underlying C++ luci-interpreter class in ONE,
|
27
|
+
preserving the original C++ API. Each method corresponds to a method in the C++ class,
|
28
|
+
with additional error handling implemented to ensure that C++ exceptions are captured and
|
29
|
+
translated into Python errors.
|
30
|
+
|
31
|
+
Note that each method includes `check_for_errors` at the end of the body to catch any C++
|
32
|
+
exceptions and translate them into Python exceptions. This ensures that errors in the C++
|
33
|
+
library do not cause undefined behavior in Python.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, circle_binary: bytes):
|
37
|
+
self.ffi = FFI()
|
38
|
+
self.ffi.cdef(
|
39
|
+
"""
|
40
|
+
typedef struct InterpreterWrapper InterpreterWrapper;
|
41
|
+
|
42
|
+
const char *get_last_error(void);
|
43
|
+
void clear_last_error(void);
|
44
|
+
InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size);
|
45
|
+
void Interpreter_delete(InterpreterWrapper *intp);
|
46
|
+
void Interpreter_interpret(InterpreterWrapper *intp);
|
47
|
+
void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, size_t input_size);
|
48
|
+
void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size);
|
49
|
+
"""
|
50
|
+
)
|
51
|
+
# TODO Check if one-compiler version is compatible. Whether it has .so file or not for CFFI.
|
52
|
+
intp_lib_path = Path("/usr/share/one/lib/libcircle_interpreter_cffi.so")
|
53
|
+
if not intp_lib_path.is_file():
|
54
|
+
raise RuntimeError("Please install one-compiler for circle inference.")
|
55
|
+
self.C = self.ffi.dlopen(str(intp_lib_path))
|
56
|
+
|
57
|
+
# Initialize interpreter
|
58
|
+
self.intp = self.C.Interpreter_new(circle_binary, len(circle_binary))
|
59
|
+
self.check_for_errors()
|
60
|
+
|
61
|
+
def delete(self):
|
62
|
+
self.C.Interpreter_delete(self.intp)
|
63
|
+
self.check_for_errors()
|
64
|
+
|
65
|
+
def interpret(self):
|
66
|
+
self.C.Interpreter_interpret(self.intp)
|
67
|
+
self.check_for_errors()
|
68
|
+
|
69
|
+
def writeInputTensor(self, input_idx: int, input_data: torch.Tensor):
|
70
|
+
input_as_numpy = input_data.numpy()
|
71
|
+
# cffi.from_buffer() only accepts C-contiguous array.
|
72
|
+
input_as_numpy = np.ascontiguousarray(input_as_numpy)
|
73
|
+
c_input = self.ffi.from_buffer(input_as_numpy)
|
74
|
+
self.C.Interpreter_writeInputTensor(
|
75
|
+
self.intp, input_idx, c_input, input_data.nbytes
|
76
|
+
)
|
77
|
+
self.check_for_errors()
|
78
|
+
|
79
|
+
def readOutputTensor(self, output_idx: int, output: np.ndarray):
|
80
|
+
c_output = self.ffi.from_buffer(output)
|
81
|
+
self.C.Interpreter_readOutputTensor(
|
82
|
+
self.intp, output_idx, c_output, output.nbytes
|
83
|
+
)
|
84
|
+
self.check_for_errors()
|
85
|
+
|
86
|
+
def check_for_errors(self):
|
87
|
+
error_message = self.ffi.string(self.C.get_last_error()).decode("utf-8")
|
88
|
+
if error_message:
|
89
|
+
self.C.clear_last_error()
|
90
|
+
raise RuntimeError(f"C++ Exception: {error_message}")
|
91
|
+
|
92
|
+
def __del__(self):
|
93
|
+
self.delete()
|
tico/passes/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Tuple, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.graph import create_node
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import (
|
27
|
+
trace_const_diff_on_pass,
|
28
|
+
trace_graph_diff_on_pass,
|
29
|
+
)
|
30
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
31
|
+
from tico.utils.validate_args_kwargs import WhereSelfArgs
|
32
|
+
|
33
|
+
|
34
|
+
dtype_ranking = {
|
35
|
+
torch.int32: 0,
|
36
|
+
torch.int64: 1,
|
37
|
+
torch.float32: 2,
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
def sort_by_dtype(
|
42
|
+
result_true: torch.fx.Node, result_false: torch.fx.Node
|
43
|
+
) -> Tuple[torch.fx.Node, torch.fx.Node]:
|
44
|
+
true_dtype = extract_torch_dtype(result_true)
|
45
|
+
false_dtype = extract_torch_dtype(result_false)
|
46
|
+
if dtype_ranking[true_dtype] > dtype_ranking[false_dtype]:
|
47
|
+
return result_true, result_false
|
48
|
+
if dtype_ranking[true_dtype] < dtype_ranking[false_dtype]:
|
49
|
+
return result_false, result_true
|
50
|
+
assert False, "There is no case that the dtype_ranking of the nodes are the same"
|
51
|
+
|
52
|
+
|
53
|
+
def check_if_covered_by_float(tensor: torch.Tensor) -> bool:
|
54
|
+
# About the min/max range, please refer to https://en.wikipedia.org/wiki/Single-precision_floating-point_format#Precision_limitations_on_integer_values
|
55
|
+
if tensor.min() < -(2**24) or tensor.max() > 2**24:
|
56
|
+
return False
|
57
|
+
return True
|
58
|
+
|
59
|
+
|
60
|
+
@trace_graph_diff_on_pass
|
61
|
+
@trace_const_diff_on_pass
|
62
|
+
class CastATenWhereArgType(PassBase):
|
63
|
+
"""
|
64
|
+
This pass casts the data type of `aten.where.self` operation's argument.
|
65
|
+
|
66
|
+
This pass is applied when the data type of `aten.where.self` operation's argument is different.
|
67
|
+
If the data type of arguments, which are denoted `result_true` and `result_false` in below graph are identical, this pass is not applied.
|
68
|
+
|
69
|
+
In addition, this pass casts the data type as the direction that avoids data loss.
|
70
|
+
For example, if the data type of `result_true` is `float32` and the data type of `result_false` is `int32`,
|
71
|
+
then the data type of `result_false` will be casted to `float32`.
|
72
|
+
Moreover, in this case, it should be checked whether the contents of `result_false` are within the range of `float32`.
|
73
|
+
If so, the data type of `result_true` will be casted to `float32`.
|
74
|
+
If not, RuntimeError will be raised.
|
75
|
+
|
76
|
+
After this pass, the arguments of `aten.where.self` should have same data type.
|
77
|
+
|
78
|
+
The graph before this pass and the graph after this pass are shown below.
|
79
|
+
NOTE Below example denotes the case when the `result_false` was casted.
|
80
|
+
|
81
|
+
(before)
|
82
|
+
|
83
|
+
[condition] [result_true] [result_false]
|
84
|
+
| | |
|
85
|
+
| | |
|
86
|
+
+---------------+----------------+
|
87
|
+
|
|
88
|
+
|
|
89
|
+
[where]
|
90
|
+
|
|
91
|
+
|
|
92
|
+
[output]
|
93
|
+
|
94
|
+
(after)
|
95
|
+
|
96
|
+
[result_false]
|
97
|
+
[condition] [result_true] |
|
98
|
+
| | [cast]
|
99
|
+
| | |
|
100
|
+
+---------------+----------------+
|
101
|
+
|
|
102
|
+
|
|
103
|
+
[where]
|
104
|
+
|
|
105
|
+
|
|
106
|
+
[output]
|
107
|
+
"""
|
108
|
+
|
109
|
+
def __init__(self):
|
110
|
+
super().__init__()
|
111
|
+
|
112
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
113
|
+
logger = logging.getLogger(__name__)
|
114
|
+
graph_module = exported_program.graph_module
|
115
|
+
graph = graph_module.graph
|
116
|
+
modified = False
|
117
|
+
|
118
|
+
for node in graph.nodes:
|
119
|
+
if not is_target_node(node, torch.ops.aten.where.self):
|
120
|
+
continue
|
121
|
+
|
122
|
+
where_args = WhereSelfArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
123
|
+
result_true, result_false = where_args.input, where_args.other
|
124
|
+
if not isinstance(result_true, torch.fx.Node) or not isinstance(
|
125
|
+
result_false, torch.fx.Node
|
126
|
+
):
|
127
|
+
continue
|
128
|
+
|
129
|
+
ep = exported_program
|
130
|
+
assert isinstance(result_true, torch.fx.Node)
|
131
|
+
assert isinstance(result_false, torch.fx.Node)
|
132
|
+
if not (
|
133
|
+
result_true.name in ep.graph_signature.inputs_to_buffers
|
134
|
+
and result_false.name in ep.graph_signature.inputs_to_buffers
|
135
|
+
):
|
136
|
+
continue
|
137
|
+
|
138
|
+
# Check if they have different data types
|
139
|
+
true_dtype = extract_torch_dtype(result_true)
|
140
|
+
false_dtype = extract_torch_dtype(result_false)
|
141
|
+
if true_dtype == false_dtype:
|
142
|
+
continue
|
143
|
+
|
144
|
+
node_to_dtype = {result_true: true_dtype, result_false: false_dtype}
|
145
|
+
|
146
|
+
not_to_cast, to_cast = sort_by_dtype(result_true, result_false)
|
147
|
+
|
148
|
+
buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
|
149
|
+
buf_name = ep.graph_signature.inputs_to_buffers[to_cast.name]
|
150
|
+
buf_data = buf_name_to_data[buf_name]
|
151
|
+
|
152
|
+
assert isinstance(buf_data, torch.Tensor)
|
153
|
+
|
154
|
+
dtype_to_cast = node_to_dtype[not_to_cast]
|
155
|
+
|
156
|
+
if dtype_to_cast == torch.float32:
|
157
|
+
if not check_if_covered_by_float(buf_data):
|
158
|
+
raise RuntimeError(
|
159
|
+
f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
|
160
|
+
)
|
161
|
+
with graph_module.graph.inserting_after(to_cast):
|
162
|
+
cast = create_node(
|
163
|
+
graph,
|
164
|
+
torch.ops.aten._to_copy.default,
|
165
|
+
args=(to_cast,),
|
166
|
+
kwargs={"dtype": dtype_to_cast},
|
167
|
+
origin=to_cast,
|
168
|
+
)
|
169
|
+
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
170
|
+
set_new_meta_val(cast)
|
171
|
+
node.update_arg(node.args.index(to_cast), cast)
|
172
|
+
|
173
|
+
# check if type promotion is valid.
|
174
|
+
node_dtype_ori = extract_torch_dtype(node)
|
175
|
+
set_new_meta_val(node)
|
176
|
+
node_dtype = extract_torch_dtype(node)
|
177
|
+
assert (
|
178
|
+
node_dtype == node_dtype_ori
|
179
|
+
), f"Type casting doesn't change node's dtype."
|
180
|
+
|
181
|
+
logger.debug(
|
182
|
+
f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
|
183
|
+
)
|
184
|
+
|
185
|
+
modified = True
|
186
|
+
|
187
|
+
graph.eliminate_dead_code()
|
188
|
+
graph.lint()
|
189
|
+
graph_module.recompile()
|
190
|
+
|
191
|
+
return PassResult(modified)
|
@@ -0,0 +1,187 @@
|
|
1
|
+
# Portions of this file are adapted from code originally authored by
|
2
|
+
# Meta Platforms, Inc. and affiliates, licensed under the BSD-style
|
3
|
+
# license found in the LICENSE file in the root directory of their source tree.
|
4
|
+
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import torch.fx
|
23
|
+
import torch
|
24
|
+
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
|
25
|
+
from torch.export import ExportedProgram
|
26
|
+
|
27
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
28
|
+
from tico.utils import logging
|
29
|
+
from tico.utils.graph import create_node
|
30
|
+
from tico.utils.passes import PassBase, PassResult
|
31
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
32
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
33
|
+
|
34
|
+
|
35
|
+
ops_to_promote = {
|
36
|
+
torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
37
|
+
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
38
|
+
torch.ops.aten.eq.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
39
|
+
torch.ops.aten.eq.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
40
|
+
torch.ops.aten.ge.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
41
|
+
torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
42
|
+
torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
43
|
+
torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
44
|
+
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
45
|
+
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
46
|
+
torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
47
|
+
torch.ops.aten.ne.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
48
|
+
torch.ops.aten.pow.Tensor_Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
49
|
+
torch.ops.aten.sub.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
50
|
+
}
|
51
|
+
|
52
|
+
|
53
|
+
def has_same_dtype(lhs, rhs):
|
54
|
+
if isinstance(lhs, torch.fx.Node):
|
55
|
+
lhs_dtype = lhs.meta["val"].dtype
|
56
|
+
elif isinstance(lhs, torch.Tensor):
|
57
|
+
lhs_dtype = lhs.dtype
|
58
|
+
else:
|
59
|
+
lhs_dtype = torch.tensor(lhs).dtype
|
60
|
+
if isinstance(rhs, torch.fx.Node):
|
61
|
+
rhs_dtype = rhs.meta["val"].dtype
|
62
|
+
elif isinstance(rhs, torch.Tensor):
|
63
|
+
rhs_dtype = rhs.dtype
|
64
|
+
else:
|
65
|
+
rhs_dtype = torch.tensor(rhs).dtype
|
66
|
+
|
67
|
+
if lhs_dtype == rhs_dtype:
|
68
|
+
return True
|
69
|
+
return False
|
70
|
+
|
71
|
+
|
72
|
+
def to_numeric_type(torch_dtype: torch.dtype):
|
73
|
+
dmap = {
|
74
|
+
torch.float32: float,
|
75
|
+
torch.float: float,
|
76
|
+
torch.int64: int,
|
77
|
+
torch.bool: bool,
|
78
|
+
}
|
79
|
+
|
80
|
+
if torch_dtype not in dmap:
|
81
|
+
return None
|
82
|
+
|
83
|
+
return dmap[torch_dtype]
|
84
|
+
|
85
|
+
|
86
|
+
@trace_graph_diff_on_pass
|
87
|
+
class CastMixedTypeArgs(PassBase):
|
88
|
+
def __init__(self, preserve_ep_invariant=True):
|
89
|
+
super().__init__()
|
90
|
+
self.preserve_ep_invariant = preserve_ep_invariant
|
91
|
+
|
92
|
+
# TODO Folding float and int values before this pass
|
93
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
94
|
+
logger = logging.getLogger(__name__)
|
95
|
+
|
96
|
+
graph_module = exported_program.graph_module
|
97
|
+
graph = graph_module.graph
|
98
|
+
modified = False
|
99
|
+
for node in graph.nodes:
|
100
|
+
if not is_target_node(node, list(ops_to_promote.keys())):
|
101
|
+
continue
|
102
|
+
|
103
|
+
assert len(node.args) == 2
|
104
|
+
lhs, rhs = node.args
|
105
|
+
assert isinstance(lhs, (torch.fx.Node, torch.Tensor, float, int)), type(lhs)
|
106
|
+
assert isinstance(rhs, (torch.fx.Node, torch.Tensor, float, int)), type(rhs)
|
107
|
+
if has_same_dtype(lhs, rhs):
|
108
|
+
continue
|
109
|
+
|
110
|
+
lhs_val = (
|
111
|
+
lhs.meta["val"] if isinstance(lhs, torch.fx.Node) else torch.tensor(lhs)
|
112
|
+
)
|
113
|
+
rhs_val = (
|
114
|
+
rhs.meta["val"] if isinstance(rhs, torch.fx.Node) else torch.tensor(rhs)
|
115
|
+
)
|
116
|
+
type_to_promote: torch.dtype = elementwise_dtypes(
|
117
|
+
lhs_val, rhs_val, type_promotion_kind=ops_to_promote[node.target]
|
118
|
+
)[1]
|
119
|
+
arg_to_promote = None
|
120
|
+
ori_type = None
|
121
|
+
if lhs_val.dtype == type_to_promote:
|
122
|
+
ori_type = rhs_val.dtype
|
123
|
+
arg_to_promote = rhs
|
124
|
+
if rhs_val.dtype == type_to_promote:
|
125
|
+
ori_type = lhs_val.dtype
|
126
|
+
arg_to_promote = lhs
|
127
|
+
assert arg_to_promote != None
|
128
|
+
|
129
|
+
if isinstance(arg_to_promote, torch.fx.Node):
|
130
|
+
with graph.inserting_after(arg_to_promote):
|
131
|
+
to_copy = create_node(
|
132
|
+
graph,
|
133
|
+
torch.ops.aten._to_copy.default,
|
134
|
+
(arg_to_promote,),
|
135
|
+
{"dtype": type_to_promote},
|
136
|
+
origin=arg_to_promote,
|
137
|
+
)
|
138
|
+
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
139
|
+
set_new_meta_val(to_copy)
|
140
|
+
node.update_arg(node.args.index(arg_to_promote), to_copy)
|
141
|
+
|
142
|
+
modified = True
|
143
|
+
logger.debug(
|
144
|
+
f"{arg_to_promote.name}'s dtype was casted from {ori_type} to {type_to_promote}"
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
index_to_promote = node.args.index(arg_to_promote)
|
148
|
+
if isinstance(arg_to_promote, torch.Tensor):
|
149
|
+
arg_to_promote = arg_to_promote.to(type_to_promote)
|
150
|
+
else:
|
151
|
+
# numerical types
|
152
|
+
numeric_type = to_numeric_type(type_to_promote)
|
153
|
+
if numeric_type is not None:
|
154
|
+
arg_to_promote = numeric_type(arg_to_promote)
|
155
|
+
else:
|
156
|
+
if self.preserve_ep_invariant:
|
157
|
+
# ExportedProgram (EP) requires to add a placeholder when
|
158
|
+
# a tensor is created, which complicates EP structure but
|
159
|
+
# not necessary for circle serialization. We skip this case if
|
160
|
+
# preserve_ep_invariant = True.
|
161
|
+
continue
|
162
|
+
else:
|
163
|
+
# Create tensor without placeholder
|
164
|
+
# NOTE This breaks EP invariant
|
165
|
+
arg_to_promote = torch.tensor(arg_to_promote).to(
|
166
|
+
type_to_promote
|
167
|
+
)
|
168
|
+
node.update_arg(index_to_promote, arg_to_promote)
|
169
|
+
|
170
|
+
modified = True
|
171
|
+
logger.debug(
|
172
|
+
f"{arg_to_promote}'s dtype was casted from {ori_type} to {type_to_promote}"
|
173
|
+
)
|
174
|
+
|
175
|
+
# check if type promotion is valid.
|
176
|
+
node_dtype_ori = extract_torch_dtype(node)
|
177
|
+
set_new_meta_val(node)
|
178
|
+
node_dtype = extract_torch_dtype(node)
|
179
|
+
assert (
|
180
|
+
node_dtype == node_dtype_ori
|
181
|
+
), f"Type casting doesn't change node's dtype."
|
182
|
+
|
183
|
+
graph.eliminate_dead_code()
|
184
|
+
graph.lint()
|
185
|
+
graph_module.recompile()
|
186
|
+
|
187
|
+
return PassResult(modified)
|