cache-dit 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (27) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/dual_block_cache/cache_context.py +282 -57
  3. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -2
  4. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -2
  5. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -2
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -1
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -1
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +1 -3
  9. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -2
  10. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -2
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -2
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -1
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -2
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -2
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +23 -23
  16. cache_dit/cache_factory/first_block_cache/cache_context.py +3 -11
  17. cache_dit/cache_factory/taylorseer.py +29 -0
  18. cache_dit/compile/__init__.py +1 -0
  19. cache_dit/compile/utils.py +94 -0
  20. cache_dit/custom_ops/__init__.py +0 -0
  21. cache_dit/custom_ops/triton_taylorseer.py +0 -0
  22. cache_dit/logger.py +28 -0
  23. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/METADATA +76 -39
  24. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/RECORD +27 -23
  25. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/WHEEL +0 -0
  26. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/licenses/LICENSE +0 -0
  27. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/top_level.txt +0 -0
@@ -55,6 +55,7 @@ class DBPPruneContext:
55
55
  default_factory=lambda: defaultdict(list),
56
56
  )
57
57
 
58
+ @torch.compiler.disable
58
59
  def get_residual_diff_threshold(self):
59
60
  residual_diff_threshold = self.residual_diff_threshold
60
61
  if self.l1_hidden_states_diff_threshold is not None:
@@ -98,19 +99,24 @@ class DBPPruneContext:
98
99
  )
99
100
  return residual_diff_threshold
100
101
 
102
+ @torch.compiler.disable
101
103
  def get_buffer(self, name):
102
104
  return self.buffers.get(name)
103
105
 
106
+ @torch.compiler.disable
104
107
  def set_buffer(self, name, buffer):
105
108
  self.buffers[name] = buffer
106
109
 
110
+ @torch.compiler.disable
107
111
  def remove_buffer(self, name):
108
112
  if name in self.buffers:
109
113
  del self.buffers[name]
110
114
 
115
+ @torch.compiler.disable
111
116
  def clear_buffers(self):
112
117
  self.buffers.clear()
113
118
 
119
+ @torch.compiler.disable
114
120
  def mark_step_begin(self):
115
121
  self.executed_steps += 1
116
122
  if self.get_current_step() == 0:
@@ -118,12 +124,15 @@ class DBPPruneContext:
118
124
  self.actual_blocks.clear()
119
125
  self.residual_diffs.clear()
120
126
 
127
+ @torch.compiler.disable
121
128
  def add_pruned_block(self, num_blocks):
122
129
  self.pruned_blocks.append(num_blocks)
123
130
 
131
+ @torch.compiler.disable
124
132
  def add_actual_block(self, num_blocks):
125
133
  self.actual_blocks.append(num_blocks)
126
134
 
135
+ @torch.compiler.disable
127
136
  def add_residual_diff(self, diff):
128
137
  if isinstance(diff, torch.Tensor):
129
138
  diff = diff.item()
@@ -141,9 +150,11 @@ class DBPPruneContext:
141
150
  f"residual diff: {diff:.6f}"
142
151
  )
143
152
 
153
+ @torch.compiler.disable
144
154
  def get_current_step(self):
145
155
  return self.executed_steps - 1
146
156
 
157
+ @torch.compiler.disable
147
158
  def is_in_warmup(self):
148
159
  return self.get_current_step() < self.warmup_steps
149
160
 
@@ -348,11 +359,19 @@ def collect_prune_kwargs(default_attrs: dict, **kwargs):
348
359
  )
349
360
  for attr in prune_attrs
350
361
  }
362
+
351
363
  # Manually set sequence fields, such as non_prune_blocks_ids
352
- prune_kwargs["non_prune_blocks_ids"] = kwargs.pop(
353
- "non_prune_blocks_ids",
354
- [],
355
- )
364
+ def _safe_set_sequence_field(
365
+ field_name: str,
366
+ default_value: Any = None,
367
+ ):
368
+ if field_name not in prune_kwargs:
369
+ prune_kwargs[field_name] = kwargs.pop(
370
+ field_name,
371
+ default_value,
372
+ )
373
+
374
+ _safe_set_sequence_field("non_prune_blocks_ids", [])
356
375
 
357
376
  assert default_attrs is not None, "default_attrs must be set before"
