tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev250716__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250714"
24
+ __version__ = "0.1.0.dev250716"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,169 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+
24
+ from tico.serialize.circle_mapping import extract_torch_dtype
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import is_target_node, set_new_meta_val
30
+ from tico.utils.validate_args_kwargs import ClampArgs
31
+
32
+
33
+ @trace_graph_diff_on_pass
34
+ class CastClampMixedTypeArgs(PassBase):
35
+ """
36
+ This pass ensures consistent dtypes for clamp operations by:
37
+ 1. Converting min/max arguments to match output dtype when provided
38
+ 2. Inserting cast operations when input dtype differs from output dtype
39
+
40
+ Behavior Examples:
41
+ - When input dtype differs from output:
42
+ Inserts _to_copy operation to convert input
43
+ - When min/max dtype differs from output:
44
+ Converts min/max values to output dtype
45
+
46
+ (Case 1, if input dtype is different from output dtype)
47
+ [before]
48
+
49
+ input min(or max)
50
+ (dtype=int) (dtype=float)
51
+ | |
52
+ clamp <----------------+
53
+ |
54
+ output
55
+ (dtype=float)
56
+
57
+ [after]
58
+
59
+ input min(or max)
60
+ (dtype=int) (dtype=float)
61
+ | |
62
+ cast |
63
+ (in=int, out=float) |
64
+ | |
65
+ clamp <--------------+
66
+ |
67
+ output
68
+ (dtype=float)
69
+
70
+ (Case 2, if min(or max) dtype is different from output dtype)
71
+ [before]
72
+
73
+ input min(or max)
74
+ (dtype=float) (dtype=int)
75
+ | |
76
+ clamp <----------------+
77
+ |
78
+ output
79
+ (dtype=float)
80
+
81
+ [after]
82
+
83
+ input min(or max)
84
+ (dtype=float) (dtype=float)
85
+ | |
86
+ clamp <--------------+
87
+ |
88
+ output
89
+ (dtype=float)
90
+ """
91
+
92
+ def __init__(self):
93
+ super().__init__()
94
+
95
+ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
96
+ logger = logging.getLogger(__name__)
97
+ modified = False
98
+
99
+ graph_module = exported_program.graph_module
100
+ graph = graph_module.graph
101
+
102
+ # clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
103
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
104
+
105
+ input = args.input
106
+ min = args.min
107
+ max = args.max
108
+
109
+ input_dtype = extract_torch_dtype(input)
110
+ output_dtype = extract_torch_dtype(node)
111
+
112
+ def _convert_arg(arg, arg_name: str):
113
+ if arg is None:
114
+ return False
115
+
116
+ arg_dtype = torch.tensor(arg).dtype
117
+ arg_idx = node.args.index(arg)
118
+ if arg_dtype != output_dtype:
119
+ assert output_dtype in [torch.float, torch.int]
120
+ if output_dtype == torch.float:
121
+ arg = float(arg)
122
+ else:
123
+ arg = int(arg)
124
+ node.update_arg(arg_idx, arg)
125
+ logger.debug(
126
+ f"Casting {arg_name} value from {arg_dtype} to {output_dtype} for clamp operation at {node.name}"
127
+ )
128
+ return True
129
+ return False
130
+
131
+ modified |= _convert_arg(min, "min")
132
+ modified |= _convert_arg(max, "max")
133
+
134
+ if input_dtype != output_dtype:
135
+ logger.debug(
136
+ f"Inserting cast from {input_dtype} to {output_dtype} for input {input.name}"
137
+ )
138
+ with graph.inserting_after(input):
139
+ to_copy = create_node(
140
+ graph,
141
+ torch.ops.aten._to_copy.default,
142
+ (input,),
143
+ {"dtype": output_dtype},
144
+ origin=input,
145
+ )
146
+ set_new_meta_val(to_copy)
147
+ node.update_arg(node.args.index(input), to_copy)
148
+
149
+ modified = True
150
+
151
+ return modified
152
+
153
+ def call(self, exported_program: ExportedProgram) -> PassResult:
154
+ target_op = ops.aten.clamp
155
+
156
+ graph_module = exported_program.graph_module
157
+ graph = graph_module.graph
158
+ modified = False
159
+ for node in graph.nodes:
160
+ if not is_target_node(node, target_op):
161
+ continue
162
+
163
+ modified |= self.convert(exported_program, node)
164
+
165
+ graph.eliminate_dead_code()
166
+ graph.lint()
167
+ graph_module.recompile()
168
+
169
+ return PassResult(modified)
tico/passes/ops.py CHANGED
@@ -73,6 +73,7 @@ class AtenOps:
73
73
  torch.ops.aten.view.default,
