cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
- cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
- cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
- cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -18,6 +18,7 @@ from cache_dit.cache_factory import BlockAdapter
|
|
|
18
18
|
from cache_dit.cache_factory import ParamsModifier
|
|
19
19
|
from cache_dit.cache_factory import ForwardPattern
|
|
20
20
|
from cache_dit.cache_factory import PatchFunctor
|
|
21
|
+
from cache_dit.cache_factory import BasicCacheConfig
|
|
21
22
|
from cache_dit.cache_factory import CalibratorConfig
|
|
22
23
|
from cache_dit.cache_factory import TaylorSeerCalibratorConfig
|
|
23
24
|
from cache_dit.cache_factory import FoCaCalibratorConfig
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.3'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 3)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -3,25 +3,22 @@ from cache_dit.cache_factory.cache_types import cache_type
|
|
|
3
3
|
from cache_dit.cache_factory.cache_types import block_range
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
6
|
-
|
|
6
|
+
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
7
7
|
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
8
8
|
|
|
9
9
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
10
|
-
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
11
10
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
12
11
|
|
|
13
12
|
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
13
|
+
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
14
14
|
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
15
|
-
from cache_dit.cache_factory.cache_contexts import
|
|
16
|
-
from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
|
|
17
|
-
from cache_dit.cache_factory.cache_contexts import CalibratorConfig # no V1
|
|
15
|
+
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
18
16
|
from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
|
|
19
17
|
from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
|
|
20
18
|
|
|
21
19
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
22
20
|
|
|
23
21
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
24
|
-
from cache_dit.cache_factory.cache_adapters import CachedAdapterV2
|
|
25
22
|
|
|
26
23
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
27
24
|
from cache_dit.cache_factory.cache_interface import disable_cache
|
|
@@ -7,73 +7,15 @@ from collections.abc import Iterable
|
|
|
7
7
|
from typing import Any, Tuple, List, Optional, Union
|
|
8
8
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
|
-
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
11
10
|
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
12
|
-
from cache_dit.cache_factory.
|
|
11
|
+
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
12
|
+
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
13
13
|
|
|
14
14
|
from cache_dit.logger import init_logger
|
|
15
15
|
|
|
16
16
|
logger = init_logger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
class ParamsModifier:
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
# Cache context kwargs
|
|
23
|
-
Fn_compute_blocks: Optional[int] = None,
|
|
24
|
-
Bn_compute_blocks: Optional[int] = None,
|
|
25
|
-
max_warmup_steps: Optional[int] = None,
|
|
26
|
-
max_cached_steps: Optional[int] = None,
|
|
27
|
-
max_continuous_cached_steps: Optional[int] = None,
|
|
28
|
-
residual_diff_threshold: Optional[float] = None,
|
|
29
|
-
# Cache CFG or not
|
|
30
|
-
enable_separate_cfg: Optional[bool] = None,
|
|
31
|
-
cfg_compute_first: Optional[bool] = None,
|
|
32
|
-
cfg_diff_compute_separate: Optional[bool] = None,
|
|
33
|
-
# Hybird TaylorSeer
|
|
34
|
-
enable_taylorseer: Optional[bool] = None,
|
|
35
|
-
enable_encoder_taylorseer: Optional[bool] = None,
|
|
36
|
-
taylorseer_cache_type: Optional[str] = None,
|
|
37
|
-
taylorseer_order: Optional[int] = None,
|
|
38
|
-
# New param only for v2 API
|
|
39
|
-
calibrator_config: Optional[CalibratorConfig] = None,
|
|
40
|
-
**other_cache_context_kwargs,
|
|
41
|
-
):
|
|
42
|
-
self._context_kwargs = other_cache_context_kwargs.copy()
|
|
43
|
-
self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
|
|
44
|
-
self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
|
|
45
|
-
self._maybe_update_param("max_warmup_steps", max_warmup_steps)
|
|
46
|
-
self._maybe_update_param("max_cached_steps", max_cached_steps)
|
|
47
|
-
self._maybe_update_param(
|
|
48
|
-
"max_continuous_cached_steps", max_continuous_cached_steps
|
|
49
|
-
)
|
|
50
|
-
self._maybe_update_param(
|
|
51
|
-
"residual_diff_threshold", residual_diff_threshold
|
|
52
|
-
)
|
|
53
|
-
self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
|
|
54
|
-
self._maybe_update_param("cfg_compute_first", cfg_compute_first)
|
|
55
|
-
self._maybe_update_param(
|
|
56
|
-
"cfg_diff_compute_separate", cfg_diff_compute_separate
|
|
57
|
-
)
|
|
58
|
-
# V1 only supports the Taylorseer calibrator. We have decided to
|
|
59
|
-
# keep this code for API compatibility reasons.
|
|
60
|
-
if calibrator_config is None:
|
|
61
|
-
self._maybe_update_param("enable_taylorseer", enable_taylorseer)
|
|
62
|
-
self._maybe_update_param(
|
|
63
|
-
"enable_encoder_taylorseer", enable_encoder_taylorseer
|
|
64
|
-
)
|
|
65
|
-
self._maybe_update_param(
|
|
66
|
-
"taylorseer_cache_type", taylorseer_cache_type
|
|
67
|
-
)
|
|
68
|
-
self._maybe_update_param("taylorseer_order", taylorseer_order)
|
|
69
|
-
else:
|
|
70
|
-
self._maybe_update_param("calibrator_config", calibrator_config)
|
|
71
|
-
|
|
72
|
-
def _maybe_update_param(self, key: str, value: Any):
|
|
73
|
-
if value is not None:
|
|
74
|
-
self._context_kwargs[key] = value
|
|
75
|
-
|
|
76
|
-
|
|
77
19
|
@dataclasses.dataclass
|
|
78
20
|
class BlockAdapter:
|
|
79
21
|
|
|
@@ -123,10 +65,12 @@ class BlockAdapter:
|
|
|
123
65
|
] = None
|
|
124
66
|
|
|
125
67
|
# modify cache context params for specific blocks.
|
|
126
|
-
params_modifiers:
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
68
|
+
params_modifiers: Optional[
|
|
69
|
+
Union[
|
|
70
|
+
ParamsModifier,
|
|
71
|
+
List[ParamsModifier],
|
|
72
|
+
List[List[ParamsModifier]],
|
|
73
|
+
]
|
|
130
74
|
] = None
|
|
131
75
|
|
|
132
76
|
check_forward_pattern: bool = True
|
|
@@ -169,6 +113,19 @@ class BlockAdapter:
|
|
|
169
113
|
if any((self.pipe is not None, self.transformer is not None)):
|
|
170
114
|
self.maybe_fill_attrs()
|
|
171
115
|
self.maybe_patchify()
|
|
116
|
+
self.maybe_skip_checks()
|
|
117
|
+
|
|
118
|
+
def maybe_skip_checks(self):
|
|
119
|
+
if getattr(self.transformer, "_hf_hook", None) is not None:
|
|
120
|
+
logger.warning("_hf_hook is not None, force skip pattern check!")
|
|
121
|
+
self.check_forward_pattern = False
|
|
122
|
+
self.check_num_outputs = False
|
|
123
|
+
elif getattr(self.transformer, "_diffusers_hook", None) is not None:
|
|
124
|
+
logger.warning(
|
|
125
|
+
"_diffusers_hook is not None, force skip pattern check!"
|
|
126
|
+
)
|
|
127
|
+
self.check_forward_pattern = False
|
|
128
|
+
self.check_num_outputs = False
|
|
172
129
|
|
|
173
130
|
def maybe_fill_attrs(self):
|
|
174
131
|
# NOTE: This func should be call before normalize.
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
|
|
3
2
|
import unittest
|
|
4
3
|
import functools
|
|
5
|
-
|
|
6
4
|
from contextlib import ExitStack
|
|
7
|
-
from typing import Dict, List, Tuple, Any, Union, Callable
|
|
5
|
+
from typing import Dict, List, Tuple, Any, Union, Callable, Optional
|
|
8
6
|
|
|
9
7
|
from diffusers import DiffusionPipeline
|
|
10
8
|
|
|
@@ -13,8 +11,10 @@ from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
|
13
11
|
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
14
12
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
15
13
|
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
14
|
+
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
15
|
+
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
16
16
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
17
|
-
from cache_dit.cache_factory.cache_blocks
|
|
17
|
+
from cache_dit.cache_factory.cache_blocks import (
|
|
18
18
|
patch_cached_stats,
|
|
19
19
|
remove_cached_stats,
|
|
20
20
|
)
|
|
@@ -55,6 +55,12 @@ class CachedAdapter:
|
|
|
55
55
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
56
56
|
pipe_or_adapter
|
|
57
57
|
)
|
|
58
|
+
if params_modifiers := cache_context_kwargs.pop(
|
|
59
|
+
"params_modifiers",
|
|
60
|
+
None,
|
|
61
|
+
):
|
|
62
|
+
block_adapter.params_modifiers = params_modifiers
|
|
63
|
+
|
|
58
64
|
return cls.cachify(
|
|
59
65
|
block_adapter,
|
|
60
66
|
**cache_context_kwargs,
|
|
@@ -69,6 +75,12 @@ class CachedAdapter:
|
|
|
69
75
|
logger.info(
|
|
70
76
|
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
71
77
|
)
|
|
78
|
+
if pipe_or_adapter.params_modifiers is None:
|
|
79
|
+
if params_modifiers := cache_context_kwargs.pop(
|
|
80
|
+
"params_modifiers", None
|
|
81
|
+
):
|
|
82
|
+
pipe_or_adapter.params_modifiers = params_modifiers
|
|
83
|
+
|
|
72
84
|
return cls.cachify(
|
|
73
85
|
pipe_or_adapter,
|
|
74
86
|
**cache_context_kwargs,
|
|
@@ -114,33 +126,36 @@ class CachedAdapter:
|
|
|
114
126
|
**cache_context_kwargs,
|
|
115
127
|
):
|
|
116
128
|
# Check cache_context_kwargs
|
|
117
|
-
|
|
129
|
+
cache_config: BasicCacheConfig = cache_context_kwargs[
|
|
130
|
+
"cache_config"
|
|
131
|
+
] # ref
|
|
132
|
+
assert cache_config is not None, "cache_config can not be None."
|
|
133
|
+
if cache_config.enable_separate_cfg is None:
|
|
118
134
|
# Check cfg for some specific case if users don't set it as True
|
|
119
135
|
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
-
|
|
136
|
+
cache_config.enable_separate_cfg = True
|
|
121
137
|
logger.info(
|
|
122
138
|
f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
|
|
123
139
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
140
|
)
|
|
125
141
|
else:
|
|
126
|
-
|
|
142
|
+
cache_config.enable_separate_cfg = (
|
|
127
143
|
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
144
|
)
|
|
129
145
|
logger.info(
|
|
130
146
|
f"Use default 'enable_separate_cfg' from block adapter "
|
|
131
|
-
f"register: {
|
|
147
|
+
f"register: {cache_config.enable_separate_cfg}, "
|
|
132
148
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
149
|
)
|
|
134
150
|
else:
|
|
135
151
|
logger.info(
|
|
136
152
|
f"Use custom 'enable_separate_cfg' from cache context "
|
|
137
|
-
f"kwargs: {
|
|
153
|
+
f"kwargs: {cache_config.enable_separate_cfg}. "
|
|
138
154
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
155
|
)
|
|
140
156
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
) is not None:
|
|
157
|
+
cache_type = cache_context_kwargs.pop("cache_type", None)
|
|
158
|
+
if cache_type is not None:
|
|
144
159
|
assert (
|
|
145
160
|
cache_type == CacheType.DBCache
|
|
146
161
|
), "Custom cache setting only support for DBCache now!"
|
|
@@ -176,7 +191,7 @@ class CachedAdapter:
|
|
|
176
191
|
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
177
192
|
|
|
178
193
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
179
|
-
block_adapter,
|
|
194
|
+
block_adapter, **cache_context_kwargs
|
|
180
195
|
)
|
|
181
196
|
|
|
182
197
|
original_call = block_adapter.pipe.__class__.__call__
|
|
@@ -212,7 +227,6 @@ class CachedAdapter:
|
|
|
212
227
|
def modify_context_params(
|
|
213
228
|
cls,
|
|
214
229
|
block_adapter: BlockAdapter,
|
|
215
|
-
cache_manager: CachedContextManager,
|
|
216
230
|
**cache_context_kwargs,
|
|
217
231
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
218
232
|
|
|
@@ -230,6 +244,8 @@ class CachedAdapter:
|
|
|
230
244
|
contexts_kwargs[i]["name"] = flatten_contexts[i]
|
|
231
245
|
|
|
232
246
|
if block_adapter.params_modifiers is None:
|
|
247
|
+
for i in range(len(contexts_kwargs)):
|
|
248
|
+
cls._config_messages(**contexts_kwargs[i])
|
|
233
249
|
return flatten_contexts, contexts_kwargs
|
|
234
250
|
|
|
235
251
|
flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
|
|
@@ -242,12 +258,26 @@ class CachedAdapter:
|
|
|
242
258
|
contexts_kwargs[i].update(
|
|
243
259
|
flatten_modifiers[i]._context_kwargs,
|
|
244
260
|
)
|
|
245
|
-
contexts_kwargs[i]
|
|
246
|
-
default_attrs={}, **contexts_kwargs[i]
|
|
247
|
-
)
|
|
261
|
+
cls._config_messages(**contexts_kwargs[i])
|
|
248
262
|
|
|
249
263
|
return flatten_contexts, contexts_kwargs
|
|
250
264
|
|
|
265
|
+
@classmethod
|
|
266
|
+
def _config_messages(cls, **contexts_kwargs):
|
|
267
|
+
cache_config: BasicCacheConfig = contexts_kwargs.get(
|
|
268
|
+
"cache_config", None
|
|
269
|
+
)
|
|
270
|
+
calibrator_config: CalibratorConfig = contexts_kwargs.get(
|
|
271
|
+
"calibrator_config", None
|
|
272
|
+
)
|
|
273
|
+
if cache_config is not None:
|
|
274
|
+
message = f"Collected Cache Config: {cache_config.strify()}"
|
|
275
|
+
if calibrator_config is not None:
|
|
276
|
+
message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
|
|
277
|
+
else:
|
|
278
|
+
message += ", Calibrator Config: None"
|
|
279
|
+
logger.info(message)
|
|
280
|
+
|
|
251
281
|
@classmethod
|
|
252
282
|
def mock_blocks(
|
|
253
283
|
cls,
|
|
@@ -298,7 +328,19 @@ class CachedAdapter:
|
|
|
298
328
|
|
|
299
329
|
assert isinstance(dummy_blocks_names, list)
|
|
300
330
|
|
|
301
|
-
|
|
331
|
+
from accelerate import hooks
|
|
332
|
+
|
|
333
|
+
_hf_hook: Optional[hooks.ModelHook] = None
|
|
334
|
+
|
|
335
|
+
if getattr(transformer, "_hf_hook", None) is not None:
|
|
336
|
+
_hf_hook = transformer._hf_hook # hooks from accelerate.hooks
|
|
337
|
+
|
|
338
|
+
# TODO: remove group offload hooks the re-apply after cache applied.
|
|
339
|
+
# hooks = _diffusers_hook.hooks.copy(); _diffusers_hook.hooks.clear()
|
|
340
|
+
# re-apply hooks to transformer after cache applied.
|
|
341
|
+
# from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
|
|
342
|
+
# from diffusers.hooks.group_offloading import apply_group_offloading
|
|
343
|
+
|
|
302
344
|
def new_forward(self, *args, **kwargs):
|
|
303
345
|
with ExitStack() as stack:
|
|
304
346
|
for name, context_name in zip(
|
|
@@ -316,9 +358,27 @@ class CachedAdapter:
|
|
|
316
358
|
self, dummy_name, dummy_blocks
|
|
317
359
|
)
|
|
318
360
|
)
|
|
319
|
-
|
|
361
|
+
outputs = original_forward(*args, **kwargs)
|
|
362
|
+
return outputs
|
|
363
|
+
|
|
364
|
+
def new_forward_with_hf_hook(self, *args, **kwargs):
|
|
365
|
+
# Compatible with model cpu offload
|
|
366
|
+
if _hf_hook is not None and hasattr(_hf_hook, "pre_forward"):
|
|
367
|
+
args, kwargs = _hf_hook.pre_forward(self, *args, **kwargs)
|
|
368
|
+
|
|
369
|
+
outputs = new_forward(self, *args, **kwargs)
|
|
370
|
+
|
|
371
|
+
if _hf_hook is not None and hasattr(_hf_hook, "post_forward"):
|
|
372
|
+
outputs = _hf_hook.post_forward(self, outputs)
|
|
373
|
+
|
|
374
|
+
return outputs
|
|
375
|
+
|
|
376
|
+
# NOTE: Still can't fully compatible with group offloading
|
|
377
|
+
transformer.forward = functools.update_wrapper(
|
|
378
|
+
functools.partial(new_forward_with_hf_hook, transformer),
|
|
379
|
+
new_forward_with_hf_hook,
|
|
380
|
+
)
|
|
320
381
|
|
|
321
|
-
transformer.forward = new_forward.__get__(transformer)
|
|
322
382
|
transformer._original_forward = original_forward
|
|
323
383
|
transformer._is_cached = True
|
|
324
384
|
|
|
@@ -335,7 +395,8 @@ class CachedAdapter:
|
|
|
335
395
|
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
336
396
|
assert hasattr(block_adapter.pipe, "_cache_manager")
|
|
337
397
|
assert isinstance(
|
|
338
|
-
block_adapter.pipe._cache_manager,
|
|
398
|
+
block_adapter.pipe._cache_manager,
|
|
399
|
+
CachedContextManager,
|
|
339
400
|
)
|
|
340
401
|
|
|
341
402
|
for i in range(len(block_adapter.transformer)):
|
|
@@ -12,6 +12,10 @@ from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
|
12
12
|
from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
13
13
|
CachedBlocks_Pattern_3_4_5,
|
|
14
14
|
)
|
|
15
|
+
from cache_dit.cache_factory.cache_blocks.pattern_utils import (
|
|
16
|
+
patch_cached_stats,
|
|
17
|
+
remove_cached_stats,
|
|
18
|
+
)
|
|
15
19
|
|
|
16
20
|
from cache_dit.logger import init_logger
|
|
17
21
|
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import asyncio
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Generator, Optional, List
|
|
6
|
+
from diffusers.hooks.group_offloading import _is_group_offload_enabled
|
|
7
|
+
from cache_dit.logger import init_logger
|
|
8
|
+
|
|
9
|
+
logger = init_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@torch.compiler.disable
|
|
13
|
+
@contextmanager
|
|
14
|
+
def maybe_onload(
|
|
15
|
+
block: torch.nn.Module,
|
|
16
|
+
reference_tensor: torch.Tensor,
|
|
17
|
+
pending_tasks: List[asyncio.Task] = [],
|
|
18
|
+
) -> Generator:
|
|
19
|
+
|
|
20
|
+
if not _is_group_offload_enabled(block):
|
|
21
|
+
yield block
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
original_devices: Optional[List[torch.device]] = None
|
|
25
|
+
if hasattr(block, "parameters"):
|
|
26
|
+
params = list(block.parameters())
|
|
27
|
+
if params:
|
|
28
|
+
original_devices = [param.data.device for param in params]
|
|
29
|
+
|
|
30
|
+
target_device: torch.device = reference_tensor.device
|
|
31
|
+
move_task: Optional[asyncio.Task] = None
|
|
32
|
+
need_restore: bool = False
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
if original_devices is not None:
|
|
36
|
+
unique_devices = list(set(original_devices))
|
|
37
|
+
if len(unique_devices) > 1 or unique_devices[0] != target_device:
|
|
38
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
39
|
+
logger.debug(
|
|
40
|
+
f"Onloading from {unique_devices} to {target_device}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
has_meta_params = any(
|
|
44
|
+
dev.type == "meta" for dev in original_devices
|
|
45
|
+
)
|
|
46
|
+
if has_meta_params: # compatible with sequential cpu offload
|
|
47
|
+
block = block.to_empty(device=target_device)
|
|
48
|
+
else:
|
|
49
|
+
block = block.to(target_device, non_blocking=False)
|
|
50
|
+
need_restore = True
|
|
51
|
+
yield block
|
|
52
|
+
finally:
|
|
53
|
+
if need_restore and original_devices:
|
|
54
|
+
|
|
55
|
+
async def restore_device():
|
|
56
|
+
for param, original_device in zip(
|
|
57
|
+
block.parameters(), original_devices
|
|
58
|
+
):
|
|
59
|
+
param.data = await asyncio.to_thread(
|
|
60
|
+
lambda p, d: p.to(d, non_blocking=True),
|
|
61
|
+
param.data, # type: torch.Tensor
|
|
62
|
+
original_device, # type: torch.device
|
|
63
|
+
) # type: ignore[assignment]
|
|
64
|
+
|
|
65
|
+
loop = get_event_loop()
|
|
66
|
+
move_task = loop.create_task(restore_device())
|
|
67
|
+
if move_task:
|
|
68
|
+
pending_tasks.append(move_task)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_event_loop() -> asyncio.AbstractEventLoop:
|
|
72
|
+
try:
|
|
73
|
+
loop = asyncio.get_running_loop()
|
|
74
|
+
except RuntimeError:
|
|
75
|
+
try:
|
|
76
|
+
loop = asyncio.get_event_loop()
|
|
77
|
+
except RuntimeError:
|
|
78
|
+
loop = asyncio.new_event_loop()
|
|
79
|
+
asyncio.set_event_loop(loop)
|
|
80
|
+
|
|
81
|
+
if not loop.is_running():
|
|
82
|
+
|
|
83
|
+
def run_loop() -> None:
|
|
84
|
+
asyncio.set_event_loop(loop)
|
|
85
|
+
loop.run_forever()
|
|
86
|
+
|
|
87
|
+
import threading
|
|
88
|
+
|
|
89
|
+
if not any(t.name == "_my_loop" for t in threading.enumerate()):
|
|
90
|
+
threading.Thread(
|
|
91
|
+
target=run_loop, name="_my_loop", daemon=True
|
|
92
|
+
).start()
|
|
93
|
+
|
|
94
|
+
return loop
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@torch.compiler.disable
|
|
98
|
+
def maybe_offload(
|
|
99
|
+
pending_tasks: List[asyncio.Task],
|
|
100
|
+
) -> None:
|
|
101
|
+
if not pending_tasks:
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
loop = get_event_loop()
|
|
105
|
+
|
|
106
|
+
async def gather_tasks():
|
|
107
|
+
return await asyncio.gather(*pending_tasks)
|
|
108
|
+
|
|
109
|
+
future = asyncio.run_coroutine_threadsafe(gather_tasks(), loop)
|
|
110
|
+
try:
|
|
111
|
+
future.result(timeout=30.0)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"May Offload Error: {e}")
|
|
114
|
+
|
|
115
|
+
pending_tasks.clear()
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import asyncio
|
|
2
3
|
import torch
|
|
3
4
|
import torch.distributed as dist
|
|
4
5
|
|
|
6
|
+
from typing import List
|
|
5
7
|
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
8
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
7
9
|
CachedContextManager,
|
|
@@ -45,6 +47,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
45
47
|
self.cache_prefix = cache_prefix
|
|
46
48
|
self.cache_context = cache_context
|
|
47
49
|
self.cache_manager = cache_manager
|
|
50
|
+
self.pending_tasks: List[asyncio.Task] = []
|
|
48
51
|
|
|
49
52
|
self._check_forward_pattern()
|
|
50
53
|
logger.info(
|
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
CachedContextManager,
|
|
5
|
-
)
|
|
6
|
-
from cache_dit.cache_factory.cache_contexts.v2 import (
|
|
7
|
-
CachedContextV2,
|
|
8
|
-
CachedContextManagerV2,
|
|
1
|
+
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
2
|
+
Calibrator,
|
|
3
|
+
CalibratorBase,
|
|
9
4
|
CalibratorConfig,
|
|
10
5
|
TaylorSeerCalibratorConfig,
|
|
11
6
|
FoCaCalibratorConfig,
|
|
12
7
|
)
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
9
|
+
CachedContext,
|
|
10
|
+
BasicCacheConfig,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
13
|
+
CachedContextManager,
|
|
14
|
+
)
|