cache-dit 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -0,0 +1,5 @@
1
+ try:
2
+ from ._version import version as __version__, version_tuple
3
+ except ImportError:
4
+ __version__ = "unknown version"
5
+ version_tuple = (0, 0, "unknown version")
cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.13'
21
- __version_tuple__ = version_tuple = (0, 2, 13)
20
+ __version__ = version = '0.2.14'
21
+ __version_tuple__ = version_tuple = (0, 2, 14)
@@ -8,8 +8,9 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
8
8
 
9
9
  import torch
10
10
 
11
- import cache_dit.primitives as DP
11
+ import cache_dit.primitives as primitives
12
12
  from cache_dit.cache_factory.taylorseer import TaylorSeer
13
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
13
14
  from cache_dit.logger import init_logger
14
15
 
15
16
  logger = init_logger(__name__)
@@ -772,8 +773,8 @@ def are_two_tensors_similar(
772
773
  mean_t1 = t1.abs().mean()
773
774
 
774
775
  if parallelized:
775
- mean_diff = DP.all_reduce_sync(mean_diff, "avg")
776
- mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
776
+ mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
777
+ mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
777
778
 
778
779
  # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
779
780
  # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
@@ -1365,6 +1366,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1365
1366
  ):
1366
1367
  original_hidden_states = hidden_states
1367
1368
  original_encoder_hidden_states = encoder_hidden_states
1369
+ # This condition branch is mainly for FLUX series.
1368
1370
  if self.single_transformer_blocks is not None:
1369
1371
  for block in self.transformer_blocks[Fn_compute_blocks() :]:
1370
1372
  hidden_states = block(
@@ -1381,22 +1383,32 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1381
1383
  hidden_states,
1382
1384
  )
1383
1385
 
