tico 0.1.0.dev250715__py3-none-any.whl → 0.1.0.dev250717__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250715"
24
+ __version__ = "0.1.0.dev250717"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,169 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+
24
+ from tico.serialize.circle_mapping import extract_torch_dtype
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import is_target_node, set_new_meta_val
30
+ from tico.utils.validate_args_kwargs import ClampArgs
31
+
32
+
33
+ @trace_graph_diff_on_pass
34
+ class CastClampMixedTypeArgs(PassBase):
35
+ """
36
+ This pass ensures consistent dtypes for clamp operations by:
37
+ 1. Converting min/max arguments to match output dtype when provided
38
+ 2. Inserting cast operations when input dtype differs from output dtype
39
+
40
+ Behavior Examples:
41
+ - When input dtype differs from output:
42
+ Inserts _to_copy operation to convert input
43
+ - When min/max dtype differs from output:
44
+ Converts min/max values to output dtype
45
+
46
+ (Case 1, if input dtype is different from output dtype)
47
+ [before]
48
+
49
+ input min(or max)
50
+ (dtype=int) (dtype=float)
51
+ | |
52
+ clamp <----------------+
53
+ |
54
+ output
55
+ (dtype=float)
56
+
57
+ [after]
58
+
59
+ input min(or max)
60
+ (dtype=int) (dtype=float)
61
+ | |
62
+ cast |
63
+ (in=int, out=float) |
64
+ | |
65
+ clamp <--------------+
66
+ |
67
+ output
68
+ (dtype=float)
69
+
70
+ (Case 2, if min(or max) dtype is different from output dtype)
71
+ [before]
72
+
73
+ input min(or max)
74
+ (dtype=float) (dtype=int)
75
+ | |
76
+ clamp <----------------+
77
+ |
78
+ output
79
+ (dtype=float)
80
+
81
+ [after]
82
+
83
+ input min(or max)
84
+ (dtype=float) (dtype=float)
85
+ | |
86
+ clamp <--------------+
87
+ |
88
+ output
89
+ (dtype=float)
90
+ """
91
+
92
+ def __init__(self):
93
+ super().__init__()
94
+
95
+ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
96
+ logger = logging.getLogger(__name__)
97
+ modified = False
98
+
99
+ graph_module = exported_program.graph_module
100
+ graph = graph_module.graph
101
+
102
+ # clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
103
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
104
+
105
+ input = args.input
106
+ min = args.min
107
+ max = args.max
108
+
109
+ input_dtype = extract_torch_dtype(input)
110
+ output_dtype = extract_torch_dtype(node)
111
+
112
+ def _convert_arg(arg, arg_name: str):
113
+ if arg is None:
114
+ return False
115
+
116
+ arg_dtype = torch.tensor(arg).dtype
117
+ arg_idx = node.args.index(arg)
118
+ if arg_dtype != output_dtype:
119
+ assert output_dtype in [torch.float, torch.int]
120
+ if output_dtype == torch.float:
121
+ arg = float(arg)
122
+ else:
123
+ arg = int(arg)
124
+ node.update_arg(arg_idx, arg)
125
+ logger.debug(
126
+ f"Casting {arg_name} value from {arg_dtype} to {output_dtype} for clamp operation at {node.name}"
127
+ )
128
+ return True
129
+ return False
130
+
131
+ modified |= _convert_arg(min, "min")
132
+ modified |= _convert_arg(max, "max")
133
+
134
+ if input_dtype != output_dtype:
135
+ logger.debug(
136
+ f"Inserting cast from {input_dtype} to {output_dtype} for input {input.name}"
137
+ )
138
+ with graph.inserting_after(input):
139
+ to_copy = create_node(
140
+ graph,
141
+ torch.ops.aten._to_copy.default,
142
+ (input,),
143
+ {"dtype": output_dtype},
144
+ origin=input,
145
+ )
146
+ set_new_meta_val(to_copy)
147
+ node.update_arg(node.args.index(input), to_copy)
148
+
149
+ modified = True
150
+
151
+ return modified
152
+
153
+ def call(self, exported_program: ExportedProgram) -> PassResult:
154
+ target_op = ops.aten.clamp
155
+
156
+ graph_module = exported_program.graph_module
157
+ graph = graph_module.graph
158
+ modified = False
159
+ for node in graph.nodes:
160
+ if not is_target_node(node, target_op):
161
+ continue
162
+
163
+ modified |= self.convert(exported_program, node)
164
+
165
+ graph.eliminate_dead_code()
166
+ graph.lint()
167
+ graph_module.recompile()
168
+
169
+ return PassResult(modified)
tico/passes/ops.py CHANGED
@@ -73,6 +73,7 @@ class AtenOps:
73
73
  torch.ops.aten.view.default,
74
74
  torch.ops.aten.view_copy.default,
75
75
  ]
76
+ self._to_copy = [torch.ops.aten._to_copy.default]
76
77
 
77
78
 
78
79
  aten = AtenOps()
tico/utils/convert.py CHANGED
@@ -35,6 +35,7 @@ from tico.experimental.quantization.passes.remove_weight_dequant_op import (
35
35
  RemoveWeightDequantOp,
36
36
  )
37
37
  from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
38
+ from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
38
39
  from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
39
40
  from tico.passes.const_prop_pass import ConstPropPass
40
41
  from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
@@ -251,6 +252,7 @@ def convert_exported_module_to_circle(
251
252
  ConvertConv1dToConv2d(),
252
253
  *LowerToSlicePasses(),
253
254
  FuseLeadingUnsqueezeReshape(),
255
+ CastClampMixedTypeArgs(),
254
256
  ]
