cache-dit 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +37 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +75 -4
- cache_dit/cache_factory/block_adapters/block_registers.py +44 -17
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +72 -30
- cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
- cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
- cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
- cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
- cache_dit/cache_factory/cache_interface.py +102 -28
- cache_dit/cache_factory/forward_pattern.py +14 -14
- cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -49
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
- cache_dit/parallelism/parallel_backend.py +2 -0
- cache_dit/parallelism/parallel_config.py +10 -3
- cache_dit/parallelism/parallel_interface.py +14 -5
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/utils.py +56 -20
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/METADATA +24 -13
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
|
@@ -5,10 +5,11 @@ import functools
|
|
|
5
5
|
from contextlib import ExitStack
|
|
6
6
|
from typing import Dict, List, Tuple, Any, Union, Callable, Optional
|
|
7
7
|
|
|
8
|
-
from diffusers import DiffusionPipeline
|
|
8
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
9
9
|
|
|
10
10
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
11
11
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
12
|
+
from cache_dit.cache_factory.block_adapters import FakeDiffusionPipeline
|
|
12
13
|
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
13
14
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
14
15
|
from cache_dit.cache_factory.cache_contexts import ContextManager
|
|
@@ -32,6 +33,9 @@ class CachedAdapter:
|
|
|
32
33
|
pipe_or_adapter: Union[
|
|
33
34
|
DiffusionPipeline,
|
|
34
35
|
BlockAdapter,
|
|
36
|
+
# Transformer-only
|
|
37
|
+
torch.nn.Module,
|
|
38
|
+
ModelMixin,
|
|
35
39
|
],
|
|
36
40
|
**context_kwargs,
|
|
37
41
|
) -> Union[
|
|
@@ -42,7 +46,9 @@ class CachedAdapter:
|
|
|
42
46
|
pipe_or_adapter is not None
|
|
43
47
|
), "pipe or block_adapter can not both None!"
|
|
44
48
|
|
|
45
|
-
if isinstance(
|
|
49
|
+
if isinstance(
|
|
50
|
+
pipe_or_adapter, (DiffusionPipeline, torch.nn.Module, ModelMixin)
|
|
51
|
+
):
|
|
46
52
|
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
47
53
|
logger.info(
|
|
48
54
|
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
@@ -52,16 +58,22 @@ class CachedAdapter:
|
|
|
52
58
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
53
59
|
pipe_or_adapter
|
|
54
60
|
)
|
|
61
|
+
assert block_adapter is not None, (
|
|
62
|
+
f"BlockAdapter for {pipe_or_adapter.__class__.__name__} "
|
|
63
|
+
"should not be None!"
|
|
64
|
+
)
|
|
55
65
|
if params_modifiers := context_kwargs.pop(
|
|
56
66
|
"params_modifiers",
|
|
57
67
|
None,
|
|
58
68
|
):
|
|
59
69
|
block_adapter.params_modifiers = params_modifiers
|
|
60
70
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
71
|
+
block_adapter = cls.cachify(block_adapter, **context_kwargs)
|
|
72
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
73
|
+
return block_adapter.pipe
|
|
74
|
+
|
|
75
|
+
return block_adapter.transformer
|
|
76
|
+
|
|
65
77
|
else:
|
|
66
78
|
raise ValueError(
|
|
67
79
|
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
@@ -178,8 +190,6 @@ class CachedAdapter:
|
|
|
178
190
|
context_kwargs = cls.check_context_kwargs(
|
|
179
191
|
block_adapter, **context_kwargs
|
|
180
192
|
)
|
|
181
|
-
# Apply cache on pipeline: wrap cache context
|
|
182
|
-
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
183
193
|
|
|
184
194
|
# Each Pipeline should have it's own context manager instance.
|
|
185
195
|
# Different transformers (Wan2.2, etc) should shared the same
|
|
@@ -189,38 +199,58 @@ class CachedAdapter:
|
|
|
189
199
|
"cache_config", None
|
|
190
200
|
)
|
|
191
201
|
assert cache_config is not None, "cache_config can not be None."
|
|
202
|
+
# Apply cache on pipeline: wrap cache context
|
|
203
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
192
204
|
context_manager = ContextManager(
|
|
193
205
|
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
194
206
|
cache_type=cache_config.cache_type,
|
|
207
|
+
# Force use persistent_context for FakeDiffusionPipeline
|
|
208
|
+
persistent_context=isinstance(
|
|
209
|
+
block_adapter.pipe, FakeDiffusionPipeline
|
|
210
|
+
),
|
|
195
211
|
)
|
|
196
|
-
block_adapter.pipe._context_manager = context_manager # instance level
|
|
197
|
-
|
|
198
212
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
199
213
|
block_adapter, **context_kwargs
|
|
200
214
|
)
|
|
201
|
-
original_call = block_adapter.pipe.__class__.__call__
|
|
202
215
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
+
block_adapter.pipe._context_manager = context_manager # instance level
|
|
217
|
+
|
|
218
|
+
if not context_manager.persistent_context:
|
|
219
|
+
|
|
220
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
221
|
+
|
|
222
|
+
@functools.wraps(original_call)
|
|
223
|
+
def new_call(self, *args, **kwargs):
|
|
224
|
+
with ExitStack() as stack:
|
|
225
|
+
# cache context will be reset for each pipe inference
|
|
226
|
+
for context_name, context_kwargs in zip(
|
|
227
|
+
flatten_contexts, contexts_kwargs
|
|
228
|
+
):
|
|
229
|
+
stack.enter_context(
|
|
230
|
+
context_manager.enter_context(
|
|
231
|
+
context_manager.reset_context(
|
|
232
|
+
context_name,
|
|
233
|
+
**context_kwargs,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
216
236
|
)
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
237
|
+
outputs = original_call(self, *args, **kwargs)
|
|
238
|
+
cls.apply_stats_hooks(block_adapter)
|
|
239
|
+
return outputs
|
|
240
|
+
|
|
241
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
242
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
243
|
+
|
|
244
|
+
else:
|
|
245
|
+
# Init persistent cache context for transformer
|
|
246
|
+
for context_name, context_kwargs in zip(
|
|
247
|
+
flatten_contexts, contexts_kwargs
|
|
248
|
+
):
|
|
249
|
+
context_manager.reset_context(
|
|
250
|
+
context_name,
|
|
251
|
+
**context_kwargs,
|
|
252
|
+
)
|
|
221
253
|
|
|
222
|
-
block_adapter.pipe.__class__.__call__ = new_call
|
|
223
|
-
block_adapter.pipe.__class__._original_call = original_call
|
|
224
254
|
block_adapter.pipe.__class__._is_cached = True
|
|
225
255
|
|
|
226
256
|
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
@@ -349,6 +379,7 @@ class CachedAdapter:
|
|
|
349
379
|
blocks_name,
|
|
350
380
|
unique_blocks_name,
|
|
351
381
|
dummy_blocks_names,
|
|
382
|
+
block_adapter,
|
|
352
383
|
)
|
|
353
384
|
|
|
354
385
|
return block_adapter.transformer
|
|
@@ -361,6 +392,7 @@ class CachedAdapter:
|
|
|
361
392
|
blocks_name: List[str],
|
|
362
393
|
unique_blocks_name: List[str],
|
|
363
394
|
dummy_blocks_names: List[str],
|
|
395
|
+
block_adapter: BlockAdapter,
|
|
364
396
|
) -> torch.nn.Module:
|
|
365
397
|
dummy_blocks = torch.nn.ModuleList()
|
|
366
398
|
|
|
@@ -387,6 +419,8 @@ class CachedAdapter:
|
|
|
387
419
|
# re-apply hooks to transformer after cache applied.
|
|
388
420
|
# from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
|
|
389
421
|
# from diffusers.hooks.group_offloading import apply_group_offloading
|
|
422
|
+
context_manager: ContextManager = block_adapter.pipe._context_manager
|
|
423
|
+
assert isinstance(context_manager, ContextManager._supported_managers)
|
|
390
424
|
|
|
391
425
|
def new_forward(self, *args, **kwargs):
|
|
392
426
|
with ExitStack() as stack:
|
|
@@ -406,6 +440,13 @@ class CachedAdapter:
|
|
|
406
440
|
)
|
|
407
441
|
)
|
|
408
442
|
outputs = original_forward(*args, **kwargs)
|
|
443
|
+
|
|
444
|
+
if (
|
|
445
|
+
context_manager.persistent_context
|
|
446
|
+
and context_manager.is_pre_refreshed()
|
|
447
|
+
):
|
|
448
|
+
cls.apply_stats_hooks(block_adapter)
|
|
449
|
+
|
|
409
450
|
return outputs
|
|
410
451
|
|
|
411
452
|
def new_forward_with_hf_hook(self, *args, **kwargs):
|
|
@@ -509,6 +550,7 @@ class CachedAdapter:
|
|
|
509
550
|
params_shift += len(blocks)
|
|
510
551
|
|
|
511
552
|
@classmethod
|
|
553
|
+
@torch.compiler.disable
|
|
512
554
|
def apply_stats_hooks(
|
|
513
555
|
cls,
|
|
514
556
|
block_adapter: BlockAdapter,
|
|
@@ -56,6 +56,11 @@ class BasicCacheConfig:
|
|
|
56
56
|
# Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
57
57
|
# use the computed diff from current non-CFG transformer step for current CFG step.
|
|
58
58
|
cfg_diff_compute_separate: bool = True
|
|
59
|
+
# num_inference_steps (`int`, *optional*, defaults to None):
|
|
60
|
+
# num_inference_steps for DiffusionPipeline, used to adjust some internal settings
|
|
61
|
+
# for better caching performance. For example, we will refresh the cache once the
|
|
62
|
+
# executed steps exceed num_inference_steps if num_inference_steps is provided.
|
|
63
|
+
num_inference_steps: Optional[int] = None
|
|
59
64
|
|
|
60
65
|
def update(self, **kwargs) -> "BasicCacheConfig":
|
|
61
66
|
for key, value in kwargs.items():
|
|
@@ -108,9 +113,6 @@ class ExtraCacheConfig:
|
|
|
108
113
|
# downsample_factor (`int`, *optional*, defaults to 1):
|
|
109
114
|
# Downsample factor for Fn buffer, in order the save GPU memory.
|
|
110
115
|
downsample_factor: int = 1
|
|
111
|
-
# num_inference_steps (`int`, *optional*, defaults to -1):
|
|
112
|
-
# num_inference_steps for DiffusionPipeline, for future use.
|
|
113
|
-
num_inference_steps: int = -1
|
|
114
116
|
|
|
115
117
|
|
|
116
118
|
@dataclasses.dataclass
|
|
@@ -7,6 +7,7 @@ import torch.distributed as dist
|
|
|
7
7
|
|
|
8
8
|
from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
|
|
9
9
|
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
10
|
+
BasicCacheConfig,
|
|
10
11
|
CachedContext,
|
|
11
12
|
)
|
|
12
13
|
from cache_dit.logger import init_logger
|
|
@@ -21,23 +22,143 @@ class ContextNotExistError(Exception):
|
|
|
21
22
|
class CachedContextManager:
|
|
22
23
|
# Each Pipeline should have it's own context manager instance.
|
|
23
24
|
|
|
24
|
-
def __init__(self, name: str = None):
|
|
25
|
+
def __init__(self, name: str = None, persistent_context: bool = False):
|
|
25
26
|
self.name = name
|
|
26
27
|
self._current_context: CachedContext = None
|
|
27
28
|
self._cached_context_manager: Dict[str, CachedContext] = {}
|
|
29
|
+
# Whether to create new context automatically when setting
|
|
30
|
+
# a non-exist context name. Persistent context is useful when
|
|
31
|
+
# the pipeline class is not provided and users want to use
|
|
32
|
+
# cache-dit in a transformer-only way.
|
|
33
|
+
self._persistent_context = persistent_context
|
|
34
|
+
self._current_step_refreshed: bool = False
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def persistent_context(self) -> bool:
|
|
38
|
+
return self._persistent_context
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def current_context(self) -> CachedContext:
|
|
42
|
+
return self._current_context
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
@torch.compiler.disable
|
|
46
|
+
def current_step_refreshed(self) -> bool:
|
|
47
|
+
return self._current_step_refreshed
|
|
48
|
+
|
|
49
|
+
@torch.compiler.disable
|
|
50
|
+
def is_pre_refreshed(self) -> bool:
|
|
51
|
+
_context = self._current_context
|
|
52
|
+
if _context is None:
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
num_inference_steps = _context.cache_config.num_inference_steps
|
|
56
|
+
if num_inference_steps is not None:
|
|
57
|
+
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
|
|
58
|
+
return current_step == num_inference_steps - 1
|
|
59
|
+
return False
|
|
28
60
|
|
|
61
|
+
@torch.compiler.disable
|
|
29
62
|
def new_context(self, *args, **kwargs) -> CachedContext:
|
|
63
|
+
if self._persistent_context:
|
|
64
|
+
cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
|
|
65
|
+
assert (
|
|
66
|
+
cache_config is not None
|
|
67
|
+
and cache_config.num_inference_steps is not None
|
|
68
|
+
), (
|
|
69
|
+
"When persistent_context is True, num_inference_steps "
|
|
70
|
+
"must be set in cache_config for proper cache refreshing."
|
|
71
|
+
f"\nkwargs: {kwargs}"
|
|
72
|
+
)
|
|
30
73
|
_context = CachedContext(*args, **kwargs)
|
|
74
|
+
# NOTE: Patch args and kwargs for implicit refresh.
|
|
75
|
+
_context._init_args = args # maybe empty tuple: ()
|
|
76
|
+
_context._init_kwargs = kwargs # maybe empty dict: {}
|
|
31
77
|
self._cached_context_manager[_context.name] = _context
|
|
32
78
|
return _context
|
|
33
79
|
|
|
34
|
-
|
|
80
|
+
@torch.compiler.disable
|
|
81
|
+
def maybe_refresh(
|
|
82
|
+
self,
|
|
83
|
+
cached_context: Optional[CachedContext | str] = None,
|
|
84
|
+
) -> bool:
|
|
85
|
+
if cached_context is None:
|
|
86
|
+
_context = self._current_context
|
|
87
|
+
assert _context is not None, "Current context is not set!"
|
|
88
|
+
|
|
35
89
|
if isinstance(cached_context, CachedContext):
|
|
36
|
-
|
|
90
|
+
_context = cached_context
|
|
37
91
|
else:
|
|
38
92
|
if cached_context not in self._cached_context_manager:
|
|
39
93
|
raise ContextNotExistError("Context not exist!")
|
|
40
|
-
|
|
94
|
+
_context = self._cached_context_manager[cached_context]
|
|
95
|
+
|
|
96
|
+
if self._persistent_context:
|
|
97
|
+
assert _context.cache_config.num_inference_steps is not None, (
|
|
98
|
+
"When persistent_context is True, num_inference_steps must be set "
|
|
99
|
+
"in cache_config for proper cache refreshing."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
num_inference_steps = _context.cache_config.num_inference_steps
|
|
103
|
+
if num_inference_steps is not None:
|
|
104
|
+
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
|
|
105
|
+
# Another round of inference, need to refresh cache context.
|
|
106
|
+
if current_step >= num_inference_steps:
|
|
107
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
108
|
+
logger.debug(
|
|
109
|
+
f"Refreshing cache context '{_context.name}' "
|
|
110
|
+
f"as current step: {current_step} >= "
|
|
111
|
+
f"num_inference_steps: {num_inference_steps}."
|
|
112
|
+
)
|
|
113
|
+
return True
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
@torch.compiler.disable
|
|
117
|
+
def set_context(
|
|
118
|
+
self,
|
|
119
|
+
cached_context: CachedContext | str,
|
|
120
|
+
*args,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> CachedContext:
|
|
123
|
+
if isinstance(cached_context, CachedContext):
|
|
124
|
+
self._current_context = cached_context
|
|
125
|
+
else:
|
|
126
|
+
if cached_context not in self._cached_context_manager:
|
|
127
|
+
if not self._persistent_context:
|
|
128
|
+
raise ContextNotExistError(
|
|
129
|
+
"Context not exist and persistent_context is False. Please "
|
|
130
|
+
"create new context first or set persistent_context=True."
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
# Create new context if not exist
|
|
134
|
+
if any((bool(args), bool(kwargs))):
|
|
135
|
+
kwargs["name"] = cached_context
|
|
136
|
+
self._current_context = self.new_context(
|
|
137
|
+
*args, **kwargs
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"To create new context, please provide args and kwargs."
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
self._current_context = self._cached_context_manager[
|
|
145
|
+
cached_context
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
if self.maybe_refresh(self._current_context):
|
|
149
|
+
if not any((bool(args), bool(kwargs))):
|
|
150
|
+
assert hasattr(self._current_context, "_init_args")
|
|
151
|
+
assert hasattr(self._current_context, "_init_kwargs")
|
|
152
|
+
args = self._current_context._init_args
|
|
153
|
+
kwargs = self._current_context._init_kwargs
|
|
154
|
+
|
|
155
|
+
self._current_context = self.reset_context(
|
|
156
|
+
self._current_context, *args, **kwargs
|
|
157
|
+
)
|
|
158
|
+
self._current_step_refreshed = True
|
|
159
|
+
else:
|
|
160
|
+
self._current_step_refreshed = False
|
|
161
|
+
|
|
41
162
|
return self._current_context
|
|
42
163
|
|
|
43
164
|
def get_context(self, name: str = None) -> CachedContext:
|
|
@@ -20,10 +20,17 @@ class ContextManager:
|
|
|
20
20
|
cls,
|
|
21
21
|
cache_type: CacheType,
|
|
22
22
|
name: str = "default",
|
|
23
|
+
persistent_context: bool = False,
|
|
23
24
|
) -> CachedContextManager | PrunedContextManager:
|
|
24
25
|
if cache_type == CacheType.DBCache:
|
|
25
|
-
return CachedContextManager(
|
|
26
|
+
return CachedContextManager(
|
|
27
|
+
name=name,
|
|
28
|
+
persistent_context=persistent_context,
|
|
29
|
+
)
|
|
26
30
|
elif cache_type == CacheType.DBPrune:
|
|
27
|
-
return PrunedContextManager(
|
|
31
|
+
return PrunedContextManager(
|
|
32
|
+
name=name,
|
|
33
|
+
persistent_context=persistent_context,
|
|
34
|
+
)
|
|
28
35
|
else:
|
|
29
36
|
raise ValueError(f"Unsupported cache_type: {cache_type}.")
|
|
@@ -3,6 +3,7 @@ import functools
|
|
|
3
3
|
from typing import Dict, List, Tuple, Union
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
6
|
+
BasicCacheConfig,
|
|
6
7
|
CachedContextManager,
|
|
7
8
|
)
|
|
8
9
|
from cache_dit.cache_factory.cache_contexts.prune_context import (
|
|
@@ -16,15 +17,27 @@ logger = init_logger(__name__)
|
|
|
16
17
|
class PrunedContextManager(CachedContextManager):
|
|
17
18
|
# Reuse CachedContextManager for Dynamic Block Prune
|
|
18
19
|
|
|
19
|
-
def __init__(self, name: str = None):
|
|
20
|
-
super().__init__(name)
|
|
20
|
+
def __init__(self, name: str = None, **kwargs):
|
|
21
|
+
super().__init__(name, **kwargs)
|
|
21
22
|
# Overwrite for Dynamic Block Prune
|
|
22
23
|
self._current_context: PrunedContext = None
|
|
23
24
|
self._cached_context_manager: Dict[str, PrunedContext] = {}
|
|
24
25
|
|
|
25
26
|
# Overwrite for Dynamic Block Prune
|
|
26
27
|
def new_context(self, *args, **kwargs) -> PrunedContext:
|
|
28
|
+
if self._persistent_context:
|
|
29
|
+
cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
|
|
30
|
+
assert (
|
|
31
|
+
cache_config is not None
|
|
32
|
+
and cache_config.num_inference_steps is not None
|
|
33
|
+
), (
|
|
34
|
+
"When persistent_context is True, num_inference_steps "
|
|
35
|
+
"must be set in cache_config for proper cache refreshing."
|
|
36
|
+
)
|
|
27
37
|
_context = PrunedContext(*args, **kwargs)
|
|
38
|
+
# NOTE: Patch args and kwargs for implicit refresh.
|
|
39
|
+
_context._init_args = args # maybe empty tuple: ()
|
|
40
|
+
_context._init_kwargs = kwargs # maybe empty dict: {}
|
|
28
41
|
self._cached_context_manager[_context.name] = _context
|
|
29
42
|
return _context
|
|
30
43
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import torch
|
|
1
2
|
from typing import Any, Tuple, List, Union, Optional
|
|
2
|
-
from diffusers import DiffusionPipeline
|
|
3
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
3
4
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
5
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
5
6
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
@@ -22,13 +23,18 @@ def enable_cache(
|
|
|
22
23
|
pipe_or_adapter: Union[
|
|
23
24
|
DiffusionPipeline,
|
|
24
25
|
BlockAdapter,
|
|
26
|
+
# Transformer-only
|
|
27
|
+
torch.nn.Module,
|
|
28
|
+
ModelMixin,
|
|
25
29
|
],
|
|
26
30
|
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
27
|
-
cache_config:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
31
|
+
cache_config: Optional[
|
|
32
|
+
Union[
|
|
33
|
+
BasicCacheConfig,
|
|
34
|
+
DBCacheConfig,
|
|
35
|
+
DBPruneConfig,
|
|
36
|
+
]
|
|
37
|
+
] = None,
|
|
32
38
|
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
33
39
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
34
40
|
# Modify cache context params for specific blocks.
|
|
@@ -45,6 +51,9 @@ def enable_cache(
|
|
|
45
51
|
**kwargs,
|
|
46
52
|
) -> Union[
|
|
47
53
|
DiffusionPipeline,
|
|
54
|
+
# Transformer-only
|
|
55
|
+
torch.nn.Module,
|
|
56
|
+
ModelMixin,
|
|
48
57
|
BlockAdapter,
|
|
49
58
|
]:
|
|
50
59
|
r"""
|
|
@@ -74,7 +83,7 @@ def enable_cache(
|
|
|
74
83
|
with minimal code changes.
|
|
75
84
|
|
|
76
85
|
Args:
|
|
77
|
-
pipe_or_adapter (`DiffusionPipeline` or `
|
|
86
|
+
pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
|
|
78
87
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
79
88
|
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
80
89
|
for the usgae of BlockAdapter.
|
|
@@ -117,6 +126,10 @@ def enable_cache(
|
|
|
117
126
|
Whether to compute separate difference values for CFG and non-CFG steps, default is True.
|
|
118
127
|
If False, we will use the computed difference from the current non-CFG transformer step
|
|
119
128
|
for the current CFG step.
|
|
129
|
+
num_inference_steps (`int`, *optional*, defaults to None):
|
|
130
|
+
num_inference_steps for DiffusionPipeline, used to adjust some internal settings
|
|
131
|
+
for better caching performance. For example, we will refresh the cache once the
|
|
132
|
+
executed steps exceed num_inference_steps if num_inference_steps is provided.
|
|
120
133
|
|
|
121
134
|
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
122
135
|
Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
|
|
@@ -135,10 +148,22 @@ def enable_cache(
|
|
|
135
148
|
Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
|
|
136
149
|
parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
|
|
137
150
|
for more details of ParallelismConfig.
|
|
151
|
+
backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
|
|
152
|
+
Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
|
|
153
|
+
For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
|
|
154
|
+
only NATIVE_PYTORCH backend is supported.
|
|
138
155
|
ulysses_size: (`int`, *optional*, defaults to None):
|
|
139
156
|
The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
|
|
157
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
140
158
|
ring_size: (`int`, *optional*, defaults to None):
|
|
141
159
|
The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
|
|
160
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
161
|
+
tp_size: (`int`, *optional*, defaults to None):
|
|
162
|
+
The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
|
|
163
|
+
This setting is only valid when backend is NATIVE_PYTORCH.
|
|
164
|
+
parallel_kwargs: (`dict`, *optional*, defaults to {}):
|
|
165
|
+
Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
|
|
166
|
+
it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
|
|
142
167
|
|
|
143
168
|
kwargs (`dict`, *optional*, defaults to {})
|
|
144
169
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
@@ -154,13 +179,27 @@ def enable_cache(
|
|
|
154
179
|
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
155
180
|
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
156
181
|
"""
|
|
182
|
+
# Precheck for compatibility of different configurations
|
|
183
|
+
if cache_config is None:
|
|
184
|
+
if parallelism_config is None:
|
|
185
|
+
# Set default cache config only when parallelism is not enabled
|
|
186
|
+
logger.info("cache_config is None, using default DBCacheConfig")
|
|
187
|
+
cache_config = DBCacheConfig()
|
|
188
|
+
else:
|
|
189
|
+
# Allow empty cache_config when parallelism is enabled
|
|
190
|
+
logger.warning(
|
|
191
|
+
"Parallelism is enabled and cache_config is None. Please manually "
|
|
192
|
+
"set cache_config to avoid potential compatibility issues. "
|
|
193
|
+
"Otherwise, cache will not be enabled."
|
|
194
|
+
)
|
|
195
|
+
|
|
157
196
|
# Collect cache context kwargs
|
|
158
197
|
context_kwargs = {}
|
|
159
198
|
if (cache_type := context_kwargs.get("cache_type", None)) is not None:
|
|
160
199
|
if cache_type == CacheType.NONE:
|
|
161
200
|
return pipe_or_adapter
|
|
162
201
|
|
|
163
|
-
#
|
|
202
|
+
# NOTE: Deprecated cache config params. These parameters are now retained
|
|
164
203
|
# for backward compatibility but will be removed in the future.
|
|
165
204
|
deprecated_kwargs = {
|
|
166
205
|
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
@@ -196,9 +235,9 @@ def enable_cache(
|
|
|
196
235
|
if cache_config is not None:
|
|
197
236
|
context_kwargs["cache_config"] = cache_config
|
|
198
237
|
|
|
199
|
-
#
|
|
238
|
+
# NOTE: Deprecated taylorseer params. These parameters are now retained
|
|
200
239
|
# for backward compatibility but will be removed in the future.
|
|
201
|
-
if (
|
|
240
|
+
if cache_config is not None and (
|
|
202
241
|
kwargs.get("enable_taylorseer", None) is not None
|
|
203
242
|
or kwargs.get("enable_encoder_taylorseer", None) is not None
|
|
204
243
|
):
|
|
@@ -226,16 +265,25 @@ def enable_cache(
|
|
|
226
265
|
if params_modifiers is not None:
|
|
227
266
|
context_kwargs["params_modifiers"] = params_modifiers
|
|
228
267
|
|
|
229
|
-
if
|
|
230
|
-
|
|
268
|
+
if cache_config is not None:
|
|
269
|
+
if isinstance(
|
|
231
270
|
pipe_or_adapter,
|
|
232
|
-
|
|
233
|
-
)
|
|
271
|
+
(DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
|
|
272
|
+
):
|
|
273
|
+
pipe_or_adapter = CachedAdapter.apply(
|
|
274
|
+
pipe_or_adapter,
|
|
275
|
+
**context_kwargs,
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"type: {type(pipe_or_adapter)} is not valid, "
|
|
280
|
+
"Please pass DiffusionPipeline or BlockAdapter"
|
|
281
|
+
"for the 1's position param: pipe_or_adapter"
|
|
282
|
+
)
|
|
234
283
|
else:
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
"
|
|
238
|
-
"for the 1's position param: pipe_or_adapter"
|
|
284
|
+
logger.warning(
|
|
285
|
+
"cache_config is None, skip enabling cache for "
|
|
286
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
239
287
|
)
|
|
240
288
|
|
|
241
289
|
# NOTE: Users should always enable parallelism after applying
|
|
@@ -244,19 +292,45 @@ def enable_cache(
|
|
|
244
292
|
assert isinstance(
|
|
245
293
|
parallelism_config, ParallelismConfig
|
|
246
294
|
), "parallelism_config should be of type ParallelismConfig."
|
|
295
|
+
|
|
296
|
+
transformers = []
|
|
247
297
|
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
248
|
-
|
|
298
|
+
adapter = BlockAdapterRegistry.get_adapter(pipe_or_adapter)
|
|
299
|
+
if adapter is None:
|
|
300
|
+
assert hasattr(pipe_or_adapter, "transformer"), (
|
|
301
|
+
"The given DiffusionPipeline does not have "
|
|
302
|
+
"a 'transformer' attribute, cannot enable "
|
|
303
|
+
"parallelism."
|
|
304
|
+
)
|
|
305
|
+
transformers = [pipe_or_adapter.transformer]
|
|
306
|
+
else:
|
|
307
|
+
adapter = BlockAdapter.normalize(adapter, unique=False)
|
|
308
|
+
transformers = BlockAdapter.flatten(adapter.transformer)
|
|
249
309
|
else:
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
310
|
+
if not BlockAdapter.is_normalized(pipe_or_adapter):
|
|
311
|
+
pipe_or_adapter = BlockAdapter.normalize(
|
|
312
|
+
pipe_or_adapter, unique=False
|
|
313
|
+
)
|
|
314
|
+
transformers = BlockAdapter.flatten(pipe_or_adapter.transformer)
|
|
315
|
+
|
|
316
|
+
if len(transformers) == 0:
|
|
317
|
+
logger.warning(
|
|
318
|
+
"No transformer is detected in the "
|
|
319
|
+
"BlockAdapter, skip enabling parallelism."
|
|
320
|
+
)
|
|
321
|
+
return pipe_or_adapter
|
|
322
|
+
|
|
323
|
+
if len(transformers) > 1:
|
|
324
|
+
logger.warning(
|
|
325
|
+
"Multiple transformers are detected in the "
|
|
326
|
+
"BlockAdapter, all transfomers will be "
|
|
327
|
+
"enabled for parallelism."
|
|
328
|
+
)
|
|
329
|
+
for i, transformer in enumerate(transformers):
|
|
330
|
+
# Enable parallelism for the transformer inplace
|
|
331
|
+
transformers[i] = enable_parallelism(
|
|
332
|
+
transformer, parallelism_config
|
|
256
333
|
)
|
|
257
|
-
transformer = BlockAdapter.flatten(pipe_or_adapter.transformer)[0]
|
|
258
|
-
# Enable parallelism for the transformer inplace
|
|
259
|
-
transformer = enable_parallelism(transformer, parallelism_config)
|
|
260
334
|
return pipe_or_adapter
|
|
261
335
|
|
|
262
336
|
|