cache-dit 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +5 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +85 -45
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +40 -22
- cache_dit/cache_factory/first_block_cache/cache_context.py +3 -3
- cache_dit/utils.py +7 -0
- {cache_dit-0.2.13.dist-info → cache_dit-0.2.14.dist-info}/METADATA +55 -103
- {cache_dit-0.2.13.dist-info → cache_dit-0.2.14.dist-info}/RECORD +12 -11
- {cache_dit-0.2.13.dist-info → cache_dit-0.2.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.13.dist-info → cache_dit-0.2.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.13.dist-info → cache_dit-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.13.dist-info → cache_dit-0.2.14.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
cache_dit/_version.py
CHANGED
|
@@ -8,8 +8,9 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
import cache_dit.primitives as
|
|
11
|
+
import cache_dit.primitives as primitives
|
|
12
12
|
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
13
|
+
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
13
14
|
from cache_dit.logger import init_logger
|
|
14
15
|
|
|
15
16
|
logger = init_logger(__name__)
|
|
@@ -772,8 +773,8 @@ def are_two_tensors_similar(
|
|
|
772
773
|
mean_t1 = t1.abs().mean()
|
|
773
774
|
|
|
774
775
|
if parallelized:
|
|
775
|
-
mean_diff =
|
|
776
|
-
mean_t1 =
|
|
776
|
+
mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
|
|
777
|
+
mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
|
|
777
778
|
|
|
778
779
|
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
779
780
|
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
@@ -1365,6 +1366,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1365
1366
|
):
|
|
1366
1367
|
original_hidden_states = hidden_states
|
|
1367
1368
|
original_encoder_hidden_states = encoder_hidden_states
|
|
1369
|
+
# This condition branch is mainly for FLUX series.
|
|
1368
1370
|
if self.single_transformer_blocks is not None:
|
|
1369
1371
|
for block in self.transformer_blocks[Fn_compute_blocks() :]:
|
|
1370
1372
|
hidden_states = block(
|
|
@@ -1381,22 +1383,32 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1381
1383
|
hidden_states,
|
|
1382
1384
|
)
|
|
1383
1385
|
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1386
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
1387
|
+
if is_diffusers_at_least_0_3_5():
|
|
1388
|
+
for block in self._Mn_single_transformer_blocks():
|
|
1389
|
+
encoder_hidden_states, hidden_states = block(
|
|
1390
|
+
hidden_states,
|
|
1391
|
+
encoder_hidden_states,
|
|
1392
|
+
*args,
|
|
1393
|
+
**kwargs,
|
|
1394
|
+
)
|
|
1395
|
+
else:
|
|
1396
|
+
hidden_states = torch.cat(
|
|
1397
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
1398
|
+
)
|
|
1399
|
+
for block in self._Mn_single_transformer_blocks():
|
|
1400
|
+
hidden_states = block(
|
|
1401
|
+
hidden_states,
|
|
1402
|
+
*args,
|
|
1403
|
+
**kwargs,
|
|
1404
|
+
)
|
|
1405
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
1406
|
+
[
|
|
1407
|
+
encoder_hidden_states.shape[1],
|
|
1408
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
1409
|
+
],
|
|
1410
|
+
dim=1,
|
|
1392
1411
|
)
|
|
1393
|
-
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
1394
|
-
[
|
|
1395
|
-
encoder_hidden_states.shape[1],
|
|
1396
|
-
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
1397
|
-
],
|
|
1398
|
-
dim=1,
|
|
1399
|
-
)
|
|
1400
1412
|
else:
|
|
1401
1413
|
for block in self._Mn_transformer_blocks():
|
|
1402
1414
|
hidden_states = block(
|
|
@@ -1789,43 +1801,71 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1789
1801
|
|
|
1790
1802
|
original_hidden_states = hidden_states
|
|
1791
1803
|
original_encoder_hidden_states = encoder_hidden_states
|
|
1804
|
+
# This condition branch is mainly for FLUX series.
|
|
1792
1805
|
if self.single_transformer_blocks is not None:
|
|
1793
1806
|
assert Bn_compute_blocks() <= len(self.single_transformer_blocks), (
|
|
1794
1807
|
f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
|
|
1795
1808
|
f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
|
|
1796
1809
|
)
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
|
|
1802
|
-
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1810
|
+
if is_diffusers_at_least_0_3_5():
|
|
1811
|
+
if len(Bn_compute_blocks_ids()) > 0:
|
|
1812
|
+
# NOTE: Reuse _compute_and_cache_transformer_block here.
|
|
1813
|
+
for i, block in enumerate(
|
|
1814
|
+
self._Bn_single_transformer_blocks()
|
|
1815
|
+
):
|
|
1816
|
+
hidden_states, encoder_hidden_states = (
|
|
1817
|
+
self._compute_and_cache_transformer_block(
|
|
1818
|
+
i,
|
|
1819
|
+
block,
|
|
1820
|
+
hidden_states,
|
|
1821
|
+
encoder_hidden_states,
|
|
1822
|
+
*args,
|
|
1823
|
+
**kwargs,
|
|
1824
|
+
)
|
|
1825
|
+
)
|
|
1826
|
+
else:
|
|
1827
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
1828
|
+
for block in self._Bn_single_transformer_blocks():
|
|
1829
|
+
encoder_hidden_states, hidden_states = block(
|
|
1809
1830
|
hidden_states,
|
|
1831
|
+
encoder_hidden_states,
|
|
1810
1832
|
*args,
|
|
1811
1833
|
**kwargs,
|
|
1812
1834
|
)
|
|
1813
|
-
)
|
|
1814
1835
|
else:
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
)
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1836
|
+
hidden_states = torch.cat(
|
|
1837
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
1838
|
+
)
|
|
1839
|
+
if len(Bn_compute_blocks_ids()) > 0:
|
|
1840
|
+
for i, block in enumerate(
|
|
1841
|
+
self._Bn_single_transformer_blocks()
|
|
1842
|
+
):
|
|
1843
|
+
hidden_states = (
|
|
1844
|
+
self._compute_and_cache_single_transformer_block(
|
|
1845
|
+
i,
|
|
1846
|
+
original_hidden_states,
|
|
1847
|
+
original_encoder_hidden_states,
|
|
1848
|
+
block,
|
|
1849
|
+
hidden_states,
|
|
1850
|
+
*args,
|
|
1851
|
+
**kwargs,
|
|
1852
|
+
)
|
|
1853
|
+
)
|
|
1854
|
+
else:
|
|
1855
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
1856
|
+
for block in self._Bn_single_transformer_blocks():
|
|
1857
|
+
hidden_states = block(
|
|
1858
|
+
hidden_states,
|
|
1859
|
+
*args,
|
|
1860
|
+
**kwargs,
|
|
1861
|
+
)
|
|
1862
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
1863
|
+
[
|
|
1864
|
+
encoder_hidden_states.shape[1],
|
|
1865
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
1866
|
+
],
|
|
1867
|
+
dim=1,
|
|
1868
|
+
)
|
|
1829
1869
|
else:
|
|
1830
1870
|
assert Bn_compute_blocks() <= len(self.transformer_blocks), (
|
|
1831
1871
|
f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
|
|
@@ -7,7 +7,8 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
import cache_dit.primitives as
|
|
10
|
+
import cache_dit.primitives as primitives
|
|
11
|
+
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
11
12
|
from cache_dit.logger import init_logger
|
|
12
13
|
|
|
13
14
|
logger = init_logger(__name__)
|
|
@@ -446,8 +447,8 @@ def are_two_tensors_similar(
|
|
|
446
447
|
mean_t1 = t1.abs().mean()
|
|
447
448
|
|
|
448
449
|
if parallelized:
|
|
449
|
-
mean_diff =
|
|
450
|
-
mean_t1 =
|
|
450
|
+
mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
|
|
451
|
+
mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
|
|
451
452
|
|
|
452
453
|
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
453
454
|
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
@@ -936,27 +937,44 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
936
937
|
)
|
|
937
938
|
|
|
938
939
|
if self.single_transformer_blocks is not None:
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
940
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
941
|
+
if is_diffusers_at_least_0_3_5():
|
|
942
|
+
for j, block in enumerate(self.single_transformer_blocks):
|
|
943
|
+
# NOTE: Reuse _compute_or_prune_transformer_block here.
|
|
944
|
+
hidden_states, encoder_hidden_states = (
|
|
945
|
+
self._compute_or_prune_transformer_block(
|
|
946
|
+
j + len(self.transformer_blocks),
|
|
947
|
+
block,
|
|
948
|
+
hidden_states,
|
|
949
|
+
encoder_hidden_states,
|
|
950
|
+
*args,
|
|
951
|
+
**kwargs,
|
|
952
|
+
)
|
|
953
|
+
)
|
|
954
|
+
else:
|
|
955
|
+
hidden_states = torch.cat(
|
|
956
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
951
957
|
)
|
|
958
|
+
for j, block in enumerate(self.single_transformer_blocks):
|
|
959
|
+
hidden_states = (
|
|
960
|
+
self._compute_or_prune_single_transformer_block(
|
|
961
|
+
j + len(self.transformer_blocks),
|
|
962
|
+
original_hidden_states,
|
|
963
|
+
original_encoder_hidden_states,
|
|
964
|
+
block,
|
|
965
|
+
hidden_states,
|
|
966
|
+
*args,
|
|
967
|
+
**kwargs,
|
|
968
|
+
)
|
|
969
|
+
)
|
|
952
970
|
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
971
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
972
|
+
[
|
|
973
|
+
encoder_hidden_states.shape[1],
|
|
974
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
975
|
+
],
|
|
976
|
+
dim=1,
|
|
977
|
+
)
|
|
960
978
|
|
|
961
979
|
hidden_states = (
|
|
962
980
|
hidden_states.reshape(-1)
|
|
@@ -7,7 +7,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
import cache_dit.primitives as
|
|
10
|
+
import cache_dit.primitives as primitives
|
|
11
11
|
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
12
12
|
from cache_dit.logger import init_logger
|
|
13
13
|
|
|
@@ -348,8 +348,8 @@ def are_two_tensors_similar(
|
|
|
348
348
|
mean_diff = (t1 - t2).abs().mean()
|
|
349
349
|
mean_t1 = t1.abs().mean()
|
|
350
350
|
if parallelized:
|
|
351
|
-
mean_diff =
|
|
352
|
-
mean_t1 =
|
|
351
|
+
mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
|
|
352
|
+
mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
|
|
353
353
|
diff = (mean_diff / mean_t1).item()
|
|
354
354
|
|
|
355
355
|
add_residual_diff(diff)
|
cache_dit/utils.py
ADDED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.14
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -53,7 +53,7 @@ Dynamic: requires-python
|
|
|
53
53
|
<img src=https://img.shields.io/badge/Release-v0.2-brightgreen.svg >
|
|
54
54
|
</div>
|
|
55
55
|
<p align="center">
|
|
56
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers
|
|
56
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. <br> CacheDiT offers a set of training-free cache accelerators for Diffusion Transformers: <br> <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">Hybrid TaylorSeer</a>, <a href="#cfg">Hybrid Cache CFG</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
|
|
57
57
|
</p>
|
|
58
58
|
</div>
|
|
59
59
|
|
|
@@ -67,95 +67,6 @@ Dynamic: requires-python
|
|
|
67
67
|
- [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the **[PR](https://github.com/huggingface/flux-fast/pull/13)**.
|
|
68
68
|
- [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
|
|
69
69
|
|
|
70
|
-
## 🤗 Introduction
|
|
71
|
-
|
|
72
|
-
<div align="center">
|
|
73
|
-
<p align="center">
|
|
74
|
-
<h3>🔥DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
75
|
-
</p>
|
|
76
|
-
</div>
|
|
77
|
-
|
|
78
|
-
**DBCache**: **Dual Block Caching** for Diffusion Transformers. We have enhanced `FBCache` into a more general and customizable cache algorithm, namely `DBCache`, enabling it to achieve fully `UNet-style` cache acceleration for DiT models. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache. Moreover, it can be entirely **training**-**free**. DBCache can strike a perfect **balance** between performance and precision!
|
|
79
|
-
|
|
80
|
-
<div align="center">
|
|
81
|
-
<p align="center">
|
|
82
|
-
DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
83
|
-
</p>
|
|
84
|
-
</div>
|
|
85
|
-
|
|
86
|
-
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
|
|
87
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
88
|
-
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
89
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
90
|
-
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
|
|
91
|
-
|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
|
|
92
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
|
|
93
|
-
|
|
94
|
-
<div align="center">
|
|
95
|
-
<p align="center">
|
|
96
|
-
DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
|
|
97
|
-
</p>
|
|
98
|
-
</div>
|
|
99
|
-
|
|
100
|
-
These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
|
|
101
|
-
|
|
102
|
-
<div align="center">
|
|
103
|
-
<p align="center">
|
|
104
|
-
<h3>🔥DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
105
|
-
</p>
|
|
106
|
-
</div>
|
|
107
|
-
|
|
108
|
-
**DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
|
|
109
|
-
|
|
110
|
-
<div align="center">
|
|
111
|
-
<p align="center">
|
|
112
|
-
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
113
|
-
</p>
|
|
114
|
-
</div>
|
|
115
|
-
|
|
116
|
-
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
117
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
118
|
-
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
119
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
120
|
-
|
|
121
|
-
<div align="center">
|
|
122
|
-
<p align="center">
|
|
123
|
-
<h3>🔥Context Parallelism and Torch Compile</h3>
|
|
124
|
-
</p>
|
|
125
|
-
</div>
|
|
126
|
-
|
|
127
|
-
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
128
|
-
|
|
129
|
-
<div align="center">
|
|
130
|
-
<p align="center">
|
|
131
|
-
DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
132
|
-
</p>
|
|
133
|
-
</div>
|
|
134
|
-
|
|
135
|
-
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
136
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
137
|
-
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
138
|
-
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
139
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
140
|
-
|
|
141
|
-
## ©️Citations
|
|
142
|
-
|
|
143
|
-
```BibTeX
|
|
144
|
-
@misc{CacheDiT@2025,
|
|
145
|
-
title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
|
|
146
|
-
url={https://github.com/vipshop/cache-dit.git},
|
|
147
|
-
note={Open-source software available at https://github.com/vipshop/cache-dit.git},
|
|
148
|
-
author={vipshop.com},
|
|
149
|
-
year={2025}
|
|
150
|
-
}
|
|
151
|
-
```
|
|
152
|
-
|
|
153
|
-
## 👋Reference
|
|
154
|
-
|
|
155
|
-
<div id="reference"></div>
|
|
156
|
-
|
|
157
|
-
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work!
|
|
158
|
-
|
|
159
70
|
## 📖Contents
|
|
160
71
|
|
|
161
72
|
<div id="contents"></div>
|
|
@@ -172,6 +83,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
|
|
|
172
83
|
- [⚙️Metrics CLI](#metrics)
|
|
173
84
|
- [👋Contribute](#contribute)
|
|
174
85
|
- [©️License](#license)
|
|
86
|
+
- [©️Citations](#citations)
|
|
175
87
|
|
|
176
88
|
## ⚙️Installation
|
|
177
89
|
|
|
@@ -208,6 +120,32 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
208
120
|
|
|
209
121
|

|
|
210
122
|
|
|
123
|
+
|
|
124
|
+
**DBCache**: **Dual Block Caching** for Diffusion Transformers. We have enhanced `FBCache` into a more general and customizable cache algorithm, namely `DBCache`, enabling it to achieve fully `UNet-style` cache acceleration for DiT models. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache. Moreover, it can be entirely **training**-**free**. DBCache can strike a perfect **balance** between performance and precision!
|
|
125
|
+
|
|
126
|
+
<div align="center">
|
|
127
|
+
<p align="center">
|
|
128
|
+
DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
129
|
+
</p>
|
|
130
|
+
</div>
|
|
131
|
+
|
|
132
|
+
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
|
|
133
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
134
|
+
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
135
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
136
|
+
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
|
|
137
|
+
|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
|
|
138
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
|
|
139
|
+
|
|
140
|
+
<div align="center">
|
|
141
|
+
<p align="center">
|
|
142
|
+
DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
|
|
143
|
+
</p>
|
|
144
|
+
</div>
|
|
145
|
+
|
|
146
|
+
These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
|
|
147
|
+
|
|
148
|
+
|
|
211
149
|
**DBCache** provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
|
|
212
150
|
|
|
213
151
|
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
|
@@ -259,17 +197,6 @@ cache_options = {
|
|
|
259
197
|
}
|
|
260
198
|
```
|
|
261
199
|
|
|
262
|
-
<div align="center">
|
|
263
|
-
<p align="center">
|
|
264
|
-
DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
265
|
-
</p>
|
|
266
|
-
</div>
|
|
267
|
-
|
|
268
|
-
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
|
|
269
|
-
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
270
|
-
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
271
|
-
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
272
|
-
|
|
273
200
|
## 🔥Hybrid TaylorSeer
|
|
274
201
|
|
|
275
202
|
<div id="taylorseer"></div>
|
|
@@ -490,6 +417,18 @@ Then, run the python test script with `torchrun`:
|
|
|
490
417
|
torchrun --nproc_per_node=4 parallel_cache.py
|
|
491
418
|
```
|
|
492
419
|
|
|
420
|
+
<div align="center">
|
|
421
|
+
<p align="center">
|
|
422
|
+
DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
423
|
+
</p>
|
|
424
|
+
</div>
|
|
425
|
+
|
|
426
|
+
|Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
427
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
428
|
+
|+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
|
|
429
|
+
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
430
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
431
|
+
|
|
493
432
|
## 🔥Torch Compile
|
|
494
433
|
|
|
495
434
|
<div id="compile"></div>
|
|
@@ -551,5 +490,18 @@ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
|
|
|
551
490
|
|
|
552
491
|
<div id="license"></div>
|
|
553
492
|
|
|
493
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from [FBCache](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
|
|
554
494
|
|
|
555
|
-
|
|
495
|
+
## ©️Citations
|
|
496
|
+
|
|
497
|
+
<div id="citations"></div>
|
|
498
|
+
|
|
499
|
+
```BibTeX
|
|
500
|
+
@misc{CacheDiT@2025,
|
|
501
|
+
title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
|
|
502
|
+
url={https://github.com/vipshop/cache-dit.git},
|
|
503
|
+
note={Open-source software available at https://github.com/vipshop/cache-dit.git},
|
|
504
|
+
author={vipshop.com},
|
|
505
|
+
year={2025}
|
|
506
|
+
}
|
|
507
|
+
```
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
1
|
+
cache_dit/__init__.py,sha256=0-B173-fLi3IA8nJXoS71zK0zD33Xplysd9skmLfEOY,171
|
|
2
|
+
cache_dit/_version.py,sha256=ut2sCt69XoYh0A1_KAmfCg1IKkN6zwqhu2eMFWAhMbQ,513
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
|
+
cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
|
|
5
6
|
cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
|
|
6
7
|
cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
|
|
7
8
|
cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
|
|
8
9
|
cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
|
|
9
10
|
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=
|
|
11
|
+
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=sJ9yxQlcrX4qkPln94FrL0WDe2WIn3_UD2-Mk8YtjSw,73301
|
|
11
12
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
|
|
12
13
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=3xUjvDzor9AkBkDUc0N7kZqM86MIdajuigesnicNzXE,2260
|
|
13
14
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=cIsov6Pf0dRyddqkzTA2CU-jSDotof8LQr-HIoY9T9M,2615
|
|
@@ -15,7 +16,7 @@ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha
|
|
|
15
16
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=8W9m-WeEVE2ytYi9udKEA8Wtb0EnvP3eT2A1Tu-d29k,2252
|
|
16
17
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
|
|
17
18
|
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=
|
|
19
|
+
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=1qarKAsEFiaaN2_ghko2dqGz_R7BTQSOyGtb_eQq38Y,35716
|
|
19
20
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
|
|
20
21
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=KP8NxtHAKzzBOoX0lhvlMgY_5dmP4Z3T5TOfwl4SSyg,2273
|
|
21
22
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=kCB7lL4OIq8TZn-baMIF8D_PVPTFW60omCMVQCb8ebs,2628
|
|
@@ -23,7 +24,7 @@ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,
|
|
|
23
24
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
|
|
24
25
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
|
|
25
26
|
cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
-
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=
|
|
27
|
+
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=qn4zWJ_eEMIPYzrxXoslunxbzK0WueuNtC54Pp5Q57k,23241
|
|
27
28
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
|
|
28
29
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
|
|
29
30
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
|
|
@@ -40,9 +41,9 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
|
|
|
40
41
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
41
42
|
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
42
43
|
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
48
|
-
cache_dit-0.2.
|
|
44
|
+
cache_dit-0.2.14.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
45
|
+
cache_dit-0.2.14.dist-info/METADATA,sha256=EyZN75JcVcvTc5bopHXfl6w-nA-ro9Uit2Sjy5DU66A,25198
|
|
46
|
+
cache_dit-0.2.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
47
|
+
cache_dit-0.2.14.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
48
|
+
cache_dit-0.2.14.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
49
|
+
cache_dit-0.2.14.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|