cache-dit 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -0,0 +1,295 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/hunyuan_video.py
2
+ import functools
3
+ import unittest
4
+ from typing import Any, Dict, Optional, Union
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
8
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
9
+ from diffusers.utils import (
10
+ scale_lora_layers,
11
+ unscale_lora_layers,
12
+ USE_PEFT_BACKEND,
13
+ )
14
+
15
+ from cache_dit.cache_factory.first_block_cache import cache_context
16
+ from cache_dit.logger import init_logger
17
+
18
+ try:
19
+ from para_attn.para_attn_interface import SparseKVAttnMode
20
+
21
+ def is_sparse_kv_attn_available():
22
+ return True
23
+
24
+ except ImportError:
25
+
26
+ class SparseKVAttnMode:
27
+ def __enter__(self):
28
+ pass
29
+
30
+ def __exit__(self, exc_type, exc_value, traceback):
31
+ pass
32
+
33
+ def is_sparse_kv_attn_available():
34
+ return False
35
+
36
+
37
+ logger = init_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ def apply_cache_on_transformer(
41
+ transformer: HunyuanVideoTransformer3DModel,
42
+ ):
43
+ if getattr(transformer, "_is_cached", False):
44
+ return transformer
45
+
46
+ cached_transformer_blocks = torch.nn.ModuleList(
47
+ [
48
+ cache_context.CachedTransformerBlocks(
49
+ transformer.transformer_blocks
50
+ + transformer.single_transformer_blocks,
51
+ transformer=transformer,
52
+ )
53
+ ]
54
+ )
55
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
56
+
57
+ original_forward = transformer.forward
58
+
59
+ @functools.wraps(transformer.__class__.forward)
60
+ def new_forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ timestep: torch.LongTensor,
64
+ encoder_hidden_states: torch.Tensor,
65
+ encoder_attention_mask: torch.Tensor,
66
+ pooled_projections: torch.Tensor,
67
+ guidance: torch.Tensor = None,
68
+ attention_kwargs: Optional[Dict[str, Any]] = None,
69
+ return_dict: bool = True,
70
+ **kwargs,
71
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
72
+ with (
73
+ unittest.mock.patch.object(
74
+ self,
75
+ "transformer_blocks",
76
+ cached_transformer_blocks,
77
+ ),
78
+ unittest.mock.patch.object(
79
+ self,
80
+ "single_transformer_blocks",
81
+ dummy_single_transformer_blocks,
82
+ ),
83
+ ):
84
+ if getattr(self, "_is_parallelized", False):
85
+ return original_forward(
86
+ hidden_states,
87
+ timestep,
88
+ encoder_hidden_states,
89
+ encoder_attention_mask,
90
+ pooled_projections,
91
+ guidance=guidance,
92
+ attention_kwargs=attention_kwargs,
93
+ return_dict=return_dict,
94
+ **kwargs,
95
+ )
96
+ else:
97
+ if attention_kwargs is not None:
98
+ attention_kwargs = attention_kwargs.copy()
99
+ lora_scale = attention_kwargs.pop("scale", 1.0)
100
+ else:
101
+ lora_scale = 1.0
102
+
103
+ if USE_PEFT_BACKEND:
104
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
105
+ scale_lora_layers(self, lora_scale)
106
+ else:
107
+ if (
108
+ attention_kwargs is not None
109
+ and attention_kwargs.get("scale", None) is not None
110
+ ):
111
+ logger.warning(
112
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
113
+ )
114
+
115
+ batch_size, num_channels, num_frames, height, width = (
116
+ hidden_states.shape
117
+ )
118
+ p, p_t = self.config.patch_size, self.config.patch_size_t
119
+ post_patch_num_frames = num_frames // p_t
120
+ post_patch_height = height // p
121
+ post_patch_width = width // p
122
+
123
+ # 1. RoPE
124
+ image_rotary_emb = self.rope(hidden_states)
125
+
126
+ # 2. Conditional embeddings
127
+ temb = self.time_text_embed(
128
+ timestep, guidance, pooled_projections
129
+ )
130
+ hidden_states = self.x_embedder(hidden_states)
131
+ encoder_hidden_states = self.context_embedder(
132
+ encoder_hidden_states, timestep, encoder_attention_mask
133
+ )
134
+
135
+ # 3. Attention mask preparation
136
+ latent_sequence_length = hidden_states.shape[1]
137
+ latent_attention_mask = torch.ones(
138
+ batch_size,
139
+ 1,
140
+ latent_sequence_length,
141
+ device=hidden_states.device,
142
+ dtype=torch.bool,
143
+ ) # [B, 1, N]
144
+ attention_mask = torch.cat(
145
+ [
146
+ latent_attention_mask,
147
+ encoder_attention_mask.unsqueeze(1).to(torch.bool),
148
+ ],
149
+ dim=-1,
150
+ ) # [B, 1, N + M]
151
+
152
+ with SparseKVAttnMode():
153
+ # 4. Transformer blocks
154
+ hidden_states, encoder_hidden_states = (
155
+ self.call_transformer_blocks(
156
+ hidden_states,
157
+ encoder_hidden_states,
158
+ temb,
159
+ attention_mask,
160
+ image_rotary_emb,
161
+ )
162
+ )
163
+
164
+ # 5. Output projection
165
+ hidden_states = self.norm_out(hidden_states, temb)
166
+ hidden_states = self.proj_out(hidden_states)
167
+
168
+ hidden_states = hidden_states.reshape(
169
+ batch_size,
170
+ post_patch_num_frames,
171
+ post_patch_height,
172
+ post_patch_width,
173
+ -1,
174
+ p_t,
175
+ p,
176
+ p,
177
+ )
178
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
179
+ hidden_states = (
180
+ hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
181
+ )
182
+
183
+ hidden_states = hidden_states.to(timestep.dtype)
184
+
185
+ if USE_PEFT_BACKEND:
186
+ # remove `lora_scale` from each PEFT layer
187
+ unscale_lora_layers(self, lora_scale)
188
+
189
+ if not return_dict:
190
+ return (hidden_states,)
191
+
192
+ return Transformer2DModelOutput(sample=hidden_states)
193
+
194
+ transformer.forward = new_forward.__get__(transformer)
195
+
196
+ def call_transformer_blocks(
197
+ self, hidden_states, encoder_hidden_states, *args, **kwargs
198
+ ):
199
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
200
+
201
+ def create_custom_forward(module, return_dict=None):
202
+ def custom_forward(*inputs):
203
+ if return_dict is not None:
204
+ return module(*inputs, return_dict=return_dict)
205
+ else:
206
+ return module(*inputs)
207
+
208
+ return custom_forward
209
+
210
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
211
+
212
+ for block in self.transformer_blocks:
213
+ hidden_states, encoder_hidden_states = (
214
+ torch.utils.checkpoint.checkpoint(
215
+ create_custom_forward(block),
216
+ hidden_states,
217
+ encoder_hidden_states,
218
+ *args,
219
+ **kwargs,
220
+ **ckpt_kwargs,
221
+ )
222
+ )
223
+
224
+ for block in self.single_transformer_blocks:
225
+ hidden_states, encoder_hidden_states = (
226
+ torch.utils.checkpoint.checkpoint(
227
+ create_custom_forward(block),
228
+ hidden_states,
229
+ encoder_hidden_states,
230
+ *args,
231
+ **kwargs,
232
+ **ckpt_kwargs,
233
+ )
234
+ )
235
+
236
+ else:
237
+ for block in self.transformer_blocks:
238
+ hidden_states, encoder_hidden_states = block(
239
+ hidden_states, encoder_hidden_states, *args, **kwargs
240
+ )
241
+
242
+ for block in self.single_transformer_blocks:
243
+ hidden_states, encoder_hidden_states = block(
244
+ hidden_states, encoder_hidden_states, *args, **kwargs
245
+ )
246
+
247
+ return hidden_states, encoder_hidden_states
248
+
249
+ transformer.call_transformer_blocks = call_transformer_blocks.__get__(
250
+ transformer
251
+ )
252
+
253
+ transformer._is_cached = True
254
+
255
+ return transformer
256
+
257
+
258
+ def apply_cache_on_pipe(
259
+ pipe: DiffusionPipeline,
260
+ *,
261
+ shallow_patch: bool = False,
262
+ residual_diff_threshold=0.06,
263
+ downsample_factor=1,
264
+ warmup_steps=0,
265
+ max_cached_steps=-1,
266
+ **kwargs,
267
+ ):
268
+ cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
269
+ default_attrs={
270
+ "residual_diff_threshold": residual_diff_threshold,
271
+ "downsample_factor": downsample_factor,
272
+ "warmup_steps": warmup_steps,
273
+ "max_cached_steps": max_cached_steps,
274
+ },
275
+ **kwargs,
276
+ )
277
+ if not getattr(pipe, "_is_cached", False):
278
+ original_call = pipe.__class__.__call__
279
+
280
+ @functools.wraps(original_call)
281
+ def new_call(self, *args, **kwargs):
282
+ with cache_context.cache_context(
283
+ cache_context.create_cache_context(
284
+ **cache_kwargs,
285
+ )
286
+ ):
287
+ return original_call(self, *args, **kwargs)
288
+
289
+ pipe.__class__.__call__ = new_call
290
+ pipe.__class__._is_cached = True
291
+
292
+ if not shallow_patch:
293
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
294
+
295
+ return pipe
@@ -4,13 +4,13 @@ import functools
4
4
  import unittest
