cache-dit 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -0,0 +1,5 @@
1
+ try:
2
+ from ._version import version as __version__, version_tuple
3
+ except ImportError:
4
+ __version__ = "unknown version"
5
+ version_tuple = (0, 0, "unknown version")
cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.12'
21
- __version_tuple__ = version_tuple = (0, 2, 12)
20
+ __version__ = version = '0.2.14'
21
+ __version_tuple__ = version_tuple = (0, 2, 14)
@@ -8,8 +8,9 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
8
8
 
9
9
  import torch
10
10
 
11
- import cache_dit.primitives as DP
11
+ import cache_dit.primitives as primitives
12
12
  from cache_dit.cache_factory.taylorseer import TaylorSeer
13
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
13
14
  from cache_dit.logger import init_logger
14
15
 
15
16
  logger = init_logger(__name__)
@@ -772,8 +773,8 @@ def are_two_tensors_similar(
772
773
  mean_t1 = t1.abs().mean()
773
774
 
774
775
  if parallelized:
775
- mean_diff = DP.all_reduce_sync(mean_diff, "avg")
776
- mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
776
+ mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
777
+ mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
777
778
 
778
779
  # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
779
780
  # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
@@ -1365,6 +1366,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1365
1366
  ):
1366
1367
  original_hidden_states = hidden_states
1367
1368
  original_encoder_hidden_states = encoder_hidden_states
1369
+ # This condition branch is mainly for FLUX series.
1368
1370
  if self.single_transformer_blocks is not None:
1369
1371
  for block in self.transformer_blocks[Fn_compute_blocks() :]:
1370
1372
  hidden_states = block(
@@ -1381,22 +1383,32 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1381
1383
  hidden_states,
1382
1384
  )
1383
1385
 
