cache-dit 0.1.3__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- {cache_dit-0.1.3 → cache_dit-0.1.6}/PKG-INFO +96 -21
- {cache_dit-0.1.3 → cache_dit-0.1.6}/README.md +92 -17
- {cache_dit-0.1.3 → cache_dit-0.1.6}/bench/bench.py +50 -13
- cache_dit-0.1.6/examples/run_cogvideox.py +30 -0
- cache_dit-0.1.6/examples/run_mochi.py +25 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/requirements.txt +3 -3
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/_version.py +2 -2
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +11 -3
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/PKG-INFO +96 -21
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/SOURCES.txt +2 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/requires.txt +3 -3
- {cache_dit-0.1.3 → cache_dit-0.1.6}/.github/workflows/issue.yml +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/.gitignore +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/.pre-commit-config.yaml +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/CONTRIBUTE.md +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/LICENSE +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/MANIFEST.in +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBCache.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/NONE_R0.08_S0.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/assets/cache-dit.png +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/bench/.gitignore +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/docs/.gitignore +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/examples/.gitignore +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/examples/run_flux.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/pyproject.toml +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/pytest.ini +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/setup.cfg +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/setup.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/taylorseer.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/utils.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/logger.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/primitives.py +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/dependency_links.txt +0 -0
- {cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -10,9 +10,9 @@ Requires-Python: >=3.10
|
|
|
10
10
|
Description-Content-Type: text/markdown
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Requires-Dist: packaging
|
|
13
|
-
Requires-Dist: torch
|
|
14
|
-
Requires-Dist: transformers
|
|
15
|
-
Requires-Dist: diffusers
|
|
13
|
+
Requires-Dist: torch>=2.5.1
|
|
14
|
+
Requires-Dist: transformers>=4.51.3
|
|
15
|
+
Requires-Dist: diffusers>=0.33.1
|
|
16
16
|
Provides-Extra: all
|
|
17
17
|
Provides-Extra: dev
|
|
18
18
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -44,10 +44,10 @@ Dynamic: requires-python
|
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.1.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT
|
|
50
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
51
51
|
</p>
|
|
52
52
|
</div>
|
|
53
53
|
|
|
@@ -55,7 +55,7 @@ Dynamic: requires-python
|
|
|
55
55
|
|
|
56
56
|
<div align="center">
|
|
57
57
|
<p align="center">
|
|
58
|
-
<h3
|
|
58
|
+
<h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
59
59
|
</p>
|
|
60
60
|
</div>
|
|
61
61
|
|
|
@@ -77,7 +77,7 @@ Dynamic: requires-python
|
|
|
77
77
|
|
|
78
78
|
<div align="center">
|
|
79
79
|
<p align="center">
|
|
80
|
-
|
|
80
|
+
DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
|
|
81
81
|
</p>
|
|
82
82
|
</div>
|
|
83
83
|
|
|
@@ -85,7 +85,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
85
85
|
|
|
86
86
|
<div align="center">
|
|
87
87
|
<p align="center">
|
|
88
|
-
<h3
|
|
88
|
+
<h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
89
89
|
</p>
|
|
90
90
|
</div>
|
|
91
91
|
|
|
@@ -102,10 +102,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
102
102
|
</p>
|
|
103
103
|
</div>
|
|
104
104
|
|
|
105
|
-
Moreover,
|
|
105
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
106
106
|
|
|
107
107
|
<p align="center">
|
|
108
|
-
|
|
108
|
+
♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
|
|
109
109
|
</p>
|
|
110
110
|
|
|
111
111
|
## ©️Citations
|
|
@@ -135,7 +135,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
|
|
|
135
135
|
- [🎉First Block Cache](#fbcache)
|
|
136
136
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
137
137
|
- [🎉Context Parallelism](#context-parallelism)
|
|
138
|
-
- [
|
|
138
|
+
- [🔥Torch Compile](#compile)
|
|
139
139
|
- [🎉Supported Models](#supported)
|
|
140
140
|
- [👋Contribute](#contribute)
|
|
141
141
|
- [©️License](#license)
|
|
@@ -167,7 +167,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
167
167
|
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
|
168
168
|
- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
|
169
169
|
- **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
170
|
-
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the
|
|
170
|
+
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
|
|
171
171
|
- **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
|
|
172
172
|
|
|
173
173
|
For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
|
|
@@ -210,6 +210,17 @@ cache_options = {
|
|
|
210
210
|
}
|
|
211
211
|
```
|
|
212
212
|
|
|
213
|
+
<div align="center">
|
|
214
|
+
<p align="center">
|
|
215
|
+
DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
216
|
+
</p>
|
|
217
|
+
</div>
|
|
218
|
+
|
|
219
|
+
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
|
|
220
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
221
|
+
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
222
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
223
|
+
|
|
213
224
|
## 🎉FBCache: First Block Cache
|
|
214
225
|
|
|
215
226
|
<div id="fbcache"></div>
|
|
@@ -261,12 +272,47 @@ pipe = FluxPipeline.from_pretrained(
|
|
|
261
272
|
torch_dtype=torch.bfloat16,
|
|
262
273
|
).to("cuda")
|
|
263
274
|
|
|
264
|
-
# Using DBPrune
|
|
275
|
+
# Using DBPrune with default options
|
|
265
276
|
cache_options = CacheType.default_options(CacheType.DBPrune)
|
|
266
277
|
|
|
267
278
|
apply_cache_on_pipe(pipe, **cache_options)
|
|
268
279
|
```
|
|
269
280
|
|
|
281
|
+
We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
|
|
282
|
+
|
|
283
|
+
```python
|
|
284
|
+
# Custom options for DBPrune
|
|
285
|
+
cache_options = {
|
|
286
|
+
"cache_type": CacheType.DBPrune,
|
|
287
|
+
"residual_diff_threshold": 0.05,
|
|
288
|
+
# Never prune the first `Fn` and last `Bn` blocks.
|
|
289
|
+
"Fn_compute_blocks": 8, # default 1
|
|
290
|
+
"Bn_compute_blocks": 8, # default 0
|
|
291
|
+
"warmup_steps": 8, # default -1
|
|
292
|
+
# Disables the pruning strategy when the previous
|
|
293
|
+
# pruned steps greater than this value.
|
|
294
|
+
"max_pruned_steps": 12, # default, -1 means no limit
|
|
295
|
+
# Enable dynamic prune threshold within step, higher
|
|
296
|
+
# `max_dynamic_prune_threshold` value may introduce a more
|
|
297
|
+
# ageressive pruning strategy.
|
|
298
|
+
"enable_dynamic_prune_threshold": True,
|
|
299
|
+
"max_dynamic_prune_threshold": 2 * 0.05,
|
|
300
|
+
# (New thresh) = mean(previous_block_diffs_within_step) * 1.25
|
|
301
|
+
# (New thresh) = ((New thresh) if (New thresh) <
|
|
302
|
+
# max_dynamic_prune_threshold else residual_diff_threshold)
|
|
303
|
+
"dynamic_prune_threshold_relax_ratio": 1.25,
|
|
304
|
+
# The step interval to update residual cache. For example,
|
|
305
|
+
# 2: means the update steps will be [0, 2, 4, ...].
|
|
306
|
+
"residual_cache_update_interval": 1,
|
|
307
|
+
# You can set non-prune blocks to avoid ageressive pruning.
|
|
308
|
+
# For example, FLUX.1 has 19 + 38 blocks, so we can set it
|
|
309
|
+
# to 0, 2, 4, ..., 56, etc.
|
|
310
|
+
"non_prune_blocks_ids": [],
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
apply_cache_on_pipe(pipe, **cache_options)
|
|
314
|
+
```
|
|
315
|
+
|
|
270
316
|
<div align="center">
|
|
271
317
|
<p align="center">
|
|
272
318
|
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
@@ -288,14 +334,19 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
288
334
|
pip3 install para-attn # or install `para-attn` from sources.
|
|
289
335
|
```
|
|
290
336
|
|
|
291
|
-
Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
|
|
337
|
+
Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
|
|
292
338
|
|
|
293
339
|
```python
|
|
340
|
+
import torch.distributed as dist
|
|
294
341
|
from diffusers import FluxPipeline
|
|
295
342
|
from para_attn.context_parallel import init_context_parallel_mesh
|
|
296
343
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
297
344
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
298
345
|
|
|
346
|
+
# Init distributed process group
|
|
347
|
+
dist.init_process_group()
|
|
348
|
+
torch.cuda.set_device(dist.get_rank())
|
|
349
|
+
|
|
299
350
|
pipe = FluxPipeline.from_pretrained(
|
|
300
351
|
"black-forest-labs/FLUX.1-dev",
|
|
301
352
|
torch_dtype=torch.bfloat16,
|
|
@@ -308,13 +359,31 @@ parallelize_pipe(
|
|
|
308
359
|
)
|
|
309
360
|
)
|
|
310
361
|
|
|
311
|
-
#
|
|
362
|
+
# DBPrune with default options from this library
|
|
312
363
|
apply_cache_on_pipe(
|
|
313
|
-
pipe, **CacheType.default_options(CacheType.
|
|
364
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
314
365
|
)
|
|
366
|
+
|
|
367
|
+
dist.destroy_process_group()
|
|
368
|
+
```
|
|
369
|
+
Then, run the python test script with `torchrun`:
|
|
370
|
+
```bash
|
|
371
|
+
torchrun --nproc_per_node=4 parallel_cache.py
|
|
315
372
|
```
|
|
316
373
|
|
|
317
|
-
|
|
374
|
+
<div align="center">
|
|
375
|
+
<p align="center">
|
|
376
|
+
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
377
|
+
</p>
|
|
378
|
+
</div>
|
|
379
|
+
|
|
380
|
+
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
381
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
382
|
+
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
383
|
+
|8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
|
|
384
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
385
|
+
|
|
386
|
+
## 🔥Torch Compile
|
|
318
387
|
|
|
319
388
|
<div id="compile"></div>
|
|
320
389
|
|
|
@@ -322,7 +391,7 @@ apply_cache_on_pipe(
|
|
|
322
391
|
|
|
323
392
|
```python
|
|
324
393
|
apply_cache_on_pipe(
|
|
325
|
-
pipe, **CacheType.default_options(CacheType.
|
|
394
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
326
395
|
)
|
|
327
396
|
# Compile the Transformer module
|
|
328
397
|
pipe.transformer = torch.compile(pipe.transformer)
|
|
@@ -333,7 +402,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
|
|
|
333
402
|
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
334
403
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
335
404
|
```
|
|
336
|
-
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
405
|
+
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
|
|
406
|
+
|
|
407
|
+
<div align="center">
|
|
408
|
+
<p align="center">
|
|
409
|
+
DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
410
|
+
</p>
|
|
411
|
+
</div>
|
|
337
412
|
|
|
338
413
|
## 🎉Supported Models
|
|
339
414
|
|
|
@@ -346,7 +421,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
|
|
|
346
421
|
## 👋Contribute
|
|
347
422
|
<div id="contribute"></div>
|
|
348
423
|
|
|
349
|
-
How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
424
|
+
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
350
425
|
|
|
351
426
|
## ©️License
|
|
352
427
|
|
|
@@ -9,10 +9,10 @@
|
|
|
9
9
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
10
10
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
11
11
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
12
|
-
<img src=https://img.shields.io/badge/Release-v0.1.
|
|
12
|
+
<img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
|
|
13
13
|
</div>
|
|
14
14
|
<p align="center">
|
|
15
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT
|
|
15
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
16
16
|
</p>
|
|
17
17
|
</div>
|
|
18
18
|
|
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
|
|
21
21
|
<div align="center">
|
|
22
22
|
<p align="center">
|
|
23
|
-
<h3
|
|
23
|
+
<h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
24
24
|
</p>
|
|
25
25
|
</div>
|
|
26
26
|
|
|
@@ -42,7 +42,7 @@
|
|
|
42
42
|
|
|
43
43
|
<div align="center">
|
|
44
44
|
<p align="center">
|
|
45
|
-
|
|
45
|
+
DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
|
|
46
46
|
</p>
|
|
47
47
|
</div>
|
|
48
48
|
|
|
@@ -50,7 +50,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
50
50
|
|
|
51
51
|
<div align="center">
|
|
52
52
|
<p align="center">
|
|
53
|
-
<h3
|
|
53
|
+
<h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
54
54
|
</p>
|
|
55
55
|
</div>
|
|
56
56
|
|
|
@@ -67,10 +67,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
67
67
|
</p>
|
|
68
68
|
</div>
|
|
69
69
|
|
|
70
|
-
Moreover,
|
|
70
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
71
71
|
|
|
72
72
|
<p align="center">
|
|
73
|
-
|
|
73
|
+
♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
|
|
74
74
|
</p>
|
|
75
75
|
|
|
76
76
|
## ©️Citations
|
|
@@ -100,7 +100,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
|
|
|
100
100
|
- [🎉First Block Cache](#fbcache)
|
|
101
101
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
102
102
|
- [🎉Context Parallelism](#context-parallelism)
|
|
103
|
-
- [
|
|
103
|
+
- [🔥Torch Compile](#compile)
|
|
104
104
|
- [🎉Supported Models](#supported)
|
|
105
105
|
- [👋Contribute](#contribute)
|
|
106
106
|
- [©️License](#license)
|
|
@@ -132,7 +132,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
132
132
|
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
|
133
133
|
- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
|
134
134
|
- **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
135
|
-
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the
|
|
135
|
+
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
|
|
136
136
|
- **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
|
|
137
137
|
|
|
138
138
|
For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
|
|
@@ -175,6 +175,17 @@ cache_options = {
|
|
|
175
175
|
}
|
|
176
176
|
```
|
|
177
177
|
|
|
178
|
+
<div align="center">
|
|
179
|
+
<p align="center">
|
|
180
|
+
DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
181
|
+
</p>
|
|
182
|
+
</div>
|
|
183
|
+
|
|
184
|
+
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
|
|
185
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
186
|
+
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
187
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
188
|
+
|
|
178
189
|
## 🎉FBCache: First Block Cache
|
|
179
190
|
|
|
180
191
|
<div id="fbcache"></div>
|
|
@@ -226,12 +237,47 @@ pipe = FluxPipeline.from_pretrained(
|
|
|
226
237
|
torch_dtype=torch.bfloat16,
|
|
227
238
|
).to("cuda")
|
|
228
239
|
|
|
229
|
-
# Using DBPrune
|
|
240
|
+
# Using DBPrune with default options
|
|
230
241
|
cache_options = CacheType.default_options(CacheType.DBPrune)
|
|
231
242
|
|
|
232
243
|
apply_cache_on_pipe(pipe, **cache_options)
|
|
233
244
|
```
|
|
234
245
|
|
|
246
|
+
We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
|
|
247
|
+
|
|
248
|
+
```python
|
|
249
|
+
# Custom options for DBPrune
|
|
250
|
+
cache_options = {
|
|
251
|
+
"cache_type": CacheType.DBPrune,
|
|
252
|
+
"residual_diff_threshold": 0.05,
|
|
253
|
+
# Never prune the first `Fn` and last `Bn` blocks.
|
|
254
|
+
"Fn_compute_blocks": 8, # default 1
|
|
255
|
+
"Bn_compute_blocks": 8, # default 0
|
|
256
|
+
"warmup_steps": 8, # default -1
|
|
257
|
+
# Disables the pruning strategy when the previous
|
|
258
|
+
# pruned steps greater than this value.
|
|
259
|
+
"max_pruned_steps": 12, # default, -1 means no limit
|
|
260
|
+
# Enable dynamic prune threshold within step, higher
|
|
261
|
+
# `max_dynamic_prune_threshold` value may introduce a more
|
|
262
|
+
# ageressive pruning strategy.
|
|
263
|
+
"enable_dynamic_prune_threshold": True,
|
|
264
|
+
"max_dynamic_prune_threshold": 2 * 0.05,
|
|
265
|
+
# (New thresh) = mean(previous_block_diffs_within_step) * 1.25
|
|
266
|
+
# (New thresh) = ((New thresh) if (New thresh) <
|
|
267
|
+
# max_dynamic_prune_threshold else residual_diff_threshold)
|
|
268
|
+
"dynamic_prune_threshold_relax_ratio": 1.25,
|
|
269
|
+
# The step interval to update residual cache. For example,
|
|
270
|
+
# 2: means the update steps will be [0, 2, 4, ...].
|
|
271
|
+
"residual_cache_update_interval": 1,
|
|
272
|
+
# You can set non-prune blocks to avoid ageressive pruning.
|
|
273
|
+
# For example, FLUX.1 has 19 + 38 blocks, so we can set it
|
|
274
|
+
# to 0, 2, 4, ..., 56, etc.
|
|
275
|
+
"non_prune_blocks_ids": [],
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
apply_cache_on_pipe(pipe, **cache_options)
|
|
279
|
+
```
|
|
280
|
+
|
|
235
281
|
<div align="center">
|
|
236
282
|
<p align="center">
|
|
237
283
|
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
@@ -253,14 +299,19 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
253
299
|
pip3 install para-attn # or install `para-attn` from sources.
|
|
254
300
|
```
|
|
255
301
|
|
|
256
|
-
Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
|
|
302
|
+
Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
|
|
257
303
|
|
|
258
304
|
```python
|
|
305
|
+
import torch.distributed as dist
|
|
259
306
|
from diffusers import FluxPipeline
|
|
260
307
|
from para_attn.context_parallel import init_context_parallel_mesh
|
|
261
308
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
262
309
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
263
310
|
|
|
311
|
+
# Init distributed process group
|
|
312
|
+
dist.init_process_group()
|
|
313
|
+
torch.cuda.set_device(dist.get_rank())
|
|
314
|
+
|
|
264
315
|
pipe = FluxPipeline.from_pretrained(
|
|
265
316
|
"black-forest-labs/FLUX.1-dev",
|
|
266
317
|
torch_dtype=torch.bfloat16,
|
|
@@ -273,13 +324,31 @@ parallelize_pipe(
|
|
|
273
324
|
)
|
|
274
325
|
)
|
|
275
326
|
|
|
276
|
-
#
|
|
327
|
+
# DBPrune with default options from this library
|
|
277
328
|
apply_cache_on_pipe(
|
|
278
|
-
pipe, **CacheType.default_options(CacheType.
|
|
329
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
279
330
|
)
|
|
331
|
+
|
|
332
|
+
dist.destroy_process_group()
|
|
333
|
+
```
|
|
334
|
+
Then, run the python test script with `torchrun`:
|
|
335
|
+
```bash
|
|
336
|
+
torchrun --nproc_per_node=4 parallel_cache.py
|
|
280
337
|
```
|
|
281
338
|
|
|
282
|
-
|
|
339
|
+
<div align="center">
|
|
340
|
+
<p align="center">
|
|
341
|
+
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
342
|
+
</p>
|
|
343
|
+
</div>
|
|
344
|
+
|
|
345
|
+
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
346
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
347
|
+
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
348
|
+
|8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
|
|
349
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
350
|
+
|
|
351
|
+
## 🔥Torch Compile
|
|
283
352
|
|
|
284
353
|
<div id="compile"></div>
|
|
285
354
|
|
|
@@ -287,7 +356,7 @@ apply_cache_on_pipe(
|
|
|
287
356
|
|
|
288
357
|
```python
|
|
289
358
|
apply_cache_on_pipe(
|
|
290
|
-
pipe, **CacheType.default_options(CacheType.
|
|
359
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
291
360
|
)
|
|
292
361
|
# Compile the Transformer module
|
|
293
362
|
pipe.transformer = torch.compile(pipe.transformer)
|
|
@@ -298,7 +367,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
|
|
|
298
367
|
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
299
368
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
300
369
|
```
|
|
301
|
-
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
370
|
+
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
|
|
371
|
+
|
|
372
|
+
<div align="center">
|
|
373
|
+
<p align="center">
|
|
374
|
+
DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
375
|
+
</p>
|
|
376
|
+
</div>
|
|
302
377
|
|
|
303
378
|
## 🎉Supported Models
|
|
304
379
|
|
|
@@ -311,7 +386,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
|
|
|
311
386
|
## 👋Contribute
|
|
312
387
|
<div id="contribute"></div>
|
|
313
388
|
|
|
314
|
-
How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
389
|
+
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
315
390
|
|
|
316
391
|
## ©️License
|
|
317
392
|
|
|
@@ -16,6 +16,7 @@ def get_args() -> argparse.ArgumentParser:
|
|
|
16
16
|
# General arguments
|
|
17
17
|
parser.add_argument("--steps", type=int, default=28)
|
|
18
18
|
parser.add_argument("--repeats", type=int, default=2)
|
|
19
|
+
parser.add_argument("--seed", type=int, default=0)
|
|
19
20
|
parser.add_argument("--cache", type=str, default=None)
|
|
20
21
|
parser.add_argument("--alter", action="store_true", default=False)
|
|
21
22
|
parser.add_argument("--l1-diff", action="store_true", default=False)
|
|
@@ -26,12 +27,9 @@ def get_args() -> argparse.ArgumentParser:
|
|
|
26
27
|
parser.add_argument("--warmup-steps", type=int, default=0)
|
|
27
28
|
parser.add_argument("--max-cached-steps", type=int, default=-1)
|
|
28
29
|
parser.add_argument("--max-pruned-steps", type=int, default=-1)
|
|
29
|
-
parser.add_argument("--
|
|
30
|
-
parser.add_argument(
|
|
31
|
-
|
|
32
|
-
action="store_true",
|
|
33
|
-
default=False,
|
|
34
|
-
)
|
|
30
|
+
parser.add_argument("--ulysses", type=int, default=None)
|
|
31
|
+
parser.add_argument("--compile", action="store_true", default=False)
|
|
32
|
+
parser.add_argument("--gen-device", type=str, default="cuda")
|
|
35
33
|
return parser.parse_args()
|
|
36
34
|
|
|
37
35
|
|
|
@@ -116,10 +114,41 @@ def main():
|
|
|
116
114
|
args = get_args()
|
|
117
115
|
logger.info(f"Arguments: {args}")
|
|
118
116
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
117
|
+
# Context Parallel from ParaAttention
|
|
118
|
+
if args.ulysses is not None:
|
|
119
|
+
try:
|
|
120
|
+
import torch.distributed as dist
|
|
121
|
+
from para_attn.context_parallel import init_context_parallel_mesh
|
|
122
|
+
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
123
|
+
|
|
124
|
+
# Initialize distributed process group
|
|
125
|
+
dist.init_process_group()
|
|
126
|
+
torch.cuda.set_device(dist.get_rank())
|
|
127
|
+
|
|
128
|
+
logger.info(f"Ulysses: {args.ulysses}")
|
|
129
|
+
|
|
130
|
+
pipe = FluxPipeline.from_pretrained(
|
|
131
|
+
os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
|
|
132
|
+
torch_dtype=torch.bfloat16,
|
|
133
|
+
).to("cuda")
|
|
134
|
+
|
|
135
|
+
parallelize_pipe(
|
|
136
|
+
pipe, mesh=init_context_parallel_mesh(
|
|
137
|
+
pipe.device.type, max_ulysses_dim_size=args.ulysses
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
except ImportError as e:
|
|
141
|
+
logger.error(
|
|
142
|
+
"para-attn is not installed, please install it "
|
|
143
|
+
"with `pip install para-attn.`"
|
|
144
|
+
)
|
|
145
|
+
args.ulysses = None
|
|
146
|
+
raise e
|
|
147
|
+
else:
|
|
148
|
+
pipe = FluxPipeline.from_pretrained(
|
|
149
|
+
os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
|
|
150
|
+
torch_dtype=torch.bfloat16,
|
|
151
|
+
).to("cuda")
|
|
123
152
|
|
|
124
153
|
cache_options, cache_type = get_cache_options(args.cache, args)
|
|
125
154
|
|
|
@@ -149,7 +178,7 @@ def main():
|
|
|
149
178
|
image = pipe(
|
|
150
179
|
"A cat holding a sign that says hello world with complex background",
|
|
151
180
|
num_inference_steps=args.steps,
|
|
152
|
-
generator=torch.Generator(
|
|
181
|
+
generator=torch.Generator(args.gen_device).manual_seed(args.seed),
|
|
153
182
|
).images[0]
|
|
154
183
|
end = time.time()
|
|
155
184
|
all_times.append(end - start)
|
|
@@ -191,19 +220,27 @@ def main():
|
|
|
191
220
|
f"Actual Blocks: {actual_blocks}\n"
|
|
192
221
|
f"Pruned Blocks: {pruned_blocks}"
|
|
193
222
|
)
|
|
223
|
+
ulysses = 0 if args.ulysses is None else args.ulysses
|
|
194
224
|
if len(actual_blocks) > 0:
|
|
195
225
|
save_name = (
|
|
196
|
-
f"{
|
|
226
|
+
f"U{ulysses}_C{int(args.compile)}_{cache_type}_"
|
|
227
|
+
f"R{args.rdt}_P{pruned_ratio:.1f}_"
|
|
197
228
|
f"T{mean_time:.2f}s.png"
|
|
198
229
|
)
|
|
199
230
|
else:
|
|
200
231
|
save_name = (
|
|
201
|
-
f"{
|
|
232
|
+
f"U{ulysses}_C{int(args.compile)}_{cache_type}_"
|
|
233
|
+
f"R{args.rdt}_S{cached_stepes}_"
|
|
202
234
|
f"T{mean_time:.2f}s.png"
|
|
203
235
|
)
|
|
204
236
|
image.save(save_name)
|
|
205
237
|
logger.info(f"Image saved as {save_name}")
|
|
206
238
|
|
|
239
|
+
if args.ulysses is not None:
|
|
240
|
+
import torch.distributed as dist
|
|
241
|
+
dist.destroy_process_group()
|
|
242
|
+
logger.info("Distributed process group destroyed.")
|
|
243
|
+
|
|
207
244
|
|
|
208
245
|
if __name__ == "__main__":
|
|
209
246
|
main()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from diffusers import CogVideoXPipeline
|
|
3
|
+
from diffusers.utils import export_to_video
|
|
4
|
+
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
5
|
+
|
|
6
|
+
pipe = CogVideoXPipeline.from_pretrained(
|
|
7
|
+
"THUDM/CogVideoX-5b",
|
|
8
|
+
torch_dtype=torch.bfloat16,
|
|
9
|
+
).to("cuda")
|
|
10
|
+
|
|
11
|
+
# Default options, F8B8, good balance between performance and precision
|
|
12
|
+
cache_options = CacheType.default_options(CacheType.DBCache)
|
|
13
|
+
|
|
14
|
+
apply_cache_on_pipe(pipe, **cache_options)
|
|
15
|
+
|
|
16
|
+
pipe.vae.enable_slicing()
|
|
17
|
+
pipe.vae.enable_tiling()
|
|
18
|
+
|
|
19
|
+
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
|
|
20
|
+
video = pipe(
|
|
21
|
+
prompt=prompt,
|
|
22
|
+
num_videos_per_prompt=1,
|
|
23
|
+
num_inference_steps=50,
|
|
24
|
+
num_frames=49,
|
|
25
|
+
guidance_scale=6,
|
|
26
|
+
generator=torch.Generator("cuda").manual_seed(0),
|
|
27
|
+
).frames[0]
|
|
28
|
+
|
|
29
|
+
print("Saving video to cogvideox.mp4")
|
|
30
|
+
export_to_video(video, "cogvideox.mp4", fps=8)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from diffusers import MochiPipeline
|
|
3
|
+
from diffusers.utils import export_to_video
|
|
4
|
+
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
5
|
+
|
|
6
|
+
pipe = MochiPipeline.from_pretrained(
|
|
7
|
+
"genmo/mochi-1-preview",
|
|
8
|
+
torch_dtype=torch.bfloat16,
|
|
9
|
+
).to("cuda")
|
|
10
|
+
|
|
11
|
+
# Default options, F8B8, good balance between performance and precision
|
|
12
|
+
cache_options = CacheType.default_options(CacheType.DBCache)
|
|
13
|
+
|
|
14
|
+
apply_cache_on_pipe(pipe, **cache_options)
|
|
15
|
+
|
|
16
|
+
pipe.enable_vae_tiling()
|
|
17
|
+
|
|
18
|
+
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
|
19
|
+
video = pipe(
|
|
20
|
+
prompt,
|
|
21
|
+
num_frames=84,
|
|
22
|
+
).frames[0]
|
|
23
|
+
|
|
24
|
+
print("Saving video to mochi.mp4")
|
|
25
|
+
export_to_video(video, "mochi.mp4", fps=30)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# Example requirement, can be anything that pip knows
|
|
2
2
|
# install with `pip install -r requirements.txt`, and make sure that CI does the same
|
|
3
3
|
packaging
|
|
4
|
-
torch
|
|
5
|
-
transformers
|
|
6
|
-
diffusers
|
|
4
|
+
torch>=2.5.1 # torch 2.7.0 is preferred, but 2.5.1 is the minimum required version
|
|
5
|
+
transformers>=4.51.3
|
|
6
|
+
diffusers>=0.33.1
|
{cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py
RENAMED
|
@@ -562,7 +562,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
562
562
|
torch._dynamo.graph_break()
|
|
563
563
|
|
|
564
564
|
add_pruned_block(self.pruned_blocks_step)
|
|
565
|
-
add_actual_block(self.
|
|
565
|
+
add_actual_block(self.num_transformer_blocks)
|
|
566
566
|
patch_pruned_stats(self.transformer)
|
|
567
567
|
|
|
568
568
|
return (
|
|
@@ -577,7 +577,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
577
577
|
|
|
578
578
|
@property
|
|
579
579
|
@torch.compiler.disable
|
|
580
|
-
def
|
|
580
|
+
def num_transformer_blocks(self):
|
|
581
581
|
# Total number of transformer blocks, including single transformer blocks.
|
|
582
582
|
num_blocks = len(self.transformer_blocks)
|
|
583
583
|
if self.single_transformer_blocks is not None:
|
|
@@ -597,7 +597,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
597
597
|
@torch.compiler.disable
|
|
598
598
|
def _non_prune_blocks_ids(self):
|
|
599
599
|
# Never prune the first `Fn` and last `Bn` blocks.
|
|
600
|
-
num_blocks = self.
|
|
600
|
+
num_blocks = self.num_transformer_blocks
|
|
601
601
|
Fn_compute_blocks_ = (
|
|
602
602
|
Fn_compute_blocks()
|
|
603
603
|
if Fn_compute_blocks() < num_blocks
|
|
@@ -627,6 +627,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
627
627
|
]
|
|
628
628
|
return sorted(non_prune_blocks_ids)
|
|
629
629
|
|
|
630
|
+
# @torch.compile(dynamic=True)
|
|
631
|
+
# mark this function as compile with dynamic=True will
|
|
632
|
+
# cause precision degradate, so, we choose to disable it
|
|
633
|
+
# now, until we find a better solution or fixed the bug.
|
|
630
634
|
@torch.compiler.disable
|
|
631
635
|
def _compute_single_hidden_states_residual(
|
|
632
636
|
self,
|
|
@@ -663,6 +667,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
663
667
|
single_encoder_hidden_states_residual,
|
|
664
668
|
)
|
|
665
669
|
|
|
670
|
+
# @torch.compile(dynamic=True)
|
|
671
|
+
# mark this function as compile with dynamic=True will
|
|
672
|
+
# cause precision degradate, so, we choose to disable it
|
|
673
|
+
# now, until we find a better solution or fixed the bug.
|
|
666
674
|
@torch.compiler.disable
|
|
667
675
|
def _split_single_hidden_states(
|
|
668
676
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -10,9 +10,9 @@ Requires-Python: >=3.10
|
|
|
10
10
|
Description-Content-Type: text/markdown
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Requires-Dist: packaging
|
|
13
|
-
Requires-Dist: torch
|
|
14
|
-
Requires-Dist: transformers
|
|
15
|
-
Requires-Dist: diffusers
|
|
13
|
+
Requires-Dist: torch>=2.5.1
|
|
14
|
+
Requires-Dist: transformers>=4.51.3
|
|
15
|
+
Requires-Dist: diffusers>=0.33.1
|
|
16
16
|
Provides-Extra: all
|
|
17
17
|
Provides-Extra: dev
|
|
18
18
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -44,10 +44,10 @@ Dynamic: requires-python
|
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.1.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT
|
|
50
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
51
51
|
</p>
|
|
52
52
|
</div>
|
|
53
53
|
|
|
@@ -55,7 +55,7 @@ Dynamic: requires-python
|
|
|
55
55
|
|
|
56
56
|
<div align="center">
|
|
57
57
|
<p align="center">
|
|
58
|
-
<h3
|
|
58
|
+
<h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
59
59
|
</p>
|
|
60
60
|
</div>
|
|
61
61
|
|
|
@@ -77,7 +77,7 @@ Dynamic: requires-python
|
|
|
77
77
|
|
|
78
78
|
<div align="center">
|
|
79
79
|
<p align="center">
|
|
80
|
-
|
|
80
|
+
DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
|
|
81
81
|
</p>
|
|
82
82
|
</div>
|
|
83
83
|
|
|
@@ -85,7 +85,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
85
85
|
|
|
86
86
|
<div align="center">
|
|
87
87
|
<p align="center">
|
|
88
|
-
<h3
|
|
88
|
+
<h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
89
89
|
</p>
|
|
90
90
|
</div>
|
|
91
91
|
|
|
@@ -102,10 +102,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
102
102
|
</p>
|
|
103
103
|
</div>
|
|
104
104
|
|
|
105
|
-
Moreover,
|
|
105
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
106
106
|
|
|
107
107
|
<p align="center">
|
|
108
|
-
|
|
108
|
+
♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
|
|
109
109
|
</p>
|
|
110
110
|
|
|
111
111
|
## ©️Citations
|
|
@@ -135,7 +135,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
|
|
|
135
135
|
- [🎉First Block Cache](#fbcache)
|
|
136
136
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
137
137
|
- [🎉Context Parallelism](#context-parallelism)
|
|
138
|
-
- [
|
|
138
|
+
- [🔥Torch Compile](#compile)
|
|
139
139
|
- [🎉Supported Models](#supported)
|
|
140
140
|
- [👋Contribute](#contribute)
|
|
141
141
|
- [©️License](#license)
|
|
@@ -167,7 +167,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
167
167
|
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
|
168
168
|
- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
|
169
169
|
- **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
170
|
-
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the
|
|
170
|
+
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
|
|
171
171
|
- **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
|
|
172
172
|
|
|
173
173
|
For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
|
|
@@ -210,6 +210,17 @@ cache_options = {
|
|
|
210
210
|
}
|
|
211
211
|
```
|
|
212
212
|
|
|
213
|
+
<div align="center">
|
|
214
|
+
<p align="center">
|
|
215
|
+
DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
216
|
+
</p>
|
|
217
|
+
</div>
|
|
218
|
+
|
|
219
|
+
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
|
|
220
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
221
|
+
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
222
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
223
|
+
|
|
213
224
|
## 🎉FBCache: First Block Cache
|
|
214
225
|
|
|
215
226
|
<div id="fbcache"></div>
|
|
@@ -261,12 +272,47 @@ pipe = FluxPipeline.from_pretrained(
|
|
|
261
272
|
torch_dtype=torch.bfloat16,
|
|
262
273
|
).to("cuda")
|
|
263
274
|
|
|
264
|
-
# Using DBPrune
|
|
275
|
+
# Using DBPrune with default options
|
|
265
276
|
cache_options = CacheType.default_options(CacheType.DBPrune)
|
|
266
277
|
|
|
267
278
|
apply_cache_on_pipe(pipe, **cache_options)
|
|
268
279
|
```
|
|
269
280
|
|
|
281
|
+
We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
|
|
282
|
+
|
|
283
|
+
```python
|
|
284
|
+
# Custom options for DBPrune
|
|
285
|
+
cache_options = {
|
|
286
|
+
"cache_type": CacheType.DBPrune,
|
|
287
|
+
"residual_diff_threshold": 0.05,
|
|
288
|
+
# Never prune the first `Fn` and last `Bn` blocks.
|
|
289
|
+
"Fn_compute_blocks": 8, # default 1
|
|
290
|
+
"Bn_compute_blocks": 8, # default 0
|
|
291
|
+
"warmup_steps": 8, # default -1
|
|
292
|
+
# Disables the pruning strategy when the previous
|
|
293
|
+
# pruned steps greater than this value.
|
|
294
|
+
"max_pruned_steps": 12, # default, -1 means no limit
|
|
295
|
+
# Enable dynamic prune threshold within step, higher
|
|
296
|
+
# `max_dynamic_prune_threshold` value may introduce a more
|
|
297
|
+
# ageressive pruning strategy.
|
|
298
|
+
"enable_dynamic_prune_threshold": True,
|
|
299
|
+
"max_dynamic_prune_threshold": 2 * 0.05,
|
|
300
|
+
# (New thresh) = mean(previous_block_diffs_within_step) * 1.25
|
|
301
|
+
# (New thresh) = ((New thresh) if (New thresh) <
|
|
302
|
+
# max_dynamic_prune_threshold else residual_diff_threshold)
|
|
303
|
+
"dynamic_prune_threshold_relax_ratio": 1.25,
|
|
304
|
+
# The step interval to update residual cache. For example,
|
|
305
|
+
# 2: means the update steps will be [0, 2, 4, ...].
|
|
306
|
+
"residual_cache_update_interval": 1,
|
|
307
|
+
# You can set non-prune blocks to avoid ageressive pruning.
|
|
308
|
+
# For example, FLUX.1 has 19 + 38 blocks, so we can set it
|
|
309
|
+
# to 0, 2, 4, ..., 56, etc.
|
|
310
|
+
"non_prune_blocks_ids": [],
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
apply_cache_on_pipe(pipe, **cache_options)
|
|
314
|
+
```
|
|
315
|
+
|
|
270
316
|
<div align="center">
|
|
271
317
|
<p align="center">
|
|
272
318
|
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
@@ -288,14 +334,19 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
288
334
|
pip3 install para-attn # or install `para-attn` from sources.
|
|
289
335
|
```
|
|
290
336
|
|
|
291
|
-
Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
|
|
337
|
+
Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
|
|
292
338
|
|
|
293
339
|
```python
|
|
340
|
+
import torch.distributed as dist
|
|
294
341
|
from diffusers import FluxPipeline
|
|
295
342
|
from para_attn.context_parallel import init_context_parallel_mesh
|
|
296
343
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
297
344
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
298
345
|
|
|
346
|
+
# Init distributed process group
|
|
347
|
+
dist.init_process_group()
|
|
348
|
+
torch.cuda.set_device(dist.get_rank())
|
|
349
|
+
|
|
299
350
|
pipe = FluxPipeline.from_pretrained(
|
|
300
351
|
"black-forest-labs/FLUX.1-dev",
|
|
301
352
|
torch_dtype=torch.bfloat16,
|
|
@@ -308,13 +359,31 @@ parallelize_pipe(
|
|
|
308
359
|
)
|
|
309
360
|
)
|
|
310
361
|
|
|
311
|
-
#
|
|
362
|
+
# DBPrune with default options from this library
|
|
312
363
|
apply_cache_on_pipe(
|
|
313
|
-
pipe, **CacheType.default_options(CacheType.
|
|
364
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
314
365
|
)
|
|
366
|
+
|
|
367
|
+
dist.destroy_process_group()
|
|
368
|
+
```
|
|
369
|
+
Then, run the python test script with `torchrun`:
|
|
370
|
+
```bash
|
|
371
|
+
torchrun --nproc_per_node=4 parallel_cache.py
|
|
315
372
|
```
|
|
316
373
|
|
|
317
|
-
|
|
374
|
+
<div align="center">
|
|
375
|
+
<p align="center">
|
|
376
|
+
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
377
|
+
</p>
|
|
378
|
+
</div>
|
|
379
|
+
|
|
380
|
+
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
381
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
382
|
+
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
383
|
+
|8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
|
|
384
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
385
|
+
|
|
386
|
+
## 🔥Torch Compile
|
|
318
387
|
|
|
319
388
|
<div id="compile"></div>
|
|
320
389
|
|
|
@@ -322,7 +391,7 @@ apply_cache_on_pipe(
|
|
|
322
391
|
|
|
323
392
|
```python
|
|
324
393
|
apply_cache_on_pipe(
|
|
325
|
-
pipe, **CacheType.default_options(CacheType.
|
|
394
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
326
395
|
)
|
|
327
396
|
# Compile the Transformer module
|
|
328
397
|
pipe.transformer = torch.compile(pipe.transformer)
|
|
@@ -333,7 +402,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
|
|
|
333
402
|
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
334
403
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
335
404
|
```
|
|
336
|
-
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
405
|
+
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
|
|
406
|
+
|
|
407
|
+
<div align="center">
|
|
408
|
+
<p align="center">
|
|
409
|
+
DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
410
|
+
</p>
|
|
411
|
+
</div>
|
|
337
412
|
|
|
338
413
|
## 🎉Supported Models
|
|
339
414
|
|
|
@@ -346,7 +421,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
|
|
|
346
421
|
## 👋Contribute
|
|
347
422
|
<div id="contribute"></div>
|
|
348
423
|
|
|
349
|
-
How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
424
|
+
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
350
425
|
|
|
351
426
|
## ©️License
|
|
352
427
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/__init__.py
RENAMED
|
File without changes
|
{cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/__init__.py
RENAMED
|
File without changes
|
{cache_dit-0.1.3 → cache_dit-0.1.6}/src/cache_dit/cache_factory/first_block_cache/cache_context.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|