cache-dit 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -1
- cache_dit/cache_factory/cache_adapters.py +298 -123
- cache_dit/cache_factory/cache_blocks.py +9 -3
- cache_dit/cache_factory/cache_context.py +85 -15
- cache_dit/cache_factory/cache_interface.py +18 -11
- cache_dit/cache_factory/taylorseer.py +5 -4
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/utils.py +25 -22
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/METADATA +19 -10
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/RECORD +16 -17
- cache_dit/primitives.py +0 -152
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/top_level.txt +0 -0
|
@@ -5,8 +5,8 @@ from collections import defaultdict
|
|
|
5
5
|
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
8
9
|
|
|
9
|
-
import cache_dit.primitives as primitives
|
|
10
10
|
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
11
11
|
from cache_dit.logger import init_logger
|
|
12
12
|
|
|
@@ -47,10 +47,11 @@ class DBCacheContext:
|
|
|
47
47
|
|
|
48
48
|
# Other settings
|
|
49
49
|
downsample_factor: int = 1
|
|
50
|
-
num_inference_steps: int = -1 #
|
|
51
|
-
|
|
50
|
+
num_inference_steps: int = -1 # for future use
|
|
51
|
+
max_warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
52
52
|
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
53
53
|
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
54
|
+
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
54
55
|
|
|
55
56
|
# Record the steps that have been cached, both cached and non-cache
|
|
56
57
|
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
@@ -89,10 +90,12 @@ class DBCacheContext:
|
|
|
89
90
|
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
90
91
|
default_factory=lambda: defaultdict(float),
|
|
91
92
|
)
|
|
93
|
+
continuous_cached_steps: int = 0
|
|
92
94
|
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
93
95
|
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
94
96
|
default_factory=lambda: defaultdict(float),
|
|
95
97
|
)
|
|
98
|
+
cfg_continuous_cached_steps: int = 0
|
|
96
99
|
|
|
97
100
|
@torch.compiler.disable
|
|
98
101
|
def __post_init__(self):
|
|
@@ -108,17 +111,17 @@ class DBCacheContext:
|
|
|
108
111
|
"cfg_diff_compute_separate is enabled."
|
|
109
112
|
)
|
|
110
113
|
|
|
111
|
-
if "
|
|
112
|
-
# If
|
|
113
|
-
# set the same as
|
|
114
|
-
self.taylorseer_kwargs["
|
|
115
|
-
self.
|
|
114
|
+
if "max_warmup_steps" not in self.taylorseer_kwargs:
|
|
115
|
+
# If max_warmup_steps is not set in taylorseer_kwargs,
|
|
116
|
+
# set the same as max_warmup_steps for DBCache
|
|
117
|
+
self.taylorseer_kwargs["max_warmup_steps"] = (
|
|
118
|
+
self.max_warmup_steps if self.max_warmup_steps > 0 else 1
|
|
116
119
|
)
|
|
117
120
|
|
|
118
121
|
# Only set n_derivatives as 2 or 3, which is enough for most cases.
|
|
119
122
|
if "n_derivatives" not in self.taylorseer_kwargs:
|
|
120
123
|
self.taylorseer_kwargs["n_derivatives"] = max(
|
|
121
|
-
2, min(3, self.taylorseer_kwargs["
|
|
124
|
+
2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
|
|
122
125
|
)
|
|
123
126
|
|
|
124
127
|
if self.enable_taylorseer:
|
|
@@ -268,10 +271,31 @@ class DBCacheContext:
|
|
|
268
271
|
|
|
269
272
|
@torch.compiler.disable
|
|
270
273
|
def add_cached_step(self):
|
|
274
|
+
curr_cached_step = self.get_current_step()
|
|
271
275
|
if not self.is_separate_cfg_step():
|
|
272
|
-
self.cached_steps
|
|
276
|
+
if self.cached_steps:
|
|
277
|
+
prev_cached_step = self.cached_steps[-1]
|
|
278
|
+
if curr_cached_step - prev_cached_step == 1:
|
|
279
|
+
if self.continuous_cached_steps == 0:
|
|
280
|
+
self.continuous_cached_steps += 2
|
|
281
|
+
else:
|
|
282
|
+
self.continuous_cached_steps += 1
|
|
283
|
+
else:
|
|
284
|
+
self.continuous_cached_steps += 1
|
|
285
|
+
|
|
286
|
+
self.cached_steps.append(curr_cached_step)
|
|
273
287
|
else:
|
|
274
|
-
self.cfg_cached_steps
|
|
288
|
+
if self.cfg_cached_steps:
|
|
289
|
+
prev_cfg_cached_step = self.cfg_cached_steps[-1]
|
|
290
|
+
if curr_cached_step - prev_cfg_cached_step == 1:
|
|
291
|
+
if self.cfg_continuous_cached_steps == 0:
|
|
292
|
+
self.cfg_continuous_cached_steps += 2
|
|
293
|
+
else:
|
|
294
|
+
self.cfg_continuous_cached_steps += 1
|
|
295
|
+
else:
|
|
296
|
+
self.cfg_continuous_cached_steps += 1
|
|
297
|
+
|
|
298
|
+
self.cfg_cached_steps.append(curr_cached_step)
|
|
275
299
|
|
|
276
300
|
@torch.compiler.disable
|
|
277
301
|
def get_cached_steps(self):
|
|
@@ -301,7 +325,7 @@ class DBCacheContext:
|
|
|
301
325
|
|
|
302
326
|
@torch.compiler.disable
|
|
303
327
|
def is_in_warmup(self):
|
|
304
|
-
return self.get_current_step() < self.
|
|
328
|
+
return self.get_current_step() < self.max_warmup_steps
|
|
305
329
|
|
|
306
330
|
|
|
307
331
|
@torch.compiler.disable
|
|
@@ -396,6 +420,27 @@ def get_max_cached_steps():
|
|
|
396
420
|
return cache_context.max_cached_steps
|
|
397
421
|
|
|
398
422
|
|
|
423
|
+
@torch.compiler.disable
|
|
424
|
+
def get_max_continuous_cached_steps():
|
|
425
|
+
cache_context = get_current_cache_context()
|
|
426
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
427
|
+
return cache_context.max_continuous_cached_steps
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
@torch.compiler.disable
|
|
431
|
+
def get_continuous_cached_steps():
|
|
432
|
+
cache_context = get_current_cache_context()
|
|
433
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
434
|
+
return cache_context.continuous_cached_steps
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
@torch.compiler.disable
|
|
438
|
+
def get_cfg_continuous_cached_steps():
|
|
439
|
+
cache_context = get_current_cache_context()
|
|
440
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
441
|
+
return cache_context.cfg_continuous_cached_steps
|
|
442
|
+
|
|
443
|
+
|
|
399
444
|
@torch.compiler.disable
|
|
400
445
|
def add_cached_step():
|
|
401
446
|
cache_context = get_current_cache_context()
|
|
@@ -744,8 +789,8 @@ def are_two_tensors_similar(
|
|
|
744
789
|
mean_t1 = t1.abs().mean()
|
|
745
790
|
|
|
746
791
|
if parallelized:
|
|
747
|
-
|
|
748
|
-
|
|
792
|
+
dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
|
|
793
|
+
dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
|
|
749
794
|
|
|
750
795
|
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
751
796
|
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
@@ -1020,6 +1065,7 @@ def get_can_use_cache(
|
|
|
1020
1065
|
if is_in_warmup():
|
|
1021
1066
|
return False
|
|
1022
1067
|
|
|
1068
|
+
# max cached steps
|
|
1023
1069
|
max_cached_steps = get_max_cached_steps()
|
|
1024
1070
|
if not is_separate_cfg_step():
|
|
1025
1071
|
cached_steps = get_cached_steps()
|
|
@@ -1030,8 +1076,32 @@ def get_can_use_cache(
|
|
|
1030
1076
|
if logger.isEnabledFor(logging.DEBUG):
|
|
1031
1077
|
logger.debug(
|
|
1032
1078
|
f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
|
|
1033
|
-
"
|
|
1079
|
+
"can not use cache."
|
|
1080
|
+
)
|
|
1081
|
+
return False
|
|
1082
|
+
|
|
1083
|
+
# max continuous cached steps
|
|
1084
|
+
max_continuous_cached_steps = get_max_continuous_cached_steps()
|
|
1085
|
+
if not is_separate_cfg_step():
|
|
1086
|
+
continuous_cached_steps = get_continuous_cached_steps()
|
|
1087
|
+
else:
|
|
1088
|
+
continuous_cached_steps = get_cfg_continuous_cached_steps()
|
|
1089
|
+
|
|
1090
|
+
if max_continuous_cached_steps >= 0 and (
|
|
1091
|
+
continuous_cached_steps >= max_continuous_cached_steps
|
|
1092
|
+
):
|
|
1093
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
1094
|
+
logger.debug(
|
|
1095
|
+
f"{prefix}, max_continuous_cached_steps "
|
|
1096
|
+
f"reached: {max_continuous_cached_steps}, "
|
|
1097
|
+
"can not use cache."
|
|
1034
1098
|
)
|
|
1099
|
+
# reset continuous cached steps stats
|
|
1100
|
+
cache_context = get_current_cache_context()
|
|
1101
|
+
if not is_separate_cfg_step():
|
|
1102
|
+
cache_context.continuous_cached_steps = 0
|
|
1103
|
+
else:
|
|
1104
|
+
cache_context.cfg_continuous_cached_steps = 0
|
|
1035
1105
|
return False
|
|
1036
1106
|
|
|
1037
1107
|
if threshold is None or threshold <= 0.0:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from diffusers import DiffusionPipeline
|
|
2
2
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
3
3
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
|
-
from cache_dit.cache_factory.cache_adapters import
|
|
4
|
+
from cache_dit.cache_factory.cache_adapters import BlockAdapter
|
|
5
5
|
from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
|
|
6
6
|
|
|
7
7
|
from cache_dit.logger import init_logger
|
|
@@ -11,13 +11,14 @@ logger = init_logger(__name__)
|
|
|
11
11
|
|
|
12
12
|
def enable_cache(
|
|
13
13
|
# BlockAdapter & forward pattern
|
|
14
|
-
pipe_or_adapter: DiffusionPipeline |
|
|
14
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter,
|
|
15
15
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
16
16
|
# Cache context kwargs
|
|
17
17
|
Fn_compute_blocks: int = 8,
|
|
18
18
|
Bn_compute_blocks: int = 0,
|
|
19
|
-
|
|
19
|
+
max_warmup_steps: int = 8,
|
|
20
20
|
max_cached_steps: int = -1,
|
|
21
|
+
max_continuous_cached_steps: int = -1,
|
|
21
22
|
residual_diff_threshold: float = 0.08,
|
|
22
23
|
# Cache CFG or not
|
|
23
24
|
do_separate_cfg: bool = False,
|
|
@@ -38,7 +39,7 @@ def enable_cache(
|
|
|
38
39
|
with F8B0, 8 warmup steps, and unlimited cached steps.
|
|
39
40
|
|
|
40
41
|
Args:
|
|
41
|
-
pipe_or_adapter (`DiffusionPipeline` or `
|
|
42
|
+
pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
|
|
42
43
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
43
44
|
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
44
45
|
for the usgae of BlockAdapter.
|
|
@@ -54,12 +55,15 @@ def enable_cache(
|
|
|
54
55
|
Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
55
56
|
prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
56
57
|
that use residual cache.
|
|
57
|
-
|
|
58
|
+
max_warmup_steps (`int`, *required*, defaults to 8):
|
|
58
59
|
DBCache does not apply the caching strategy when the number of running steps is less than
|
|
59
60
|
or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
60
61
|
max_cached_steps (`int`, *required*, defaults to -1):
|
|
61
62
|
DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
62
63
|
prevent precision degradation.
|
|
64
|
+
max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
65
|
+
DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
66
|
+
prevent precision degradation.
|
|
63
67
|
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
64
68
|
he value of residual diff threshold, a higher value leads to faster performance at the
|
|
65
69
|
cost of lower precision.
|
|
@@ -106,8 +110,11 @@ def enable_cache(
|
|
|
106
110
|
cache_context_kwargs["cache_type"] = CacheType.DBCache
|
|
107
111
|
cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
|
|
108
112
|
cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
|
|
109
|
-
cache_context_kwargs["
|
|
113
|
+
cache_context_kwargs["max_warmup_steps"] = max_warmup_steps
|
|
110
114
|
cache_context_kwargs["max_cached_steps"] = max_cached_steps
|
|
115
|
+
cache_context_kwargs["max_continuous_cached_steps"] = (
|
|
116
|
+
max_continuous_cached_steps
|
|
117
|
+
)
|
|
111
118
|
cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
|
|
112
119
|
cache_context_kwargs["do_separate_cfg"] = do_separate_cfg
|
|
113
120
|
cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
|
|
@@ -128,22 +135,22 @@ def enable_cache(
|
|
|
128
135
|
"n_derivatives": taylorseer_order
|
|
129
136
|
}
|
|
130
137
|
|
|
131
|
-
if isinstance(pipe_or_adapter,
|
|
138
|
+
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
132
139
|
return UnifiedCacheAdapter.apply(
|
|
133
140
|
pipe=None,
|
|
134
|
-
|
|
141
|
+
block_adapter=pipe_or_adapter,
|
|
135
142
|
forward_pattern=forward_pattern,
|
|
136
143
|
**cache_context_kwargs,
|
|
137
144
|
)
|
|
138
145
|
elif isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
139
146
|
return UnifiedCacheAdapter.apply(
|
|
140
147
|
pipe=pipe_or_adapter,
|
|
141
|
-
|
|
148
|
+
block_adapter=None,
|
|
142
149
|
forward_pattern=forward_pattern,
|
|
143
150
|
**cache_context_kwargs,
|
|
144
151
|
)
|
|
145
152
|
else:
|
|
146
153
|
raise ValueError(
|
|
147
|
-
"Please pass DiffusionPipeline or
|
|
148
|
-
"
|
|
154
|
+
"Please pass DiffusionPipeline or BlockAdapter"
|
|
155
|
+
"for the 1's position param: pipe_or_adapter"
|
|
149
156
|
)
|
|
@@ -6,13 +6,13 @@ class TaylorSeer:
|
|
|
6
6
|
def __init__(
|
|
7
7
|
self,
|
|
8
8
|
n_derivatives=2,
|
|
9
|
-
|
|
9
|
+
max_warmup_steps=1,
|
|
10
10
|
skip_interval_steps=1,
|
|
11
11
|
compute_step_map=None,
|
|
12
12
|
):
|
|
13
13
|
self.n_derivatives = n_derivatives
|
|
14
14
|
self.ORDER = n_derivatives + 1
|
|
15
|
-
self.
|
|
15
|
+
self.max_warmup_steps = max_warmup_steps
|
|
16
16
|
self.skip_interval_steps = skip_interval_steps
|
|
17
17
|
self.compute_step_map = compute_step_map
|
|
18
18
|
self.reset_cache()
|
|
@@ -32,8 +32,9 @@ class TaylorSeer:
|
|
|
32
32
|
if self.compute_step_map is not None:
|
|
33
33
|
return self.compute_step_map[step]
|
|
34
34
|
if (
|
|
35
|
-
step < self.
|
|
36
|
-
or (step - self.
|
|
35
|
+
step < self.max_warmup_steps
|
|
36
|
+
or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
|
|
37
|
+
== 0
|
|
37
38
|
):
|
|
38
39
|
return True
|
|
39
40
|
return False
|
cache_dit/cache_factory/utils.py
CHANGED
cache_dit/utils.py
CHANGED
|
@@ -27,22 +27,26 @@ class CacheStats:
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def summary(
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
pipe_or_transformer: DiffusionPipeline | torch.nn.Module,
|
|
31
|
+
details: bool = False,
|
|
32
|
+
logging: bool = True,
|
|
33
|
+
) -> CacheStats:
|
|
32
34
|
cache_stats = CacheStats()
|
|
33
|
-
|
|
35
|
+
cls_name = pipe_or_transformer.__class__.__name__
|
|
36
|
+
if isinstance(pipe_or_transformer, DiffusionPipeline):
|
|
37
|
+
transformer = pipe_or_transformer.transformer
|
|
38
|
+
else:
|
|
39
|
+
transformer = pipe_or_transformer
|
|
34
40
|
|
|
35
|
-
if hasattr(
|
|
36
|
-
cache_options =
|
|
41
|
+
if hasattr(transformer, "_cache_context_kwargs"):
|
|
42
|
+
cache_options = transformer._cache_context_kwargs
|
|
37
43
|
cache_stats.cache_options = cache_options
|
|
38
44
|
if logging:
|
|
39
|
-
print(f"\n🤗Cache Options: {
|
|
45
|
+
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
40
46
|
|
|
41
|
-
if hasattr(
|
|
42
|
-
cached_steps: list[int] =
|
|
43
|
-
residual_diffs: dict[str, float] = dict(
|
|
44
|
-
pipe.transformer._residual_diffs
|
|
45
|
-
)
|
|
47
|
+
if hasattr(transformer, "_cached_steps"):
|
|
48
|
+
cached_steps: list[int] = transformer._cached_steps
|
|
49
|
+
residual_diffs: dict[str, float] = dict(transformer._residual_diffs)
|
|
46
50
|
cache_stats.cached_steps = cached_steps
|
|
47
51
|
cache_stats.residual_diffs = residual_diffs
|
|
48
52
|
|
|
@@ -57,7 +61,7 @@ def summary(
|
|
|
57
61
|
qmax = np.max(diffs_values)
|
|
58
62
|
|
|
59
63
|
print(
|
|
60
|
-
f"\n⚡️Cache Steps and Residual Diffs Statistics: {
|
|
64
|
+
f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
61
65
|
)
|
|
62
66
|
|
|
63
67
|
print(
|
|
@@ -74,9 +78,7 @@ def summary(
|
|
|
74
78
|
print("")
|
|
75
79
|
|
|
76
80
|
if details:
|
|
77
|
-
print(
|
|
78
|
-
f"📚Cache Steps and Residual Diffs Details: {pipe_cls_name}\n"
|
|
79
|
-
)
|
|
81
|
+
print(f"📚Cache Steps and Residual Diffs Details: {cls_name}\n")
|
|
80
82
|
pprint(
|
|
81
83
|
f"Cache Steps: {len(cached_steps)}, {cached_steps}",
|
|
82
84
|
)
|
|
@@ -85,10 +87,10 @@ def summary(
|
|
|
85
87
|
compact=True,
|
|
86
88
|
)
|
|
87
89
|
|
|
88
|
-
if hasattr(
|
|
89
|
-
cfg_cached_steps: list[int] =
|
|
90
|
+
if hasattr(transformer, "_cfg_cached_steps"):
|
|
91
|
+
cfg_cached_steps: list[int] = transformer._cfg_cached_steps
|
|
90
92
|
cfg_residual_diffs: dict[str, float] = dict(
|
|
91
|
-
|
|
93
|
+
transformer._cfg_residual_diffs
|
|
92
94
|
)
|
|
93
95
|
cache_stats.cfg_cached_steps = cfg_cached_steps
|
|
94
96
|
cache_stats.cfg_residual_diffs = cfg_residual_diffs
|
|
@@ -104,7 +106,7 @@ def summary(
|
|
|
104
106
|
qmax = np.max(cfg_diffs_values)
|
|
105
107
|
|
|
106
108
|
print(
|
|
107
|
-
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {
|
|
109
|
+
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
108
110
|
)
|
|
109
111
|
|
|
110
112
|
print(
|
|
@@ -122,7 +124,7 @@ def summary(
|
|
|
122
124
|
|
|
123
125
|
if details:
|
|
124
126
|
print(
|
|
125
|
-
f"📚CFG Cache Steps and Residual Diffs Details: {
|
|
127
|
+
f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
|
|
126
128
|
)
|
|
127
129
|
pprint(
|
|
128
130
|
f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
|
|
@@ -149,9 +151,10 @@ def strify(pipe_or_stats: DiffusionPipeline | CacheStats):
|
|
|
149
151
|
|
|
150
152
|
cache_type_str = (
|
|
151
153
|
f"DBCACHE_F{cache_options['Fn_compute_blocks']}"
|
|
152
|
-
f"B{cache_options['Bn_compute_blocks']}"
|
|
153
|
-
f"W{cache_options['
|
|
154
|
+
f"B{cache_options['Bn_compute_blocks']}_"
|
|
155
|
+
f"W{cache_options['max_warmup_steps']}"
|
|
154
156
|
f"M{max(0, cache_options['max_cached_steps'])}"
|
|
157
|
+
f"MC{max(0, cache_options['max_continuous_cached_steps'])}_"
|
|
155
158
|
f"T{int(cache_options['enable_taylorseer'])}"
|
|
156
159
|
f"O{cache_options['taylorseer_kwargs']['n_derivatives']}_"
|
|
157
160
|
f"R{cache_options['residual_diff_threshold']}_"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.24
|
|
4
4
|
Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -59,12 +59,13 @@ Dynamic: requires-python
|
|
|
59
59
|
</p>
|
|
60
60
|
<p align="center">
|
|
61
61
|
🎉Now, <b>cache-dit</b> covers <b>Most</b> mainstream <b>Diffusers'</b> Pipelines</b>🎉<br>
|
|
62
|
-
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
62
|
+
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
63
63
|
</p>
|
|
64
64
|
</div>
|
|
65
65
|
|
|
66
66
|
## 🔥News
|
|
67
67
|
|
|
68
|
+
- [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.5x⚡️** speedup! Please check [run_wan_2.2.py](./examples/run_wan_2.2.py) as an example.
|
|
68
69
|
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x⚡️** speedup! Check example [run_qwen_image_edit.py](./examples/run_qwen_image_edit.py).
|
|
69
70
|
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
70
71
|
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x⚡️** speedup! Please refer [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
|
|
@@ -119,6 +120,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
119
120
|
- [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
120
121
|
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
121
122
|
- [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
123
|
+
- [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
122
124
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
123
125
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
124
126
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -166,23 +168,30 @@ cache_dit.enable_cache(pipe)
|
|
|
166
168
|
output = pipe(...)
|
|
167
169
|
```
|
|
168
170
|
|
|
169
|
-
### 🔥
|
|
171
|
+
### 🔥Automatic Block Adapter
|
|
170
172
|
|
|
171
173
|
But in some cases, you may have a **modified** Diffusion Pipeline or Transformer that is not located in the diffusers library or not officially supported by **cache-dit** at this time. The **BlockAdapter** can help you solve this problems. Please refer to [Qwen-Image w/ BlockAdapter](./examples/run_qwen_image_adapter.py) as an example.
|
|
172
174
|
|
|
173
175
|
```python
|
|
174
176
|
from cache_dit import ForwardPattern, BlockAdapter
|
|
175
177
|
|
|
176
|
-
#
|
|
178
|
+
# Use BlockAdapter with `auto` mode.
|
|
179
|
+
cache_dit.enable_cache(
|
|
180
|
+
BlockAdapter(pipe=pipe, auto=True), # Qwen-Image, etc.
|
|
181
|
+
# Check `📚Forward Pattern Matching` documentation and hack the code of
|
|
182
|
+
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
|
|
183
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Or, manually setup transformer configurations.
|
|
177
187
|
cache_dit.enable_cache(
|
|
178
188
|
BlockAdapter(
|
|
179
189
|
pipe=pipe, # Qwen-Image, etc.
|
|
180
190
|
transformer=pipe.transformer,
|
|
181
191
|
blocks=pipe.transformer.transformer_blocks,
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
192
|
+
blocks_name="transformer_blocks",
|
|
193
|
+
),
|
|
194
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
186
195
|
)
|
|
187
196
|
```
|
|
188
197
|
For such situations, **BlockAdapter** can help you quickly apply various cache acceleration features to your own Diffusion Pipelines and Transformers. Please check the [📚BlockAdapter.md](./docs/BlockAdapter.md) for more details.
|
|
@@ -231,7 +240,7 @@ cache_dit.enable_cache(pipe)
|
|
|
231
240
|
# Custom options, F8B8, higher precision
|
|
232
241
|
cache_dit.enable_cache(
|
|
233
242
|
pipe,
|
|
234
|
-
|
|
243
|
+
max_warmup_steps=8, # steps do not cache
|
|
235
244
|
max_cached_steps=-1, # -1 means no limit
|
|
236
245
|
Fn_compute_blocks=8, # Fn, F8, etc.
|
|
237
246
|
Bn_compute_blocks=8, # Bn, B8, etc.
|
|
@@ -290,7 +299,7 @@ cache_dit.enable_cache(
|
|
|
290
299
|
taylorseer_kwargs={
|
|
291
300
|
"n_derivatives": 2, # default is 2.
|
|
292
301
|
},
|
|
293
|
-
|
|
302
|
+
max_warmup_steps=3, # prefer: >= n_derivatives + 1
|
|
294
303
|
residual_diff_threshold=0.12
|
|
295
304
|
)
|
|
296
305
|
```
|
|
@@ -1,18 +1,17 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
1
|
+
cache_dit/__init__.py,sha256=KwhX9NfYkWSvDFuuUVeVjcuiZiGS_22y386l8j4afMo,905
|
|
2
|
+
cache_dit/_version.py,sha256=AZPr2DJJAwMsYN7GLT_kjMvP33B8Rgy4O_7h4o_T_88,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/
|
|
5
|
-
cache_dit/utils.py,sha256=3UgVhfmTFG28w6CV-Rfxp5u1uzLrRozocHwLCTGiQ5M,5865
|
|
4
|
+
cache_dit/utils.py,sha256=kzwF98nzfzIFHSLtCx7Vq4a9aTW42lY-Bth7Oi4jAhg,6083
|
|
6
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
7
|
-
cache_dit/cache_factory/__init__.py,sha256=
|
|
8
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
9
|
-
cache_dit/cache_factory/cache_blocks.py,sha256=
|
|
10
|
-
cache_dit/cache_factory/cache_context.py,sha256=
|
|
11
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
6
|
+
cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
|
|
7
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=Yugqljm9tm615srM2BGQlR_tA0QiZo3PbLPceObh4dQ,25988
|
|
8
|
+
cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
|
|
9
|
+
cache_dit/cache_factory/cache_context.py,sha256=Cexr1_uwEkX7v8gB7DSyhCX0SI2dqS_e_ccTR16G2es,41738
|
|
10
|
+
cache_dit/cache_factory/cache_interface.py,sha256=ri8wAxmHOsDW8c6qYP6VquOJQaTSXuOchWXG3PdcYQM,8434
|
|
12
11
|
cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
|
|
13
12
|
cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
|
|
14
|
-
cache_dit/cache_factory/taylorseer.py,sha256=
|
|
15
|
-
cache_dit/cache_factory/utils.py,sha256=
|
|
13
|
+
cache_dit/cache_factory/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
|
|
14
|
+
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
16
15
|
cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
16
|
cache_dit/cache_factory/patch/flux.py,sha256=iNQ-1RlOgXupZ4uPiEvJ__Ro6vKT_fOKja9JrpMrO78,8998
|
|
18
17
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
@@ -25,9 +24,9 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
|
|
|
25
24
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
26
25
|
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
27
26
|
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
28
|
-
cache_dit-0.2.
|
|
29
|
-
cache_dit-0.2.
|
|
30
|
-
cache_dit-0.2.
|
|
31
|
-
cache_dit-0.2.
|
|
32
|
-
cache_dit-0.2.
|
|
33
|
-
cache_dit-0.2.
|
|
27
|
+
cache_dit-0.2.24.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
28
|
+
cache_dit-0.2.24.dist-info/METADATA,sha256=zq_bGjQ_X--m1njAbOob--MwOpTDlUlAzZ3u_MiNiFM,19977
|
|
29
|
+
cache_dit-0.2.24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
+
cache_dit-0.2.24.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
31
|
+
cache_dit-0.2.24.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
32
|
+
cache_dit-0.2.24.dist-info/RECORD,,
|