tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
|
@@ -20,8 +20,16 @@ from packaging.version import Version
|
|
|
20
20
|
from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
|
22
22
|
|
|
23
|
+
__all__ = [
|
|
24
|
+
"CompileConfigV1",
|
|
25
|
+
"get_default_config",
|
|
26
|
+
"convert",
|
|
27
|
+
"convert_from_exported_program",
|
|
28
|
+
"convert_from_pt2",
|
|
29
|
+
]
|
|
30
|
+
|
|
23
31
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
|
24
|
-
__version__ = "0.1.0.
|
|
32
|
+
__version__ = "0.1.0.dev251102"
|
|
25
33
|
|
|
26
34
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
|
27
35
|
SECURE_TORCH_VERSION = "2.6.0"
|
tico/config/base.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -20,6 +20,11 @@ from tico.config.base import CompileConfigBase
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class CompileConfigV1(CompileConfigBase):
|
|
22
22
|
legalize_causal_mask_value: bool = False
|
|
23
|
+
remove_constant_input: bool = False
|
|
24
|
+
convert_lhs_const_mm_to_fc: bool = False
|
|
25
|
+
convert_rhs_const_mm_to_fc: bool = True
|
|
26
|
+
convert_single_batch_lhs_const_bmm_to_fc: bool = False
|
|
27
|
+
convert_expand_to_slice_cat: bool = False
|
|
23
28
|
|
|
24
29
|
def get(self, name: str):
|
|
25
30
|
return super().get(name)
|
|
@@ -176,7 +176,7 @@ class CastATenWhereArgType(PassBase):
|
|
|
176
176
|
node_dtype = extract_torch_dtype(node)
|
|
177
177
|
assert (
|
|
178
178
|
node_dtype == node_dtype_ori
|
|
179
|
-
),
|
|
179
|
+
), "Type casting doesn't change node's dtype."
|
|
180
180
|
|
|
181
181
|
logger.debug(
|
|
182
182
|
f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.passes import ops
|
|
23
|
+
|
|
24
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
|
25
|
+
from tico.utils import logging
|
|
26
|
+
from tico.utils.graph import create_node
|
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
29
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
|
30
|
+
from tico.utils.validate_args_kwargs import ClampArgs
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@trace_graph_diff_on_pass
|
|
34
|
+
class CastClampMixedTypeArgs(PassBase):
|
|
35
|
+
"""
|
|
36
|
+
This pass ensures consistent dtypes for clamp operations by:
|
|
37
|
+
1. Converting min/max arguments to match output dtype when provided
|
|
38
|
+
2. Inserting cast operations when input dtype differs from output dtype
|
|
39
|
+
|
|
40
|
+
Behavior Examples:
|
|
41
|
+
- When input dtype differs from output:
|
|
42
|
+
Inserts _to_copy operation to convert input
|
|
43
|
+
- When min/max dtype differs from output:
|
|
44
|
+
Converts min/max values to output dtype
|
|
45
|
+
|
|
46
|
+
(Case 1, if input dtype is different from output dtype)
|
|
47
|
+
[before]
|
|
48
|
+
|
|
49
|
+
input min(or max)
|
|
50
|
+
(dtype=int) (dtype=float)
|
|
51
|
+
| |
|
|
52
|
+
clamp <----------------+
|
|
53
|
+
|
|
|
54
|
+
output
|
|
55
|
+
(dtype=float)
|
|
56
|
+
|
|
57
|
+
[after]
|
|
58
|
+
|
|
59
|
+
input min(or max)
|
|
60
|
+
(dtype=int) (dtype=float)
|
|
61
|
+
| |
|
|
62
|
+
cast |
|
|
63
|
+
(in=int, out=float) |
|
|
64
|
+
| |
|
|
65
|
+
clamp <--------------+
|
|
66
|
+
|
|
|
67
|
+
output
|
|
68
|
+
(dtype=float)
|
|
69
|
+
|
|
70
|
+
(Case 2, if min(or max) dtype is different from output dtype)
|
|
71
|
+
[before]
|
|
72
|
+
|
|
73
|
+
input min(or max)
|
|
74
|
+
(dtype=float) (dtype=int)
|
|
75
|
+
| |
|
|
76
|
+
clamp <----------------+
|
|
77
|
+
|
|
|
78
|
+
output
|
|
79
|
+
(dtype=float)
|
|
80
|
+
|
|
81
|
+
[after]
|
|
82
|
+
|
|
83
|
+
input min(or max)
|
|
84
|
+
(dtype=float) (dtype=float)
|
|
85
|
+
| |
|
|
86
|
+
clamp <--------------+
|
|
87
|
+
|
|
|
88
|
+
output
|
|
89
|
+
(dtype=float)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self):
|
|
93
|
+
super().__init__()
|
|
94
|
+
|
|
95
|
+
def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
|
|
96
|
+
logger = logging.getLogger(__name__)
|
|
97
|
+
modified = False
|
|
98
|
+
|
|
99
|
+
graph_module = exported_program.graph_module
|
|
100
|
+
graph = graph_module.graph
|
|
101
|
+
|
|
102
|
+
# clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
|
|
103
|
+
args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
104
|
+
|
|
105
|
+
input = args.input
|
|
106
|
+
min = args.min
|
|
107
|
+
max = args.max
|
|
108
|
+
|
|
109
|
+
input_dtype = extract_torch_dtype(input)
|
|
110
|
+
output_dtype = extract_torch_dtype(node)
|
|
111
|
+
|
|
112
|
+
def _convert_arg(arg, arg_name: str):
|
|
113
|
+
if arg is None:
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
arg_dtype = torch.tensor(arg).dtype
|
|
117
|
+
arg_idx = node.args.index(arg)
|
|
118
|
+
if arg_dtype != output_dtype:
|
|
119
|
+
assert output_dtype in [torch.float, torch.int]
|
|
120
|
+
if output_dtype == torch.float:
|
|
121
|
+
arg = float(arg)
|
|
122
|
+
else:
|
|
123
|
+
arg = int(arg)
|
|
124
|
+
node.update_arg(arg_idx, arg)
|
|
125
|
+
logger.debug(
|
|
126
|
+
f"Casting {arg_name} value from {arg_dtype} to {output_dtype} for clamp operation at {node.name}"
|
|
127
|
+
)
|
|
128
|
+
return True
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
modified |= _convert_arg(min, "min")
|
|
132
|
+
modified |= _convert_arg(max, "max")
|
|
133
|
+
|
|
134
|
+
if input_dtype != output_dtype:
|
|
135
|
+
logger.debug(
|
|
136
|
+
f"Inserting cast from {input_dtype} to {output_dtype} for input {input.name}"
|
|
137
|
+
)
|
|
138
|
+
with graph.inserting_after(input):
|
|
139
|
+
to_copy = create_node(
|
|
140
|
+
graph,
|
|
141
|
+
torch.ops.aten._to_copy.default,
|
|
142
|
+
(input,),
|
|
143
|
+
{"dtype": output_dtype},
|
|
144
|
+
origin=input,
|
|
145
|
+
)
|
|
146
|
+
set_new_meta_val(to_copy)
|
|
147
|
+
node.update_arg(node.args.index(input), to_copy)
|
|
148
|
+
|
|
149
|
+
modified = True
|
|
150
|
+
|
|
151
|
+
return modified
|
|
152
|
+
|
|
153
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
154
|
+
target_op = ops.aten.clamp
|
|
155
|
+
|
|
156
|
+
graph_module = exported_program.graph_module
|
|
157
|
+
graph = graph_module.graph
|
|
158
|
+
modified = False
|
|
159
|
+
for node in graph.nodes:
|
|
160
|
+
if not is_target_node(node, target_op):
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
modified |= self.convert(exported_program, node)
|
|
164
|
+
|
|
165
|
+
graph.eliminate_dead_code()
|
|
166
|
+
graph.lint()
|
|
167
|
+
graph_module.recompile()
|
|
168
|
+
|
|
169
|
+
return PassResult(modified)
|
|
@@ -41,6 +41,8 @@ ops_to_promote = {
|
|
|
41
41
|
torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
42
42
|
torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
43
43
|
torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
44
|
+
torch.ops.aten.le.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
45
|
+
torch.ops.aten.le.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
44
46
|
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
45
47
|
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
46
48
|
torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
@@ -124,7 +126,7 @@ class CastMixedTypeArgs(PassBase):
|
|
|
124
126
|
if rhs_val.dtype == type_to_promote:
|
|
125
127
|
ori_type = lhs_val.dtype
|
|
126
128
|
arg_to_promote = lhs
|
|
127
|
-
assert arg_to_promote
|
|
129
|
+
assert arg_to_promote is not None
|
|
128
130
|
|
|
129
131
|
if isinstance(arg_to_promote, torch.fx.Node):
|
|
130
132
|
with graph.inserting_after(arg_to_promote):
|
|
@@ -178,7 +180,7 @@ class CastMixedTypeArgs(PassBase):
|
|
|
178
180
|
node_dtype = extract_torch_dtype(node)
|
|
179
181
|
assert (
|
|
180
182
|
node_dtype == node_dtype_ori
|
|
181
|
-
),
|
|
183
|
+
), "Type casting doesn't change node's dtype."
|
|
182
184
|
|
|
183
185
|
graph.eliminate_dead_code()
|
|
184
186
|
graph.lint()
|
tico/passes/const_prop_pass.py
CHANGED
|
@@ -301,7 +301,7 @@ class ConstPropPass(PassBase):
|
|
|
301
301
|
graph.eliminate_dead_code()
|
|
302
302
|
graph_module.recompile()
|
|
303
303
|
|
|
304
|
-
logger.debug(
|
|
304
|
+
logger.debug("Constant nodes are propagated")
|
|
305
305
|
# Constant folding can be done with only one time run. Let's set `modified` to False.
|
|
306
306
|
modified = False
|
|
307
307
|
return PassResult(modified)
|
|
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
|
21
21
|
|
|
22
|
-
from tico.serialize.
|
|
22
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
23
23
|
from tico.utils import logging
|
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
|
25
25
|
from tico.utils.graph import create_node
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.passes import ops
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
from tico.utils import logging
|
|
25
|
+
from tico.utils.graph import create_node
|
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
28
|
+
from tico.utils.utils import is_target_node
|
|
29
|
+
from tico.utils.validate_args_kwargs import ExpandArgs, ReshapeArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@trace_graph_diff_on_pass
|
|
33
|
+
class ConvertExpandToSliceCat(PassBase):
|
|
34
|
+
"""
|
|
35
|
+
This pass replaces `aten.reshape` + `aten.expand` pattern by rewriting it using
|
|
36
|
+
a series of `aten.slice` and `aten.cat` operations.
|
|
37
|
+
|
|
38
|
+
This pass is specialized for expand of KVCache.
|
|
39
|
+
- Expects (batch, num_key_value_heads, seq_len, head_dim) as input shape of reshape
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, enabled: bool = False):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.enabled = enabled
|
|
45
|
+
|
|
46
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
47
|
+
if not self.enabled:
|
|
48
|
+
return PassResult(False)
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
graph_module = exported_program.graph_module
|
|
53
|
+
graph = graph_module.graph
|
|
54
|
+
modified = False
|
|
55
|
+
|
|
56
|
+
# This pass handles expand on EXPAND_DIM only
|
|
57
|
+
CAT_DIM = 1
|
|
58
|
+
EXPAND_DIM = 2
|
|
59
|
+
|
|
60
|
+
for node in graph.nodes:
|
|
61
|
+
if not isinstance(node, torch.fx.Node) or not is_target_node(
|
|
62
|
+
node, ops.aten.reshape
|
|
63
|
+
):
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
post_reshape = node
|
|
67
|
+
post_reshape_args = ReshapeArgs(*post_reshape.args, **post_reshape.kwargs)
|
|
68
|
+
post_reshape_input = post_reshape_args.input
|
|
69
|
+
|
|
70
|
+
if not isinstance(post_reshape_input, torch.fx.Node) or not is_target_node(
|
|
71
|
+
post_reshape_input, ops.aten.expand
|
|
72
|
+
):
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
expand = post_reshape_input
|
|
76
|
+
expand_args = ExpandArgs(*expand.args, **expand.kwargs)
|
|
77
|
+
expand_input = expand_args.input
|
|
78
|
+
expand_shape = extract_shape(expand)
|
|
79
|
+
|
|
80
|
+
if not isinstance(expand_input, torch.fx.Node) or not is_target_node(
|
|
81
|
+
expand_input, ops.aten.reshape
|
|
82
|
+
):
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
pre_reshape = expand_input
|
|
86
|
+
pre_reshape_args = ReshapeArgs(*pre_reshape.args, **pre_reshape.kwargs)
|
|
87
|
+
pre_reshape_input = pre_reshape_args.input
|
|
88
|
+
pre_reshape_shape = extract_shape(pre_reshape)
|
|
89
|
+
|
|
90
|
+
if pre_reshape_shape[EXPAND_DIM] != 1:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
reshape_input_shape = extract_shape(pre_reshape_input)
|
|
94
|
+
|
|
95
|
+
if len(expand_shape) != len(pre_reshape_shape):
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Ensure all dimensions *except* at EXPAND_DIM are identical.
|
|
99
|
+
if not (
|
|
100
|
+
expand_shape[:EXPAND_DIM] == pre_reshape_shape[:EXPAND_DIM]
|
|
101
|
+
and expand_shape[EXPAND_DIM + 1 :]
|
|
102
|
+
== pre_reshape_shape[EXPAND_DIM + 1 :]
|
|
103
|
+
):
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# Ensure the expansion dimension is a clean multiple.
|
|
107
|
+
if expand_shape[EXPAND_DIM] % pre_reshape_shape[EXPAND_DIM] != 0:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
expand_ratio = expand_shape[EXPAND_DIM] // pre_reshape_shape[EXPAND_DIM]
|
|
111
|
+
|
|
112
|
+
if expand_ratio <= 1:
|
|
113
|
+
continue
|
|
114
|
+
|
|
115
|
+
cat_nodes = []
|
|
116
|
+
|
|
117
|
+
for i in range(reshape_input_shape[CAT_DIM]):
|
|
118
|
+
with graph.inserting_before(expand):
|
|
119
|
+
slice_copy_args = (pre_reshape_input, CAT_DIM, i, i + 1, 1)
|
|
120
|
+
slice_node = create_node(
|
|
121
|
+
graph,
|
|
122
|
+
torch.ops.aten.slice.Tensor,
|
|
123
|
+
args=slice_copy_args,
|
|
124
|
+
origin=expand,
|
|
125
|
+
)
|
|
126
|
+
with graph.inserting_after(slice_node):
|
|
127
|
+
cat_args = ([slice_node] * expand_ratio, CAT_DIM)
|
|
128
|
+
cat_node = create_node(
|
|
129
|
+
graph,
|
|
130
|
+
torch.ops.aten.cat.default,
|
|
131
|
+
args=cat_args,
|
|
132
|
+
origin=expand,
|
|
133
|
+
)
|
|
134
|
+
cat_nodes.append(cat_node)
|
|
135
|
+
|
|
136
|
+
with graph.inserting_after(expand):
|
|
137
|
+
cat_args = (cat_nodes, CAT_DIM)
|
|
138
|
+
cat_node = create_node(
|
|
139
|
+
graph,
|
|
140
|
+
torch.ops.aten.cat.default,
|
|
141
|
+
args=cat_args,
|
|
142
|
+
origin=expand,
|
|
143
|
+
)
|
|
144
|
+
expand.replace_all_uses_with(cat_node)
|
|
145
|
+
|
|
146
|
+
modified = True
|
|
147
|
+
logger.debug(f"{expand.name} is replaced with {cat_node.name} operators")
|
|
148
|
+
|
|
149
|
+
graph.eliminate_dead_code()
|
|
150
|
+
graph.lint()
|
|
151
|
+
graph_module.recompile()
|
|
152
|
+
|
|
153
|
+
return PassResult(modified)
|