tico 0.1.0.dev250609__py3-none-any.whl → 0.1.0.dev250610__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/passes/fuse_leading_unsqueeze_reshape.py +110 -0
- tico/utils/convert.py +2 -0
- tico/utils/utils.py +28 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250610.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250610.dist-info}/RECORD +10 -9
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250610.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250610.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250610.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250610.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250610"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Sequence
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch.export import ExportedProgram
|
19
|
+
|
20
|
+
from tico.passes import ops
|
21
|
+
from tico.serialize.circle_mapping import extract_shape
|
22
|
+
from tico.utils import logging
|
23
|
+
from tico.utils.passes import PassBase, PassResult
|
24
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
|
+
from tico.utils.utils import is_single_use_target_node
|
26
|
+
from tico.utils.validate_args_kwargs import PermuteArgs, ReshapeArgs
|
27
|
+
|
28
|
+
|
29
|
+
def _is_leading_unsqueeze(target: Sequence[int], permuted: Sequence[int]) -> bool:
|
30
|
+
"""
|
31
|
+
True if `target` == [1]*k + permuted, k>=1.
|
32
|
+
"""
|
33
|
+
k = len(target) - len(permuted)
|
34
|
+
return (
|
35
|
+
k > 0 and all(d == 1 for d in target[:k]) and list(target[k:]) == list(permuted)
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
@trace_graph_diff_on_pass
|
40
|
+
class FuseLeadingUnsqueezeReshape(PassBase):
|
41
|
+
"""
|
42
|
+
Fuse reshape → permute → reshape where the second reshape only
|
43
|
+
prepends one-sized dims (unsqueeze) to the permuted tensor.
|
44
|
+
|
45
|
+
[BEFORE]
|
46
|
+
x - aten.reshape(s1) - aten.permute(p) - aten.reshape([1]*k + p(s1))
|
47
|
+
[AFTER]
|
48
|
+
x - aten.reshape([1]*k + s1) - aten.permute(list(range(k)) + [d+k for d in p])
|
49
|
+
"""
|
50
|
+
|
51
|
+
def call(self, ep: ExportedProgram) -> PassResult:
|
52
|
+
logger = logging.getLogger(__name__)
|
53
|
+
|
54
|
+
gm = ep.graph_module
|
55
|
+
graph = gm.graph
|
56
|
+
modified = False
|
57
|
+
for reshape_back in graph.nodes:
|
58
|
+
if (
|
59
|
+
reshape_back.op != "call_function"
|
60
|
+
or reshape_back.target not in ops.aten.reshape
|
61
|
+
):
|
62
|
+
continue
|
63
|
+
reshape_back_args = ReshapeArgs(*reshape_back.args, **reshape_back.kwargs) # type: ignore[arg-type]
|
64
|
+
permute, reshape_back_size = reshape_back_args.input, reshape_back_args.size
|
65
|
+
|
66
|
+
if not is_single_use_target_node(permute, ops.aten.permute):
|
67
|
+
continue
|
68
|
+
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
69
|
+
reshape_front, permute_dims = permute_args.input, permute_args.dims
|
70
|
+
|
71
|
+
if not is_single_use_target_node(reshape_front, ops.aten.reshape):
|
72
|
+
continue
|
73
|
+
reshape_front_args = ReshapeArgs(*reshape_front.args, **reshape_front.kwargs) # type: ignore[arg-type]
|
74
|
+
reshape_front_input, reshape_front_size = (
|
75
|
+
reshape_front_args.input,
|
76
|
+
reshape_front_args.size,
|
77
|
+
)
|
78
|
+
|
79
|
+
# ---- condition: only leading unsqueeze ------------------------
|
80
|
+
back_shape = extract_shape(reshape_back)
|
81
|
+
permute_shape = extract_shape(permute)
|
82
|
+
|
83
|
+
if not _is_leading_unsqueeze(back_shape, permute_shape):
|
84
|
+
continue
|
85
|
+
|
86
|
+
# ---- create new reshape & new permute -------------------------
|
87
|
+
k = len(back_shape) - len(permute_shape)
|
88
|
+
with graph.inserting_before(permute):
|
89
|
+
new_shape = [1] * k + list(reshape_front_size)
|
90
|
+
r_new = graph.call_function(
|
91
|
+
torch.ops.aten.reshape.default,
|
92
|
+
args=(reshape_front_input, new_shape),
|
93
|
+
)
|
94
|
+
new_p_dims = list(range(k)) + [
|
95
|
+
d + k for d in permute_dims
|
96
|
+
] # shift by k
|
97
|
+
p_new = graph.call_function(
|
98
|
+
torch.ops.aten.permute.default, args=(r_new, new_p_dims)
|
99
|
+
)
|
100
|
+
|
101
|
+
reshape_back.replace_all_uses_with(p_new, propagate_meta=True)
|
102
|
+
modified = True
|
103
|
+
logger.debug(f"{reshape_back.name} is fused to {r_new.name}")
|
104
|
+
|
105
|
+
if modified:
|
106
|
+
graph.eliminate_dead_code()
|
107
|
+
graph.lint()
|
108
|
+
gm.recompile()
|
109
|
+
|
110
|
+
return PassResult(modified)
|
tico/utils/convert.py
CHANGED
@@ -51,6 +51,7 @@ from tico.passes.decompose_grouped_conv2d import DecomposeGroupedConv2d
|
|
51
51
|
from tico.passes.decompose_slice_scatter import DecomposeSliceScatter
|
52
52
|
from tico.passes.extract_dtype_kwargs import ExtractDtypeKwargsPass
|
53
53
|
from tico.passes.fill_meta_val import FillMetaVal
|
54
|
+
from tico.passes.fuse_leading_unsqueeze_reshape import FuseLeadingUnsqueezeReshape
|
54
55
|
from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean
|
55
56
|
from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
|
56
57
|
from tico.passes.legalize_predefined_layout_operators import (
|
@@ -225,6 +226,7 @@ def convert_exported_module_to_circle(
|
|
225
226
|
LowerPow2ToMul(),
|
226
227
|
ConvertConv1dToConv2d(),
|
227
228
|
*LowerToSlicePasses(),
|
229
|
+
FuseLeadingUnsqueezeReshape(),
|
228
230
|
]
|
229
231
|
)
|
230
232
|
circle_legalize.run(exported_program)
|
tico/utils/utils.py
CHANGED
@@ -378,3 +378,31 @@ def broadcastable(
|
|
378
378
|
if dim_a != 1 and dim_b != 1 and dim_a != dim_b:
|
379
379
|
return False
|
380
380
|
return True
|
381
|
+
|
382
|
+
|
383
|
+
def is_single_use_target_node(
|
384
|
+
node: torch.fx.Node, target_ops: list[torch._ops.OpOverload] | torch._ops.OpOverload
|
385
|
+
):
|
386
|
+
"""
|
387
|
+
Check whether a given node is a `call_function` node that matches one of the specified targets
|
388
|
+
and is used by only one other node.
|
389
|
+
|
390
|
+
Args:
|
391
|
+
node (torch.fx.Node): The node to check.
|
392
|
+
target_ops (Iterable[Callable]): A list or set of target operations to match (e.g., ops.aten.reshape).
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
bool: True if the node is a call_function, its target is in `target_ops`, and it has exactly one user.
|
396
|
+
"""
|
397
|
+
if not isinstance(target_ops, list):
|
398
|
+
target_ops = [target_ops]
|
399
|
+
assert all(isinstance(t, torch._ops.OpOverload) for t in target_ops), target_ops
|
400
|
+
|
401
|
+
if node.op != "call_function":
|
402
|
+
return False
|
403
|
+
if node.target not in target_ops:
|
404
|
+
return False
|
405
|
+
if len(node.users) != 1:
|
406
|
+
return False
|
407
|
+
|
408
|
+
return True
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=PHlL6DvW31YH0rgvY-y2LlEJDbNJSK7FO1KX1O5BGMM,1743
|
2
2
|
tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -75,6 +75,7 @@ tico/passes/decompose_grouped_conv2d.py,sha256=KJhH6PX7l9k9T8KBV8JDAvaSfJuUnRo_j
|
|
75
75
|
tico/passes/decompose_slice_scatter.py,sha256=ko9p8v-zY5rOx4aSpWomwSdSWb1lIF32gnU7ik5xgII,5604
|
76
76
|
tico/passes/extract_dtype_kwargs.py,sha256=hfGJ_GfZULbBmLif2AJkhPHVifhucxBiLoQI862Yejk,4303
|
77
77
|
tico/passes/fill_meta_val.py,sha256=Xbam6Aq90ZfWItZw1dgLIwH_q8RCiU5JodKNqkj-ink,1797
|
78
|
+
tico/passes/fuse_leading_unsqueeze_reshape.py,sha256=1hqX9urmV4kQXGWCEJ22XNH_3M2Og5v09N2wwV87FoQ,4304
|
78
79
|
tico/passes/fuse_redundant_reshape_to_mean.py,sha256=SzHLLL2Yfj4k1c2L5i4PVI9EFUilHRfIS6S-hahKFCM,3702
|
79
80
|
tico/passes/legalize_causal_mask_value.py,sha256=KZc_UPk7CGPXO35JOu6dVrOzRJx-ZpJoVvuzz-GvQek,4080
|
80
81
|
tico/passes/legalize_predefined_layout_operators.py,sha256=N2TtJInjSTk-E5afnkDXXbo9v4zTM7yzsjna3VoihMw,15895
|
@@ -177,7 +178,7 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
|
|
177
178
|
tico/serialize/operators/op_where.py,sha256=qZDFKVQ2u8HB0Jjpto_1A4g-jfpEprAbBT3PmIGzdoY,2855
|
178
179
|
tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
|
179
180
|
tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
180
|
-
tico/utils/convert.py,sha256=
|
181
|
+
tico/utils/convert.py,sha256=nbpBJue9_ezb8Dj9L9dWcNeVlZd0soGJu_fvrL6-c5c,11810
|
181
182
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
182
183
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
183
184
|
tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
|
@@ -189,15 +190,15 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
|
189
190
|
tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
|
190
191
|
tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
|
191
192
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
192
|
-
tico/utils/utils.py,sha256=
|
193
|
+
tico/utils/utils.py,sha256=A5YeUXIjSY5HnlmtnMfWm6wx3i6Feq5dSra1NJsZV2o,13253
|
193
194
|
tico/utils/validate_args_kwargs.py,sha256=P4aMnr9EhNCtc_AgJPpuezfQbqFfDn0lhJSWqmumLZ8,25054
|
194
195
|
tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
195
196
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
196
197
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
197
198
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
198
|
-
tico-0.1.0.
|
199
|
-
tico-0.1.0.
|
200
|
-
tico-0.1.0.
|
201
|
-
tico-0.1.0.
|
202
|
-
tico-0.1.0.
|
203
|
-
tico-0.1.0.
|
199
|
+
tico-0.1.0.dev250610.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
200
|
+
tico-0.1.0.dev250610.dist-info/METADATA,sha256=9nX12pxSQ-4h2gPPm0vYX0EW3wkn00HtmLF1E-xIabE,8633
|
201
|
+
tico-0.1.0.dev250610.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
202
|
+
tico-0.1.0.dev250610.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
203
|
+
tico-0.1.0.dev250610.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
204
|
+
tico-0.1.0.dev250610.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|