cache-dit 0.2.37__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +7 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +16 -6
- cache_dit/cache_factory/cache_adapters/__init__.py +2 -0
- cache_dit/cache_factory/{cache_adapters.py → cache_adapters/cache_adapter.py} +6 -6
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +3 -0
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +524 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +7 -0
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +13 -0
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +288 -0
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +799 -0
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +81 -0
- cache_dit/cache_factory/cache_contexts/v2/calibrators/base.py +27 -0
- cache_dit/cache_factory/cache_contexts/v2/calibrators/foca.py +26 -0
- cache_dit/cache_factory/cache_contexts/v2/calibrators/taylorseer.py +105 -0
- cache_dit/cache_factory/cache_interface.py +39 -12
- cache_dit/utils.py +17 -7
- {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/METADATA +57 -42
- {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/RECORD +24 -14
- {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -18,6 +18,9 @@ from cache_dit.cache_factory import BlockAdapter
|
|
|
18
18
|
from cache_dit.cache_factory import ParamsModifier
|
|
19
19
|
from cache_dit.cache_factory import ForwardPattern
|
|
20
20
|
from cache_dit.cache_factory import PatchFunctor
|
|
21
|
+
from cache_dit.cache_factory import CalibratorConfig
|
|
22
|
+
from cache_dit.cache_factory import TaylorSeerCalibratorConfig
|
|
23
|
+
from cache_dit.cache_factory import FoCaCalibratorConfig
|
|
21
24
|
from cache_dit.cache_factory import supported_pipelines
|
|
22
25
|
from cache_dit.cache_factory import get_adapter
|
|
23
26
|
from cache_dit.compile import set_compile_configs
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.3.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 1)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -12,9 +12,16 @@ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
|
12
12
|
|
|
13
13
|
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
14
14
|
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
15
|
+
from cache_dit.cache_factory.cache_contexts import CachedContextV2
|
|
16
|
+
from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
|
|
17
|
+
from cache_dit.cache_factory.cache_contexts import CalibratorConfig # no V1
|
|
18
|
+
from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
|
|
19
|
+
from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
|
|
20
|
+
|
|
15
21
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
16
22
|
|
|
17
23
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
24
|
+
from cache_dit.cache_factory.cache_adapters import CachedAdapterV2
|
|
18
25
|
|
|
19
26
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
20
27
|
from cache_dit.cache_factory.cache_interface import disable_cache
|
|
@@ -9,6 +9,7 @@ from typing import Any, Tuple, List, Optional, Union
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
10
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
11
11
|
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
12
|
+
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
12
13
|
|
|
13
14
|
from cache_dit.logger import init_logger
|
|
14
15
|
|
|
@@ -34,6 +35,8 @@ class ParamsModifier:
|
|
|
34
35
|
enable_encoder_taylorseer: Optional[bool] = None,
|
|
35
36
|
taylorseer_cache_type: Optional[str] = None,
|
|
36
37
|
taylorseer_order: Optional[int] = None,
|
|
38
|
+
# New param only for v2 API
|
|
39
|
+
calibrator_config: Optional[CalibratorConfig] = None,
|
|
37
40
|
**other_cache_context_kwargs,
|
|
38
41
|
):
|
|
39
42
|
self._context_kwargs = other_cache_context_kwargs.copy()
|
|
@@ -52,12 +55,19 @@ class ParamsModifier:
|
|
|
52
55
|
self._maybe_update_param(
|
|
53
56
|
"cfg_diff_compute_separate", cfg_diff_compute_separate
|
|
54
57
|
)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
58
|
+
# V1 only supports the Taylorseer calibrator. We have decided to
|
|
59
|
+
# keep this code for API compatibility reasons.
|
|
60
|
+
if calibrator_config is None:
|
|
61
|
+
self._maybe_update_param("enable_taylorseer", enable_taylorseer)
|
|
62
|
+
self._maybe_update_param(
|
|
63
|
+
"enable_encoder_taylorseer", enable_encoder_taylorseer
|
|
64
|
+
)
|
|
65
|
+
self._maybe_update_param(
|
|
66
|
+
"taylorseer_cache_type", taylorseer_cache_type
|
|
67
|
+
)
|
|
68
|
+
self._maybe_update_param("taylorseer_order", taylorseer_order)
|
|
69
|
+
else:
|
|
70
|
+
self._maybe_update_param("calibrator_config", calibrator_config)
|
|
61
71
|
|
|
62
72
|
def _maybe_update_param(self, key: str, value: Any):
|
|
63
73
|
if value is not None:
|
|
@@ -8,12 +8,12 @@ from typing import Dict, List, Tuple, Any, Union, Callable
|
|
|
8
8
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
10
|
|
|
11
|
-
from cache_dit.cache_factory import CacheType
|
|
12
|
-
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
-
from cache_dit.cache_factory import ParamsModifier
|
|
14
|
-
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
15
|
-
from cache_dit.cache_factory import CachedContextManager
|
|
16
|
-
from cache_dit.cache_factory import CachedBlocks
|
|
11
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
12
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
14
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
15
|
+
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
16
|
+
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
17
17
|
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
18
18
|
patch_cached_stats,
|
|
19
19
|
remove_cached_stats,
|
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
import unittest
|
|
4
|
+
import functools
|
|
5
|
+
|
|
6
|
+
from contextlib import ExitStack
|
|
7
|
+
from typing import Dict, List, Tuple, Any, Union, Callable
|
|
8
|
+
|
|
9
|
+
from diffusers import DiffusionPipeline
|
|
10
|
+
|
|
11
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
12
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
14
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
15
|
+
from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
|
|
16
|
+
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
17
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
18
|
+
patch_cached_stats,
|
|
19
|
+
remove_cached_stats,
|
|
20
|
+
)
|
|
21
|
+
from cache_dit.logger import init_logger
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Unified Cached Adapter
|
|
27
|
+
class CachedAdapterV2:
|
|
28
|
+
|
|
29
|
+
def __call__(self, *args, **kwargs):
|
|
30
|
+
return self.apply(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def apply(
|
|
34
|
+
cls,
|
|
35
|
+
pipe_or_adapter: Union[
|
|
36
|
+
DiffusionPipeline,
|
|
37
|
+
BlockAdapter,
|
|
38
|
+
],
|
|
39
|
+
**cache_context_kwargs,
|
|
40
|
+
) -> Union[
|
|
41
|
+
DiffusionPipeline,
|
|
42
|
+
BlockAdapter,
|
|
43
|
+
]:
|
|
44
|
+
assert (
|
|
45
|
+
pipe_or_adapter is not None
|
|
46
|
+
), "pipe or block_adapter can not both None!"
|
|
47
|
+
|
|
48
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
49
|
+
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
50
|
+
logger.info(
|
|
51
|
+
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
52
|
+
"supported by cache-dit. Use it's pre-defined BlockAdapter "
|
|
53
|
+
"directly!"
|
|
54
|
+
)
|
|
55
|
+
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
56
|
+
pipe_or_adapter
|
|
57
|
+
)
|
|
58
|
+
return cls.cachify(
|
|
59
|
+
block_adapter,
|
|
60
|
+
**cache_context_kwargs,
|
|
61
|
+
).pipe
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
65
|
+
"by cache-dit, please set BlockAdapter instead!"
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
assert isinstance(pipe_or_adapter, BlockAdapter)
|
|
69
|
+
logger.info(
|
|
70
|
+
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
71
|
+
)
|
|
72
|
+
return cls.cachify(
|
|
73
|
+
pipe_or_adapter,
|
|
74
|
+
**cache_context_kwargs,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def cachify(
|
|
79
|
+
cls,
|
|
80
|
+
block_adapter: BlockAdapter,
|
|
81
|
+
**cache_context_kwargs,
|
|
82
|
+
) -> BlockAdapter:
|
|
83
|
+
|
|
84
|
+
if block_adapter.auto:
|
|
85
|
+
block_adapter = BlockAdapter.auto_block_adapter(
|
|
86
|
+
block_adapter,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if BlockAdapter.check_block_adapter(block_adapter):
|
|
90
|
+
|
|
91
|
+
# 0. Must normalize block_adapter before apply cache
|
|
92
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
93
|
+
if BlockAdapter.is_cached(block_adapter):
|
|
94
|
+
return block_adapter
|
|
95
|
+
|
|
96
|
+
# 1. Apply cache on pipeline: wrap cache context, must
|
|
97
|
+
# call create_context before mock_blocks.
|
|
98
|
+
cls.create_context(
|
|
99
|
+
block_adapter,
|
|
100
|
+
**cache_context_kwargs,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# 2. Apply cache on transformer: mock cached blocks
|
|
104
|
+
cls.mock_blocks(
|
|
105
|
+
block_adapter,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return block_adapter
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def check_context_kwargs(
|
|
112
|
+
cls,
|
|
113
|
+
block_adapter: BlockAdapter,
|
|
114
|
+
**cache_context_kwargs,
|
|
115
|
+
):
|
|
116
|
+
# Check cache_context_kwargs
|
|
117
|
+
if cache_context_kwargs["enable_separate_cfg"] is None:
|
|
118
|
+
# Check cfg for some specific case if users don't set it as True
|
|
119
|
+
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
+
cache_context_kwargs["enable_separate_cfg"] = True
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
|
|
123
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
cache_context_kwargs["enable_separate_cfg"] = (
|
|
127
|
+
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
|
+
)
|
|
129
|
+
logger.info(
|
|
130
|
+
f"Use default 'enable_separate_cfg' from block adapter "
|
|
131
|
+
f"register: {cache_context_kwargs['enable_separate_cfg']}, "
|
|
132
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
|
+
)
|
|
134
|
+
else:
|
|
135
|
+
logger.info(
|
|
136
|
+
f"Use custom 'enable_separate_cfg' from cache context "
|
|
137
|
+
f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
|
|
138
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if (
|
|
142
|
+
cache_type := cache_context_kwargs.pop("cache_type", None)
|
|
143
|
+
) is not None:
|
|
144
|
+
assert (
|
|
145
|
+
cache_type == CacheType.DBCache
|
|
146
|
+
), "Custom cache setting only support for DBCache now!"
|
|
147
|
+
|
|
148
|
+
return cache_context_kwargs
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def create_context(
|
|
152
|
+
cls,
|
|
153
|
+
block_adapter: BlockAdapter,
|
|
154
|
+
**cache_context_kwargs,
|
|
155
|
+
) -> DiffusionPipeline:
|
|
156
|
+
|
|
157
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
158
|
+
|
|
159
|
+
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
160
|
+
return block_adapter.pipe
|
|
161
|
+
|
|
162
|
+
# Check cache_context_kwargs
|
|
163
|
+
cache_context_kwargs = cls.check_context_kwargs(
|
|
164
|
+
block_adapter, **cache_context_kwargs
|
|
165
|
+
)
|
|
166
|
+
# Apply cache on pipeline: wrap cache context
|
|
167
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
168
|
+
|
|
169
|
+
# Each Pipeline should have it's own context manager instance.
|
|
170
|
+
# Different transformers (Wan2.2, etc) should shared the same
|
|
171
|
+
# cache manager but with different cache context (according
|
|
172
|
+
# to their unique instance id).
|
|
173
|
+
cache_manager = CachedContextManagerV2(
|
|
174
|
+
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
175
|
+
)
|
|
176
|
+
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
177
|
+
|
|
178
|
+
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
179
|
+
block_adapter, cache_manager, **cache_context_kwargs
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
183
|
+
|
|
184
|
+
@functools.wraps(original_call)
|
|
185
|
+
def new_call(self, *args, **kwargs):
|
|
186
|
+
with ExitStack() as stack:
|
|
187
|
+
# cache context will be reset for each pipe inference
|
|
188
|
+
for context_name, context_kwargs in zip(
|
|
189
|
+
flatten_contexts, contexts_kwargs
|
|
190
|
+
):
|
|
191
|
+
stack.enter_context(
|
|
192
|
+
cache_manager.enter_context(
|
|
193
|
+
cache_manager.reset_context(
|
|
194
|
+
context_name,
|
|
195
|
+
**context_kwargs,
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
outputs = original_call(self, *args, **kwargs)
|
|
200
|
+
cls.apply_stats_hooks(block_adapter)
|
|
201
|
+
return outputs
|
|
202
|
+
|
|
203
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
204
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
205
|
+
block_adapter.pipe.__class__._is_cached = True
|
|
206
|
+
|
|
207
|
+
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
208
|
+
|
|
209
|
+
return block_adapter.pipe
|
|
210
|
+
|
|
211
|
+
@classmethod
|
|
212
|
+
def modify_context_params(
|
|
213
|
+
cls,
|
|
214
|
+
block_adapter: BlockAdapter,
|
|
215
|
+
cache_manager: CachedContextManagerV2,
|
|
216
|
+
**cache_context_kwargs,
|
|
217
|
+
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
218
|
+
|
|
219
|
+
flatten_contexts = BlockAdapter.flatten(
|
|
220
|
+
block_adapter.unique_blocks_name
|
|
221
|
+
)
|
|
222
|
+
contexts_kwargs = [
|
|
223
|
+
cache_context_kwargs.copy()
|
|
224
|
+
for _ in range(
|
|
225
|
+
len(flatten_contexts),
|
|
226
|
+
)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
for i in range(len(contexts_kwargs)):
|
|
230
|
+
contexts_kwargs[i]["name"] = flatten_contexts[i]
|
|
231
|
+
|
|
232
|
+
if block_adapter.params_modifiers is None:
|
|
233
|
+
return flatten_contexts, contexts_kwargs
|
|
234
|
+
|
|
235
|
+
flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
|
|
236
|
+
block_adapter.params_modifiers,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
for i in range(
|
|
240
|
+
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
241
|
+
):
|
|
242
|
+
contexts_kwargs[i].update(
|
|
243
|
+
flatten_modifiers[i]._context_kwargs,
|
|
244
|
+
)
|
|
245
|
+
contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
|
|
246
|
+
default_attrs={}, **contexts_kwargs[i]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return flatten_contexts, contexts_kwargs
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def mock_blocks(
|
|
253
|
+
cls,
|
|
254
|
+
block_adapter: BlockAdapter,
|
|
255
|
+
) -> List[torch.nn.Module]:
|
|
256
|
+
|
|
257
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
258
|
+
|
|
259
|
+
if BlockAdapter.is_cached(block_adapter.transformer):
|
|
260
|
+
return block_adapter.transformer
|
|
261
|
+
|
|
262
|
+
# Apply cache on transformer: mock cached transformer blocks
|
|
263
|
+
for (
|
|
264
|
+
cached_blocks,
|
|
265
|
+
transformer,
|
|
266
|
+
blocks_name,
|
|
267
|
+
unique_blocks_name,
|
|
268
|
+
dummy_blocks_names,
|
|
269
|
+
) in zip(
|
|
270
|
+
cls.collect_cached_blocks(block_adapter),
|
|
271
|
+
block_adapter.transformer,
|
|
272
|
+
block_adapter.blocks_name,
|
|
273
|
+
block_adapter.unique_blocks_name,
|
|
274
|
+
block_adapter.dummy_blocks_names,
|
|
275
|
+
):
|
|
276
|
+
cls.mock_transformer(
|
|
277
|
+
cached_blocks,
|
|
278
|
+
transformer,
|
|
279
|
+
blocks_name,
|
|
280
|
+
unique_blocks_name,
|
|
281
|
+
dummy_blocks_names,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return block_adapter.transformer
|
|
285
|
+
|
|
286
|
+
@classmethod
|
|
287
|
+
def mock_transformer(
|
|
288
|
+
cls,
|
|
289
|
+
cached_blocks: Dict[str, torch.nn.ModuleList],
|
|
290
|
+
transformer: torch.nn.Module,
|
|
291
|
+
blocks_name: List[str],
|
|
292
|
+
unique_blocks_name: List[str],
|
|
293
|
+
dummy_blocks_names: List[str],
|
|
294
|
+
) -> torch.nn.Module:
|
|
295
|
+
dummy_blocks = torch.nn.ModuleList()
|
|
296
|
+
|
|
297
|
+
original_forward = transformer.forward
|
|
298
|
+
|
|
299
|
+
assert isinstance(dummy_blocks_names, list)
|
|
300
|
+
|
|
301
|
+
@functools.wraps(original_forward)
|
|
302
|
+
def new_forward(self, *args, **kwargs):
|
|
303
|
+
with ExitStack() as stack:
|
|
304
|
+
for name, context_name in zip(
|
|
305
|
+
blocks_name,
|
|
306
|
+
unique_blocks_name,
|
|
307
|
+
):
|
|
308
|
+
stack.enter_context(
|
|
309
|
+
unittest.mock.patch.object(
|
|
310
|
+
self, name, cached_blocks[context_name]
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
for dummy_name in dummy_blocks_names:
|
|
314
|
+
stack.enter_context(
|
|
315
|
+
unittest.mock.patch.object(
|
|
316
|
+
self, dummy_name, dummy_blocks
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
return original_forward(*args, **kwargs)
|
|
320
|
+
|
|
321
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
322
|
+
transformer._original_forward = original_forward
|
|
323
|
+
transformer._is_cached = True
|
|
324
|
+
|
|
325
|
+
return transformer
|
|
326
|
+
|
|
327
|
+
@classmethod
|
|
328
|
+
def collect_cached_blocks(
|
|
329
|
+
cls,
|
|
330
|
+
block_adapter: BlockAdapter,
|
|
331
|
+
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
332
|
+
|
|
333
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
334
|
+
|
|
335
|
+
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
336
|
+
assert hasattr(block_adapter.pipe, "_cache_manager")
|
|
337
|
+
assert isinstance(
|
|
338
|
+
block_adapter.pipe._cache_manager, CachedContextManagerV2
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
for i in range(len(block_adapter.transformer)):
|
|
342
|
+
|
|
343
|
+
cached_blocks_bind_context = {}
|
|
344
|
+
for j in range(len(block_adapter.blocks[i])):
|
|
345
|
+
cached_blocks_bind_context[
|
|
346
|
+
block_adapter.unique_blocks_name[i][j]
|
|
347
|
+
] = torch.nn.ModuleList(
|
|
348
|
+
[
|
|
349
|
+
CachedBlocks(
|
|
350
|
+
# 0. Transformer blocks configuration
|
|
351
|
+
block_adapter.blocks[i][j],
|
|
352
|
+
transformer=block_adapter.transformer[i],
|
|
353
|
+
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
354
|
+
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
355
|
+
check_num_outputs=block_adapter.check_num_outputs,
|
|
356
|
+
# 1. Cache context configuration
|
|
357
|
+
cache_prefix=block_adapter.blocks_name[i][j],
|
|
358
|
+
cache_context=block_adapter.unique_blocks_name[i][
|
|
359
|
+
j
|
|
360
|
+
],
|
|
361
|
+
cache_manager=block_adapter.pipe._cache_manager,
|
|
362
|
+
)
|
|
363
|
+
]
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
total_cached_blocks.append(cached_blocks_bind_context)
|
|
367
|
+
|
|
368
|
+
return total_cached_blocks
|
|
369
|
+
|
|
370
|
+
@classmethod
|
|
371
|
+
def apply_params_hooks(
|
|
372
|
+
cls,
|
|
373
|
+
block_adapter: BlockAdapter,
|
|
374
|
+
contexts_kwargs: List[Dict],
|
|
375
|
+
):
|
|
376
|
+
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
377
|
+
|
|
378
|
+
params_shift = 0
|
|
379
|
+
for i in range(len(block_adapter.transformer)):
|
|
380
|
+
|
|
381
|
+
block_adapter.transformer[i]._forward_pattern = (
|
|
382
|
+
block_adapter.forward_pattern
|
|
383
|
+
)
|
|
384
|
+
block_adapter.transformer[i]._has_separate_cfg = (
|
|
385
|
+
block_adapter.has_separate_cfg
|
|
386
|
+
)
|
|
387
|
+
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
388
|
+
contexts_kwargs[params_shift]
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
blocks = block_adapter.blocks[i]
|
|
392
|
+
for j in range(len(blocks)):
|
|
393
|
+
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
394
|
+
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
395
|
+
params_shift + j
|
|
396
|
+
]
|
|
397
|
+
|
|
398
|
+
params_shift += len(blocks)
|
|
399
|
+
|
|
400
|
+
@classmethod
|
|
401
|
+
def apply_stats_hooks(
|
|
402
|
+
cls,
|
|
403
|
+
block_adapter: BlockAdapter,
|
|
404
|
+
):
|
|
405
|
+
cache_manager = block_adapter.pipe._cache_manager
|
|
406
|
+
|
|
407
|
+
for i in range(len(block_adapter.transformer)):
|
|
408
|
+
patch_cached_stats(
|
|
409
|
+
block_adapter.transformer[i],
|
|
410
|
+
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
411
|
+
cache_manager=cache_manager,
|
|
412
|
+
)
|
|
413
|
+
for blocks, unique_name in zip(
|
|
414
|
+
block_adapter.blocks[i],
|
|
415
|
+
block_adapter.unique_blocks_name[i],
|
|
416
|
+
):
|
|
417
|
+
patch_cached_stats(
|
|
418
|
+
blocks,
|
|
419
|
+
cache_context=unique_name,
|
|
420
|
+
cache_manager=cache_manager,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
@classmethod
|
|
424
|
+
def maybe_release_hooks(
|
|
425
|
+
cls,
|
|
426
|
+
pipe_or_adapter: Union[
|
|
427
|
+
DiffusionPipeline,
|
|
428
|
+
BlockAdapter,
|
|
429
|
+
],
|
|
430
|
+
):
|
|
431
|
+
# release model hooks
|
|
432
|
+
def _release_blocks_hooks(blocks):
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
def _release_transformer_hooks(transformer):
|
|
436
|
+
if hasattr(transformer, "_original_forward"):
|
|
437
|
+
original_forward = transformer._original_forward
|
|
438
|
+
transformer.forward = original_forward.__get__(transformer)
|
|
439
|
+
del transformer._original_forward
|
|
440
|
+
if hasattr(transformer, "_is_cached"):
|
|
441
|
+
del transformer._is_cached
|
|
442
|
+
|
|
443
|
+
def _release_pipeline_hooks(pipe):
|
|
444
|
+
if hasattr(pipe, "_original_call"):
|
|
445
|
+
original_call = pipe.__class__._original_call
|
|
446
|
+
pipe.__class__.__call__ = original_call
|
|
447
|
+
del pipe.__class__._original_call
|
|
448
|
+
if hasattr(pipe, "_cache_manager"):
|
|
449
|
+
cache_manager = pipe._cache_manager
|
|
450
|
+
if isinstance(cache_manager, CachedContextManagerV2):
|
|
451
|
+
cache_manager.clear_contexts()
|
|
452
|
+
del pipe._cache_manager
|
|
453
|
+
if hasattr(pipe, "_is_cached"):
|
|
454
|
+
del pipe.__class__._is_cached
|
|
455
|
+
|
|
456
|
+
cls.release_hooks(
|
|
457
|
+
pipe_or_adapter,
|
|
458
|
+
_release_blocks_hooks,
|
|
459
|
+
_release_transformer_hooks,
|
|
460
|
+
_release_pipeline_hooks,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# release params hooks
|
|
464
|
+
def _release_blocks_params(blocks):
|
|
465
|
+
if hasattr(blocks, "_forward_pattern"):
|
|
466
|
+
del blocks._forward_pattern
|
|
467
|
+
if hasattr(blocks, "_cache_context_kwargs"):
|
|
468
|
+
del blocks._cache_context_kwargs
|
|
469
|
+
|
|
470
|
+
def _release_transformer_params(transformer):
|
|
471
|
+
if hasattr(transformer, "_forward_pattern"):
|
|
472
|
+
del transformer._forward_pattern
|
|
473
|
+
if hasattr(transformer, "_has_separate_cfg"):
|
|
474
|
+
del transformer._has_separate_cfg
|
|
475
|
+
if hasattr(transformer, "_cache_context_kwargs"):
|
|
476
|
+
del transformer._cache_context_kwargs
|
|
477
|
+
for blocks in BlockAdapter.find_blocks(transformer):
|
|
478
|
+
_release_blocks_params(blocks)
|
|
479
|
+
|
|
480
|
+
def _release_pipeline_params(pipe):
|
|
481
|
+
if hasattr(pipe, "_cache_context_kwargs"):
|
|
482
|
+
del pipe._cache_context_kwargs
|
|
483
|
+
|
|
484
|
+
cls.release_hooks(
|
|
485
|
+
pipe_or_adapter,
|
|
486
|
+
_release_blocks_params,
|
|
487
|
+
_release_transformer_params,
|
|
488
|
+
_release_pipeline_params,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# release stats hooks
|
|
492
|
+
cls.release_hooks(
|
|
493
|
+
pipe_or_adapter,
|
|
494
|
+
remove_cached_stats,
|
|
495
|
+
remove_cached_stats,
|
|
496
|
+
remove_cached_stats,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
@classmethod
|
|
500
|
+
def release_hooks(
|
|
501
|
+
cls,
|
|
502
|
+
pipe_or_adapter: Union[
|
|
503
|
+
DiffusionPipeline,
|
|
504
|
+
BlockAdapter,
|
|
505
|
+
],
|
|
506
|
+
_release_blocks: Callable,
|
|
507
|
+
_release_transformer: Callable,
|
|
508
|
+
_release_pipeline: Callable,
|
|
509
|
+
):
|
|
510
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
511
|
+
pipe = pipe_or_adapter
|
|
512
|
+
_release_pipeline(pipe)
|
|
513
|
+
if hasattr(pipe, "transformer"):
|
|
514
|
+
_release_transformer(pipe.transformer)
|
|
515
|
+
if hasattr(pipe, "transformer_2"): # Wan 2.2
|
|
516
|
+
_release_transformer(pipe.transformer_2)
|
|
517
|
+
elif isinstance(pipe_or_adapter, BlockAdapter):
|
|
518
|
+
adapter = pipe_or_adapter
|
|
519
|
+
BlockAdapter.assert_normalized(adapter)
|
|
520
|
+
_release_pipeline(adapter.pipe)
|
|
521
|
+
for transformer in BlockAdapter.flatten(adapter.transformer):
|
|
522
|
+
_release_transformer(transformer)
|
|
523
|
+
for blocks in BlockAdapter.flatten(adapter.blocks):
|
|
524
|
+
_release_blocks(blocks)
|
|
@@ -3,3 +3,10 @@ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
|
3
3
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
4
4
|
CachedContextManager,
|
|
5
5
|
)
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.v2 import (
|
|
7
|
+
CachedContextV2,
|
|
8
|
+
CachedContextManagerV2,
|
|
9
|
+
CalibratorConfig,
|
|
10
|
+
TaylorSeerCalibratorConfig,
|
|
11
|
+
FoCaCalibratorConfig,
|
|
12
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from cache_dit.cache_factory.cache_contexts.v2.calibrators import (
|
|
2
|
+
Calibrator,
|
|
3
|
+
CalibratorBase,
|
|
4
|
+
CalibratorConfig,
|
|
5
|
+
TaylorSeerCalibratorConfig,
|
|
6
|
+
FoCaCalibratorConfig,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.v2.cache_context_v2 import (
|
|
9
|
+
CachedContextV2,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.cache_factory.cache_contexts.v2.cache_manager_v2 import (
|
|
12
|
+
CachedContextManagerV2,
|
|
13
|
+
)
|