cache-dit 0.1.7__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (103) hide show
  1. {cache_dit-0.1.7 → cache_dit-0.1.8}/PKG-INFO +55 -21
  2. {cache_dit-0.1.7 → cache_dit-0.1.8}/README.md +54 -20
  3. cache_dit-0.1.8/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png +0 -0
  4. cache_dit-0.1.8/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png +0 -0
  5. cache_dit-0.1.8/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png +0 -0
  6. cache_dit-0.1.8/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png +0 -0
  7. cache_dit-0.1.8/assets/U0_C1_DBPRUNE_F1B0_R0.05_P41.6_T12.70s.png +0 -0
  8. cache_dit-0.1.8/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png +0 -0
  9. cache_dit-0.1.8/assets/U0_C1_DBPRUNE_F8B8_R0.08_P23.1_T16.14s.png +0 -0
  10. cache_dit-0.1.8/assets/U0_C1_NONE_R0.08_S0_T20.43s.png +0 -0
  11. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.62s.png +0 -0
  12. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.63s.png +0 -0
  13. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.81s.png +0 -0
  14. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.82s.png +0 -0
  15. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.06s.png +0 -0
  16. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.07s.png +0 -0
  17. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.08s.png +0 -0
  18. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.27s.png +0 -0
  19. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.28s.png +0 -0
  20. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.95s.png +0 -0
  21. cache_dit-0.1.8/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.96s.png +0 -0
  22. cache_dit-0.1.8/assets/U4_C1_NONE_R0.08_S0_T7.78s.png +0 -0
  23. cache_dit-0.1.8/assets/U4_C1_NONE_R0.08_S0_T7.79s.png +0 -0
  24. {cache_dit-0.1.7 → cache_dit-0.1.8}/bench/bench.py +22 -6
  25. {cache_dit-0.1.7 → cache_dit-0.1.8}/examples/.gitignore +0 -1
  26. cache_dit-0.1.8/examples/data/cup.png +0 -0
  27. cache_dit-0.1.8/examples/data/cup_mask.png +0 -0
  28. cache_dit-0.1.8/examples/run_cogvideox.py +46 -0
  29. {cache_dit-0.1.7 → cache_dit-0.1.8}/examples/run_flux.py +5 -1
  30. cache_dit-0.1.8/examples/run_flux_fill.py +32 -0
  31. {cache_dit-0.1.7 → cache_dit-0.1.8}/examples/run_mochi.py +9 -2
  32. cache_dit-0.1.8/examples/run_wan.py +49 -0
  33. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/_version.py +2 -2
  34. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +2 -2
  35. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit.egg-info/PKG-INFO +55 -21
  36. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit.egg-info/SOURCES.txt +25 -0
  37. cache_dit-0.1.7/examples/run_cogvideox.py +0 -30
  38. {cache_dit-0.1.7 → cache_dit-0.1.8}/.github/workflows/issue.yml +0 -0
  39. {cache_dit-0.1.7 → cache_dit-0.1.8}/.gitignore +0 -0
  40. {cache_dit-0.1.7 → cache_dit-0.1.8}/.pre-commit-config.yaml +0 -0
  41. {cache_dit-0.1.7 → cache_dit-0.1.8}/CONTRIBUTE.md +0 -0
  42. {cache_dit-0.1.7 → cache_dit-0.1.8}/LICENSE +0 -0
  43. {cache_dit-0.1.7 → cache_dit-0.1.8}/MANIFEST.in +0 -0
  44. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
  45. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
  46. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
  47. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
  48. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
  49. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
  50. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
  51. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
  52. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
  53. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
  54. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
  55. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBCache.png +0 -0
  56. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
  57. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
  58. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
  59. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
  60. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
  61. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
  62. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
  63. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
  64. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
  65. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
  66. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
  67. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
  68. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/NONE_R0.08_S0.png +0 -0
  69. {cache_dit-0.1.7 → cache_dit-0.1.8}/assets/cache-dit.png +0 -0
  70. {cache_dit-0.1.7 → cache_dit-0.1.8}/bench/.gitignore +0 -0
  71. {cache_dit-0.1.7 → cache_dit-0.1.8}/docs/.gitignore +0 -0
  72. {cache_dit-0.1.7 → cache_dit-0.1.8}/pyproject.toml +0 -0
  73. {cache_dit-0.1.7 → cache_dit-0.1.8}/pytest.ini +0 -0
  74. {cache_dit-0.1.7 → cache_dit-0.1.8}/requirements.txt +0 -0
  75. {cache_dit-0.1.7 → cache_dit-0.1.8}/setup.cfg +0 -0
  76. {cache_dit-0.1.7 → cache_dit-0.1.8}/setup.py +0 -0
  77. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/__init__.py +0 -0
  78. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/__init__.py +0 -0
  79. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  80. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -0
  81. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -0
  82. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
  83. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
  84. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
  85. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  86. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -0
  87. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
  88. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
  89. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
  90. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  91. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -0
  92. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -0
  93. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
  94. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
  95. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
  96. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -0
  97. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/taylorseer.py +0 -0
  98. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/cache_factory/utils.py +0 -0
  99. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/logger.py +0 -0
  100. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit/primitives.py +0 -0
  101. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit.egg-info/dependency_links.txt +0 -0
  102. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit.egg-info/requires.txt +0 -0
  103. {cache_dit-0.1.7 → cache_dit-0.1.8}/src/cache_dit.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -35,7 +35,7 @@ Dynamic: requires-python