255
257
  )
256
258
  circle_legalize.run(exported_program)
tico/utils/model.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ from pathlib import Path
17
18
  from typing import Any
18
19
 
19
20
  from tico.interpreter import infer
@@ -32,6 +33,6 @@ class CircleModel:
32
33
  buf = bytes(f.read())
33
34
  return CircleModel(buf)
34
35
 
35
- def save(self, circle_path: str) -> None:
36
+ def save(self, circle_path: str | Path) -> None:
36
37
  with open(circle_path, "wb") as f:
37
38
  f.write(self.circle_binary)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250715
3
+ Version: 0.1.0.dev250717
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=MYWB0f9ftIZXXj1q1Sdv4Qn0EGgO27twfOAD_gDGNVQ,1743
1
+ tico/__init__.py,sha256=8WsnAhznDCSGOK_vrdZdi2apsz1wqJDKSfCjR4LTG8c,1743
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -61,6 +61,7 @@ tico/interpreter/infer.py,sha256=1ZFe3DVMR2mlwBosoedqoL0-CGN_01CKLgMgxuw62KA,486
61
61
  tico/interpreter/interpreter.py,sha256=tGbluCbrehTCqBu8mtGDNzby_ieJ2ry8_RH_eC0CQxk,3828
62
62
  tico/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
63
63
  tico/passes/cast_aten_where_arg_type.py,sha256=ybtGj1L7_2zGyfb_G-_y1N1mRgKHVq6fBZc-9-fH9sA,7229
64
+ tico/passes/cast_clamp_mixed_type_args.py,sha256=m3_HpXLywWmWERfE5lM5PgvjBod7C4BWu_Q-TkRyO8k,5387
64
65
  tico/passes/cast_mixed_type_args.py,sha256=ArJpIPnQP1LPNaWIwee13Nbj749_awFKDO-pAYvdYvI,7618
65
66
  tico/passes/const_prop_pass.py,sha256=QOeR2u3fo9ZhWXRhfAUW1dTtuWgqgoqdDJoQ516UDbQ,11532
66
67
  tico/passes/convert_conv1d_to_conv2d.py,sha256=7YljWJQBX5vBUMgGgRv8TvbJ9UpEL9hf4ZU3dNUhEZ8,5301
@@ -84,7 +85,7 @@ tico/passes/lower_pow2_to_mul.py,sha256=nfJXa9ZTZMiLg6ownSyvkM4KF2z9tZW34Q3CCWI_
84
85
  tico/passes/lower_to_resize_nearest_neighbor.py,sha256=N6F56Of8Aiv-KIiYLHnh33WX72W60ZVQSBEYWHdYqNQ,9005
85
86
  tico/passes/lower_to_slice.py,sha256=0qAX3WzZdyMFDW4DiO9b5JFXd4rL1-0doBT6lJvaw_I,7260
86
87
  tico/passes/merge_consecutive_cat.py,sha256=BYmiU170DsrHQMj7gMe7U6ZpndrX-S4OpvJweDdspec,2701
87
- tico/passes/ops.py,sha256=XzaKC_FpsfJLpnU4JlL9X-HVarWKm8cX0iiRgx9bMOs,2909
88
+ tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
88
89
  tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
89
90
  tico/passes/remove_redundant_assert_nodes.py,sha256=IONd3xBy6I8tH6_Y1eN3_eCHH7WTC8soBgjXzOju9cQ,1612
90
91
  tico/passes/remove_redundant_expand.py,sha256=5SIqN7eIIcqF68tlrB31n1482jSBSBOgKb1wddLX6lw,2197
@@ -182,14 +183,14 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
182
183
  tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
183
184
  tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
184
185
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
185
- tico/utils/convert.py,sha256=11Ps0i4-3Fmcts_PZO5Eo8rRL7FcLV33eHz_G97MnCg,12865
186
+ tico/utils/convert.py,sha256=w4l7fnqbiVACOU5-OXr8Ebyl4EMeeBz6vwUSuOS_CtI,12977
186
187
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
187
188
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
188
189
  tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
189
190
  tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
190
191
  tico/utils/installed_packages.py,sha256=J0FTwnkCGs0MxRWoCMYAqiwH7Z0GWFDLV--x-IndSp4,1017
191
192
  tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
192
- tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
193
+ tico/utils/model.py,sha256=pPOIjD0qjQirLibiRxxfjOR6efimOcDAd9R-74eus-k,1282
193
194
  tico/utils/padding.py,sha256=jyNhGmlLZfruWZ6n5hll8RZOFg85iCZP8OJqnHGS97g,3293
194
195
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
195
196
  tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
@@ -202,9 +203,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
202
203
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
203
204
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
204
205
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
205
- tico-0.1.0.dev250715.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
206
- tico-0.1.0.dev250715.dist-info/METADATA,sha256=XNqsTtUt8jSqU2EsY3sm3RJDKwRGerDxGkH9eMXwOQk,8430
207
- tico-0.1.0.dev250715.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
208
- tico-0.1.0.dev250715.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
209
- tico-0.1.0.dev250715.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
210
- tico-0.1.0.dev250715.dist-info/RECORD,,
206
+ tico-0.1.0.dev250717.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
207
+ tico-0.1.0.dev250717.dist-info/METADATA,sha256=AxsK-qqfRS2Cd0fJ8ChPI-pVZHvX5Kt1XeB7SMkdyKc,8430
208
+ tico-0.1.0.dev250717.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
209
+ tico-0.1.0.dev250717.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
210
+ tico-0.1.0.dev250717.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
211
+ tico-0.1.0.dev250717.dist-info/RECORD,,