74
74
  torch.ops.aten.view_copy.default,
75
75
  ]
76
+ self._to_copy = [torch.ops.aten._to_copy.default]
76
77
 
77
78
 
78
79
  aten = AtenOps()
tico/utils/convert.py CHANGED
@@ -35,6 +35,7 @@ from tico.experimental.quantization.passes.remove_weight_dequant_op import (
35
35
  RemoveWeightDequantOp,
36
36
  )
37
37
  from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
38
+ from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
38
39
  from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
39
40
  from tico.passes.const_prop_pass import ConstPropPass
40
41
  from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
@@ -251,6 +252,7 @@ def convert_exported_module_to_circle(
251
252
  ConvertConv1dToConv2d(),
252
253
  *LowerToSlicePasses(),
253
254
  FuseLeadingUnsqueezeReshape(),
255
+ CastClampMixedTypeArgs(),
254
256
  ]
255
257
  )
256
258
  circle_legalize.run(exported_program)
tico/utils/model.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ from pathlib import Path
17
18
  from typing import Any
18
19
 
19
20
  from tico.interpreter import infer
@@ -32,6 +33,6 @@ class CircleModel:
32
33
  buf = bytes(f.read())
33
34
  return CircleModel(buf)
34
35
 
35
- def save(self, circle_path: str) -> None:
36
+ def save(self, circle_path: str | Path) -> None:
36
37
  with open(circle_path, "wb") as f:
37
38
  f.write(self.circle_binary)
