tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev250715__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/utils/pytree_utils.py +134 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev250715.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev250715.dist-info}/RECORD +8 -7
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev250715.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev250715.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev250715.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev250715.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250715"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import threading
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from packaging.version import Version
|
5
|
+
|
6
|
+
from tico.utils import logging
|
7
|
+
from tico.utils.installed_packages import is_transformers_installed
|
8
|
+
|
9
|
+
__all__ = ["register_dynamic_cache"]
|
10
|
+
|
11
|
+
|
12
|
+
def register_dynamic_cache():
|
13
|
+
PyTreeRegistryHelper().register_dynamic_cache()
|
14
|
+
|
15
|
+
|
16
|
+
class PyTreeRegistryHelper:
|
17
|
+
"""
|
18
|
+
Thread-safe singleton helper class for registering custom PyTree nodes.
|
19
|
+
|
20
|
+
This class provides functionality to register DynamicCache as a PyTree node
|
21
|
+
for torch.export compatibility. This registration is only needed for
|
22
|
+
transformers versions below 4.50.0.
|
23
|
+
|
24
|
+
Thread Safety:
|
25
|
+
- Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation
|
26
|
+
- Uses the same lock to protect the registration process from concurrent calls
|
27
|
+
"""
|
28
|
+
|
29
|
+
_instance = None # Class variable to hold the singleton instance
|
30
|
+
_has_called = False # Flag to track if registration has been performed
|
31
|
+
_lock = threading.Lock() # Class-level lock for thread-safe operations
|
32
|
+
|
33
|
+
def __init__(self):
|
34
|
+
"""Private constructor to prevent direct instantiation"""
|
35
|
+
pass
|
36
|
+
|
37
|
+
def __new__(cls, *args, **kwargs):
|
38
|
+
"""
|
39
|
+
Thread-safe singleton instance creation using double-checked locking pattern.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
PyTreeRegistryHelper: The singleton instance of this class
|
43
|
+
"""
|
44
|
+
if not cls._instance:
|
45
|
+
with cls._lock: # Acquire lock for thread-safe instantiation
|
46
|
+
if not cls._instance: # Double-check after acquiring lock
|
47
|
+
cls._instance = super().__new__(cls)
|
48
|
+
return cls._instance
|
49
|
+
|
50
|
+
def register_dynamic_cache(self):
|
51
|
+
"""
|
52
|
+
Registers DynamicCache as a PyTree node for torch.export compatibility.
|
53
|
+
|
54
|
+
This method is thread-safe and idempotent - it will only perform the
|
55
|
+
registration once, even if called multiple times from different threads.
|
56
|
+
|
57
|
+
Note:
|
58
|
+
This registration is only needed for transformers versions below 4.50.0.
|
59
|
+
|
60
|
+
Raises:
|
61
|
+
ImportError: If transformers package is not installed
|
62
|
+
"""
|
63
|
+
with self._lock: # Acquire lock for thread-safe registration
|
64
|
+
if self.__class__._has_called:
|
65
|
+
logger = logging.getLogger(__name__)
|
66
|
+
logger.debug("register_dynamic_cache already called, skipping")
|
67
|
+
return
|
68
|
+
|
69
|
+
self.__class__._has_called = True
|
70
|
+
logger = logging.getLogger(__name__)
|
71
|
+
logger.info("Registering DynamicCache PyTree node")
|
72
|
+
|
73
|
+
if not is_transformers_installed: # type: ignore[truthy-function]
|
74
|
+
raise ImportError("transformers package is not installed")
|
75
|
+
|
76
|
+
import transformers
|
77
|
+
|
78
|
+
HAS_TRANSFORMERS_LESS_4_50_0 = Version(transformers.__version__) < Version(
|
79
|
+
"4.50.0"
|
80
|
+
)
|
81
|
+
if not HAS_TRANSFORMERS_LESS_4_50_0:
|
82
|
+
return
|
83
|
+
|
84
|
+
from transformers.cache_utils import DynamicCache
|
85
|
+
|
86
|
+
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
|
87
|
+
if not isinstance(dynamic_cache, DynamicCache):
|
88
|
+
raise RuntimeError(
|
89
|
+
"This pytree flattening function should only be applied to DynamicCache"
|
90
|
+
)
|
91
|
+
HAS_TORCH_2_6_0 = Version(torch.__version__) >= Version("2.6.0")
|
92
|
+
if not HAS_TORCH_2_6_0:
|
93
|
+
logger = logging.getLogger(__name__)
|
94
|
+
logger.warning_once(
|
95
|
+
"DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
|
96
|
+
)
|
97
|
+
dictionary = {
|
98
|
+
"key_cache": getattr(dynamic_cache, "key_cache"),
|
99
|
+
"value_cache": getattr(dynamic_cache, "value_cache"),
|
100
|
+
}
|
101
|
+
return torch.utils._pytree._dict_flatten(dictionary)
|
102
|
+
|
103
|
+
def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
|
104
|
+
dictionary = {
|
105
|
+
"key_cache": getattr(dynamic_cache, "key_cache"),
|
106
|
+
"value_cache": getattr(dynamic_cache, "value_cache"),
|
107
|
+
}
|
108
|
+
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
109
|
+
|
110
|
+
def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
|
111
|
+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
112
|
+
cache = DynamicCache()
|
113
|
+
for k, v in dictionary.items():
|
114
|
+
setattr(cache, k, v)
|
115
|
+
return cache
|
116
|
+
|
117
|
+
def _flatten_dynamic_cache_for_fx(cache, spec):
|
118
|
+
dictionary = {
|
119
|
+
"key_cache": getattr(cache, "key_cache"),
|
120
|
+
"value_cache": getattr(cache, "value_cache"),
|
121
|
+
}
|
122
|
+
return torch.fx._pytree._dict_flatten_spec(dictionary, spec)
|
123
|
+
|
124
|
+
torch.utils._pytree.register_pytree_node(
|
125
|
+
DynamicCache,
|
126
|
+
_flatten_dynamic_cache,
|
127
|
+
_unflatten_dynamic_cache,
|
128
|
+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
129
|
+
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
130
|
+
)
|
131
|
+
# TODO: This won't be needed in torch 2.7+.
|
132
|
+
torch.fx._pytree.register_pytree_flatten_spec(
|
133
|
+
DynamicCache, _flatten_dynamic_cache_for_fx
|
134
|
+
)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=MYWB0f9ftIZXXj1q1Sdv4Qn0EGgO27twfOAD_gDGNVQ,1743
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -192,6 +192,7 @@ tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
|
|
192
192
|
tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
|
193
193
|
tico/utils/padding.py,sha256=jyNhGmlLZfruWZ6n5hll8RZOFg85iCZP8OJqnHGS97g,3293
|
194
194
|
tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
195
|
+
tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
|
195
196
|
tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
|
196
197
|
tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
|
197
198
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
@@ -201,9 +202,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
201
202
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
202
203
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
203
204
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
204
|
-
tico-0.1.0.
|
205
|
-
tico-0.1.0.
|
206
|
-
tico-0.1.0.
|
207
|
-
tico-0.1.0.
|
208
|
-
tico-0.1.0.
|
209
|
-
tico-0.1.0.
|
205
|
+
tico-0.1.0.dev250715.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
206
|
+
tico-0.1.0.dev250715.dist-info/METADATA,sha256=XNqsTtUt8jSqU2EsY3sm3RJDKwRGerDxGkH9eMXwOQk,8430
|
207
|
+
tico-0.1.0.dev250715.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
208
|
+
tico-0.1.0.dev250715.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
209
|
+
tico-0.1.0.dev250715.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
210
|
+
tico-0.1.0.dev250715.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|