cache-dit 0.1.3__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (77) hide show
  1. {cache_dit-0.1.3 → cache_dit-0.1.6}/PKG-INFO +96 -21
  2. {cache_dit-0.1.3 → cache_dit-0.1.6}/README.md +92 -17
  3. {cache_dit-0.1.3 → cache_dit-0.1.6}/bench/bench.py +50 -13
  4. cache_dit-0.1.6/examples/run_cogvideox.py +30 -0
  5. cache_dit-0.1.6/examples/run_mochi.py +25 -0
  6. {cache_dit-0.1.3 → cache_dit-0.1.6}/requirements.txt +3 -3
  7. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/_version.py +2 -2
  8. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +11 -3
  9. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/PKG-INFO +96 -21
  10. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/SOURCES.txt +2 -0
  11. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/requires.txt +3 -3
  12. {cache_dit-0.1.3 → cache_dit-0.1.6}/.github/workflows/issue.yml +0 -0
  13. {cache_dit-0.1.3 → cache_dit-0.1.6}/.gitignore +0 -0
  14. {cache_dit-0.1.3 → cache_dit-0.1.6}/.pre-commit-config.yaml +0 -0
  15. {cache_dit-0.1.3 → cache_dit-0.1.6}/CONTRIBUTE.md +0 -0
  16. {cache_dit-0.1.3 → cache_dit-0.1.6}/LICENSE +0 -0
  17. {cache_dit-0.1.3 → cache_dit-0.1.6}/MANIFEST.in +0 -0
  18. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
  19. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
  20. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
  21. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
  22. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
  23. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
  24. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
  25. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
  26. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
  27. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
  28. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
  29. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCache.png +0 -0
  30. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
  31. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
  32. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
  33. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
  34. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
  35. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
  36. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
  37. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
  38. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
  39. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
  40. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
  41. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
  42. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/NONE_R0.08_S0.png +0 -0
  43. {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/cache-dit.png +0 -0
  44. {cache_dit-0.1.3 → cache_dit-0.1.6}/bench/.gitignore +0 -0
  45. {cache_dit-0.1.3 → cache_dit-0.1.6}/docs/.gitignore +0 -0
  46. {cache_dit-0.1.3 → cache_dit-0.1.6}/examples/.gitignore +0 -0
  47. {cache_dit-0.1.3 → cache_dit-0.1.6}/examples/run_flux.py +0 -0
  48. {cache_dit-0.1.3 → cache_dit-0.1.6}/pyproject.toml +0 -0
  49. {cache_dit-0.1.3 → cache_dit-0.1.6}/pytest.ini +0 -0
  50. {cache_dit-0.1.3 → cache_dit-0.1.6}/setup.cfg +0 -0
  51. {cache_dit-0.1.3 → cache_dit-0.1.6}/setup.py +0 -0
  52. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/__init__.py +0 -0
  53. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/__init__.py +0 -0
  54. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  55. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -0
  56. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -0
  57. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
  58. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
  59. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
  60. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  61. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -0
  62. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
  63. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
  64. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
  65. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  66. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -0
  67. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -0
  68. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
  69. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
  70. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
  71. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -0
  72. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/taylorseer.py +0 -0
  73. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/utils.py +0 -0
  74. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/logger.py +0 -0
  75. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/primitives.py +0 -0
  76. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/dependency_links.txt +0 -0
  77. {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.3
3
+ Version: 0.1.6
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -10,9 +10,9 @@ Requires-Python: >=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: packaging
13
- Requires-Dist: torch
14
- Requires-Dist: transformers
15
- Requires-Dist: diffusers
13
+ Requires-Dist: torch>=2.5.1
14
+ Requires-Dist: transformers>=4.51.3
15
+ Requires-Dist: diffusers>=0.33.1
16
16
  Provides-Extra: all
17
17
  Provides-Extra: dev
18
18
  Requires-Dist: pre-commit; extra == "dev"
@@ -44,10 +44,10 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.1.3-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT provides <br>a series of training-free, UNet-style cache accelerators for DiT: DBCache, DBPrune, FBCache, etc.
50
+ DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
51
51
  </p>
52
52
  </div>
53
53
 
@@ -55,7 +55,7 @@ Dynamic: requires-python
55
55
 
56
56
  <div align="center">
57
57
  <p align="center">
58
- <h3>DBCache: Dual Block Caching for Diffusion Transformers</h3>
58
+ <h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
59
59
  </p>
60
60
  </div>
61
61
 
@@ -77,7 +77,7 @@ Dynamic: requires-python
77
77
 
78
78
  <div align="center">
79
79
  <p align="center">
80
- DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
80
+ DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
81
81
  </p>
82
82
  </div>
83
83
 
@@ -85,7 +85,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
85
85
 
86
86
  <div align="center">
87
87
  <p align="center">
88
- <h3>DBPrune: Dynamic Block Prune with Residual Caching</h3>
88
+ <h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
89
89
  </p>
90
90
  </div>
91
91
 
@@ -102,10 +102,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
102
102
  </p>
103
103
  </div>
104
104
 
105
- Moreover, both DBCache and DBPrune are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
105
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
106
106
 
107
107
  <p align="center">
108
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
108
+ ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
109
109
  </p>
110
110
 
111
111
  ## ©️Citations
@@ -135,7 +135,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
135
135
  - [🎉First Block Cache](#fbcache)
136
136
  - [⚡️Dynamic Block Prune](#dbprune)
137
137
  - [🎉Context Parallelism](#context-parallelism)
138
- - [⚡️Torch Compile](#compile)
138
+ - [🔥Torch Compile](#compile)
139
139
  - [🎉Supported Models](#supported)
140
140
  - [👋Contribute](#contribute)
141
141
  - [©️License](#license)
@@ -167,7 +167,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
167
167
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
168
168
  - **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
169
169
  - **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
170
- - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the running steps exceed this value to prevent precision degradation.
170
+ - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
171
171
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
172
172
 
173
173
  For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
@@ -210,6 +210,17 @@ cache_options = {
210
210
  }
211
211
  ```
212
212
 
213
+ <div align="center">
214
+ <p align="center">
215
+ DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
216
+ </p>
217
+ </div>
218
+
219
+ |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
220
+ |:---:|:---:|:---:|:---:|:---:|:---:|
221
+ |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
222
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
223
+
213
224
  ## 🎉FBCache: First Block Cache
214
225
 
215
226
  <div id="fbcache"></div>
@@ -261,12 +272,47 @@ pipe = FluxPipeline.from_pretrained(
261
272
  torch_dtype=torch.bfloat16,
262
273
  ).to("cuda")
263
274
 
264
- # Using DBPrune
275
+ # Using DBPrune with default options
265
276
  cache_options = CacheType.default_options(CacheType.DBPrune)
266
277
 
267
278
  apply_cache_on_pipe(pipe, **cache_options)
268
279
  ```
269
280
 
281
+ We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
282
+
283
+ ```python
284
+ # Custom options for DBPrune
285
+ cache_options = {
286
+ "cache_type": CacheType.DBPrune,
287
+ "residual_diff_threshold": 0.05,
288
+ # Never prune the first `Fn` and last `Bn` blocks.
289
+ "Fn_compute_blocks": 8, # default 1
290
+ "Bn_compute_blocks": 8, # default 0
291
+ "warmup_steps": 8, # default -1
292
+ # Disables the pruning strategy when the previous
293
+ # pruned steps greater than this value.
294
+ "max_pruned_steps": 12, # default, -1 means no limit
295
+ # Enable dynamic prune threshold within step, higher
296
+ # `max_dynamic_prune_threshold` value may introduce a more
297
+ # ageressive pruning strategy.
298
+ "enable_dynamic_prune_threshold": True,
299
+ "max_dynamic_prune_threshold": 2 * 0.05,
300
+ # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
301
+ # (New thresh) = ((New thresh) if (New thresh) <
302
+ # max_dynamic_prune_threshold else residual_diff_threshold)
303
+ "dynamic_prune_threshold_relax_ratio": 1.25,
304
+ # The step interval to update residual cache. For example,
305
+ # 2: means the update steps will be [0, 2, 4, ...].
306
+ "residual_cache_update_interval": 1,
307
+ # You can set non-prune blocks to avoid ageressive pruning.
308
+ # For example, FLUX.1 has 19 + 38 blocks, so we can set it
309
+ # to 0, 2, 4, ..., 56, etc.
310
+ "non_prune_blocks_ids": [],
311
+ }
312
+
313
+ apply_cache_on_pipe(pipe, **cache_options)
314
+ ```
315
+
270
316
  <div align="center">
271
317
  <p align="center">
272
318
  DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
@@ -288,14 +334,19 @@ apply_cache_on_pipe(pipe, **cache_options)
288
334
  pip3 install para-attn # or install `para-attn` from sources.
289
335
  ```
290
336
 
291
- Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
337
+ Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
292
338
 
293
339
  ```python
340
+ import torch.distributed as dist
294
341
  from diffusers import FluxPipeline
295
342
  from para_attn.context_parallel import init_context_parallel_mesh
296
343
  from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
297
344
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
298
345
 
346
+ # Init distributed process group
347
+ dist.init_process_group()
348
+ torch.cuda.set_device(dist.get_rank())
349
+
299
350
  pipe = FluxPipeline.from_pretrained(
300
351
  "black-forest-labs/FLUX.1-dev",
301
352
  torch_dtype=torch.bfloat16,
@@ -308,13 +359,31 @@ parallelize_pipe(
308
359
  )
309
360
  )
310
361
 
311
- # DBCache with F8B8 from this library
362
+ # DBPrune with default options from this library
312
363
  apply_cache_on_pipe(
313
- pipe, **CacheType.default_options(CacheType.DBCache)
364
+ pipe, **CacheType.default_options(CacheType.DBPrune)
314
365
  )
366
+
367
+ dist.destroy_process_group()
368
+ ```
369
+ Then, run the python test script with `torchrun`:
370
+ ```bash
371
+ torchrun --nproc_per_node=4 parallel_cache.py
315
372
  ```
316
373
 
317
- ## ⚡️Torch Compile
374
+ <div align="center">
375
+ <p align="center">
376
+ DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
377
+ </p>
378
+ </div>
379
+
380
+ |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
381
+ |:---:|:---:|:---:|:---:|:---:|:---:|
382
+ |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
383
+ |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
384
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
385
+
386
+ ## 🔥Torch Compile
318
387
 
319
388
  <div id="compile"></div>
320
389
 
@@ -322,7 +391,7 @@ apply_cache_on_pipe(
322
391
 
323
392
  ```python
324
393
  apply_cache_on_pipe(
325
- pipe, **CacheType.default_options(CacheType.DBCache)
394
+ pipe, **CacheType.default_options(CacheType.DBPrune)
326
395
  )
327
396
  # Compile the Transformer module
328
397
  pipe.transformer = torch.compile(pipe.transformer)
@@ -333,7 +402,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
333
402
  torch._dynamo.config.recompile_limit = 96 # default is 8
334
403
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
335
404
  ```
336
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
405
+ Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
406
+
407
+ <div align="center">
408
+ <p align="center">
409
+ DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
410
+ </p>
411
+ </div>
337
412
 
338
413
  ## 🎉Supported Models
339
414
 
@@ -346,7 +421,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
346
421
  ## 👋Contribute
347
422
  <div id="contribute"></div>
348
423
 
349
- How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
424
+ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
350
425
 
351
426
  ## ©️License
352
427
 
@@ -9,10 +9,10 @@
9
9
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
10
10
  <img src=https://static.pepy.tech/badge/cache-dit >
11
11
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
12
- <img src=https://img.shields.io/badge/Release-v0.1.3-brightgreen.svg >
12
+ <img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
13
13
  </div>
14
14
  <p align="center">
15
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT provides <br>a series of training-free, UNet-style cache accelerators for DiT: DBCache, DBPrune, FBCache, etc.
15
+ DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
16
16
  </p>
17
17
  </div>
18
18
 
@@ -20,7 +20,7 @@
20
20
 
21
21
  <div align="center">
22
22
  <p align="center">
23
- <h3>DBCache: Dual Block Caching for Diffusion Transformers</h3>
23
+ <h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
24
24
  </p>
25
25
  </div>
26
26
 
@@ -42,7 +42,7 @@
42
42
 
43
43
  <div align="center">
44
44
  <p align="center">
45
- DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
45
+ DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
46
46
  </p>
47
47
  </div>
48
48
 
@@ -50,7 +50,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
50
50
 
51
51
  <div align="center">
52
52
  <p align="center">
53
- <h3>DBPrune: Dynamic Block Prune with Residual Caching</h3>
53
+ <h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
54
54
  </p>
55
55
  </div>
56
56
 
@@ -67,10 +67,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
67
67
  </p>
68
68
  </div>
69
69
 
70
- Moreover, both DBCache and DBPrune are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
70
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
71
71
 
72
72
  <p align="center">
73
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
73
+ ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
74
74
  </p>
75
75
 
76
76
  ## ©️Citations
@@ -100,7 +100,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
100
100
  - [🎉First Block Cache](#fbcache)
101
101
  - [⚡️Dynamic Block Prune](#dbprune)
102
102
  - [🎉Context Parallelism](#context-parallelism)
103
- - [⚡️Torch Compile](#compile)
103
+ - [🔥Torch Compile](#compile)
104
104
  - [🎉Supported Models](#supported)
105
105
  - [👋Contribute](#contribute)
106
106
  - [©️License](#license)
@@ -132,7 +132,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
132
132
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
133
133
  - **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
134
134
  - **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
135
- - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the running steps exceed this value to prevent precision degradation.
135
+ - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
136
136
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
137
137
 
138
138
  For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
@@ -175,6 +175,17 @@ cache_options = {
175
175
  }
176
176
  ```
177
177
 
178
+ <div align="center">
179
+ <p align="center">
180
+ DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
181
+ </p>
182
+ </div>
183
+
184
+ |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
185
+ |:---:|:---:|:---:|:---:|:---:|:---:|
186
+ |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
187
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
188
+
178
189
  ## 🎉FBCache: First Block Cache
179
190
 
180
191
  <div id="fbcache"></div>
@@ -226,12 +237,47 @@ pipe = FluxPipeline.from_pretrained(
226
237
  torch_dtype=torch.bfloat16,
227
238
  ).to("cuda")
228
239
 
229
- # Using DBPrune
240
+ # Using DBPrune with default options
230
241
  cache_options = CacheType.default_options(CacheType.DBPrune)
231
242
 
232
243
  apply_cache_on_pipe(pipe, **cache_options)
233
244
  ```
234
245
 
246
+ We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
247
+
248
+ ```python
249
+ # Custom options for DBPrune
250
+ cache_options = {
251
+ "cache_type": CacheType.DBPrune,
252
+ "residual_diff_threshold": 0.05,
253
+ # Never prune the first `Fn` and last `Bn` blocks.
254
+ "Fn_compute_blocks": 8, # default 1
255
+ "Bn_compute_blocks": 8, # default 0
256
+ "warmup_steps": 8, # default -1
257
+ # Disables the pruning strategy when the previous
258
+ # pruned steps greater than this value.
259
+ "max_pruned_steps": 12, # default, -1 means no limit
260
+ # Enable dynamic prune threshold within step, higher
261
+ # `max_dynamic_prune_threshold` value may introduce a more
262
+ # ageressive pruning strategy.
263
+ "enable_dynamic_prune_threshold": True,
264
+ "max_dynamic_prune_threshold": 2 * 0.05,
265
+ # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
266
+ # (New thresh) = ((New thresh) if (New thresh) <
267
+ # max_dynamic_prune_threshold else residual_diff_threshold)
268
+ "dynamic_prune_threshold_relax_ratio": 1.25,
269
+ # The step interval to update residual cache. For example,
270
+ # 2: means the update steps will be [0, 2, 4, ...].
271
+ "residual_cache_update_interval": 1,
272
+ # You can set non-prune blocks to avoid ageressive pruning.
273
+ # For example, FLUX.1 has 19 + 38 blocks, so we can set it
274
+ # to 0, 2, 4, ..., 56, etc.
275
+ "non_prune_blocks_ids": [],
276
+ }
277
+
278
+ apply_cache_on_pipe(pipe, **cache_options)
279
+ ```
280
+
235
281
  <div align="center">
236
282
  <p align="center">
237
283
  DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
@@ -253,14 +299,19 @@ apply_cache_on_pipe(pipe, **cache_options)
253
299
  pip3 install para-attn # or install `para-attn` from sources.
254
300
  ```
255
301
 
256
- Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
302
+ Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
257
303
 
258
304
  ```python
305
+ import torch.distributed as dist
259
306
  from diffusers import FluxPipeline
260
307
  from para_attn.context_parallel import init_context_parallel_mesh
261
308
  from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
262
309
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
263
310
 
311
+ # Init distributed process group
312
+ dist.init_process_group()
313
+ torch.cuda.set_device(dist.get_rank())
314
+
264
315
  pipe = FluxPipeline.from_pretrained(
265
316
  "black-forest-labs/FLUX.1-dev",
266
317
  torch_dtype=torch.bfloat16,
@@ -273,13 +324,31 @@ parallelize_pipe(
273
324
  )
274
325
  )
275
326
 
276
- # DBCache with F8B8 from this library
327
+ # DBPrune with default options from this library
277
328
  apply_cache_on_pipe(
278
- pipe, **CacheType.default_options(CacheType.DBCache)
329
+ pipe, **CacheType.default_options(CacheType.DBPrune)
279
330
  )
331
+
332
+ dist.destroy_process_group()
333
+ ```
334
+ Then, run the python test script with `torchrun`:
335
+ ```bash
336
+ torchrun --nproc_per_node=4 parallel_cache.py
280
337
  ```
281
338
 
282
- ## ⚡️Torch Compile
339
+ <div align="center">
340
+ <p align="center">
341
+ DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
342
+ </p>
343
+ </div>
344
+
345
+ |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
346
+ |:---:|:---:|:---:|:---:|:---:|:---:|
347
+ |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
348
+ |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
349
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
350
+
351
+ ## 🔥Torch Compile
283
352
 
284
353
  <div id="compile"></div>
285
354
 
@@ -287,7 +356,7 @@ apply_cache_on_pipe(
287
356
 
288
357
  ```python
289
358
  apply_cache_on_pipe(
290
- pipe, **CacheType.default_options(CacheType.DBCache)
359
+ pipe, **CacheType.default_options(CacheType.DBPrune)
291
360
  )
292
361
  # Compile the Transformer module
293
362
  pipe.transformer = torch.compile(pipe.transformer)
@@ -298,7 +367,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
298
367
  torch._dynamo.config.recompile_limit = 96 # default is 8
299
368
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
300
369
  ```
301
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
370
+ Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
371
+
372
+ <div align="center">
373
+ <p align="center">
374
+ DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
375
+ </p>
376
+ </div>
302
377
 
303
378
  ## 🎉Supported Models
304
379
 
@@ -311,7 +386,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
311
386
  ## 👋Contribute
312
387
  <div id="contribute"></div>
313
388
 
314
- How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
389
+ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
315
390
 
316
391
  ## ©️License
317
392
 
@@ -16,6 +16,7 @@ def get_args() -> argparse.ArgumentParser:
16
16
  # General arguments
17
17
  parser.add_argument("--steps", type=int, default=28)
18
18
  parser.add_argument("--repeats", type=int, default=2)
19
+ parser.add_argument("--seed", type=int, default=0)
19
20
  parser.add_argument("--cache", type=str, default=None)
20
21
  parser.add_argument("--alter", action="store_true", default=False)
21
22
  parser.add_argument("--l1-diff", action="store_true", default=False)
@@ -26,12 +27,9 @@ def get_args() -> argparse.ArgumentParser:
26
27
  parser.add_argument("--warmup-steps", type=int, default=0)
27
28
  parser.add_argument("--max-cached-steps", type=int, default=-1)
28
29
  parser.add_argument("--max-pruned-steps", type=int, default=-1)
29
- parser.add_argument("--seed", type=int, default=0)
30
- parser.add_argument(
31
- "--compile",
32
- action="store_true",
33
- default=False,
34
- )
30
+ parser.add_argument("--ulysses", type=int, default=None)
31
+ parser.add_argument("--compile", action="store_true", default=False)
32
+ parser.add_argument("--gen-device", type=str, default="cuda")
35
33
  return parser.parse_args()
36
34
 
37
35
 
@@ -116,10 +114,41 @@ def main():
116
114
  args = get_args()
117
115
  logger.info(f"Arguments: {args}")
118
116
 
119
- pipe = FluxPipeline.from_pretrained(
120
- os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
121
- torch_dtype=torch.bfloat16,
122
- ).to("cuda")
117
+ # Context Parallel from ParaAttention
118
+ if args.ulysses is not None:
119
+ try:
120
+ import torch.distributed as dist
121
+ from para_attn.context_parallel import init_context_parallel_mesh
122
+ from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
123
+
124
+ # Initialize distributed process group
125
+ dist.init_process_group()
126
+ torch.cuda.set_device(dist.get_rank())
127
+
128
+ logger.info(f"Ulysses: {args.ulysses}")
129
+
130
+ pipe = FluxPipeline.from_pretrained(
131
+ os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
132
+ torch_dtype=torch.bfloat16,
133
+ ).to("cuda")
134
+
135
+ parallelize_pipe(
136
+ pipe, mesh=init_context_parallel_mesh(
137
+ pipe.device.type, max_ulysses_dim_size=args.ulysses
138
+ )
139
+ )
140
+ except ImportError as e:
141
+ logger.error(
142
+ "para-attn is not installed, please install it "
143
+ "with `pip install para-attn.`"
144
+ )
145
+ args.ulysses = None
146
+ raise e
147
+ else:
148
+ pipe = FluxPipeline.from_pretrained(
149
+ os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
150
+ torch_dtype=torch.bfloat16,
151
+ ).to("cuda")
123
152
 
124
153
  cache_options, cache_type = get_cache_options(args.cache, args)
125
154
 
@@ -149,7 +178,7 @@ def main():
149
178
  image = pipe(
150
179
  "A cat holding a sign that says hello world with complex background",
151
180
  num_inference_steps=args.steps,
152
- generator=torch.Generator("cuda").manual_seed(args.seed),
181
+ generator=torch.Generator(args.gen_device).manual_seed(args.seed),
153
182
  ).images[0]
154
183
  end = time.time()
155
184
  all_times.append(end - start)
@@ -191,19 +220,27 @@ def main():
191
220
  f"Actual Blocks: {actual_blocks}\n"
192
221
  f"Pruned Blocks: {pruned_blocks}"
193
222
  )
223
+ ulysses = 0 if args.ulysses is None else args.ulysses
194
224
  if len(actual_blocks) > 0:
195
225
  save_name = (
196
- f"{cache_type}_R{args.rdt}_P{pruned_ratio:.1f}_"
226
+ f"U{ulysses}_C{int(args.compile)}_{cache_type}_"
227
+ f"R{args.rdt}_P{pruned_ratio:.1f}_"
197
228
  f"T{mean_time:.2f}s.png"
198
229
  )
199
230
  else:
200
231
  save_name = (
201
- f"{cache_type}_R{args.rdt}_S{cached_stepes}_"
232
+ f"U{ulysses}_C{int(args.compile)}_{cache_type}_"
233
+ f"R{args.rdt}_S{cached_stepes}_"
202
234
  f"T{mean_time:.2f}s.png"
203
235
  )
204
236
  image.save(save_name)
205
237
  logger.info(f"Image saved as {save_name}")
206
238
 
239
+ if args.ulysses is not None:
240
+ import torch.distributed as dist
241
+ dist.destroy_process_group()
242
+ logger.info("Distributed process group destroyed.")
243
+
207
244
 
208
245
  if __name__ == "__main__":
209
246
  main()
@@ -0,0 +1,30 @@
1
+ import torch
2
+ from diffusers import CogVideoXPipeline
3
+ from diffusers.utils import export_to_video
4
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
5
+
6
+ pipe = CogVideoXPipeline.from_pretrained(
7
+ "THUDM/CogVideoX-5b",
8
+ torch_dtype=torch.bfloat16,
9
+ ).to("cuda")
10
+
11
+ # Default options, F8B8, good balance between performance and precision
12
+ cache_options = CacheType.default_options(CacheType.DBCache)
13
+
14
+ apply_cache_on_pipe(pipe, **cache_options)
15
+
16
+ pipe.vae.enable_slicing()
17
+ pipe.vae.enable_tiling()
18
+
19
+ prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
20
+ video = pipe(
21
+ prompt=prompt,
22
+ num_videos_per_prompt=1,
23
+ num_inference_steps=50,
24
+ num_frames=49,
25
+ guidance_scale=6,
26
+ generator=torch.Generator("cuda").manual_seed(0),
27
+ ).frames[0]
28
+
29
+ print("Saving video to cogvideox.mp4")
30
+ export_to_video(video, "cogvideox.mp4", fps=8)
@@ -0,0 +1,25 @@
1
+ import torch
2
+ from diffusers import MochiPipeline
3
+ from diffusers.utils import export_to_video
4
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
5
+
6
+ pipe = MochiPipeline.from_pretrained(
7
+ "genmo/mochi-1-preview",
8
+ torch_dtype=torch.bfloat16,
9
+ ).to("cuda")
10
+
11
+ # Default options, F8B8, good balance between performance and precision
12
+ cache_options = CacheType.default_options(CacheType.DBCache)
13
+
14
+ apply_cache_on_pipe(pipe, **cache_options)
15
+
16
+ pipe.enable_vae_tiling()
17
+
18
+ prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
19
+ video = pipe(
20
+ prompt,
21
+ num_frames=84,
22
+ ).frames[0]
23
+
24
+ print("Saving video to mochi.mp4")
25
+ export_to_video(video, "mochi.mp4", fps=30)
@@ -1,6 +1,6 @@
1
1
  # Example requirement, can be anything that pip knows
2
2
  # install with `pip install -r requirements.txt`, and make sure that CI does the same
3
3
  packaging
4
- torch
5
- transformers
6
- diffusers
4
+ torch>=2.5.1 # torch 2.7.0 is preferred, but 2.5.1 is the minimum required version
5
+ transformers>=4.51.3
6
+ diffusers>=0.33.1
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.3'
21
- __version_tuple__ = version_tuple = (0, 1, 3)
20
+ __version__ = version = '0.1.6'
21
+ __version_tuple__ = version_tuple = (0, 1, 6)
@@ -562,7 +562,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
562
562
  torch._dynamo.graph_break()
563
563
 
564
564
  add_pruned_block(self.pruned_blocks_step)
565
- add_actual_block(self._num_transformer_blocks)
565
+ add_actual_block(self.num_transformer_blocks)
566
566
  patch_pruned_stats(self.transformer)
567
567
 
568
568
  return (
@@ -577,7 +577,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
577
577
 
578
578
  @property
579
579
  @torch.compiler.disable
580
- def _num_transformer_blocks(self):
580
+ def num_transformer_blocks(self):
581
581
  # Total number of transformer blocks, including single transformer blocks.
582
582
  num_blocks = len(self.transformer_blocks)
583
583
  if self.single_transformer_blocks is not None:
@@ -597,7 +597,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
597
597
  @torch.compiler.disable
598
598
  def _non_prune_blocks_ids(self):
599
599
  # Never prune the first `Fn` and last `Bn` blocks.
600
- num_blocks = self._num_transformer_blocks
600
+ num_blocks = self.num_transformer_blocks
601
601
  Fn_compute_blocks_ = (
602
602
  Fn_compute_blocks()
603
603
  if Fn_compute_blocks() < num_blocks
@@ -627,6 +627,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
627
627
  ]
628
628
  return sorted(non_prune_blocks_ids)
629
629
 
630
+ # @torch.compile(dynamic=True)
631
+ # mark this function as compile with dynamic=True will
632
+ # cause precision degradate, so, we choose to disable it
633
+ # now, until we find a better solution or fixed the bug.
630
634
  @torch.compiler.disable
631
635
  def _compute_single_hidden_states_residual(
632
636
  self,
@@ -663,6 +667,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
663
667
  single_encoder_hidden_states_residual,
664
668
  )
665
669
 
670
+ # @torch.compile(dynamic=True)
671
+ # mark this function as compile with dynamic=True will
672
+ # cause precision degradate, so, we choose to disable it
673
+ # now, until we find a better solution or fixed the bug.
666
674
  @torch.compiler.disable
667
675
  def _split_single_hidden_states(
668
676
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.3
3
+ Version: 0.1.6
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -10,9 +10,9 @@ Requires-Python: >=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: packaging
13
- Requires-Dist: torch
14
- Requires-Dist: transformers
15
- Requires-Dist: diffusers
13
+ Requires-Dist: torch>=2.5.1
14
+ Requires-Dist: transformers>=4.51.3
15
+ Requires-Dist: diffusers>=0.33.1
16
16
  Provides-Extra: all
17
17
  Provides-Extra: dev
18
18
  Requires-Dist: pre-commit; extra == "dev"
@@ -44,10 +44,10 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.1.3-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT provides <br>a series of training-free, UNet-style cache accelerators for DiT: DBCache, DBPrune, FBCache, etc.
50
+ DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
51
51
  </p>
52
52
  </div>
53
53
 
@@ -55,7 +55,7 @@ Dynamic: requires-python
55
55
 
56
56
  <div align="center">
57
57
  <p align="center">
58
- <h3>DBCache: Dual Block Caching for Diffusion Transformers</h3>
58
+ <h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
59
59
  </p>
60
60
  </div>
61
61
 
@@ -77,7 +77,7 @@ Dynamic: requires-python
77
77
 
78
78
  <div align="center">
79
79
  <p align="center">
80
- DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
80
+ DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
81
81
  </p>
82
82
  </div>
83
83
 
@@ -85,7 +85,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
85
85
 
86
86
  <div align="center">
87
87
  <p align="center">
88
- <h3>DBPrune: Dynamic Block Prune with Residual Caching</h3>
88
+ <h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
89
89
  </p>
90
90
  </div>
91
91
 
@@ -102,10 +102,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
102
102
  </p>
103
103
  </div>
104
104
 
105
- Moreover, both DBCache and DBPrune are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
105
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
106
106
 
107
107
  <p align="center">
108
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
108
+ ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
109
109
  </p>
110
110
 
111
111
  ## ©️Citations
@@ -135,7 +135,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
135
135
  - [🎉First Block Cache](#fbcache)
136
136
  - [⚡️Dynamic Block Prune](#dbprune)
137
137
  - [🎉Context Parallelism](#context-parallelism)
138
- - [⚡️Torch Compile](#compile)
138
+ - [🔥Torch Compile](#compile)
139
139
  - [🎉Supported Models](#supported)
140
140
  - [👋Contribute](#contribute)
141
141
  - [©️License](#license)
@@ -167,7 +167,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
167
167
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
168
168
  - **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
169
169
  - **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
170
- - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the running steps exceed this value to prevent precision degradation.
170
+ - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
171
171
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
172
172
 
173
173
  For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
@@ -210,6 +210,17 @@ cache_options = {
210
210
  }
211
211
  ```
212
212
 
213
+ <div align="center">
214
+ <p align="center">
215
+ DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
216
+ </p>
217
+ </div>
218
+
219
+ |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
220
+ |:---:|:---:|:---:|:---:|:---:|:---:|
221
+ |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
222
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
223
+
213
224
  ## 🎉FBCache: First Block Cache
214
225
 
215
226
  <div id="fbcache"></div>
@@ -261,12 +272,47 @@ pipe = FluxPipeline.from_pretrained(
261
272
  torch_dtype=torch.bfloat16,
262
273
  ).to("cuda")
263
274
 
264
- # Using DBPrune
275
+ # Using DBPrune with default options
265
276
  cache_options = CacheType.default_options(CacheType.DBPrune)
266
277
 
267
278
  apply_cache_on_pipe(pipe, **cache_options)
268
279
  ```
269
280
 
281
+ We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
282
+
283
+ ```python
284
+ # Custom options for DBPrune
285
+ cache_options = {
286
+ "cache_type": CacheType.DBPrune,
287
+ "residual_diff_threshold": 0.05,
288
+ # Never prune the first `Fn` and last `Bn` blocks.
289
+ "Fn_compute_blocks": 8, # default 1
290
+ "Bn_compute_blocks": 8, # default 0
291
+ "warmup_steps": 8, # default -1
292
+ # Disables the pruning strategy when the previous
293
+ # pruned steps greater than this value.
294
+ "max_pruned_steps": 12, # default, -1 means no limit
295
+ # Enable dynamic prune threshold within step, higher
296
+ # `max_dynamic_prune_threshold` value may introduce a more
297
+ # ageressive pruning strategy.
298
+ "enable_dynamic_prune_threshold": True,
299
+ "max_dynamic_prune_threshold": 2 * 0.05,
300
+ # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
301
+ # (New thresh) = ((New thresh) if (New thresh) <
302
+ # max_dynamic_prune_threshold else residual_diff_threshold)
303
+ "dynamic_prune_threshold_relax_ratio": 1.25,
304
+ # The step interval to update residual cache. For example,
305
+ # 2: means the update steps will be [0, 2, 4, ...].
306
+ "residual_cache_update_interval": 1,
307
+ # You can set non-prune blocks to avoid ageressive pruning.
308
+ # For example, FLUX.1 has 19 + 38 blocks, so we can set it
309
+ # to 0, 2, 4, ..., 56, etc.
310
+ "non_prune_blocks_ids": [],
311
+ }
312
+
313
+ apply_cache_on_pipe(pipe, **cache_options)
314
+ ```
315
+
270
316
  <div align="center">
271
317
  <p align="center">
272
318
  DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
@@ -288,14 +334,19 @@ apply_cache_on_pipe(pipe, **cache_options)
288
334
  pip3 install para-attn # or install `para-attn` from sources.
289
335
  ```
290
336
 
291
- Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
337
+ Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
292
338
 
293
339
  ```python
340
+ import torch.distributed as dist
294
341
  from diffusers import FluxPipeline
295
342
  from para_attn.context_parallel import init_context_parallel_mesh
296
343
  from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
297
344
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
298
345
 
346
+ # Init distributed process group
347
+ dist.init_process_group()
348
+ torch.cuda.set_device(dist.get_rank())
349
+
299
350
  pipe = FluxPipeline.from_pretrained(
300
351
  "black-forest-labs/FLUX.1-dev",
301
352
  torch_dtype=torch.bfloat16,
@@ -308,13 +359,31 @@ parallelize_pipe(
308
359
  )
309
360
  )
310
361
 
311
- # DBCache with F8B8 from this library
362
+ # DBPrune with default options from this library
312
363
  apply_cache_on_pipe(
313
- pipe, **CacheType.default_options(CacheType.DBCache)
364
+ pipe, **CacheType.default_options(CacheType.DBPrune)
314
365
  )
366
+
367
+ dist.destroy_process_group()
368
+ ```
369
+ Then, run the python test script with `torchrun`:
370
+ ```bash
371
+ torchrun --nproc_per_node=4 parallel_cache.py
315
372
  ```
316
373
 
317
- ## ⚡️Torch Compile
374
+ <div align="center">
375
+ <p align="center">
376
+ DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
377
+ </p>
378
+ </div>
379
+
380
+ |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
381
+ |:---:|:---:|:---:|:---:|:---:|:---:|
382
+ |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
383
+ |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
384
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
385
+
386
+ ## 🔥Torch Compile
318
387
 
319
388
  <div id="compile"></div>
320
389
 
@@ -322,7 +391,7 @@ apply_cache_on_pipe(
322
391
 
323
392
  ```python
324
393
  apply_cache_on_pipe(
325
- pipe, **CacheType.default_options(CacheType.DBCache)
394
+ pipe, **CacheType.default_options(CacheType.DBPrune)
326
395
  )
327
396
  # Compile the Transformer module
328
397
  pipe.transformer = torch.compile(pipe.transformer)
@@ -333,7 +402,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
333
402
  torch._dynamo.config.recompile_limit = 96 # default is 8
334
403
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
335
404
  ```
336
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
405
+ Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
406
+
407
+ <div align="center">
408
+ <p align="center">
409
+ DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
410
+ </p>
411
+ </div>
337
412
 
338
413
  ## 🎉Supported Models
339
414
 
@@ -346,7 +421,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
346
421
  ## 👋Contribute
347
422
  <div id="contribute"></div>
348
423
 
349
- How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
424
+ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
350
425
 
351
426
  ## ©️License
352
427
 
@@ -40,7 +40,9 @@ bench/.gitignore
40
40
  bench/bench.py
41
41
  docs/.gitignore
42
42
  examples/.gitignore
43
+ examples/run_cogvideox.py
43
44
  examples/run_flux.py
45
+ examples/run_mochi.py
44
46
  src/cache_dit/__init__.py
45
47
  src/cache_dit/_version.py
46
48
  src/cache_dit/logger.py
@@ -1,7 +1,7 @@
1
1
  packaging
2
- torch
3
- transformers
4
- diffusers
2
+ torch>=2.5.1
3
+ transformers>=4.51.3
4
+ diffusers>=0.33.1
5
5
 
6
6
  [all]
7
7
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes