tico 0.1.0.dev250713__py3-none-any.whl → 0.1.0.dev250715__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250713"
24
+ __version__ = "0.1.0.dev250715"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1,134 @@
1
+ import threading
2
+
3
+ import torch
4
+ from packaging.version import Version
5
+
6
+ from tico.utils import logging
7
+ from tico.utils.installed_packages import is_transformers_installed
8
+
9
+ __all__ = ["register_dynamic_cache"]
10
+
11
+
12
+ def register_dynamic_cache():
13
+ PyTreeRegistryHelper().register_dynamic_cache()
14
+
15
+
16
+ class PyTreeRegistryHelper:
17
+ """
18
+ Thread-safe singleton helper class for registering custom PyTree nodes.
19
+
20
+ This class provides functionality to register DynamicCache as a PyTree node
21
+ for torch.export compatibility. This registration is only needed for
22
+ transformers versions below 4.50.0.
23
+
24
+ Thread Safety:
25
+ - Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation
26
+ - Uses the same lock to protect the registration process from concurrent calls
27
+ """
28
+
29
+ _instance = None # Class variable to hold the singleton instance
30
+ _has_called = False # Flag to track if registration has been performed
31
+ _lock = threading.Lock() # Class-level lock for thread-safe operations
32
+
33
+ def __init__(self):
34
+ """Private constructor to prevent direct instantiation"""
35
+ pass
36
+
37
+ def __new__(cls, *args, **kwargs):
38
+ """
39
+ Thread-safe singleton instance creation using double-checked locking pattern.
40
+
41
+ Returns:
42
+ PyTreeRegistryHelper: The singleton instance of this class
43
+ """
44
+ if not cls._instance:
45
+ with cls._lock: # Acquire lock for thread-safe instantiation
46
+ if not cls._instance: # Double-check after acquiring lock
47
+ cls._instance = super().__new__(cls)
48
+ return cls._instance
49
+
50
+ def register_dynamic_cache(self):
51
+ """
52
+ Registers DynamicCache as a PyTree node for torch.export compatibility.
53
+
54
+ This method is thread-safe and idempotent - it will only perform the
55
+ registration once, even if called multiple times from different threads.
56
+
57
+ Note:
58
+ This registration is only needed for transformers versions below 4.50.0.
59
+
60
+ Raises:
61
+ ImportError: If transformers package is not installed
62
+ """
63
+ with self._lock: # Acquire lock for thread-safe registration
64
+ if self.__class__._has_called:
65
+ logger = logging.getLogger(__name__)
66
+ logger.debug("register_dynamic_cache already called, skipping")
67
+ return
68
+
69
+ self.__class__._has_called = True
70
+ logger = logging.getLogger(__name__)
71
+ logger.info("Registering DynamicCache PyTree node")
72
+
73
+ if not is_transformers_installed: # type: ignore[truthy-function]
74
+ raise ImportError("transformers package is not installed")
75
+
76
+ import transformers
77
+
78
+ HAS_TRANSFORMERS_LESS_4_50_0 = Version(transformers.__version__) < Version(
79
+ "4.50.0"
80
+ )
81
+ if not HAS_TRANSFORMERS_LESS_4_50_0:
82
+ return
83
+
84
+ from transformers.cache_utils import DynamicCache
85
+
86
+ def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
87
+ if not isinstance(dynamic_cache, DynamicCache):
88
+ raise RuntimeError(
89
+ "This pytree flattening function should only be applied to DynamicCache"
90
+ )
91
+ HAS_TORCH_2_6_0 = Version(torch.__version__) >= Version("2.6.0")
92
+ if not HAS_TORCH_2_6_0:
93
+ logger = logging.getLogger(__name__)
94
+ logger.warning_once(
95
+ "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
96
+ )
97
+ dictionary = {
98
+ "key_cache": getattr(dynamic_cache, "key_cache"),
99
+ "value_cache": getattr(dynamic_cache, "value_cache"),
100
+ }
101
+ return torch.utils._pytree._dict_flatten(dictionary)
102
+
103
+ def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
104
+ dictionary = {
105
+ "key_cache": getattr(dynamic_cache, "key_cache"),
106
+ "value_cache": getattr(dynamic_cache, "value_cache"),
107
+ }
108
+ return torch.utils._pytree._dict_flatten_with_keys(dictionary)
109
+
110
+ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
111
+ dictionary = torch.utils._pytree._dict_unflatten(values, context)
112
+ cache = DynamicCache()
113
+ for k, v in dictionary.items():
114
+ setattr(cache, k, v)
115
+ return cache
116
+
117
+ def _flatten_dynamic_cache_for_fx(cache, spec):
118
+ dictionary = {
119
+ "key_cache": getattr(cache, "key_cache"),
120
+ "value_cache": getattr(cache, "value_cache"),
121
+ }
122
+ return torch.fx._pytree._dict_flatten_spec(dictionary, spec)
123
+
124
+ torch.utils._pytree.register_pytree_node(
125
+ DynamicCache,
126
+ _flatten_dynamic_cache,
127
+ _unflatten_dynamic_cache,
128
+ serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
129
+ flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
130
+ )
131
+ # TODO: This won't be needed in torch 2.7+.
132
+ torch.fx._pytree.register_pytree_flatten_spec(
133
+ DynamicCache, _flatten_dynamic_cache_for_fx
134
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250713
3
+ Version: 0.1.0.dev250715
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=LDO-ry5VXGAZCkBH5Tb-m5CIIBwYSY1NqxdYJ6EtezQ,1743
1
+ tico/__init__.py,sha256=MYWB0f9ftIZXXj1q1Sdv4Qn0EGgO27twfOAD_gDGNVQ,1743
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -192,6 +192,7 @@ tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
192
192
  tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
193
193
  tico/utils/padding.py,sha256=jyNhGmlLZfruWZ6n5hll8RZOFg85iCZP8OJqnHGS97g,3293
194
194
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
195
+ tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
195
196
  tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
196
197
  tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
197
198
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
@@ -201,9 +202,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
201
202
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
202
203
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
203
204
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
204
- tico-0.1.0.dev250713.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
205
- tico-0.1.0.dev250713.dist-info/METADATA,sha256=FhV3CCFUbKieEpdRVt0MvAOV6sOVV-h3aKPBsrQttHY,8430
206
- tico-0.1.0.dev250713.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
207
- tico-0.1.0.dev250713.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
208
- tico-0.1.0.dev250713.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
209
- tico-0.1.0.dev250713.dist-info/RECORD,,
205
+ tico-0.1.0.dev250715.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
206
+ tico-0.1.0.dev250715.dist-info/METADATA,sha256=XNqsTtUt8jSqU2EsY3sm3RJDKwRGerDxGkH9eMXwOQk,8430
207
+ tico-0.1.0.dev250715.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
208
+ tico-0.1.0.dev250715.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
209
+ tico-0.1.0.dev250715.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
210
+ tico-0.1.0.dev250715.dist-info/RECORD,,