cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Any, Tuple, List, Union, Optional
|
|
3
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
4
|
+
from cache_dit.caching.cache_types import CacheType
|
|
5
|
+
from cache_dit.caching.block_adapters import BlockAdapter
|
|
6
|
+
from cache_dit.caching.block_adapters import BlockAdapterRegistry
|
|
7
|
+
from cache_dit.caching.cache_adapters import CachedAdapter
|
|
8
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
9
|
+
from cache_dit.caching.cache_contexts import DBCacheConfig
|
|
10
|
+
from cache_dit.caching.cache_contexts import DBPruneConfig
|
|
11
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
12
|
+
from cache_dit.caching.params_modifier import ParamsModifier
|
|
13
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
14
|
+
from cache_dit.parallelism import enable_parallelism
|
|
15
|
+
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def enable_cache(
|
|
22
|
+
# DiffusionPipeline or BlockAdapter
|
|
23
|
+
pipe_or_adapter: Union[
|
|
24
|
+
DiffusionPipeline,
|
|
25
|
+
BlockAdapter,
|
|
26
|
+
# Transformer-only
|
|
27
|
+
torch.nn.Module,
|
|
28
|
+
ModelMixin,
|
|
29
|
+
],
|
|
30
|
+
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
31
|
+
cache_config: Optional[
|
|
32
|
+
Union[
|
|
33
|
+
BasicCacheConfig,
|
|
34
|
+
DBCacheConfig,
|
|
35
|
+
DBPruneConfig,
|
|
36
|
+
]
|
|
37
|
+
] = None,
|
|
38
|
+
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
39
|
+
calibrator_config: Optional[CalibratorConfig] = None,
|
|
40
|
+
# Modify cache context params for specific blocks.
|
|
41
|
+
params_modifiers: Optional[
|
|
42
|
+
Union[
|
|
43
|
+
ParamsModifier,
|
|
44
|
+
List[ParamsModifier],
|
|
45
|
+
List[List[ParamsModifier]],
|
|
46
|
+
]
|
|
47
|
+
] = None,
|
|
48
|
+
# Config for Parallelism
|
|
49
|
+
parallelism_config: Optional[ParallelismConfig] = None,
|
|
50
|
+
# Other cache context kwargs: Deprecated cache kwargs
|
|
51
|
+
**kwargs,
|
|
52
|
+
) -> Union[
|
|
53
|
+
DiffusionPipeline,
|
|
54
|
+
# Transformer-only
|
|
55
|
+
torch.nn.Module,
|
|
56
|
+
ModelMixin,
|
|
57
|
+
BlockAdapter,
|
|
58
|
+
]:
|
|
59
|
+
r"""
|
|
60
|
+
The `enable_cache` function serves as a unified caching interface designed to optimize the performance
|
|
61
|
+
of diffusion transformer models by implementing an intelligent caching mechanism known as `DBCache`.
|
|
62
|
+
This API is engineered to be compatible with nearly `all` diffusion transformer architectures that
|
|
63
|
+
feature transformer blocks adhering to standard input-output patterns, eliminating the need for
|
|
64
|
+
architecture-specific modifications.
|
|
65
|
+
|
|
66
|
+
By strategically caching intermediate outputs of transformer blocks during the diffusion process,
|
|
67
|
+
`DBCache` significantly reduces redundant computations without compromising generation quality.
|
|
68
|
+
The caching mechanism works by tracking residual differences between consecutive steps, allowing
|
|
69
|
+
the model to reuse previously computed features when these differences fall below a configurable
|
|
70
|
+
threshold. This approach maintains a balance between computational efficiency and output precision.
|
|
71
|
+
|
|
72
|
+
The default configuration (`F8B0, 8 warmup steps, unlimited cached steps`) is carefully tuned to
|
|
73
|
+
provide an optimal tradeoff for most common use cases. The "F8B0" configuration indicates that
|
|
74
|
+
the first 8 transformer blocks are used to compute stable feature differences, while no final
|
|
75
|
+
blocks are employed for additional fusion. The warmup phase ensures the model establishes
|
|
76
|
+
sufficient feature representation before caching begins, preventing potential degradation of
|
|
77
|
+
output quality.
|
|
78
|
+
|
|
79
|
+
This function seamlessly integrates with both standard diffusion pipelines and custom block
|
|
80
|
+
adapters, making it versatile for various deployment scenarios—from research prototyping to
|
|
81
|
+
production environments where inference speed is critical. By abstracting the complexity of
|
|
82
|
+
caching logic behind a simple interface, it enables developers to enhance model performance
|
|
83
|
+
with minimal code changes.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
|
|
87
|
+
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
88
|
+
For example: cache_dit.enable_cache(FluxPipeline(...)).
|
|
89
|
+
|
|
90
|
+
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
91
|
+
Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
|
|
92
|
+
Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
93
|
+
Specifies that `DBCache` uses the**first n**Transformer blocks to fit the information at time step t,
|
|
94
|
+
enabling the calculation of a more stable L1 difference and delivering more accurate information
|
|
95
|
+
to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
96
|
+
for more details of DBCache.
|
|
97
|
+
Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
98
|
+
Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
99
|
+
prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
100
|
+
that use residual cache.
|
|
101
|
+
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
102
|
+
the value of residual diff threshold, a higher value leads to faster performance at the
|
|
103
|
+
cost of lower precision.
|
|
104
|
+
max_warmup_steps (`int`, *required*, defaults to 8):
|
|
105
|
+
DBCache does not apply the caching strategy when the number of running steps is less than
|
|
106
|
+
or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
107
|
+
warmup_interval (`int`, *required*, defaults to 1):
|
|
108
|
+
Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
|
|
109
|
+
in warmup steps will be computed, others will use dynamic cache.
|
|
110
|
+
max_cached_steps (`int`, *required*, defaults to -1):
|
|
111
|
+
DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
112
|
+
prevent precision degradation.
|
|
113
|
+
max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
114
|
+
DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
115
|
+
prevent precision degradation.
|
|
116
|
+
enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
117
|
+
Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
118
|
+
and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
119
|
+
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
120
|
+
cfg_compute_first (`bool`, *required*, defaults to False):
|
|
121
|
+
Whether to compute cfg forward first, default is False, meaning:
|
|
122
|
+
0, 2, 4, ..., -> non-CFG step;
|
|
123
|
+
1, 3, 5, ... -> CFG step.
|
|
124
|
+
cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
125
|
+
Whether to compute separate difference values for CFG and non-CFG steps, default is True.
|
|
126
|
+
If False, we will use the computed difference from the current non-CFG transformer step
|
|
127
|
+
for the current CFG step.
|
|
128
|
+
num_inference_steps (`int`, *optional*, defaults to None):
|
|
129
|
+
num_inference_steps for DiffusionPipeline, used to adjust some internal settings
|
|
130
|
+
for better caching performance. For example, we will refresh the cache once the
|
|
131
|
+
executed steps exceed num_inference_steps if num_inference_steps is provided.
|
|
132
|
+
|
|
133
|
+
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
134
|
+
Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
|
|
135
|
+
with a specific calibrator, such as taylorseer, foca, and so on.
|
|
136
|
+
|
|
137
|
+
params_modifiers ('ParamsModifier', *optional*, defaults to None):
|
|
138
|
+
Modify cache context params for specific blocks. The configurable params listed belows:
|
|
139
|
+
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
140
|
+
The same as 'cache_config' param in cache_dit.enable_cache() interface.
|
|
141
|
+
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
142
|
+
The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
|
|
143
|
+
**kwargs: (`dict`, *optional*, defaults to {}):
|
|
144
|
+
The same as 'kwargs' param in cache_dit.enable_cache() interface.
|
|
145
|
+
|
|
146
|
+
parallelism_config (`ParallelismConfig`, *optional*, defaults to None):
|
|
147
|
+
Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
|
|
148
|
+
parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
|
|
149
|
+
for more details of ParallelismConfig.
|
|
150
|
+
backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
|
|
151
|
+
Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
|
|
152
|
+
For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
|
|
153
|
+
only NATIVE_PYTORCH backend is supported.
|
|
154
|
+
ulysses_size: (`int`, *optional*, defaults to None):
|
|
155
|
+
The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
|
|
156
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
157
|
+
ring_size: (`int`, *optional*, defaults to None):
|
|
158
|
+
The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
|
|
159
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
160
|
+
tp_size: (`int`, *optional*, defaults to None):
|
|
161
|
+
The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
|
|
162
|
+
This setting is only valid when backend is NATIVE_PYTORCH.
|
|
163
|
+
parallel_kwargs: (`dict`, *optional*, defaults to {}):
|
|
164
|
+
Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
|
|
165
|
+
it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
|
|
166
|
+
|
|
167
|
+
kwargs (`dict`, *optional*, defaults to {})
|
|
168
|
+
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_context.py
|
|
169
|
+
for more details.
|
|
170
|
+
|
|
171
|
+
Examples:
|
|
172
|
+
```py
|
|
173
|
+
>>> import cache_dit
|
|
174
|
+
>>> from diffusers import DiffusionPipeline
|
|
175
|
+
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
|
|
176
|
+
>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
|
|
177
|
+
>>> output = pipe(...) # Just call the pipe as normal.
|
|
178
|
+
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
179
|
+
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
180
|
+
"""
|
|
181
|
+
# Precheck for compatibility of different configurations
|
|
182
|
+
if cache_config is None:
|
|
183
|
+
if parallelism_config is None:
|
|
184
|
+
# Set default cache config only when parallelism is not enabled
|
|
185
|
+
logger.info("cache_config is None, using default DBCacheConfig")
|
|
186
|
+
cache_config = DBCacheConfig()
|
|
187
|
+
else:
|
|
188
|
+
# Allow empty cache_config when parallelism is enabled
|
|
189
|
+
logger.warning(
|
|
190
|
+
"Parallelism is enabled and cache_config is None. Please manually "
|
|
191
|
+
"set cache_config to avoid potential compatibility issues. "
|
|
192
|
+
"Otherwise, cache will not be enabled."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Collect cache context kwargs
|
|
196
|
+
context_kwargs = {}
|
|
197
|
+
if (cache_type := context_kwargs.get("cache_type", None)) is not None:
|
|
198
|
+
if cache_type == CacheType.NONE:
|
|
199
|
+
return pipe_or_adapter
|
|
200
|
+
|
|
201
|
+
# NOTE: Deprecated cache config params. These parameters are now retained
|
|
202
|
+
# for backward compatibility but will be removed in the future.
|
|
203
|
+
deprecated_kwargs = {
|
|
204
|
+
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
205
|
+
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
206
|
+
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
207
|
+
"max_cached_steps": kwargs.get("max_cached_steps", None),
|
|
208
|
+
"max_continuous_cached_steps": kwargs.get(
|
|
209
|
+
"max_continuous_cached_steps", None
|
|
210
|
+
),
|
|
211
|
+
"residual_diff_threshold": kwargs.get("residual_diff_threshold", None),
|
|
212
|
+
"enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
|
|
213
|
+
"cfg_compute_first": kwargs.get("cfg_compute_first", None),
|
|
214
|
+
"cfg_diff_compute_separate": kwargs.get(
|
|
215
|
+
"cfg_diff_compute_separate", None
|
|
216
|
+
),
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
deprecated_kwargs = {
|
|
220
|
+
k: v for k, v in deprecated_kwargs.items() if v is not None
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
if deprecated_kwargs:
|
|
224
|
+
logger.warning(
|
|
225
|
+
"Manually settup DBCache context without BasicCacheConfig is "
|
|
226
|
+
"deprecated and will be removed in the future, please use "
|
|
227
|
+
"`cache_config` parameter instead!"
|
|
228
|
+
)
|
|
229
|
+
if cache_config is not None:
|
|
230
|
+
cache_config.update(**deprecated_kwargs)
|
|
231
|
+
else:
|
|
232
|
+
cache_config = BasicCacheConfig(**deprecated_kwargs)
|
|
233
|
+
|
|
234
|
+
if cache_config is not None:
|
|
235
|
+
context_kwargs["cache_config"] = cache_config
|
|
236
|
+
|
|
237
|
+
# NOTE: Deprecated taylorseer params. These parameters are now retained
|
|
238
|
+
# for backward compatibility but will be removed in the future.
|
|
239
|
+
if cache_config is not None and (
|
|
240
|
+
kwargs.get("enable_taylorseer", None) is not None
|
|
241
|
+
or kwargs.get("enable_encoder_taylorseer", None) is not None
|
|
242
|
+
):
|
|
243
|
+
logger.warning(
|
|
244
|
+
"Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
|
|
245
|
+
"deprecated and will be removed in the future, please use "
|
|
246
|
+
"`calibrator_config` parameter instead!"
|
|
247
|
+
)
|
|
248
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
249
|
+
TaylorSeerCalibratorConfig,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
calibrator_config = TaylorSeerCalibratorConfig(
|
|
253
|
+
enable_calibrator=kwargs.get("enable_taylorseer"),
|
|
254
|
+
enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer"),
|
|
255
|
+
calibrator_cache_type=kwargs.get(
|
|
256
|
+
"taylorseer_cache_type", "residual"
|
|
257
|
+
),
|
|
258
|
+
taylorseer_order=kwargs.get("taylorseer_order", 1),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if calibrator_config is not None:
|
|
262
|
+
context_kwargs["calibrator_config"] = calibrator_config
|
|
263
|
+
|
|
264
|
+
if params_modifiers is not None:
|
|
265
|
+
context_kwargs["params_modifiers"] = params_modifiers
|
|
266
|
+
|
|
267
|
+
if cache_config is not None:
|
|
268
|
+
if isinstance(
|
|
269
|
+
pipe_or_adapter,
|
|
270
|
+
(DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
|
|
271
|
+
):
|
|
272
|
+
pipe_or_adapter = CachedAdapter.apply(
|
|
273
|
+
pipe_or_adapter,
|
|
274
|
+
**context_kwargs,
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
f"type: {type(pipe_or_adapter)} is not valid, "
|
|
279
|
+
"Please pass DiffusionPipeline or BlockAdapter"
|
|
280
|
+
"for the 1's position param: pipe_or_adapter"
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
logger.warning(
|
|
284
|
+
"cache_config is None, skip enabling cache for "
|
|
285
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# NOTE: Users should always enable parallelism after applying
|
|
289
|
+
# cache to avoid hooks conflict.
|
|
290
|
+
if parallelism_config is not None:
|
|
291
|
+
assert isinstance(
|
|
292
|
+
parallelism_config, ParallelismConfig
|
|
293
|
+
), "parallelism_config should be of type ParallelismConfig."
|
|
294
|
+
|
|
295
|
+
transformers = []
|
|
296
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
297
|
+
adapter = BlockAdapterRegistry.get_adapter(pipe_or_adapter)
|
|
298
|
+
if adapter is None:
|
|
299
|
+
assert hasattr(pipe_or_adapter, "transformer"), (
|
|
300
|
+
"The given DiffusionPipeline does not have "
|
|
301
|
+
"a 'transformer' attribute, cannot enable "
|
|
302
|
+
"parallelism."
|
|
303
|
+
)
|
|
304
|
+
transformers = [pipe_or_adapter.transformer]
|
|
305
|
+
else:
|
|
306
|
+
adapter = BlockAdapter.normalize(adapter, unique=False)
|
|
307
|
+
transformers = BlockAdapter.flatten(adapter.transformer)
|
|
308
|
+
else:
|
|
309
|
+
if not BlockAdapter.is_normalized(pipe_or_adapter):
|
|
310
|
+
pipe_or_adapter = BlockAdapter.normalize(
|
|
311
|
+
pipe_or_adapter, unique=False
|
|
312
|
+
)
|
|
313
|
+
transformers = BlockAdapter.flatten(pipe_or_adapter.transformer)
|
|
314
|
+
|
|
315
|
+
if len(transformers) == 0:
|
|
316
|
+
logger.warning(
|
|
317
|
+
"No transformer is detected in the "
|
|
318
|
+
"BlockAdapter, skip enabling parallelism."
|
|
319
|
+
)
|
|
320
|
+
return pipe_or_adapter
|
|
321
|
+
|
|
322
|
+
if len(transformers) > 1:
|
|
323
|
+
logger.warning(
|
|
324
|
+
"Multiple transformers are detected in the "
|
|
325
|
+
"BlockAdapter, all transfomers will be "
|
|
326
|
+
"enabled for parallelism."
|
|
327
|
+
)
|
|
328
|
+
for i, transformer in enumerate(transformers):
|
|
329
|
+
# Enable parallelism for the transformer inplace
|
|
330
|
+
transformers[i] = enable_parallelism(
|
|
331
|
+
transformer, parallelism_config
|
|
332
|
+
)
|
|
333
|
+
return pipe_or_adapter
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def disable_cache(
|
|
337
|
+
pipe_or_adapter: Union[
|
|
338
|
+
DiffusionPipeline,
|
|
339
|
+
BlockAdapter,
|
|
340
|
+
],
|
|
341
|
+
):
|
|
342
|
+
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
|
|
343
|
+
logger.warning(
|
|
344
|
+
f"Cache Acceleration is disabled for: "
|
|
345
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def supported_pipelines(
|
|
350
|
+
**kwargs,
|
|
351
|
+
) -> Tuple[int, List[str]]:
|
|
352
|
+
return BlockAdapterRegistry.supported_pipelines(**kwargs)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def get_adapter(
|
|
356
|
+
pipe: DiffusionPipeline | str | Any,
|
|
357
|
+
) -> BlockAdapter:
|
|
358
|
+
return BlockAdapterRegistry.get_adapter(pipe)
|
|
@@ -6,7 +6,8 @@ logger = init_logger(__name__)
|
|
|
6
6
|
|
|
7
7
|
class CacheType(Enum):
|
|
8
8
|
NONE = "NONE"
|
|
9
|
-
DBCache = "Dual_Block_Cache"
|
|
9
|
+
DBCache = "DBCache" # "Dual_Block_Cache"
|
|
10
|
+
DBPrune = "DBPrune" # "Dynamic_Block_Prune"
|
|
10
11
|
|
|
11
12
|
@staticmethod
|
|
12
13
|
def type(type_hint: "CacheType | str") -> "CacheType":
|
|
@@ -14,6 +15,9 @@ class CacheType(Enum):
|
|
|
14
15
|
return type_hint
|
|
15
16
|
return cache_type(type_hint)
|
|
16
17
|
|
|
18
|
+
def __str__(self) -> str:
|
|
19
|
+
return self.value
|
|
20
|
+
|
|
17
21
|
|
|
18
22
|
def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
19
23
|
if type_hint is None:
|
|
@@ -21,7 +25,6 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
21
25
|
|
|
22
26
|
if isinstance(type_hint, CacheType):
|
|
23
27
|
return type_hint
|
|
24
|
-
|
|
25
28
|
elif type_hint.upper() in (
|
|
26
29
|
"DUAL_BLOCK_CACHE",
|
|
27
30
|
"DB_CACHE",
|
|
@@ -29,6 +32,20 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
29
32
|
"DB",
|
|
30
33
|
):
|
|
31
34
|
return CacheType.DBCache
|
|
35
|
+
elif type_hint.upper() in (
|
|
36
|
+
"DYNAMIC_BLOCK_PRUNE",
|
|
37
|
+
"DB_PRUNE",
|
|
38
|
+
"DBPRUNE",
|
|
39
|
+
"DBP",
|
|
40
|
+
):
|
|
41
|
+
return CacheType.DBPrune
|
|
42
|
+
elif type_hint.upper() in (
|
|
43
|
+
"NONE",
|
|
44
|
+
"NO_CACHE",
|
|
45
|
+
"NOCACHE",
|
|
46
|
+
"NC",
|
|
47
|
+
):
|
|
48
|
+
return CacheType.NONE
|
|
32
49
|
return CacheType.NONE
|
|
33
50
|
|
|
34
51
|
|
|
@@ -20,33 +20,33 @@ class ForwardPattern(Enum):
|
|
|
20
20
|
|
|
21
21
|
Pattern_0 = (
|
|
22
22
|
True, # Return_H_First
|
|
23
|
-
False,
|
|
24
|
-
False,
|
|
23
|
+
False, # Return_H_Only
|
|
24
|
+
False, # Forward_H_only
|
|
25
25
|
("hidden_states", "encoder_hidden_states"), # In
|
|
26
26
|
("hidden_states", "encoder_hidden_states"), # Out
|
|
27
27
|
True, # Supported
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
Pattern_1 = (
|
|
31
|
-
False,
|
|
32
|
-
False,
|
|
33
|
-
False,
|
|
31
|
+
False, # Return_H_First
|
|
32
|
+
False, # Return_H_Only
|
|
33
|
+
False, # Forward_H_only
|
|
34
34
|
("hidden_states", "encoder_hidden_states"), # In
|
|
35
35
|
("encoder_hidden_states", "hidden_states"), # Out
|
|
36
36
|
True, # Supported
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
Pattern_2 = (
|
|
40
|
-
False,
|
|
40
|
+
False, # Return_H_First
|
|
41
41
|
True, # Return_H_Only
|
|
42
|
-
False,
|
|
42
|
+
False, # Forward_H_only
|
|
43
43
|
("hidden_states", "encoder_hidden_states"), # In
|
|
44
|
-
("hidden_states",),
|
|
44
|
+
("hidden_states",), # Out
|
|
45
45
|
True, # Supported
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
Pattern_3 = (
|
|
49
|
-
False,
|
|
49
|
+
False, # Return_H_First
|
|
50
50
|
True, # Return_H_Only
|
|
51
51
|
True, # Forward_H_only
|
|
52
52
|
("hidden_states",), # In
|
|
@@ -56,18 +56,18 @@ class ForwardPattern(Enum):
|
|
|
56
56
|
|
|
57
57
|
Pattern_4 = (
|
|
58
58
|
True, # Return_H_First
|
|
59
|
-
False,
|
|
59
|
+
False, # Return_H_Only
|
|
60
60
|
True, # Forward_H_only
|
|
61
|
-
("hidden_states",),
|
|
61
|
+
("hidden_states",), # In
|
|
62
62
|
("hidden_states", "encoder_hidden_states"), # Out
|
|
63
63
|
True, # Supported
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
Pattern_5 = (
|
|
67
|
-
False,
|
|
68
|
-
False,
|
|
67
|
+
False, # Return_H_First
|
|
68
|
+
False, # Return_H_Only
|
|
69
69
|
True, # Forward_H_only
|
|
70
|
-
("hidden_states",),
|
|
70
|
+
("hidden_states",), # In
|
|
71
71
|
("encoder_hidden_states", "hidden_states"), # Out
|
|
72
72
|
True, # Supported
|
|
73
73
|
)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
-
from cache_dit.
|
|
4
|
-
from cache_dit.
|
|
3
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
4
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
5
5
|
|
|
6
6
|
from cache_dit.logger import init_logger
|
|
7
7
|
|
|
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
|
|
|
11
11
|
class ParamsModifier:
|
|
12
12
|
def __init__(
|
|
13
13
|
self,
|
|
14
|
-
#
|
|
14
|
+
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
15
15
|
cache_config: BasicCacheConfig = None,
|
|
16
16
|
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
17
17
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
@@ -22,7 +22,7 @@ class ParamsModifier:
|
|
|
22
22
|
|
|
23
23
|
# WARNING: Deprecated cache config params. These parameters are now retained
|
|
24
24
|
# for backward compatibility but will be removed in the future.
|
|
25
|
-
|
|
25
|
+
deprecated_kwargs = {
|
|
26
26
|
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
27
27
|
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
28
28
|
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
@@ -40,20 +40,20 @@ class ParamsModifier:
|
|
|
40
40
|
),
|
|
41
41
|
}
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
k: v for k, v in
|
|
43
|
+
deprecated_kwargs = {
|
|
44
|
+
k: v for k, v in deprecated_kwargs.items() if v is not None
|
|
45
45
|
}
|
|
46
46
|
|
|
47
|
-
if
|
|
47
|
+
if deprecated_kwargs:
|
|
48
48
|
logger.warning(
|
|
49
49
|
"Manually settup DBCache context without BasicCacheConfig is "
|
|
50
50
|
"deprecated and will be removed in the future, please use "
|
|
51
51
|
"`cache_config` parameter instead!"
|
|
52
52
|
)
|
|
53
53
|
if cache_config is not None:
|
|
54
|
-
cache_config.update(**
|
|
54
|
+
cache_config.update(**deprecated_kwargs)
|
|
55
55
|
else:
|
|
56
|
-
cache_config = BasicCacheConfig(**
|
|
56
|
+
cache_config = BasicCacheConfig(**deprecated_kwargs)
|
|
57
57
|
|
|
58
58
|
if cache_config is not None:
|
|
59
59
|
self._context_kwargs["cache_config"] = cache_config
|
|
@@ -68,7 +68,7 @@ class ParamsModifier:
|
|
|
68
68
|
"deprecated and will be removed in the future, please use "
|
|
69
69
|
"`calibrator_config` parameter instead!"
|
|
70
70
|
)
|
|
71
|
-
from cache_dit.
|
|
71
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
72
72
|
TaylorSeerCalibratorConfig,
|
|
73
73
|
)
|
|
74
74
|
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from cache_dit.caching.patch_functors.functor_base import PatchFunctor
|
|
2
|
+
from cache_dit.caching.patch_functors.functor_dit import DiTPatchFunctor
|
|
3
|
+
from cache_dit.caching.patch_functors.functor_flux import FluxPatchFunctor
|
|
4
|
+
from cache_dit.caching.patch_functors.functor_chroma import (
|
|
5
|
+
ChromaPatchFunctor,
|
|
6
|
+
)
|
|
7
|
+
from cache_dit.caching.patch_functors.functor_hidream import (
|
|
8
|
+
HiDreamPatchFunctor,
|
|
9
|
+
)
|
|
10
|
+
from cache_dit.caching.patch_functors.functor_hunyuan_dit import (
|
|
11
|
+
HunyuanDiTPatchFunctor,
|
|
12
|
+
)
|
|
13
|
+
from cache_dit.caching.patch_functors.functor_qwen_image_controlnet import (
|
|
14
|
+
QwenImageControlNetPatchFunctor,
|
|
15
|
+
)
|
|
@@ -6,7 +6,7 @@ from diffusers.models.transformers.dit_transformer_2d import (
|
|
|
6
6
|
DiTTransformer2DModel,
|
|
7
7
|
Transformer2DModelOutput,
|
|
8
8
|
)
|
|
9
|
-
from cache_dit.
|
|
9
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
10
10
|
PatchFunctor,
|
|
11
11
|
)
|
|
12
12
|
from cache_dit.logger import init_logger
|
|
@@ -13,7 +13,7 @@ from diffusers.utils import (
|
|
|
13
13
|
scale_lora_layers,
|
|
14
14
|
unscale_lora_layers,
|
|
15
15
|
)
|
|
16
|
-
from cache_dit.
|
|
16
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
17
17
|
PatchFunctor,
|
|
18
18
|
)
|
|
19
19
|
from cache_dit.logger import init_logger
|
|
@@ -362,9 +362,7 @@ def __patch_transformer_forward__(
|
|
|
362
362
|
)
|
|
363
363
|
if hidden_states_masks is not None:
|
|
364
364
|
# NOTE: Patched
|
|
365
|
-
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
|
|
366
|
-
self.double_stream_blocks[-1].block._block_id
|
|
367
|
-
]
|
|
365
|
+
cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[0]
|
|
368
366
|
encoder_attention_mask_ones = torch.ones(
|
|
369
367
|
(
|
|
370
368
|
batch_size,
|
|
@@ -5,7 +5,7 @@ from diffusers.models.transformers.hunyuan_transformer_2d import (
|
|
|
5
5
|
HunyuanDiTBlock,
|
|
6
6
|
Transformer2DModelOutput,
|
|
7
7
|
)
|
|
8
|
-
from cache_dit.
|
|
8
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
9
9
|
PatchFunctor,
|
|
10
10
|
)
|
|
11
11
|
from cache_dit.logger import init_logger
|