5
5
 
6
6
  import torch
7
- from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
7
+ from diffusers import DiffusionPipeline, WanTransformer3DModel
8
8
 
9
9
  from cache_dit.cache_factory.first_block_cache import cache_context
10
10
 
11
11
 
12
12
  def apply_cache_on_transformer(
13
- transformer: HunyuanVideoTransformer3DModel,
13
+ transformer: WanTransformer3DModel,
14
14
  ):
15
15
  if getattr(transformer, "_is_cached", False):
16
16
  return transformer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.8
3
+ Version: 0.2.1
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -44,31 +44,18 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.1.8-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.2.1-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
50
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
51
51
  </p>
52
- <p align="center">
53
- <h3> 🔥Supported Models🔥</h2>
54
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
55
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
56
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
57
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: 🔜DBCache, 🔜DBPrune, ✔️FBCache🔥</a> <br> <br>
58
- <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
59
- </p>
60
52
  </div>
61
53
 
54
+ ## 👋 Highlight
62
55
 
63
- <!--
64
- ## 🎉Supported Models
65
- <div id="supported"></div>
66
- - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
67
- - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
68
- - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
69
- - [🚀Wan2.1**](https://github.com/vipshop/cache-dit/raw/main/examples): *🔜DBCache, 🔜DBPrune, ✔️FBCache*
70
- -->
56
+ <div id="reference"></div>
71
57
 
58
+ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
72
59
 
73
60
  ## 🤗 Introduction
74
61
 
@@ -110,6 +97,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
110
97
 
111
98
  **DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
112
99
 
100
+ <div align="center">
101
+ <p align="center">
102
+ DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
103
+ </p>
104
+ </div>
105
+
113
106
  |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
114
107
  |:---:|:---:|:---:|:---:|:---:|:---:|
115
108
  |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
@@ -117,11 +110,11 @@ These case studies demonstrate that even with relatively high thresholds (such a
117
110
 
118
111
  <div align="center">
119
112
  <p align="center">
120
- DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
113
+ <h3>🔥 Context Parallelism and Torch Compile</h3>
121
114
  </p>
122
- </div>
115
+ </div>
123
116
 
124
- **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. Moreover, **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance.
117
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
125
118
 
126
119
  <div align="center">
127
120
  <p align="center">
@@ -131,11 +124,16 @@ These case studies demonstrate that even with relatively high thresholds (such a
131
124
 
132
125
  |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
133
126
  |:---:|:---:|:---:|:---:|:---:|:---:|
134
- |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
135
- |+compile:20.43s|16.25s|14.12s|13.41s|12s|8.86s|
127
+ |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
136
128
  |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
137
129
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
138
130
 
131
+ <div align="center">
132
+ <p align="center">
133
+ <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
134
+ </p>
135
+ </div>
136
+
139
137
  ## ©️Citations
140
138
 
141
139
  ```BibTeX
@@ -148,17 +146,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
148
146
  }
149
147
  ```
150
148
 
151
- ## 👋Reference
152
-
153
- <div id="reference"></div>
154
-
155
- The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
156
-
157
149
  ## 📖Contents
158
150
 
159
151
  <div id="contents"></div>
160
152
 
161
153
  - [⚙️Installation](#️installation)
154
+ - [🔥Supported Models](#supported)
162
155
  - [⚡️Dual Block Cache](#dbcache)
163
156
  - [🎉First Block Cache](#fbcache)
164
157
  - [⚡️Dynamic Block Prune](#dbprune)
@@ -182,6 +175,30 @@ Or you can install the latest develop version from GitHub:
182
175
  pip3 install git+https://github.com/vipshop/cache-dit.git
183
176
  ```
184
177
 
178
+ ## 🔥Supported Models
179
+
180
+ <div id="supported"></div>
181
+
182
+ - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples)
183
+ - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples)
184
+ - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
185
+ - [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
186
+ - [🚀Wan2.1](https://github.com/vipshop/cache-dit/raw/main/examples)
187
+ - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
188
+
189
+
190
+ <!--
191
+ <p align="center">
192
+ <h4> 🔥Supported Models🔥</h4>
193
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
194
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
195
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
196
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
197
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
198
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
199
+ </p>
200
+ -->
201
+
185
202
  ## ⚡️DBCache: Dual Block Cache
186
203
 
187
204
  <div id="dbcache"></div>
@@ -339,6 +356,9 @@ cache_options = {
339
356
  apply_cache_on_pipe(pipe, **cache_options)
340
357
  ```
341
358
 
359
+ > [!Important]
360
+ > Please note that for GPUs with lower VRAM, DBPrune may not be suitable for use on video DiTs, as it caches the hidden states and residuals of each block, leading to higher GPU memory requirements. In such cases, please use DBCache, which only caches the hidden states and residuals of 2 blocks.
361
+
342
362
  <div align="center">
343
363
  <p align="center">
344
364
  DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
@@ -396,26 +416,12 @@ Then, run the python test script with `torchrun`:
396
416
  ```bash
397
417
  torchrun --nproc_per_node=4 parallel_cache.py
398
418
  ```
399
- <!--
400
-
401
- <div align="center">
402
- <p align="center">
403
- DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
404
- </p>
405
- </div>
406
-
407
- |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
408
- |:---:|:---:|:---:|:---:|:---:|:---:|
409
- |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
410
- |+L20x4:8.54s|7.20s|6.61s|6.09s|5.54s|4.22s|
411
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
412
- -->
413
419
 
414
420
  ## 🔥Torch Compile
415
421
 
416
422
  <div id="compile"></div>
417
423
 
418
- **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
424
+ By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
419
425
 
420
426
  ```python
421
427
  apply_cache_on_pipe(
@@ -430,22 +436,6 @@ torch._dynamo.config.recompile_limit = 96 # default is 8
430
436
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
431
437
  ```
432
438
 
433
- <!--
434
-
435
- <div align="center">
436
- <p align="center">
437
- DBPrune + <b>torch.compile</b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
438
- </p>
439
- </div>
440
-
441
- |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
442
- |:---:|:---:|:---:|:---:|:---:|:---:|
443
- |+L20x1:24.8s|19.4s|16.8s|15.9s|14.2s|10.6s|
444
- |+compile:20.4s|16.5s|14.1s|13.4s|12s|8.8s|
445
- |+L20x4:7.7s|6.6s|6.0s|5.8s|5.2s|3.9s|
446
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
447
- -->
448
-
449
439
  ## 👋Contribute
450
440
  <div id="contribute"></div>
451
441
 
@@ -1,31 +1,36 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=AjUi5zEL_BoWoXMXR1FnWc3mD6FHX7snDXjDHVLoens,511
2
+ cache_dit/_version.py,sha256=UoNvMtd4wCG76RwoSpNCUtaFyTwakGcZolfjXzNVSMY,511
3
3
  cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
6
6
  cache_dit/cache_factory/taylorseer.py,sha256=0W29ykJg3MnyLAB2KFicsl11Xe41cDYPgI60bquG_NY,2495
7
7
  cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=EJ-uhA2-sWMW1jNDhcBtjHDqSn8lUzfKbYoPfZDQhZU,49665
10
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=C6tfXHpdY8YFV3gk74dr_IpYH4bO4ItbPCQYud3NgAM,1667
9
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=huudVxz-SF3wAY_9vkViFCBhFKm5IzLvXR686u82pbM,50430
10
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=ySmO_0IuSm5VrYdi9ccGVHoFVkgtPZMJGq_OMoyl0Q8,2003
11
11
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=1_n-RFMiL3v2SjhSfFrPH5Mn5Dq9z4BesVK8GN_nh2g,2404
12
12
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=UbE6nIF-EtA92QxIZVMzIssdZKQSPAVX1hchF9R8drU,2754
13
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=k3fnlVRTSFkEfWAJ136EMVCeOHPklsArKWFQGMMyLuM,10102
13
14
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=qxMu1L3ycT8F-uxpGsmFQBY_BH1vDiGIOXgS_Qbb7dM,2391
15
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=M_O91sVaHSJgptv6OY5MhT4vD3gQAUTETEmP1kLdE6c,2713
14
16
  cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
17
  cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=YRDwZ_16yjThpgVgDv6YaIB4QCE9nEkE-MOru0jOd50,35026
16
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=8IjJjZOs5XRzsj7Ni2MXpR2Z1PUyRSONIhmfAn1G0eM,1667
18
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=WNBk4GeekjG2Ln1pAer20TVHpWyGyyHN50DdqWuPhc0,2003
17
19
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=ORJpdkXkgziDUo-rpebC6pUemgYaDCoeu0cwwLz175U,2407
18
20
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=KbEkLSsHtS6xwLWNh3jlOlXRyGRdrI2pWV1zyQxMTj4,2757
21
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=v3SgLJQPbEqSVy2sYVGMhptJqCe-XPXL7LIV7GEacZg,10105
19
22
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=rgeXfww-7WX6URSDg7mF1HuxSmYmoJVjMVoNGuxjwxc,2395
23
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=3-Bg9oPdLcIFZqqSpBGU3Ps1DJ9J8rslP5X7Ow1EHmc,2713
20
24
  cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
25
  cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=DpDhtK095PlrvACf7sbjOt2-QpVkV1arr1qGEKJqgaQ,23502
22
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=h80pgoZ3l8qC4rbm9KY0jSN8hOsmGgyvvFxD-xznHdw,1959
26
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
23
27
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
24
28
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
29
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sha256=OL7W4ukYlZz0IDmBR1zVV6XT3Mgciglj9Hqzv1wUAkQ,10092
25
30
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
26
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=IVH-lroOzvYb4XKLk9MOw54EtijBtuzVaKcVGz0KlBA,2656
27
- cache_dit-0.1.8.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
28
- cache_dit-0.1.8.dist-info/METADATA,sha256=sAYGKro4VfeE_SHrZA8X0BcHfw9y3YY_Qcj9ONkbemE,23952
29
- cache_dit-0.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
- cache_dit-0.1.8.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
31
- cache_dit-0.1.8.dist-info/RECORD,,
31
+ cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
32
+ cache_dit-0.2.1.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
33
+ cache_dit-0.2.1.dist-info/METADATA,sha256=YbS-gmVFpGfTmaKNTUbXfWlfQa-RCoz0TyRqCtHEGJc,22700
34
+ cache_dit-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
35
+ cache_dit-0.2.1.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
36
+ cache_dit-0.2.1.dist-info/RECORD,,