cache-dit 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +166 -160
- cache_dit/cache_factory/block_adapters/block_adapters.py +195 -125
- cache_dit/cache_factory/block_adapters/block_registers.py +25 -13
- cache_dit/cache_factory/cache_adapters.py +209 -86
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_blocks/utils.py +16 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +22 -10
- cache_dit/cache_factory/cache_interface.py +26 -14
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_flux.py +3 -2
- cache_dit/utils.py +168 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/METADATA +34 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/RECORD +21 -21
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/top_level.txt +0 -0
|
@@ -63,6 +63,20 @@ class CachedContextManager:
|
|
|
63
63
|
_context = self.new_context(*args, **kwargs)
|
|
64
64
|
return _context
|
|
65
65
|
|
|
66
|
+
def remove_context(self, cached_context: CachedContext | str):
|
|
67
|
+
if isinstance(cached_context, CachedContext):
|
|
68
|
+
cached_context.clear_buffers()
|
|
69
|
+
if cached_context.name in self._cached_context_manager:
|
|
70
|
+
del self._cached_context_manager[cached_context.name]
|
|
71
|
+
else:
|
|
72
|
+
if cached_context in self._cached_context_manager:
|
|
73
|
+
self._cached_context_manager[cached_context].clear_buffers()
|
|
74
|
+
del self._cached_context_manager[cached_context]
|
|
75
|
+
|
|
76
|
+
def clear_contexts(self):
|
|
77
|
+
for cached_context in self._cached_context_manager:
|
|
78
|
+
self.remove_context(cached_context)
|
|
79
|
+
|
|
66
80
|
@contextlib.contextmanager
|
|
67
81
|
def enter_context(self, cached_context: CachedContext | str):
|
|
68
82
|
old_cached_context = self._current_context
|
|
@@ -719,17 +733,15 @@ class CachedContextManager:
|
|
|
719
733
|
encoder_prefix
|
|
720
734
|
)
|
|
721
735
|
|
|
722
|
-
|
|
723
|
-
encoder_hidden_states_prev is not None
|
|
724
|
-
), f"{prefix}_encoder_buffer must be set before"
|
|
736
|
+
if encoder_hidden_states_prev is not None:
|
|
725
737
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
738
|
+
if self.is_encoder_cache_residual():
|
|
739
|
+
encoder_hidden_states = (
|
|
740
|
+
encoder_hidden_states_prev + encoder_hidden_states
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
744
|
+
encoder_hidden_states = encoder_hidden_states_prev
|
|
733
745
|
|
|
734
746
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
735
747
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Tuple, List
|
|
1
|
+
from typing import Any, Tuple, List, Union
|
|
2
2
|
from diffusers import DiffusionPipeline
|
|
3
3
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
4
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
@@ -12,7 +12,10 @@ logger = init_logger(__name__)
|
|
|
12
12
|
|
|
13
13
|
def enable_cache(
|
|
14
14
|
# DiffusionPipeline or BlockAdapter
|
|
15
|
-
pipe_or_adapter:
|
|
15
|
+
pipe_or_adapter: Union[
|
|
16
|
+
DiffusionPipeline,
|
|
17
|
+
BlockAdapter,
|
|
18
|
+
],
|
|
16
19
|
# Cache context kwargs
|
|
17
20
|
Fn_compute_blocks: int = 8,
|
|
18
21
|
Bn_compute_blocks: int = 0,
|
|
@@ -30,7 +33,10 @@ def enable_cache(
|
|
|
30
33
|
taylorseer_cache_type: str = "residual",
|
|
31
34
|
taylorseer_order: int = 2,
|
|
32
35
|
**other_cache_context_kwargs,
|
|
33
|
-
) ->
|
|
36
|
+
) -> Union[
|
|
37
|
+
DiffusionPipeline,
|
|
38
|
+
BlockAdapter,
|
|
39
|
+
]:
|
|
34
40
|
r"""
|
|
35
41
|
Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
|
|
36
42
|
that match the specific Input and Output patterns).
|
|
@@ -100,11 +106,11 @@ def enable_cache(
|
|
|
100
106
|
>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
|
|
101
107
|
>>> output = pipe(...) # Just call the pipe as normal.
|
|
102
108
|
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
109
|
+
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
103
110
|
"""
|
|
104
|
-
|
|
105
111
|
# Collect cache context kwargs
|
|
106
112
|
cache_context_kwargs = other_cache_context_kwargs.copy()
|
|
107
|
-
if cache_type := cache_context_kwargs.get("cache_type", None):
|
|
113
|
+
if (cache_type := cache_context_kwargs.get("cache_type", None)) is not None:
|
|
108
114
|
if cache_type == CacheType.NONE:
|
|
109
115
|
return pipe_or_adapter
|
|
110
116
|
|
|
@@ -129,16 +135,9 @@ def enable_cache(
|
|
|
129
135
|
cache_context_kwargs["taylorseer_cache_type"] = taylorseer_cache_type
|
|
130
136
|
cache_context_kwargs["taylorseer_order"] = taylorseer_order
|
|
131
137
|
|
|
132
|
-
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
133
|
-
return CachedAdapter.apply(
|
|
134
|
-
pipe=None,
|
|
135
|
-
block_adapter=pipe_or_adapter,
|
|
136
|
-
**cache_context_kwargs,
|
|
137
|
-
)
|
|
138
|
-
elif isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
138
|
+
if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
|
|
139
139
|
return CachedAdapter.apply(
|
|
140
|
-
|
|
141
|
-
block_adapter=None,
|
|
140
|
+
pipe_or_adapter,
|
|
142
141
|
**cache_context_kwargs,
|
|
143
142
|
)
|
|
144
143
|
else:
|
|
@@ -149,6 +148,19 @@ def enable_cache(
|
|
|
149
148
|
)
|
|
150
149
|
|
|
151
150
|
|
|
151
|
+
def disable_cache(
|
|
152
|
+
pipe_or_adapter: Union[
|
|
153
|
+
DiffusionPipeline,
|
|
154
|
+
BlockAdapter,
|
|
155
|
+
],
|
|
156
|
+
):
|
|
157
|
+
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
|
|
158
|
+
logger.warning(
|
|
159
|
+
f"Cache Acceleration is disabled for: "
|
|
160
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
152
164
|
def supported_pipelines(
|
|
153
165
|
**kwargs,
|
|
154
166
|
) -> Tuple[int, List[str]]:
|
|
@@ -22,11 +22,11 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
22
22
|
if isinstance(type_hint, CacheType):
|
|
23
23
|
return type_hint
|
|
24
24
|
|
|
25
|
-
elif type_hint.
|
|
26
|
-
"
|
|
27
|
-
"
|
|
28
|
-
"
|
|
29
|
-
"
|
|
25
|
+
elif type_hint.upper() in (
|
|
26
|
+
"DUAL_BLOCK_CACHE",
|
|
27
|
+
"DB_CACHE",
|
|
28
|
+
"DBCACHE",
|
|
29
|
+
"DB",
|
|
30
30
|
):
|
|
31
31
|
return CacheType.DBCache
|
|
32
32
|
return CacheType.NONE
|
|
@@ -30,7 +30,7 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
30
30
|
blocks: torch.nn.ModuleList = None,
|
|
31
31
|
**kwargs,
|
|
32
32
|
) -> ChromaTransformer2DModel:
|
|
33
|
-
if
|
|
33
|
+
if hasattr(transformer, "_is_patched"):
|
|
34
34
|
return transformer
|
|
35
35
|
|
|
36
36
|
if blocks is None:
|
|
@@ -56,7 +56,8 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
56
56
|
transformer.forward = __patch_transformer_forward__.__get__(
|
|
57
57
|
transformer
|
|
58
58
|
)
|
|
59
|
-
|
|
59
|
+
|
|
60
|
+
transformer._is_patched = is_patched # True or False
|
|
60
61
|
|
|
61
62
|
cls_name = transformer.__class__.__name__
|
|
62
63
|
logger.info(
|
|
@@ -31,7 +31,7 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
31
31
|
**kwargs,
|
|
32
32
|
) -> FluxTransformer2DModel:
|
|
33
33
|
|
|
34
|
-
if
|
|
34
|
+
if hasattr(transformer, "_is_patched"):
|
|
35
35
|
return transformer
|
|
36
36
|
|
|
37
37
|
if blocks is None:
|
|
@@ -57,7 +57,8 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
57
57
|
transformer.forward = __patch_transformer_forward__.__get__(
|
|
58
58
|
transformer
|
|
59
59
|
)
|
|
60
|
-
|
|
60
|
+
|
|
61
|
+
transformer._is_patched = is_patched # True or False
|
|
61
62
|
|
|
62
63
|
cls_name = transformer.__class__.__name__
|
|
63
64
|
logger.info(
|
cache_dit/utils.py
CHANGED
|
@@ -5,7 +5,8 @@ import numpy as np
|
|
|
5
5
|
from pprint import pprint
|
|
6
6
|
from diffusers import DiffusionPipeline
|
|
7
7
|
|
|
8
|
-
from typing import Dict, Any
|
|
8
|
+
from typing import Dict, Any, List, Union
|
|
9
|
+
from cache_dit.cache_factory import BlockAdapter
|
|
9
10
|
from cache_dit.logger import init_logger
|
|
10
11
|
|
|
11
12
|
|
|
@@ -29,9 +30,171 @@ class CacheStats:
|
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
def summary(
|
|
32
|
-
|
|
33
|
+
adapter_or_others: Union[
|
|
34
|
+
BlockAdapter,
|
|
35
|
+
DiffusionPipeline,
|
|
36
|
+
torch.nn.Module,
|
|
37
|
+
],
|
|
33
38
|
details: bool = False,
|
|
34
39
|
logging: bool = True,
|
|
40
|
+
**kwargs,
|
|
41
|
+
) -> List[CacheStats]:
|
|
42
|
+
if adapter_or_others is None:
|
|
43
|
+
return [CacheStats()]
|
|
44
|
+
|
|
45
|
+
if not isinstance(adapter_or_others, BlockAdapter):
|
|
46
|
+
if not isinstance(adapter_or_others, DiffusionPipeline):
|
|
47
|
+
transformer = adapter_or_others
|
|
48
|
+
transformer_2 = None
|
|
49
|
+
else:
|
|
50
|
+
transformer = adapter_or_others.transformer
|
|
51
|
+
transformer_2 = None
|
|
52
|
+
if hasattr(adapter_or_others, "transformer_2"):
|
|
53
|
+
transformer_2 = adapter_or_others.transformer_2
|
|
54
|
+
|
|
55
|
+
if not BlockAdapter.is_cached(transformer):
|
|
56
|
+
return [CacheStats()]
|
|
57
|
+
|
|
58
|
+
blocks_stats: List[CacheStats] = []
|
|
59
|
+
for blocks in BlockAdapter.find_blocks(transformer):
|
|
60
|
+
blocks_stats.append(
|
|
61
|
+
_summary(
|
|
62
|
+
blocks,
|
|
63
|
+
details=details,
|
|
64
|
+
logging=logging,
|
|
65
|
+
**kwargs,
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if transformer_2 is not None:
|
|
70
|
+
for blocks in BlockAdapter.find_blocks(transformer_2):
|
|
71
|
+
blocks_stats.append(
|
|
72
|
+
_summary(
|
|
73
|
+
blocks,
|
|
74
|
+
details=details,
|
|
75
|
+
logging=logging,
|
|
76
|
+
**kwargs,
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
blocks_stats.append(
|
|
81
|
+
_summary(
|
|
82
|
+
transformer,
|
|
83
|
+
details=details,
|
|
84
|
+
logging=logging,
|
|
85
|
+
**kwargs,
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
if transformer_2 is not None:
|
|
89
|
+
blocks_stats.append(
|
|
90
|
+
_summary(
|
|
91
|
+
transformer_2,
|
|
92
|
+
details=details,
|
|
93
|
+
logging=logging,
|
|
94
|
+
**kwargs,
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
|
|
99
|
+
|
|
100
|
+
return blocks_stats if len(blocks_stats) else [CacheStats()]
|
|
101
|
+
|
|
102
|
+
adapter = adapter_or_others
|
|
103
|
+
if not BlockAdapter.check_block_adapter(adapter):
|
|
104
|
+
return [CacheStats()]
|
|
105
|
+
|
|
106
|
+
blocks_stats = []
|
|
107
|
+
flatten_blocks = BlockAdapter.flatten(adapter.blocks)
|
|
108
|
+
for blocks in flatten_blocks:
|
|
109
|
+
blocks_stats.append(
|
|
110
|
+
_summary(
|
|
111
|
+
blocks,
|
|
112
|
+
details=details,
|
|
113
|
+
logging=logging,
|
|
114
|
+
**kwargs,
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
|
|
119
|
+
|
|
120
|
+
return blocks_stats if len(blocks_stats) else [CacheStats()]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def strify(
|
|
124
|
+
adapter_or_others: Union[
|
|
125
|
+
BlockAdapter,
|
|
126
|
+
DiffusionPipeline,
|
|
127
|
+
CacheStats,
|
|
128
|
+
List[CacheStats],
|
|
129
|
+
Dict[str, Any],
|
|
130
|
+
],
|
|
131
|
+
) -> str:
|
|
132
|
+
if isinstance(adapter_or_others, BlockAdapter):
|
|
133
|
+
stats = summary(adapter_or_others, logging=False)[-1]
|
|
134
|
+
cache_options = stats.cache_options
|
|
135
|
+
cached_steps = len(stats.cached_steps)
|
|
136
|
+
elif isinstance(adapter_or_others, DiffusionPipeline):
|
|
137
|
+
stats = summary(adapter_or_others, logging=False)[-1]
|
|
138
|
+
cache_options = stats.cache_options
|
|
139
|
+
cached_steps = len(stats.cached_steps)
|
|
140
|
+
elif isinstance(adapter_or_others, CacheStats):
|
|
141
|
+
stats = adapter_or_others
|
|
142
|
+
cache_options = stats.cache_options
|
|
143
|
+
cached_steps = len(stats.cached_steps)
|
|
144
|
+
elif isinstance(adapter_or_others, list):
|
|
145
|
+
stats = adapter_or_others[0]
|
|
146
|
+
cache_options = stats.cache_options
|
|
147
|
+
cached_steps = len(stats.cached_steps)
|
|
148
|
+
elif isinstance(adapter_or_others, dict):
|
|
149
|
+
from cache_dit.cache_factory import CacheType
|
|
150
|
+
|
|
151
|
+
# Assume cache_context_kwargs
|
|
152
|
+
cache_options = adapter_or_others
|
|
153
|
+
cached_steps = None
|
|
154
|
+
cache_type = cache_options.get("cache_type", CacheType.NONE)
|
|
155
|
+
|
|
156
|
+
if cache_type == CacheType.NONE:
|
|
157
|
+
return "NONE"
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
"Please set pipe_or_stats param as one of: "
|
|
161
|
+
"DiffusionPipeline | CacheStats | Dict[str, Any]"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if not cache_options:
|
|
165
|
+
return "NONE"
|
|
166
|
+
|
|
167
|
+
def get_taylorseer_order():
|
|
168
|
+
taylorseer_order = 0
|
|
169
|
+
if "taylorseer_order" in cache_options:
|
|
170
|
+
taylorseer_order = cache_options["taylorseer_order"]
|
|
171
|
+
return taylorseer_order
|
|
172
|
+
|
|
173
|
+
cache_type_str = (
|
|
174
|
+
f"DBCACHE_F{cache_options.get('Fn_compute_blocks', 1)}"
|
|
175
|
+
f"B{cache_options.get('Bn_compute_blocks', 0)}_"
|
|
176
|
+
f"W{cache_options.get('max_warmup_steps', 0)}"
|
|
177
|
+
f"M{max(0, cache_options.get('max_cached_steps', -1))}"
|
|
178
|
+
f"MC{max(0, cache_options.get('max_continuous_cached_steps', -1))}_"
|
|
179
|
+
f"T{int(cache_options.get('enable_taylorseer', False))}"
|
|
180
|
+
f"O{get_taylorseer_order()}_"
|
|
181
|
+
f"R{cache_options.get('residual_diff_threshold', 0.08)}"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if cached_steps:
|
|
185
|
+
cache_type_str += f"_S{cached_steps}"
|
|
186
|
+
|
|
187
|
+
return cache_type_str
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _summary(
|
|
191
|
+
pipe_or_module: Union[
|
|
192
|
+
DiffusionPipeline,
|
|
193
|
+
torch.nn.Module,
|
|
194
|
+
],
|
|
195
|
+
details: bool = False,
|
|
196
|
+
logging: bool = True,
|
|
197
|
+
**kwargs,
|
|
35
198
|
) -> CacheStats:
|
|
36
199
|
cache_stats = CacheStats()
|
|
37
200
|
|
|
@@ -51,6 +214,9 @@ def summary(
|
|
|
51
214
|
cache_stats.cache_options = cache_options
|
|
52
215
|
if logging:
|
|
53
216
|
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
217
|
+
else:
|
|
218
|
+
if logging:
|
|
219
|
+
logger.warning(f"Can't find Cache Options for: {cls_name}")
|
|
54
220
|
|
|
55
221
|
if hasattr(module, "_cached_steps"):
|
|
56
222
|
cached_steps: list[int] = module._cached_steps
|
|
@@ -141,56 +307,3 @@ def summary(
|
|
|
141
307
|
)
|
|
142
308
|
|
|
143
309
|
return cache_stats
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
def strify(
|
|
147
|
-
pipe_or_stats: DiffusionPipeline | CacheStats | Dict[str, Any],
|
|
148
|
-
) -> str:
|
|
149
|
-
if isinstance(pipe_or_stats, DiffusionPipeline):
|
|
150
|
-
stats = summary(pipe_or_stats, logging=False)
|
|
151
|
-
cache_options = stats.cache_options
|
|
152
|
-
cached_steps = len(stats.cached_steps)
|
|
153
|
-
elif isinstance(pipe_or_stats, CacheStats):
|
|
154
|
-
stats = pipe_or_stats
|
|
155
|
-
cache_options = stats.cache_options
|
|
156
|
-
cached_steps = len(stats.cached_steps)
|
|
157
|
-
elif isinstance(pipe_or_stats, dict):
|
|
158
|
-
from cache_dit.cache_factory import CacheType
|
|
159
|
-
|
|
160
|
-
# Assume cache_context_kwargs
|
|
161
|
-
cache_options = pipe_or_stats
|
|
162
|
-
cached_steps = None
|
|
163
|
-
cache_type = cache_options.get("cache_type", CacheType.NONE)
|
|
164
|
-
|
|
165
|
-
if cache_type == CacheType.NONE:
|
|
166
|
-
return "NONE"
|
|
167
|
-
else:
|
|
168
|
-
raise ValueError(
|
|
169
|
-
"Please set pipe_or_stats param as one of: "
|
|
170
|
-
"DiffusionPipeline | CacheStats | Dict[str, Any]"
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
if not cache_options:
|
|
174
|
-
return "NONE"
|
|
175
|
-
|
|
176
|
-
def get_taylorseer_order():
|
|
177
|
-
taylorseer_order = 0
|
|
178
|
-
if "taylorseer_order" in cache_options:
|
|
179
|
-
taylorseer_order = cache_options["taylorseer_order"]
|
|
180
|
-
return taylorseer_order
|
|
181
|
-
|
|
182
|
-
cache_type_str = (
|
|
183
|
-
f"DBCACHE_F{cache_options.get('Fn_compute_blocks', 1)}"
|
|
184
|
-
f"B{cache_options.get('Bn_compute_blocks', 0)}_"
|
|
185
|
-
f"W{cache_options.get('max_warmup_steps', 0)}"
|
|
186
|
-
f"M{max(0, cache_options.get('max_cached_steps', -1))}"
|
|
187
|
-
f"MC{max(0, cache_options.get('max_continuous_cached_steps', -1))}_"
|
|
188
|
-
f"T{int(cache_options.get('enable_taylorseer', False))}"
|
|
189
|
-
f"O{get_taylorseer_order()}_"
|
|
190
|
-
f"R{cache_options.get('residual_diff_threshold', 0.08)}"
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
if cached_steps:
|
|
194
|
-
cache_type_str += f"_S{cached_steps}"
|
|
195
|
-
|
|
196
|
-
return cache_type_str
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.30
|
|
4
4
|
Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -43,7 +43,7 @@ Dynamic: requires-python
|
|
|
43
43
|
<div align="center">
|
|
44
44
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-logo.png height="120">
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
<p align="center">
|
|
47
47
|
A <b>Unified</b> and Training-free <b>Cache Acceleration</b> Toolbox for <b>Diffusion Transformers</b> <br>
|
|
48
48
|
♥️ <b>Cache Acceleration</b> with <b>One-line</b> Code ~ ♥️
|
|
49
49
|
</p>
|
|
@@ -59,26 +59,36 @@ Dynamic: requires-python
|
|
|
59
59
|
🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
|
|
60
60
|
</p>
|
|
61
61
|
<p align="center">
|
|
62
|
-
🎉Now, <b>cache-dit</b> covers <b>
|
|
62
|
+
🎉Now, <b>cache-dit</b> covers <b>mainstream</b> Diffusers' <b>DiT-based</b> Pipelines🎉<br>
|
|
63
63
|
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
64
64
|
</p>
|
|
65
65
|
</div>
|
|
66
|
+
<div align='center'>
|
|
67
|
+
<img src=./assets/gifs/wan2.2.C0_Q0_NONE.gif width=160px>
|
|
68
|
+
<img src=./assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
|
|
69
|
+
<img src=./assets/gifs/wan2.2.C1_Q1_fp8_w8a8_dq_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=160px>
|
|
70
|
+
<p><b>🔥Wan2.2 MoE</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~2.0x↑🎉</b> | +FP8 DQ:<b>~2.4x↑🎉</b></p>
|
|
71
|
+
<img src=./assets/qwen-image.C0_Q0_NONE.png width=160px>
|
|
72
|
+
<img src=./assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
|
|
73
|
+
<img src=./assets/qwen-image.C1_Q1_fp8_w8a8_dq_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S18.png width=160px>
|
|
74
|
+
<p><b>🔥Qwen-Image</b> Baseline | <b><a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:~1.8x↑🎉</b> | +FP8 DQ:<b>~2.2x↑🎉</b><br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
|
|
75
|
+
</p>
|
|
76
|
+
</div>
|
|
66
77
|
|
|
67
78
|
## 🔥News
|
|
68
79
|
|
|
69
|
-
- [2025-09-03] 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x
|
|
70
|
-
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x
|
|
80
|
+
- [2025-09-03] 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x↑🎉** speedup! Please refer to [run_wan_2.2.py](./examples/pipeline/run_wan_2.2.py) as an example.
|
|
81
|
+
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x↑🎉** speedup! Check the example: [run_qwen_image_edit.py](./examples/pipeline/run_qwen_image_edit.py).
|
|
71
82
|
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
72
|
-
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x
|
|
83
|
+
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x↑🎉** speedup! Please refer to [run_qwen_image.py](./examples/pipeline/run_qwen_image.py) as an example.
|
|
84
|
+
- [2025-07-13] 🎉[**FLUX.1-Dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + `compile + FP8 DQ`.
|
|
73
85
|
|
|
74
86
|
<details>
|
|
75
87
|
<summary> Previous News </summary>
|
|
76
88
|
|
|
77
89
|
- [2025-09-01] 📚[**Hybird Forward Pattern**](#unified) is supported! Please check [FLUX.1-dev](./examples/run_flux_adapter.py) as an example.
|
|
78
|
-
- [2025-08-29] 🔥</b>Covers <b>100%</b> Diffusers' <b>DiT-based</b> Pipelines: **[BlockAdapter](#unified) + [Pattern Matching](#unified).**
|
|
79
90
|
- [2025-08-10] 🔥[**FLUX.1-Kontext-dev**](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](./examples/pipeline/run_flux_kontext.py) as an example.
|
|
80
91
|
- [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
|
|
81
|
-
- [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! **3.3x** speedup for FLUX.1 on NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)**.
|
|
82
92
|
|
|
83
93
|
</details>
|
|
84
94
|
|
|
@@ -119,19 +129,8 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
119
129
|
|
|
120
130
|
<div id="supported"></div>
|
|
121
131
|
|
|
122
|
-
```python
|
|
123
|
-
>>> import cache_dit
|
|
124
|
-
>>> cache_dit.supported_pipelines()
|
|
125
|
-
(31, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTXVideo*',
|
|
126
|
-
'Allegro*', 'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'SD3*',
|
|
127
|
-
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'HunyuanDiT*', 'HunyuanDiTPAG*', 'Lumina*', 'Lumina2*',
|
|
128
|
-
'OmniGen*', 'PixArt*', 'Sana*', 'ShapE*', 'StableAudio*', 'VisualCloze*', 'AuraFlow*',
|
|
129
|
-
'Chroma*', 'HiDream*'])
|
|
130
|
-
```
|
|
131
|
-
|
|
132
132
|
Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
|
|
133
133
|
|
|
134
|
-
|
|
135
134
|
- [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
136
135
|
- [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
137
136
|
- [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -143,35 +142,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
143
142
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
144
143
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
145
144
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
146
|
-
- [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
147
145
|
|
|
148
|
-
<details>
|
|
149
|
-
<summary> More Pipelines </summary>
|
|
150
|
-
|
|
151
|
-
- [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
152
|
-
- [🚀LTXVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
153
|
-
- [🚀Allegro](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
154
|
-
- [🚀CogView3Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
155
|
-
- [🚀CogView4](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
156
|
-
- [🚀Cosmos](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
157
|
-
- [🚀EasyAnimate](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
158
|
-
- [🚀SkyReelsV2](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
159
|
-
- [🚀SD3](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
160
|
-
- [🚀ConsisID](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
161
|
-
- [🚀DiT](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
162
|
-
- [🚀Amused](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
163
|
-
- [🚀HunyuanDiTPAG](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
164
|
-
- [🚀Lumina](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
165
|
-
- [🚀Lumina2](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
166
|
-
- [🚀OmniGen](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
167
|
-
- [🚀PixArt](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
168
|
-
- [🚀Sana](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
169
|
-
- [🚀StableAudio](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
170
|
-
- [🚀VisualCloze](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
171
|
-
- [🚀AuraFlow](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
172
|
-
- [🚀Chroma](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
173
|
-
- [🚀HiDream](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
174
|
-
|
|
175
146
|
</details>
|
|
176
147
|
|
|
177
148
|
## 🎉Unified Cache APIs
|
|
@@ -200,6 +171,9 @@ cache_dit.enable_cache(pipe)
|
|
|
200
171
|
|
|
201
172
|
# Just call the pipe as normal.
|
|
202
173
|
output = pipe(...)
|
|
174
|
+
|
|
175
|
+
# Disable cache and run original pipe.
|
|
176
|
+
cache_dit.disable_cache(pipe)
|
|
203
177
|
```
|
|
204
178
|
|
|
205
179
|
### 🔥Automatic Block Adapter
|
|
@@ -226,7 +200,6 @@ cache_dit.enable_cache(
|
|
|
226
200
|
pipe=pipe, # Qwen-Image, etc.
|
|
227
201
|
transformer=pipe.transformer,
|
|
228
202
|
blocks=pipe.transformer.transformer_blocks,
|
|
229
|
-
blocks_name="transformer_blocks",
|
|
230
203
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
231
204
|
),
|
|
232
205
|
)
|
|
@@ -248,10 +221,6 @@ cache_dit.enable_cache(
|
|
|
248
221
|
pipe.transformer.transformer_blocks,
|
|
249
222
|
pipe.transformer.single_transformer_blocks,
|
|
250
223
|
],
|
|
251
|
-
blocks_name=[
|
|
252
|
-
"transformer_blocks",
|
|
253
|
-
"single_transformer_blocks",
|
|
254
|
-
],
|
|
255
224
|
forward_pattern=[
|
|
256
225
|
ForwardPattern.Pattern_1,
|
|
257
226
|
ForwardPattern.Pattern_3,
|
|
@@ -457,11 +426,21 @@ cache-dit-metrics-cli all -i1 true_dir -i2 test_dir # image dir
|
|
|
457
426
|
|
|
458
427
|
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
459
428
|
|
|
460
|
-
|
|
429
|
+
<div align='center'>
|
|
430
|
+
<a href="https://star-history.com/#vipshop/cache-dit&Date">
|
|
431
|
+
<picture align='center'>
|
|
432
|
+
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=vipshop/cache-dit&type=Date&theme=dark" />
|
|
433
|
+
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=vipshop/cache-dit&type=Date" />
|
|
434
|
+
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=vipshop/cache-dit&type=Date" width=400px />
|
|
435
|
+
</picture>
|
|
436
|
+
</a>
|
|
437
|
+
</div>
|
|
438
|
+
|
|
439
|
+
## ©️Acknowledgements
|
|
461
440
|
|
|
462
|
-
<div id="
|
|
441
|
+
<div id="Acknowledgements"></div>
|
|
463
442
|
|
|
464
|
-
The **cache-dit** codebase is adapted from FBCache.
|
|
443
|
+
The **cache-dit** codebase is adapted from FBCache. Over time its codebase diverged a lot, and **cache-dit** API is no longer compatible with FBCache.
|
|
465
444
|
|
|
466
445
|
## ©️Citations
|
|
467
446
|
|