tico 0.1.0.dev250609__py3-none-any.whl → 0.1.0.dev250610__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250609"
24
+ __version__ = "0.1.0.dev250610"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,110 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Sequence
16
+
17
+ import torch
18
+ from torch.export import ExportedProgram
19
+
20
+ from tico.passes import ops
21
+ from tico.serialize.circle_mapping import extract_shape
22
+ from tico.utils import logging
23
+ from tico.utils.passes import PassBase, PassResult
24
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+ from tico.utils.utils import is_single_use_target_node
26
+ from tico.utils.validate_args_kwargs import PermuteArgs, ReshapeArgs
27
+
28
+
29
+ def _is_leading_unsqueeze(target: Sequence[int], permuted: Sequence[int]) -> bool:
30
+ """
31
+ True if `target` == [1]*k + permuted, k>=1.
32
+ """
33
+ k = len(target) - len(permuted)
34
+ return (
35
+ k > 0 and all(d == 1 for d in target[:k]) and list(target[k:]) == list(permuted)
36
+ )
37
+
38
+
39
+ @trace_graph_diff_on_pass
40
+ class FuseLeadingUnsqueezeReshape(PassBase):
41
+ """
42
+ Fuse reshape → permute → reshape where the second reshape only
43
+ prepends one-sized dims (unsqueeze) to the permuted tensor.
44
+
45
+ [BEFORE]
46
+ x - aten.reshape(s1) - aten.permute(p) - aten.reshape([1]*k + p(s1))
47
+ [AFTER]
48
+ x - aten.reshape([1]*k + s1) - aten.permute(list(range(k)) + [d+k for d in p])
49
+ """
50
+
51
+ def call(self, ep: ExportedProgram) -> PassResult:
52
+ logger = logging.getLogger(__name__)
53
+
54
+ gm = ep.graph_module
55
+ graph = gm.graph
56
+ modified = False
57
+ for reshape_back in graph.nodes:
58
+ if (
59
+ reshape_back.op != "call_function"
60
+ or reshape_back.target not in ops.aten.reshape
61
+ ):
62
+ continue
63
+ reshape_back_args = ReshapeArgs(*reshape_back.args, **reshape_back.kwargs) # type: ignore[arg-type]
64
+ permute, reshape_back_size = reshape_back_args.input, reshape_back_args.size
65
+
66
+ if not is_single_use_target_node(permute, ops.aten.permute):
67
+ continue
68
+ permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
69
+ reshape_front, permute_dims = permute_args.input, permute_args.dims
70
+
71
+ if not is_single_use_target_node(reshape_front, ops.aten.reshape):
72
+ continue
73
+ reshape_front_args = ReshapeArgs(*reshape_front.args, **reshape_front.kwargs) # type: ignore[arg-type]
74
+ reshape_front_input, reshape_front_size = (
75
+ reshape_front_args.input,
76
+ reshape_front_args.size,
77
+ )
78
+
79
+ # ---- condition: only leading unsqueeze ------------------------
80
+ back_shape = extract_shape(reshape_back)
81
+ permute_shape = extract_shape(permute)
82
+
83
+ if not _is_leading_unsqueeze(back_shape, permute_shape):
84
+ continue
85
+
86
+ # ---- create new reshape & new permute -------------------------
87
+ k = len(back_shape) - len(permute_shape)
88
+ with graph.inserting_before(permute):
89
+ new_shape = [1] * k + list(reshape_front_size)
90
+ r_new = graph.call_function(
91
+ torch.ops.aten.reshape.default,
92
+ args=(reshape_front_input, new_shape),
93
+ )
94
+ new_p_dims = list(range(k)) + [
95
+ d + k for d in permute_dims
96
+ ] # shift by k
97
+ p_new = graph.call_function(
98
+ torch.ops.aten.permute.default, args=(r_new, new_p_dims)
99
+ )
100
+
101
+ reshape_back.replace_all_uses_with(p_new, propagate_meta=True)
102
+ modified = True
103
+ logger.debug(f"{reshape_back.name} is fused to {r_new.name}")
104
+
105
+ if modified:
106
+ graph.eliminate_dead_code()
107
+ graph.lint()
108
+ gm.recompile()
109
+
110
+ return PassResult(modified)
tico/utils/convert.py CHANGED
@@ -51,6 +51,7 @@ from tico.passes.decompose_grouped_conv2d import DecomposeGroupedConv2d
51
51
  from tico.passes.decompose_slice_scatter import DecomposeSliceScatter
52
52
  from tico.passes.extract_dtype_kwargs import ExtractDtypeKwargsPass
53
53
  from tico.passes.fill_meta_val import FillMetaVal
54
+ from tico.passes.fuse_leading_unsqueeze_reshape import FuseLeadingUnsqueezeReshape
54
55
  from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean
55
56
  from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
56
57
  from tico.passes.legalize_predefined_layout_operators import (
@@ -225,6 +226,7 @@ def convert_exported_module_to_circle(
225
226
  LowerPow2ToMul(),
226
227
  ConvertConv1dToConv2d(),
227
228
  *LowerToSlicePasses(),
229
+ FuseLeadingUnsqueezeReshape(),
228
230
  ]
229
231
  )
230
232
  circle_legalize.run(exported_program)
