cache-dit 1.0.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +24 -8
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +10 -6
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +99 -34
- cache_dit/cache_factory/cache_blocks/pattern_base.py +80 -35
- cache_dit/cache_factory/cache_contexts/__init__.py +1 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +9 -2
- cache_dit/cache_factory/cache_interface.py +38 -13
- cache_dit/cache_factory/patch_functors/__init__.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/quantize/quantize_ao.py +4 -0
- {cache_dit-1.0.1.dist-info → cache_dit-1.0.2.dist-info}/METADATA +69 -63
- {cache_dit-1.0.1.dist-info → cache_dit-1.0.2.dist-info}/RECORD +17 -16
- {cache_dit-1.0.1.dist-info → cache_dit-1.0.2.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.1.dist-info → cache_dit-1.0.2.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.1.dist-info → cache_dit-1.0.2.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.1.dist-info → cache_dit-1.0.2.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0,
|
|
31
|
+
__version__ = version = '1.0.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 2)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -143,14 +143,30 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
143
143
|
from diffusers import QwenImageTransformer2DModel
|
|
144
144
|
|
|
145
145
|
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
146
|
+
|
|
147
|
+
pipe_cls_name: str = pipe.__class__.__name__
|
|
148
|
+
if pipe_cls_name.startswith("QwenImageControlNet"):
|
|
149
|
+
from cache_dit.cache_factory.patch_functors import (
|
|
150
|
+
QwenImageControlNetPatchFunctor,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return BlockAdapter(
|
|
154
|
+
pipe=pipe,
|
|
155
|
+
transformer=pipe.transformer,
|
|
156
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
157
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
158
|
+
patch_functor=QwenImageControlNetPatchFunctor(),
|
|
159
|
+
has_separate_cfg=True,
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
return BlockAdapter(
|
|
163
|
+
pipe=pipe,
|
|
164
|
+
transformer=pipe.transformer,
|
|
165
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
166
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
167
|
+
has_separate_cfg=True,
|
|
168
|
+
**kwargs,
|
|
169
|
+
)
|
|
154
170
|
|
|
155
171
|
|
|
156
172
|
@BlockAdapterRegistry.register("LTX")
|
|
@@ -14,10 +14,6 @@ from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
|
14
14
|
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
15
15
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
16
16
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
17
|
-
from cache_dit.cache_factory.cache_blocks import (
|
|
18
|
-
patch_cached_stats,
|
|
19
|
-
remove_cached_stats,
|
|
20
|
-
)
|
|
21
17
|
from cache_dit.logger import init_logger
|
|
22
18
|
|
|
23
19
|
logger = init_logger(__name__)
|
|
@@ -167,7 +163,7 @@ class CachedAdapter:
|
|
|
167
163
|
cls,
|
|
168
164
|
block_adapter: BlockAdapter,
|
|
169
165
|
**cache_context_kwargs,
|
|
170
|
-
) ->
|
|
166
|
+
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
171
167
|
|
|
172
168
|
BlockAdapter.assert_normalized(block_adapter)
|
|
173
169
|
|
|
@@ -221,7 +217,7 @@ class CachedAdapter:
|
|
|
221
217
|
|
|
222
218
|
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
223
219
|
|
|
224
|
-
return
|
|
220
|
+
return flatten_contexts, contexts_kwargs
|
|
225
221
|
|
|
226
222
|
@classmethod
|
|
227
223
|
def modify_context_params(
|
|
@@ -470,6 +466,10 @@ class CachedAdapter:
|
|
|
470
466
|
cls,
|
|
471
467
|
block_adapter: BlockAdapter,
|
|
472
468
|
):
|
|
469
|
+
from cache_dit.cache_factory.cache_blocks import (
|
|
470
|
+
patch_cached_stats,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
473
|
cache_manager = block_adapter.pipe._cache_manager
|
|
474
474
|
|
|
475
475
|
for i in range(len(block_adapter.transformer)):
|
|
@@ -557,6 +557,10 @@ class CachedAdapter:
|
|
|
557
557
|
)
|
|
558
558
|
|
|
559
559
|
# release stats hooks
|
|
560
|
+
from cache_dit.cache_factory.cache_blocks import (
|
|
561
|
+
remove_cached_stats,
|
|
562
|
+
)
|
|
563
|
+
|
|
560
564
|
cls.release_hooks(
|
|
561
565
|
pipe_or_adapter,
|
|
562
566
|
remove_cached_stats,
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
5
|
+
CacheNotExistError,
|
|
6
|
+
)
|
|
4
7
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
5
8
|
CachedBlocks_Pattern_Base,
|
|
6
9
|
)
|
|
@@ -16,6 +19,70 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
16
19
|
ForwardPattern.Pattern_5,
|
|
17
20
|
]
|
|
18
21
|
|
|
22
|
+
def call_blocks(
|
|
23
|
+
self,
|
|
24
|
+
hidden_states: torch.Tensor,
|
|
25
|
+
*args,
|
|
26
|
+
**kwargs,
|
|
27
|
+
):
|
|
28
|
+
# Call all blocks to process the hidden states without cache.
|
|
29
|
+
new_encoder_hidden_states = None
|
|
30
|
+
for block in self.transformer_blocks:
|
|
31
|
+
hidden_states = block(
|
|
32
|
+
hidden_states,
|
|
33
|
+
*args,
|
|
34
|
+
**kwargs,
|
|
35
|
+
)
|
|
36
|
+
hidden_states, new_encoder_hidden_states = self._process_outputs(
|
|
37
|
+
hidden_states
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return hidden_states, new_encoder_hidden_states
|
|
41
|
+
|
|
42
|
+
@torch.compiler.disable
|
|
43
|
+
def _process_outputs(
|
|
44
|
+
self, hidden_states: torch.Tensor | tuple
|
|
45
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
46
|
+
# Process the outputs for the block.
|
|
47
|
+
new_encoder_hidden_states = None
|
|
48
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
49
|
+
if len(hidden_states) == 2:
|
|
50
|
+
if isinstance(hidden_states[1], torch.Tensor):
|
|
51
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
52
|
+
if not self.forward_pattern.Return_H_First:
|
|
53
|
+
hidden_states, new_encoder_hidden_states = (
|
|
54
|
+
new_encoder_hidden_states,
|
|
55
|
+
hidden_states,
|
|
56
|
+
)
|
|
57
|
+
elif isinstance(hidden_states[0], torch.Tensor):
|
|
58
|
+
hidden_states = hidden_states[0]
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError("Unexpected hidden_states format.")
|
|
61
|
+
else:
|
|
62
|
+
assert (
|
|
63
|
+
len(hidden_states) == 1
|
|
64
|
+
), f"Unexpected output length: {len(hidden_states)}"
|
|
65
|
+
hidden_states = hidden_states[0]
|
|
66
|
+
return hidden_states, new_encoder_hidden_states
|
|
67
|
+
|
|
68
|
+
@torch.compiler.disable
|
|
69
|
+
def _forward_outputs(
|
|
70
|
+
self,
|
|
71
|
+
hidden_states: torch.Tensor,
|
|
72
|
+
new_encoder_hidden_states: torch.Tensor | None,
|
|
73
|
+
) -> (
|
|
74
|
+
torch.Tensor
|
|
75
|
+
| tuple[torch.Tensor, torch.Tensor]
|
|
76
|
+
| tuple[torch.Tensor, None]
|
|
77
|
+
):
|
|
78
|
+
if self.forward_pattern.Return_H_Only:
|
|
79
|
+
return hidden_states
|
|
80
|
+
else:
|
|
81
|
+
if self.forward_pattern.Return_H_First:
|
|
82
|
+
return (hidden_states, new_encoder_hidden_states)
|
|
83
|
+
else:
|
|
84
|
+
return (new_encoder_hidden_states, hidden_states)
|
|
85
|
+
|
|
19
86
|
def forward(
|
|
20
87
|
self,
|
|
21
88
|
hidden_states: torch.Tensor,
|
|
@@ -23,8 +90,19 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
23
90
|
**kwargs,
|
|
24
91
|
):
|
|
25
92
|
# Use it's own cache context.
|
|
26
|
-
|
|
27
|
-
|
|
93
|
+
try:
|
|
94
|
+
self.cache_manager.set_context(self.cache_context)
|
|
95
|
+
self._check_cache_params()
|
|
96
|
+
except CacheNotExistError as e:
|
|
97
|
+
logger.warning(f"Cache context not exist: {e}, skip cache.")
|
|
98
|
+
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
99
|
+
hidden_states,
|
|
100
|
+
*args,
|
|
101
|
+
**kwargs,
|
|
102
|
+
)
|
|
103
|
+
return self._forward_outputs(
|
|
104
|
+
hidden_states, new_encoder_hidden_states
|
|
105
|
+
)
|
|
28
106
|
|
|
29
107
|
original_hidden_states = hidden_states
|
|
30
108
|
# Call first `n` blocks to process the hidden states for
|
|
@@ -35,7 +113,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
35
113
|
**kwargs,
|
|
36
114
|
)
|
|
37
115
|
|
|
38
|
-
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
116
|
+
Fn_hidden_states_residual = hidden_states - original_hidden_states.to(
|
|
117
|
+
hidden_states.device
|
|
118
|
+
)
|
|
39
119
|
del original_hidden_states
|
|
40
120
|
|
|
41
121
|
self.cache_manager.mark_step_begin()
|
|
@@ -147,15 +227,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
147
227
|
|
|
148
228
|
torch._dynamo.graph_break()
|
|
149
229
|
|
|
150
|
-
return (
|
|
151
|
-
hidden_states
|
|
152
|
-
if self.forward_pattern.Return_H_Only
|
|
153
|
-
else (
|
|
154
|
-
(hidden_states, new_encoder_hidden_states)
|
|
155
|
-
if self.forward_pattern.Return_H_First
|
|
156
|
-
else (new_encoder_hidden_states, hidden_states)
|
|
157
|
-
)
|
|
158
|
-
)
|
|
230
|
+
return self._forward_outputs(hidden_states, new_encoder_hidden_states)
|
|
159
231
|
|
|
160
232
|
def call_Fn_blocks(
|
|
161
233
|
self,
|
|
@@ -170,13 +242,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
170
242
|
*args,
|
|
171
243
|
**kwargs,
|
|
172
244
|
)
|
|
173
|
-
|
|
174
|
-
hidden_states
|
|
175
|
-
|
|
176
|
-
hidden_states, new_encoder_hidden_states = (
|
|
177
|
-
new_encoder_hidden_states,
|
|
178
|
-
hidden_states,
|
|
179
|
-
)
|
|
245
|
+
hidden_states, new_encoder_hidden_states = self._process_outputs(
|
|
246
|
+
hidden_states
|
|
247
|
+
)
|
|
180
248
|
|
|
181
249
|
return hidden_states, new_encoder_hidden_states
|
|
182
250
|
|
|
@@ -194,16 +262,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
194
262
|
*args,
|
|
195
263
|
**kwargs,
|
|
196
264
|
)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
hidden_states,
|
|
203
|
-
)
|
|
265
|
+
|
|
266
|
+
hidden_states, new_encoder_hidden_states = self._process_outputs(
|
|
267
|
+
hidden_states
|
|
268
|
+
)
|
|
269
|
+
|
|
204
270
|
# compute hidden_states residual
|
|
205
271
|
hidden_states = hidden_states.contiguous()
|
|
206
|
-
hidden_states_residual = hidden_states - original_hidden_states
|
|
272
|
+
hidden_states_residual = hidden_states - original_hidden_states.to(
|
|
273
|
+
hidden_states.device
|
|
274
|
+
)
|
|
207
275
|
|
|
208
276
|
return (
|
|
209
277
|
hidden_states,
|
|
@@ -227,12 +295,9 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
227
295
|
*args,
|
|
228
296
|
**kwargs,
|
|
229
297
|
)
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
new_encoder_hidden_states,
|
|
235
|
-
hidden_states,
|
|
236
|
-
)
|
|
298
|
+
|
|
299
|
+
hidden_states, new_encoder_hidden_states = self._process_outputs(
|
|
300
|
+
hidden_states
|
|
301
|
+
)
|
|
237
302
|
|
|
238
303
|
return hidden_states, new_encoder_hidden_states
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
import asyncio
|
|
3
2
|
import torch
|
|
4
3
|
import torch.distributed as dist
|
|
5
4
|
|
|
6
|
-
from typing import List
|
|
7
5
|
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
8
6
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
9
7
|
CachedContextManager,
|
|
8
|
+
CacheNotExistError,
|
|
10
9
|
)
|
|
11
10
|
from cache_dit.cache_factory import ForwardPattern
|
|
12
11
|
from cache_dit.logger import init_logger
|
|
@@ -47,7 +46,6 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
47
46
|
self.cache_prefix = cache_prefix
|
|
48
47
|
self.cache_context = cache_context
|
|
49
48
|
self.cache_manager = cache_manager
|
|
50
|
-
self.pending_tasks: List[asyncio.Task] = []
|
|
51
49
|
|
|
52
50
|
self._check_forward_pattern()
|
|
53
51
|
logger.info(
|
|
@@ -111,6 +109,62 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
111
109
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
112
110
|
)
|
|
113
111
|
|
|
112
|
+
def call_blocks(
|
|
113
|
+
self,
|
|
114
|
+
hidden_states: torch.Tensor,
|
|
115
|
+
encoder_hidden_states: torch.Tensor,
|
|
116
|
+
*args,
|
|
117
|
+
**kwargs,
|
|
118
|
+
):
|
|
119
|
+
# Call all blocks to process the hidden states without cache.
|
|
120
|
+
for block in self.transformer_blocks:
|
|
121
|
+
hidden_states = block(
|
|
122
|
+
hidden_states,
|
|
123
|
+
encoder_hidden_states,
|
|
124
|
+
*args,
|
|
125
|
+
**kwargs,
|
|
126
|
+
)
|
|
127
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
128
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
129
|
+
if not self.forward_pattern.Return_H_First:
|
|
130
|
+
hidden_states, encoder_hidden_states = (
|
|
131
|
+
encoder_hidden_states,
|
|
132
|
+
hidden_states,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return hidden_states, encoder_hidden_states
|
|
136
|
+
|
|
137
|
+
@torch.compiler.disable
|
|
138
|
+
def _process_outputs(
|
|
139
|
+
self,
|
|
140
|
+
hidden_states: torch.Tensor | tuple,
|
|
141
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
142
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
143
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
144
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
145
|
+
if not self.forward_pattern.Return_H_First:
|
|
146
|
+
hidden_states, encoder_hidden_states = (
|
|
147
|
+
encoder_hidden_states,
|
|
148
|
+
hidden_states,
|
|
149
|
+
)
|
|
150
|
+
return hidden_states, encoder_hidden_states
|
|
151
|
+
|
|
152
|
+
@torch.compiler.disable
|
|
153
|
+
def _forward_outputs(
|
|
154
|
+
self,
|
|
155
|
+
hidden_states: torch.Tensor,
|
|
156
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
157
|
+
) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:
|
|
158
|
+
return (
|
|
159
|
+
hidden_states
|
|
160
|
+
if self.forward_pattern.Return_H_Only
|
|
161
|
+
else (
|
|
162
|
+
(hidden_states, encoder_hidden_states)
|
|
163
|
+
if self.forward_pattern.Return_H_First
|
|
164
|
+
else (encoder_hidden_states, hidden_states)
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
114
168
|
def forward(
|
|
115
169
|
self,
|
|
116
170
|
hidden_states: torch.Tensor,
|
|
@@ -119,8 +173,19 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
119
173
|
**kwargs,
|
|
120
174
|
):
|
|
121
175
|
# Use it's own cache context.
|
|
122
|
-
|
|
123
|
-
|
|
176
|
+
try:
|
|
177
|
+
self.cache_manager.set_context(self.cache_context)
|
|
178
|
+
self._check_cache_params()
|
|
179
|
+
except CacheNotExistError as e:
|
|
180
|
+
logger.warning(f"Cache context not exist: {e}, skip cache.")
|
|
181
|
+
# Call all blocks to process the hidden states.
|
|
182
|
+
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
183
|
+
hidden_states,
|
|
184
|
+
encoder_hidden_states,
|
|
185
|
+
*args,
|
|
186
|
+
**kwargs,
|
|
187
|
+
)
|
|
188
|
+
return self._forward_outputs(hidden_states, encoder_hidden_states)
|
|
124
189
|
|
|
125
190
|
original_hidden_states = hidden_states
|
|
126
191
|
# Call first `n` blocks to process the hidden states for
|
|
@@ -239,15 +304,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
239
304
|
# patch cached stats for blocks or remove it.
|
|
240
305
|
torch._dynamo.graph_break()
|
|
241
306
|
|
|
242
|
-
return (
|
|
243
|
-
hidden_states
|
|
244
|
-
if self.forward_pattern.Return_H_Only
|
|
245
|
-
else (
|
|
246
|
-
(hidden_states, encoder_hidden_states)
|
|
247
|
-
if self.forward_pattern.Return_H_First
|
|
248
|
-
else (encoder_hidden_states, hidden_states)
|
|
249
|
-
)
|
|
250
|
-
)
|
|
307
|
+
return self._forward_outputs(hidden_states, encoder_hidden_states)
|
|
251
308
|
|
|
252
309
|
@torch.compiler.disable
|
|
253
310
|
def _is_parallelized(self):
|
|
@@ -322,13 +379,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
322
379
|
*args,
|
|
323
380
|
**kwargs,
|
|
324
381
|
)
|
|
325
|
-
|
|
326
|
-
hidden_states, encoder_hidden_states
|
|
327
|
-
|
|
328
|
-
hidden_states, encoder_hidden_states = (
|
|
329
|
-
encoder_hidden_states,
|
|
330
|
-
hidden_states,
|
|
331
|
-
)
|
|
382
|
+
hidden_states, encoder_hidden_states = self._process_outputs(
|
|
383
|
+
hidden_states, encoder_hidden_states
|
|
384
|
+
)
|
|
332
385
|
|
|
333
386
|
return hidden_states, encoder_hidden_states
|
|
334
387
|
|
|
@@ -348,13 +401,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
348
401
|
*args,
|
|
349
402
|
**kwargs,
|
|
350
403
|
)
|
|
351
|
-
|
|
352
|
-
hidden_states, encoder_hidden_states
|
|
353
|
-
|
|
354
|
-
hidden_states, encoder_hidden_states = (
|
|
355
|
-
encoder_hidden_states,
|
|
356
|
-
hidden_states,
|
|
357
|
-
)
|
|
404
|
+
hidden_states, encoder_hidden_states = self._process_outputs(
|
|
405
|
+
hidden_states, encoder_hidden_states
|
|
406
|
+
)
|
|
358
407
|
|
|
359
408
|
# compute hidden_states residual
|
|
360
409
|
hidden_states = hidden_states.contiguous()
|
|
@@ -396,12 +445,8 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
396
445
|
*args,
|
|
397
446
|
**kwargs,
|
|
398
447
|
)
|
|
399
|
-
|
|
400
|
-
hidden_states, encoder_hidden_states
|
|
401
|
-
|
|
402
|
-
hidden_states, encoder_hidden_states = (
|
|
403
|
-
encoder_hidden_states,
|
|
404
|
-
hidden_states,
|
|
405
|
-
)
|
|
448
|
+
hidden_states, encoder_hidden_states = self._process_outputs(
|
|
449
|
+
hidden_states, encoder_hidden_states
|
|
450
|
+
)
|
|
406
451
|
|
|
407
452
|
return hidden_states, encoder_hidden_states
|
|
@@ -14,6 +14,10 @@ from cache_dit.logger import init_logger
|
|
|
14
14
|
logger = init_logger(__name__)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
class CacheNotExistError(Exception):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
17
21
|
class CachedContextManager:
|
|
18
22
|
# Each Pipeline should have it's own context manager instance.
|
|
19
23
|
|
|
@@ -27,16 +31,19 @@ class CachedContextManager:
|
|
|
27
31
|
self._cached_context_manager[_context.name] = _context
|
|
28
32
|
return _context
|
|
29
33
|
|
|
30
|
-
def set_context(self, cached_context: CachedContext | str):
|
|
34
|
+
def set_context(self, cached_context: CachedContext | str) -> CachedContext:
|
|
31
35
|
if isinstance(cached_context, CachedContext):
|
|
32
36
|
self._current_context = cached_context
|
|
33
37
|
else:
|
|
38
|
+
if cached_context not in self._cached_context_manager:
|
|
39
|
+
raise CacheNotExistError("Context not exist!")
|
|
34
40
|
self._current_context = self._cached_context_manager[cached_context]
|
|
41
|
+
return self._current_context
|
|
35
42
|
|
|
36
43
|
def get_context(self, name: str = None) -> CachedContext:
|
|
37
44
|
if name is not None:
|
|
38
45
|
if name not in self._cached_context_manager:
|
|
39
|
-
raise
|
|
46
|
+
raise CacheNotExistError("Context not exist!")
|
|
40
47
|
return self._cached_context_manager[name]
|
|
41
48
|
return self._current_context
|
|
42
49
|
|
|
@@ -38,23 +38,43 @@ def enable_cache(
|
|
|
38
38
|
BlockAdapter,
|
|
39
39
|
]:
|
|
40
40
|
r"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
41
|
+
The `enable_cache` function serves as a unified caching interface designed to optimize the performance
|
|
42
|
+
of diffusion transformer models by implementing an intelligent caching mechanism known as `DBCache`.
|
|
43
|
+
This API is engineered to be compatible with nearly `all` diffusion transformer architectures that
|
|
44
|
+
feature transformer blocks adhering to standard input-output patterns, eliminating the need for
|
|
45
|
+
architecture-specific modifications.
|
|
46
|
+
|
|
47
|
+
By strategically caching intermediate outputs of transformer blocks during the diffusion process,
|
|
48
|
+
`DBCache` significantly reduces redundant computations without compromising generation quality.
|
|
49
|
+
The caching mechanism works by tracking residual differences between consecutive steps, allowing
|
|
50
|
+
the model to reuse previously computed features when these differences fall below a configurable
|
|
51
|
+
threshold. This approach maintains a balance between computational efficiency and output precision.
|
|
52
|
+
|
|
53
|
+
The default configuration (`F8B0, 8 warmup steps, unlimited cached steps`) is carefully tuned to
|
|
54
|
+
provide an optimal tradeoff for most common use cases. The "F8B0" configuration indicates that
|
|
55
|
+
the first 8 transformer blocks are used to compute stable feature differences, while no final
|
|
56
|
+
blocks are employed for additional fusion. The warmup phase ensures the model establishes
|
|
57
|
+
sufficient feature representation before caching begins, preventing potential degradation of
|
|
58
|
+
output quality.
|
|
59
|
+
|
|
60
|
+
This function seamlessly integrates with both standard diffusion pipelines and custom block
|
|
61
|
+
adapters, making it versatile for various deployment scenarios—from research prototyping to
|
|
62
|
+
production environments where inference speed is critical. By abstracting the complexity of
|
|
63
|
+
caching logic behind a simple interface, it enables developers to enhance model performance
|
|
64
|
+
with minimal code changes.
|
|
46
65
|
|
|
47
66
|
Args:
|
|
48
67
|
pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
|
|
49
68
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
50
69
|
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
51
70
|
for the usgae of BlockAdapter.
|
|
71
|
+
|
|
52
72
|
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
53
73
|
Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
|
|
54
74
|
Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
55
|
-
Specifies that `DBCache` uses the
|
|
56
|
-
|
|
57
|
-
|
|
75
|
+
Specifies that `DBCache` uses the**first n**Transformer blocks to fit the information at time step t,
|
|
76
|
+
enabling the calculation of a more stable L1 difference and delivering more accurate information
|
|
77
|
+
to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
58
78
|
for more details of DBCache.
|
|
59
79
|
Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
60
80
|
Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
@@ -77,14 +97,18 @@ def enable_cache(
|
|
|
77
97
|
and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
78
98
|
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
79
99
|
cfg_compute_first (`bool`, *required*, defaults to False):
|
|
80
|
-
|
|
100
|
+
Whether to compute cfg forward first, default is False, meaning:
|
|
101
|
+
0, 2, 4, ..., -> non-CFG step;
|
|
81
102
|
1, 3, 5, ... -> CFG step.
|
|
82
103
|
cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
83
|
-
|
|
84
|
-
use the computed
|
|
104
|
+
Whether to compute separate difference values for CFG and non-CFG steps, default is True.
|
|
105
|
+
If False, we will use the computed difference from the current non-CFG transformer step
|
|
106
|
+
for the current CFG step.
|
|
107
|
+
|
|
85
108
|
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
86
|
-
Config for calibrator
|
|
87
|
-
with specific calibrator, such as taylorseer, foca, and so on.
|
|
109
|
+
Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
|
|
110
|
+
with a specific calibrator, such as taylorseer, foca, and so on.
|
|
111
|
+
|
|
88
112
|
params_modifiers ('ParamsModifier', *optional*, defaults to None):
|
|
89
113
|
Modify cache context params for specific blocks. The configurable params listed belows:
|
|
90
114
|
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
@@ -93,6 +117,7 @@ def enable_cache(
|
|
|
93
117
|
The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
|
|
94
118
|
**kwargs: (`dict`, *optional*, defaults to {}):
|
|
95
119
|
The same as 'kwargs' param in cache_dit.enable_cache() interface.
|
|
120
|
+
|
|
96
121
|
kwargs (`dict`, *optional*, defaults to {})
|
|
97
122
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
98
123
|
for more details.
|
|
@@ -10,3 +10,6 @@ from cache_dit.cache_factory.patch_functors.functor_hidream import (
|
|
|
10
10
|
from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
|
|
11
11
|
HunyuanDiTPatchFunctor,
|
|
12
12
|
)
|
|
13
|
+
from cache_dit.cache_factory.patch_functors.functor_qwen_image_controlnet import (
|
|
14
|
+
QwenImageControlNetPatchFunctor,
|
|
15
|
+
)
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Tuple, Optional, Dict, Any, Union, List
|
|
4
|
+
from diffusers import QwenImageTransformer2DModel
|
|
5
|
+
from diffusers.models.transformers.transformer_qwenimage import (
|
|
6
|
+
QwenImageTransformerBlock,
|
|
7
|
+
Transformer2DModelOutput,
|
|
8
|
+
)
|
|
9
|
+
from diffusers.utils import (
|
|
10
|
+
USE_PEFT_BACKEND,
|
|
11
|
+
scale_lora_layers,
|
|
12
|
+
unscale_lora_layers,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
15
|
+
PatchFunctor,
|
|
16
|
+
)
|
|
17
|
+
from cache_dit.logger import init_logger
|
|
18
|
+
|
|
19
|
+
logger = init_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class QwenImageControlNetPatchFunctor(PatchFunctor):
|
|
23
|
+
|
|
24
|
+
def apply(
|
|
25
|
+
self,
|
|
26
|
+
transformer: QwenImageTransformer2DModel,
|
|
27
|
+
**kwargs,
|
|
28
|
+
) -> QwenImageTransformer2DModel:
|
|
29
|
+
if hasattr(transformer, "_is_patched"):
|
|
30
|
+
return transformer
|
|
31
|
+
|
|
32
|
+
is_patched = False
|
|
33
|
+
|
|
34
|
+
_index_block = 0
|
|
35
|
+
_num_blocks = len(transformer.transformer_blocks)
|
|
36
|
+
for block in transformer.transformer_blocks:
|
|
37
|
+
assert isinstance(block, QwenImageTransformerBlock)
|
|
38
|
+
block._index_block = _index_block
|
|
39
|
+
block._num_blocks = _num_blocks
|
|
40
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
41
|
+
_index_block += 1
|
|
42
|
+
|
|
43
|
+
is_patched = True
|
|
44
|
+
cls_name = transformer.__class__.__name__
|
|
45
|
+
|
|
46
|
+
if is_patched:
|
|
47
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
48
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
49
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
50
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
51
|
+
"parallized forward and cause a downgrade of performance."
|
|
52
|
+
)
|
|
53
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
54
|
+
transformer
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
transformer._is_patched = is_patched # True or False
|
|
58
|
+
|
|
59
|
+
logger.info(
|
|
60
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
61
|
+
f"Patch: {is_patched}."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return transformer
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def __patch_block_forward__(
|
|
68
|
+
self: QwenImageTransformerBlock,
|
|
69
|
+
hidden_states: torch.Tensor,
|
|
70
|
+
encoder_hidden_states: torch.Tensor,
|
|
71
|
+
encoder_hidden_states_mask: torch.Tensor,
|
|
72
|
+
temb: torch.Tensor,
|
|
73
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
74
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
75
|
+
controlnet_block_samples: Optional[List[torch.Tensor]] = None,
|
|
76
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
77
|
+
# Get modulation parameters for both streams
|
|
78
|
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
|
79
|
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
|
80
|
+
|
|
81
|
+
# Split modulation parameters for norm1 and norm2
|
|
82
|
+
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
|
83
|
+
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
|
84
|
+
|
|
85
|
+
# Process image stream - norm1 + modulation
|
|
86
|
+
img_normed = self.img_norm1(hidden_states)
|
|
87
|
+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
|
88
|
+
|
|
89
|
+
# Process text stream - norm1 + modulation
|
|
90
|
+
txt_normed = self.txt_norm1(encoder_hidden_states)
|
|
91
|
+
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
|
92
|
+
|
|
93
|
+
# Use QwenAttnProcessor2_0 for joint attention computation
|
|
94
|
+
# This directly implements the DoubleStreamLayerMegatron logic:
|
|
95
|
+
# 1. Computes QKV for both streams
|
|
96
|
+
# 2. Applies QK normalization and RoPE
|
|
97
|
+
# 3. Concatenates and runs joint attention
|
|
98
|
+
# 4. Splits results back to separate streams
|
|
99
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
100
|
+
attn_output = self.attn(
|
|
101
|
+
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
|
102
|
+
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
|
103
|
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
|
104
|
+
image_rotary_emb=image_rotary_emb,
|
|
105
|
+
**joint_attention_kwargs,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
|
109
|
+
img_attn_output, txt_attn_output = attn_output
|
|
110
|
+
|
|
111
|
+
# Apply attention gates and add residual (like in Megatron)
|
|
112
|
+
hidden_states = hidden_states + img_gate1 * img_attn_output
|
|
113
|
+
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
|
114
|
+
|
|
115
|
+
# Process image stream - norm2 + MLP
|
|
116
|
+
img_normed2 = self.img_norm2(hidden_states)
|
|
117
|
+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
|
118
|
+
img_mlp_output = self.img_mlp(img_modulated2)
|
|
119
|
+
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
|
120
|
+
|
|
121
|
+
# Process text stream - norm2 + MLP
|
|
122
|
+
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
|
123
|
+
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
|
124
|
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
|
125
|
+
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
|
126
|
+
|
|
127
|
+
# Clip to prevent overflow for fp16
|
|
128
|
+
if encoder_hidden_states.dtype == torch.float16:
|
|
129
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
|
130
|
+
if hidden_states.dtype == torch.float16:
|
|
131
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
|
132
|
+
|
|
133
|
+
if controlnet_block_samples is not None:
|
|
134
|
+
# Add ControlNet conditioning
|
|
135
|
+
num_blocks = self._num_blocks
|
|
136
|
+
index_block = self._index_block
|
|
137
|
+
interval_control = num_blocks / len(controlnet_block_samples)
|
|
138
|
+
interval_control = int(np.ceil(interval_control))
|
|
139
|
+
hidden_states = (
|
|
140
|
+
hidden_states
|
|
141
|
+
+ controlnet_block_samples[index_block // interval_control]
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return encoder_hidden_states, hidden_states
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def __patch_transformer_forward__(
|
|
148
|
+
self: QwenImageTransformer2DModel,
|
|
149
|
+
hidden_states: torch.Tensor,
|
|
150
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
151
|
+
encoder_hidden_states_mask: torch.Tensor = None,
|
|
152
|
+
timestep: torch.LongTensor = None,
|
|
153
|
+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
|
154
|
+
txt_seq_lens: Optional[List[int]] = None,
|
|
155
|
+
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
|
156
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
157
|
+
controlnet_block_samples=None,
|
|
158
|
+
return_dict: bool = True,
|
|
159
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
160
|
+
"""
|
|
161
|
+
The [`QwenTransformer2DModel`] forward method.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
|
165
|
+
Input `hidden_states`.
|
|
166
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
|
167
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
168
|
+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
|
169
|
+
Mask of the input conditions.
|
|
170
|
+
timestep ( `torch.LongTensor`):
|
|
171
|
+
Used to indicate denoising step.
|
|
172
|
+
attention_kwargs (`dict`, *optional*):
|
|
173
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
174
|
+
`self.processor` in
|
|
175
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
176
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
|
177
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
|
178
|
+
tuple.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
|
182
|
+
`tuple` where the first element is the sample tensor.
|
|
183
|
+
"""
|
|
184
|
+
if attention_kwargs is not None:
|
|
185
|
+
attention_kwargs = attention_kwargs.copy()
|
|
186
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
187
|
+
else:
|
|
188
|
+
lora_scale = 1.0
|
|
189
|
+
|
|
190
|
+
if USE_PEFT_BACKEND:
|
|
191
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
192
|
+
scale_lora_layers(self, lora_scale)
|
|
193
|
+
else:
|
|
194
|
+
if (
|
|
195
|
+
attention_kwargs is not None
|
|
196
|
+
and attention_kwargs.get("scale", None) is not None
|
|
197
|
+
):
|
|
198
|
+
logger.warning(
|
|
199
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
hidden_states = self.img_in(hidden_states)
|
|
203
|
+
|
|
204
|
+
timestep = timestep.to(hidden_states.dtype)
|
|
205
|
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
|
206
|
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
|
207
|
+
|
|
208
|
+
if guidance is not None:
|
|
209
|
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
210
|
+
|
|
211
|
+
temb = (
|
|
212
|
+
self.time_text_embed(timestep, hidden_states)
|
|
213
|
+
if guidance is None
|
|
214
|
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
image_rotary_emb = self.pos_embed(
|
|
218
|
+
img_shapes, txt_seq_lens, device=hidden_states.device
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
|
222
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
223
|
+
encoder_hidden_states, hidden_states = (
|
|
224
|
+
self._gradient_checkpointing_func(
|
|
225
|
+
block,
|
|
226
|
+
hidden_states,
|
|
227
|
+
encoder_hidden_states,
|
|
228
|
+
encoder_hidden_states_mask,
|
|
229
|
+
temb,
|
|
230
|
+
image_rotary_emb,
|
|
231
|
+
controlnet_block_samples,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
else:
|
|
236
|
+
encoder_hidden_states, hidden_states = block(
|
|
237
|
+
hidden_states=hidden_states,
|
|
238
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
239
|
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
|
240
|
+
temb=temb,
|
|
241
|
+
image_rotary_emb=image_rotary_emb,
|
|
242
|
+
controlnet_block_samples=controlnet_block_samples,
|
|
243
|
+
joint_attention_kwargs=attention_kwargs,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# # controlnet residual
|
|
247
|
+
# if controlnet_block_samples is not None:
|
|
248
|
+
# interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
|
249
|
+
# interval_control = int(np.ceil(interval_control))
|
|
250
|
+
# hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
|
251
|
+
|
|
252
|
+
# Use only the image part (hidden_states) from the dual-stream blocks
|
|
253
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
254
|
+
output = self.proj_out(hidden_states)
|
|
255
|
+
|
|
256
|
+
if USE_PEFT_BACKEND:
|
|
257
|
+
# remove `lora_scale` from each PEFT layer
|
|
258
|
+
unscale_lora_layers(self, lora_scale)
|
|
259
|
+
|
|
260
|
+
if not return_dict:
|
|
261
|
+
return (output,)
|
|
262
|
+
|
|
263
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -182,12 +182,16 @@ def quantize_ao(
|
|
|
182
182
|
force_empty_cache()
|
|
183
183
|
|
|
184
184
|
logger.info(
|
|
185
|
+
f"Quantized Module: {module.__class__.__name__:>5}\n"
|
|
185
186
|
f"Quantized Method: {quant_type:>5}\n"
|
|
186
187
|
f"Quantized Linear Layers: {num_quant_linear:>5}\n"
|
|
187
188
|
f"Skipped Linear Layers: {num_skip_linear:>5}\n"
|
|
188
189
|
f"Total Linear Layers: {num_linear_layers:>5}\n"
|
|
189
190
|
f"Total (all) Layers: {num_layers:>5}"
|
|
190
191
|
)
|
|
192
|
+
|
|
193
|
+
module._quantize_type = quant_type
|
|
194
|
+
module._is_quantized = True
|
|
191
195
|
return module
|
|
192
196
|
|
|
193
197
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.2
|
|
4
4
|
Summary: A Unified, Flexible and Training-free Cache Acceleration Framework for 🤗Diffusers.
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -63,73 +63,17 @@ Dynamic: requires-python
|
|
|
63
63
|
</div>
|
|
64
64
|
<p align="center">
|
|
65
65
|
🎉Now, <b>cache-dit</b> covers almost <b>All</b> Diffusers' <b>DiT</b> Pipelines🎉<br>
|
|
66
|
-
🔥<a href="./examples/">Qwen-Image</a> | <a href="./examples/">
|
|
67
|
-
🔥<a href="./examples/">
|
|
68
|
-
🔥<a href="./examples/">
|
|
69
|
-
🔥<a href="./examples/">
|
|
70
|
-
🔥<a href="./examples/">
|
|
66
|
+
🔥<a href="./examples/pipeline">Qwen-Image</a> | <a href="./examples/pipeline">Qwen-Image-Edit</a> | <a href="./examples/pipeline">Qwen-Image-Edit-Plus </a> 🔥<br>
|
|
67
|
+
🔥<a href="./examples/pipeline">FLUX.1</a> | <a href="./examples/pipeline">Qwen-Image-Lightning 4/8 Steps</a> | <a href="./examples/pipeline"> Wan 2.1 </a> | <a href="./examples/pipeline"> Wan 2.2 </a>🔥<br>
|
|
68
|
+
🔥<a href="./examples/pipeline">HunyuanImage-2.1</a> | <a href="./examples/pipeline">HunyuanVideo</a> | <a href="./examples/pipeline">HunyuanDiT</a> | <a href="./examples/pipeline">HiDream</a> | <a href="./examples/pipeline">AuraFlow</a>🔥<br>
|
|
69
|
+
🔥<a href="./examples/pipeline">CogView3Plus</a> | <a href="./examples/pipeline">CogView4</a> | <a href="./examples/pipeline">LTXVideo</a> | <a href="./examples/pipeline">CogVideoX</a> | <a href="./examples/">CogVideoX 1.5</a> | <a href="./examples/">ConsisID</a>🔥<br>
|
|
70
|
+
🔥<a href="./examples/pipeline">Cosmos</a> | <a href="./examples/pipeline">SkyReelsV2</a> | <a href="./examples/pipeline">VisualCloze</a> | <a href="./examples/pipeline">OmniGen 1/2</a> | <a href="./examples/pipeline">Lumina 1/2</a> | <a href="./examples/pipeline">PixArt</a>🔥<br>
|
|
71
|
+
🔥<a href="./examples/pipeline">Chroma</a> | <a href="./examples/pipeline">Sana</a> | <a href="./examples/pipeline">Allegro</a> | <a href="./examples/pipeline">Mochi</a> | <a href="./examples/pipeline">SD 3/3.5</a> | <a href="./examples/pipeline">Amused</a> | <a href="./examples/pipeline"> ... </a> | <a href="./examples/pipeline">DiT-XL</a>🔥
|
|
71
72
|
<br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️
|
|
72
73
|
</p>
|
|
73
74
|
</div>
|
|
74
75
|
|
|
75
|
-
## 🔥Hightlight <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
|
|
76
|
-
|
|
77
76
|
<div align='center'>
|
|
78
|
-
<details>
|
|
79
|
-
<summary> 🔥<b>Click</b> here to show <b>Important News</b>: First API-stable (v1.0.0) Release🔥 </summary>
|
|
80
|
-
|
|
81
|
-
2025.09.25: 🎉The **first API-stable version (v1.0.0)** of cache-dit has finally been released!<br>
|
|
82
|
-
2025.09.25: 🔥**cache-dit** has joined the Diffusers community ecosystem: <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a><br>
|
|
83
|
-
2025.09.10: 🎉Day 1 support [**HunyuanImage-2.1**](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with **1.7x↑🎉** speedup! Check this [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_hunyuan_image_2.1.py).<br>
|
|
84
|
-
2025.09.08: 🔥[**Qwen-Image-Lightning**](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_lightning.py) **7.1/3.5 steps🎉** inference with **[DBCache: F16B16](https://github.com/vipshop/cache-dit)**.<br>
|
|
85
|
-
2025.09.03: 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x↑🎉** speedup! Please refer to [run_wan_2.2.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.<br>
|
|
86
|
-
2025.08.12: 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).<br>
|
|
87
|
-
2025.08.11: 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x↑🎉** speedup! Please refer to [run_qwen_image.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image.py) as an example.<br>
|
|
88
|
-
2025.09.08: 🎉First caching mechanism in [Wan2.2](https://github.com/Wan-Video/Wan2.2) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/Wan-Video/Wan2.2/pull/127) for more details.<br>
|
|
89
|
-
2025.09.08: 🎉First caching mechanism in [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/ModelTC/Qwen-Image-Lightning/pull/35).<br>
|
|
90
|
-
2025.08.19: 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x↑🎉** speedup! Check the example: [run_qwen_image_edit.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_edit.py).<br>
|
|
91
|
-
2025.09.01: 📚[**Hybird Forward Pattern**](#unified) is supported! Please check [FLUX.1-dev](https://github.com/vipshop/cache-dit/blob/main/examples/run_flux_adapter.py) as an example.<br>
|
|
92
|
-
2025.08.10: 🔥[**FLUX.1-Kontext-dev**](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_flux_kontext.py) as an example.<br>
|
|
93
|
-
2025.07.18: 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).<br>
|
|
94
|
-
2025.07.13: 🎉[**FLUX.1-dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + **compile + FP8 DQ**.<br>
|
|
95
|
-
|
|
96
|
-
</details>
|
|
97
|
-
</div>
|
|
98
|
-
|
|
99
|
-
We are excited to announce that the **first API-stable version (v1.0.0)** of cache-dit has finally been released!
|
|
100
|
-
**[cache-dit](https://github.com/vipshop/cache-dit)** is a **Unified**, **Flexible**, and **Training-free** cache acceleration framework for 🤗 Diffusers, enabling cache acceleration with just **one line** of code. Key features include **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, **DBCache**, **TaylorSeer Calibrator**, and **Cache CFG**.
|
|
101
|
-
|
|
102
|
-
```bash
|
|
103
|
-
pip3 install -U cache-dit # pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
104
|
-
```
|
|
105
|
-
You can install the stable release of cache-dit from PyPI, or the latest development version from GitHub. Then try ♥️ Cache Acceleration with just **one line** of code ~ ♥️
|
|
106
|
-
```python
|
|
107
|
-
>>> import cache_dit
|
|
108
|
-
>>> from diffusers import DiffusionPipeline
|
|
109
|
-
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
|
|
110
|
-
>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
|
|
111
|
-
>>> output = pipe(...) # Just call the pipe as normal.
|
|
112
|
-
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
113
|
-
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
114
|
-
```
|
|
115
|
-
|
|
116
|
-
### 📚Core Features
|
|
117
|
-
|
|
118
|
-
- **[🎉Full 🤗Diffusers Support](./docs/User_Guide.md#supported-pipelines)**: Notably, **[cache-dit](https://github.com/vipshop/cache-dit)** now supports nearly **all** of Diffusers' **DiT-based** pipelines, such as Qwen-Image, FLUX.1, Qwen-Image-Lightning, HunyuanImage-2.1, HunyuanVideo, HunyuanDiT, Wan 2.1/2.2, HiDream, AuraFlow, CogView3Plus, CogView4, LTXVideo, CogVideoX 1.5, ConsisID, SkyReelsV2, VisualCloze, OmniGen, Lumina, PixArt, Chroma, Sana, Allegro, Mochi, SD 3.5, Amused, and DiT-XL.
|
|
119
|
-
- **[🎉Extremely Easy to Use](./docs/User_Guide.md#unified-cache-apis)**: In most cases, you only need **one line** of code: `cache_dit.enable_cache(...)`. After calling this API, just use the pipeline as normal.
|
|
120
|
-
- **[🎉Easy New Model Integration](./docs/User_Guide.md#automatic-block-adapter)**: Features like **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, and **Patch Functor** make it highly functional and flexible. For example, we achieved 🎉 Day 1 support for [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with 1.7x speedup w/o precision loss—even before it was available in the Diffusers library.
|
|
121
|
-
- **[🎉State-of-the-Art Performance](./bench/)**: Compared with algorithms including Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa, cache-dit achieves the best accuracy when the speedup ratio is below 4x.
|
|
122
|
-
- **[🎉Support for 4/8-Steps Distilled Models](./bench/)**: Surprisingly, cache-dit's **DBCache** works for extremely few-step distilled models—something many other methods fail to do.
|
|
123
|
-
- **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, model CPU offload, sequential CPU offload, group offloading, etc.
|
|
124
|
-
- **[🎉Hybrid Cache Acceleration](./docs/User_Guide.md#taylorseer-calibrator)**: Now supports hybrid **DBCache + Calibrator** schemes (e.g., DBCache + TaylorSeerCalibrator). DBCache acts as the **Indicator** to decide *when* to cache, while the Calibrator decides *how* to cache. More mainstream cache acceleration algorithms (e.g., FoCa) will be supported in the future, along with additional benchmarks—stay tuned for updates!
|
|
125
|
-
- **[🤗Diffusers Ecosystem Integration](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)**: 🔥**cache-dit** has joined the Diffusers community ecosystem as the **first** DiT-specific cache acceleration framework! Check out the documentation here: <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
|
|
126
|
-
|
|
127
|
-

|
|
128
|
-
|
|
129
|
-
<details align='center'>
|
|
130
|
-
<summary>🔥<b>Click</b> here to show many <b>Image/Video</b> cases🔥</summary>
|
|
131
|
-
|
|
132
|
-
<div align='center'>
|
|
133
77
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=124px>
|
|
134
78
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=124px>
|
|
135
79
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/hunyuan_video.C0_L0_Q0_NONE.gif width=126px>
|
|
@@ -150,6 +94,12 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
|
|
|
150
94
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S18.png width=125px>
|
|
151
95
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S24.png width=125px>
|
|
152
96
|
<p><b>🔥Qwen-Image-Edit</b> | Input w/o Edit | Baseline | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.6x↑🎉 | 1.9x↑🎉 </p>
|
|
97
|
+
</div>
|
|
98
|
+
|
|
99
|
+
<details align='center'>
|
|
100
|
+
<summary>🔥<b>Click</b> here to show many <b>Image/Video</b> cases🔥</summary>
|
|
101
|
+
|
|
102
|
+
<div align='center'>
|
|
153
103
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext-cat.C0_L0_Q0_NONE.png width=100px>
|
|
154
104
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_NONE.png width=100px>
|
|
155
105
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S10.png width=100px>
|
|
@@ -218,6 +168,62 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
|
|
|
218
168
|
|
|
219
169
|
</details>
|
|
220
170
|
|
|
171
|
+
## 🔥Hightlight <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
|
|
172
|
+
|
|
173
|
+
We are excited to announce that the **first API-stable version (v1.0.0)** of cache-dit has finally been released!
|
|
174
|
+
**[cache-dit](https://github.com/vipshop/cache-dit)** is a **Unified**, **Flexible**, and **Training-free** cache acceleration framework for 🤗 Diffusers, enabling cache acceleration with just **one line** of code. Key features include **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, **DBCache**, **TaylorSeer Calibrator**, and **Cache CFG**.
|
|
175
|
+
|
|
176
|
+
```bash
|
|
177
|
+
pip3 install -U cache-dit # pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
178
|
+
```
|
|
179
|
+
You can install the stable release of cache-dit from PyPI, or the latest development version from GitHub. Then try ♥️ Cache Acceleration with just **one line** of code ~ ♥️
|
|
180
|
+
```python
|
|
181
|
+
>>> import cache_dit
|
|
182
|
+
>>> from diffusers import DiffusionPipeline
|
|
183
|
+
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
|
|
184
|
+
>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
|
|
185
|
+
>>> output = pipe(...) # Just call the pipe as normal.
|
|
186
|
+
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
187
|
+
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
### 📚Core Features
|
|
191
|
+
|
|
192
|
+
- **[🎉Full 🤗Diffusers Support](./docs/User_Guide.md#supported-pipelines)**: Notably, **[cache-dit](https://github.com/vipshop/cache-dit)** now supports nearly **all** of Diffusers' **DiT-based** pipelines, such as Qwen-Image, FLUX.1, Qwen-Image-Lightning, HunyuanImage-2.1, HunyuanVideo, HunyuanDiT, Wan 2.1/2.2, HiDream, AuraFlow, CogView3Plus, CogView4, LTXVideo, CogVideoX 1.5, ConsisID, SkyReelsV2, VisualCloze, OmniGen, Lumina, PixArt, Chroma, Sana, Allegro, Mochi, SD 3.5, Amused, and DiT-XL.
|
|
193
|
+
- **[🎉Extremely Easy to Use](./docs/User_Guide.md#unified-cache-apis)**: In most cases, you only need **one line** of code: `cache_dit.enable_cache(...)`. After calling this API, just use the pipeline as normal.
|
|
194
|
+
- **[🎉Easy New Model Integration](./docs/User_Guide.md#automatic-block-adapter)**: Features like **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, and **Patch Functor** make it highly functional and flexible. For example, we achieved 🎉 Day 1 support for [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with 1.7x speedup w/o precision loss—even before it was available in the Diffusers library.
|
|
195
|
+
- **[🎉State-of-the-Art Performance](./bench/)**: Compared with algorithms including Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa, cache-dit achieves the best accuracy when the speedup ratio is below 4x.
|
|
196
|
+
- **[🎉Support for 4/8-Steps Distilled Models](./bench/)**: Surprisingly, cache-dit's **DBCache** works for extremely few-step distilled models—something many other methods fail to do.
|
|
197
|
+
- **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, model CPU offload, sequential CPU offload, group offloading, etc.
|
|
198
|
+
- **[🎉Hybrid Cache Acceleration](./docs/User_Guide.md#taylorseer-calibrator)**: Now supports hybrid **DBCache + Calibrator** schemes (e.g., DBCache + TaylorSeerCalibrator). DBCache acts as the **Indicator** to decide *when* to cache, while the Calibrator decides *how* to cache. More mainstream cache acceleration algorithms (e.g., FoCa) will be supported in the future, along with additional benchmarks—stay tuned for updates!
|
|
199
|
+
- **[🤗Diffusers Ecosystem Integration](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)**: 🔥**cache-dit** has joined the Diffusers community ecosystem as the **first** DiT-specific cache acceleration framework! Check out the documentation here: <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
|
|
200
|
+
|
|
201
|
+

|
|
202
|
+
|
|
203
|
+
## 🔥Important News
|
|
204
|
+
|
|
205
|
+
- 2025.10.10: 🔥[**Qwen-Image-ControlNet-Inpainting**](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting) **2.3x↑🎉** speedup! Check the [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_controlnet_inpaint.py).
|
|
206
|
+
- 2025.09.26: 🔥[**Qwen-Image-Edit-Plus(2509)**](https://github.com/QwenLM/Qwen-Image) **2.1x↑🎉** speedup! Please check the [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_edit_plus.py).
|
|
207
|
+
- 2025.09.25: 🎉The **first API-stable version (v1.0.0)** of cache-dit has finally been released!
|
|
208
|
+
- 2025.09.25: 🔥**cache-dit** has joined the Diffusers community ecosystem: <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
|
|
209
|
+
- 2025.09.10: 🎉Day 1 support [**HunyuanImage-2.1**](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with **1.7x↑🎉** speedup! Check this [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_hunyuan_image_2.1.py).
|
|
210
|
+
- 2025.09.08: 🔥[**Qwen-Image-Lightning**](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_lightning.py) **7.1/3.5 steps🎉** inference with **[DBCache: F16B16](https://github.com/vipshop/cache-dit)**.
|
|
211
|
+
- 2025.09.03: 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x↑🎉** speedup! Please refer to [run_wan_2.2.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
|
|
212
|
+
- 2025.08.19: 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x↑🎉** speedup! Check the example: [run_qwen_image_edit.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_edit.py).
|
|
213
|
+
- 2025.08.11: 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x↑🎉** speedup! Please refer to [run_qwen_image.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image.py) as an example.
|
|
214
|
+
|
|
215
|
+
<details>
|
|
216
|
+
<summary>Previous News</summary>
|
|
217
|
+
|
|
218
|
+
- 2025.09.08: 🎉First caching mechanism in [Wan2.2](https://github.com/Wan-Video/Wan2.2) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/Wan-Video/Wan2.2/pull/127) for more details.
|
|
219
|
+
- 2025.09.08: 🎉First caching mechanism in [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/ModelTC/Qwen-Image-Lightning/pull/35).
|
|
220
|
+
- 2025.08.10: 🔥[**FLUX.1-Kontext-dev**](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_flux_kontext.py) as an example.
|
|
221
|
+
- 2025.08.12: 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
222
|
+
- 2025.07.18: 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
|
|
223
|
+
- 2025.07.13: 🎉[**FLUX.1-dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + **compile + FP8 DQ**.
|
|
224
|
+
|
|
225
|
+
</details>
|
|
226
|
+
|
|
221
227
|
## 📚User Guide
|
|
222
228
|
|
|
223
229
|
<div id="user-guide"></div>
|
|
@@ -1,39 +1,40 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=sHRg0swXZZiw6lvSQ53fcVtN9JRayx0az2lXAz5OOGI,1510
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=ZTgKq8LPNy3l9uR2ke-VtLhvvl5l71frQ9wO76n1L5k,704
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/utils.py,sha256=AyYRwi5XBxYBH4GaXxOxv9-X24Te_IYOYwh54t_1d3A,10674
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=vy9I6Ofkj9jWeUoOvh-cY5a9QlDDKfj2FVPlVTf7BeA,1390
|
|
7
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_interface.py,sha256=KseSPyZ9D3m6pmpE7k-uYr0wfBI-hhscG1Nw54GCHxk,12316
|
|
8
8
|
cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
|
|
9
9
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
10
10
|
cache_dit/cache_factory/params_modifier.py,sha256=zYJJsInTYCaYHBZ7mZJOP-PZnkSg3iN1WPewNOayXos,3628
|
|
11
11
|
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
12
|
-
cache_dit/cache_factory/block_adapters/__init__.py,sha256=
|
|
12
|
+
cache_dit/cache_factory/block_adapters/__init__.py,sha256=vM3aDMzPY79Tw4L0hlV2PdA3MFYomnf0eo0BGBo9P78,18087
|
|
13
13
|
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=2TVK_KqiYXC7AKZ2s07fzdOzUoeUBc9P1SzQtLVzhf4,22249
|
|
14
14
|
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
|
|
15
15
|
cache_dit/cache_factory/cache_adapters/__init__.py,sha256=py71WGD3JztQ1uk6qdLVbzYcQ1rvqFidNNaQYo7tqTo,79
|
|
16
|
-
cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=
|
|
16
|
+
cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=HTyZdspd34G6QiJ2qPNoLmGwcxmAnCwpAf91NTIQtl4,21442
|
|
17
17
|
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=mivvm8YOfqT7YHs8y_MzGOGztPw8LxAqKGXuSRXxCv0,3032
|
|
18
18
|
cache_dit/cache_factory/cache_blocks/offload_utils.py,sha256=wusgcqaCrwEjvv7Guy-6VXhNOgPPUrBV2sSVuRmGuvo,3513
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
|
|
20
|
-
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=
|
|
21
|
-
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=
|
|
20
|
+
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=rfq5-WEt-ErY28vcB4ur9E-uCb6BKP0S8v5lTw61ROk,10555
|
|
21
|
+
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=StNW2PyDiXEIxZd30byPUrZZ8jgSiuC_yrly2w7X2LQ,16176
|
|
22
22
|
cache_dit/cache_factory/cache_blocks/pattern_utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
|
|
23
|
-
cache_dit/cache_factory/cache_contexts/__init__.py,sha256=
|
|
23
|
+
cache_dit/cache_factory/cache_contexts/__init__.py,sha256=N3SxFnluXk5q09nhSqKIJCVzEGWzySJWm-vic6dH79E,412
|
|
24
24
|
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=3EhaMCz3VUQ_NF81VgYwWoSEGIvhScPxPYhjL1OcgxE,15240
|
|
25
|
-
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=
|
|
25
|
+
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=X99XnmiY-Us8D2pqJGPKxWcXAhQQpk3xdEWOOOYXIZ4,30465
|
|
26
26
|
cache_dit/cache_factory/cache_contexts/calibrators/__init__.py,sha256=mzYXO8tbytGpJJ9rpPu20kMoj1Iu_7Ym9tjfzV8rA98,5574
|
|
27
27
|
cache_dit/cache_factory/cache_contexts/calibrators/base.py,sha256=mn6ZBkChGpGwN5csrHTUGMoX6BBPvqHXSLbIExiW-EU,748
|
|
28
28
|
cache_dit/cache_factory/cache_contexts/calibrators/foca.py,sha256=nhHGs_hxwW1M942BQDMJb9-9IuHdnOxp774Jrna1bJI,891
|
|
29
29
|
cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py,sha256=aGxr9SpytYznTepDWGPAxWDnuVMSuNyn6uNXnLh2acQ,4001
|
|
30
|
-
cache_dit/cache_factory/patch_functors/__init__.py,sha256=
|
|
30
|
+
cache_dit/cache_factory/patch_functors/__init__.py,sha256=IJZrvSkeHbR_xW-6IzY7sqEhApBsOfPyorQGJutvWH0,652
|
|
31
31
|
cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
|
|
32
32
|
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=xD0Q96VArp1vYBLQ0pcjRIyFB1i_Y7muZ2q07Hz8Oqs,13430
|
|
33
33
|
cache_dit/cache_factory/patch_functors/functor_dit.py,sha256=SDjhzCWa6PoFNN4_upoQEf6DHvW1yJ7zuXMS2VvyJco,3904
|
|
34
34
|
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
|
|
35
35
|
cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=inf4T5UcIa06zVsoLWCNJbb1bEDmGeBGSyC7OL1zpuc,15309
|
|
36
36
|
cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
|
|
37
|
+
cache_dit/cache_factory/patch_functors/functor_qwen_image_controlnet.py,sha256=D5i1Rrq1FQ49liupLcV2DW04moBqLnW9TICzfnMMzIU,10519
|
|
37
38
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
38
39
|
cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
|
|
39
40
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -47,11 +48,11 @@ cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR
|
|
|
47
48
|
cache_dit/metrics/lpips.py,sha256=hrHrmdM-f2B4TKDs0xLqJO5JFaYcCjq2qNIR8oCrVkc,811
|
|
48
49
|
cache_dit/metrics/metrics.py,sha256=AZbQyoavE-djvyRUZ_EfCIrWSQbiWQFo7n2dhn7XptE,40466
|
|
49
50
|
cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
|
|
50
|
-
cache_dit/quantize/quantize_ao.py,sha256=
|
|
51
|
+
cache_dit/quantize/quantize_ao.py,sha256=Pr3u3Qr6qLvFkd8k-_rfcz4Mkjlg36U9BHG2t6Bl-6M,6301
|
|
51
52
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
52
|
-
cache_dit-1.0.
|
|
53
|
-
cache_dit-1.0.
|
|
54
|
-
cache_dit-1.0.
|
|
55
|
-
cache_dit-1.0.
|
|
56
|
-
cache_dit-1.0.
|
|
57
|
-
cache_dit-1.0.
|
|
53
|
+
cache_dit-1.0.2.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
54
|
+
cache_dit-1.0.2.dist-info/METADATA,sha256=E6MkP_T9cwJEbqWE1DIRVkQLI7wLWr5zryY2poWgkyw,26766
|
|
55
|
+
cache_dit-1.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
56
|
+
cache_dit-1.0.2.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
57
|
+
cache_dit-1.0.2.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
58
|
+
cache_dit-1.0.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|