cache-dit 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +8 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +47 -14
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/METADATA +78 -64
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/RECORD +23 -28
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -18,6 +18,7 @@ from cache_dit.cache_factory import BlockAdapter
|
|
|
18
18
|
from cache_dit.cache_factory import ParamsModifier
|
|
19
19
|
from cache_dit.cache_factory import ForwardPattern
|
|
20
20
|
from cache_dit.cache_factory import PatchFunctor
|
|
21
|
+
from cache_dit.cache_factory import BasicCacheConfig
|
|
21
22
|
from cache_dit.cache_factory import CalibratorConfig
|
|
22
23
|
from cache_dit.cache_factory import TaylorSeerCalibratorConfig
|
|
23
24
|
from cache_dit.cache_factory import FoCaCalibratorConfig
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 2)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -3,25 +3,22 @@ from cache_dit.cache_factory.cache_types import cache_type
|
|
|
3
3
|
from cache_dit.cache_factory.cache_types import block_range
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
6
|
-
|
|
6
|
+
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
7
7
|
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
8
8
|
|
|
9
9
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
10
|
-
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
11
10
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
12
11
|
|
|
13
12
|
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
13
|
+
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
14
14
|
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
15
|
-
from cache_dit.cache_factory.cache_contexts import
|
|
16
|
-
from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
|
|
17
|
-
from cache_dit.cache_factory.cache_contexts import CalibratorConfig # no V1
|
|
15
|
+
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
18
16
|
from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
|
|
19
17
|
from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
|
|
20
18
|
|
|
21
19
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
22
20
|
|
|
23
21
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
24
|
-
from cache_dit.cache_factory.cache_adapters import CachedAdapterV2
|
|
25
22
|
|
|
26
23
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
27
24
|
from cache_dit.cache_factory.cache_interface import disable_cache
|
|
@@ -7,73 +7,15 @@ from collections.abc import Iterable
|
|
|
7
7
|
from typing import Any, Tuple, List, Optional, Union
|
|
8
8
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
|
-
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
11
10
|
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
12
|
-
from cache_dit.cache_factory.
|
|
11
|
+
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
12
|
+
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
13
13
|
|
|
14
14
|
from cache_dit.logger import init_logger
|
|
15
15
|
|
|
16
16
|
logger = init_logger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class ParamsModifier:
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
# Cache context kwargs
|
|
23
|
-
Fn_compute_blocks: Optional[int] = None,
|
|
24
|
-
Bn_compute_blocks: Optional[int] = None,
|
|
25
|
-
max_warmup_steps: Optional[int] = None,
|
|
26
|
-
max_cached_steps: Optional[int] = None,
|
|
27
|
-
max_continuous_cached_steps: Optional[int] = None,
|
|
28
|
-
residual_diff_threshold: Optional[float] = None,
|
|
29
|
-
# Cache CFG or not
|
|
30
|
-
enable_separate_cfg: Optional[bool] = None,
|
|
31
|
-
cfg_compute_first: Optional[bool] = None,
|
|
32
|
-
cfg_diff_compute_separate: Optional[bool] = None,
|
|
33
|
-
# Hybird TaylorSeer
|
|
34
|
-
enable_taylorseer: Optional[bool] = None,
|
|
35
|
-
enable_encoder_taylorseer: Optional[bool] = None,
|
|
36
|
-
taylorseer_cache_type: Optional[str] = None,
|
|
37
|
-
taylorseer_order: Optional[int] = None,
|
|
38
|
-
# New param only for v2 API
|
|
39
|
-
calibrator_config: Optional[CalibratorConfig] = None,
|
|
40
|
-
**other_cache_context_kwargs,
|
|
41
|
-
):
|
|
42
|
-
self._context_kwargs = other_cache_context_kwargs.copy()
|
|
43
|
-
self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
|
|
44
|
-
self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
|
|
45
|
-
self._maybe_update_param("max_warmup_steps", max_warmup_steps)
|
|
46
|
-
self._maybe_update_param("max_cached_steps", max_cached_steps)
|
|
47
|
-
self._maybe_update_param(
|
|
48
|
-
"max_continuous_cached_steps", max_continuous_cached_steps
|
|
49
|
-
)
|
|
50
|
-
self._maybe_update_param(
|
|
51
|
-
"residual_diff_threshold", residual_diff_threshold
|
|
52
|
-
)
|
|
53
|
-
self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
|
|
54
|
-
self._maybe_update_param("cfg_compute_first", cfg_compute_first)
|
|
55
|
-
self._maybe_update_param(
|
|
56
|
-
"cfg_diff_compute_separate", cfg_diff_compute_separate
|
|
57
|
-
)
|
|
58
|
-
# V1 only supports the Taylorseer calibrator. We have decided to
|
|
59
|
-
# keep this code for API compatibility reasons.
|
|
60
|
-
if calibrator_config is None:
|
|
61
|
-
self._maybe_update_param("enable_taylorseer", enable_taylorseer)
|
|
62
|
-
self._maybe_update_param(
|
|
63
|
-
"enable_encoder_taylorseer", enable_encoder_taylorseer
|
|
64
|
-
)
|
|
65
|
-
self._maybe_update_param(
|
|
66
|
-
"taylorseer_cache_type", taylorseer_cache_type
|
|
67
|
-
)
|
|
68
|
-
self._maybe_update_param("taylorseer_order", taylorseer_order)
|
|
69
|
-
else:
|
|
70
|
-
self._maybe_update_param("calibrator_config", calibrator_config)
|
|
71
|
-
|
|
72
|
-
def _maybe_update_param(self, key: str, value: Any):
|
|
73
|
-
if value is not None:
|
|
74
|
-
self._context_kwargs[key] = value
|
|
75
|
-
|
|
76
|
-
|
|
77
19
|
@dataclasses.dataclass
|
|
78
20
|
class BlockAdapter:
|
|
79
21
|
|
|
@@ -123,10 +65,12 @@ class BlockAdapter:
|
|
|
123
65
|
] = None
|
|
124
66
|
|
|
125
67
|
# modify cache context params for specific blocks.
|
|
126
|
-
params_modifiers:
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
68
|
+
params_modifiers: Optional[
|
|
69
|
+
Union[
|
|
70
|
+
ParamsModifier,
|
|
71
|
+
List[ParamsModifier],
|
|
72
|
+
List[List[ParamsModifier]],
|
|
73
|
+
]
|
|
130
74
|
] = None
|
|
131
75
|
|
|
132
76
|
check_forward_pattern: bool = True
|
|
@@ -13,6 +13,8 @@ from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
|
13
13
|
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
14
14
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
15
15
|
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
16
|
+
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
17
|
+
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
16
18
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
17
19
|
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
18
20
|
patch_cached_stats,
|
|
@@ -55,6 +57,12 @@ class CachedAdapter:
|
|
|
55
57
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
56
58
|
pipe_or_adapter
|
|
57
59
|
)
|
|
60
|
+
if params_modifiers := cache_context_kwargs.pop(
|
|
61
|
+
"params_modifiers",
|
|
62
|
+
None,
|
|
63
|
+
):
|
|
64
|
+
block_adapter.params_modifiers = params_modifiers
|
|
65
|
+
|
|
58
66
|
return cls.cachify(
|
|
59
67
|
block_adapter,
|
|
60
68
|
**cache_context_kwargs,
|
|
@@ -69,6 +77,12 @@ class CachedAdapter:
|
|
|
69
77
|
logger.info(
|
|
70
78
|
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
71
79
|
)
|
|
80
|
+
if pipe_or_adapter.params_modifiers is None:
|
|
81
|
+
if params_modifiers := cache_context_kwargs.pop(
|
|
82
|
+
"params_modifiers", None
|
|
83
|
+
):
|
|
84
|
+
pipe_or_adapter.params_modifiers = params_modifiers
|
|
85
|
+
|
|
72
86
|
return cls.cachify(
|
|
73
87
|
pipe_or_adapter,
|
|
74
88
|
**cache_context_kwargs,
|
|
@@ -114,33 +128,36 @@ class CachedAdapter:
|
|
|
114
128
|
**cache_context_kwargs,
|
|
115
129
|
):
|
|
116
130
|
# Check cache_context_kwargs
|
|
117
|
-
|
|
131
|
+
cache_config: BasicCacheConfig = cache_context_kwargs[
|
|
132
|
+
"cache_config"
|
|
133
|
+
] # ref
|
|
134
|
+
assert cache_config is not None, "cache_config can not be None."
|
|
135
|
+
if cache_config.enable_separate_cfg is None:
|
|
118
136
|
# Check cfg for some specific case if users don't set it as True
|
|
119
137
|
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
-
|
|
138
|
+
cache_config.enable_separate_cfg = True
|
|
121
139
|
logger.info(
|
|
122
140
|
f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
|
|
123
141
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
142
|
)
|
|
125
143
|
else:
|
|
126
|
-
|
|
144
|
+
cache_config.enable_separate_cfg = (
|
|
127
145
|
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
146
|
)
|
|
129
147
|
logger.info(
|
|
130
148
|
f"Use default 'enable_separate_cfg' from block adapter "
|
|
131
|
-
f"register: {
|
|
149
|
+
f"register: {cache_config.enable_separate_cfg}, "
|
|
132
150
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
151
|
)
|
|
134
152
|
else:
|
|
135
153
|
logger.info(
|
|
136
154
|
f"Use custom 'enable_separate_cfg' from cache context "
|
|
137
|
-
f"kwargs: {
|
|
155
|
+
f"kwargs: {cache_config.enable_separate_cfg}. "
|
|
138
156
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
157
|
)
|
|
140
158
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
) is not None:
|
|
159
|
+
cache_type = cache_context_kwargs.pop("cache_type", None)
|
|
160
|
+
if cache_type is not None:
|
|
144
161
|
assert (
|
|
145
162
|
cache_type == CacheType.DBCache
|
|
146
163
|
), "Custom cache setting only support for DBCache now!"
|
|
@@ -176,7 +193,7 @@ class CachedAdapter:
|
|
|
176
193
|
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
177
194
|
|
|
178
195
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
179
|
-
block_adapter,
|
|
196
|
+
block_adapter, **cache_context_kwargs
|
|
180
197
|
)
|
|
181
198
|
|
|
182
199
|
original_call = block_adapter.pipe.__class__.__call__
|
|
@@ -212,7 +229,6 @@ class CachedAdapter:
|
|
|
212
229
|
def modify_context_params(
|
|
213
230
|
cls,
|
|
214
231
|
block_adapter: BlockAdapter,
|
|
215
|
-
cache_manager: CachedContextManager,
|
|
216
232
|
**cache_context_kwargs,
|
|
217
233
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
218
234
|
|
|
@@ -230,6 +246,8 @@ class CachedAdapter:
|
|
|
230
246
|
contexts_kwargs[i]["name"] = flatten_contexts[i]
|
|
231
247
|
|
|
232
248
|
if block_adapter.params_modifiers is None:
|
|
249
|
+
for i in range(len(contexts_kwargs)):
|
|
250
|
+
cls._config_messages(**contexts_kwargs[i])
|
|
233
251
|
return flatten_contexts, contexts_kwargs
|
|
234
252
|
|
|
235
253
|
flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
|
|
@@ -242,12 +260,26 @@ class CachedAdapter:
|
|
|
242
260
|
contexts_kwargs[i].update(
|
|
243
261
|
flatten_modifiers[i]._context_kwargs,
|
|
244
262
|
)
|
|
245
|
-
contexts_kwargs[i]
|
|
246
|
-
default_attrs={}, **contexts_kwargs[i]
|
|
247
|
-
)
|
|
263
|
+
cls._config_messages(**contexts_kwargs[i])
|
|
248
264
|
|
|
249
265
|
return flatten_contexts, contexts_kwargs
|
|
250
266
|
|
|
267
|
+
@classmethod
|
|
268
|
+
def _config_messages(cls, **contexts_kwargs):
|
|
269
|
+
cache_config: BasicCacheConfig = contexts_kwargs.get(
|
|
270
|
+
"cache_config", None
|
|
271
|
+
)
|
|
272
|
+
calibrator_config: CalibratorConfig = contexts_kwargs.get(
|
|
273
|
+
"calibrator_config", None
|
|
274
|
+
)
|
|
275
|
+
if cache_config is not None:
|
|
276
|
+
message = f"Collected Cache Config: {cache_config.strify()}"
|
|
277
|
+
if calibrator_config is not None:
|
|
278
|
+
message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
|
|
279
|
+
else:
|
|
280
|
+
message += ", Calibrator Config: None"
|
|
281
|
+
logger.info(message)
|
|
282
|
+
|
|
251
283
|
@classmethod
|
|
252
284
|
def mock_blocks(
|
|
253
285
|
cls,
|
|
@@ -335,7 +367,8 @@ class CachedAdapter:
|
|
|
335
367
|
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
336
368
|
assert hasattr(block_adapter.pipe, "_cache_manager")
|
|
337
369
|
assert isinstance(
|
|
338
|
-
block_adapter.pipe._cache_manager,
|
|
370
|
+
block_adapter.pipe._cache_manager,
|
|
371
|
+
CachedContextManager,
|
|
339
372
|
)
|
|
340
373
|
|
|
341
374
|
for i in range(len(block_adapter.transformer)):
|
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
CachedContextManager,
|
|
5
|
-
)
|
|
6
|
-
from cache_dit.cache_factory.cache_contexts.v2 import (
|
|
7
|
-
CachedContextV2,
|
|
8
|
-
CachedContextManagerV2,
|
|
1
|
+
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
2
|
+
Calibrator,
|
|
3
|
+
CalibratorBase,
|
|
9
4
|
CalibratorConfig,
|
|
10
5
|
TaylorSeerCalibratorConfig,
|
|
11
6
|
FoCaCalibratorConfig,
|
|
12
7
|
)
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
9
|
+
CachedContext,
|
|
10
|
+
BasicCacheConfig,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
13
|
+
CachedContextManager,
|
|
14
|
+
)
|