cache-dit 0.1.8__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- {cache_dit-0.1.8 → cache_dit-0.2.0}/PKG-INFO +27 -56
- {cache_dit-0.1.8 → cache_dit-0.2.0}/README.md +26 -55
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/.gitignore +1 -0
- cache_dit-0.2.0/examples/README.md +45 -0
- cache_dit-0.2.0/examples/requirements.txt +4 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/run_cogvideox.py +32 -6
- cache_dit-0.2.0/examples/run_hunyuan_video.py +75 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/run_wan.py +10 -5
- {cache_dit-0.1.8 → cache_dit-0.2.0}/setup.py +1 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/_version.py +2 -2
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +8 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +99 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +12 -4
- cache_dit-0.2.0/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +99 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +2 -2
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit.egg-info/PKG-INFO +27 -56
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit.egg-info/SOURCES.txt +8 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/.github/workflows/issue.yml +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/.gitignore +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/.pre-commit-config.yaml +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/CONTRIBUTE.md +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/LICENSE +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/MANIFEST.in +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBCache.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/NONE_R0.08_S0.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_DBPRUNE_F1B0_R0.05_P41.6_T12.70s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_DBPRUNE_F8B8_R0.08_P23.1_T16.14s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U0_C1_NONE_R0.08_S0_T20.43s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.62s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.63s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.81s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.82s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.06s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.07s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.08s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.27s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.28s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.95s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.96s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_NONE_R0.08_S0_T7.78s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/U4_C1_NONE_R0.08_S0_T7.79s.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/assets/cache-dit.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/bench/.gitignore +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/bench/bench.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/docs/.gitignore +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/data/cup.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/data/cup_mask.png +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/run_flux.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/run_flux_fill.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/examples/run_mochi.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/pyproject.toml +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/pytest.ini +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/requirements.txt +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/setup.cfg +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/__init__.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/__init__.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/taylorseer.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/cache_factory/utils.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/logger.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit/primitives.py +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit.egg-info/dependency_links.txt +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit.egg-info/requires.txt +0 -0
- {cache_dit-0.1.8 → cache_dit-0.2.0}/src/cache_dit.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -44,31 +44,27 @@ Dynamic: requires-python
|
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.2.0-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
50
|
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
51
51
|
</p>
|
|
52
52
|
<p align="center">
|
|
53
|
-
<
|
|
53
|
+
<h4> 🔥Supported Models🔥</h4>
|
|
54
54
|
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
55
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
56
55
|
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
57
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀
|
|
58
|
-
<
|
|
56
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
57
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
58
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
59
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
59
60
|
</p>
|
|
60
61
|
</div>
|
|
61
62
|
|
|
63
|
+
## 👋 Highlight
|
|
62
64
|
|
|
63
|
-
|
|
64
|
-
## 🎉Supported Models
|
|
65
|
-
<div id="supported"></div>
|
|
66
|
-
- [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
|
|
67
|
-
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
|
|
68
|
-
- [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
|
|
69
|
-
- [🚀Wan2.1**](https://github.com/vipshop/cache-dit/raw/main/examples): *🔜DBCache, 🔜DBPrune, ✔️FBCache*
|
|
70
|
-
-->
|
|
65
|
+
<div id="reference"></div>
|
|
71
66
|
|
|
67
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
|
|
72
68
|
|
|
73
69
|
## 🤗 Introduction
|
|
74
70
|
|
|
@@ -110,6 +106,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
110
106
|
|
|
111
107
|
**DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
|
|
112
108
|
|
|
109
|
+
<div align="center">
|
|
110
|
+
<p align="center">
|
|
111
|
+
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
112
|
+
</p>
|
|
113
|
+
</div>
|
|
114
|
+
|
|
113
115
|
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
114
116
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
115
117
|
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
@@ -117,11 +119,11 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
117
119
|
|
|
118
120
|
<div align="center">
|
|
119
121
|
<p align="center">
|
|
120
|
-
|
|
122
|
+
<h3>🔥 Context Parallelism and Torch Compile</h3>
|
|
121
123
|
</p>
|
|
122
|
-
</div>
|
|
124
|
+
</div>
|
|
123
125
|
|
|
124
|
-
**CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
126
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
125
127
|
|
|
126
128
|
<div align="center">
|
|
127
129
|
<p align="center">
|
|
@@ -131,11 +133,16 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
131
133
|
|
|
132
134
|
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
133
135
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
134
|
-
|+
|
|
135
|
-
|+compile:20.43s|16.25s|14.12s|13.41s|12s|8.86s|
|
|
136
|
+
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
136
137
|
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
137
138
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
138
139
|
|
|
140
|
+
<div align="center">
|
|
141
|
+
<p align="center">
|
|
142
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
143
|
+
</p>
|
|
144
|
+
</div>
|
|
145
|
+
|
|
139
146
|
## ©️Citations
|
|
140
147
|
|
|
141
148
|
```BibTeX
|
|
@@ -148,12 +155,6 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
148
155
|
}
|
|
149
156
|
```
|
|
150
157
|
|
|
151
|
-
## 👋Reference
|
|
152
|
-
|
|
153
|
-
<div id="reference"></div>
|
|
154
|
-
|
|
155
|
-
The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
|
|
156
|
-
|
|
157
158
|
## 📖Contents
|
|
158
159
|
|
|
159
160
|
<div id="contents"></div>
|
|
@@ -396,26 +397,12 @@ Then, run the python test script with `torchrun`:
|
|
|
396
397
|
```bash
|
|
397
398
|
torchrun --nproc_per_node=4 parallel_cache.py
|
|
398
399
|
```
|
|
399
|
-
<!--
|
|
400
|
-
|
|
401
|
-
<div align="center">
|
|
402
|
-
<p align="center">
|
|
403
|
-
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
404
|
-
</p>
|
|
405
|
-
</div>
|
|
406
|
-
|
|
407
|
-
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
408
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
409
|
-
|+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
410
|
-
|+L20x4:8.54s|7.20s|6.61s|6.09s|5.54s|4.22s|
|
|
411
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
412
|
-
-->
|
|
413
400
|
|
|
414
401
|
## 🔥Torch Compile
|
|
415
402
|
|
|
416
403
|
<div id="compile"></div>
|
|
417
404
|
|
|
418
|
-
**CacheDiT**
|
|
405
|
+
By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
|
|
419
406
|
|
|
420
407
|
```python
|
|
421
408
|
apply_cache_on_pipe(
|
|
@@ -430,22 +417,6 @@ torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
|
430
417
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
431
418
|
```
|
|
432
419
|
|
|
433
|
-
<!--
|
|
434
|
-
|
|
435
|
-
<div align="center">
|
|
436
|
-
<p align="center">
|
|
437
|
-
DBPrune + <b>torch.compile</b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
438
|
-
</p>
|
|
439
|
-
</div>
|
|
440
|
-
|
|
441
|
-
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
442
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
443
|
-
|+L20x1:24.8s|19.4s|16.8s|15.9s|14.2s|10.6s|
|
|
444
|
-
|+compile:20.4s|16.5s|14.1s|13.4s|12s|8.8s|
|
|
445
|
-
|+L20x4:7.7s|6.6s|6.0s|5.8s|5.2s|3.9s|
|
|
446
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
447
|
-
-->
|
|
448
|
-
|
|
449
420
|
## 👋Contribute
|
|
450
421
|
<div id="contribute"></div>
|
|
451
422
|
|
|
@@ -9,31 +9,27 @@
|
|
|
9
9
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
10
10
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
11
11
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
12
|
-
<img src=https://img.shields.io/badge/Release-v0.
|
|
12
|
+
<img src=https://img.shields.io/badge/Release-v0.2.0-brightgreen.svg >
|
|
13
13
|
</div>
|
|
14
14
|
<p align="center">
|
|
15
15
|
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
16
16
|
</p>
|
|
17
17
|
<p align="center">
|
|
18
|
-
<
|
|
18
|
+
<h4> 🔥Supported Models🔥</h4>
|
|
19
19
|
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
20
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
21
20
|
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
22
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀
|
|
23
|
-
<
|
|
21
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
22
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
23
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
24
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
24
25
|
</p>
|
|
25
26
|
</div>
|
|
26
27
|
|
|
28
|
+
## 👋 Highlight
|
|
27
29
|
|
|
28
|
-
|
|
29
|
-
## 🎉Supported Models
|
|
30
|
-
<div id="supported"></div>
|
|
31
|
-
- [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
|
|
32
|
-
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
|
|
33
|
-
- [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples): *✔️DBCache, ✔️DBPrune, ✔️FBCache*
|
|
34
|
-
- [🚀Wan2.1**](https://github.com/vipshop/cache-dit/raw/main/examples): *🔜DBCache, 🔜DBPrune, ✔️FBCache*
|
|
35
|
-
-->
|
|
30
|
+
<div id="reference"></div>
|
|
36
31
|
|
|
32
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
|
|
37
33
|
|
|
38
34
|
## 🤗 Introduction
|
|
39
35
|
|
|
@@ -75,6 +71,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
75
71
|
|
|
76
72
|
**DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
|
|
77
73
|
|
|
74
|
+
<div align="center">
|
|
75
|
+
<p align="center">
|
|
76
|
+
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
77
|
+
</p>
|
|
78
|
+
</div>
|
|
79
|
+
|
|
78
80
|
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
79
81
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
80
82
|
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
@@ -82,11 +84,11 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
82
84
|
|
|
83
85
|
<div align="center">
|
|
84
86
|
<p align="center">
|
|
85
|
-
|
|
87
|
+
<h3>🔥 Context Parallelism and Torch Compile</h3>
|
|
86
88
|
</p>
|
|
87
|
-
</div>
|
|
89
|
+
</div>
|
|
88
90
|
|
|
89
|
-
**CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
91
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
90
92
|
|
|
91
93
|
<div align="center">
|
|
92
94
|
<p align="center">
|
|
@@ -96,11 +98,16 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
96
98
|
|
|
97
99
|
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
98
100
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
99
|
-
|+
|
|
100
|
-
|+compile:20.43s|16.25s|14.12s|13.41s|12s|8.86s|
|
|
101
|
+
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
101
102
|
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
102
103
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
103
104
|
|
|
105
|
+
<div align="center">
|
|
106
|
+
<p align="center">
|
|
107
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
108
|
+
</p>
|
|
109
|
+
</div>
|
|
110
|
+
|
|
104
111
|
## ©️Citations
|
|
105
112
|
|
|
106
113
|
```BibTeX
|
|
@@ -113,12 +120,6 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
113
120
|
}
|
|
114
121
|
```
|
|
115
122
|
|
|
116
|
-
## 👋Reference
|
|
117
|
-
|
|
118
|
-
<div id="reference"></div>
|
|
119
|
-
|
|
120
|
-
The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
|
|
121
|
-
|
|
122
123
|
## 📖Contents
|
|
123
124
|
|
|
124
125
|
<div id="contents"></div>
|
|
@@ -361,26 +362,12 @@ Then, run the python test script with `torchrun`:
|
|
|
361
362
|
```bash
|
|
362
363
|
torchrun --nproc_per_node=4 parallel_cache.py
|
|
363
364
|
```
|
|
364
|
-
<!--
|
|
365
|
-
|
|
366
|
-
<div align="center">
|
|
367
|
-
<p align="center">
|
|
368
|
-
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
369
|
-
</p>
|
|
370
|
-
</div>
|
|
371
|
-
|
|
372
|
-
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
373
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
374
|
-
|+L20x1:24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
375
|
-
|+L20x4:8.54s|7.20s|6.61s|6.09s|5.54s|4.22s|
|
|
376
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
377
|
-
-->
|
|
378
365
|
|
|
379
366
|
## 🔥Torch Compile
|
|
380
367
|
|
|
381
368
|
<div id="compile"></div>
|
|
382
369
|
|
|
383
|
-
**CacheDiT**
|
|
370
|
+
By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
|
|
384
371
|
|
|
385
372
|
```python
|
|
386
373
|
apply_cache_on_pipe(
|
|
@@ -395,22 +382,6 @@ torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
|
395
382
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
396
383
|
```
|
|
397
384
|
|
|
398
|
-
<!--
|
|
399
|
-
|
|
400
|
-
<div align="center">
|
|
401
|
-
<p align="center">
|
|
402
|
-
DBPrune + <b>torch.compile</b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
403
|
-
</p>
|
|
404
|
-
</div>
|
|
405
|
-
|
|
406
|
-
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
407
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
408
|
-
|+L20x1:24.8s|19.4s|16.8s|15.9s|14.2s|10.6s|
|
|
409
|
-
|+compile:20.4s|16.5s|14.1s|13.4s|12s|8.8s|
|
|
410
|
-
|+L20x4:7.7s|6.6s|6.0s|5.8s|5.2s|3.9s|
|
|
411
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
412
|
-
-->
|
|
413
|
-
|
|
414
385
|
## 👋Contribute
|
|
415
386
|
<div id="contribute"></div>
|
|
416
387
|
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Examples for CacheDiT
|
|
2
|
+
|
|
3
|
+
## Install requirements
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
pip3 install -r requirements.txt
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
## Run examples
|
|
10
|
+
|
|
11
|
+
- FLUX.1-dev
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
python3 run_flux.py
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
- FLUX.1-Fill-dev
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
python3 run_flux_fill.py
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
- CogVideoX
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
python3 run_cogvideox.py
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
- Wan2.1
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
python3 run_wan.py
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
- Mochi
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
python3 run_mochi.py
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
- HunyuanVideo
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
python3 run_hunyuan_video.py
|
|
45
|
+
```
|
|
@@ -1,14 +1,33 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
|
-
from diffusers import CogVideoXPipeline
|
|
4
3
|
from diffusers.utils import export_to_video
|
|
4
|
+
from diffusers import CogVideoXPipeline, AutoencoderKLCogVideoX
|
|
5
5
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
6
6
|
|
|
7
|
+
|
|
8
|
+
model_id = os.environ.get("COGVIDEOX_DIR", "THUDM/CogVideoX-5b")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def is_cogvideox_1_5():
|
|
12
|
+
return "CogVideoX1.5" in model_id or "THUDM/CogVideoX1.5" in model_id
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_gpu_memory_in_gib():
|
|
16
|
+
if not torch.cuda.is_available():
|
|
17
|
+
return 0
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
total_memory_bytes = torch.cuda.get_device_properties(
|
|
21
|
+
torch.cuda.current_device(),
|
|
22
|
+
).total_memory
|
|
23
|
+
total_memory_gib = total_memory_bytes / (1024**3)
|
|
24
|
+
return int(total_memory_gib)
|
|
25
|
+
except Exception:
|
|
26
|
+
return 0
|
|
27
|
+
|
|
28
|
+
|
|
7
29
|
pipe = CogVideoXPipeline.from_pretrained(
|
|
8
|
-
|
|
9
|
-
"COGVIDEOX_DIR",
|
|
10
|
-
"THUDM/CogVideoX-5b",
|
|
11
|
-
),
|
|
30
|
+
model_id,
|
|
12
31
|
torch_dtype=torch.bfloat16,
|
|
13
32
|
).to("cuda")
|
|
14
33
|
|
|
@@ -17,6 +36,8 @@ cache_options = CacheType.default_options(CacheType.DBCache)
|
|
|
17
36
|
|
|
18
37
|
apply_cache_on_pipe(pipe, **cache_options)
|
|
19
38
|
|
|
39
|
+
pipe.enable_model_cpu_offload()
|
|
40
|
+
assert isinstance(pipe.vae, AutoencoderKLCogVideoX) # enable type check for IDE
|
|
20
41
|
pipe.vae.enable_slicing()
|
|
21
42
|
pipe.vae.enable_tiling()
|
|
22
43
|
|
|
@@ -37,7 +58,12 @@ video = pipe(
|
|
|
37
58
|
prompt=prompt,
|
|
38
59
|
num_videos_per_prompt=1,
|
|
39
60
|
num_inference_steps=50,
|
|
40
|
-
num_frames=
|
|
61
|
+
num_frames=(
|
|
62
|
+
# Avoid OOM for CogVideoX1.5 model on 48GB GPU
|
|
63
|
+
16
|
|
64
|
+
if (is_cogvideox_1_5() and get_gpu_memory_in_gib() < 48)
|
|
65
|
+
else 49
|
|
66
|
+
),
|
|
41
67
|
guidance_scale=6,
|
|
42
68
|
generator=torch.Generator("cuda").manual_seed(0),
|
|
43
69
|
).frames[0]
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/first_block_cache_examples/run_hunyuan_video.py
|
|
2
|
+
import os
|
|
3
|
+
import torch
|
|
4
|
+
from diffusers.utils import export_to_video
|
|
5
|
+
from diffusers import (
|
|
6
|
+
HunyuanVideoPipeline,
|
|
7
|
+
HunyuanVideoTransformer3DModel,
|
|
8
|
+
AutoencoderKLHunyuanVideo,
|
|
9
|
+
)
|
|
10
|
+
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
11
|
+
|
|
12
|
+
model_id = os.environ.get("HUNYAN_DIR", "tencent/HunyuanVideo")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_gpu_memory_in_gib():
|
|
16
|
+
if not torch.cuda.is_available():
|
|
17
|
+
return 0
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
total_memory_bytes = torch.cuda.get_device_properties(
|
|
21
|
+
torch.cuda.current_device(),
|
|
22
|
+
).total_memory
|
|
23
|
+
total_memory_gib = total_memory_bytes / (1024**3)
|
|
24
|
+
return int(total_memory_gib)
|
|
25
|
+
except Exception:
|
|
26
|
+
return 0
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
|
30
|
+
model_id,
|
|
31
|
+
subfolder="transformer",
|
|
32
|
+
torch_dtype=torch.bfloat16,
|
|
33
|
+
revision="refs/pr/18",
|
|
34
|
+
)
|
|
35
|
+
pipe = HunyuanVideoPipeline.from_pretrained(
|
|
36
|
+
model_id,
|
|
37
|
+
transformer=transformer,
|
|
38
|
+
torch_dtype=torch.float16,
|
|
39
|
+
revision="refs/pr/18",
|
|
40
|
+
).to("cuda")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Default options, F8B8, good balance between performance and precision
|
|
44
|
+
apply_cache_on_pipe(pipe, **CacheType.default_options(CacheType.DBCache))
|
|
45
|
+
|
|
46
|
+
assert isinstance(
|
|
47
|
+
pipe.vae, AutoencoderKLHunyuanVideo
|
|
48
|
+
) # enable type check for IDE
|
|
49
|
+
|
|
50
|
+
# Enable memory savings
|
|
51
|
+
pipe.enable_model_cpu_offload()
|
|
52
|
+
if get_gpu_memory_in_gib() <= 48:
|
|
53
|
+
pipe.vae.enable_tiling(
|
|
54
|
+
# Make it runnable on GPUs with 48GB memory
|
|
55
|
+
tile_sample_min_height=128,
|
|
56
|
+
tile_sample_stride_height=96,
|
|
57
|
+
tile_sample_min_width=128,
|
|
58
|
+
tile_sample_stride_width=96,
|
|
59
|
+
tile_sample_min_num_frames=32,
|
|
60
|
+
tile_sample_stride_num_frames=24,
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
pipe.vae.enable_tiling()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
output = pipe(
|
|
67
|
+
prompt="A cat walks on the grass, realistic",
|
|
68
|
+
height=720,
|
|
69
|
+
width=1280,
|
|
70
|
+
num_frames=129,
|
|
71
|
+
num_inference_steps=30,
|
|
72
|
+
).frames[0]
|
|
73
|
+
|
|
74
|
+
print("Saving video to hunyuan_video.mp4")
|
|
75
|
+
export_to_video(output, "hunyuan_video.mp4", fps=15)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
|
-
from diffusers import WanPipeline
|
|
3
|
+
from diffusers import WanPipeline, AutoencoderKLWan
|
|
4
4
|
from diffusers.utils import export_to_video
|
|
5
5
|
from diffusers.schedulers.scheduling_unipc_multistep import (
|
|
6
6
|
UniPCMultistepScheduler,
|
|
@@ -27,11 +27,16 @@ if hasattr(pipe, "scheduler") and pipe.scheduler is not None:
|
|
|
27
27
|
|
|
28
28
|
pipe.to("cuda")
|
|
29
29
|
|
|
30
|
-
|
|
30
|
+
# Default options, F8B8, good balance between performance and precision
|
|
31
|
+
apply_cache_on_pipe(pipe, **CacheType.default_options(CacheType.DBCache))
|
|
31
32
|
|
|
32
33
|
# Enable memory savings
|
|
33
34
|
pipe.enable_model_cpu_offload()
|
|
34
|
-
|
|
35
|
+
|
|
36
|
+
# Wan currently requires installing diffusers from source
|
|
37
|
+
assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
|
|
38
|
+
pipe.vae.enable_tiling()
|
|
39
|
+
pipe.vae.enable_slicing()
|
|
35
40
|
|
|
36
41
|
video = pipe(
|
|
37
42
|
prompt=(
|
|
@@ -39,8 +44,8 @@ video = pipe(
|
|
|
39
44
|
"flying past in the background, hyperrealistic"
|
|
40
45
|
),
|
|
41
46
|
negative_prompt="",
|
|
42
|
-
height=
|
|
43
|
-
width=
|
|
47
|
+
height=height,
|
|
48
|
+
width=width,
|
|
44
49
|
num_frames=81,
|
|
45
50
|
num_inference_steps=30,
|
|
46
51
|
).frames[0]
|
|
@@ -13,6 +13,10 @@ def apply_db_cache_on_transformer(transformer, *args, **kwargs):
|
|
|
13
13
|
adapter_name = "mochi"
|
|
14
14
|
elif transformer_cls_name.startswith("CogVideoX"):
|
|
15
15
|
adapter_name = "cogvideox"
|
|
16
|
+
elif transformer_cls_name.startswith("Wan"):
|
|
17
|
+
adapter_name = "wan"
|
|
18
|
+
elif transformer_cls_name.startswith("HunyuanVideo"):
|
|
19
|
+
adapter_name = "hunyuan_video"
|
|
16
20
|
else:
|
|
17
21
|
raise ValueError(
|
|
18
22
|
f"Unknown transformer class name: {transformer_cls_name}"
|
|
@@ -35,6 +39,10 @@ def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
|
35
39
|
adapter_name = "mochi"
|
|
36
40
|
elif pipe_cls_name.startswith("CogVideoX"):
|
|
37
41
|
adapter_name = "cogvideox"
|
|
42
|
+
elif pipe_cls_name.startswith("Wan"):
|
|
43
|
+
adapter_name = "wan"
|
|
44
|
+
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
45
|
+
adapter_name = "hunyuan_video"
|
|
38
46
|
else:
|
|
39
47
|
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
40
48
|
|