cache-dit 0.1.7__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (111) hide show
  1. {cache_dit-0.1.7 → cache_dit-0.2.0}/PKG-INFO +45 -40
  2. {cache_dit-0.1.7 → cache_dit-0.2.0}/README.md +44 -39
  3. cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png +0 -0
  4. cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png +0 -0
  5. cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png +0 -0
  6. cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png +0 -0
  7. cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.05_P41.6_T12.70s.png +0 -0
  8. cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png +0 -0
  9. cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F8B8_R0.08_P23.1_T16.14s.png +0 -0
  10. cache_dit-0.2.0/assets/U0_C1_NONE_R0.08_S0_T20.43s.png +0 -0
  11. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.62s.png +0 -0
  12. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.63s.png +0 -0
  13. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.81s.png +0 -0
  14. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.82s.png +0 -0
  15. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.06s.png +0 -0
  16. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.07s.png +0 -0
  17. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.08s.png +0 -0
  18. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.27s.png +0 -0
  19. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.28s.png +0 -0
  20. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.95s.png +0 -0
  21. cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.96s.png +0 -0
  22. cache_dit-0.2.0/assets/U4_C1_NONE_R0.08_S0_T7.78s.png +0 -0
  23. cache_dit-0.2.0/assets/U4_C1_NONE_R0.08_S0_T7.79s.png +0 -0
  24. {cache_dit-0.1.7 → cache_dit-0.2.0}/bench/bench.py +22 -6
  25. {cache_dit-0.1.7 → cache_dit-0.2.0}/examples/.gitignore +1 -1
  26. cache_dit-0.2.0/examples/README.md +45 -0
  27. cache_dit-0.2.0/examples/data/cup.png +0 -0
  28. cache_dit-0.2.0/examples/data/cup_mask.png +0 -0
  29. cache_dit-0.2.0/examples/requirements.txt +4 -0
  30. cache_dit-0.2.0/examples/run_cogvideox.py +72 -0
  31. {cache_dit-0.1.7 → cache_dit-0.2.0}/examples/run_flux.py +5 -1
  32. cache_dit-0.2.0/examples/run_flux_fill.py +32 -0
  33. cache_dit-0.2.0/examples/run_hunyuan_video.py +75 -0
  34. {cache_dit-0.1.7 → cache_dit-0.2.0}/examples/run_mochi.py +9 -2
  35. cache_dit-0.2.0/examples/run_wan.py +54 -0
  36. {cache_dit-0.1.7 → cache_dit-0.2.0}/setup.py +1 -0
  37. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/_version.py +2 -2
  38. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +8 -0
  39. cache_dit-0.2.0/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
  40. cache_dit-0.2.0/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +99 -0
  41. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +12 -4
  42. cache_dit-0.2.0/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +295 -0
  43. cache_dit-0.2.0/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +99 -0
  44. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +2 -2
  45. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +4 -0
  46. cache_dit-0.2.0/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
  47. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +2 -2
  48. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/PKG-INFO +45 -40
  49. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/SOURCES.txt +33 -0
  50. cache_dit-0.1.7/examples/run_cogvideox.py +0 -30
  51. {cache_dit-0.1.7 → cache_dit-0.2.0}/.github/workflows/issue.yml +0 -0
  52. {cache_dit-0.1.7 → cache_dit-0.2.0}/.gitignore +0 -0
  53. {cache_dit-0.1.7 → cache_dit-0.2.0}/.pre-commit-config.yaml +0 -0
  54. {cache_dit-0.1.7 → cache_dit-0.2.0}/CONTRIBUTE.md +0 -0
  55. {cache_dit-0.1.7 → cache_dit-0.2.0}/LICENSE +0 -0
  56. {cache_dit-0.1.7 → cache_dit-0.2.0}/MANIFEST.in +0 -0
  57. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
  58. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
  59. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
  60. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
  61. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
  62. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
  63. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
  64. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
  65. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
  66. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
  67. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
  68. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCache.png +0 -0
  69. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
  70. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
  71. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
  72. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
  73. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
  74. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
  75. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
  76. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
  77. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
  78. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
  79. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
  80. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
  81. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/NONE_R0.08_S0.png +0 -0
  82. {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/cache-dit.png +0 -0
  83. {cache_dit-0.1.7 → cache_dit-0.2.0}/bench/.gitignore +0 -0
  84. {cache_dit-0.1.7 → cache_dit-0.2.0}/docs/.gitignore +0 -0
  85. {cache_dit-0.1.7 → cache_dit-0.2.0}/pyproject.toml +0 -0
  86. {cache_dit-0.1.7 → cache_dit-0.2.0}/pytest.ini +0 -0
  87. {cache_dit-0.1.7 → cache_dit-0.2.0}/requirements.txt +0 -0
  88. {cache_dit-0.1.7 → cache_dit-0.2.0}/setup.cfg +0 -0
  89. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/__init__.py +0 -0
  90. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/__init__.py +0 -0
  91. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  92. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -0
  93. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
  94. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
  95. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
  96. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  97. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
  98. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
  99. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
  100. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  101. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -0
  102. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
  103. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
  104. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
  105. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/taylorseer.py +0 -0
  106. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/utils.py +0 -0
  107. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/logger.py +0 -0
  108. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/primitives.py +0 -0
  109. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/dependency_links.txt +0 -0
  110. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/requires.txt +0 -0
  111. {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.7
3
+ Version: 0.2.0
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -35,7 +35,7 @@ Dynamic: requires-python
35
35
 
36
36
  <div align="center">
37
37
  <p align="center">
38
- <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
38
+ <h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
39
39
  </p>
40
40
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
41
41
  <div align='center'>
@@ -44,13 +44,28 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.1.7-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.2.0-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
50
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
51
51
  </p>
52
+ <p align="center">
53
+ <h4> 🔥Supported Models🔥</h4>
54
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
55
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
56
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
57
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
58
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
59
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
60
+ </p>
52
61
  </div>
53
62
 
63
+ ## 👋 Highlight
64
+
65
+ <div id="reference"></div>
66
+
67
+ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
68
+
54
69
  ## 🤗 Introduction
55
70
 
56
71
  <div align="center">
@@ -91,6 +106,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
91
106
 
92
107
  **DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
93
108
 
109
+ <div align="center">
110
+ <p align="center">
111
+ DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
112
+ </p>
113
+ </div>
114
+
94
115
  |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
95
116
  |:---:|:---:|:---:|:---:|:---:|:---:|
96
117
  |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
@@ -98,15 +119,29 @@ These case studies demonstrate that even with relatively high thresholds (such a
98
119
 
99
120
  <div align="center">
100
121
  <p align="center">
101
- DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
122
+ <h3>🔥 Context Parallelism and Torch Compile</h3>
123
+ </p>
124
+ </div>
125
+
126
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
127
+
128
+ <div align="center">
129
+ <p align="center">
130
+ DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
102
131
  </p>
103
132
  </div>
104
133
 
105
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
134
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
135
+ |:---:|:---:|:---:|:---:|:---:|:---:|
136
+ |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
137
+ |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
138
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
106
139
 
107
- <p align="center">
108
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
109
- </p>
140
+ <div align="center">
141
+ <p align="center">
142
+ <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
143
+ </p>
144
+ </div>
110
145
 
111
146
  ## ©️Citations
112
147
 
@@ -120,12 +155,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
120
155
  }
121
156
  ```
122
157
 
123
- ## 👋Reference
124
-
125
- <div id="reference"></div>
126
-
127
- The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
128
-
129
158
  ## 📖Contents
130
159
 
131
160
  <div id="contents"></div>
@@ -136,11 +165,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
136
165
  - [⚡️Dynamic Block Prune](#dbprune)
137
166
  - [🎉Context Parallelism](#context-parallelism)
138
167
  - [🔥Torch Compile](#compile)
139
- - [🎉Supported Models](#supported)
140
168
  - [👋Contribute](#contribute)
141
169
  - [©️License](#license)
142
170
 
143
-
144
171
  ## ⚙️Installation
145
172
 
146
173
  <div id="installation"></div>
@@ -371,23 +398,11 @@ Then, run the python test script with `torchrun`:
371
398
  torchrun --nproc_per_node=4 parallel_cache.py
372
399
  ```
373
400
 
374
- <div align="center">
375
- <p align="center">
376
- DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
377
- </p>
378
- </div>
379
-
380
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
381
- |:---:|:---:|:---:|:---:|:---:|:---:|
382
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
383
- |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
384
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
385
-
386
401
  ## 🔥Torch Compile
387
402
 
388
403
  <div id="compile"></div>
389
404
 
390
- **CacheDiT** are designed to work compatibly with `torch.compile`. For example:
405
+ By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
391
406
 
392
407
  ```python
393
408
  apply_cache_on_pipe(
@@ -396,21 +411,11 @@ apply_cache_on_pipe(
396
411
  # Compile the Transformer module
397
412
  pipe.transformer = torch.compile(pipe.transformer)
398
413
  ```
399
- However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo` to achieve better performance.
400
-
414
+ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
401
415
  ```python
402
416
  torch._dynamo.config.recompile_limit = 96 # default is 8
403
417
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
404
418
  ```
405
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
406
-
407
- ## 🎉Supported Models
408
-
409
- <div id="supported"></div>
410
-
411
- - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
412
- - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
413
- - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
414
419
 
415
420
  ## 👋Contribute
416
421
  <div id="contribute"></div>
@@ -1,6 +1,6 @@
1
1
  <div align="center">
2
2
  <p align="center">
3
- <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
3
+ <h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
4
4
  </p>
5
5
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
6
6
  <div align='center'>
@@ -9,13 +9,28 @@
9
9
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
10
10
  <img src=https://static.pepy.tech/badge/cache-dit >
11
11
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
12
- <img src=https://img.shields.io/badge/Release-v0.1.7-brightgreen.svg >
12
+ <img src=https://img.shields.io/badge/Release-v0.2.0-brightgreen.svg >
13
13
  </div>
14
14
  <p align="center">
15
15
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
16
16
  </p>
17
+ <p align="center">
18
+ <h4> 🔥Supported Models🔥</h4>
19
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
20
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
21
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
22
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
23
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
24
+ <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
25
+ </p>
17
26
  </div>
18
27
 
28
+ ## 👋 Highlight
29
+
30
+ <div id="reference"></div>
31
+
32
+ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
33
+
19
34
  ## 🤗 Introduction
20
35
 
21
36
  <div align="center">
@@ -56,6 +71,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
56
71
 
57
72
  **DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
58
73
 
74
+ <div align="center">
75
+ <p align="center">
76
+ DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
77
+ </p>
78
+ </div>
79
+
59
80
  |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
60
81
  |:---:|:---:|:---:|:---:|:---:|:---:|
61
82
  |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
@@ -63,15 +84,29 @@ These case studies demonstrate that even with relatively high thresholds (such a
63
84
 
64
85
  <div align="center">
65
86
  <p align="center">
66
- DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
87
+ <h3>🔥 Context Parallelism and Torch Compile</h3>
88
+ </p>
89
+ </div>
90
+
91
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
92
+
93
+ <div align="center">
94
+ <p align="center">
95
+ DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
67
96
  </p>
68
97
  </div>
69
98
 
70
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
99
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
100
+ |:---:|:---:|:---:|:---:|:---:|:---:|
101
+ |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
102
+ |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
103
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
71
104
 
72
- <p align="center">
73
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
74
- </p>
105
+ <div align="center">
106
+ <p align="center">
107
+ <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
108
+ </p>
109
+ </div>
75
110
 
76
111
  ## ©️Citations
77
112
 
@@ -85,12 +120,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
85
120
  }
86
121
  ```
87
122
 
88
- ## 👋Reference
89
-
90
- <div id="reference"></div>
91
-
92
- The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
93
-
94
123
  ## 📖Contents
95
124
 
96
125
  <div id="contents"></div>
@@ -101,11 +130,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
101
130
  - [⚡️Dynamic Block Prune](#dbprune)
102
131
  - [🎉Context Parallelism](#context-parallelism)
103
132
  - [🔥Torch Compile](#compile)
104
- - [🎉Supported Models](#supported)
105
133
  - [👋Contribute](#contribute)
106
134
  - [©️License](#license)
107
135
 
108
-
109
136
  ## ⚙️Installation
110
137
 
111
138
  <div id="installation"></div>
@@ -336,23 +363,11 @@ Then, run the python test script with `torchrun`:
336
363
  torchrun --nproc_per_node=4 parallel_cache.py
337
364
  ```
338
365
 
339
- <div align="center">
340
- <p align="center">
341
- DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
342
- </p>
343
- </div>
344
-
345
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
346
- |:---:|:---:|:---:|:---:|:---:|:---:|
347
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
348
- |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
349
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
350
-
351
366
  ## 🔥Torch Compile
352
367
 
353
368
  <div id="compile"></div>
354
369
 
355
- **CacheDiT** are designed to work compatibly with `torch.compile`. For example:
370
+ By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
356
371
 
357
372
  ```python
358
373
  apply_cache_on_pipe(
@@ -361,21 +376,11 @@ apply_cache_on_pipe(
361
376
  # Compile the Transformer module
362
377
  pipe.transformer = torch.compile(pipe.transformer)
363
378
  ```
364
- However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo` to achieve better performance.
365
-
379
+ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
366
380
  ```python
367
381
  torch._dynamo.config.recompile_limit = 96 # default is 8
368
382
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
369
383
  ```
370
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
371
-
372
- ## 🎉Supported Models
373
-
374
- <div id="supported"></div>
375
-
376
- - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
377
- - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
378
- - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
379
384
 
380
385
  ## 👋Contribute
381
386
  <div id="contribute"></div>
@@ -3,7 +3,7 @@ import argparse
3
3
  import torch
4
4
  import time
5
5
 
6
- from diffusers import FluxPipeline
6
+ from diffusers import FluxPipeline, FluxTransformer2DModel
7
7
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
8
8
  from cache_dit.logger import init_logger
9
9
 
@@ -110,6 +110,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
110
110
  return cache_options, cache_type_str
111
111
 
112
112
 
113
+ @torch.no_grad()
113
114
  def main():
114
115
  args = get_args()
115
116
  logger.info(f"Arguments: {args}")
@@ -119,7 +120,9 @@ def main():
119
120
  try:
120
121
  import torch.distributed as dist
121
122
  from para_attn.context_parallel import init_context_parallel_mesh
122
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
123
+ from para_attn.context_parallel.diffusers_adapters import (
124
+ parallelize_pipe,
125
+ )
123
126
 
124
127
  # Initialize distributed process group
125
128
  dist.init_process_group()
@@ -133,9 +136,10 @@ def main():
133
136
  ).to("cuda")
134
137
 
135
138
  parallelize_pipe(
136
- pipe, mesh=init_context_parallel_mesh(
139
+ pipe,
140
+ mesh=init_context_parallel_mesh(
137
141
  pipe.device.type, max_ulysses_dim_size=args.ulysses
138
- )
142
+ ),
139
143
  )
140
144
  except ImportError as e:
141
145
  logger.error(
@@ -148,7 +152,7 @@ def main():
148
152
  pipe = FluxPipeline.from_pretrained(
149
153
  os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
150
154
  torch_dtype=torch.bfloat16,
151
- ).to("cuda")
155
+ ).to("cuda")
152
156
 
153
157
  cache_options, cache_type = get_cache_options(args.cache, args)
154
158
 
@@ -165,7 +169,18 @@ def main():
165
169
  torch._dynamo.config.accumulated_recompile_limit = (
166
170
  2048 # default is 256
167
171
  )
168
- pipe.transformer = torch.compile(pipe.transformer, mode="default")
172
+ if isinstance(pipe.transformer, FluxTransformer2DModel):
173
+ logger.warning(
174
+ "Only compile transformer blocks not the whole model "
175
+ "for FluxTransformer2DModel to keep higher precision."
176
+ )
177
+ for module in pipe.transformer.transformer_blocks:
178
+ module.compile()
179
+ for module in pipe.transformer.single_transformer_blocks:
180
+ module.compile()
181
+ else:
182
+ logger.info("Compiling the transformer with default mode.")
183
+ pipe.transformer = torch.compile(pipe.transformer, mode="default")
169
184
 
170
185
  all_times = []
171
186
  cached_stepes = 0
@@ -238,6 +253,7 @@ def main():
238
253
 
239
254
  if args.ulysses is not None:
240
255
  import torch.distributed as dist
256
+
241
257
  dist.destroy_process_group()
242
258
  logger.info("Distributed process group destroyed.")
243
259
 
@@ -164,5 +164,5 @@ _version.py
164
164
  report*.html
165
165
 
166
166
  .DS_Store
167
-
168
167
  *.png
168
+ *.mp4
@@ -0,0 +1,45 @@
1
+ # Examples for CacheDiT
2
+
3
+ ## Install requirements
4
+
5
+ ```bash
6
+ pip3 install -r requirements.txt
7
+ ```
8
+
9
+ ## Run examples
10
+
11
+ - FLUX.1-dev
12
+
13
+ ```bash
14
+ python3 run_flux.py
15
+ ```
16
+
17
+ - FLUX.1-Fill-dev
18
+
19
+ ```bash
20
+ python3 run_flux_fill.py
21
+ ```
22
+
23
+ - CogVideoX
24
+
25
+ ```bash
26
+ python3 run_cogvideox.py
27
+ ```
28
+
29
+ - Wan2.1
30
+
31
+ ```bash
32
+ python3 run_wan.py
33
+ ```
34
+
35
+ - Mochi
36
+
37
+ ```bash
38
+ python3 run_mochi.py
39
+ ```
40
+
41
+ - HunyuanVideo
42
+
43
+ ```bash
44
+ python3 run_hunyuan_video.py
45
+ ```
Binary file
@@ -0,0 +1,4 @@
1
+ imageio-ffmpeg
2
+ # wan currently requires installing from source
3
+ diffusers @ git+https://github.com/huggingface/diffusers
4
+ ftfy
@@ -0,0 +1,72 @@
1
+ import os
2
+ import torch
3
+ from diffusers.utils import export_to_video
4
+ from diffusers import CogVideoXPipeline, AutoencoderKLCogVideoX
5
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
6
+
7
+
8
+ model_id = os.environ.get("COGVIDEOX_DIR", "THUDM/CogVideoX-5b")
9
+
10
+
11
+ def is_cogvideox_1_5():
12
+ return "CogVideoX1.5" in model_id or "THUDM/CogVideoX1.5" in model_id
13
+
14
+
15
+ def get_gpu_memory_in_gib():
16
+ if not torch.cuda.is_available():
17
+ return 0
18
+
19
+ try:
20
+ total_memory_bytes = torch.cuda.get_device_properties(
21
+ torch.cuda.current_device(),
22
+ ).total_memory
23
+ total_memory_gib = total_memory_bytes / (1024**3)
24
+ return int(total_memory_gib)
25
+ except Exception:
26
+ return 0
27
+
28
+
29
+ pipe = CogVideoXPipeline.from_pretrained(
30
+ model_id,
31
+ torch_dtype=torch.bfloat16,
32
+ ).to("cuda")
33
+
34
+ # Default options, F8B8, good balance between performance and precision
35
+ cache_options = CacheType.default_options(CacheType.DBCache)
36
+
37
+ apply_cache_on_pipe(pipe, **cache_options)
38
+
39
+ pipe.enable_model_cpu_offload()
40
+ assert isinstance(pipe.vae, AutoencoderKLCogVideoX) # enable type check for IDE
41
+ pipe.vae.enable_slicing()
42
+ pipe.vae.enable_tiling()
43
+
44
+ prompt = (
45
+ "A panda, dressed in a small, red jacket and a tiny hat, "
46
+ "sits on a wooden stool in a serene bamboo forest. The "
47
+ "panda's fluffy paws strum a miniature acoustic guitar, "
48
+ "producing soft, melodic tunes. Nearby, a few other pandas "
49
+ "gather, watching curiously and some clapping in rhythm. "
50
+ "Sunlight filters through the tall bamboo, casting a gentle "
51
+ "glow on the scene. The panda's face is expressive, showing "
52
+ "concentration and joy as it plays. The background includes "
53
+ "a small, flowing stream and vibrant green foliage, enhancing "
54
+ "the peaceful and magical atmosphere of this unique musical "
55
+ "performance."
56
+ )
57
+ video = pipe(
58
+ prompt=prompt,
59
+ num_videos_per_prompt=1,
60
+ num_inference_steps=50,
61
+ num_frames=(
62
+ # Avoid OOM for CogVideoX1.5 model on 48GB GPU
63
+ 16
64
+ if (is_cogvideox_1_5() and get_gpu_memory_in_gib() < 48)
65
+ else 49
66
+ ),
67
+ guidance_scale=6,
68
+ generator=torch.Generator("cuda").manual_seed(0),
69
+ ).frames[0]
70
+
71
+ print("Saving video to cogvideox.mp4")
72
+ export_to_video(video, "cogvideox.mp4", fps=8)
@@ -1,9 +1,13 @@
1
+ import os
1
2
  import torch
2
3
  from diffusers import FluxPipeline
3
4
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
4
5
 
5
6
  pipe = FluxPipeline.from_pretrained(
6
- "black-forest-labs/FLUX.1-dev",
7
+ os.environ.get(
8
+ "FLUX_DIR",
9
+ "black-forest-labs/FLUX.1-dev",
10
+ ),
7
11
  torch_dtype=torch.bfloat16,
8
12
  ).to("cuda")
9
13
 
@@ -0,0 +1,32 @@
1
+ import os
2
+ import torch
3
+ from diffusers import FluxFillPipeline
4
+ from diffusers.utils import load_image
5
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
6
+
7
+ pipe = FluxFillPipeline.from_pretrained(
8
+ os.environ.get(
9
+ "FLUX_FILL_DIR",
10
+ "black-forest-labs/FLUX.1-Fill-dev",
11
+ ),
12
+ torch_dtype=torch.bfloat16,
13
+ ).to("cuda")
14
+
15
+
16
+ # Default options, F8B8, good balance between performance and precision
17
+ cache_options = CacheType.default_options(CacheType.DBCache)
18
+
19
+ apply_cache_on_pipe(pipe, **cache_options)
20
+
21
+ image = pipe(
22
+ prompt="a white paper cup",
23
+ image=load_image("data/cup.png"),
24
+ mask_image=load_image("data/cup_mask.png"),
25
+ guidance_scale=30,
26
+ num_inference_steps=28,
27
+ max_sequence_length=512,
28
+ generator=torch.Generator("cuda").manual_seed(0),
29
+ ).images[0]
30
+
31
+ print("Saving image to flux-fill.png")
32
+ image.save("flux-fill.png")