358
377
  for attr in prune_attrs:
@@ -627,10 +646,6 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
627
646
  ]
628
647
  return sorted(non_prune_blocks_ids)
629
648
 
630
- # @torch.compile(dynamic=True)
631
- # mark this function as compile with dynamic=True will
632
- # cause precision degradate, so, we choose to disable it
633
- # now, until we find a better solution or fixed the bug.
634
649
  @torch.compiler.disable
635
650
  def _compute_single_hidden_states_residual(
636
651
  self,
@@ -667,10 +682,6 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
667
682
  single_encoder_hidden_states_residual,
668
683
  )
669
684
 
670
- # @torch.compile(dynamic=True)
671
- # mark this function as compile with dynamic=True will
672
- # cause precision degradate, so, we choose to disable it
673
- # now, until we find a better solution or fixed the bug.
674
685
  @torch.compiler.disable
675
686
  def _split_single_hidden_states(
676
687
  self,
@@ -969,17 +980,6 @@ def patch_pruned_stats(
969
980
  if transformer is None:
970
981
  return
971
982
 
972
- pruned_transformer_blocks = getattr(transformer, "transformer_blocks", None)
973
- if pruned_transformer_blocks is None:
974
- return
975
-
976
- if isinstance(pruned_transformer_blocks, torch.nn.ModuleList):
977
- pruned_transformer_blocks = pruned_transformer_blocks[0]
978
- if not isinstance(
979
- pruned_transformer_blocks, DBPrunedTransformerBlocks
980
- ) or not isinstance(transformer, torch.nn.Module):
981
- return
982
-
983
983
  # TODO: Patch more pruned stats to the transformer
984
984
  transformer._pruned_blocks = get_pruned_blocks()
985
985
  transformer._pruned_steps = get_pruned_steps()
@@ -370,6 +370,9 @@ def apply_prev_hidden_states_residual(
370
370
  hidden_states = hidden_states_residual + hidden_states
371
371
 
372
372
  hidden_states = hidden_states.contiguous()
373
+ # NOTE: We should also support taylorseer for
374
+ # encoder_hidden_states approximation. Please
375
+ # use DBCache instead.
373
376
  else:
374
377
  hidden_states_residual = get_hidden_states_residual()
375
378
  assert (
@@ -711,17 +714,6 @@ def patch_cached_stats(
711
714
  if transformer is None:
712
715
  return
713
716
 
714
- cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
715
- if cached_transformer_blocks is None:
716
- return
717
-
718
- if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
719
- cached_transformer_blocks = cached_transformer_blocks[0]
720
- if not isinstance(
721
- cached_transformer_blocks, CachedTransformerBlocks
722
- ) or not isinstance(transformer, torch.nn.Module):
723
- return
724
-
725
717
  # TODO: Patch more cached stats to the transformer
726
718
  transformer._cached_steps = get_cached_steps()
727
719
  transformer._residual_diffs = get_residual_diffs()
@@ -1,5 +1,7 @@
1
1
  # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/taylorseer.py
2
+ # Reference: https://github.com/Shenyi-Z/TaylorSeer/TaylorSeer-FLUX/src/flux/taylor_utils/__init__.py
2
3
  import math
4
+ import torch
3
5
 
4
6
 
5
7
  class TaylorSeer:
@@ -17,6 +19,7 @@ class TaylorSeer:
17
19
  self.compute_step_map = compute_step_map
18
20
  self.reset_cache()
19
21
 
22
+ @torch.compiler.disable
20
23
  def reset_cache(self):
21
24
  self.state = {
22
25
  "dY_prev": [None] * self.ORDER,
@@ -25,6 +28,7 @@ class TaylorSeer:
25
28
  self.current_step = -1
26
29
  self.last_non_approximated_step = -1
27
30
 
31
+ @torch.compiler.disable
28
32
  def should_compute_full(self, step=None):
29
33
  step = self.current_step if step is None else step
30
34
  if self.compute_step_map is not None:
@@ -36,7 +40,13 @@ class TaylorSeer:
36
40
  return True
37
41
  return False
38
42
 
43
+ @torch.compiler.disable
39
44
  def approximate_derivative(self, Y):
45
+ # n-th order Taylor expansion:
46
+ # Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
47
+ # + ... + d^nY(0)/dt^n * t^n / n!
48
+ # TODO: Custom Triton/CUDA kernel for better performance,
49
+ # especially for large n_derivatives.
40
50
  dY_current = [None] * self.ORDER
41
51
  dY_current[0] = Y
42
52
  window = self.current_step - self.last_non_approximated_step
@@ -49,7 +59,10 @@ class TaylorSeer:
49
59
  break
50
60
  return dY_current
51
61
 
62
+ @torch.compiler.disable
52
63
  def approximate_value(self):
64
+ # TODO: Custom Triton/CUDA kernel for better performance,
65
+ # especially for large n_derivatives.
53
66
  elapsed = self.current_step - self.last_non_approximated_step
54
67
  output = 0
55
68
  for i, derivative in enumerate(self.state["dY_current"]):
@@ -59,14 +72,30 @@ class TaylorSeer:
59
72
  break
60
73
  return output
61
74
 
75
+ @torch.compiler.disable
62
76
  def mark_step_begin(self):
63
77
  self.current_step += 1
64
78
 
79
+ @torch.compiler.disable
65
80
  def update(self, Y):
81
+ # Directly call this method will ingnore the warmup
82
+ # policy and force full computation.
83
+ # Assume warmup steps is 3, and n_derivatives is 3.
84
+ # step 0: dY_prev = [None, None, None, None ]
85
+ # dY_current = [Y0, None, None, None ]
86
+ # step 1: dY_prev = [Y0, None, None, None ]
87
+ # dY_current = [Y1, dY1, None, None ]
88
+ # step 2: dY_prev = [Y1, dY1, None, None ]
89
+ # dY_current = [Y2, dY2/Y1, dY2/dY1, None ]
90
+ # step 3: dY_prev = [Y2, dY2/Y1, dY2/dY1, None ],
91
+ # dY_current = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
92
+ # step 4: dY_prev = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
93
+ # dY_current = [Y4, dY4/Y3, dY4/dY3, dY4/dY2]
66
94
  self.state["dY_prev"] = self.state["dY_current"]
67
95
  self.state["dY_current"] = self.approximate_derivative(Y)
68
96
  self.last_non_approximated_step = self.current_step
69
97
 
98
+ @torch.compiler.disable
70
99
  def step(self, Y):
71
100
  self.mark_step_begin()
72
101
  if self.should_compute_full():
@@ -0,0 +1 @@
1
+ from cache_dit.compile.utils import set_custom_compile_configs
@@ -0,0 +1,94 @@
1
+ import os
2
+
3
+ import torch
4
+ from cache_dit.logger import init_logger, logging_rank_0
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
10
+ mode = kwargs.get("epilogue_prologue_fusion", False)
11
+ CACHE_DIT_EPILOGUE_PROLOGUE_FUSION = bool(
12
+ int(os.environ.get("CACHE_DIT_EPILOGUE_PROLOGUE_FUSION", "0"))
13
+ )
14
+
15
+ if CACHE_DIT_EPILOGUE_PROLOGUE_FUSION:
16
+ logging_rank_0(
17
+ logger,
18
+ "CACHE_DIT_EPILOGUE_PROLOGUE_FUSION is set to 1. \n"
19
+ "Force enable epilogue and prologue fusion.",
20
+ )
21
+
22
+ return CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
23
+
24
+
25
+ def set_custom_compile_configs(
26
+ cuda_graphs: bool = False,
27
+ force_disable_compile_caches: bool = False,
28
+ use_fast_math: bool = False,
29
+ **kwargs, # other kwargs
30
+ ):
31
+ # Alway increase recompile_limit for dynamic shape compilation
32
+ torch._dynamo.config.recompile_limit = 96 # default is 8
33
+ torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
34
+ # Handle compiler caches
35
+ # https://github.com/vllm-project/vllm/blob/23baa2180b0ebba5ae94073ba9b8e93f88b75486/vllm/compilation/compiler_interface.py#L270
36
+ torch._inductor.config.fx_graph_cache = True
37
+ torch._inductor.config.fx_graph_remote_cache = False
38
+ # https://github.com/pytorch/pytorch/issues/153791
39
+ torch._inductor.config.autotune_local_cache = False
40
+
41
+ FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
42
+ os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
43
+ == "1"
44
+ )
45
+ if FORCE_DISABLE_CUSTOM_COMPILE_CONFIG:
46
+ logging_rank_0(
47
+ logger,
48
+ "CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG is set to 1. \n"
49
+ "Force disable custom compile config.",
50
+ )
51
+ return
52
+
53
+ # Enable compute comm overlap
54
+ torch._inductor.config.reorder_for_compute_comm_overlap = True
55
+ # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
56
+ torch._inductor.config.intra_node_bw = (
57
+ 64 if "L20" in torch.cuda.get_device_name() else 300
58
+ )
59
+
60
+ # Below are default settings for torch.compile, you can change
61
+ # them to your needs and test the performance
62
+ torch._inductor.config.max_fusion_size = 64
63
+ torch._inductor.config.max_pointwise_cat_inputs = 8
64
+ torch._inductor.config.triton.cudagraphs = cuda_graphs
65
+ torch._inductor.config.triton.use_block_ptr = False
66
+ torch._inductor.config.triton.codegen_upcast_to_fp32 = True
67
+
68
+ # Copy from https://pytorch.org/blog/accelerating-generative-ai-3/
69
+ torch._inductor.config.conv_1x1_as_mm = True
70
+ torch._inductor.config.coordinate_descent_tuning = True
71
+ torch._inductor.config.coordinate_descent_check_all_directions = True
72
+ torch._inductor.config.epilogue_fusion = False
73
+
74
+ # Enable epilogue and prologue fusion
75
+ if epilogue_prologue_fusion_enabled(**kwargs):
76
+ torch._inductor.config.epilogue_fusion = True
77
+ torch._inductor.config.prologue_fusion = True
78
+ torch._inductor.config.epilogue_fusion_first = True
79
+
80
+ # Dead code elimination
81
+ torch._inductor.config.dce = True # default is False
82
+
83
+ # May need to force disable all cache
84
+ if force_disable_compile_caches:
85
+ torch._inductor.config.force_disable_caches = True
86
+ torch._inductor.config.fx_graph_cache = False
87
+ torch._inductor.config.fx_graph_remote_cache = False
88
+ torch._inductor.config.autotune_local_cache = False # default is True
89
+
90
+ # Use fast math
91
+ if hasattr(torch._inductor.config, "use_fast_math"):
92
+ torch._inductor.config.use_fast_math = use_fast_math
93
+ if hasattr(torch._inductor.config, "cuda.use_fast_math"):
94
+ torch._inductor.config.cuda.use_fast_math = use_fast_math
File without changes
File without changes
cache_dit/logger.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  import sys
4
+ import torch.distributed as dist
4
5
 
5
6
  _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
6
7
  _DATE_FORMAT = "%m-%d %H:%M:%S"
@@ -95,3 +96,30 @@ def init_logger(name: str):
95
96
  logger.addHandler(_inference_log_file_handler[pid])
96
97
  logger.propagate = False
97
98
  return logger
99
+
100
+
101
+ def logging_rank_0(
102
+ logger: logging.Logger, message: str, level: int = logging.INFO
103
+ ):
104
+ if not isinstance(logger, logging.Logger):
105
+ raise TypeError("logger must be an instance of logging.Logger")
106
+ if not isinstance(message, str):
107
+ raise TypeError("message must be a string")
108
+ if not isinstance(level, int):
109
+ raise TypeError("level must be an integer representing a logging level")
110
+
111
+ def _logging_msg():
112
+ if level == logging.DEBUG:
113
+ logger.debug(message)
114
+ elif level == logging.WARNING:
115
+ logger.warning(message)
116
+ elif level == logging.ERROR:
117
+ logger.error(message)
118
+ else:
119
+ logger.info(message)
120
+
121
+ if dist.is_initialized():
122
+ if dist.get_rank() == 0:
123
+ _logging_msg()
124
+ else:
125
+ _logging_msg()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -37,31 +37,31 @@ Dynamic: requires-python
37
37
  <p align="center">
38
38
  <h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
39
39
  </p>
40
- <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
40
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-v1.png >
41
41
  <div align='center'>
42
42
  <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
43
43
  <img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.2.1-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.2.2-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
50
+ DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers <br>a set of training-free cache accelerators for DiT: <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">TaylorSeer</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
51
51
  </p>
52
52
  </div>
53
53
 
54
- ## 👋 Highlight
55
-
56
- <div id="reference"></div>
57
-
58
- The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! The **FBCache** support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache).
54
+ <div align="center">
55
+ <p align="center">
56
+ <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
57
+ </p>
58
+ </div>
59
59
 
60
60
  ## 🤗 Introduction
61
61
 
62
62
  <div align="center">
63
63
  <p align="center">
64
- <h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
64
+ <h3>🔥DBCache: Dual Block Caching for Diffusion Transformers</h3>
65
65
  </p>
66
66
  </div>
67
67
 
@@ -77,9 +77,9 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
77
77
  |:---:|:---:|:---:|:---:|:---:|:---:|
78
78
  |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
79
79
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
80
- |**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.20)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
80
+ |**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
81
81
  |27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
82
- |<img src=https://github.com/user-attachments/assets/70ea57f4-d8f2-415b-8a96-d8315974a5e6 width=105px>|<img src=https://github.com/user-attachments/assets/fc0e1a67-19cc-44aa-bf50-04696e7978a0 width=105px> |<img src=https://github.com/user-attachments/assets/d1434896-628c-436b-95ad-43c085a8629e width=105px>|<img src=https://github.com/user-attachments/assets/aaa42cd2-57de-4c4e-8bfb-913018a8251d width=105px>|<img src=https://github.com/user-attachments/assets/dc0ba2a4-ef7c-436d-8a39-67055deab92f width=105px>|<img src=https://github.com/user-attachments/assets/aede466f-61ed-4256-8df0-fecf8020c5ca width=105px>|
82
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
83
83
 
84
84
  <div align="center">
85
85
  <p align="center">
@@ -91,7 +91,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
91
91
 
92
92
  <div align="center">
93
93
  <p align="center">
94
- <h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
94
+ <h3>🔥DBPrune: Dynamic Block Prune with Residual Caching</h3>
95
95
  </p>
96
96
  </div>
97
97
 
@@ -110,11 +110,11 @@ These case studies demonstrate that even with relatively high thresholds (such a
110
110
 
111
111
  <div align="center">
112
112
  <p align="center">
113
- <h3>🔥 Context Parallelism and Torch Compile</h3>
113
+ <h3>🔥Context Parallelism and Torch Compile</h3>
114
114
  </p>
115
115
  </div>
116
116
 
117
- Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. By the way, CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
117
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
118
118
 
119
119
  <div align="center">
120
120
  <p align="center">
@@ -128,12 +128,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
128
128
  |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
129
129
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
130
130
 
131
- <div align="center">
132
- <p align="center">
133
- <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
134
- </p>
135
- </div>
136
-
137
131
  ## ©️Citations
138
132
 
139
133
  ```BibTeX
@@ -146,6 +140,12 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
146
140
  }
147
141
  ```
148
142
 
143
+ ## 👋Reference
144
+
145
+ <div id="reference"></div>
146
+
147
+ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work!
148
+
149
149
  ## 📖Contents
150
150
 
151
151
  <div id="contents"></div>
@@ -153,6 +153,7 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
153
153
  - [⚙️Installation](#️installation)
154
154
  - [🔥Supported Models](#supported)
155
155
  - [⚡️Dual Block Cache](#dbcache)
156
+ - [🔥Hybrid TaylorSeer](#taylorseer)
156
157
  - [🎉First Block Cache](#fbcache)
157
158
  - [⚡️Dynamic Block Prune](#dbprune)
158
159
  - [🎉Context Parallelism](#context-parallelism)
@@ -187,28 +188,19 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
187
188
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
188
189
 
189
190
 
190
- <!--
191
- <p align="center">
192
- <h4> 🔥Supported Models🔥</h4>
193
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
194
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
195
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
196
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
197
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
198
- <a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
199
- </p>
200
- -->
201
-
202
191
  ## ⚡️DBCache: Dual Block Cache
203
192
 
204
193
  <div id="dbcache"></div>
205
194
 
206
- ![](https://github.com/user-attachments/assets/c2a382b9-0ccd-46f4-aacc-87857b4a4de8)
195
+ ![](https://github.com/vipshop/cache-dit/raw/main/assets/dbcache-v1.png)
207
196
 
208
197
  **DBCache** provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
209
198
 
210
199
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
211
200
  - **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
201
+
202
+ ![](https://github.com/vipshop/cache-dit/raw/main/assets/dbcache-fnbn-v1.png)
203
+
212
204
  - **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
213
205
  - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
214
206
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
@@ -264,11 +256,54 @@ cache_options = {
264
256
  |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
265
257
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
266
258
 
259
+ ## 🔥Hybrid TaylorSeer
260
+
261
+ <div id="taylorseer"></div>
262
+
263
+ We have supported the [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
264
+
265
+ $$
266
+ \mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)=\mathcal{F}\left(x_t^l\right)+\sum_{i=1}^m \frac{\Delta^i \mathcal{F}\left(x_t^l\right)}{i!\cdot N^i}(-k)^i
267
+ $$
268
+
269
+ **TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. That is $\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.
270
+
271
+ ```python
272
+ cache_options = {
273
+ # TaylorSeer options
274
+ "enable_taylorseer": True,
275
+ "enable_encoder_taylorseer": True,
276
+ # Taylorseer cache type cache be hidden_states or residual.
277
+ "taylorseer_cache_type": "residual",
278
+ # Higher values of n_derivatives will lead to longer
279
+ # computation time but may improve precision significantly.
280
+ "taylorseer_kwargs": {
281
+ "n_derivatives": 2, # default is 2.
282
+ },
283
+ "warmup_steps": 3, # n_derivatives + 1
284
+ "residual_diff_threshold": 0.12,
285
+ }
286
+ ```
287
+
288
+ > [!Important]
289
+ > Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the **Bn** param of DBCache can be set to **0**. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of **TaylorSeer** + **DBCache FnB0**.
290
+
291
+ <div align="center">
292
+ <p align="center">
293
+ <b>DBCache F1B0 + TaylorSeer</b>, L20x1, Steps: 28, <br>"A cat holding a sign that says hello world with complex background"
294
+ </p>
295
+ </div>
296
+
297
+ |Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer|+compile|
298
+ |:---:|:---:|:---:|:---:|:---:|:---:|
299
+ |24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
300
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=105px>|
301
+
267
302
  ## 🎉FBCache: First Block Cache
268
303
 
269
304
  <div id="fbcache"></div>
270
305
 
271
- ![](https://github.com/user-attachments/assets/0fb66656-b711-457a-92a7-a830f134272d)
306
+ ![](https://github.com/vipshop/cache-dit/raw/main/assets/fbcache-v1.png)
272
307
 
273
308
  **DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can either use the original FBCache implementation directly or configure **DBCache** with **F1B0** settings to achieve the same functionality.
274
309
 
@@ -302,7 +337,7 @@ apply_cache_on_pipe(pipe, **cache_options)
302
337
 
303
338
  <div id="dbprune"></div>
304
339
 
305
- ![](https://github.com/user-attachments/assets/932b6360-9533-4352-b176-4c4d84bd4695)
340
+ ![](https://github.com/vipshop/cache-dit/raw/main/assets/dbprune-v1.png)
306
341
 
307
342
  We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, which is referred to as **DBPrune**. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.
308
343
 
@@ -389,7 +424,7 @@ from para_attn.context_parallel import init_context_parallel_mesh
389
424
  from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
390
425
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
391
426
 
392
- # Init distributed process group
427
+ # Init distributed process group
393
428
  dist.init_process_group()
394
429
  torch.cuda.set_device(dist.get_rank())
395
430
 
@@ -436,14 +471,16 @@ torch._dynamo.config.recompile_limit = 96 # default is 8
436
471
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
437
472
  ```
438
473
 
474
+ Please check [bench.py](./bench/bench.py) for more details.
475
+
439
476
  ## 👋Contribute
440
477
  <div id="contribute"></div>
441
478
 
442
- How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
479
+ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](https://github.com/vipshop/cache-dit/raw/main/CONTRIBUTE.md).
443
480
 
444
481
  ## ©️License
445
482
 
446
483
  <div id="license"></div>
447
484
 
448
485
 
449
- We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](./LICENSE) for more details.
486
+ We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.