tico 0.1.0.dev250715__py3-none-any.whl → 0.1.0.dev250716__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/ops.py +1 -0
- tico/utils/convert.py +2 -0
- tico/utils/model.py +2 -1
- {tico-0.1.0.dev250715.dist-info → tico-0.1.0.dev250716.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250715.dist-info → tico-0.1.0.dev250716.dist-info}/RECORD +11 -10
- {tico-0.1.0.dev250715.dist-info → tico-0.1.0.dev250716.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250715.dist-info → tico-0.1.0.dev250716.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250715.dist-info → tico-0.1.0.dev250716.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250715.dist-info → tico-0.1.0.dev250716.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250716"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
|
24
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
30
|
+
from tico.utils.validate_args_kwargs import ClampArgs
|
31
|
+
|
32
|
+
|
33
|
+
@trace_graph_diff_on_pass
|
34
|
+
class CastClampMixedTypeArgs(PassBase):
|
35
|
+
"""
|
36
|
+
This pass ensures consistent dtypes for clamp operations by:
|
37
|
+
1. Converting min/max arguments to match output dtype when provided
|
38
|
+
2. Inserting cast operations when input dtype differs from output dtype
|
39
|
+
|
40
|
+
Behavior Examples:
|
41
|
+
- When input dtype differs from output:
|
42
|
+
Inserts _to_copy operation to convert input
|
43
|
+
- When min/max dtype differs from output:
|
44
|
+
Converts min/max values to output dtype
|
45
|
+
|
46
|
+
(Case 1, if input dtype is different from output dtype)
|
47
|
+
[before]
|
48
|
+
|
49
|
+
input min(or max)
|
50
|
+
(dtype=int) (dtype=float)
|
51
|
+
| |
|
52
|
+
clamp <----------------+
|
53
|
+
|
|
54
|
+
output
|
55
|
+
(dtype=float)
|
56
|
+
|
57
|
+
[after]
|
58
|
+
|
59
|
+
input min(or max)
|
60
|
+
(dtype=int) (dtype=float)
|
61
|
+
| |
|
62
|
+
cast |
|
63
|
+
(in=int, out=float) |
|
64
|
+
| |
|
65
|
+
clamp <--------------+
|
66
|
+
|
|
67
|
+
output
|
68
|
+
(dtype=float)
|
69
|
+
|
70
|
+
(Case 2, if min(or max) dtype is different from output dtype)
|
71
|
+
[before]
|
72
|
+
|
73
|
+
input min(or max)
|
74
|
+
(dtype=float) (dtype=int)
|
75
|
+
| |
|
76
|
+
clamp <----------------+
|
77
|
+
|
|
78
|
+
output
|
79
|
+
(dtype=float)
|
80
|
+
|
81
|
+
[after]
|
82
|
+
|
83
|
+
input min(or max)
|
84
|
+
(dtype=float) (dtype=float)
|
85
|
+
| |
|
86
|
+
clamp <--------------+
|
87
|
+
|
|
88
|
+
output
|
89
|
+
(dtype=float)
|
90
|
+
"""
|
91
|
+
|
92
|
+
def __init__(self):
|
93
|
+
super().__init__()
|
94
|
+
|
95
|
+
def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
|
96
|
+
logger = logging.getLogger(__name__)
|
97
|
+
modified = False
|
98
|
+
|
99
|
+
graph_module = exported_program.graph_module
|
100
|
+
graph = graph_module.graph
|
101
|
+
|
102
|
+
# clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
|
103
|
+
args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
104
|
+
|
105
|
+
input = args.input
|
106
|
+
min = args.min
|
107
|
+
max = args.max
|
108
|
+
|
109
|
+
input_dtype = extract_torch_dtype(input)
|
110
|
+
output_dtype = extract_torch_dtype(node)
|
111
|
+
|
112
|
+
def _convert_arg(arg, arg_name: str):
|
113
|
+
if arg is None:
|
114
|
+
return False
|
115
|
+
|
116
|
+
arg_dtype = torch.tensor(arg).dtype
|
117
|
+
arg_idx = node.args.index(arg)
|
118
|
+
if arg_dtype != output_dtype:
|
119
|
+
assert output_dtype in [torch.float, torch.int]
|
120
|
+
if output_dtype == torch.float:
|
121
|
+
arg = float(arg)
|
122
|
+
else:
|
123
|
+
arg = int(arg)
|
124
|
+
node.update_arg(arg_idx, arg)
|
125
|
+
logger.debug(
|
126
|
+
f"Casting {arg_name} value from {arg_dtype} to {output_dtype} for clamp operation at {node.name}"
|
127
|
+
)
|
128
|
+
return True
|
129
|
+
return False
|
130
|
+
|
131
|
+
modified |= _convert_arg(min, "min")
|
132
|
+
modified |= _convert_arg(max, "max")
|
133
|
+
|
134
|
+
if input_dtype != output_dtype:
|
135
|
+
logger.debug(
|
136
|
+
f"Inserting cast from {input_dtype} to {output_dtype} for input {input.name}"
|
137
|
+
)
|
138
|
+
with graph.inserting_after(input):
|
139
|
+
to_copy = create_node(
|
140
|
+
graph,
|
141
|
+
torch.ops.aten._to_copy.default,
|
142
|
+
(input,),
|
143
|
+
{"dtype": output_dtype},
|
144
|
+
origin=input,
|
145
|
+
)
|
146
|
+
set_new_meta_val(to_copy)
|
147
|
+
node.update_arg(node.args.index(input), to_copy)
|
148
|
+
|
149
|
+
modified = True
|
150
|
+
|
151
|
+
return modified
|
152
|
+
|
153
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
154
|
+
target_op = ops.aten.clamp
|
155
|
+
|
156
|
+
graph_module = exported_program.graph_module
|
157
|
+
graph = graph_module.graph
|
158
|
+
modified = False
|
159
|
+
for node in graph.nodes:
|
160
|
+
if not is_target_node(node, target_op):
|
161
|
+
continue
|
162
|
+
|
163
|
+
modified |= self.convert(exported_program, node)
|
164
|
+
|
165
|
+
graph.eliminate_dead_code()
|
166
|
+
graph.lint()
|
167
|
+
graph_module.recompile()
|
168
|
+
|
169
|
+
return PassResult(modified)
|
tico/passes/ops.py
CHANGED
tico/utils/convert.py
CHANGED
@@ -35,6 +35,7 @@ from tico.experimental.quantization.passes.remove_weight_dequant_op import (
|
|
35
35
|
RemoveWeightDequantOp,
|
36
36
|
)
|
37
37
|
from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
|
38
|
+
from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
|
38
39
|
from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
|
39
40
|
from tico.passes.const_prop_pass import ConstPropPass
|
40
41
|
from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
|
@@ -251,6 +252,7 @@ def convert_exported_module_to_circle(
|
|
251
252
|
ConvertConv1dToConv2d(),
|
252
253
|
*LowerToSlicePasses(),
|
253
254
|
FuseLeadingUnsqueezeReshape(),
|
255
|
+
CastClampMixedTypeArgs(),
|
254
256
|
]
|
255
257
|
)
|
256
258
|
circle_legalize.run(exported_program)
|
tico/utils/model.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
from pathlib import Path
|
17
18
|
from typing import Any
|
18
19
|
|
19
20
|
from tico.interpreter import infer
|
@@ -32,6 +33,6 @@ class CircleModel:
|
|
32
33
|
buf = bytes(f.read())
|
33
34
|
return CircleModel(buf)
|
34
35
|
|
35
|
-
def save(self, circle_path: str) -> None:
|
36
|
+
def save(self, circle_path: str | Path) -> None:
|
36
37
|
with open(circle_path, "wb") as f:
|
37
38
|
f.write(self.circle_binary)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=7Bu_kNw98Z5aDMnphDfqKQ00FlgKKT8iym416i5oNBI,1743
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -61,6 +61,7 @@ tico/interpreter/infer.py,sha256=1ZFe3DVMR2mlwBosoedqoL0-CGN_01CKLgMgxuw62KA,486
|
|
61
61
|
tico/interpreter/interpreter.py,sha256=tGbluCbrehTCqBu8mtGDNzby_ieJ2ry8_RH_eC0CQxk,3828
|
62
62
|
tico/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
63
63
|
tico/passes/cast_aten_where_arg_type.py,sha256=ybtGj1L7_2zGyfb_G-_y1N1mRgKHVq6fBZc-9-fH9sA,7229
|
64
|
+
tico/passes/cast_clamp_mixed_type_args.py,sha256=m3_HpXLywWmWERfE5lM5PgvjBod7C4BWu_Q-TkRyO8k,5387
|
64
65
|
tico/passes/cast_mixed_type_args.py,sha256=ArJpIPnQP1LPNaWIwee13Nbj749_awFKDO-pAYvdYvI,7618
|
65
66
|
tico/passes/const_prop_pass.py,sha256=QOeR2u3fo9ZhWXRhfAUW1dTtuWgqgoqdDJoQ516UDbQ,11532
|
66
67
|
tico/passes/convert_conv1d_to_conv2d.py,sha256=7YljWJQBX5vBUMgGgRv8TvbJ9UpEL9hf4ZU3dNUhEZ8,5301
|
@@ -84,7 +85,7 @@ tico/passes/lower_pow2_to_mul.py,sha256=nfJXa9ZTZMiLg6ownSyvkM4KF2z9tZW34Q3CCWI_
|
|
84
85
|
tico/passes/lower_to_resize_nearest_neighbor.py,sha256=N6F56Of8Aiv-KIiYLHnh33WX72W60ZVQSBEYWHdYqNQ,9005
|
85
86
|
tico/passes/lower_to_slice.py,sha256=0qAX3WzZdyMFDW4DiO9b5JFXd4rL1-0doBT6lJvaw_I,7260
|
86
87
|
tico/passes/merge_consecutive_cat.py,sha256=BYmiU170DsrHQMj7gMe7U6ZpndrX-S4OpvJweDdspec,2701
|
87
|
-
tico/passes/ops.py,sha256=
|
88
|
+
tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
|
88
89
|
tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
|
89
90
|
tico/passes/remove_redundant_assert_nodes.py,sha256=IONd3xBy6I8tH6_Y1eN3_eCHH7WTC8soBgjXzOju9cQ,1612
|
90
91
|
tico/passes/remove_redundant_expand.py,sha256=5SIqN7eIIcqF68tlrB31n1482jSBSBOgKb1wddLX6lw,2197
|
@@ -182,14 +183,14 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
|
|
182
183
|
tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
|
183
184
|
tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
|
184
185
|
tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
185
|
-
tico/utils/convert.py,sha256=
|
186
|
+
tico/utils/convert.py,sha256=w4l7fnqbiVACOU5-OXr8Ebyl4EMeeBz6vwUSuOS_CtI,12977
|
186
187
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
187
188
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
188
189
|
tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
|
189
190
|
tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
|
190
191
|
tico/utils/installed_packages.py,sha256=J0FTwnkCGs0MxRWoCMYAqiwH7Z0GWFDLV--x-IndSp4,1017
|
191
192
|
tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
|
192
|
-
tico/utils/model.py,sha256=
|
193
|
+
tico/utils/model.py,sha256=pPOIjD0qjQirLibiRxxfjOR6efimOcDAd9R-74eus-k,1282
|
193
194
|
tico/utils/padding.py,sha256=jyNhGmlLZfruWZ6n5hll8RZOFg85iCZP8OJqnHGS97g,3293
|
194
195
|
tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
195
196
|
tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
|
@@ -202,9 +203,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
202
203
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
203
204
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
204
205
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
205
|
-
tico-0.1.0.
|
206
|
-
tico-0.1.0.
|
207
|
-
tico-0.1.0.
|
208
|
-
tico-0.1.0.
|
209
|
-
tico-0.1.0.
|
210
|
-
tico-0.1.0.
|
206
|
+
tico-0.1.0.dev250716.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
207
|
+
tico-0.1.0.dev250716.dist-info/METADATA,sha256=ol8z8iSa9e8WfJcDpebI1h_hXSQdi3jXwIOzIp-wAO4,8430
|
208
|
+
tico-0.1.0.dev250716.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
209
|
+
tico-0.1.0.dev250716.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
210
|
+
tico-0.1.0.dev250716.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
211
|
+
tico-0.1.0.dev250716.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|