tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
tico/utils/convert.py
ADDED
@@ -0,0 +1,292 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import operator
|
16
|
+
import os
|
17
|
+
from typing import Any, Dict, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch.export import export, ExportedProgram
|
21
|
+
|
22
|
+
from tico.config import CompileConfigBase, get_default_config
|
23
|
+
from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
|
24
|
+
from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
|
25
|
+
InsertQuantizeOnDtypeMismatch,
|
26
|
+
)
|
27
|
+
from tico.experimental.quantization.passes.propagate_qparam_backward import (
|
28
|
+
PropagateQParamBackward,
|
29
|
+
)
|
30
|
+
from tico.experimental.quantization.passes.propagate_qparam_forward import (
|
31
|
+
PropagateQParamForward,
|
32
|
+
)
|
33
|
+
from tico.experimental.quantization.passes.remove_weight_dequant_op import (
|
34
|
+
RemoveWeightDequantOp,
|
35
|
+
)
|
36
|
+
from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
|
37
|
+
from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
|
38
|
+
from tico.passes.const_prop_pass import ConstPropPass
|
39
|
+
from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
|
40
|
+
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
|
41
|
+
from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
|
42
|
+
from tico.passes.convert_to_relu6 import ConvertToReLU6
|
43
|
+
from tico.passes.decompose_addmm import DecomposeAddmm
|
44
|
+
from tico.passes.decompose_batch_norm import DecomposeBatchNorm
|
45
|
+
from tico.passes.decompose_fake_quantize import DecomposeFakeQuantize
|
46
|
+
from tico.passes.decompose_fake_quantize_tensor_qparams import (
|
47
|
+
DecomposeFakeQuantizeTensorQParams,
|
48
|
+
)
|
49
|
+
from tico.passes.decompose_group_norm import DecomposeGroupNorm
|
50
|
+
from tico.passes.decompose_grouped_conv2d import DecomposeGroupedConv2d
|
51
|
+
from tico.passes.decompose_slice_scatter import DecomposeSliceScatter
|
52
|
+
from tico.passes.extract_dtype_kwargs import ExtractDtypeKwargsPass
|
53
|
+
from tico.passes.fill_meta_val import FillMetaVal
|
54
|
+
from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean
|
55
|
+
from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
|
56
|
+
from tico.passes.legalize_predefined_layout_operators import (
|
57
|
+
LegalizePreDefinedLayoutOperators,
|
58
|
+
)
|
59
|
+
from tico.passes.lower_pow2_to_mul import LowerPow2ToMul
|
60
|
+
from tico.passes.lower_to_resize_nearest_neighbor import LowerToResizeNearestNeighbor
|
61
|
+
from tico.passes.lower_to_slice import LowerToSlice
|
62
|
+
from tico.passes.merge_consecutive_cat import MergeConsecutiveCat
|
63
|
+
from tico.passes.remove_nop import RemoveNop
|
64
|
+
from tico.passes.remove_redundant_assert_nodes import RemoveRedundantAssertionNodes
|
65
|
+
from tico.passes.remove_redundant_expand import RemoveRedundantExpand
|
66
|
+
from tico.passes.remove_redundant_permute import passes as RemoveRedundantPermutePasses
|
67
|
+
from tico.passes.remove_redundant_reshape import passes as RemoveRedundantViewPasses
|
68
|
+
from tico.passes.remove_redundant_slice import RemoveRedundantSlice
|
69
|
+
from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
|
70
|
+
from tico.passes.restore_linear import RestoreLinear
|
71
|
+
from tico.passes.segment_index_select import SegmentIndexSelectConst
|
72
|
+
from tico.serialize.circle_serializer import build_circle
|
73
|
+
from tico.serialize.operators.node_visitor import get_support_targets
|
74
|
+
from tico.utils import logging
|
75
|
+
from tico.utils.errors import NotYetSupportedError
|
76
|
+
from tico.utils.model import CircleModel
|
77
|
+
from tico.utils.passes import PassManager
|
78
|
+
from tico.utils.trace_decorators import (
|
79
|
+
trace_const_diff_on_func,
|
80
|
+
trace_graph_diff_on_func,
|
81
|
+
)
|
82
|
+
from tico.utils.utils import has_quantization_ops, SuppressWarning
|
83
|
+
|
84
|
+
|
85
|
+
@trace_const_diff_on_func
|
86
|
+
@trace_graph_diff_on_func
|
87
|
+
def traced_run_decompositions(exported_program: ExportedProgram):
|
88
|
+
"""
|
89
|
+
Let's preserve convolution operators.
|
90
|
+
`run_decompositions()` converts all Conv-related Ops to generic `aten.convolution`.
|
91
|
+
But, we should re-convert them to specific circle ops such as CircleConv2D, TransposeConv, etc.
|
92
|
+
Therefore, we do not decompose Conv-related Ops and convert them directly to circle ops.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def run_decompositions_v25(ep: ExportedProgram):
|
96
|
+
_preserve_ops = (
|
97
|
+
torch.ops.aten.conv2d.default,
|
98
|
+
torch.ops.aten.conv2d.padding,
|
99
|
+
torch.ops.aten.conv1d.default,
|
100
|
+
torch.ops.aten.conv1d.padding,
|
101
|
+
torch.ops.aten.instance_norm.default,
|
102
|
+
torch.ops.aten._safe_softmax.default,
|
103
|
+
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
104
|
+
torch.ops.aten.linear.default,
|
105
|
+
)
|
106
|
+
ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
|
107
|
+
|
108
|
+
return ep
|
109
|
+
|
110
|
+
def run_decompositions(ep: ExportedProgram):
|
111
|
+
_decomp_table = torch.export.default_decompositions() # type: ignore[attr-defined]
|
112
|
+
_preserve_ops = (
|
113
|
+
torch.ops.aten.conv2d.default,
|
114
|
+
torch.ops.aten.conv2d.padding,
|
115
|
+
torch.ops.aten.conv1d.default,
|
116
|
+
torch.ops.aten.conv1d.padding,
|
117
|
+
torch.ops.aten.instance_norm.default,
|
118
|
+
torch.ops.aten._safe_softmax.default,
|
119
|
+
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
120
|
+
torch.ops.aten.prelu.default,
|
121
|
+
torch.ops.aten.linear.default,
|
122
|
+
)
|
123
|
+
for op in _preserve_ops:
|
124
|
+
if op in _decomp_table:
|
125
|
+
del _decomp_table[op]
|
126
|
+
|
127
|
+
ep = ep.run_decompositions(decomp_table=_decomp_table)
|
128
|
+
return ep
|
129
|
+
|
130
|
+
if torch.__version__.startswith("2.5"):
|
131
|
+
return run_decompositions_v25(exported_program)
|
132
|
+
elif (
|
133
|
+
torch.__version__.startswith("2.6")
|
134
|
+
or torch.__version__.startswith("2.7")
|
135
|
+
or torch.__version__.startswith("2.8")
|
136
|
+
):
|
137
|
+
return run_decompositions(exported_program)
|
138
|
+
else:
|
139
|
+
raise RuntimeError(f"Unsupported PyTorch version: {torch.__version__}")
|
140
|
+
|
141
|
+
|
142
|
+
def check_unsupported_target(exported_program: ExportedProgram):
|
143
|
+
logger = logging.getLogger(__name__)
|
144
|
+
|
145
|
+
supported_target = list(get_support_targets())
|
146
|
+
# Ignore `getitem` since it is no-op for multiple outputs.
|
147
|
+
supported_target.append(operator.getitem)
|
148
|
+
unsupported = []
|
149
|
+
for n in exported_program.graph.nodes:
|
150
|
+
if n.op != "call_function":
|
151
|
+
continue
|
152
|
+
if not n.target in supported_target:
|
153
|
+
unsupported.append(n)
|
154
|
+
|
155
|
+
if unsupported:
|
156
|
+
for node in unsupported:
|
157
|
+
logger.error(
|
158
|
+
f"NOT SUPPORTED OPERATOR\n\t(op) {node.target.__name__}\n\t(trace) {node.meta.get('stack_trace')}"
|
159
|
+
)
|
160
|
+
raise NotYetSupportedError("NOT SUPPORTED OPERATOR IN GRAPH MODULE")
|
161
|
+
|
162
|
+
|
163
|
+
def convert_exported_module_to_circle(
|
164
|
+
exported_program: ExportedProgram,
|
165
|
+
config: CompileConfigBase = get_default_config(),
|
166
|
+
) -> bytes:
|
167
|
+
logger = logging.getLogger(__name__)
|
168
|
+
logger.debug("Input ExportedProgram (must be core aten)")
|
169
|
+
logger.debug(exported_program)
|
170
|
+
|
171
|
+
# PRE-EDGE PASSES
|
172
|
+
#
|
173
|
+
# Here are the passes that run before to_edge() conversion.
|
174
|
+
# Let's decompose nodes that are not Aten Canonical, which can't be converted to the edge IR.
|
175
|
+
decompose_quantize_op = PassManager(
|
176
|
+
passes=[
|
177
|
+
DecomposeFakeQuantize(),
|
178
|
+
DecomposeFakeQuantizeTensorQParams(),
|
179
|
+
]
|
180
|
+
)
|
181
|
+
decompose_quantize_op.run(exported_program)
|
182
|
+
|
183
|
+
# This pass should be run before 'RestoreLinear' and after 'decompose_quantize_op'.
|
184
|
+
# TODO run pass regardless of the orders.
|
185
|
+
with SuppressWarning(UserWarning, ".*quantize_per_tensor"), SuppressWarning(
|
186
|
+
UserWarning,
|
187
|
+
".*TF32 acceleration on top of oneDNN is available for Intel GPUs.*",
|
188
|
+
):
|
189
|
+
# Warning details:
|
190
|
+
# ...site-packages/torch/_subclasses/functional_tensor.py:364
|
191
|
+
# UserWarning: At pre-dispatch tracing, we assume that any custom op marked with
|
192
|
+
# CompositeImplicitAutograd and have functional schema are safe to not decompose.
|
193
|
+
exported_program = traced_run_decompositions(exported_program)
|
194
|
+
|
195
|
+
# TODO Distinguish legalize and optimize
|
196
|
+
circle_legalize = PassManager(
|
197
|
+
passes=[
|
198
|
+
FillMetaVal(),
|
199
|
+
ExtractDtypeKwargsPass(),
|
200
|
+
RemoveNop(),
|
201
|
+
ConvertLayoutOpToReshape(),
|
202
|
+
RestoreLinear(),
|
203
|
+
ConvertToReLU6(),
|
204
|
+
DecomposeAddmm(),
|
205
|
+
DecomposeSliceScatter(),
|
206
|
+
DecomposeGroupNorm(),
|
207
|
+
DecomposeBatchNorm(),
|
208
|
+
DecomposeGroupedConv2d(),
|
209
|
+
CastATenWhereArgType(),
|
210
|
+
ConvertRepeatToExpandCopy(),
|
211
|
+
*RemoveRedundantPermutePasses(),
|
212
|
+
RemoveRedundantAssertionNodes(),
|
213
|
+
RemoveRedundantExpand(),
|
214
|
+
RemoveRedundantSlice(),
|
215
|
+
FuseRedundantReshapeToMean(),
|
216
|
+
*RemoveRedundantViewPasses(),
|
217
|
+
RemoveRedundantToCopy(),
|
218
|
+
MergeConsecutiveCat(),
|
219
|
+
CastMixedTypeArgs(preserve_ep_invariant=True),
|
220
|
+
ConstPropPass(),
|
221
|
+
SegmentIndexSelectConst(),
|
222
|
+
LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
|
223
|
+
LowerToResizeNearestNeighbor(),
|
224
|
+
LegalizePreDefinedLayoutOperators(),
|
225
|
+
LowerPow2ToMul(),
|
226
|
+
ConvertConv1dToConv2d(),
|
227
|
+
LowerToSlice(),
|
228
|
+
]
|
229
|
+
)
|
230
|
+
circle_legalize.run(exported_program)
|
231
|
+
|
232
|
+
# After this stage, ExportedProgram invariant is broken, i.e.,
|
233
|
+
# graph can have a constant torch.tensor not lifted to a placeholder
|
234
|
+
circle_legalize = PassManager(
|
235
|
+
passes=[
|
236
|
+
FillMetaVal(),
|
237
|
+
CastMixedTypeArgs(preserve_ep_invariant=False),
|
238
|
+
]
|
239
|
+
)
|
240
|
+
circle_legalize.run(exported_program)
|
241
|
+
|
242
|
+
# TODO Give an option to enable quantiztion to user
|
243
|
+
enable_quantization = has_quantization_ops(exported_program.graph)
|
244
|
+
if enable_quantization:
|
245
|
+
quantize_graph = PassManager(
|
246
|
+
passes=[
|
247
|
+
FoldQuantOps(),
|
248
|
+
RemoveWeightDequantOp(),
|
249
|
+
PropagateQParamForward(),
|
250
|
+
PropagateQParamBackward(),
|
251
|
+
InsertQuantizeOnDtypeMismatch(),
|
252
|
+
]
|
253
|
+
)
|
254
|
+
quantize_graph.run(exported_program)
|
255
|
+
|
256
|
+
check_unsupported_target(exported_program)
|
257
|
+
circle_program = build_circle(exported_program)
|
258
|
+
|
259
|
+
return circle_program
|
260
|
+
|
261
|
+
|
262
|
+
def convert(
|
263
|
+
mod: torch.nn.Module,
|
264
|
+
args: Tuple[Any, ...],
|
265
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
266
|
+
strict: bool = True,
|
267
|
+
config: CompileConfigBase = get_default_config(),
|
268
|
+
) -> CircleModel:
|
269
|
+
with torch.no_grad():
|
270
|
+
exported_program = export(mod, args, kwargs, strict=strict)
|
271
|
+
|
272
|
+
circle_binary = convert_exported_module_to_circle(exported_program, config=config)
|
273
|
+
|
274
|
+
return CircleModel(circle_binary)
|
275
|
+
|
276
|
+
|
277
|
+
def convert_from_exported_program(
|
278
|
+
exported_program: ExportedProgram,
|
279
|
+
config: CompileConfigBase = get_default_config(),
|
280
|
+
) -> CircleModel:
|
281
|
+
circle_binary = convert_exported_module_to_circle(exported_program, config=config)
|
282
|
+
|
283
|
+
return CircleModel(circle_binary)
|
284
|
+
|
285
|
+
|
286
|
+
def convert_from_pt2(
|
287
|
+
pt2_path: str | os.PathLike, config: CompileConfigBase = get_default_config()
|
288
|
+
) -> CircleModel:
|
289
|
+
exported_program = torch.export.load(pt2_path)
|
290
|
+
circle_binary = convert_exported_module_to_circle(exported_program, config=config)
|
291
|
+
|
292
|
+
return CircleModel(circle_binary)
|
tico/utils/define.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List
|
16
|
+
|
17
|
+
from circle_schema import circle
|
18
|
+
|
19
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
20
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
21
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
22
|
+
|
23
|
+
|
24
|
+
def define_pad_node(
|
25
|
+
graph: CircleSubgraph, op_codes: Dict[OpCode, int], inputs: List, outputs: List
|
26
|
+
) -> circle.Operator.OperatorT:
|
27
|
+
def set_pad_option(operator: circle.Operator.OperatorT):
|
28
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PadOptions
|
29
|
+
option = circle.PadOptions.PadOptionsT()
|
30
|
+
operator.builtinOptions = option
|
31
|
+
|
32
|
+
pad_op_index = get_op_index(circle.BuiltinOperator.BuiltinOperator.PAD, op_codes)
|
33
|
+
operator = create_builtin_operator(graph, pad_op_index, inputs, outputs)
|
34
|
+
set_pad_option(operator)
|
35
|
+
return operator
|
tico/utils/diff_graph.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from copy import deepcopy
|
16
|
+
from difflib import ndiff
|
17
|
+
from functools import reduce
|
18
|
+
from logging import DEBUG
|
19
|
+
from typing import Optional, TYPE_CHECKING
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import torch.fx
|
23
|
+
import torch
|
24
|
+
|
25
|
+
from tico.utils.logging import getLogger, LOG_LEVEL
|
26
|
+
|
27
|
+
|
28
|
+
def strdiff(a: str, b: str):
|
29
|
+
"""
|
30
|
+
Get difference in two strings as if linux `diff` command does
|
31
|
+
"""
|
32
|
+
assert isinstance(a, str), f"{a} must be str, type: {type(a)}"
|
33
|
+
assert isinstance(b, str), f"{b} must be str, type: {type(b)}"
|
34
|
+
|
35
|
+
changed = []
|
36
|
+
for line in ndiff(a.splitlines(keepends=True), b.splitlines(keepends=True)):
|
37
|
+
if line.startswith(("-", "+")):
|
38
|
+
changed.append(line)
|
39
|
+
return "".join(changed)
|
40
|
+
|
41
|
+
|
42
|
+
def disable_when(predicate):
|
43
|
+
"""
|
44
|
+
Disable function only if predicate is true
|
45
|
+
"""
|
46
|
+
|
47
|
+
def _inner_disable_when(func):
|
48
|
+
if predicate:
|
49
|
+
|
50
|
+
def nop(*args, **kwargs):
|
51
|
+
pass
|
52
|
+
|
53
|
+
return nop
|
54
|
+
else:
|
55
|
+
return func
|
56
|
+
|
57
|
+
return _inner_disable_when
|
58
|
+
|
59
|
+
|
60
|
+
LOGGER_THRESHOLD = DEBUG
|
61
|
+
graph_captured: Optional[str | torch.fx.Graph] = None
|
62
|
+
const_size_captured: Optional[int] = None
|
63
|
+
|
64
|
+
|
65
|
+
def get_const_size(ep: torch.export.ExportedProgram) -> int:
|
66
|
+
"""
|
67
|
+
Return const tensor's size in **byte**
|
68
|
+
"""
|
69
|
+
|
70
|
+
def const_size(items):
|
71
|
+
const_sum = 0
|
72
|
+
for _, tensor in items:
|
73
|
+
if len(tensor.size()) == 0:
|
74
|
+
# scalar tensor
|
75
|
+
const_sum += tensor.dtype.itemsize
|
76
|
+
else:
|
77
|
+
const_sum += (
|
78
|
+
reduce(lambda x, y: x * y, list(tensor.size()))
|
79
|
+
* tensor.dtype.itemsize
|
80
|
+
)
|
81
|
+
return const_sum
|
82
|
+
|
83
|
+
constant_tensor_sum = 0
|
84
|
+
|
85
|
+
constant_tensor_sum += const_size(ep.state_dict.items())
|
86
|
+
constant_tensor_sum += const_size(ep.constants.items())
|
87
|
+
|
88
|
+
return constant_tensor_sum
|
89
|
+
|
90
|
+
|
91
|
+
@disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
|
92
|
+
def capture_const(ep: torch.export.ExportedProgram):
|
93
|
+
assert isinstance(ep, torch.export.ExportedProgram)
|
94
|
+
|
95
|
+
global const_size_captured
|
96
|
+
const_size_captured = get_const_size(ep)
|
97
|
+
|
98
|
+
|
99
|
+
@disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
|
100
|
+
def log_const(ep: torch.export.ExportedProgram, title: str, recapture: bool):
|
101
|
+
assert isinstance(ep, torch.export.ExportedProgram)
|
102
|
+
|
103
|
+
global const_size_captured
|
104
|
+
assert const_size_captured is not None
|
105
|
+
const_size = get_const_size(ep)
|
106
|
+
const_size_diff = const_size - const_size_captured
|
107
|
+
|
108
|
+
# print differences
|
109
|
+
logger = getLogger(__name__)
|
110
|
+
prefix = f"[{title}]" if title else ""
|
111
|
+
if const_size_diff > 0:
|
112
|
+
const_size_inc_dec = "has changed (increased)"
|
113
|
+
elif const_size_diff == 0:
|
114
|
+
const_size_inc_dec = "has unchanged"
|
115
|
+
else:
|
116
|
+
const_size_inc_dec = "has changed (decreased)"
|
117
|
+
|
118
|
+
percentage_avg_str = ""
|
119
|
+
if const_size + const_size_captured == 0:
|
120
|
+
percentage_avg_str = "N/A"
|
121
|
+
else:
|
122
|
+
percentage_avg = (
|
123
|
+
float(const_size_diff) / float(const_size + const_size_captured) * 100
|
124
|
+
)
|
125
|
+
if percentage_avg > 0:
|
126
|
+
percentage_avg_str = f"+{percentage_avg:.2f}%"
|
127
|
+
else:
|
128
|
+
percentage_avg_str = f"{percentage_avg:.2f}%"
|
129
|
+
|
130
|
+
if const_size_diff:
|
131
|
+
logger.debug(
|
132
|
+
f"{prefix} Total const size {const_size_inc_dec} by {const_size_diff} Bytes"
|
133
|
+
)
|
134
|
+
logger.debug(f"{const_size_captured}B -> {const_size}B ({percentage_avg_str})")
|
135
|
+
|
136
|
+
if recapture:
|
137
|
+
const_size_captured = const_size
|
138
|
+
|
139
|
+
|
140
|
+
@disable_when(LOG_LEVEL > LOGGER_THRESHOLD)
|
141
|
+
def capture(graph: torch.fx.Graph):
|
142
|
+
"""
|
143
|
+
Capture the start-point graph for graph-diff.
|
144
|
+
String diff lines will be printed to debug logger if enabled.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
graph (torch.fx.Graph): graph to captureString diff lines
|
148
|
+
"""
|
149
|
+
assert isinstance(graph, torch.fx.Graph)
|
150
|
+
global graph_captured
|
151
|
+
graph_captured = str(graph)
|
152
|
+
|
153
|
+
|
154
|
+
@disable_when(LOG_LEVEL > DEBUG)
|
155
|
+
def log(graph: torch.fx.Graph, title: str, recapture: bool):
|
156
|
+
"""
|
157
|
+
Capture the end-point graph for graph-diff.
|
158
|
+
String diff lines will be printed to debug logger if enabled.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
graph (torch.fx.Graph): graph to capture
|
162
|
+
title (str): Title in log
|
163
|
+
recapture (bool): recapture the graph
|
164
|
+
"""
|
165
|
+
assert isinstance(graph, torch.fx.Graph)
|
166
|
+
global graph_captured
|
167
|
+
|
168
|
+
logger = getLogger(__name__)
|
169
|
+
diff = strdiff(f"{graph_captured}\n", f"{graph}\n")
|
170
|
+
prefix = f"[{title}]" if title else ""
|
171
|
+
if len(diff) > 0:
|
172
|
+
logger.debug(f"{prefix} Graph is changed.")
|
173
|
+
logger.debug(f"\n{diff}")
|
174
|
+
|
175
|
+
if recapture:
|
176
|
+
graph_captured = deepcopy(graph)
|
177
|
+
else:
|
178
|
+
graph_captured = None # reset
|
179
|
+
|
180
|
+
|
181
|
+
# TODO diff graph signature
|
tico/utils/errors.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
class CircleExirError(Exception):
|
17
|
+
"""Base class for custom exceptions in project"""
|
18
|
+
|
19
|
+
pass
|
20
|
+
|
21
|
+
|
22
|
+
class NotYetSupportedError(CircleExirError):
|
23
|
+
"""
|
24
|
+
Not yet supported feature or functionality
|
25
|
+
"""
|
26
|
+
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
class InvalidArgumentError(CircleExirError):
|
31
|
+
"""
|
32
|
+
Invalid argument, which is never allowed
|
33
|
+
"""
|
34
|
+
|
35
|
+
pass
|