cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/adapters.py +47 -5
- cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
- cache_dit/cache_factory/patch/flux.py +241 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
- cache_dit-0.2.16.dist-info/RECORD +47 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.14.dist-info/RECORD +0 -49
- /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Tuple, Optional, Dict, Any, Union
|
|
6
|
+
from diffusers import FluxTransformer2DModel
|
|
7
|
+
from diffusers.models.transformers.transformer_flux import (
|
|
8
|
+
FluxSingleTransformerBlock,
|
|
9
|
+
Transformer2DModelOutput,
|
|
10
|
+
)
|
|
11
|
+
from diffusers.utils import (
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
scale_lora_layers,
|
|
14
|
+
unscale_lora_layers,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from cache_dit.logger import init_logger
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
24
|
+
def __patch_single_forward__(
|
|
25
|
+
self: FluxSingleTransformerBlock,
|
|
26
|
+
hidden_states: torch.Tensor,
|
|
27
|
+
encoder_hidden_states: torch.Tensor,
|
|
28
|
+
temb: torch.Tensor,
|
|
29
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
30
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
31
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
32
|
+
text_seq_len = encoder_hidden_states.shape[1]
|
|
33
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
34
|
+
|
|
35
|
+
residual = hidden_states
|
|
36
|
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
|
37
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
|
38
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
39
|
+
attn_output = self.attn(
|
|
40
|
+
hidden_states=norm_hidden_states,
|
|
41
|
+
image_rotary_emb=image_rotary_emb,
|
|
42
|
+
**joint_attention_kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
|
46
|
+
gate = gate.unsqueeze(1)
|
|
47
|
+
hidden_states = gate * self.proj_out(hidden_states)
|
|
48
|
+
hidden_states = residual + hidden_states
|
|
49
|
+
if hidden_states.dtype == torch.float16:
|
|
50
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
|
51
|
+
|
|
52
|
+
encoder_hidden_states, hidden_states = (
|
|
53
|
+
hidden_states[:, :text_seq_len],
|
|
54
|
+
hidden_states[:, text_seq_len:],
|
|
55
|
+
)
|
|
56
|
+
return encoder_hidden_states, hidden_states
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L631
|
|
60
|
+
def __patch_transformer_forward__(
|
|
61
|
+
self: FluxTransformer2DModel,
|
|
62
|
+
hidden_states: torch.Tensor,
|
|
63
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
64
|
+
pooled_projections: torch.Tensor = None,
|
|
65
|
+
timestep: torch.LongTensor = None,
|
|
66
|
+
img_ids: torch.Tensor = None,
|
|
67
|
+
txt_ids: torch.Tensor = None,
|
|
68
|
+
guidance: torch.Tensor = None,
|
|
69
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
70
|
+
controlnet_block_samples=None,
|
|
71
|
+
controlnet_single_block_samples=None,
|
|
72
|
+
return_dict: bool = True,
|
|
73
|
+
controlnet_blocks_repeat: bool = False,
|
|
74
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
75
|
+
if joint_attention_kwargs is not None:
|
|
76
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
|
77
|
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
|
78
|
+
else:
|
|
79
|
+
lora_scale = 1.0
|
|
80
|
+
|
|
81
|
+
if USE_PEFT_BACKEND:
|
|
82
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
83
|
+
scale_lora_layers(self, lora_scale)
|
|
84
|
+
else:
|
|
85
|
+
if (
|
|
86
|
+
joint_attention_kwargs is not None
|
|
87
|
+
and joint_attention_kwargs.get("scale", None) is not None
|
|
88
|
+
):
|
|
89
|
+
logger.warning(
|
|
90
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
94
|
+
|
|
95
|
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
|
96
|
+
if guidance is not None:
|
|
97
|
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
98
|
+
|
|
99
|
+
temb = (
|
|
100
|
+
self.time_text_embed(timestep, pooled_projections)
|
|
101
|
+
if guidance is None
|
|
102
|
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
|
103
|
+
)
|
|
104
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
105
|
+
|
|
106
|
+
if txt_ids.ndim == 3:
|
|
107
|
+
logger.warning(
|
|
108
|
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
|
109
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
110
|
+
)
|
|
111
|
+
txt_ids = txt_ids[0]
|
|
112
|
+
if img_ids.ndim == 3:
|
|
113
|
+
logger.warning(
|
|
114
|
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
|
115
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
116
|
+
)
|
|
117
|
+
img_ids = img_ids[0]
|
|
118
|
+
|
|
119
|
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
|
120
|
+
image_rotary_emb = self.pos_embed(ids)
|
|
121
|
+
|
|
122
|
+
if (
|
|
123
|
+
joint_attention_kwargs is not None
|
|
124
|
+
and "ip_adapter_image_embeds" in joint_attention_kwargs
|
|
125
|
+
):
|
|
126
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop(
|
|
127
|
+
"ip_adapter_image_embeds"
|
|
128
|
+
)
|
|
129
|
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
|
130
|
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
|
131
|
+
|
|
132
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
|
133
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
134
|
+
encoder_hidden_states, hidden_states = (
|
|
135
|
+
self._gradient_checkpointing_func(
|
|
136
|
+
block,
|
|
137
|
+
hidden_states,
|
|
138
|
+
encoder_hidden_states,
|
|
139
|
+
temb,
|
|
140
|
+
image_rotary_emb,
|
|
141
|
+
joint_attention_kwargs,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
else:
|
|
146
|
+
encoder_hidden_states, hidden_states = block(
|
|
147
|
+
hidden_states=hidden_states,
|
|
148
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
149
|
+
temb=temb,
|
|
150
|
+
image_rotary_emb=image_rotary_emb,
|
|
151
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# controlnet residual
|
|
155
|
+
if controlnet_block_samples is not None:
|
|
156
|
+
interval_control = len(self.transformer_blocks) / len(
|
|
157
|
+
controlnet_block_samples
|
|
158
|
+
)
|
|
159
|
+
interval_control = int(np.ceil(interval_control))
|
|
160
|
+
# For Xlabs ControlNet.
|
|
161
|
+
if controlnet_blocks_repeat:
|
|
162
|
+
hidden_states = (
|
|
163
|
+
hidden_states
|
|
164
|
+
+ controlnet_block_samples[
|
|
165
|
+
index_block % len(controlnet_block_samples)
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
hidden_states = (
|
|
170
|
+
hidden_states
|
|
171
|
+
+ controlnet_block_samples[index_block // interval_control]
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
175
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
176
|
+
encoder_hidden_states, hidden_states = (
|
|
177
|
+
self._gradient_checkpointing_func(
|
|
178
|
+
block,
|
|
179
|
+
hidden_states,
|
|
180
|
+
encoder_hidden_states,
|
|
181
|
+
temb,
|
|
182
|
+
image_rotary_emb,
|
|
183
|
+
joint_attention_kwargs,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
else:
|
|
188
|
+
encoder_hidden_states, hidden_states = block(
|
|
189
|
+
hidden_states=hidden_states,
|
|
190
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
191
|
+
temb=temb,
|
|
192
|
+
image_rotary_emb=image_rotary_emb,
|
|
193
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# controlnet residual
|
|
197
|
+
if controlnet_single_block_samples is not None:
|
|
198
|
+
interval_control = len(self.single_transformer_blocks) / len(
|
|
199
|
+
controlnet_single_block_samples
|
|
200
|
+
)
|
|
201
|
+
interval_control = int(np.ceil(interval_control))
|
|
202
|
+
hidden_states = (
|
|
203
|
+
hidden_states
|
|
204
|
+
+ controlnet_single_block_samples[
|
|
205
|
+
index_block // interval_control
|
|
206
|
+
]
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
210
|
+
output = self.proj_out(hidden_states)
|
|
211
|
+
|
|
212
|
+
if USE_PEFT_BACKEND:
|
|
213
|
+
# remove `lora_scale` from each PEFT layer
|
|
214
|
+
unscale_lora_layers(self, lora_scale)
|
|
215
|
+
|
|
216
|
+
if not return_dict:
|
|
217
|
+
return (output,)
|
|
218
|
+
|
|
219
|
+
return Transformer2DModelOutput(sample=output)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def maybe_patch_flux_transformer(
|
|
223
|
+
transformer: FluxTransformer2DModel,
|
|
224
|
+
) -> FluxTransformer2DModel:
|
|
225
|
+
single_forward_parameters = inspect.signature(
|
|
226
|
+
transformer.single_transformer_blocks[0].forward
|
|
227
|
+
).parameters.keys()
|
|
228
|
+
if "encoder_hidden_states" not in single_forward_parameters:
|
|
229
|
+
logger.warning("Patch FluxSingleTransformerBlock for cache-dit.")
|
|
230
|
+
for block in transformer.single_transformer_blocks:
|
|
231
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
232
|
+
|
|
233
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
234
|
+
"Please call apply_cache_on_pipe before Parallelize, "
|
|
235
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
236
|
+
"parallized forward and cause a downgrade of performance."
|
|
237
|
+
)
|
|
238
|
+
transformer.forward = __patch_transformer_forward__.__get__(transformer)
|
|
239
|
+
transformer._is_patched = True
|
|
240
|
+
|
|
241
|
+
return transformer
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.16
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -52,20 +52,22 @@ Dynamic: requires-python
|
|
|
52
52
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
53
53
|
<img src=https://img.shields.io/badge/Release-v0.2-brightgreen.svg >
|
|
54
54
|
</div>
|
|
55
|
-
|
|
56
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. <br> CacheDiT offers a set of training-free cache accelerators for Diffusion Transformers: <br> <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">Hybrid TaylorSeer</a>, <a href="#cfg">Hybrid Cache CFG</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
|
|
57
|
-
</p>
|
|
55
|
+
🔥<a href="#dbcache">DBCache</a> | <a href="#dbprune">DBPrune</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a> | <a href="#fbcache">FBCache</a>🔥
|
|
58
56
|
</div>
|
|
59
57
|
|
|
60
58
|
<div align="center">
|
|
61
59
|
<p align="center">
|
|
62
|
-
|
|
60
|
+
♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
|
|
63
61
|
</p>
|
|
64
62
|
</div>
|
|
65
63
|
|
|
66
|
-
## 🔥News
|
|
67
|
-
|
|
68
|
-
- [2025-
|
|
64
|
+
## 🔥News
|
|
65
|
+
|
|
66
|
+
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
67
|
+
- [2025-08-11] 🔥[Qwen-Image](https://github.com/QwenLM/Qwen-Image) is supported now! Please check [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
|
|
68
|
+
- [2025-08-10] 🔥[FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please check [run_flux_kontext.py](./examples/run_flux_kontext.py) as an example.
|
|
69
|
+
- [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
|
|
70
|
+
- [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! **3.3x** speedup for FLUX.1 on NVIDIA L20 with `cache-dit`.
|
|
69
71
|
|
|
70
72
|
## 📖Contents
|
|
71
73
|
|
|
@@ -78,7 +80,6 @@ Dynamic: requires-python
|
|
|
78
80
|
- [⚡️Hybrid Cache CFG](#cfg)
|
|
79
81
|
- [🎉First Block Cache](#fbcache)
|
|
80
82
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
81
|
-
- [🎉Context Parallelism](#context-parallelism)
|
|
82
83
|
- [🔥Torch Compile](#compile)
|
|
83
84
|
- [⚙️Metrics CLI](#metrics)
|
|
84
85
|
- [👋Contribute](#contribute)
|
|
@@ -104,8 +105,10 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
104
105
|
|
|
105
106
|
<div id="supported"></div>
|
|
106
107
|
|
|
108
|
+
- [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
107
109
|
- [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
108
110
|
- [🚀FLUX.1-Fill-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
111
|
+
- [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
109
112
|
- [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
110
113
|
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
111
114
|
- [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -157,7 +160,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
157
160
|
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
|
|
158
161
|
- **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
|
|
159
162
|
|
|
160
|
-
For a good balance between performance and precision, DBCache is configured by default with **
|
|
163
|
+
For a good balance between performance and precision, DBCache is configured by default with **F8B0**, 8 warmup steps, and unlimited cached steps.
|
|
161
164
|
|
|
162
165
|
```python
|
|
163
166
|
from diffusers import FluxPipeline
|
|
@@ -168,16 +171,16 @@ pipe = FluxPipeline.from_pretrained(
|
|
|
168
171
|
torch_dtype=torch.bfloat16,
|
|
169
172
|
).to("cuda")
|
|
170
173
|
|
|
171
|
-
# Default options,
|
|
174
|
+
# Default options, F8B0, good balance between performance and precision
|
|
172
175
|
cache_options = CacheType.default_options(CacheType.DBCache)
|
|
173
176
|
|
|
174
|
-
# Custom options,
|
|
177
|
+
# Custom options, F8B0, higher precision
|
|
175
178
|
cache_options = {
|
|
176
179
|
"cache_type": CacheType.DBCache,
|
|
177
180
|
"warmup_steps": 8,
|
|
178
|
-
"max_cached_steps":
|
|
179
|
-
"Fn_compute_blocks": 8,
|
|
180
|
-
"Bn_compute_blocks":
|
|
181
|
+
"max_cached_steps": -1, # -1 means no limit
|
|
182
|
+
"Fn_compute_blocks": 8, # Fn, F8, etc.
|
|
183
|
+
"Bn_compute_blocks": 0, # Bn, B0, etc.
|
|
181
184
|
"residual_diff_threshold": 0.12,
|
|
182
185
|
}
|
|
183
186
|
|
|
@@ -253,7 +256,7 @@ cache_options = {
|
|
|
253
256
|
# should set do_separate_classifier_free_guidance as False.
|
|
254
257
|
# For example, set it as True for Wan 2.1 and set it as False
|
|
255
258
|
# for FLUX.1, HunyuanVideo, CogVideoX, Mochi.
|
|
256
|
-
"do_separate_classifier_free_guidance": True,
|
|
259
|
+
"do_separate_classifier_free_guidance": True, # Wan 2.1, Qwen-Image
|
|
257
260
|
# Compute cfg forward first or not, default False, namely,
|
|
258
261
|
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
259
262
|
"cfg_compute_first": False,
|
|
@@ -270,7 +273,7 @@ cache_options = {
|
|
|
270
273
|
|
|
271
274
|

|
|
272
275
|
|
|
273
|
-
**DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can
|
|
276
|
+
**DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can use configure **DBCache** with **F1B0** settings to achieve the same functionality.
|
|
274
277
|
|
|
275
278
|
```python
|
|
276
279
|
from diffusers import FluxPipeline
|
|
@@ -281,15 +284,12 @@ pipe = FluxPipeline.from_pretrained(
|
|
|
281
284
|
torch_dtype=torch.bfloat16,
|
|
282
285
|
).to("cuda")
|
|
283
286
|
|
|
284
|
-
# Using FBCache directly
|
|
285
|
-
cache_options = CacheType.default_options(CacheType.FBCache)
|
|
286
|
-
|
|
287
287
|
# Or using DBCache with F1B0.
|
|
288
288
|
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
289
289
|
cache_options = {
|
|
290
290
|
"cache_type": CacheType.DBCache,
|
|
291
291
|
"warmup_steps": 8,
|
|
292
|
-
"max_cached_steps":
|
|
292
|
+
"max_cached_steps": -1, # -1 means no limit
|
|
293
293
|
"Fn_compute_blocks": 1, # Fn, F1, etc.
|
|
294
294
|
"Bn_compute_blocks": 0, # Bn, B0, etc.
|
|
295
295
|
"residual_diff_threshold": 0.12,
|
|
@@ -370,64 +370,6 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
370
370
|
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
371
371
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
372
372
|
|
|
373
|
-
## 🎉Context Parallelism
|
|
374
|
-
|
|
375
|
-
<div id="context-parallelism"></div>
|
|
376
|
-
|
|
377
|
-
**CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can **easily tap into** its **Context Parallelism** features for distributed inference. Firstly, install `para-attn` from PyPI:
|
|
378
|
-
|
|
379
|
-
```bash
|
|
380
|
-
pip3 install para-attn # or install `para-attn` from sources.
|
|
381
|
-
```
|
|
382
|
-
|
|
383
|
-
Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
|
|
384
|
-
|
|
385
|
-
```python
|
|
386
|
-
import torch.distributed as dist
|
|
387
|
-
from diffusers import FluxPipeline
|
|
388
|
-
from para_attn.context_parallel import init_context_parallel_mesh
|
|
389
|
-
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
390
|
-
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
391
|
-
|
|
392
|
-
# Init distributed process group
|
|
393
|
-
dist.init_process_group()
|
|
394
|
-
torch.cuda.set_device(dist.get_rank())
|
|
395
|
-
|
|
396
|
-
pipe = FluxPipeline.from_pretrained(
|
|
397
|
-
"black-forest-labs/FLUX.1-dev",
|
|
398
|
-
torch_dtype=torch.bfloat16,
|
|
399
|
-
).to("cuda")
|
|
400
|
-
|
|
401
|
-
# Context Parallel from ParaAttention
|
|
402
|
-
parallelize_pipe(
|
|
403
|
-
pipe, mesh=init_context_parallel_mesh(
|
|
404
|
-
pipe.device.type, max_ulysses_dim_size=4
|
|
405
|
-
)
|
|
406
|
-
)
|
|
407
|
-
|
|
408
|
-
# DBPrune with default options from this library
|
|
409
|
-
apply_cache_on_pipe(
|
|
410
|
-
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
411
|
-
)
|
|
412
|
-
|
|
413
|
-
dist.destroy_process_group()
|
|
414
|
-
```
|
|
415
|
-
Then, run the python test script with `torchrun`:
|
|
416
|
-
```bash
|
|
417
|
-
torchrun --nproc_per_node=4 parallel_cache.py
|
|
418
|
-
```
|
|
419
|
-
|
|
420
|
-
<div align="center">
|
|
421
|
-
<p align="center">
|
|
422
|
-
DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
423
|
-
</p>
|
|
424
|
-
</div>
|
|
425
|
-
|
|
426
|
-
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
427
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
428
|
-
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
429
|
-
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
430
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
431
373
|
|
|
432
374
|
## 🔥Torch Compile
|
|
433
375
|
|
|
@@ -490,7 +432,7 @@ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
|
|
|
490
432
|
|
|
491
433
|
<div id="license"></div>
|
|
492
434
|
|
|
493
|
-
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from
|
|
435
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from FBCache, please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
|
|
494
436
|
|
|
495
437
|
## ©️Citations
|
|
496
438
|
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
cache_dit/__init__.py,sha256=0-B173-fLi3IA8nJXoS71zK0zD33Xplysd9skmLfEOY,171
|
|
2
|
+
cache_dit/_version.py,sha256=uFGhweCFKwebVyMUvDALfnhYcWJQj8O3h_9xJIOhTtk,513
|
|
3
|
+
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
+
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
|
+
cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
|
|
6
|
+
cache_dit/cache_factory/__init__.py,sha256=2UuUJ-CRXrLbv_ZoC2nV3qPHoipfqeWvO7xZO3CxOD4,263
|
|
7
|
+
cache_dit/cache_factory/adapters.py,sha256=3iHIkhkb_2s1f-W-jw0bCToZyLYvbJlPpxASv4EqrqU,6714
|
|
8
|
+
cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
|
|
9
|
+
cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
|
|
10
|
+
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=gkouVv5IgcsiTOQ5_I-a3S3TnJifNZhrwdkrO1KRCqw,120
|
|
11
|
+
cache_dit/cache_factory/dual_block_cache/cache_blocks.py,sha256=M9R_6t-X6vrsSNvSMQLn24I_fZ9EpFK4RlRzgAtEdac,18407
|
|
12
|
+
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=28g3lqgsarCk-u880QF0yl967glXIZrWpUfJYKJZqxg,39872
|
|
13
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=uSqF5aD2-feHB25vEbx1STBQVjWVAOn_wYTdAEmS4NU,2045
|
|
14
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=RNF24Ysuddo5cjv4hUT4s3C4C3pgud5YJ1OklvtrrlU,2286
|
|
15
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=URWqROg_ANseraFf-on5ZqvJzD7tH3WouboFaJgOAkk,2853
|
|
16
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=y0Ab_tTJOixTaK-Remg5IZLCDGIgBOZFszI32PnK9gc,9981
|
|
17
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=nqpbVjT0vFecIIqVGBvCUIBysOtFFjh-MhrXw4VPBGo,2278
|
|
18
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py,sha256=qhwpGDxx5RBbumSArc38z4aMv5YVwxH3eiyIpzjXycQ,2322
|
|
19
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=P9-r0HNofBO8rWpB4pAQvGOEbsfiHUAKV16Noq8ZzWg,2610
|
|
20
|
+
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=d_cbjrkBaprzQw4HJI3sRtqMzrhDVL56moLkpSGnqO4,123
|
|
21
|
+
cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py,sha256=6L9WhuWP6GHFRLnGRO_XZ1oE4v9xm-yqQ7cgNP_fOdY,9704
|
|
22
|
+
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=94e0VKmbU09hb-s8JQ9vOYTJtZ07gva7XT9OQm8EBSc,25530
|
|
23
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=9WLorIqd7m_dIDs6pPPj-lcd9e56fdvGM2D3DDdWwEU,2045
|
|
24
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=Ny1f3gvtcbVbTvGXNdyv7YC2Ez1rGonttx7xnsvfHM4,2299
|
|
25
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=ibANvG29Sa2CU2pG_IndkEYarshuQDWGTP3klnAStV0,2792
|
|
26
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=ylNgdkCgFwWbuCb2sbyIV6XoHTd8MzGbgMO0t1Z6D50,9994
|
|
27
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=iMZd473vkh2hsJRIlKN2DCsCRJmsV4JSB7SZzZdg4eI,2291
|
|
28
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/qwen_image.py,sha256=0MgHcT4EVUcUnq8-g4GO2p11MzMtRhQp5D3wG9t_KlA,2398
|
|
29
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=VOdnpIZRbcyKHP7e4yRPSlS5hZSyq4dcp9fmHOfI-Fc,2630
|
|
30
|
+
cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
+
cache_dit/cache_factory/patch/flux.py,sha256=zRUWZlt02vZGJcK2WOfmSbGR4UOUvsFvSBDgeBNZxh8,8813
|
|
32
|
+
cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
|
|
33
|
+
cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
|
|
34
|
+
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
|
+
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
+
cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
|
|
37
|
+
cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
|
|
38
|
+
cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
|
|
39
|
+
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
40
|
+
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
41
|
+
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
42
|
+
cache_dit-0.2.16.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
43
|
+
cache_dit-0.2.16.dist-info/METADATA,sha256=y8oidHX3B0iZXfjym_mHtGgB0fGE4fiEA6JIlIeWaRo,22769
|
|
44
|
+
cache_dit-0.2.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
45
|
+
cache_dit-0.2.16.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
46
|
+
cache_dit-0.2.16.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
47
|
+
cache_dit-0.2.16.dist-info/RECORD,,
|