cache-dit 0.1.7__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- {cache_dit-0.1.7 → cache_dit-0.2.0}/PKG-INFO +45 -40
- {cache_dit-0.1.7 → cache_dit-0.2.0}/README.md +44 -39
- cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png +0 -0
- cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png +0 -0
- cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png +0 -0
- cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png +0 -0
- cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.05_P41.6_T12.70s.png +0 -0
- cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png +0 -0
- cache_dit-0.2.0/assets/U0_C1_DBPRUNE_F8B8_R0.08_P23.1_T16.14s.png +0 -0
- cache_dit-0.2.0/assets/U0_C1_NONE_R0.08_S0_T20.43s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.62s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.63s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.81s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.82s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.06s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.07s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.08s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.27s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.28s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.95s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.96s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_NONE_R0.08_S0_T7.78s.png +0 -0
- cache_dit-0.2.0/assets/U4_C1_NONE_R0.08_S0_T7.79s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/bench/bench.py +22 -6
- {cache_dit-0.1.7 → cache_dit-0.2.0}/examples/.gitignore +1 -1
- cache_dit-0.2.0/examples/README.md +45 -0
- cache_dit-0.2.0/examples/data/cup.png +0 -0
- cache_dit-0.2.0/examples/data/cup_mask.png +0 -0
- cache_dit-0.2.0/examples/requirements.txt +4 -0
- cache_dit-0.2.0/examples/run_cogvideox.py +72 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/examples/run_flux.py +5 -1
- cache_dit-0.2.0/examples/run_flux_fill.py +32 -0
- cache_dit-0.2.0/examples/run_hunyuan_video.py +75 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/examples/run_mochi.py +9 -2
- cache_dit-0.2.0/examples/run_wan.py +54 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/setup.py +1 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/_version.py +2 -2
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +8 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +99 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +12 -4
- cache_dit-0.2.0/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +99 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +2 -2
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit-0.2.0/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +2 -2
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/PKG-INFO +45 -40
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/SOURCES.txt +33 -0
- cache_dit-0.1.7/examples/run_cogvideox.py +0 -30
- {cache_dit-0.1.7 → cache_dit-0.2.0}/.github/workflows/issue.yml +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/.gitignore +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/.pre-commit-config.yaml +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/CONTRIBUTE.md +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/LICENSE +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/MANIFEST.in +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBCache.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/NONE_R0.08_S0.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/assets/cache-dit.png +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/bench/.gitignore +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/docs/.gitignore +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/pyproject.toml +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/pytest.ini +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/requirements.txt +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/setup.cfg +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/__init__.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/__init__.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/taylorseer.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/cache_factory/utils.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/logger.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit/primitives.py +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/dependency_links.txt +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/requires.txt +0 -0
- {cache_dit-0.1.7 → cache_dit-0.2.0}/src/cache_dit.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -35,7 +35,7 @@ Dynamic: requires-python
|
|
|
35
35
|
|
|
36
36
|
<div align="center">
|
|
37
37
|
<p align="center">
|
|
38
|
-
<
|
|
38
|
+
<h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
|
|
39
39
|
</p>
|
|
40
40
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
|
|
41
41
|
<div align='center'>
|
|
@@ -44,13 +44,28 @@ Dynamic: requires-python
|
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.2.0-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
50
|
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
51
51
|
</p>
|
|
52
|
+
<p align="center">
|
|
53
|
+
<h4> 🔥Supported Models🔥</h4>
|
|
54
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
55
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
56
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
57
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
58
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
59
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
60
|
+
</p>
|
|
52
61
|
</div>
|
|
53
62
|
|
|
63
|
+
## 👋 Highlight
|
|
64
|
+
|
|
65
|
+
<div id="reference"></div>
|
|
66
|
+
|
|
67
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
|
|
68
|
+
|
|
54
69
|
## 🤗 Introduction
|
|
55
70
|
|
|
56
71
|
<div align="center">
|
|
@@ -91,6 +106,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
91
106
|
|
|
92
107
|
**DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
|
|
93
108
|
|
|
109
|
+
<div align="center">
|
|
110
|
+
<p align="center">
|
|
111
|
+
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
112
|
+
</p>
|
|
113
|
+
</div>
|
|
114
|
+
|
|
94
115
|
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
95
116
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
96
117
|
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
@@ -98,15 +119,29 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
98
119
|
|
|
99
120
|
<div align="center">
|
|
100
121
|
<p align="center">
|
|
101
|
-
|
|
122
|
+
<h3>🔥 Context Parallelism and Torch Compile</h3>
|
|
123
|
+
</p>
|
|
124
|
+
</div>
|
|
125
|
+
|
|
126
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
127
|
+
|
|
128
|
+
<div align="center">
|
|
129
|
+
<p align="center">
|
|
130
|
+
DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
102
131
|
</p>
|
|
103
132
|
</div>
|
|
104
133
|
|
|
105
|
-
|
|
134
|
+
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
135
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
136
|
+
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
137
|
+
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
138
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
106
139
|
|
|
107
|
-
<
|
|
108
|
-
|
|
109
|
-
|
|
140
|
+
<div align="center">
|
|
141
|
+
<p align="center">
|
|
142
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
143
|
+
</p>
|
|
144
|
+
</div>
|
|
110
145
|
|
|
111
146
|
## ©️Citations
|
|
112
147
|
|
|
@@ -120,12 +155,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
120
155
|
}
|
|
121
156
|
```
|
|
122
157
|
|
|
123
|
-
## 👋Reference
|
|
124
|
-
|
|
125
|
-
<div id="reference"></div>
|
|
126
|
-
|
|
127
|
-
The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
|
|
128
|
-
|
|
129
158
|
## 📖Contents
|
|
130
159
|
|
|
131
160
|
<div id="contents"></div>
|
|
@@ -136,11 +165,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
|
|
|
136
165
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
137
166
|
- [🎉Context Parallelism](#context-parallelism)
|
|
138
167
|
- [🔥Torch Compile](#compile)
|
|
139
|
-
- [🎉Supported Models](#supported)
|
|
140
168
|
- [👋Contribute](#contribute)
|
|
141
169
|
- [©️License](#license)
|
|
142
170
|
|
|
143
|
-
|
|
144
171
|
## ⚙️Installation
|
|
145
172
|
|
|
146
173
|
<div id="installation"></div>
|
|
@@ -371,23 +398,11 @@ Then, run the python test script with `torchrun`:
|
|
|
371
398
|
torchrun --nproc_per_node=4 parallel_cache.py
|
|
372
399
|
```
|
|
373
400
|
|
|
374
|
-
<div align="center">
|
|
375
|
-
<p align="center">
|
|
376
|
-
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
377
|
-
</p>
|
|
378
|
-
</div>
|
|
379
|
-
|
|
380
|
-
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
381
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
382
|
-
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
383
|
-
|8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
|
|
384
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
385
|
-
|
|
386
401
|
## 🔥Torch Compile
|
|
387
402
|
|
|
388
403
|
<div id="compile"></div>
|
|
389
404
|
|
|
390
|
-
**CacheDiT**
|
|
405
|
+
By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
|
|
391
406
|
|
|
392
407
|
```python
|
|
393
408
|
apply_cache_on_pipe(
|
|
@@ -396,21 +411,11 @@ apply_cache_on_pipe(
|
|
|
396
411
|
# Compile the Transformer module
|
|
397
412
|
pipe.transformer = torch.compile(pipe.transformer)
|
|
398
413
|
```
|
|
399
|
-
However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo
|
|
400
|
-
|
|
414
|
+
However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
401
415
|
```python
|
|
402
416
|
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
403
417
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
404
418
|
```
|
|
405
|
-
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
406
|
-
|
|
407
|
-
## 🎉Supported Models
|
|
408
|
-
|
|
409
|
-
<div id="supported"></div>
|
|
410
|
-
|
|
411
|
-
- [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
412
|
-
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
413
|
-
- [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
414
419
|
|
|
415
420
|
## 👋Contribute
|
|
416
421
|
<div id="contribute"></div>
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
<div align="center">
|
|
2
2
|
<p align="center">
|
|
3
|
-
<
|
|
3
|
+
<h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
|
|
4
4
|
</p>
|
|
5
5
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
|
|
6
6
|
<div align='center'>
|
|
@@ -9,13 +9,28 @@
|
|
|
9
9
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
10
10
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
11
11
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
12
|
-
<img src=https://img.shields.io/badge/Release-v0.
|
|
12
|
+
<img src=https://img.shields.io/badge/Release-v0.2.0-brightgreen.svg >
|
|
13
13
|
</div>
|
|
14
14
|
<p align="center">
|
|
15
15
|
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
16
16
|
</p>
|
|
17
|
+
<p align="center">
|
|
18
|
+
<h4> 🔥Supported Models🔥</h4>
|
|
19
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
20
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
21
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
22
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
23
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
24
|
+
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
25
|
+
</p>
|
|
17
26
|
</div>
|
|
18
27
|
|
|
28
|
+
## 👋 Highlight
|
|
29
|
+
|
|
30
|
+
<div id="reference"></div>
|
|
31
|
+
|
|
32
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
|
|
33
|
+
|
|
19
34
|
## 🤗 Introduction
|
|
20
35
|
|
|
21
36
|
<div align="center">
|
|
@@ -56,6 +71,12 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
56
71
|
|
|
57
72
|
**DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
|
|
58
73
|
|
|
74
|
+
<div align="center">
|
|
75
|
+
<p align="center">
|
|
76
|
+
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
77
|
+
</p>
|
|
78
|
+
</div>
|
|
79
|
+
|
|
59
80
|
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
60
81
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
61
82
|
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
@@ -63,15 +84,29 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
63
84
|
|
|
64
85
|
<div align="center">
|
|
65
86
|
<p align="center">
|
|
66
|
-
|
|
87
|
+
<h3>🔥 Context Parallelism and Torch Compile</h3>
|
|
88
|
+
</p>
|
|
89
|
+
</div>
|
|
90
|
+
|
|
91
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
92
|
+
|
|
93
|
+
<div align="center">
|
|
94
|
+
<p align="center">
|
|
95
|
+
DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
67
96
|
</p>
|
|
68
97
|
</div>
|
|
69
98
|
|
|
70
|
-
|
|
99
|
+
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
100
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
101
|
+
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
102
|
+
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
103
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
71
104
|
|
|
72
|
-
<
|
|
73
|
-
|
|
74
|
-
|
|
105
|
+
<div align="center">
|
|
106
|
+
<p align="center">
|
|
107
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
108
|
+
</p>
|
|
109
|
+
</div>
|
|
75
110
|
|
|
76
111
|
## ©️Citations
|
|
77
112
|
|
|
@@ -85,12 +120,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
85
120
|
}
|
|
86
121
|
```
|
|
87
122
|
|
|
88
|
-
## 👋Reference
|
|
89
|
-
|
|
90
|
-
<div id="reference"></div>
|
|
91
|
-
|
|
92
|
-
The **CacheDiT** codebase was adapted from FBCache's implementation at the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). We would like to express our sincere gratitude for this excellent work!
|
|
93
|
-
|
|
94
123
|
## 📖Contents
|
|
95
124
|
|
|
96
125
|
<div id="contents"></div>
|
|
@@ -101,11 +130,9 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
|
|
|
101
130
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
102
131
|
- [🎉Context Parallelism](#context-parallelism)
|
|
103
132
|
- [🔥Torch Compile](#compile)
|
|
104
|
-
- [🎉Supported Models](#supported)
|
|
105
133
|
- [👋Contribute](#contribute)
|
|
106
134
|
- [©️License](#license)
|
|
107
135
|
|
|
108
|
-
|
|
109
136
|
## ⚙️Installation
|
|
110
137
|
|
|
111
138
|
<div id="installation"></div>
|
|
@@ -336,23 +363,11 @@ Then, run the python test script with `torchrun`:
|
|
|
336
363
|
torchrun --nproc_per_node=4 parallel_cache.py
|
|
337
364
|
```
|
|
338
365
|
|
|
339
|
-
<div align="center">
|
|
340
|
-
<p align="center">
|
|
341
|
-
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
342
|
-
</p>
|
|
343
|
-
</div>
|
|
344
|
-
|
|
345
|
-
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
346
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
347
|
-
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
348
|
-
|8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
|
|
349
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
350
|
-
|
|
351
366
|
## 🔥Torch Compile
|
|
352
367
|
|
|
353
368
|
<div id="compile"></div>
|
|
354
369
|
|
|
355
|
-
**CacheDiT**
|
|
370
|
+
By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
|
|
356
371
|
|
|
357
372
|
```python
|
|
358
373
|
apply_cache_on_pipe(
|
|
@@ -361,21 +376,11 @@ apply_cache_on_pipe(
|
|
|
361
376
|
# Compile the Transformer module
|
|
362
377
|
pipe.transformer = torch.compile(pipe.transformer)
|
|
363
378
|
```
|
|
364
|
-
However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo
|
|
365
|
-
|
|
379
|
+
However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
366
380
|
```python
|
|
367
381
|
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
368
382
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
369
383
|
```
|
|
370
|
-
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
371
|
-
|
|
372
|
-
## 🎉Supported Models
|
|
373
|
-
|
|
374
|
-
<div id="supported"></div>
|
|
375
|
-
|
|
376
|
-
- [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
377
|
-
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
378
|
-
- [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters)
|
|
379
384
|
|
|
380
385
|
## 👋Contribute
|
|
381
386
|
<div id="contribute"></div>
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -3,7 +3,7 @@ import argparse
|
|
|
3
3
|
import torch
|
|
4
4
|
import time
|
|
5
5
|
|
|
6
|
-
from diffusers import FluxPipeline
|
|
6
|
+
from diffusers import FluxPipeline, FluxTransformer2DModel
|
|
7
7
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
8
8
|
from cache_dit.logger import init_logger
|
|
9
9
|
|
|
@@ -110,6 +110,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
|
|
|
110
110
|
return cache_options, cache_type_str
|
|
111
111
|
|
|
112
112
|
|
|
113
|
+
@torch.no_grad()
|
|
113
114
|
def main():
|
|
114
115
|
args = get_args()
|
|
115
116
|
logger.info(f"Arguments: {args}")
|
|
@@ -119,7 +120,9 @@ def main():
|
|
|
119
120
|
try:
|
|
120
121
|
import torch.distributed as dist
|
|
121
122
|
from para_attn.context_parallel import init_context_parallel_mesh
|
|
122
|
-
from para_attn.context_parallel.diffusers_adapters import
|
|
123
|
+
from para_attn.context_parallel.diffusers_adapters import (
|
|
124
|
+
parallelize_pipe,
|
|
125
|
+
)
|
|
123
126
|
|
|
124
127
|
# Initialize distributed process group
|
|
125
128
|
dist.init_process_group()
|
|
@@ -133,9 +136,10 @@ def main():
|
|
|
133
136
|
).to("cuda")
|
|
134
137
|
|
|
135
138
|
parallelize_pipe(
|
|
136
|
-
pipe,
|
|
139
|
+
pipe,
|
|
140
|
+
mesh=init_context_parallel_mesh(
|
|
137
141
|
pipe.device.type, max_ulysses_dim_size=args.ulysses
|
|
138
|
-
)
|
|
142
|
+
),
|
|
139
143
|
)
|
|
140
144
|
except ImportError as e:
|
|
141
145
|
logger.error(
|
|
@@ -148,7 +152,7 @@ def main():
|
|
|
148
152
|
pipe = FluxPipeline.from_pretrained(
|
|
149
153
|
os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
|
|
150
154
|
torch_dtype=torch.bfloat16,
|
|
151
|
-
).to("cuda")
|
|
155
|
+
).to("cuda")
|
|
152
156
|
|
|
153
157
|
cache_options, cache_type = get_cache_options(args.cache, args)
|
|
154
158
|
|
|
@@ -165,7 +169,18 @@ def main():
|
|
|
165
169
|
torch._dynamo.config.accumulated_recompile_limit = (
|
|
166
170
|
2048 # default is 256
|
|
167
171
|
)
|
|
168
|
-
|
|
172
|
+
if isinstance(pipe.transformer, FluxTransformer2DModel):
|
|
173
|
+
logger.warning(
|
|
174
|
+
"Only compile transformer blocks not the whole model "
|
|
175
|
+
"for FluxTransformer2DModel to keep higher precision."
|
|
176
|
+
)
|
|
177
|
+
for module in pipe.transformer.transformer_blocks:
|
|
178
|
+
module.compile()
|
|
179
|
+
for module in pipe.transformer.single_transformer_blocks:
|
|
180
|
+
module.compile()
|
|
181
|
+
else:
|
|
182
|
+
logger.info("Compiling the transformer with default mode.")
|
|
183
|
+
pipe.transformer = torch.compile(pipe.transformer, mode="default")
|
|
169
184
|
|
|
170
185
|
all_times = []
|
|
171
186
|
cached_stepes = 0
|
|
@@ -238,6 +253,7 @@ def main():
|
|
|
238
253
|
|
|
239
254
|
if args.ulysses is not None:
|
|
240
255
|
import torch.distributed as dist
|
|
256
|
+
|
|
241
257
|
dist.destroy_process_group()
|
|
242
258
|
logger.info("Distributed process group destroyed.")
|
|
243
259
|
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Examples for CacheDiT
|
|
2
|
+
|
|
3
|
+
## Install requirements
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
pip3 install -r requirements.txt
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
## Run examples
|
|
10
|
+
|
|
11
|
+
- FLUX.1-dev
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
python3 run_flux.py
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
- FLUX.1-Fill-dev
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
python3 run_flux_fill.py
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
- CogVideoX
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
python3 run_cogvideox.py
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
- Wan2.1
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
python3 run_wan.py
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
- Mochi
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
python3 run_mochi.py
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
- HunyuanVideo
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
python3 run_hunyuan_video.py
|
|
45
|
+
```
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from diffusers.utils import export_to_video
|
|
4
|
+
from diffusers import CogVideoXPipeline, AutoencoderKLCogVideoX
|
|
5
|
+
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
model_id = os.environ.get("COGVIDEOX_DIR", "THUDM/CogVideoX-5b")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def is_cogvideox_1_5():
|
|
12
|
+
return "CogVideoX1.5" in model_id or "THUDM/CogVideoX1.5" in model_id
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_gpu_memory_in_gib():
|
|
16
|
+
if not torch.cuda.is_available():
|
|
17
|
+
return 0
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
total_memory_bytes = torch.cuda.get_device_properties(
|
|
21
|
+
torch.cuda.current_device(),
|
|
22
|
+
).total_memory
|
|
23
|
+
total_memory_gib = total_memory_bytes / (1024**3)
|
|
24
|
+
return int(total_memory_gib)
|
|
25
|
+
except Exception:
|
|
26
|
+
return 0
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
pipe = CogVideoXPipeline.from_pretrained(
|
|
30
|
+
model_id,
|
|
31
|
+
torch_dtype=torch.bfloat16,
|
|
32
|
+
).to("cuda")
|
|
33
|
+
|
|
34
|
+
# Default options, F8B8, good balance between performance and precision
|
|
35
|
+
cache_options = CacheType.default_options(CacheType.DBCache)
|
|
36
|
+
|
|
37
|
+
apply_cache_on_pipe(pipe, **cache_options)
|
|
38
|
+
|
|
39
|
+
pipe.enable_model_cpu_offload()
|
|
40
|
+
assert isinstance(pipe.vae, AutoencoderKLCogVideoX) # enable type check for IDE
|
|
41
|
+
pipe.vae.enable_slicing()
|
|
42
|
+
pipe.vae.enable_tiling()
|
|
43
|
+
|
|
44
|
+
prompt = (
|
|
45
|
+
"A panda, dressed in a small, red jacket and a tiny hat, "
|
|
46
|
+
"sits on a wooden stool in a serene bamboo forest. The "
|
|
47
|
+
"panda's fluffy paws strum a miniature acoustic guitar, "
|
|
48
|
+
"producing soft, melodic tunes. Nearby, a few other pandas "
|
|
49
|
+
"gather, watching curiously and some clapping in rhythm. "
|
|
50
|
+
"Sunlight filters through the tall bamboo, casting a gentle "
|
|
51
|
+
"glow on the scene. The panda's face is expressive, showing "
|
|
52
|
+
"concentration and joy as it plays. The background includes "
|
|
53
|
+
"a small, flowing stream and vibrant green foliage, enhancing "
|
|
54
|
+
"the peaceful and magical atmosphere of this unique musical "
|
|
55
|
+
"performance."
|
|
56
|
+
)
|
|
57
|
+
video = pipe(
|
|
58
|
+
prompt=prompt,
|
|
59
|
+
num_videos_per_prompt=1,
|
|
60
|
+
num_inference_steps=50,
|
|
61
|
+
num_frames=(
|
|
62
|
+
# Avoid OOM for CogVideoX1.5 model on 48GB GPU
|
|
63
|
+
16
|
|
64
|
+
if (is_cogvideox_1_5() and get_gpu_memory_in_gib() < 48)
|
|
65
|
+
else 49
|
|
66
|
+
),
|
|
67
|
+
guidance_scale=6,
|
|
68
|
+
generator=torch.Generator("cuda").manual_seed(0),
|
|
69
|
+
).frames[0]
|
|
70
|
+
|
|
71
|
+
print("Saving video to cogvideox.mp4")
|
|
72
|
+
export_to_video(video, "cogvideox.mp4", fps=8)
|
|
@@ -1,9 +1,13 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import torch
|
|
2
3
|
from diffusers import FluxPipeline
|
|
3
4
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
4
5
|
|
|
5
6
|
pipe = FluxPipeline.from_pretrained(
|
|
6
|
-
|
|
7
|
+
os.environ.get(
|
|
8
|
+
"FLUX_DIR",
|
|
9
|
+
"black-forest-labs/FLUX.1-dev",
|
|
10
|
+
),
|
|
7
11
|
torch_dtype=torch.bfloat16,
|
|
8
12
|
).to("cuda")
|
|
9
13
|
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from diffusers import FluxFillPipeline
|
|
4
|
+
from diffusers.utils import load_image
|
|
5
|
+
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
6
|
+
|
|
7
|
+
pipe = FluxFillPipeline.from_pretrained(
|
|
8
|
+
os.environ.get(
|
|
9
|
+
"FLUX_FILL_DIR",
|
|
10
|
+
"black-forest-labs/FLUX.1-Fill-dev",
|
|
11
|
+
),
|
|
12
|
+
torch_dtype=torch.bfloat16,
|
|
13
|
+
).to("cuda")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Default options, F8B8, good balance between performance and precision
|
|
17
|
+
cache_options = CacheType.default_options(CacheType.DBCache)
|
|
18
|
+
|
|
19
|
+
apply_cache_on_pipe(pipe, **cache_options)
|
|
20
|
+
|
|
21
|
+
image = pipe(
|
|
22
|
+
prompt="a white paper cup",
|
|
23
|
+
image=load_image("data/cup.png"),
|
|
24
|
+
mask_image=load_image("data/cup_mask.png"),
|
|
25
|
+
guidance_scale=30,
|
|
26
|
+
num_inference_steps=28,
|
|
27
|
+
max_sequence_length=512,
|
|
28
|
+
generator=torch.Generator("cuda").manual_seed(0),
|
|
29
|
+
).images[0]
|
|
30
|
+
|
|
31
|
+
print("Saving image to flux-fill.png")
|
|
32
|
+
image.save("flux-fill.png")
|