1384
- hidden_states = torch.cat(
1385
- [encoder_hidden_states, hidden_states], dim=1
1386
- )
1387
- for block in self._Mn_single_transformer_blocks():
1388
- hidden_states = block(
1389
- hidden_states,
1390
- *args,
1391
- **kwargs,
1386
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
1387
+ if is_diffusers_at_least_0_3_5():
1388
+ for block in self._Mn_single_transformer_blocks():
1389
+ encoder_hidden_states, hidden_states = block(
1390
+ hidden_states,
1391
+ encoder_hidden_states,
1392
+ *args,
1393
+ **kwargs,
1394
+ )
1395
+ else:
1396
+ hidden_states = torch.cat(
1397
+ [encoder_hidden_states, hidden_states], dim=1
1398
+ )
1399
+ for block in self._Mn_single_transformer_blocks():
1400
+ hidden_states = block(
1401
+ hidden_states,
1402
+ *args,
1403
+ **kwargs,
1404
+ )
1405
+ encoder_hidden_states, hidden_states = hidden_states.split(
1406
+ [
1407
+ encoder_hidden_states.shape[1],
1408
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
1409
+ ],
1410
+ dim=1,
1392
1411
  )
1393
- encoder_hidden_states, hidden_states = hidden_states.split(
1394
- [
1395
- encoder_hidden_states.shape[1],
1396
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
1397
- ],
1398
- dim=1,
1399
- )
1400
1412
  else:
1401
1413
  for block in self._Mn_transformer_blocks():
1402
1414
  hidden_states = block(
@@ -1789,43 +1801,71 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1789
1801
 
1790
1802
  original_hidden_states = hidden_states
1791
1803
  original_encoder_hidden_states = encoder_hidden_states
1804
+ # This condition branch is mainly for FLUX series.
1792
1805
  if self.single_transformer_blocks is not None:
1793
1806
  assert Bn_compute_blocks() <= len(self.single_transformer_blocks), (
1794
1807
  f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
1795
1808
  f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
1796
1809
  )
1797
-
1798
- hidden_states = torch.cat(
1799
- [encoder_hidden_states, hidden_states], dim=1
1800
- )
1801
- if len(Bn_compute_blocks_ids()) > 0:
1802
- for i, block in enumerate(self._Bn_single_transformer_blocks()):
1803
- hidden_states = (
1804
- self._compute_and_cache_single_transformer_block(
1805
- i,
1806
- original_hidden_states,
1807
- original_encoder_hidden_states,
1808
- block,
1810
+ if is_diffusers_at_least_0_3_5():
1811
+ if len(Bn_compute_blocks_ids()) > 0:
1812
+ # NOTE: Reuse _compute_and_cache_transformer_block here.
1813
+ for i, block in enumerate(
1814
+ self._Bn_single_transformer_blocks()
1815
+ ):
1816
+ hidden_states, encoder_hidden_states = (
1817
+ self._compute_and_cache_transformer_block(
1818
+ i,
1819
+ block,
1820
+ hidden_states,
1821
+ encoder_hidden_states,
1822
+ *args,
1823
+ **kwargs,
1824
+ )
1825
+ )
1826
+ else:
1827
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1828
+ for block in self._Bn_single_transformer_blocks():
1829
+ encoder_hidden_states, hidden_states = block(
1809
1830
  hidden_states,
1831
+ encoder_hidden_states,
1810
1832
  *args,
1811
1833
  **kwargs,
1812
1834
  )
1813
- )
1814
1835
  else:
1815
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1816
- for block in self._Bn_single_transformer_blocks():
1817
- hidden_states = block(
1818
- hidden_states,
1819
- *args,
1820
- **kwargs,
1821
- )
1822
- encoder_hidden_states, hidden_states = hidden_states.split(
1823
- [
1824
- encoder_hidden_states.shape[1],
1825
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
1826
- ],
1827
- dim=1,
1828
- )
1836
+ hidden_states = torch.cat(
1837
+ [encoder_hidden_states, hidden_states], dim=1
1838
+ )
1839
+ if len(Bn_compute_blocks_ids()) > 0:
1840
+ for i, block in enumerate(
1841
+ self._Bn_single_transformer_blocks()
1842
+ ):
1843
+ hidden_states = (
1844
+ self._compute_and_cache_single_transformer_block(
1845
+ i,
1846
+ original_hidden_states,
1847
+ original_encoder_hidden_states,
1848
+ block,
1849
+ hidden_states,
1850
+ *args,
1851
+ **kwargs,
1852
+ )
1853
+ )
1854
+ else:
1855
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1856
+ for block in self._Bn_single_transformer_blocks():
1857
+ hidden_states = block(
1858
+ hidden_states,
1859
+ *args,
1860
+ **kwargs,
1861
+ )
1862
+ encoder_hidden_states, hidden_states = hidden_states.split(
1863
+ [
1864
+ encoder_hidden_states.shape[1],
1865
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
1866
+ ],
1867
+ dim=1,
1868
+ )
1829
1869
  else:
1830
1870
  assert Bn_compute_blocks() <= len(self.transformer_blocks), (
1831
1871
  f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
@@ -7,7 +7,8 @@ from typing import Any, Dict, List, Optional, Union
7
7
 
8
8
  import torch
9
9
 
10
- import cache_dit.primitives as DP
10
+ import cache_dit.primitives as primitives
11
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
11
12
  from cache_dit.logger import init_logger
12
13
 
13
14
  logger = init_logger(__name__)
@@ -446,8 +447,8 @@ def are_two_tensors_similar(
446
447
  mean_t1 = t1.abs().mean()
447
448
 
448
449
  if parallelized:
449
- mean_diff = DP.all_reduce_sync(mean_diff, "avg")
450
- mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
450
+ mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
451
+ mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
451
452
 
452
453
  # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
453
454
  # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
@@ -936,27 +937,44 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
936
937
  )
937
938
 
938
939
  if self.single_transformer_blocks is not None:
939
- hidden_states = torch.cat(
940
- [encoder_hidden_states, hidden_states], dim=1
941
- )
942
- for j, block in enumerate(self.single_transformer_blocks):
943
- hidden_states = self._compute_or_prune_single_transformer_block(
944
- j + len(self.transformer_blocks),
945
- original_hidden_states,
946
- original_encoder_hidden_states,
947
- block,
948
- hidden_states,
949
- *args,
950
- **kwargs,
940
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
941
+ if is_diffusers_at_least_0_3_5():
942
+ for j, block in enumerate(self.single_transformer_blocks):
943
+ # NOTE: Reuse _compute_or_prune_transformer_block here.
944
+ hidden_states, encoder_hidden_states = (
945
+ self._compute_or_prune_transformer_block(
946
+ j + len(self.transformer_blocks),
947
+ block,
948
+ hidden_states,
949
+ encoder_hidden_states,
950
+ *args,
951
+ **kwargs,
952
+ )
953
+ )
954
+ else:
955
+ hidden_states = torch.cat(
956
+ [encoder_hidden_states, hidden_states], dim=1
951
957
  )
958
+ for j, block in enumerate(self.single_transformer_blocks):
959
+ hidden_states = (
960
+ self._compute_or_prune_single_transformer_block(
961
+ j + len(self.transformer_blocks),
962
+ original_hidden_states,
963
+ original_encoder_hidden_states,
964
+ block,
965
+ hidden_states,
966
+ *args,
967
+ **kwargs,
968
+ )
969
+ )
952
970
 
953
- encoder_hidden_states, hidden_states = hidden_states.split(
954
- [
955
- encoder_hidden_states.shape[1],
956
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
957
- ],
958
- dim=1,
959
- )
971
+ encoder_hidden_states, hidden_states = hidden_states.split(
972
+ [
973
+ encoder_hidden_states.shape[1],
974
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
975
+ ],
976
+ dim=1,
977
+ )
960
978
 
961
979
  hidden_states = (
962
980
  hidden_states.reshape(-1)
@@ -7,7 +7,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
7
7
 
8
8
  import torch
9
9
 
10
- import cache_dit.primitives as DP
10
+ import cache_dit.primitives as primitives
11
11
  from cache_dit.cache_factory.taylorseer import TaylorSeer
12
12
  from cache_dit.logger import init_logger
13
13
 
@@ -348,8 +348,8 @@ def are_two_tensors_similar(
348
348
  mean_diff = (t1 - t2).abs().mean()
349
349
  mean_t1 = t1.abs().mean()
350
350
  if parallelized:
351
- mean_diff = DP.all_reduce_sync(mean_diff, "avg")
352
- mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
351
+ mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
352
+ mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
353
353
  diff = (mean_diff / mean_t1).item()
354
354
 
355
355
  add_residual_diff(diff)
cache_dit/utils.py ADDED
@@ -0,0 +1,7 @@
1
+ import torch
2
+ import diffusers
3
+
4
+
5
+ @torch.compiler.disable
6
+ def is_diffusers_at_least_0_3_5() -> bool:
7
+ return diffusers.__version__ >= "0.35.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.13
3
+ Version: 0.2.14
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -53,7 +53,7 @@ Dynamic: requires-python
53
53
  <img src=https://img.shields.io/badge/Release-v0.2-brightgreen.svg >
54
54
  </div>
55
55
  <p align="center">
56
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers <br>a set of training-free cache accelerators for DiT: <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">TaylorSeer</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
56
+ DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. <br> CacheDiT offers a set of training-free cache accelerators for Diffusion Transformers: <br> <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">Hybrid TaylorSeer</a>, <a href="#cfg">Hybrid Cache CFG</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
57
57
  </p>
58
58
  </div>
59
59
 
@@ -67,95 +67,6 @@ Dynamic: requires-python
67
67
  - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the **[PR](https://github.com/huggingface/flux-fast/pull/13)**.
68
68
  - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
69
69
 
70
- ## 🤗 Introduction
71
-
72
- <div align="center">
73
- <p align="center">
74
- <h3>🔥DBCache: Dual Block Caching for Diffusion Transformers</h3>
75
- </p>
76
- </div>
77
-
78
- **DBCache**: **Dual Block Caching** for Diffusion Transformers. We have enhanced `FBCache` into a more general and customizable cache algorithm, namely `DBCache`, enabling it to achieve fully `UNet-style` cache acceleration for DiT models. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache. Moreover, it can be entirely **training**-**free**. DBCache can strike a perfect **balance** between performance and precision!
79
-
80
- <div align="center">
81
- <p align="center">
82
- DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
83
- </p>
84
- </div>
85
-
86
- |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
87
- |:---:|:---:|:---:|:---:|:---:|:---:|
88
- |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
89
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
90
- |**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
91
- |27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
92
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
93
-
94
- <div align="center">
95
- <p align="center">
96
- DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
97
- </p>
98
- </div>
99
-
100
- These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
101
-
102
- <div align="center">
103
- <p align="center">
104
- <h3>🔥DBPrune: Dynamic Block Prune with Residual Caching</h3>
105
- </p>
106
- </div>
107
-
108
- **DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
109
-
110
- <div align="center">
111
- <p align="center">
112
- DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
113
- </p>
114
- </div>
115
-
116
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
117
- |:---:|:---:|:---:|:---:|:---:|:---:|
118
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
119
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
120
-
121
- <div align="center">
122
- <p align="center">
123
- <h3>🔥Context Parallelism and Torch Compile</h3>
124
- </p>
125
- </div>
126
-
127
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
128
-
129
- <div align="center">
130
- <p align="center">
131
- DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
132
- </p>
133
- </div>
134
-
135
- |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
136
- |:---:|:---:|:---:|:---:|:---:|:---:|
137
- |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
138
- |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
139
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
140
-
141
- ## ©️Citations
142
-
143
- ```BibTeX
144
- @misc{CacheDiT@2025,
145
- title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
146
- url={https://github.com/vipshop/cache-dit.git},
147
- note={Open-source software available at https://github.com/vipshop/cache-dit.git},
148
- author={vipshop.com},
149
- year={2025}
150
- }
151
- ```
152
-
153
- ## 👋Reference
154
-
155
- <div id="reference"></div>
156
-
157
- The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work!
158
-
159
70
  ## 📖Contents
160
71
 
161
72
  <div id="contents"></div>
@@ -172,6 +83,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
172
83
  - [⚙️Metrics CLI](#metrics)
173
84
  - [👋Contribute](#contribute)
174
85
  - [©️License](#license)
86
+ - [©️Citations](#citations)
175
87
 
176
88
  ## ⚙️Installation
177
89
 
@@ -208,6 +120,32 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
208
120
 
209
121
  ![](https://github.com/vipshop/cache-dit/raw/main/assets/dbcache-v1.png)
210
122
 
123
+
124
+ **DBCache**: **Dual Block Caching** for Diffusion Transformers. We have enhanced `FBCache` into a more general and customizable cache algorithm, namely `DBCache`, enabling it to achieve fully `UNet-style` cache acceleration for DiT models. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache. Moreover, it can be entirely **training**-**free**. DBCache can strike a perfect **balance** between performance and precision!
125
+
126
+ <div align="center">
127
+ <p align="center">
128
+ DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
129
+ </p>
130
+ </div>
131
+
132
+ |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
133
+ |:---:|:---:|:---:|:---:|:---:|:---:|
134
+ |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
135
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
136
+ |**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
137
+ |27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
138
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
139
+
140
+ <div align="center">
141
+ <p align="center">
142
+ DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
143
+ </p>
144
+ </div>
145
+
146
+ These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
147
+
148
+
211
149
  **DBCache** provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
212
150
 
213
151
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
@@ -259,17 +197,6 @@ cache_options = {
259
197
  }
260
198
  ```
261
199
 
262
- <div align="center">
263
- <p align="center">
264
- DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
265
- </p>
266
- </div>
267
-
268
- |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
269
- |:---:|:---:|:---:|:---:|:---:|:---:|
270
- |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
271
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
272
-
273
200
  ## 🔥Hybrid TaylorSeer
274
201
 
275
202
  <div id="taylorseer"></div>
@@ -490,6 +417,18 @@ Then, run the python test script with `torchrun`:
490
417
  torchrun --nproc_per_node=4 parallel_cache.py
491
418
  ```
492
419
 
420
+ <div align="center">
421
+ <p align="center">
422
+ DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
423
+ </p>
424
+ </div>
425
+
426
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
427
+ |:---:|:---:|:---:|:---:|:---:|:---:|
428
+ |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
429
+ |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
430
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
431
+
493
432
  ## 🔥Torch Compile
494
433
 
495
434
  <div id="compile"></div>
@@ -551,5 +490,18 @@ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
551
490
 
552
491
  <div id="license"></div>
553
492
 
493
+ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from [FBCache](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
554
494
 
555
- We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
495
+ ## ©️Citations
496
+
497
+ <div id="citations"></div>
498
+
499
+ ```BibTeX
500
+ @misc{CacheDiT@2025,
501
+ title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
502
+ url={https://github.com/vipshop/cache-dit.git},
503
+ note={Open-source software available at https://github.com/vipshop/cache-dit.git},
504
+ author={vipshop.com},
505
+ year={2025}
506
+ }
507
+ ```
@@ -1,13 +1,14 @@
1
- cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=2ECxD0Bipdh9vxnyteM0k9jxi9NOpPR7YxTi7Ad1ors,513
1
+ cache_dit/__init__.py,sha256=0-B173-fLi3IA8nJXoS71zK0zD33Xplysd9skmLfEOY,171
2
+ cache_dit/_version.py,sha256=ut2sCt69XoYh0A1_KAmfCg1IKkN6zwqhu2eMFWAhMbQ,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
+ cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
5
6
  cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
6
7
  cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
7
8
  cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
8
9
  cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
9
10
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=itVEb6gT2eZuncAHUmP51ZS0r6v6cGtRvnPjyeXqKH8,71156
11
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=sJ9yxQlcrX4qkPln94FrL0WDe2WIn3_UD2-Mk8YtjSw,73301
11
12
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
12
13
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=3xUjvDzor9AkBkDUc0N7kZqM86MIdajuigesnicNzXE,2260
13
14
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=cIsov6Pf0dRyddqkzTA2CU-jSDotof8LQr-HIoY9T9M,2615
@@ -15,7 +16,7 @@ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha
15
16
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=8W9m-WeEVE2ytYi9udKEA8Wtb0EnvP3eT2A1Tu-d29k,2252
16
17
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
17
18
  cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=so1wGdb8W0ATwrjv7E5IEZLPcobybaY1HJa6hBYlOOQ,34698
19
+ cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=1qarKAsEFiaaN2_ghko2dqGz_R7BTQSOyGtb_eQq38Y,35716
19
20
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
20
21
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=KP8NxtHAKzzBOoX0lhvlMgY_5dmP4Z3T5TOfwl4SSyg,2273
21
22
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=kCB7lL4OIq8TZn-baMIF8D_PVPTFW60omCMVQCb8ebs,2628
@@ -23,7 +24,7 @@ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,
23
24
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
24
25
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
25
26
  cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=tTPwhPLEA7LqGupps1Zy2MycCtLzs22wsW0yUhiiF-U,23217
27
+ cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=qn4zWJ_eEMIPYzrxXoslunxbzK0WueuNtC54Pp5Q57k,23241
27
28
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
28
29
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
29
30
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
@@ -40,9 +41,9 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
40
41
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
41
42
  cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
42
43
  cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
43
- cache_dit-0.2.13.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
44
- cache_dit-0.2.13.dist-info/METADATA,sha256=at8DNFeGI5aVnBTi7_6zJgAi_QdgsItpBMzSGl8HEME,28247
45
- cache_dit-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
- cache_dit-0.2.13.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
47
- cache_dit-0.2.13.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
48
- cache_dit-0.2.13.dist-info/RECORD,,
44
+ cache_dit-0.2.14.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
45
+ cache_dit-0.2.14.dist-info/METADATA,sha256=EyZN75JcVcvTc5bopHXfl6w-nA-ro9Uit2Sjy5DU66A,25198
46
+ cache_dit-0.2.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
+ cache_dit-0.2.14.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
48
+ cache_dit-0.2.14.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
49
+ cache_dit-0.2.14.dist-info/RECORD,,