tico/utils/utils.py CHANGED
@@ -378,3 +378,31 @@ def broadcastable(
378
378
  if dim_a != 1 and dim_b != 1 and dim_a != dim_b:
379
379
  return False
380
380
  return True
381
+
382
+
383
+ def is_single_use_target_node(
384
+ node: torch.fx.Node, target_ops: list[torch._ops.OpOverload] | torch._ops.OpOverload
385
+ ):
386
+ """
387
+ Check whether a given node is a `call_function` node that matches one of the specified targets
388
+ and is used by only one other node.
389
+
390
+ Args:
391
+ node (torch.fx.Node): The node to check.
392
+ target_ops (Iterable[Callable]): A list or set of target operations to match (e.g., ops.aten.reshape).
393
+
394
+ Returns:
395
+ bool: True if the node is a call_function, its target is in `target_ops`, and it has exactly one user.
396
+ """
397
+ if not isinstance(target_ops, list):
398
+ target_ops = [target_ops]
399
+ assert all(isinstance(t, torch._ops.OpOverload) for t in target_ops), target_ops
400
+
401
+ if node.op != "call_function":
402
+ return False
403
+ if node.target not in target_ops:
404
+ return False
405
+ if len(node.users) != 1:
406
+ return False
407
+
408
+ return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250609
3
+ Version: 0.1.0.dev250610
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=mEQHwb4j_KQrsDMO6gp2USb5RCYNhd0MOxL4P_EGMW8,1743
1
+ tico/__init__.py,sha256=PHlL6DvW31YH0rgvY-y2LlEJDbNJSK7FO1KX1O5BGMM,1743
2
2
  tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -75,6 +75,7 @@ tico/passes/decompose_grouped_conv2d.py,sha256=KJhH6PX7l9k9T8KBV8JDAvaSfJuUnRo_j
75
75
  tico/passes/decompose_slice_scatter.py,sha256=ko9p8v-zY5rOx4aSpWomwSdSWb1lIF32gnU7ik5xgII,5604
76
76
  tico/passes/extract_dtype_kwargs.py,sha256=hfGJ_GfZULbBmLif2AJkhPHVifhucxBiLoQI862Yejk,4303
77
77
  tico/passes/fill_meta_val.py,sha256=Xbam6Aq90ZfWItZw1dgLIwH_q8RCiU5JodKNqkj-ink,1797
78
+ tico/passes/fuse_leading_unsqueeze_reshape.py,sha256=1hqX9urmV4kQXGWCEJ22XNH_3M2Og5v09N2wwV87FoQ,4304
78
79
  tico/passes/fuse_redundant_reshape_to_mean.py,sha256=SzHLLL2Yfj4k1c2L5i4PVI9EFUilHRfIS6S-hahKFCM,3702
79
80
  tico/passes/legalize_causal_mask_value.py,sha256=KZc_UPk7CGPXO35JOu6dVrOzRJx-ZpJoVvuzz-GvQek,4080
80
81
  tico/passes/legalize_predefined_layout_operators.py,sha256=N2TtJInjSTk-E5afnkDXXbo9v4zTM7yzsjna3VoihMw,15895
@@ -177,7 +178,7 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
177
178
  tico/serialize/operators/op_where.py,sha256=qZDFKVQ2u8HB0Jjpto_1A4g-jfpEprAbBT3PmIGzdoY,2855
178
179
  tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
179
180
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
180
- tico/utils/convert.py,sha256=KCllPnvQ8bjEYR1yI72s9aNBp7Py1CzIEEpYSYZcu60,11684
181
+ tico/utils/convert.py,sha256=nbpBJue9_ezb8Dj9L9dWcNeVlZd0soGJu_fvrL6-c5c,11810
181
182
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
182
183
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
183
184
  tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
@@ -189,15 +190,15 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
189
190
  tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
190
191
  tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
191
192
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
192
- tico/utils/utils.py,sha256=7t4U3Bs7SV7y3PCdgE9bEAAAI2u1mzQEUP81oP_XHHU,12334
193
+ tico/utils/utils.py,sha256=A5YeUXIjSY5HnlmtnMfWm6wx3i6Feq5dSra1NJsZV2o,13253
193
194
  tico/utils/validate_args_kwargs.py,sha256=P4aMnr9EhNCtc_AgJPpuezfQbqFfDn0lhJSWqmumLZ8,25054
194
195
  tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
195
196
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
196
197
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
197
198
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
198
- tico-0.1.0.dev250609.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
199
- tico-0.1.0.dev250609.dist-info/METADATA,sha256=ZpNUuShOOV1hMrNe84u05IoeX4UZn1fj1_TEU4ZlI1Y,8633
200
- tico-0.1.0.dev250609.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
201
- tico-0.1.0.dev250609.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
202
- tico-0.1.0.dev250609.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
203
- tico-0.1.0.dev250609.dist-info/RECORD,,
199
+ tico-0.1.0.dev250610.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
200
+ tico-0.1.0.dev250610.dist-info/METADATA,sha256=9nX12pxSQ-4h2gPPm0vYX0EW3wkn00HtmLF1E-xIabE,8633
201
+ tico-0.1.0.dev250610.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
202
+ tico-0.1.0.dev250610.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
203
+ tico-0.1.0.dev250610.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
204
+ tico-0.1.0.dev250610.dist-info/RECORD,,