cache-dit 0.1.7__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +8 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +99 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +12 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +99 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +2 -2
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +2 -2
- {cache_dit-0.1.7.dist-info → cache_dit-0.2.0.dist-info}/METADATA +45 -40
- {cache_dit-0.1.7.dist-info → cache_dit-0.2.0.dist-info}/RECORD +16 -11
- {cache_dit-0.1.7.dist-info → cache_dit-0.2.0.dist-info}/WHEEL +0 -0
- {cache_dit-0.1.7.dist-info → cache_dit-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.1.7.dist-info → cache_dit-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/hunyuan_video.py
|
|
2
|
+
import functools
|
|
3
|
+
import unittest
|
|
4
|
+
from typing import Any, Dict, Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
|
|
8
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
9
|
+
from diffusers.utils import (
|
|
10
|
+
scale_lora_layers,
|
|
11
|
+
unscale_lora_layers,
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from para_attn.para_attn_interface import SparseKVAttnMode
|
|
20
|
+
|
|
21
|
+
def is_sparse_kv_attn_available():
|
|
22
|
+
return True
|
|
23
|
+
|
|
24
|
+
except ImportError:
|
|
25
|
+
|
|
26
|
+
class SparseKVAttnMode:
|
|
27
|
+
def __enter__(self):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def is_sparse_kv_attn_available():
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = init_logger(__name__) # pylint: disable=invalid-name
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def apply_cache_on_transformer(
|
|
41
|
+
transformer: HunyuanVideoTransformer3DModel,
|
|
42
|
+
):
|
|
43
|
+
if getattr(transformer, "_is_cached", False):
|
|
44
|
+
return transformer
|
|
45
|
+
|
|
46
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
47
|
+
[
|
|
48
|
+
cache_context.CachedTransformerBlocks(
|
|
49
|
+
transformer.transformer_blocks
|
|
50
|
+
+ transformer.single_transformer_blocks,
|
|
51
|
+
transformer=transformer,
|
|
52
|
+
)
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
|
56
|
+
|
|
57
|
+
original_forward = transformer.forward
|
|
58
|
+
|
|
59
|
+
@functools.wraps(transformer.__class__.forward)
|
|
60
|
+
def new_forward(
|
|
61
|
+
self,
|
|
62
|
+
hidden_states: torch.Tensor,
|
|
63
|
+
timestep: torch.LongTensor,
|
|
64
|
+
encoder_hidden_states: torch.Tensor,
|
|
65
|
+
encoder_attention_mask: torch.Tensor,
|
|
66
|
+
pooled_projections: torch.Tensor,
|
|
67
|
+
guidance: torch.Tensor = None,
|
|
68
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
69
|
+
return_dict: bool = True,
|
|
70
|
+
**kwargs,
|
|
71
|
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
72
|
+
with (
|
|
73
|
+
unittest.mock.patch.object(
|
|
74
|
+
self,
|
|
75
|
+
"transformer_blocks",
|
|
76
|
+
cached_transformer_blocks,
|
|
77
|
+
),
|
|
78
|
+
unittest.mock.patch.object(
|
|
79
|
+
self,
|
|
80
|
+
"single_transformer_blocks",
|
|
81
|
+
dummy_single_transformer_blocks,
|
|
82
|
+
),
|
|
83
|
+
):
|
|
84
|
+
if getattr(self, "_is_parallelized", False):
|
|
85
|
+
return original_forward(
|
|
86
|
+
hidden_states,
|
|
87
|
+
timestep,
|
|
88
|
+
encoder_hidden_states,
|
|
89
|
+
encoder_attention_mask,
|
|
90
|
+
pooled_projections,
|
|
91
|
+
guidance=guidance,
|
|
92
|
+
attention_kwargs=attention_kwargs,
|
|
93
|
+
return_dict=return_dict,
|
|
94
|
+
**kwargs,
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
if attention_kwargs is not None:
|
|
98
|
+
attention_kwargs = attention_kwargs.copy()
|
|
99
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
100
|
+
else:
|
|
101
|
+
lora_scale = 1.0
|
|
102
|
+
|
|
103
|
+
if USE_PEFT_BACKEND:
|
|
104
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
105
|
+
scale_lora_layers(self, lora_scale)
|
|
106
|
+
else:
|
|
107
|
+
if (
|
|
108
|
+
attention_kwargs is not None
|
|
109
|
+
and attention_kwargs.get("scale", None) is not None
|
|
110
|
+
):
|
|
111
|
+
logger.warning(
|
|
112
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
batch_size, num_channels, num_frames, height, width = (
|
|
116
|
+
hidden_states.shape
|
|
117
|
+
)
|
|
118
|
+
p, p_t = self.config.patch_size, self.config.patch_size_t
|
|
119
|
+
post_patch_num_frames = num_frames // p_t
|
|
120
|
+
post_patch_height = height // p
|
|
121
|
+
post_patch_width = width // p
|
|
122
|
+
|
|
123
|
+
# 1. RoPE
|
|
124
|
+
image_rotary_emb = self.rope(hidden_states)
|
|
125
|
+
|
|
126
|
+
# 2. Conditional embeddings
|
|
127
|
+
temb = self.time_text_embed(
|
|
128
|
+
timestep, guidance, pooled_projections
|
|
129
|
+
)
|
|
130
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
131
|
+
encoder_hidden_states = self.context_embedder(
|
|
132
|
+
encoder_hidden_states, timestep, encoder_attention_mask
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# 3. Attention mask preparation
|
|
136
|
+
latent_sequence_length = hidden_states.shape[1]
|
|
137
|
+
latent_attention_mask = torch.ones(
|
|
138
|
+
batch_size,
|
|
139
|
+
1,
|
|
140
|
+
latent_sequence_length,
|
|
141
|
+
device=hidden_states.device,
|
|
142
|
+
dtype=torch.bool,
|
|
143
|
+
) # [B, 1, N]
|
|
144
|
+
attention_mask = torch.cat(
|
|
145
|
+
[
|
|
146
|
+
latent_attention_mask,
|
|
147
|
+
encoder_attention_mask.unsqueeze(1).to(torch.bool),
|
|
148
|
+
],
|
|
149
|
+
dim=-1,
|
|
150
|
+
) # [B, 1, N + M]
|
|
151
|
+
|
|
152
|
+
with SparseKVAttnMode():
|
|
153
|
+
# 4. Transformer blocks
|
|
154
|
+
hidden_states, encoder_hidden_states = (
|
|
155
|
+
self.call_transformer_blocks(
|
|
156
|
+
hidden_states,
|
|
157
|
+
encoder_hidden_states,
|
|
158
|
+
temb,
|
|
159
|
+
attention_mask,
|
|
160
|
+
image_rotary_emb,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# 5. Output projection
|
|
165
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
166
|
+
hidden_states = self.proj_out(hidden_states)
|
|
167
|
+
|
|
168
|
+
hidden_states = hidden_states.reshape(
|
|
169
|
+
batch_size,
|
|
170
|
+
post_patch_num_frames,
|
|
171
|
+
post_patch_height,
|
|
172
|
+
post_patch_width,
|
|
173
|
+
-1,
|
|
174
|
+
p_t,
|
|
175
|
+
p,
|
|
176
|
+
p,
|
|
177
|
+
)
|
|
178
|
+
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
|
179
|
+
hidden_states = (
|
|
180
|
+
hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
hidden_states = hidden_states.to(timestep.dtype)
|
|
184
|
+
|
|
185
|
+
if USE_PEFT_BACKEND:
|
|
186
|
+
# remove `lora_scale` from each PEFT layer
|
|
187
|
+
unscale_lora_layers(self, lora_scale)
|
|
188
|
+
|
|
189
|
+
if not return_dict:
|
|
190
|
+
return (hidden_states,)
|
|
191
|
+
|
|
192
|
+
return Transformer2DModelOutput(sample=hidden_states)
|
|
193
|
+
|
|
194
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
195
|
+
|
|
196
|
+
def call_transformer_blocks(
|
|
197
|
+
self, hidden_states, encoder_hidden_states, *args, **kwargs
|
|
198
|
+
):
|
|
199
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
200
|
+
|
|
201
|
+
def create_custom_forward(module, return_dict=None):
|
|
202
|
+
def custom_forward(*inputs):
|
|
203
|
+
if return_dict is not None:
|
|
204
|
+
return module(*inputs, return_dict=return_dict)
|
|
205
|
+
else:
|
|
206
|
+
return module(*inputs)
|
|
207
|
+
|
|
208
|
+
return custom_forward
|
|
209
|
+
|
|
210
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
|
|
211
|
+
|
|
212
|
+
for block in self.transformer_blocks:
|
|
213
|
+
hidden_states, encoder_hidden_states = (
|
|
214
|
+
torch.utils.checkpoint.checkpoint(
|
|
215
|
+
create_custom_forward(block),
|
|
216
|
+
hidden_states,
|
|
217
|
+
encoder_hidden_states,
|
|
218
|
+
*args,
|
|
219
|
+
**kwargs,
|
|
220
|
+
**ckpt_kwargs,
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
for block in self.single_transformer_blocks:
|
|
225
|
+
hidden_states, encoder_hidden_states = (
|
|
226
|
+
torch.utils.checkpoint.checkpoint(
|
|
227
|
+
create_custom_forward(block),
|
|
228
|
+
hidden_states,
|
|
229
|
+
encoder_hidden_states,
|
|
230
|
+
*args,
|
|
231
|
+
**kwargs,
|
|
232
|
+
**ckpt_kwargs,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
else:
|
|
237
|
+
for block in self.transformer_blocks:
|
|
238
|
+
hidden_states, encoder_hidden_states = block(
|
|
239
|
+
hidden_states, encoder_hidden_states, *args, **kwargs
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
for block in self.single_transformer_blocks:
|
|
243
|
+
hidden_states, encoder_hidden_states = block(
|
|
244
|
+
hidden_states, encoder_hidden_states, *args, **kwargs
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return hidden_states, encoder_hidden_states
|
|
248
|
+
|
|
249
|
+
transformer.call_transformer_blocks = call_transformer_blocks.__get__(
|
|
250
|
+
transformer
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
transformer._is_cached = True
|
|
254
|
+
|
|
255
|
+
return transformer
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def apply_cache_on_pipe(
|
|
259
|
+
pipe: DiffusionPipeline,
|
|
260
|
+
*,
|
|
261
|
+
shallow_patch: bool = False,
|
|
262
|
+
residual_diff_threshold=0.06,
|
|
263
|
+
downsample_factor=1,
|
|
264
|
+
warmup_steps=0,
|
|
265
|
+
max_cached_steps=-1,
|
|
266
|
+
**kwargs,
|
|
267
|
+
):
|
|
268
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
269
|
+
default_attrs={
|
|
270
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
271
|
+
"downsample_factor": downsample_factor,
|
|
272
|
+
"warmup_steps": warmup_steps,
|
|
273
|
+
"max_cached_steps": max_cached_steps,
|
|
274
|
+
},
|
|
275
|
+
**kwargs,
|
|
276
|
+
)
|
|
277
|
+
if not getattr(pipe, "_is_cached", False):
|
|
278
|
+
original_call = pipe.__class__.__call__
|
|
279
|
+
|
|
280
|
+
@functools.wraps(original_call)
|
|
281
|
+
def new_call(self, *args, **kwargs):
|
|
282
|
+
with cache_context.cache_context(
|
|
283
|
+
cache_context.create_cache_context(
|
|
284
|
+
**cache_kwargs,
|
|
285
|
+
)
|
|
286
|
+
):
|
|
287
|
+
return original_call(self, *args, **kwargs)
|
|
288
|
+
|
|
289
|
+
pipe.__class__.__call__ = new_call
|
|
290
|
+
pipe.__class__._is_cached = True
|
|
291
|
+
|
|
292
|
+
if not shallow_patch:
|
|
293
|
+
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
294
|
+
|
|
295
|
+
return pipe
|
|
@@ -4,13 +4,13 @@ import functools
|
|
|
4
4
|
import unittest
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from diffusers import DiffusionPipeline,
|
|
7
|
+
from diffusers import DiffusionPipeline, WanTransformer3DModel
|
|
8
8
|
|
|
9
9
|
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def apply_cache_on_transformer(
|
|
13
|
-
transformer:
|
|
13
|
+
transformer: WanTransformer3DModel,
|
|
14
14
|
):
|
|
15
15
|
if getattr(transformer, "_is_cached", False):
|
|
16
16
|
return transformer
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -35,7 +35,7 @@ Dynamic: requires-python
|
|
|
35
35
|
|
|
36
36
|
<div align="center">
|
|
37
37
|
<p align="center">
|
|
38
|
-
<
|
|
38
|
+
<h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
|
|
39
39
|
</p>
|
|
40
40
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
|
|
41
41
|
<div align='center'>
|
|
@@ -44,13 +44,28 @@ Dynamic: requires-python
|
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.2.0-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
50
|
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
51
51
|
</p>
|
|
52
|
+
<p align="center">
|
|
53
|
+
<h4> 🔥Supported Models🔥</h4>
|
|
54
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
55
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
56
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
57
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
58
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
59
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
60
|
+
</p>
|
|
52
61
|
</div>
|
|
53
62
|
|
|
63
|
+
## 👋 Highlight
|
|
64
|
+
|
|
65
|
+
<div id="reference"></div>
|
|
66
|
+
|
|
67
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
|
|
68
|
+
|
|
54
69
|
## 🤗 Introduction
|
|
55
70
|
|
|
56
71
|
<div align="center">
|
|
@@ -91,6 +106,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
91
106
|
|
|
92
107
|
**DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
|
|
93
108
|
|
|
109
|
+
<div align="center">
|
|
110
|
+
<p align="center">
|
|
111
|
+
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
112
|
+
</p>
|
|
113
|
+
</div>
|
|
114
|
+
|
|
94
115
|
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
95
116
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
96
117
|
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
@@ -98,15 +119,29 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
98
119
|
|
|
99
120
|
<div align="center">
|
|
100
121
|
<p align="center">
|
|
101
|
-
|
|
122
|
+
<h3>🔥 Context Parallelism and Torch Compile</h3>
|
|
123
|
+
</p>
|
|
124
|
+
</div>
|
|
125
|
+
|
|
126
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
127
|
+
|
|
128
|
+
<div align="center">
|
|
129
|
+
<p align="center">
|
|
130
|
+
DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
102
131
|
</p>
|
|
103
132
|
</div>
|
|
104
133
|
|
|
105
|
-
|
|
134
|
+
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
135
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
136
|
+
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
137
|
+
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
138
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
106
139
|
|
|
107
|
-
<
|
|
108
|
-
|
|
109
|
-
|
|
140
|
+
<div align="center">
|
|
141
|
+
<p align="center">
|
|
142
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
143
|
+
</p>
|
|
144
|
+
</div>
|
|
110
145
|
|
|
111
146
|
## ©️Citations
|
|
112
147
|
|
|
@@ -120,12 +155,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
120
155
|
}
|
|
121
156
|
```
|
|
122
157
|
|
|
123
|
-
## 👋Reference
|
|
124
|
-
|
|
125
|
-
<div id="reference"></div>
|
|
126
|
-
|
|
127
|
-
The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
|
|
128
|
-
|
|
129
158
|
## 📖Contents
|
|
130
159
|
|
|
131
160
|
<div id="contents"></div>
|
|
@@ -136,11 +165,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
|
|
|
136
165
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
137
166
|
- [🎉Context Parallelism](#context-parallelism)
|
|
138
167
|
- [🔥Torch Compile](#compile)
|
|
139
|
-
- [🎉Supported Models](#supported)
|
|
140
168
|
- [👋Contribute](#contribute)
|
|
141
169
|
- [©️License](#license)
|
|
142
170
|
|
|
143
|
-
|
|
144
171
|
## ⚙️Installation
|
|
145
172
|
|
|
146
173
|
<div id="installation"></div>
|
|
@@ -371,23 +398,11 @@ Then, run the python test script with `torchrun`:
|
|
|
371
398
|
torchrun --nproc_per_node=4 parallel_cache.py
|
|
372
399
|
```
|
|
373
400
|
|
|
374
|
-
<div align="center">
|
|
375
|
-
<p align="center">
|
|
376
|
-
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
377
|
-
</p>
|
|
378
|
-
</div>
|
|
379
|
-
|
|
380
|
-
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
381
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
382
|
-
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
383
|
-
|8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
|
|
384
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
385
|
-
|
|
386
401
|
## 🔥Torch Compile
|
|
387
402
|
|
|
388
403
|
<div id="compile"></div>
|
|
389
404
|
|
|
390
|
-
**CacheDiT**
|
|
405
|
+
By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
|
|
391
406
|
|
|
392
407
|
```python
|
|
393
408
|
apply_cache_on_pipe(
|
|
@@ -396,21 +411,11 @@ apply_cache_on_pipe(
|
|
|
396
411
|
# Compile the Transformer module
|
|
397
412
|
pipe.transformer = torch.compile(pipe.transformer)
|
|
398
413
|
```
|
|
399
|
-
However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo
|
|
400
|
-
|
|
414
|
+
However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
401
415
|
```python
|
|
402
416
|
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
403
417
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
404
418
|
```
|
|
405
|
-
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
406
|
-
|
|
407
|
-
## 🎉Supported Models
|
|
408
|
-
|
|
409
|
-
<div id="supported"></div>
|
|
410
|
-
|
|
411
|
-
- [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
412
|
-
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
413
|
-
- [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
414
419
|
|
|
415
420
|
## 👋Contribute
|
|
416
421
|
<div id="contribute"></div>
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=iB5DfB5V6YB5Wo4JmvS-txT42QtmGaWcWp3udRT7zCI,511
|
|
3
3
|
cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
|
|
@@ -7,25 +7,30 @@ cache_dit/cache_factory/taylorseer.py,sha256=0W29ykJg3MnyLAB2KFicsl11Xe41cDYPgI6
|
|
|
7
7
|
cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=EJ-uhA2-sWMW1jNDhcBtjHDqSn8lUzfKbYoPfZDQhZU,49665
|
|
10
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=
|
|
10
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=ySmO_0IuSm5VrYdi9ccGVHoFVkgtPZMJGq_OMoyl0Q8,2003
|
|
11
11
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=1_n-RFMiL3v2SjhSfFrPH5Mn5Dq9z4BesVK8GN_nh2g,2404
|
|
12
12
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=UbE6nIF-EtA92QxIZVMzIssdZKQSPAVX1hchF9R8drU,2754
|
|
13
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=k3fnlVRTSFkEfWAJ136EMVCeOHPklsArKWFQGMMyLuM,10102
|
|
13
14
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=qxMu1L3ycT8F-uxpGsmFQBY_BH1vDiGIOXgS_Qbb7dM,2391
|
|
15
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=M_O91sVaHSJgptv6OY5MhT4vD3gQAUTETEmP1kLdE6c,2713
|
|
14
16
|
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=
|
|
16
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=
|
|
17
|
+
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=YRDwZ_16yjThpgVgDv6YaIB4QCE9nEkE-MOru0jOd50,35026
|
|
18
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=WNBk4GeekjG2Ln1pAer20TVHpWyGyyHN50DdqWuPhc0,2003
|
|
17
19
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=ORJpdkXkgziDUo-rpebC6pUemgYaDCoeu0cwwLz175U,2407
|
|
18
20
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=KbEkLSsHtS6xwLWNh3jlOlXRyGRdrI2pWV1zyQxMTj4,2757
|
|
21
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=v3SgLJQPbEqSVy2sYVGMhptJqCe-XPXL7LIV7GEacZg,10105
|
|
19
22
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=rgeXfww-7WX6URSDg7mF1HuxSmYmoJVjMVoNGuxjwxc,2395
|
|
23
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=JCRIq7kt59nbv7t10FdfmFZd1GawUZu2tYQ_QBB8zXQ,2713
|
|
20
24
|
cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
25
|
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=DpDhtK095PlrvACf7sbjOt2-QpVkV1arr1qGEKJqgaQ,23502
|
|
22
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256
|
|
26
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
|
|
23
27
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
|
|
24
28
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
|
|
29
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sha256=OL7W4ukYlZz0IDmBR1zVV6XT3Mgciglj9Hqzv1wUAkQ,10092
|
|
25
30
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
|
|
26
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=
|
|
27
|
-
cache_dit-0.
|
|
28
|
-
cache_dit-0.
|
|
29
|
-
cache_dit-0.
|
|
30
|
-
cache_dit-0.
|
|
31
|
-
cache_dit-0.
|
|
31
|
+
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
|
|
32
|
+
cache_dit-0.2.0.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
33
|
+
cache_dit-0.2.0.dist-info/METADATA,sha256=WK3Fu8euIwLlm3TXjJws9VzwNjEfcMNvkCSJRt7jEdo,21845
|
|
34
|
+
cache_dit-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
+
cache_dit-0.2.0.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
36
|
+
cache_dit-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|