1384
- hidden_states = torch.cat(
1385
- [encoder_hidden_states, hidden_states], dim=1
1386
- )
1387
- for block in self._Mn_single_transformer_blocks():
1388
- hidden_states = block(
1389
- hidden_states,
1390
- *args,
1391
- **kwargs,
1386
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
1387
+ if is_diffusers_at_least_0_3_5():
1388
+ for block in self._Mn_single_transformer_blocks():
1389
+ encoder_hidden_states, hidden_states = block(
1390
+ hidden_states,
1391
+ encoder_hidden_states,
1392
+ *args,
1393
+ **kwargs,
1394
+ )
1395
+ else:
1396
+ hidden_states = torch.cat(
1397
+ [encoder_hidden_states, hidden_states], dim=1
1398
+ )
1399
+ for block in self._Mn_single_transformer_blocks():
1400
+ hidden_states = block(
1401
+ hidden_states,
1402
+ *args,
1403
+ **kwargs,
1404
+ )
1405
+ encoder_hidden_states, hidden_states = hidden_states.split(
1406
+ [
1407
+ encoder_hidden_states.shape[1],
1408
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
1409
+ ],
1410
+ dim=1,
1392
1411
  )
1393
- encoder_hidden_states, hidden_states = hidden_states.split(
1394
- [
1395
- encoder_hidden_states.shape[1],
1396
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
1397
- ],
1398
- dim=1,
1399
- )
1400
1412
  else:
1401
1413
  for block in self._Mn_transformer_blocks():
1402
1414
  hidden_states = block(
@@ -1789,43 +1801,71 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1789
1801
 
1790
1802
  original_hidden_states = hidden_states
1791
1803
  original_encoder_hidden_states = encoder_hidden_states
1804
+ # This condition branch is mainly for FLUX series.
1792
1805
  if self.single_transformer_blocks is not None:
1793
1806
  assert Bn_compute_blocks() <= len(self.single_transformer_blocks), (
1794
1807
  f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
1795
1808
  f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
1796
1809
  )
1797
-
1798
- hidden_states = torch.cat(
1799
- [encoder_hidden_states, hidden_states], dim=1
1800
- )
1801
- if len(Bn_compute_blocks_ids()) > 0:
1802
- for i, block in enumerate(self._Bn_single_transformer_blocks()):
1803
- hidden_states = (
1804
- self._compute_and_cache_single_transformer_block(
1805
- i,
1806
- original_hidden_states,
1807
- original_encoder_hidden_states,
1808
- block,
1810
+ if is_diffusers_at_least_0_3_5():
1811
+ if len(Bn_compute_blocks_ids()) > 0:
1812
+ # NOTE: Reuse _compute_and_cache_transformer_block here.
1813
+ for i, block in enumerate(
1814
+ self._Bn_single_transformer_blocks()
1815
+ ):
1816
+ hidden_states, encoder_hidden_states = (
1817
+ self._compute_and_cache_transformer_block(
1818
+ i,
1819
+ block,
1820
+ hidden_states,
1821
+ encoder_hidden_states,
1822
+ *args,
1823
+ **kwargs,
1824
+ )
1825
+ )
1826
+ else:
1827
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1828
+ for block in self._Bn_single_transformer_blocks():
1829
+ encoder_hidden_states, hidden_states = block(
1809
1830
  hidden_states,
1831
+ encoder_hidden_states,
1810
1832
  *args,
1811
1833
  **kwargs,
1812
1834
  )
1813
- )
1814
1835
  else:
1815
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1816
- for block in self._Bn_single_transformer_blocks():
1817
- hidden_states = block(
1818
- hidden_states,
1819
- *args,
1820
- **kwargs,
1821
- )
1822
- encoder_hidden_states, hidden_states = hidden_states.split(
1823
- [
1824
- encoder_hidden_states.shape[1],
1825
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
1826
- ],
1827
- dim=1,
1828
- )
1836
+ hidden_states = torch.cat(
1837
+ [encoder_hidden_states, hidden_states], dim=1
1838
+ )
1839
+ if len(Bn_compute_blocks_ids()) > 0:
1840
+ for i, block in enumerate(
1841
+ self._Bn_single_transformer_blocks()
1842
+ ):
1843
+ hidden_states = (
1844
+ self._compute_and_cache_single_transformer_block(
1845
+ i,
1846
+ original_hidden_states,
1847
+ original_encoder_hidden_states,
1848
+ block,
1849
+ hidden_states,
1850
+ *args,
1851
+ **kwargs,
1852
+ )
1853
+ )
1854
+ else:
1855
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1856
+ for block in self._Bn_single_transformer_blocks():
1857
+ hidden_states = block(
1858
+ hidden_states,
1859
+ *args,
1860
+ **kwargs,
1861
+ )
1862
+ encoder_hidden_states, hidden_states = hidden_states.split(
1863
+ [
1864
+ encoder_hidden_states.shape[1],
1865
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
1866
+ ],
1867
+ dim=1,
1868
+ )
1829
1869
  else:
1830
1870
  assert Bn_compute_blocks() <= len(self.transformer_blocks), (
1831
1871
  f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
@@ -7,7 +7,8 @@ from typing import Any, Dict, List, Optional, Union
7
7
 
8
8
  import torch
9
9
 
10
- import cache_dit.primitives as DP
10
+ import cache_dit.primitives as primitives
11
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
11
12
  from cache_dit.logger import init_logger
12
13
 
13
14
  logger = init_logger(__name__)
@@ -446,8 +447,8 @@ def are_two_tensors_similar(
446
447
  mean_t1 = t1.abs().mean()
447
448
 
448
449
  if parallelized:
449
- mean_diff = DP.all_reduce_sync(mean_diff, "avg")
450
- mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
450
+ mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
451
+ mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
451
452
 
452
453
  # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
453
454
  # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
@@ -936,27 +937,44 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
936
937
  )
937
938
 
938
939
  if self.single_transformer_blocks is not None:
939
- hidden_states = torch.cat(
940
- [encoder_hidden_states, hidden_states], dim=1
941
- )
942
- for j, block in enumerate(self.single_transformer_blocks):
943
- hidden_states = self._compute_or_prune_single_transformer_block(
944
- j + len(self.transformer_blocks),
945
- original_hidden_states,
946
- original_encoder_hidden_states,
947
- block,
948
- hidden_states,
949
- *args,
950
- **kwargs,
940
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
941
+ if is_diffusers_at_least_0_3_5():
942
+ for j, block in enumerate(self.single_transformer_blocks):
943
+ # NOTE: Reuse _compute_or_prune_transformer_block here.
944
+ hidden_states, encoder_hidden_states = (
945
+ self._compute_or_prune_transformer_block(
946
+ j + len(self.transformer_blocks),
947
+ block,
948
+ hidden_states,
949
+ encoder_hidden_states,
950
+ *args,
951
+ **kwargs,
952
+ )
953
+ )
954
+ else:
955
+ hidden_states = torch.cat(
956
+ [encoder_hidden_states, hidden_states], dim=1
951
957
  )
958
+ for j, block in enumerate(self.single_transformer_blocks):
959
+ hidden_states = (
960
+ self._compute_or_prune_single_transformer_block(
961
+ j + len(self.transformer_blocks),
962
+ original_hidden_states,
963
+ original_encoder_hidden_states,
964
+ block,
965
+ hidden_states,
966
+ *args,
967
+ **kwargs,
968
+ )
969
+ )
952
970
 
953
- encoder_hidden_states, hidden_states = hidden_states.split(
954
- [
955
- encoder_hidden_states.shape[1],
956
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
957
- ],
958
- dim=1,
959
- )
971
+ encoder_hidden_states, hidden_states = hidden_states.split(
972
+ [
973
+ encoder_hidden_states.shape[1],
974
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
975
+ ],
976
+ dim=1,
977
+ )
960
978
 
961
979
  hidden_states = (
962
980
  hidden_states.reshape(-1)
@@ -7,7 +7,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
7
7
 
8
8
  import torch
9
9
 
10
- import cache_dit.primitives as DP
10
+ import cache_dit.primitives as primitives
11
11
  from cache_dit.cache_factory.taylorseer import TaylorSeer
12
12
  from cache_dit.logger import init_logger
13
13
 
@@ -348,8 +348,8 @@ def are_two_tensors_similar(
348
348
  mean_diff = (t1 - t2).abs().mean()
349
349
  mean_t1 = t1.abs().mean()
350
350
  if parallelized:
351
- mean_diff = DP.all_reduce_sync(mean_diff, "avg")
352
- mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
351
+ mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
352
+ mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
353
353
  diff = (mean_diff / mean_t1).item()
354
354
 
355
355
  add_residual_diff(diff)
@@ -0,0 +1,43 @@
1
+ import builtins as __builtin__
2
+ import contextlib
3
+ import warnings
4
+
5
+ import lpips
6
+ import torch
7
+
8
+ warnings.filterwarnings("ignore")
9
+
10
+ lpips_loss_fn_vgg = None
11
+ lpips_loss_fn_alex = None
12
+
13
+
14
+ def dummy_print(*args, **kwargs):
15
+ pass
16
+
17
+
18
+ @contextlib.contextmanager
19
+ def disable_print():
20
+ origin_print = __builtin__.print
21
+ __builtin__.print = dummy_print
22
+ yield
23
+ __builtin__.print = origin_print
24
+
25
+
26
+ def compute_lpips_img(img0, img1, net: str = "alex"):
27
+ global lpips_loss_fn_vgg
28
+ global lpips_loss_fn_alex
29
+ if net.lower() == "alex":
30
+ if lpips_loss_fn_alex is None:
31
+ with disable_print():
32
+ lpips_loss_fn_alex = lpips.LPIPS(net="alex")
33
+ loss_fn = lpips_loss_fn_alex
34
+ elif net.lower() == "vgg":
35
+ if lpips_loss_fn_vgg is None:
36
+ with disable_print():
37
+ lpips_loss_fn_vgg = lpips.LPIPS(net="vgg")
38
+ loss_fn = lpips_loss_fn_vgg
39
+ else:
40
+ assert False, f"unsupport net {net}"
41
+
42
+ with torch.no_grad():
43
+ return loss_fn(img0, img1).item()
@@ -14,6 +14,7 @@ from cache_dit.metrics.config import get_metrics_verbose
14
14
  from cache_dit.metrics.config import _IMAGE_EXTENSIONS
15
15
  from cache_dit.metrics.config import _VIDEO_EXTENSIONS
16
16
  from cache_dit.logger import init_logger
17
+ from cache_dit.metrics.lpips import compute_lpips_img
17
18
 
18
19
  logger = init_logger(__name__)
19
20
 
@@ -21,6 +22,35 @@ logger = init_logger(__name__)
21
22
  DISABLE_VERBOSE = not get_metrics_verbose()
22
23
 
23
24
 
25
+ def compute_lpips_file(
26
+ image_true: np.ndarray | str,
27
+ image_test: np.ndarray | str,
28
+ ) -> float:
29
+ import torch
30
+ from PIL import Image
31
+ from torchvision.transforms.v2.functional import (
32
+ convert_image_dtype,
33
+ normalize,
34
+ pil_to_tensor,
35
+ )
36
+
37
+ def load_img_as_tensor(path):
38
+ pil = Image.open(path)
39
+ img = pil_to_tensor(pil)
40
+ img = convert_image_dtype(img, dtype=torch.float32)
41
+ img = normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
42
+ return img
43
+
44
+ if isinstance(image_true, str):
45
+ image_true = load_img_as_tensor(image_true)
46
+ if isinstance(image_test, str):
47
+ image_test = load_img_as_tensor(image_test)
48
+ return compute_lpips_img(
49
+ image_true,
50
+ image_test,
51
+ )
52
+
53
+
24
54
  def compute_psnr_file(
25
55
  image_true: np.ndarray | str,
26
56
  image_test: np.ndarray | str,
@@ -305,6 +335,11 @@ def compute_video_metric(
305
335
  return None, None
306
336
 
307
337
 
338
+ compute_lpips = partial(
339
+ compute_dir_metric,
340
+ compute_file_func=compute_lpips_file,
341
+ )
342
+
308
343
  compute_psnr = partial(
309
344
  compute_dir_metric,
310
345
  compute_file_func=compute_psnr_file,
@@ -320,6 +355,10 @@ compute_mse = partial(
320
355
  compute_file_func=compute_mse_file,
321
356
  )
322
357
 
358
+ compute_video_lpips = partial(
359
+ compute_video_metric,
360
+ compute_frame_func=compute_lpips_file,
361
+ )
323
362
  compute_video_psnr = partial(
324
363
  compute_video_metric,
325
364
  compute_frame_func=compute_psnr_file,
@@ -335,6 +374,7 @@ compute_video_mse = partial(
335
374
 
336
375
 
337
376
  METRICS_CHOICES = [
377
+ "lpips",
338
378
  "psnr",
339
379
  "ssim",
340
380
  "mse",
@@ -522,6 +562,9 @@ def entrypoint():
522
562
  METRICS_META[msg] = value
523
563
  logger.info(msg)
524
564
 
565
+ if metric == "lpips" or metric == "all":
566
+ img_lpips, n = compute_lpips(img_true, img_test)
567
+ _logging_msg(img_lpips, "lpips", n)
525
568
  if metric == "psnr" or metric == "all":
526
569
  img_psnr, n = compute_psnr(img_true, img_test)
527
570
  _logging_msg(img_psnr, "psnr", n)
@@ -558,6 +601,9 @@ def entrypoint():
558
601
  METRICS_META[msg] = value
559
602
  logger.info(msg)
560
603
 
604
+ if metric == "lpips" or metric == "all":
605
+ video_lpips, n = compute_video_lpips(video_true, video_test)
606
+ _logging_msg(video_lpips, "lpips", n)
561
607
  if metric == "psnr" or metric == "all":
562
608
  video_psnr, n = compute_video_psnr(video_true, video_test)
563
609
  _logging_msg(video_psnr, "psnr", n)
cache_dit/utils.py ADDED
@@ -0,0 +1,7 @@
1
+ import torch
2
+ import diffusers
3
+
4
+
5
+ @torch.compiler.disable
6
+ def is_diffusers_at_least_0_3_5() -> bool:
7
+ return diffusers.__version__ >= "0.35.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.12
3
+ Version: 0.2.14
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -16,6 +16,7 @@ Requires-Dist: transformers>=4.51.3
16
16
  Requires-Dist: diffusers>=0.33.1
17
17
  Requires-Dist: scikit-image
18
18
  Requires-Dist: scipy
19
+ Requires-Dist: lpips==0.1.4
19
20
  Provides-Extra: all
20
21
  Provides-Extra: dev
21
22
  Requires-Dist: pre-commit; extra == "dev"
@@ -52,7 +53,7 @@ Dynamic: requires-python
52
53
  <img src=https://img.shields.io/badge/Release-v0.2-brightgreen.svg >
53
54
  </div>
54
55
  <p align="center">
55
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers <br>a set of training-free cache accelerators for DiT: <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">TaylorSeer</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
56
+ DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. <br> CacheDiT offers a set of training-free cache accelerators for Diffusion Transformers: <br> <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">Hybrid TaylorSeer</a>, <a href="#cfg">Hybrid Cache CFG</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
56
57
  </p>
57
58
  </div>
58
59
 
@@ -66,95 +67,6 @@ Dynamic: requires-python
66
67
  - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the **[PR](https://github.com/huggingface/flux-fast/pull/13)**.
67
68
  - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
68
69
 
69
- ## 🤗 Introduction
70
-
71
- <div align="center">
72
- <p align="center">
73
- <h3>🔥DBCache: Dual Block Caching for Diffusion Transformers</h3>
74
- </p>
75
- </div>
76
-
77
- **DBCache**: **Dual Block Caching** for Diffusion Transformers. We have enhanced `FBCache` into a more general and customizable cache algorithm, namely `DBCache`, enabling it to achieve fully `UNet-style` cache acceleration for DiT models. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache. Moreover, it can be entirely **training**-**free**. DBCache can strike a perfect **balance** between performance and precision!
78
-
79
- <div align="center">
80
- <p align="center">
81
- DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
82
- </p>
83
- </div>
84
-
85
- |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
86
- |:---:|:---:|:---:|:---:|:---:|:---:|
87
- |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
88
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
89
- |**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
90
- |27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
91
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
92
-
93
- <div align="center">
94
- <p align="center">
95
- DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
96
- </p>
97
- </div>
98
-
99
- These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
100
-
101
- <div align="center">
102
- <p align="center">
103
- <h3>🔥DBPrune: Dynamic Block Prune with Residual Caching</h3>
104
- </p>
105
- </div>
106
-
107
- **DBPrune**: We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then **dynamically prunes** blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.
108
-
109
- <div align="center">
110
- <p align="center">
111
- DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
112
- </p>
113
- </div>
114
-
115
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
116
- |:---:|:---:|:---:|:---:|:---:|:---:|
117
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
118
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
119
-
120
- <div align="center">
121
- <p align="center">
122
- <h3>🔥Context Parallelism and Torch Compile</h3>
123
- </p>
124
- </div>
125
-
126
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
127
-
128
- <div align="center">
129
- <p align="center">
130
- DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
131
- </p>
132
- </div>
133
-
134
- |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
135
- |:---:|:---:|:---:|:---:|:---:|:---:|
136
- |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
137
- |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
138
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
139
-
140
- ## ©️Citations
141
-
142
- ```BibTeX
143
- @misc{CacheDiT@2025,
144
- title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
145
- url={https://github.com/vipshop/cache-dit.git},
146
- note={Open-source software available at https://github.com/vipshop/cache-dit.git},
147
- author={vipshop.com},
148
- year={2025}
149
- }
150
- ```
151
-
152
- ## 👋Reference
153
-
154
- <div id="reference"></div>
155
-
156
- The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work!
157
-
158
70
  ## 📖Contents
159
71
 
160
72
  <div id="contents"></div>
@@ -171,6 +83,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
171
83
  - [⚙️Metrics CLI](#metrics)
172
84
  - [👋Contribute](#contribute)
173
85
  - [©️License](#license)
86
+ - [©️Citations](#citations)
174
87
 
175
88
  ## ⚙️Installation
176
89
 
@@ -207,6 +120,32 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
207
120
 
208
121
  ![](https://github.com/vipshop/cache-dit/raw/main/assets/dbcache-v1.png)
209
122
 
123
+
124
+ **DBCache**: **Dual Block Caching** for Diffusion Transformers. We have enhanced `FBCache` into a more general and customizable cache algorithm, namely `DBCache`, enabling it to achieve fully `UNet-style` cache acceleration for DiT models. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache. Moreover, it can be entirely **training**-**free**. DBCache can strike a perfect **balance** between performance and precision!
125
+
126
+ <div align="center">
127
+ <p align="center">
128
+ DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
129
+ </p>
130
+ </div>
131
+
132
+ |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
133
+ |:---:|:---:|:---:|:---:|:---:|:---:|
134
+ |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
135
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
136
+ |**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
137
+ |27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
138
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
139
+
140
+ <div align="center">
141
+ <p align="center">
142
+ DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
143
+ </p>
144
+ </div>
145
+
146
+ These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
147
+
148
+
210
149
  **DBCache** provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
211
150
 
212
151
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
@@ -258,17 +197,6 @@ cache_options = {
258
197
  }
259
198
  ```
260
199
 
261
- <div align="center">
262
- <p align="center">
263
- DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
264
- </p>
265
- </div>
266
-
267
- |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
268
- |:---:|:---:|:---:|:---:|:---:|:---:|
269
- |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
270
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
271
-
272
200
  ## 🔥Hybrid TaylorSeer
273
201
 
274
202
  <div id="taylorseer"></div>
@@ -489,6 +417,18 @@ Then, run the python test script with `torchrun`:
489
417
  torchrun --nproc_per_node=4 parallel_cache.py
490
418
  ```
491
419
 
420
+ <div align="center">
421
+ <p align="center">
422
+ DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
423
+ </p>
424
+ </div>
425
+
426
+ |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
427
+ |:---:|:---:|:---:|:---:|:---:|:---:|
428
+ |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
429
+ |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
430
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
431
+
492
432
  ## 🔥Torch Compile
493
433
 
494
434
  <div id="compile"></div>
@@ -550,5 +490,18 @@ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
550
490
 
551
491
  <div id="license"></div>
552
492
 
493
+ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from [FBCache](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
553
494
 
554
- We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
495
+ ## ©️Citations
496
+
497
+ <div id="citations"></div>
498
+
499
+ ```BibTeX
500
+ @misc{CacheDiT@2025,
501
+ title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
502
+ url={https://github.com/vipshop/cache-dit.git},
503
+ note={Open-source software available at https://github.com/vipshop/cache-dit.git},
504
+ author={vipshop.com},
505
+ year={2025}
506
+ }
507
+ ```
@@ -1,13 +1,14 @@
1
- cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=7CFcHKqzy7OglwTX58ipGg1TD8MpDORqWvEBO3W1dHI,513
1
+ cache_dit/__init__.py,sha256=0-B173-fLi3IA8nJXoS71zK0zD33Xplysd9skmLfEOY,171
2
+ cache_dit/_version.py,sha256=ut2sCt69XoYh0A1_KAmfCg1IKkN6zwqhu2eMFWAhMbQ,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
+ cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
5
6
  cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
6
7
  cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
7
8
  cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
8
9
  cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
9
10
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=itVEb6gT2eZuncAHUmP51ZS0r6v6cGtRvnPjyeXqKH8,71156
11
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=sJ9yxQlcrX4qkPln94FrL0WDe2WIn3_UD2-Mk8YtjSw,73301
11
12
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
12
13
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=3xUjvDzor9AkBkDUc0N7kZqM86MIdajuigesnicNzXE,2260
13
14
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=cIsov6Pf0dRyddqkzTA2CU-jSDotof8LQr-HIoY9T9M,2615
@@ -15,7 +16,7 @@ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha
15
16
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=8W9m-WeEVE2ytYi9udKEA8Wtb0EnvP3eT2A1Tu-d29k,2252
16
17
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
17
18
  cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=so1wGdb8W0ATwrjv7E5IEZLPcobybaY1HJa6hBYlOOQ,34698
19
+ cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=1qarKAsEFiaaN2_ghko2dqGz_R7BTQSOyGtb_eQq38Y,35716
19
20
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
20
21
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=KP8NxtHAKzzBOoX0lhvlMgY_5dmP4Z3T5TOfwl4SSyg,2273
21
22
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=kCB7lL4OIq8TZn-baMIF8D_PVPTFW60omCMVQCb8ebs,2628
@@ -23,7 +24,7 @@ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,
23
24
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
24
25
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
25
26
  cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=tTPwhPLEA7LqGupps1Zy2MycCtLzs22wsW0yUhiiF-U,23217
27
+ cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=qn4zWJ_eEMIPYzrxXoslunxbzK0WueuNtC54Pp5Q57k,23241
27
28
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
28
29
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
29
30
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
@@ -38,10 +39,11 @@ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE
38
39
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
39
40
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
40
41
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
41
- cache_dit/metrics/metrics.py,sha256=1TTbfaj_-vdUfxopLnc5kVrXs5rMpAoSi8D0ItYdPu8,26439
42
- cache_dit-0.2.12.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.12.dist-info/METADATA,sha256=-AIWGVOFsY-nhMkDeFErUFcELTWmza96-0IUN3od88A,28219
44
- cache_dit-0.2.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.12.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.12.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.12.dist-info/RECORD,,
42
+ cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
43
+ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
44
+ cache_dit-0.2.14.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
45
+ cache_dit-0.2.14.dist-info/METADATA,sha256=EyZN75JcVcvTc5bopHXfl6w-nA-ro9Uit2Sjy5DU66A,25198
46
+ cache_dit-0.2.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
+ cache_dit-0.2.14.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
48
+ cache_dit-0.2.14.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
49
+ cache_dit-0.2.14.dist-info/RECORD,,