@@ -0,0 +1,134 @@
1
+ import threading
2
+
3
+ import torch
4
+ from packaging.version import Version
5
+
6
+ from tico.utils import logging
7
+ from tico.utils.installed_packages import is_transformers_installed
8
+
9
+ __all__ = ["register_dynamic_cache"]
10
+
11
+
12
+ def register_dynamic_cache():
13
+ PyTreeRegistryHelper().register_dynamic_cache()
14
+
15
+
16
+ class PyTreeRegistryHelper:
17
+ """
18
+ Thread-safe singleton helper class for registering custom PyTree nodes.
19
+
20
+ This class provides functionality to register DynamicCache as a PyTree node
21
+ for torch.export compatibility. This registration is only needed for
22
+ transformers versions below 4.50.0.
23
+
24
+ Thread Safety:
25
+ - Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation
26
+ - Uses the same lock to protect the registration process from concurrent calls
27
+ """
28
+
29
+ _instance = None # Class variable to hold the singleton instance
30
+ _has_called = False # Flag to track if registration has been performed
31
+ _lock = threading.Lock() # Class-level lock for thread-safe operations
32
+
33
+ def __init__(self):
34
+ """Private constructor to prevent direct instantiation"""
35
+ pass
36
+
37
+ def __new__(cls, *args, **kwargs):
38
+ """
39
+ Thread-safe singleton instance creation using double-checked locking pattern.
40
+
41
+ Returns:
42
+ PyTreeRegistryHelper: The singleton instance of this class
43
+ """
44
+ if not cls._instance:
45
+ with cls._lock: # Acquire lock for thread-safe instantiation
46
+ if not cls._instance: # Double-check after acquiring lock
47
+ cls._instance = super().__new__(cls)
48
+ return cls._instance
49
+
50
+ def register_dynamic_cache(self):
51
+ """
52
+ Registers DynamicCache as a PyTree node for torch.export compatibility.
53
+
54
+ This method is thread-safe and idempotent - it will only perform the
55
+ registration once, even if called multiple times from different threads.
56
+
57
+ Note:
58
+ This registration is only needed for transformers versions below 4.50.0.
59
+
60
+ Raises:
61
+ ImportError: If transformers package is not installed
62
+ """
63
+ with self._lock: # Acquire lock for thread-safe registration
64
+ if self.__class__._has_called:
65
+ logger = logging.getLogger(__name__)
66
+ logger.debug("register_dynamic_cache already called, skipping")
67
+ return
68
+
69
+ self.__class__._has_called = True
70
+ logger = logging.getLogger(__name__)
71
+ logger.info("Registering DynamicCache PyTree node")
72
+
73
+ if not is_transformers_installed: # type: ignore[truthy-function]
74
+ raise ImportError("transformers package is not installed")
75
+
76
+ import transformers
77
+
78
+ HAS_TRANSFORMERS_LESS_4_50_0 = Version(transformers.__version__) < Version(
79
+ "4.50.0"
80
+ )
81
+ if not HAS_TRANSFORMERS_LESS_4_50_0:
82
+ return
83
+
84
+ from transformers.cache_utils import DynamicCache
85
+
86
+ def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
87
+ if not isinstance(dynamic_cache, DynamicCache):
88
+ raise RuntimeError(
89
+ "This pytree flattening function should only be applied to DynamicCache"
90
+ )
91
+ HAS_TORCH_2_6_0 = Version(torch.__version__) >= Version("2.6.0")
92
+ if not HAS_TORCH_2_6_0:
93
+ logger = logging.getLogger(__name__)
94
+ logger.warning_once(
95
+ "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
96
+ )
97
+ dictionary = {
98
+ "key_cache": getattr(dynamic_cache, "key_cache"),
99
+ "value_cache": getattr(dynamic_cache, "value_cache"),
100
+ }
101
+ return torch.utils._pytree._dict_flatten(dictionary)
102
+
103
+ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
104
+ dictionary = {
105
+ "key_cache": getattr(dynamic_cache, "key_cache"),
106
+ "value_cache": getattr(dynamic_cache, "value_cache"),
107
+ }
108
+ return torch.utils._pytree._dict_flatten_with_keys(dictionary)
109
+
110
+ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
111
+ dictionary = torch.utils._pytree._dict_unflatten(values, context)
112
+ cache = DynamicCache()
113
+ for k, v in dictionary.items():
114
+ setattr(cache, k, v)
115
+ return cache
116
+
117
+ def _flatten_dynamic_cache_for_fx(cache, spec):
118
+ dictionary = {
119
+ "key_cache": getattr(cache, "key_cache"),
120
+ "value_cache": getattr(cache, "value_cache"),
121
+ }
122
+ return torch.fx._pytree._dict_flatten_spec(dictionary, spec)
123
+
124
+ torch.utils._pytree.register_pytree_node(
125
+ DynamicCache,
126
+ _flatten_dynamic_cache,
127
+ _unflatten_dynamic_cache,
128
+ serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
129
+ flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
130
+ )
131
+ # TODO: This won't be needed in torch 2.7+.
132
+ torch.fx._pytree.register_pytree_flatten_spec(
133
+ DynamicCache, _flatten_dynamic_cache_for_fx
134
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250714
3
+ Version: 0.1.0.dev250716
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=xbMeC4-9n5Z8Lu1B6yqbmE10G6wmHdhpue7kD7skYtM,1743
1
+ tico/__init__.py,sha256=7Bu_kNw98Z5aDMnphDfqKQ00FlgKKT8iym416i5oNBI,1743
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -61,6 +61,7 @@ tico/interpreter/infer.py,sha256=1ZFe3DVMR2mlwBosoedqoL0-CGN_01CKLgMgxuw62KA,486
61
61
  tico/interpreter/interpreter.py,sha256=tGbluCbrehTCqBu8mtGDNzby_ieJ2ry8_RH_eC0CQxk,3828
62
62
  tico/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
63
63
  tico/passes/cast_aten_where_arg_type.py,sha256=ybtGj1L7_2zGyfb_G-_y1N1mRgKHVq6fBZc-9-fH9sA,7229
64
+ tico/passes/cast_clamp_mixed_type_args.py,sha256=m3_HpXLywWmWERfE5lM5PgvjBod7C4BWu_Q-TkRyO8k,5387
64
65
  tico/passes/cast_mixed_type_args.py,sha256=ArJpIPnQP1LPNaWIwee13Nbj749_awFKDO-pAYvdYvI,7618
65
66
  tico/passes/const_prop_pass.py,sha256=QOeR2u3fo9ZhWXRhfAUW1dTtuWgqgoqdDJoQ516UDbQ,11532
66
67
  tico/passes/convert_conv1d_to_conv2d.py,sha256=7YljWJQBX5vBUMgGgRv8TvbJ9UpEL9hf4ZU3dNUhEZ8,5301
@@ -84,7 +85,7 @@ tico/passes/lower_pow2_to_mul.py,sha256=nfJXa9ZTZMiLg6ownSyvkM4KF2z9tZW34Q3CCWI_
84
85
  tico/passes/lower_to_resize_nearest_neighbor.py,sha256=N6F56Of8Aiv-KIiYLHnh33WX72W60ZVQSBEYWHdYqNQ,9005
85
86
  tico/passes/lower_to_slice.py,sha256=0qAX3WzZdyMFDW4DiO9b5JFXd4rL1-0doBT6lJvaw_I,7260
86
87
  tico/passes/merge_consecutive_cat.py,sha256=BYmiU170DsrHQMj7gMe7U6ZpndrX-S4OpvJweDdspec,2701
87
- tico/passes/ops.py,sha256=XzaKC_FpsfJLpnU4JlL9X-HVarWKm8cX0iiRgx9bMOs,2909
88
+ tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
88
89
  tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
89
90
  tico/passes/remove_redundant_assert_nodes.py,sha256=IONd3xBy6I8tH6_Y1eN3_eCHH7WTC8soBgjXzOju9cQ,1612
90
91
  tico/passes/remove_redundant_expand.py,sha256=5SIqN7eIIcqF68tlrB31n1482jSBSBOgKb1wddLX6lw,2197
@@ -182,16 +183,17 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
182
183
  tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
183
184
  tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
184
185
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
185
- tico/utils/convert.py,sha256=11Ps0i4-3Fmcts_PZO5Eo8rRL7FcLV33eHz_G97MnCg,12865
186
+ tico/utils/convert.py,sha256=w4l7fnqbiVACOU5-OXr8Ebyl4EMeeBz6vwUSuOS_CtI,12977
186
187
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
187
188
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
188
189
  tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
189
190
  tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
190
191
  tico/utils/installed_packages.py,sha256=J0FTwnkCGs0MxRWoCMYAqiwH7Z0GWFDLV--x-IndSp4,1017
191
192
  tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
192
- tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
193
+ tico/utils/model.py,sha256=pPOIjD0qjQirLibiRxxfjOR6efimOcDAd9R-74eus-k,1282
193
194
  tico/utils/padding.py,sha256=jyNhGmlLZfruWZ6n5hll8RZOFg85iCZP8OJqnHGS97g,3293
194
195
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
196
+ tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
195
197
  tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
196
198
  tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
197
199
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
@@ -201,9 +203,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
201
203
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
202
204
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
203
205
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
204
- tico-0.1.0.dev250714.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
205
- tico-0.1.0.dev250714.dist-info/METADATA,sha256=5RPA0JUl-L5MOz1_Ix5-8SZSpn628DHjCpNZHPvrO68,8430
206
- tico-0.1.0.dev250714.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
207
- tico-0.1.0.dev250714.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
208
- tico-0.1.0.dev250714.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
209
- tico-0.1.0.dev250714.dist-info/RECORD,,
206
+ tico-0.1.0.dev250716.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
207
+ tico-0.1.0.dev250716.dist-info/METADATA,sha256=ol8z8iSa9e8WfJcDpebI1h_hXSQdi3jXwIOzIp-wAO4,8430
208
+ tico-0.1.0.dev250716.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
209
+ tico-0.1.0.dev250716.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
210
+ tico-0.1.0.dev250716.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
211
+ tico-0.1.0.dev250716.dist-info/RECORD,,