cache-dit 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +282 -57
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +1 -3
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -1
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +23 -23
- cache_dit/cache_factory/first_block_cache/cache_context.py +3 -11
- cache_dit/cache_factory/taylorseer.py +29 -0
- cache_dit/compile/__init__.py +1 -0
- cache_dit/compile/utils.py +94 -0
- cache_dit/custom_ops/__init__.py +0 -0
- cache_dit/custom_ops/triton_taylorseer.py +0 -0
- cache_dit/logger.py +28 -0
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/METADATA +76 -39
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/RECORD +27 -23
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -55,6 +55,7 @@ class DBPPruneContext:
|
|
|
55
55
|
default_factory=lambda: defaultdict(list),
|
|
56
56
|
)
|
|
57
57
|
|
|
58
|
+
@torch.compiler.disable
|
|
58
59
|
def get_residual_diff_threshold(self):
|
|
59
60
|
residual_diff_threshold = self.residual_diff_threshold
|
|
60
61
|
if self.l1_hidden_states_diff_threshold is not None:
|
|
@@ -98,19 +99,24 @@ class DBPPruneContext:
|
|
|
98
99
|
)
|
|
99
100
|
return residual_diff_threshold
|
|
100
101
|
|
|
102
|
+
@torch.compiler.disable
|
|
101
103
|
def get_buffer(self, name):
|
|
102
104
|
return self.buffers.get(name)
|
|
103
105
|
|
|
106
|
+
@torch.compiler.disable
|
|
104
107
|
def set_buffer(self, name, buffer):
|
|
105
108
|
self.buffers[name] = buffer
|
|
106
109
|
|
|
110
|
+
@torch.compiler.disable
|
|
107
111
|
def remove_buffer(self, name):
|
|
108
112
|
if name in self.buffers:
|
|
109
113
|
del self.buffers[name]
|
|
110
114
|
|
|
115
|
+
@torch.compiler.disable
|
|
111
116
|
def clear_buffers(self):
|
|
112
117
|
self.buffers.clear()
|
|
113
118
|
|
|
119
|
+
@torch.compiler.disable
|
|
114
120
|
def mark_step_begin(self):
|
|
115
121
|
self.executed_steps += 1
|
|
116
122
|
if self.get_current_step() == 0:
|
|
@@ -118,12 +124,15 @@ class DBPPruneContext:
|
|
|
118
124
|
self.actual_blocks.clear()
|
|
119
125
|
self.residual_diffs.clear()
|
|
120
126
|
|
|
127
|
+
@torch.compiler.disable
|
|
121
128
|
def add_pruned_block(self, num_blocks):
|
|
122
129
|
self.pruned_blocks.append(num_blocks)
|
|
123
130
|
|
|
131
|
+
@torch.compiler.disable
|
|
124
132
|
def add_actual_block(self, num_blocks):
|
|
125
133
|
self.actual_blocks.append(num_blocks)
|
|
126
134
|
|
|
135
|
+
@torch.compiler.disable
|
|
127
136
|
def add_residual_diff(self, diff):
|
|
128
137
|
if isinstance(diff, torch.Tensor):
|
|
129
138
|
diff = diff.item()
|
|
@@ -141,9 +150,11 @@ class DBPPruneContext:
|
|
|
141
150
|
f"residual diff: {diff:.6f}"
|
|
142
151
|
)
|
|
143
152
|
|
|
153
|
+
@torch.compiler.disable
|
|
144
154
|
def get_current_step(self):
|
|
145
155
|
return self.executed_steps - 1
|
|
146
156
|
|
|
157
|
+
@torch.compiler.disable
|
|
147
158
|
def is_in_warmup(self):
|
|
148
159
|
return self.get_current_step() < self.warmup_steps
|
|
149
160
|
|
|
@@ -348,11 +359,19 @@ def collect_prune_kwargs(default_attrs: dict, **kwargs):
|
|
|
348
359
|
)
|
|
349
360
|
for attr in prune_attrs
|
|
350
361
|
}
|
|
362
|
+
|
|
351
363
|
# Manually set sequence fields, such as non_prune_blocks_ids
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
)
|
|
364
|
+
def _safe_set_sequence_field(
|
|
365
|
+
field_name: str,
|
|
366
|
+
default_value: Any = None,
|
|
367
|
+
):
|
|
368
|
+
if field_name not in prune_kwargs:
|
|
369
|
+
prune_kwargs[field_name] = kwargs.pop(
|
|
370
|
+
field_name,
|
|
371
|
+
default_value,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
_safe_set_sequence_field("non_prune_blocks_ids", [])
|
|
356
375
|
|
|
357
376
|
assert default_attrs is not None, "default_attrs must be set before"
|
|
358
377
|
for attr in prune_attrs:
|
|
@@ -627,10 +646,6 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
627
646
|
]
|
|
628
647
|
return sorted(non_prune_blocks_ids)
|
|
629
648
|
|
|
630
|
-
# @torch.compile(dynamic=True)
|
|
631
|
-
# mark this function as compile with dynamic=True will
|
|
632
|
-
# cause precision degradate, so, we choose to disable it
|
|
633
|
-
# now, until we find a better solution or fixed the bug.
|
|
634
649
|
@torch.compiler.disable
|
|
635
650
|
def _compute_single_hidden_states_residual(
|
|
636
651
|
self,
|
|
@@ -667,10 +682,6 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
667
682
|
single_encoder_hidden_states_residual,
|
|
668
683
|
)
|
|
669
684
|
|
|
670
|
-
# @torch.compile(dynamic=True)
|
|
671
|
-
# mark this function as compile with dynamic=True will
|
|
672
|
-
# cause precision degradate, so, we choose to disable it
|
|
673
|
-
# now, until we find a better solution or fixed the bug.
|
|
674
685
|
@torch.compiler.disable
|
|
675
686
|
def _split_single_hidden_states(
|
|
676
687
|
self,
|
|
@@ -969,17 +980,6 @@ def patch_pruned_stats(
|
|
|
969
980
|
if transformer is None:
|
|
970
981
|
return
|
|
971
982
|
|
|
972
|
-
pruned_transformer_blocks = getattr(transformer, "transformer_blocks", None)
|
|
973
|
-
if pruned_transformer_blocks is None:
|
|
974
|
-
return
|
|
975
|
-
|
|
976
|
-
if isinstance(pruned_transformer_blocks, torch.nn.ModuleList):
|
|
977
|
-
pruned_transformer_blocks = pruned_transformer_blocks[0]
|
|
978
|
-
if not isinstance(
|
|
979
|
-
pruned_transformer_blocks, DBPrunedTransformerBlocks
|
|
980
|
-
) or not isinstance(transformer, torch.nn.Module):
|
|
981
|
-
return
|
|
982
|
-
|
|
983
983
|
# TODO: Patch more pruned stats to the transformer
|
|
984
984
|
transformer._pruned_blocks = get_pruned_blocks()
|
|
985
985
|
transformer._pruned_steps = get_pruned_steps()
|
|
@@ -370,6 +370,9 @@ def apply_prev_hidden_states_residual(
|
|
|
370
370
|
hidden_states = hidden_states_residual + hidden_states
|
|
371
371
|
|
|
372
372
|
hidden_states = hidden_states.contiguous()
|
|
373
|
+
# NOTE: We should also support taylorseer for
|
|
374
|
+
# encoder_hidden_states approximation. Please
|
|
375
|
+
# use DBCache instead.
|
|
373
376
|
else:
|
|
374
377
|
hidden_states_residual = get_hidden_states_residual()
|
|
375
378
|
assert (
|
|
@@ -711,17 +714,6 @@ def patch_cached_stats(
|
|
|
711
714
|
if transformer is None:
|
|
712
715
|
return
|
|
713
716
|
|
|
714
|
-
cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
|
|
715
|
-
if cached_transformer_blocks is None:
|
|
716
|
-
return
|
|
717
|
-
|
|
718
|
-
if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
|
|
719
|
-
cached_transformer_blocks = cached_transformer_blocks[0]
|
|
720
|
-
if not isinstance(
|
|
721
|
-
cached_transformer_blocks, CachedTransformerBlocks
|
|
722
|
-
) or not isinstance(transformer, torch.nn.Module):
|
|
723
|
-
return
|
|
724
|
-
|
|
725
717
|
# TODO: Patch more cached stats to the transformer
|
|
726
718
|
transformer._cached_steps = get_cached_steps()
|
|
727
719
|
transformer._residual_diffs = get_residual_diffs()
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/taylorseer.py
|
|
2
|
+
# Reference: https://github.com/Shenyi-Z/TaylorSeer/TaylorSeer-FLUX/src/flux/taylor_utils/__init__.py
|
|
2
3
|
import math
|
|
4
|
+
import torch
|
|
3
5
|
|
|
4
6
|
|
|
5
7
|
class TaylorSeer:
|
|
@@ -17,6 +19,7 @@ class TaylorSeer:
|
|
|
17
19
|
self.compute_step_map = compute_step_map
|
|
18
20
|
self.reset_cache()
|
|
19
21
|
|
|
22
|
+
@torch.compiler.disable
|
|
20
23
|
def reset_cache(self):
|
|
21
24
|
self.state = {
|
|
22
25
|
"dY_prev": [None] * self.ORDER,
|
|
@@ -25,6 +28,7 @@ class TaylorSeer:
|
|
|
25
28
|
self.current_step = -1
|
|
26
29
|
self.last_non_approximated_step = -1
|
|
27
30
|
|
|
31
|
+
@torch.compiler.disable
|
|
28
32
|
def should_compute_full(self, step=None):
|
|
29
33
|
step = self.current_step if step is None else step
|
|
30
34
|
if self.compute_step_map is not None:
|
|
@@ -36,7 +40,13 @@ class TaylorSeer:
|
|
|
36
40
|
return True
|
|
37
41
|
return False
|
|
38
42
|
|
|
43
|
+
@torch.compiler.disable
|
|
39
44
|
def approximate_derivative(self, Y):
|
|
45
|
+
# n-th order Taylor expansion:
|
|
46
|
+
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
47
|
+
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
48
|
+
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
49
|
+
# especially for large n_derivatives.
|
|
40
50
|
dY_current = [None] * self.ORDER
|
|
41
51
|
dY_current[0] = Y
|
|
42
52
|
window = self.current_step - self.last_non_approximated_step
|
|
@@ -49,7 +59,10 @@ class TaylorSeer:
|
|
|
49
59
|
break
|
|
50
60
|
return dY_current
|
|
51
61
|
|
|
62
|
+
@torch.compiler.disable
|
|
52
63
|
def approximate_value(self):
|
|
64
|
+
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
65
|
+
# especially for large n_derivatives.
|
|
53
66
|
elapsed = self.current_step - self.last_non_approximated_step
|
|
54
67
|
output = 0
|
|
55
68
|
for i, derivative in enumerate(self.state["dY_current"]):
|
|
@@ -59,14 +72,30 @@ class TaylorSeer:
|
|
|
59
72
|
break
|
|
60
73
|
return output
|
|
61
74
|
|
|
75
|
+
@torch.compiler.disable
|
|
62
76
|
def mark_step_begin(self):
|
|
63
77
|
self.current_step += 1
|
|
64
78
|
|
|
79
|
+
@torch.compiler.disable
|
|
65
80
|
def update(self, Y):
|
|
81
|
+
# Directly call this method will ingnore the warmup
|
|
82
|
+
# policy and force full computation.
|
|
83
|
+
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
84
|
+
# step 0: dY_prev = [None, None, None, None ]
|
|
85
|
+
# dY_current = [Y0, None, None, None ]
|
|
86
|
+
# step 1: dY_prev = [Y0, None, None, None ]
|
|
87
|
+
# dY_current = [Y1, dY1, None, None ]
|
|
88
|
+
# step 2: dY_prev = [Y1, dY1, None, None ]
|
|
89
|
+
# dY_current = [Y2, dY2/Y1, dY2/dY1, None ]
|
|
90
|
+
# step 3: dY_prev = [Y2, dY2/Y1, dY2/dY1, None ],
|
|
91
|
+
# dY_current = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
92
|
+
# step 4: dY_prev = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
93
|
+
# dY_current = [Y4, dY4/Y3, dY4/dY3, dY4/dY2]
|
|
66
94
|
self.state["dY_prev"] = self.state["dY_current"]
|
|
67
95
|
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
68
96
|
self.last_non_approximated_step = self.current_step
|
|
69
97
|
|
|
98
|
+
@torch.compiler.disable
|
|
70
99
|
def step(self, Y):
|
|
71
100
|
self.mark_step_begin()
|
|
72
101
|
if self.should_compute_full():
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from cache_dit.compile.utils import set_custom_compile_configs
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from cache_dit.logger import init_logger, logging_rank_0
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
|
|
10
|
+
mode = kwargs.get("epilogue_prologue_fusion", False)
|
|
11
|
+
CACHE_DIT_EPILOGUE_PROLOGUE_FUSION = bool(
|
|
12
|
+
int(os.environ.get("CACHE_DIT_EPILOGUE_PROLOGUE_FUSION", "0"))
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
if CACHE_DIT_EPILOGUE_PROLOGUE_FUSION:
|
|
16
|
+
logging_rank_0(
|
|
17
|
+
logger,
|
|
18
|
+
"CACHE_DIT_EPILOGUE_PROLOGUE_FUSION is set to 1. \n"
|
|
19
|
+
"Force enable epilogue and prologue fusion.",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
return CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def set_custom_compile_configs(
|
|
26
|
+
cuda_graphs: bool = False,
|
|
27
|
+
force_disable_compile_caches: bool = False,
|
|
28
|
+
use_fast_math: bool = False,
|
|
29
|
+
**kwargs, # other kwargs
|
|
30
|
+
):
|
|
31
|
+
# Alway increase recompile_limit for dynamic shape compilation
|
|
32
|
+
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
33
|
+
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
34
|
+
# Handle compiler caches
|
|
35
|
+
# https://github.com/vllm-project/vllm/blob/23baa2180b0ebba5ae94073ba9b8e93f88b75486/vllm/compilation/compiler_interface.py#L270
|
|
36
|
+
torch._inductor.config.fx_graph_cache = True
|
|
37
|
+
torch._inductor.config.fx_graph_remote_cache = False
|
|
38
|
+
# https://github.com/pytorch/pytorch/issues/153791
|
|
39
|
+
torch._inductor.config.autotune_local_cache = False
|
|
40
|
+
|
|
41
|
+
FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
|
|
42
|
+
os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
|
|
43
|
+
== "1"
|
|
44
|
+
)
|
|
45
|
+
if FORCE_DISABLE_CUSTOM_COMPILE_CONFIG:
|
|
46
|
+
logging_rank_0(
|
|
47
|
+
logger,
|
|
48
|
+
"CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG is set to 1. \n"
|
|
49
|
+
"Force disable custom compile config.",
|
|
50
|
+
)
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
# Enable compute comm overlap
|
|
54
|
+
torch._inductor.config.reorder_for_compute_comm_overlap = True
|
|
55
|
+
# L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
|
|
56
|
+
torch._inductor.config.intra_node_bw = (
|
|
57
|
+
64 if "L20" in torch.cuda.get_device_name() else 300
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Below are default settings for torch.compile, you can change
|
|
61
|
+
# them to your needs and test the performance
|
|
62
|
+
torch._inductor.config.max_fusion_size = 64
|
|
63
|
+
torch._inductor.config.max_pointwise_cat_inputs = 8
|
|
64
|
+
torch._inductor.config.triton.cudagraphs = cuda_graphs
|
|
65
|
+
torch._inductor.config.triton.use_block_ptr = False
|
|
66
|
+
torch._inductor.config.triton.codegen_upcast_to_fp32 = True
|
|
67
|
+
|
|
68
|
+
# Copy from https://pytorch.org/blog/accelerating-generative-ai-3/
|
|
69
|
+
torch._inductor.config.conv_1x1_as_mm = True
|
|
70
|
+
torch._inductor.config.coordinate_descent_tuning = True
|
|
71
|
+
torch._inductor.config.coordinate_descent_check_all_directions = True
|
|
72
|
+
torch._inductor.config.epilogue_fusion = False
|
|
73
|
+
|
|
74
|
+
# Enable epilogue and prologue fusion
|
|
75
|
+
if epilogue_prologue_fusion_enabled(**kwargs):
|
|
76
|
+
torch._inductor.config.epilogue_fusion = True
|
|
77
|
+
torch._inductor.config.prologue_fusion = True
|
|
78
|
+
torch._inductor.config.epilogue_fusion_first = True
|
|
79
|
+
|
|
80
|
+
# Dead code elimination
|
|
81
|
+
torch._inductor.config.dce = True # default is False
|
|
82
|
+
|
|
83
|
+
# May need to force disable all cache
|
|
84
|
+
if force_disable_compile_caches:
|
|
85
|
+
torch._inductor.config.force_disable_caches = True
|
|
86
|
+
torch._inductor.config.fx_graph_cache = False
|
|
87
|
+
torch._inductor.config.fx_graph_remote_cache = False
|
|
88
|
+
torch._inductor.config.autotune_local_cache = False # default is True
|
|
89
|
+
|
|
90
|
+
# Use fast math
|
|
91
|
+
if hasattr(torch._inductor.config, "use_fast_math"):
|
|
92
|
+
torch._inductor.config.use_fast_math = use_fast_math
|
|
93
|
+
if hasattr(torch._inductor.config, "cuda.use_fast_math"):
|
|
94
|
+
torch._inductor.config.cuda.use_fast_math = use_fast_math
|
|
File without changes
|
|
File without changes
|
cache_dit/logger.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
import sys
|
|
4
|
+
import torch.distributed as dist
|
|
4
5
|
|
|
5
6
|
_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
|
|
6
7
|
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
|
@@ -95,3 +96,30 @@ def init_logger(name: str):
|
|
|
95
96
|
logger.addHandler(_inference_log_file_handler[pid])
|
|
96
97
|
logger.propagate = False
|
|
97
98
|
return logger
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def logging_rank_0(
|
|
102
|
+
logger: logging.Logger, message: str, level: int = logging.INFO
|
|
103
|
+
):
|
|
104
|
+
if not isinstance(logger, logging.Logger):
|
|
105
|
+
raise TypeError("logger must be an instance of logging.Logger")
|
|
106
|
+
if not isinstance(message, str):
|
|
107
|
+
raise TypeError("message must be a string")
|
|
108
|
+
if not isinstance(level, int):
|
|
109
|
+
raise TypeError("level must be an integer representing a logging level")
|
|
110
|
+
|
|
111
|
+
def _logging_msg():
|
|
112
|
+
if level == logging.DEBUG:
|
|
113
|
+
logger.debug(message)
|
|
114
|
+
elif level == logging.WARNING:
|
|
115
|
+
logger.warning(message)
|
|
116
|
+
elif level == logging.ERROR:
|
|
117
|
+
logger.error(message)
|
|
118
|
+
else:
|
|
119
|
+
logger.info(message)
|
|
120
|
+
|
|
121
|
+
if dist.is_initialized():
|
|
122
|
+
if dist.get_rank() == 0:
|
|
123
|
+
_logging_msg()
|
|
124
|
+
else:
|
|
125
|
+
_logging_msg()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -37,31 +37,31 @@ Dynamic: requires-python
|
|
|
37
37
|
<p align="center">
|
|
38
38
|
<h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
|
|
39
39
|
</p>
|
|
40
|
-
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
|
|
40
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-v1.png >
|
|
41
41
|
<div align='center'>
|
|
42
42
|
<img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
|
|
43
43
|
<img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.2.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.2.2-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>
|
|
50
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers <br>a set of training-free cache accelerators for DiT: <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">TaylorSeer</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
|
|
51
51
|
</p>
|
|
52
52
|
</div>
|
|
53
53
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
<
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
<div align="center">
|
|
55
|
+
<p align="center">
|
|
56
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
57
|
+
</p>
|
|
58
|
+
</div>
|
|
59
59
|
|
|
60
60
|
## 🤗 Introduction
|
|
61
61
|
|
|
62
62
|
<div align="center">
|
|
63
63
|
<p align="center">
|
|
64
|
-
<h3>🔥
|
|
64
|
+
<h3>🔥DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
65
65
|
</p>
|
|
66
66
|
</div>
|
|
67
67
|
|
|
@@ -77,9 +77,9 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
|
|
|
77
77
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
78
78
|
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
79
79
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
80
|
-
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.
|
|
80
|
+
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
|
|
81
81
|
|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
|
|
82
|
-
|<img src=https://github.com/
|
|
82
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
|
|
83
83
|
|
|
84
84
|
<div align="center">
|
|
85
85
|
<p align="center">
|
|
@@ -91,7 +91,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
91
91
|
|
|
92
92
|
<div align="center">
|
|
93
93
|
<p align="center">
|
|
94
|
-
<h3>🔥
|
|
94
|
+
<h3>🔥DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
95
95
|
</p>
|
|
96
96
|
</div>
|
|
97
97
|
|
|
@@ -110,11 +110,11 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
110
110
|
|
|
111
111
|
<div align="center">
|
|
112
112
|
<p align="center">
|
|
113
|
-
<h3>🔥
|
|
113
|
+
<h3>🔥Context Parallelism and Torch Compile</h3>
|
|
114
114
|
</p>
|
|
115
115
|
</div>
|
|
116
116
|
|
|
117
|
-
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
117
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
118
118
|
|
|
119
119
|
<div align="center">
|
|
120
120
|
<p align="center">
|
|
@@ -128,12 +128,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
128
128
|
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
129
129
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
130
130
|
|
|
131
|
-
<div align="center">
|
|
132
|
-
<p align="center">
|
|
133
|
-
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
134
|
-
</p>
|
|
135
|
-
</div>
|
|
136
|
-
|
|
137
131
|
## ©️Citations
|
|
138
132
|
|
|
139
133
|
```BibTeX
|
|
@@ -146,6 +140,12 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
146
140
|
}
|
|
147
141
|
```
|
|
148
142
|
|
|
143
|
+
## 👋Reference
|
|
144
|
+
|
|
145
|
+
<div id="reference"></div>
|
|
146
|
+
|
|
147
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work!
|
|
148
|
+
|
|
149
149
|
## 📖Contents
|
|
150
150
|
|
|
151
151
|
<div id="contents"></div>
|
|
@@ -153,6 +153,7 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
153
153
|
- [⚙️Installation](#️installation)
|
|
154
154
|
- [🔥Supported Models](#supported)
|
|
155
155
|
- [⚡️Dual Block Cache](#dbcache)
|
|
156
|
+
- [🔥Hybrid TaylorSeer](#taylorseer)
|
|
156
157
|
- [🎉First Block Cache](#fbcache)
|
|
157
158
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
158
159
|
- [🎉Context Parallelism](#context-parallelism)
|
|
@@ -187,28 +188,19 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
187
188
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
188
189
|
|
|
189
190
|
|
|
190
|
-
<!--
|
|
191
|
-
<p align="center">
|
|
192
|
-
<h4> 🔥Supported Models🔥</h4>
|
|
193
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
194
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
195
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
196
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
197
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
198
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
199
|
-
</p>
|
|
200
|
-
-->
|
|
201
|
-
|
|
202
191
|
## ⚡️DBCache: Dual Block Cache
|
|
203
192
|
|
|
204
193
|
<div id="dbcache"></div>
|
|
205
194
|
|
|
206
|
-

|
|
207
196
|
|
|
208
197
|
**DBCache** provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
|
|
209
198
|
|
|
210
199
|
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
|
211
200
|
- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
|
201
|
+
|
|
202
|
+

|
|
203
|
+
|
|
212
204
|
- **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
213
205
|
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
|
|
214
206
|
- **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
|
|
@@ -264,11 +256,54 @@ cache_options = {
|
|
|
264
256
|
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
265
257
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
266
258
|
|
|
259
|
+
## 🔥Hybrid TaylorSeer
|
|
260
|
+
|
|
261
|
+
<div id="taylorseer"></div>
|
|
262
|
+
|
|
263
|
+
We have supported the [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
|
|
264
|
+
|
|
265
|
+
$$
|
|
266
|
+
\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)=\mathcal{F}\left(x_t^l\right)+\sum_{i=1}^m \frac{\Delta^i \mathcal{F}\left(x_t^l\right)}{i!\cdot N^i}(-k)^i
|
|
267
|
+
$$
|
|
268
|
+
|
|
269
|
+
**TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. That is $\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.
|
|
270
|
+
|
|
271
|
+
```python
|
|
272
|
+
cache_options = {
|
|
273
|
+
# TaylorSeer options
|
|
274
|
+
"enable_taylorseer": True,
|
|
275
|
+
"enable_encoder_taylorseer": True,
|
|
276
|
+
# Taylorseer cache type cache be hidden_states or residual.
|
|
277
|
+
"taylorseer_cache_type": "residual",
|
|
278
|
+
# Higher values of n_derivatives will lead to longer
|
|
279
|
+
# computation time but may improve precision significantly.
|
|
280
|
+
"taylorseer_kwargs": {
|
|
281
|
+
"n_derivatives": 2, # default is 2.
|
|
282
|
+
},
|
|
283
|
+
"warmup_steps": 3, # n_derivatives + 1
|
|
284
|
+
"residual_diff_threshold": 0.12,
|
|
285
|
+
}
|
|
286
|
+
```
|
|
287
|
+
|
|
288
|
+
> [!Important]
|
|
289
|
+
> Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the **Bn** param of DBCache can be set to **0**. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of **TaylorSeer** + **DBCache FnB0**.
|
|
290
|
+
|
|
291
|
+
<div align="center">
|
|
292
|
+
<p align="center">
|
|
293
|
+
<b>DBCache F1B0 + TaylorSeer</b>, L20x1, Steps: 28, <br>"A cat holding a sign that says hello world with complex background"
|
|
294
|
+
</p>
|
|
295
|
+
</div>
|
|
296
|
+
|
|
297
|
+
|Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer|+compile|
|
|
298
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
299
|
+
|24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
|
|
300
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=105px>|
|
|
301
|
+
|
|
267
302
|
## 🎉FBCache: First Block Cache
|
|
268
303
|
|
|
269
304
|
<div id="fbcache"></div>
|
|
270
305
|
|
|
271
|
-

|
|
272
307
|
|
|
273
308
|
**DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can either use the original FBCache implementation directly or configure **DBCache** with **F1B0** settings to achieve the same functionality.
|
|
274
309
|
|
|
@@ -302,7 +337,7 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
302
337
|
|
|
303
338
|
<div id="dbprune"></div>
|
|
304
339
|
|
|
305
|
-

|
|
306
341
|
|
|
307
342
|
We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, which is referred to as **DBPrune**. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.
|
|
308
343
|
|
|
@@ -389,7 +424,7 @@ from para_attn.context_parallel import init_context_parallel_mesh
|
|
|
389
424
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
390
425
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
391
426
|
|
|
392
|
-
|
|
427
|
+
# Init distributed process group
|
|
393
428
|
dist.init_process_group()
|
|
394
429
|
torch.cuda.set_device(dist.get_rank())
|
|
395
430
|
|
|
@@ -436,14 +471,16 @@ torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
|
436
471
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
437
472
|
```
|
|
438
473
|
|
|
474
|
+
Please check [bench.py](./bench/bench.py) for more details.
|
|
475
|
+
|
|
439
476
|
## 👋Contribute
|
|
440
477
|
<div id="contribute"></div>
|
|
441
478
|
|
|
442
|
-
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
|
|
479
|
+
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](https://github.com/vipshop/cache-dit/raw/main/CONTRIBUTE.md).
|
|
443
480
|
|
|
444
481
|
## ©️License
|
|
445
482
|
|
|
446
483
|
<div id="license"></div>
|
|
447
484
|
|
|
448
485
|
|
|
449
|
-
We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](
|
|
486
|
+
We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
|