35
35
 
36
36
  <div align="center">
37
37
  <p align="center">
38
- <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
38
+ <h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
39
39
  </p>
40
40
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
41
41
  <div align='center'>
@@ -44,13 +44,32 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.1.7-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.1.8-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
50
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
51
51
  </p>
52
+ <p align="center">
53
+ <h3> 🔥Supported Models🔥</h2>
54
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
55
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
56
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
57
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: 🔜DBCache, 🔜DBPrune, ✔️FBCache🔥</a> <br> <br>
58
+ <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
59
+ </p>
52
60
  </div>
53
61
 
62
+
63
+ <!--
64
+ ## 🎉Supported Models
65
+ <div id="supported"></div>
66
+ - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
67
+ - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
68
+ - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
69
+ - [🚀Wan2.1**](https://github.com/vipshop/cache-dit/raw/main/examples): *🔜DBCache, 🔜DBPrune, ✔️FBCache*
70
+ -->
71
+
72
+
54
73
  ## 🤗 Introduction
55
74
 
56
75
  <div align="center">
@@ -102,11 +121,20 @@ These case studies demonstrate that even with relatively high thresholds (such a
102
121
  </p>
103
122
  </div>
104
123
 
105
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
124
+ **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. Moreover, **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance.
125
+
126
+ <div align="center">
127
+ <p align="center">
128
+ DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
129
+ </p>
130
+ </div>
106
131
 
107
- <p align="center">
108
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
109
- </p>
132
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
133
+ |:---:|:---:|:---:|:---:|:---:|:---:|
134
+ |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
135
+ |+compile:20.43s|16.25s|14.12s|13.41s|12s|8.86s|
136
+ |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
137
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
110
138
 
111
139
  ## ©️Citations
112
140
 
@@ -136,11 +164,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
136
164
  - [⚡️Dynamic Block Prune](#dbprune)
137
165
  - [🎉Context Parallelism](#context-parallelism)
138
166
  - [🔥Torch Compile](#compile)
139
- - [🎉Supported Models](#supported)
140
167
  - [👋Contribute](#contribute)
141
168
  - [©️License](#license)
142
169
 
143
-
144
170
  ## ⚙️Installation
145
171
 
146
172
  <div id="installation"></div>
@@ -370,6 +396,7 @@ Then, run the python test script with `torchrun`:
370
396
  ```bash
371
397
  torchrun --nproc_per_node=4 parallel_cache.py
372
398
  ```
399
+ <!--
373
400
 
374
401
  <div align="center">
375
402
  <p align="center">
@@ -377,17 +404,18 @@ torchrun --nproc_per_node=4 parallel_cache.py
377
404
  </p>
378
405
  </div>
379
406
 
380
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
407
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
381
408
  |:---:|:---:|:---:|:---:|:---:|:---:|
382
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
383
- |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
409
+ |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
410
+ |+L20x4:8.54s|7.20s|6.61s|6.09s|5.54s|4.22s|
384
411
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
412
+ -->
385
413
 
386
414
  ## 🔥Torch Compile
387
415
 
388
416
  <div id="compile"></div>
389
417
 
390
- **CacheDiT** are designed to work compatibly with `torch.compile`. For example:
418
+ **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
391
419
 
392
420
  ```python
393
421
  apply_cache_on_pipe(
@@ -396,21 +424,27 @@ apply_cache_on_pipe(
396
424
  # Compile the Transformer module
397
425
  pipe.transformer = torch.compile(pipe.transformer)
398
426
  ```
399
- However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo` to achieve better performance.
400
-
427
+ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
401
428
  ```python
402
429
  torch._dynamo.config.recompile_limit = 96 # default is 8
403
430
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
404
431
  ```
405
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
406
432
 
407
- ## 🎉Supported Models
433
+ <!--
408
434
 
409
- <div id="supported"></div>
435
+ <div align="center">
436
+ <p align="center">
437
+ DBPrune + <b>torch.compile</b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
438
+ </p>
439
+ </div>
410
440
 
411
- - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
412
- - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
413
- - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
441
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
442
+ |:---:|:---:|:---:|:---:|:---:|:---:|
443
+ |+L20x1:24.8s|19.4s|16.8s|15.9s|14.2s|10.6s|
444
+ |+compile:20.4s|16.5s|14.1s|13.4s|12s|8.8s|
445
+ |+L20x4:7.7s|6.6s|6.0s|5.8s|5.2s|3.9s|
446
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
447
+ -->
414
448
 
415
449
  ## 👋Contribute
416
450
  <div id="contribute"></div>
@@ -1,6 +1,6 @@
1
1
  <div align="center">
2
2
  <p align="center">
3
- <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
3
+ <h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
4
4
  </p>
5
5
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
6
6
  <div align='center'>
@@ -9,13 +9,32 @@
9
9
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
10
10
  <img src=https://static.pepy.tech/badge/cache-dit >
11
11
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
12
- <img src=https://img.shields.io/badge/Release-v0.1.7-brightgreen.svg >
12
+ <img src=https://img.shields.io/badge/Release-v0.1.8-brightgreen.svg >
13
13
  </div>
14
14
  <p align="center">
15
15
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
16
16
  </p>
17
+ <p align="center">
18
+ <h3> 🔥Supported Models🔥</h2>
19
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
20
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
21
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
22
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: 🔜DBCache, 🔜DBPrune, ✔️FBCache🔥</a> <br> <br>
23
+ <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
24
+ </p>
17
25
  </div>
18
26
 
27
+
28
+ <!--
29
+ ## 🎉Supported Models
30
+ <div id="supported"></div>
31
+ - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
32
+ - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
33
+ - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
34
+ - [🚀Wan2.1**](https://github.com/vipshop/cache-dit/raw/main/examples): *🔜DBCache, 🔜DBPrune, ✔️FBCache*
35
+ -->
36
+
37
+
19
38
  ## 🤗 Introduction
20
39
 
21
40
  <div align="center">
@@ -67,11 +86,20 @@ These case studies demonstrate that even with relatively high thresholds (such a
67
86
  </p>
68
87
  </div>
69
88
 
70
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
89
+ **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. Moreover, **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance.
90
+
91
+ <div align="center">
92
+ <p align="center">
93
+ DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
94
+ </p>
95
+ </div>
71
96
 
72
- <p align="center">
73
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
74
- </p>
97
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
98
+ |:---:|:---:|:---:|:---:|:---:|:---:|
99
+ |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
100
+ |+compile:20.43s|16.25s|14.12s|13.41s|12s|8.86s|
101
+ |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
102
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
75
103
 
76
104
  ## ©️Citations
77
105
 
@@ -101,11 +129,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
101
129
  - [⚡️Dynamic Block Prune](#dbprune)
102
130
  - [🎉Context Parallelism](#context-parallelism)
103
131
  - [🔥Torch Compile](#compile)
104
- - [🎉Supported Models](#supported)
105
132
  - [👋Contribute](#contribute)
106
133
  - [©️License](#license)
107
134
 
108
-
109
135
  ## ⚙️Installation
110
136
 
111
137
  <div id="installation"></div>
@@ -335,6 +361,7 @@ Then, run the python test script with `torchrun`:
335
361
  ```bash
336
362
  torchrun --nproc_per_node=4 parallel_cache.py
337
363
  ```
364
+ <!--
338
365
 
339
366
  <div align="center">
340
367
  <p align="center">
@@ -342,17 +369,18 @@ torchrun --nproc_per_node=4 parallel_cache.py
342
369
  </p>
343
370
  </div>
344
371
 
345
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
372
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
346
373
  |:---:|:---:|:---:|:---:|:---:|:---:|
347
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
348
- |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
374
+ |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
375
+ |+L20x4:8.54s|7.20s|6.61s|6.09s|5.54s|4.22s|
349
376
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
377
+ -->
350
378
 
351
379
  ## 🔥Torch Compile
352
380
 
353
381
  <div id="compile"></div>
354
382
 
355
- **CacheDiT** are designed to work compatibly with `torch.compile`. For example:
383
+ **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
356
384
 
357
385
  ```python
358
386
  apply_cache_on_pipe(
@@ -361,21 +389,27 @@ apply_cache_on_pipe(
361
389
  # Compile the Transformer module
362
390
  pipe.transformer = torch.compile(pipe.transformer)
363
391
  ```
364
- However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo` to achieve better performance.
365
-
392
+ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
366
393
  ```python
367
394
  torch._dynamo.config.recompile_limit = 96 # default is 8
368
395
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
369
396
  ```
370
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
371
397
 
372
- ## 🎉Supported Models
398
+ <!--
373
399
 
374
- <div id="supported"></div>
400
+ <div align="center">
401
+ <p align="center">
402
+ DBPrune + <b>torch.compile</b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
403
+ </p>
404
+ </div>
375
405
 
376
- - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
377
- - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
378
- - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
406
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
407
+ |:---:|:---:|:---:|:---:|:---:|:---:|
408
+ |+L20x1:24.8s|19.4s|16.8s|15.9s|14.2s|10.6s|
409
+ |+compile:20.4s|16.5s|14.1s|13.4s|12s|8.8s|
410
+ |+L20x4:7.7s|6.6s|6.0s|5.8s|5.2s|3.9s|
411
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
412
+ -->
379
413
 
380
414
  ## 👋Contribute
381
415
  <div id="contribute"></div>
@@ -3,7 +3,7 @@ import argparse
3
3
  import torch
4
4
  import time
5
5
 
6
- from diffusers import FluxPipeline
6
+ from diffusers import FluxPipeline, FluxTransformer2DModel
7
7
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
8
8
  from cache_dit.logger import init_logger
9
9
 
@@ -110,6 +110,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
110
110
  return cache_options, cache_type_str
111
111
 
112
112
 
113
+ @torch.no_grad()
113
114
  def main():
114
115
  args = get_args()
115
116
  logger.info(f"Arguments: {args}")
@@ -119,7 +120,9 @@ def main():
119
120
  try:
120
121
  import torch.distributed as dist
121
122
  from para_attn.context_parallel import init_context_parallel_mesh
122
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
123
+ from para_attn.context_parallel.diffusers_adapters import (
124
+ parallelize_pipe,
125
+ )
123
126
 
124
127
  # Initialize distributed process group
125
128
  dist.init_process_group()
@@ -133,9 +136,10 @@ def main():
133
136
  ).to("cuda")
134
137
 
135
138
  parallelize_pipe(
136
- pipe, mesh=init_context_parallel_mesh(
139
+ pipe,
140
+ mesh=init_context_parallel_mesh(
137
141
  pipe.device.type, max_ulysses_dim_size=args.ulysses
138
- )
142
+ ),
139
143
  )
140
144
  except ImportError as e:
141
145
  logger.error(
@@ -148,7 +152,7 @@ def main():
148
152
  pipe = FluxPipeline.from_pretrained(
149
153
  os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
150
154
  torch_dtype=torch.bfloat16,
151
- ).to("cuda")
155
+ ).to("cuda")
152
156
 
153
157
  cache_options, cache_type = get_cache_options(args.cache, args)
154
158
 
@@ -165,7 +169,18 @@ def main():
165
169
  torch._dynamo.config.accumulated_recompile_limit = (
166
170
  2048 # default is 256
167
171
  )
168
- pipe.transformer = torch.compile(pipe.transformer, mode="default")
172
+ if isinstance(pipe.transformer, FluxTransformer2DModel):
173
+ logger.warning(
174
+ "Only compile transformer blocks not the whole model "
175
+ "for FluxTransformer2DModel to keep higher precision."
176
+ )
177
+ for module in pipe.transformer.transformer_blocks:
178
+ module.compile()
179
+ for module in pipe.transformer.single_transformer_blocks:
180
+ module.compile()
181
+ else:
182
+ logger.info("Compiling the transformer with default mode.")
183
+ pipe.transformer = torch.compile(pipe.transformer, mode="default")
169
184
 
170
185
  all_times = []
171
186
  cached_stepes = 0
@@ -238,6 +253,7 @@ def main():
238
253
 
239
254
  if args.ulysses is not None:
240
255
  import torch.distributed as dist
256
+
241
257
  dist.destroy_process_group()
242
258
  logger.info("Distributed process group destroyed.")
243
259
 
@@ -164,5 +164,4 @@ _version.py
164
164
  report*.html
165
165
 
166
166
  .DS_Store
167
-
168
167
  *.png
Binary file
@@ -0,0 +1,46 @@
1
+ import os
2
+ import torch
3
+ from diffusers import CogVideoXPipeline
4
+ from diffusers.utils import export_to_video
5
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
6
+
7
+ pipe = CogVideoXPipeline.from_pretrained(
8
+ os.environ.get(
9
+ "COGVIDEOX_DIR",
10
+ "THUDM/CogVideoX-5b",
11
+ ),
12
+ torch_dtype=torch.bfloat16,
13
+ ).to("cuda")
14
+
15
+ # Default options, F8B8, good balance between performance and precision
16
+ cache_options = CacheType.default_options(CacheType.DBCache)
17
+
18
+ apply_cache_on_pipe(pipe, **cache_options)
19
+
20
+ pipe.vae.enable_slicing()
21
+ pipe.vae.enable_tiling()
22
+
23
+ prompt = (
24
+ "A panda, dressed in a small, red jacket and a tiny hat, "
25
+ "sits on a wooden stool in a serene bamboo forest. The "
26
+ "panda's fluffy paws strum a miniature acoustic guitar, "
27
+ "producing soft, melodic tunes. Nearby, a few other pandas "
28
+ "gather, watching curiously and some clapping in rhythm. "
29
+ "Sunlight filters through the tall bamboo, casting a gentle "
30
+ "glow on the scene. The panda's face is expressive, showing "
31
+ "concentration and joy as it plays. The background includes "
32
+ "a small, flowing stream and vibrant green foliage, enhancing "
33
+ "the peaceful and magical atmosphere of this unique musical "
34
+ "performance."
35
+ )
36
+ video = pipe(
37
+ prompt=prompt,
38
+ num_videos_per_prompt=1,
39
+ num_inference_steps=50,
40
+ num_frames=49,
41
+ guidance_scale=6,
42
+ generator=torch.Generator("cuda").manual_seed(0),
43
+ ).frames[0]
44
+
45
+ print("Saving video to cogvideox.mp4")
46
+ export_to_video(video, "cogvideox.mp4", fps=8)
@@ -1,9 +1,13 @@
1
+ import os
1
2
  import torch
2
3
  from diffusers import FluxPipeline
3
4
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
4
5
 
5
6
  pipe = FluxPipeline.from_pretrained(
6
- "black-forest-labs/FLUX.1-dev",
7
+ os.environ.get(
8
+ "FLUX_DIR",
9
+ "black-forest-labs/FLUX.1-dev",
10
+ ),
7
11
  torch_dtype=torch.bfloat16,
8
12
  ).to("cuda")
9
13
 
@@ -0,0 +1,32 @@
1
+ import os
2
+ import torch
3
+ from diffusers import FluxFillPipeline
4
+ from diffusers.utils import load_image
5
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
6
+
7
+ pipe = FluxFillPipeline.from_pretrained(
8
+ os.environ.get(
9
+ "FLUX_FILL_DIR",
10
+ "black-forest-labs/FLUX.1-Fill-dev",
11
+ ),
12
+ torch_dtype=torch.bfloat16,
13
+ ).to("cuda")
14
+
15
+
16
+ # Default options, F8B8, good balance between performance and precision
17
+ cache_options = CacheType.default_options(CacheType.DBCache)
18
+
19
+ apply_cache_on_pipe(pipe, **cache_options)
20
+
21
+ image = pipe(
22
+ prompt="a white paper cup",
23
+ image=load_image("data/cup.png"),
24
+ mask_image=load_image("data/cup_mask.png"),
25
+ guidance_scale=30,
26
+ num_inference_steps=28,
27
+ max_sequence_length=512,
28
+ generator=torch.Generator("cuda").manual_seed(0),
29
+ ).images[0]
30
+
31
+ print("Saving image to flux-fill.png")
32
+ image.save("flux-fill.png")
@@ -1,10 +1,14 @@
1
+ import os
1
2
  import torch
2
3
  from diffusers import MochiPipeline
3
4
  from diffusers.utils import export_to_video
4
5
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
5
6
 
6
7
  pipe = MochiPipeline.from_pretrained(
7
- "genmo/mochi-1-preview",
8
+ os.environ.get(
9
+ "MOCHI_DIR",
10
+ "genmo/mochi-1-preview",
11
+ ),
8
12
  torch_dtype=torch.bfloat16,
9
13
  ).to("cuda")
10
14
 
@@ -15,7 +19,10 @@ apply_cache_on_pipe(pipe, **cache_options)
15
19
 
16
20
  pipe.enable_vae_tiling()
17
21
 
18
- prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
22
+ prompt = (
23
+ "Close-up of a chameleon's eye, with its scaly skin "
24
+ "changing color. Ultra high resolution 4k."
25
+ )
19
26
  video = pipe(
20
27
  prompt,
21
28
  num_frames=84,
@@ -0,0 +1,49 @@
1
+ import os
2
+ import torch
3
+ from diffusers import WanPipeline
4
+ from diffusers.utils import export_to_video
5
+ from diffusers.schedulers.scheduling_unipc_multistep import (
6
+ UniPCMultistepScheduler,
7
+ )
8
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
9
+
10
+ height, width = 480, 832
11
+ pipe = WanPipeline.from_pretrained(
12
+ os.environ.get(
13
+ "WAN_DIR",
14
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
15
+ ),
16
+ torch_dtype=torch.bfloat16,
17
+ )
18
+
19
+ # flow shift should be 3.0 for 480p images, 5.0 for 720p images
20
+ if hasattr(pipe, "scheduler") and pipe.scheduler is not None:
21
+ # Use the UniPCMultistepScheduler with the specified flow shift
22
+ flow_shift = 3.0 if height == 480 else 5.0
23
+ pipe.scheduler = UniPCMultistepScheduler.from_config(
24
+ pipe.scheduler.config,
25
+ flow_shift=flow_shift,
26
+ )
27
+
28
+ pipe.to("cuda")
29
+
30
+ apply_cache_on_pipe(pipe, **CacheType.default_options(CacheType.FBCache))
31
+
32
+ # Enable memory savings
33
+ pipe.enable_model_cpu_offload()
34
+ pipe.enable_vae_tiling()
35
+
36
+ video = pipe(
37
+ prompt=(
38
+ "An astronaut dancing vigorously on the moon with earth "
39
+ "flying past in the background, hyperrealistic"
40
+ ),
41
+ negative_prompt="",
42
+ height=480,
43
+ width=832,
44
+ num_frames=81,
45
+ num_inference_steps=30,
46
+ ).frames[0]
47
+
48
+ print("Saving video to wan.mp4")
49
+ export_to_video(video, "wan.mp4", fps=15)
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.7'
21
- __version_tuple__ = version_tuple = (0, 1, 7)
20
+ __version__ = version = '0.1.8'
21
+ __version_tuple__ = version_tuple = (0, 1, 8)
@@ -628,7 +628,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
628
628
  return sorted(non_prune_blocks_ids)
629
629
 
630
630
  # @torch.compile(dynamic=True)
631
- # mark this function as compile with dynamic=True will
631
+ # mark this function as compile with dynamic=True will
632
632
  # cause precision degradate, so, we choose to disable it
633
633
  # now, until we find a better solution or fixed the bug.
634
634
  @torch.compiler.disable
@@ -668,7 +668,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
668
668
  )
669
669
 
670
670
  # @torch.compile(dynamic=True)
671
- # mark this function as compile with dynamic=True will
671
+ # mark this function as compile with dynamic=True will
672
672
  # cause precision degradate, so, we choose to disable it
673
673
  # now, until we find a better solution or fixed the bug.
674
674
  @torch.compiler.disable
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -35,7 +35,7 @@ Dynamic: requires-python
35
35
 
36
36
  <div align="center">
37
37
  <p align="center">
38
- <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
38
+ <h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
39
39
  </p>
40
40
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
41
41
  <div align='center'>
@@ -44,13 +44,32 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.1.7-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.1.8-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
50
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
51
51
  </p>
52
+ <p align="center">
53
+ <h3> 🔥Supported Models🔥</h2>
54
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
55
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
56
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
57
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: 🔜DBCache, 🔜DBPrune, ✔️FBCache🔥</a> <br> <br>
58
+ <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
59
+ </p>
52
60
  </div>
53
61
 
62
+
63
+ <!--
64
+ ## 🎉Supported Models
65
+ <div id="supported"></div>
66
+ - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
67
+ - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
68
+ - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
69
+ - [🚀Wan2.1**](https://github.com/vipshop/cache-dit/raw/main/examples): *🔜DBCache, 🔜DBPrune, ✔️FBCache*
70
+ -->
71
+
72
+
54
73
  ## 🤗 Introduction
55
74
 
56
75
  <div align="center">
@@ -102,11 +121,20 @@ These case studies demonstrate that even with relatively high thresholds (such a
102
121
  </p>
103
122
  </div>
104
123
 
105
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
124
+ **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. Moreover, **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance.
125
+
126
+ <div align="center">
127
+ <p align="center">
128
+ DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
129
+ </p>
130
+ </div>
106
131
 
107
- <p align="center">
108
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
109
- </p>
132
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
133
+ |:---:|:---:|:---:|:---:|:---:|:---:|
134
+ |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
135
+ |+compile:20.43s|16.25s|14.12s|13.41s|12s|8.86s|
136
+ |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
137
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
110
138
 
111
139
  ## ©️Citations
112
140
 
@@ -136,11 +164,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
136
164
  - [⚡️Dynamic Block Prune](#dbprune)
137
165
  - [🎉Context Parallelism](#context-parallelism)
138
166
  - [🔥Torch Compile](#compile)
139
- - [🎉Supported Models](#supported)
140
167
  - [👋Contribute](#contribute)
141
168
  - [©️License](#license)
142
169
 
143
-
144
170
  ## ⚙️Installation
145
171
 
146
172
  <div id="installation"></div>
@@ -370,6 +396,7 @@ Then, run the python test script with `torchrun`:
370
396
  ```bash
371
397
  torchrun --nproc_per_node=4 parallel_cache.py
372
398
  ```
399
+ <!--
373
400
 
374
401
  <div align="center">
375
402
  <p align="center">
@@ -377,17 +404,18 @@ torchrun --nproc_per_node=4 parallel_cache.py
377
404
  </p>
378
405
  </div>
379
406
 
380
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
407
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
381
408
  |:---:|:---:|:---:|:---:|:---:|:---:|
382
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
383
- |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
409
+ |+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
410
+ |+L20x4:8.54s|7.20s|6.61s|6.09s|5.54s|4.22s|
384
411
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
412
+ -->
385
413
 
386
414
  ## 🔥Torch Compile
387
415
 
388
416
  <div id="compile"></div>
389
417
 
390
- **CacheDiT** are designed to work compatibly with `torch.compile`. For example:
418
+ **CacheDiT** are designed to work compatibly with `torch.compile`. You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
391
419
 
392
420
  ```python
393
421
  apply_cache_on_pipe(
@@ -396,21 +424,27 @@ apply_cache_on_pipe(
396
424
  # Compile the Transformer module
397
425
  pipe.transformer = torch.compile(pipe.transformer)
398
426
  ```
399
- However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo` to achieve better performance.
400
-
427
+ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
401
428
  ```python
402
429
  torch._dynamo.config.recompile_limit = 96 # default is 8
403
430
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
404
431
  ```
405
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
406
432
 
407
- ## 🎉Supported Models
433
+ <!--
408
434
 
409
- <div id="supported"></div>
435
+ <div align="center">
436
+ <p align="center">
437
+ DBPrune + <b>torch.compile</b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
438
+ </p>
439
+ </div>
410
440
 
411
- - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
412
- - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
413
- - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
441
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
442
+ |:---:|:---:|:---:|:---:|:---:|:---:|
443
+ |+L20x1:24.8s|19.4s|16.8s|15.9s|14.2s|10.6s|
444
+ |+compile:20.4s|16.5s|14.1s|13.4s|12s|8.8s|
445
+ |+L20x4:7.7s|6.6s|6.0s|5.8s|5.2s|3.9s|
446
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
447
+ -->
414
448
 
415
449
  ## 👋Contribute
416
450
  <div id="contribute"></div>
@@ -35,6 +35,27 @@ assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png
35
35
  assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png
36
36
  assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png
37
37
  assets/NONE_R0.08_S0.png
38
+ assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png
39
+ assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png
40
+ assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png
41
+ assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png
42
+ assets/U0_C1_DBPRUNE_F1B0_R0.05_P41.6_T12.70s.png
43
+ assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png
44
+ assets/U0_C1_DBPRUNE_F8B8_R0.08_P23.1_T16.14s.png
45
+ assets/U0_C1_NONE_R0.08_S0_T20.43s.png
46
+ assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.62s.png
47
+ assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.63s.png
48
+ assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.81s.png
49
+ assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.82s.png
50
+ assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.06s.png
51
+ assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.07s.png
52
+ assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.08s.png
53
+ assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.27s.png
54
+ assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.28s.png
55
+ assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.95s.png
56
+ assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.96s.png
57
+ assets/U4_C1_NONE_R0.08_S0_T7.78s.png
58
+ assets/U4_C1_NONE_R0.08_S0_T7.79s.png
38
59
  assets/cache-dit.png
39
60
  bench/.gitignore
40
61
  bench/bench.py
@@ -42,7 +63,11 @@ docs/.gitignore
42
63
  examples/.gitignore
43
64
  examples/run_cogvideox.py
44
65
  examples/run_flux.py
66
+ examples/run_flux_fill.py
45
67
  examples/run_mochi.py
68
+ examples/run_wan.py
69
+ examples/data/cup.png
70
+ examples/data/cup_mask.png
46
71
  src/cache_dit/__init__.py
47
72
  src/cache_dit/_version.py
48
73
  src/cache_dit/logger.py
@@ -1,30 +0,0 @@
1
- import torch
2
- from diffusers import CogVideoXPipeline
3
- from diffusers.utils import export_to_video
4
- from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
5
-
6
- pipe = CogVideoXPipeline.from_pretrained(
7
- "THUDM/CogVideoX-5b",
8
- torch_dtype=torch.bfloat16,
9
- ).to("cuda")
10
-
11
- # Default options, F8B8, good balance between performance and precision
12
- cache_options = CacheType.default_options(CacheType.DBCache)
13
-
14
- apply_cache_on_pipe(pipe, **cache_options)
15
-
16
- pipe.vae.enable_slicing()
17
- pipe.vae.enable_tiling()
18
-
19
- prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
20
- video = pipe(
21
- prompt=prompt,
22
- num_videos_per_prompt=1,
23
- num_inference_steps=50,
24
- num_frames=49,
25
- guidance_scale=6,
26
- generator=torch.Generator("cuda").manual_seed(0),
27
- ).frames[0]
28
-
29
- print("Saving video to cogvideox.mp4")
30
- export_to_video(video, "cogvideox.mp4", fps=8)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes