cache-dit 0.2.33__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +46 -2
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +11 -11
- cache_dit/cache_factory/cache_contexts/cache_manager.py +5 -5
- cache_dit/cache_factory/cache_contexts/taylorseer.py +12 -6
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.34.dist-info}/METADATA +169 -29
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.34.dist-info}/RECORD +18 -17
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.34.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.34.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.34.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.34.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.34'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 34)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -153,7 +153,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
|
|
156
|
-
@BlockAdapterRegistry.register("
|
|
156
|
+
@BlockAdapterRegistry.register("LTX")
|
|
157
157
|
def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
158
158
|
from diffusers import LTXVideoTransformer3DModel
|
|
159
159
|
|
|
@@ -248,7 +248,10 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
248
248
|
pipe=pipe,
|
|
249
249
|
transformer=pipe.transformer,
|
|
250
250
|
blocks=pipe.transformer.blocks,
|
|
251
|
-
|
|
251
|
+
# NOTE: Use Pattern_3 instead of Pattern_2 because the
|
|
252
|
+
# encoder_hidden_states will never change in the blocks
|
|
253
|
+
# forward loop.
|
|
254
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
252
255
|
has_separate_cfg=True,
|
|
253
256
|
**kwargs,
|
|
254
257
|
)
|
|
@@ -285,6 +288,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
285
288
|
@BlockAdapterRegistry.register("DiT")
|
|
286
289
|
def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
287
290
|
from diffusers import DiTTransformer2DModel
|
|
291
|
+
from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
|
|
288
292
|
|
|
289
293
|
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
290
294
|
return BlockAdapter(
|
|
@@ -292,6 +296,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
292
296
|
transformer=pipe.transformer,
|
|
293
297
|
blocks=pipe.transformer.transformer_blocks,
|
|
294
298
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
299
|
+
patch_functor=DiTPatchFunctor(),
|
|
295
300
|
**kwargs,
|
|
296
301
|
)
|
|
297
302
|
|
|
@@ -331,24 +336,13 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
331
336
|
|
|
332
337
|
|
|
333
338
|
@BlockAdapterRegistry.register("Lumina")
|
|
334
|
-
def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
335
|
-
from diffusers import LuminaNextDiT2DModel
|
|
336
|
-
|
|
337
|
-
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
338
|
-
return BlockAdapter(
|
|
339
|
-
pipe=pipe,
|
|
340
|
-
transformer=pipe.transformer,
|
|
341
|
-
blocks=pipe.transformer.layers,
|
|
342
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
343
|
-
**kwargs,
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@BlockAdapterRegistry.register("Lumina2")
|
|
348
339
|
def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
349
340
|
from diffusers import Lumina2Transformer2DModel
|
|
341
|
+
from diffusers import LuminaNextDiT2DModel
|
|
350
342
|
|
|
351
|
-
assert isinstance(
|
|
343
|
+
assert isinstance(
|
|
344
|
+
pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel)
|
|
345
|
+
)
|
|
352
346
|
return BlockAdapter(
|
|
353
347
|
pipe=pipe,
|
|
354
348
|
transformer=pipe.transformer,
|
|
@@ -386,12 +380,10 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
386
380
|
)
|
|
387
381
|
|
|
388
382
|
|
|
389
|
-
@BlockAdapterRegistry.register("Sana"
|
|
383
|
+
@BlockAdapterRegistry.register("Sana")
|
|
390
384
|
def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
391
385
|
from diffusers import SanaTransformer2DModel
|
|
392
386
|
|
|
393
|
-
# TODO: fix -> got multiple values for argument 'encoder_hidden_states'
|
|
394
|
-
|
|
395
387
|
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
396
388
|
return BlockAdapter(
|
|
397
389
|
pipe=pipe,
|
|
@@ -469,6 +461,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
469
461
|
@BlockAdapterRegistry.register("Chroma")
|
|
470
462
|
def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
471
463
|
from diffusers import ChromaTransformer2DModel
|
|
464
|
+
from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
|
|
472
465
|
|
|
473
466
|
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
474
467
|
return BlockAdapter(
|
|
@@ -482,6 +475,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
482
475
|
ForwardPattern.Pattern_1,
|
|
483
476
|
ForwardPattern.Pattern_3,
|
|
484
477
|
],
|
|
478
|
+
patch_functor=ChromaPatchFunctor(),
|
|
485
479
|
has_separate_cfg=True,
|
|
486
480
|
**kwargs,
|
|
487
481
|
)
|
|
@@ -16,8 +16,52 @@ logger = init_logger(__name__)
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class ParamsModifier:
|
|
19
|
-
def __init__(
|
|
20
|
-
self
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
# Cache context kwargs
|
|
22
|
+
Fn_compute_blocks: Optional[int] = None,
|
|
23
|
+
Bn_compute_blocks: Optional[int] = None,
|
|
24
|
+
max_warmup_steps: Optional[int] = None,
|
|
25
|
+
max_cached_steps: Optional[int] = None,
|
|
26
|
+
max_continuous_cached_steps: Optional[int] = None,
|
|
27
|
+
residual_diff_threshold: Optional[float] = None,
|
|
28
|
+
# Cache CFG or not
|
|
29
|
+
enable_separate_cfg: Optional[bool] = None,
|
|
30
|
+
cfg_compute_first: Optional[bool] = None,
|
|
31
|
+
cfg_diff_compute_separate: Optional[bool] = None,
|
|
32
|
+
# Hybird TaylorSeer
|
|
33
|
+
enable_taylorseer: Optional[bool] = None,
|
|
34
|
+
enable_encoder_taylorseer: Optional[bool] = None,
|
|
35
|
+
taylorseer_cache_type: Optional[str] = None,
|
|
36
|
+
taylorseer_order: Optional[int] = None,
|
|
37
|
+
**other_cache_context_kwargs,
|
|
38
|
+
):
|
|
39
|
+
self._context_kwargs = other_cache_context_kwargs.copy()
|
|
40
|
+
self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
|
|
41
|
+
self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
|
|
42
|
+
self._maybe_update_param("max_warmup_steps", max_warmup_steps)
|
|
43
|
+
self._maybe_update_param("max_cached_steps", max_cached_steps)
|
|
44
|
+
self._maybe_update_param(
|
|
45
|
+
"max_continuous_cached_steps", max_continuous_cached_steps
|
|
46
|
+
)
|
|
47
|
+
self._maybe_update_param(
|
|
48
|
+
"residual_diff_threshold", residual_diff_threshold
|
|
49
|
+
)
|
|
50
|
+
self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
|
|
51
|
+
self._maybe_update_param("cfg_compute_first", cfg_compute_first)
|
|
52
|
+
self._maybe_update_param(
|
|
53
|
+
"cfg_diff_compute_separate", cfg_diff_compute_separate
|
|
54
|
+
)
|
|
55
|
+
self._maybe_update_param("enable_taylorseer", enable_taylorseer)
|
|
56
|
+
self._maybe_update_param(
|
|
57
|
+
"enable_encoder_taylorseer", enable_encoder_taylorseer
|
|
58
|
+
)
|
|
59
|
+
self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
|
|
60
|
+
self._maybe_update_param("taylorseer_order", taylorseer_order)
|
|
61
|
+
|
|
62
|
+
def _maybe_update_param(self, key: str, value: Any):
|
|
63
|
+
if value is not None:
|
|
64
|
+
self._context_kwargs[key] = value
|
|
21
65
|
|
|
22
66
|
|
|
23
67
|
@dataclasses.dataclass
|
|
@@ -10,13 +10,14 @@ logger = init_logger(__name__)
|
|
|
10
10
|
|
|
11
11
|
class BlockAdapterRegistry:
|
|
12
12
|
_adapters: Dict[str, Callable[..., BlockAdapter]] = {}
|
|
13
|
-
|
|
13
|
+
_predefined_adapters_has_separate_cfg: List[str] = [
|
|
14
14
|
"QwenImage",
|
|
15
15
|
"Wan",
|
|
16
16
|
"CogView4",
|
|
17
17
|
"Cosmos",
|
|
18
18
|
"SkyReelsV2",
|
|
19
19
|
"Chroma",
|
|
20
|
+
"Lumina2",
|
|
20
21
|
]
|
|
21
22
|
|
|
22
23
|
@classmethod
|
|
@@ -68,7 +69,7 @@ class BlockAdapterRegistry:
|
|
|
68
69
|
return True
|
|
69
70
|
|
|
70
71
|
pipe_cls_name = pipe_or_adapter.__class__.__name__
|
|
71
|
-
for name in cls.
|
|
72
|
+
for name in cls._predefined_adapters_has_separate_cfg:
|
|
72
73
|
if pipe_cls_name.startswith(name):
|
|
73
74
|
return True
|
|
74
75
|
|
|
@@ -114,27 +114,27 @@ class CachedAdapter:
|
|
|
114
114
|
**cache_context_kwargs,
|
|
115
115
|
):
|
|
116
116
|
# Check cache_context_kwargs
|
|
117
|
-
if cache_context_kwargs["
|
|
117
|
+
if cache_context_kwargs["enable_separate_cfg"] is None:
|
|
118
118
|
# Check cfg for some specific case if users don't set it as True
|
|
119
119
|
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
-
cache_context_kwargs["
|
|
120
|
+
cache_context_kwargs["enable_separate_cfg"] = True
|
|
121
121
|
logger.info(
|
|
122
|
-
f"Use custom '
|
|
122
|
+
f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
|
|
123
123
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
124
|
)
|
|
125
125
|
else:
|
|
126
|
-
cache_context_kwargs["
|
|
126
|
+
cache_context_kwargs["enable_separate_cfg"] = (
|
|
127
127
|
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
128
|
)
|
|
129
129
|
logger.info(
|
|
130
|
-
f"Use default '
|
|
131
|
-
f"register: {cache_context_kwargs['
|
|
130
|
+
f"Use default 'enable_separate_cfg' from block adapter "
|
|
131
|
+
f"register: {cache_context_kwargs['enable_separate_cfg']}, "
|
|
132
132
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
133
|
)
|
|
134
134
|
else:
|
|
135
135
|
logger.info(
|
|
136
|
-
f"Use custom '
|
|
137
|
-
f"kwargs: {cache_context_kwargs['
|
|
136
|
+
f"Use custom 'enable_separate_cfg' from cache context "
|
|
137
|
+
f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
|
|
138
138
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
139
|
)
|
|
140
140
|
|
|
@@ -53,20 +53,20 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
53
53
|
enable_taylorseer: bool = False
|
|
54
54
|
enable_encoder_taylorseer: bool = False
|
|
55
55
|
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
56
|
-
taylorseer_order: int =
|
|
56
|
+
taylorseer_order: int = 1 # The order for TaylorSeer
|
|
57
57
|
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
58
58
|
taylorseer: Optional[TaylorSeer] = None
|
|
59
59
|
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
60
60
|
|
|
61
|
-
# Support
|
|
61
|
+
# Support enable_separate_cfg, such as Wan 2.1,
|
|
62
62
|
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
63
|
-
# forward step, should set
|
|
63
|
+
# forward step, should set enable_separate_cfg as False.
|
|
64
64
|
# For example: CogVideoX, HunyuanVideo, Mochi.
|
|
65
|
-
|
|
65
|
+
enable_separate_cfg: bool = False
|
|
66
66
|
# Compute cfg forward first or not, default False, namely,
|
|
67
67
|
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
68
68
|
cfg_compute_first: bool = False
|
|
69
|
-
# Compute
|
|
69
|
+
# Compute separate diff values for CFG and non-CFG step,
|
|
70
70
|
# default True. If False, we will use the computed diff from
|
|
71
71
|
# current non-CFG transformer step for current CFG step.
|
|
72
72
|
cfg_diff_compute_separate: bool = True
|
|
@@ -89,7 +89,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
89
89
|
if logger.isEnabledFor(logging.DEBUG):
|
|
90
90
|
logger.info(f"Created _CacheContext: {self.name}")
|
|
91
91
|
# Some checks for settings
|
|
92
|
-
if self.
|
|
92
|
+
if self.enable_separate_cfg:
|
|
93
93
|
if self.cfg_diff_compute_separate:
|
|
94
94
|
assert self.cfg_compute_first is False, (
|
|
95
95
|
"cfg_compute_first must set as False if "
|
|
@@ -108,12 +108,12 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
108
108
|
|
|
109
109
|
if self.enable_taylorseer:
|
|
110
110
|
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
111
|
-
if self.
|
|
111
|
+
if self.enable_separate_cfg:
|
|
112
112
|
self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
113
113
|
|
|
114
114
|
if self.enable_encoder_taylorseer:
|
|
115
115
|
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
116
|
-
if self.
|
|
116
|
+
if self.enable_separate_cfg:
|
|
117
117
|
self.cfg_encoder_taylorseer = TaylorSeer(
|
|
118
118
|
**self.taylorseer_kwargs
|
|
119
119
|
)
|
|
@@ -145,7 +145,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
145
145
|
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
146
146
|
# current step: incr step - 1
|
|
147
147
|
self.transformer_executed_steps += 1
|
|
148
|
-
if not self.
|
|
148
|
+
if not self.enable_separate_cfg:
|
|
149
149
|
self.executed_steps += 1
|
|
150
150
|
else:
|
|
151
151
|
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
@@ -183,7 +183,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
183
183
|
|
|
184
184
|
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
185
185
|
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
186
|
-
if self.
|
|
186
|
+
if self.enable_separate_cfg:
|
|
187
187
|
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
188
188
|
if not self.is_separate_cfg_step():
|
|
189
189
|
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
@@ -269,7 +269,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
269
269
|
return self.transformer_executed_steps - 1
|
|
270
270
|
|
|
271
271
|
def is_separate_cfg_step(self):
|
|
272
|
-
if not self.
|
|
272
|
+
if not self.enable_separate_cfg:
|
|
273
273
|
return False
|
|
274
274
|
if self.cfg_compute_first:
|
|
275
275
|
# CFG steps: 0, 2, 4, 6, ...
|
|
@@ -74,8 +74,8 @@ class CachedContextManager:
|
|
|
74
74
|
del self._cached_context_manager[cached_context]
|
|
75
75
|
|
|
76
76
|
def clear_contexts(self):
|
|
77
|
-
for
|
|
78
|
-
self.remove_context(
|
|
77
|
+
for context_name in list(self._cached_context_manager.keys()):
|
|
78
|
+
self.remove_context(context_name)
|
|
79
79
|
|
|
80
80
|
@contextlib.contextmanager
|
|
81
81
|
def enter_context(self, cached_context: CachedContext | str):
|
|
@@ -364,10 +364,10 @@ class CachedContextManager:
|
|
|
364
364
|
return cached_context.Bn_compute_blocks
|
|
365
365
|
|
|
366
366
|
@torch.compiler.disable
|
|
367
|
-
def
|
|
367
|
+
def enable_separate_cfg(self) -> bool:
|
|
368
368
|
cached_context = self.get_context()
|
|
369
369
|
assert cached_context is not None, "cached_context must be set before"
|
|
370
|
-
return cached_context.
|
|
370
|
+
return cached_context.enable_separate_cfg
|
|
371
371
|
|
|
372
372
|
@torch.compiler.disable
|
|
373
373
|
def is_separate_cfg_step(self) -> bool:
|
|
@@ -410,7 +410,7 @@ class CachedContextManager:
|
|
|
410
410
|
|
|
411
411
|
if all(
|
|
412
412
|
(
|
|
413
|
-
self.
|
|
413
|
+
self.enable_separate_cfg(),
|
|
414
414
|
self.is_separate_cfg_step(),
|
|
415
415
|
not self.cfg_diff_compute_separate(),
|
|
416
416
|
self.get_current_step_residual_diff() is not None,
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import torch
|
|
3
|
+
from typing import List, Dict
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class TaylorSeer:
|
|
@@ -17,7 +19,7 @@ class TaylorSeer:
|
|
|
17
19
|
self.reset_cache()
|
|
18
20
|
|
|
19
21
|
def reset_cache(self):
|
|
20
|
-
self.state = {
|
|
22
|
+
self.state: Dict[str, List[torch.Tensor]] = {
|
|
21
23
|
"dY_prev": [None] * self.ORDER,
|
|
22
24
|
"dY_current": [None] * self.ORDER,
|
|
23
25
|
}
|
|
@@ -36,15 +38,19 @@ class TaylorSeer:
|
|
|
36
38
|
return True
|
|
37
39
|
return False
|
|
38
40
|
|
|
39
|
-
def approximate_derivative(self, Y):
|
|
41
|
+
def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
|
|
40
42
|
# n-th order Taylor expansion:
|
|
41
43
|
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
42
44
|
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
43
45
|
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
44
46
|
# especially for large n_derivatives.
|
|
45
|
-
dY_current = [None] * self.ORDER
|
|
47
|
+
dY_current: List[torch.Tensor] = [None] * self.ORDER
|
|
46
48
|
dY_current[0] = Y
|
|
47
49
|
window = self.current_step - self.last_non_approximated_step
|
|
50
|
+
if self.state["dY_prev"][0] is not None:
|
|
51
|
+
if dY_current[0].shape != self.state["dY_prev"][0].shape:
|
|
52
|
+
self.reset_cache()
|
|
53
|
+
|
|
48
54
|
for i in range(self.n_derivatives):
|
|
49
55
|
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
50
56
|
dY_current[i + 1] = (
|
|
@@ -54,7 +60,7 @@ class TaylorSeer:
|
|
|
54
60
|
break
|
|
55
61
|
return dY_current
|
|
56
62
|
|
|
57
|
-
def approximate_value(self):
|
|
63
|
+
def approximate_value(self) -> torch.Tensor:
|
|
58
64
|
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
59
65
|
# especially for large n_derivatives.
|
|
60
66
|
elapsed = self.current_step - self.last_non_approximated_step
|
|
@@ -69,7 +75,7 @@ class TaylorSeer:
|
|
|
69
75
|
def mark_step_begin(self):
|
|
70
76
|
self.current_step += 1
|
|
71
77
|
|
|
72
|
-
def update(self, Y):
|
|
78
|
+
def update(self, Y: torch.Tensor):
|
|
73
79
|
# Directly call this method will ingnore the warmup
|
|
74
80
|
# policy and force full computation.
|
|
75
81
|
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
@@ -87,7 +93,7 @@ class TaylorSeer:
|
|
|
87
93
|
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
88
94
|
self.last_non_approximated_step = self.current_step
|
|
89
95
|
|
|
90
|
-
def step(self, Y):
|
|
96
|
+
def step(self, Y: torch.Tensor):
|
|
91
97
|
self.mark_step_begin()
|
|
92
98
|
if self.should_compute_full():
|
|
93
99
|
self.update(Y)
|
|
@@ -24,14 +24,14 @@ def enable_cache(
|
|
|
24
24
|
max_continuous_cached_steps: int = -1,
|
|
25
25
|
residual_diff_threshold: float = 0.08,
|
|
26
26
|
# Cache CFG or not
|
|
27
|
-
|
|
27
|
+
enable_separate_cfg: bool = None,
|
|
28
28
|
cfg_compute_first: bool = False,
|
|
29
29
|
cfg_diff_compute_separate: bool = True,
|
|
30
30
|
# Hybird TaylorSeer
|
|
31
31
|
enable_taylorseer: bool = False,
|
|
32
32
|
enable_encoder_taylorseer: bool = False,
|
|
33
33
|
taylorseer_cache_type: str = "residual",
|
|
34
|
-
taylorseer_order: int =
|
|
34
|
+
taylorseer_order: int = 1,
|
|
35
35
|
**other_cache_context_kwargs,
|
|
36
36
|
) -> Union[
|
|
37
37
|
DiffusionPipeline,
|
|
@@ -70,15 +70,15 @@ def enable_cache(
|
|
|
70
70
|
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
71
71
|
he value of residual diff threshold, a higher value leads to faster performance at the
|
|
72
72
|
cost of lower precision.
|
|
73
|
-
|
|
73
|
+
enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
74
74
|
Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
75
|
-
and non-CFG into single forward step, should set
|
|
75
|
+
and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
76
76
|
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
77
77
|
cfg_compute_first (`bool`, *required*, defaults to False):
|
|
78
78
|
Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
79
79
|
1, 3, 5, ... -> CFG step.
|
|
80
80
|
cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
81
|
-
Compute
|
|
81
|
+
Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
82
82
|
use the computed diff from current non-CFG transformer step for current CFG step.
|
|
83
83
|
enable_taylorseer (`bool`, *required*, defaults to False):
|
|
84
84
|
Enable the hybird TaylorSeer for hidden_states or not. We have supported the
|
|
@@ -91,10 +91,10 @@ def enable_cache(
|
|
|
91
91
|
Enable the hybird TaylorSeer for encoder_hidden_states or not.
|
|
92
92
|
taylorseer_cache_type (`str`, *required*, defaults to `residual`):
|
|
93
93
|
The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
|
|
94
|
-
taylorseer_order (`int`, *required*, defaults to
|
|
94
|
+
taylorseer_order (`int`, *required*, defaults to 1):
|
|
95
95
|
The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
the recommended value is 1 or 2.
|
|
97
|
+
other_cache_context_kwargs: (`dict`, *optional*, defaults to {})
|
|
98
98
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
99
99
|
for more details.
|
|
100
100
|
|
|
@@ -123,7 +123,7 @@ def enable_cache(
|
|
|
123
123
|
max_continuous_cached_steps
|
|
124
124
|
)
|
|
125
125
|
cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
|
|
126
|
-
cache_context_kwargs["
|
|
126
|
+
cache_context_kwargs["enable_separate_cfg"] = enable_separate_cfg
|
|
127
127
|
cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
|
|
128
128
|
cache_context_kwargs["cfg_diff_compute_separate"] = (
|
|
129
129
|
cfg_diff_compute_separate
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
|
|
2
|
+
from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
|
|
2
3
|
from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
|
|
3
4
|
from cache_dit.cache_factory.patch_functors.functor_chroma import (
|
|
4
5
|
ChromaPatchFunctor,
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
import numpy as np
|
|
5
3
|
from typing import Tuple, Optional, Dict, Any, Union
|
|
6
4
|
from diffusers import ChromaTransformer2DModel
|
|
7
5
|
from diffusers.models.transformers.transformer_chroma import (
|
|
6
|
+
ChromaTransformerBlock,
|
|
8
7
|
ChromaSingleTransformerBlock,
|
|
9
8
|
Transformer2DModelOutput,
|
|
10
9
|
)
|
|
@@ -27,24 +26,31 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
27
26
|
def apply(
|
|
28
27
|
self,
|
|
29
28
|
transformer: ChromaTransformer2DModel,
|
|
30
|
-
blocks: torch.nn.ModuleList = None,
|
|
31
29
|
**kwargs,
|
|
32
30
|
) -> ChromaTransformer2DModel:
|
|
33
31
|
if hasattr(transformer, "_is_patched"):
|
|
34
32
|
return transformer
|
|
35
33
|
|
|
36
|
-
if blocks is None:
|
|
37
|
-
blocks = transformer.single_transformer_blocks
|
|
38
|
-
|
|
39
34
|
is_patched = False
|
|
40
|
-
for block in
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
35
|
+
for index_block, block in enumerate(transformer.transformer_blocks):
|
|
36
|
+
assert isinstance(block, ChromaTransformerBlock)
|
|
37
|
+
img_offset = 3 * len(transformer.single_transformer_blocks)
|
|
38
|
+
txt_offset = img_offset + 6 * len(transformer.transformer_blocks)
|
|
39
|
+
img_modulation = img_offset + 6 * index_block
|
|
40
|
+
text_modulation = txt_offset + 6 * index_block
|
|
41
|
+
block._img_modulation = img_modulation
|
|
42
|
+
block._text_modulation = text_modulation
|
|
43
|
+
block.forward = __patch_double_forward__.__get__(block)
|
|
44
|
+
|
|
45
|
+
for index_block, block in enumerate(
|
|
46
|
+
transformer.single_transformer_blocks
|
|
47
|
+
):
|
|
48
|
+
assert isinstance(block, ChromaSingleTransformerBlock)
|
|
49
|
+
start_idx = 3 * index_block
|
|
50
|
+
block._start_idx = start_idx
|
|
51
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
52
|
+
|
|
53
|
+
is_patched = True
|
|
48
54
|
|
|
49
55
|
cls_name = transformer.__class__.__name__
|
|
50
56
|
|
|
@@ -69,25 +75,123 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
69
75
|
return transformer
|
|
70
76
|
|
|
71
77
|
|
|
78
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
79
|
+
def __patch_double_forward__(
|
|
80
|
+
self: ChromaTransformerBlock,
|
|
81
|
+
hidden_states: torch.Tensor,
|
|
82
|
+
encoder_hidden_states: torch.Tensor,
|
|
83
|
+
pooled_temb: torch.Tensor,
|
|
84
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
85
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
86
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
87
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
88
|
+
# TODO: Fuse controlnet into block forward
|
|
89
|
+
img_modulation = self._img_modulation
|
|
90
|
+
text_modulation = self._text_modulation
|
|
91
|
+
temb = torch.cat(
|
|
92
|
+
(
|
|
93
|
+
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
94
|
+
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
95
|
+
),
|
|
96
|
+
dim=1,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
|
100
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
101
|
+
hidden_states, emb=temb_img
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
(
|
|
105
|
+
norm_encoder_hidden_states,
|
|
106
|
+
c_gate_msa,
|
|
107
|
+
c_shift_mlp,
|
|
108
|
+
c_scale_mlp,
|
|
109
|
+
c_gate_mlp,
|
|
110
|
+
) = self.norm1_context(encoder_hidden_states, emb=temb_txt)
|
|
111
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
112
|
+
if attention_mask is not None:
|
|
113
|
+
attention_mask = (
|
|
114
|
+
attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Attention.
|
|
118
|
+
attention_outputs = self.attn(
|
|
119
|
+
hidden_states=norm_hidden_states,
|
|
120
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
|
121
|
+
image_rotary_emb=image_rotary_emb,
|
|
122
|
+
attention_mask=attention_mask,
|
|
123
|
+
**joint_attention_kwargs,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if len(attention_outputs) == 2:
|
|
127
|
+
attn_output, context_attn_output = attention_outputs
|
|
128
|
+
elif len(attention_outputs) == 3:
|
|
129
|
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
|
130
|
+
|
|
131
|
+
# Process attention outputs for the `hidden_states`.
|
|
132
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
133
|
+
hidden_states = hidden_states + attn_output
|
|
134
|
+
|
|
135
|
+
norm_hidden_states = self.norm2(hidden_states)
|
|
136
|
+
norm_hidden_states = (
|
|
137
|
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
ff_output = self.ff(norm_hidden_states)
|
|
141
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
142
|
+
|
|
143
|
+
hidden_states = hidden_states + ff_output
|
|
144
|
+
if len(attention_outputs) == 3:
|
|
145
|
+
hidden_states = hidden_states + ip_attn_output
|
|
146
|
+
|
|
147
|
+
# Process attention outputs for the `encoder_hidden_states`.
|
|
148
|
+
|
|
149
|
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
|
150
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
|
151
|
+
|
|
152
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
|
153
|
+
norm_encoder_hidden_states = (
|
|
154
|
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
|
155
|
+
+ c_shift_mlp[:, None]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
|
159
|
+
encoder_hidden_states = (
|
|
160
|
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
|
161
|
+
)
|
|
162
|
+
if encoder_hidden_states.dtype == torch.float16:
|
|
163
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
|
164
|
+
|
|
165
|
+
return encoder_hidden_states, hidden_states
|
|
166
|
+
|
|
167
|
+
|
|
72
168
|
# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
73
169
|
def __patch_single_forward__(
|
|
74
170
|
self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
|
|
75
171
|
hidden_states: torch.Tensor,
|
|
76
|
-
|
|
77
|
-
temb: torch.Tensor,
|
|
172
|
+
pooled_temb: torch.Tensor,
|
|
78
173
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
174
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
79
175
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
80
|
-
) ->
|
|
81
|
-
|
|
82
|
-
|
|
176
|
+
) -> torch.Tensor:
|
|
177
|
+
# TODO: Fuse controlnet into block forward
|
|
178
|
+
start_idx = self._start_idx
|
|
179
|
+
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
83
180
|
|
|
84
181
|
residual = hidden_states
|
|
85
182
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
|
86
183
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
|
87
184
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
185
|
+
|
|
186
|
+
if attention_mask is not None:
|
|
187
|
+
attention_mask = (
|
|
188
|
+
attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
|
189
|
+
)
|
|
190
|
+
|
|
88
191
|
attn_output = self.attn(
|
|
89
192
|
hidden_states=norm_hidden_states,
|
|
90
193
|
image_rotary_emb=image_rotary_emb,
|
|
194
|
+
attention_mask=attention_mask,
|
|
91
195
|
**joint_attention_kwargs,
|
|
92
196
|
)
|
|
93
197
|
|
|
@@ -98,11 +202,7 @@ def __patch_single_forward__(
|
|
|
98
202
|
if hidden_states.dtype == torch.float16:
|
|
99
203
|
hidden_states = hidden_states.clip(-65504, 65504)
|
|
100
204
|
|
|
101
|
-
|
|
102
|
-
hidden_states[:, :text_seq_len],
|
|
103
|
-
hidden_states[:, text_seq_len:],
|
|
104
|
-
)
|
|
105
|
-
return encoder_hidden_states, hidden_states
|
|
205
|
+
return hidden_states
|
|
106
206
|
|
|
107
207
|
|
|
108
208
|
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
@@ -174,24 +274,13 @@ def __patch_transformer_forward__(
|
|
|
174
274
|
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
|
175
275
|
|
|
176
276
|
for index_block, block in enumerate(self.transformer_blocks):
|
|
177
|
-
img_offset = 3 * len(self.single_transformer_blocks)
|
|
178
|
-
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
|
179
|
-
img_modulation = img_offset + 6 * index_block
|
|
180
|
-
text_modulation = txt_offset + 6 * index_block
|
|
181
|
-
temb = torch.cat(
|
|
182
|
-
(
|
|
183
|
-
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
184
|
-
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
185
|
-
),
|
|
186
|
-
dim=1,
|
|
187
|
-
)
|
|
188
277
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
189
278
|
encoder_hidden_states, hidden_states = (
|
|
190
279
|
self._gradient_checkpointing_func(
|
|
191
280
|
block,
|
|
192
281
|
hidden_states,
|
|
193
282
|
encoder_hidden_states,
|
|
194
|
-
|
|
283
|
+
pooled_temb,
|
|
195
284
|
image_rotary_emb,
|
|
196
285
|
attention_mask,
|
|
197
286
|
)
|
|
@@ -201,12 +290,13 @@ def __patch_transformer_forward__(
|
|
|
201
290
|
encoder_hidden_states, hidden_states = block(
|
|
202
291
|
hidden_states=hidden_states,
|
|
203
292
|
encoder_hidden_states=encoder_hidden_states,
|
|
204
|
-
|
|
293
|
+
pooled_temb=pooled_temb,
|
|
205
294
|
image_rotary_emb=image_rotary_emb,
|
|
206
295
|
attention_mask=attention_mask,
|
|
207
296
|
joint_attention_kwargs=joint_attention_kwargs,
|
|
208
297
|
)
|
|
209
298
|
|
|
299
|
+
# TODO: Fuse controlnet into block forward
|
|
210
300
|
# controlnet residual
|
|
211
301
|
if controlnet_block_samples is not None:
|
|
212
302
|
interval_control = len(self.transformer_blocks) / len(
|
|
@@ -227,43 +317,43 @@ def __patch_transformer_forward__(
|
|
|
227
317
|
+ controlnet_block_samples[index_block // interval_control]
|
|
228
318
|
)
|
|
229
319
|
|
|
320
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
321
|
+
|
|
230
322
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
231
|
-
start_idx = 3 * index_block
|
|
232
|
-
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
233
323
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
)
|
|
324
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
325
|
+
block,
|
|
326
|
+
hidden_states,
|
|
327
|
+
pooled_temb,
|
|
328
|
+
image_rotary_emb,
|
|
329
|
+
attention_mask,
|
|
330
|
+
joint_attention_kwargs,
|
|
242
331
|
)
|
|
243
332
|
|
|
244
333
|
else:
|
|
245
|
-
|
|
334
|
+
hidden_states = block(
|
|
246
335
|
hidden_states=hidden_states,
|
|
247
|
-
|
|
248
|
-
temb=temb,
|
|
336
|
+
pooled_temb=pooled_temb,
|
|
249
337
|
image_rotary_emb=image_rotary_emb,
|
|
250
338
|
attention_mask=attention_mask,
|
|
251
339
|
joint_attention_kwargs=joint_attention_kwargs,
|
|
252
340
|
)
|
|
253
341
|
|
|
342
|
+
# TODO: Fuse controlnet into block forward
|
|
254
343
|
# controlnet residual
|
|
255
344
|
if controlnet_single_block_samples is not None:
|
|
256
345
|
interval_control = len(self.single_transformer_blocks) / len(
|
|
257
346
|
controlnet_single_block_samples
|
|
258
347
|
)
|
|
259
348
|
interval_control = int(np.ceil(interval_control))
|
|
260
|
-
hidden_states = (
|
|
261
|
-
hidden_states
|
|
349
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
|
350
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
262
351
|
+ controlnet_single_block_samples[
|
|
263
352
|
index_block // interval_control
|
|
264
353
|
]
|
|
265
354
|
)
|
|
266
355
|
|
|
356
|
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
267
357
|
temb = pooled_temb[:, -2:]
|
|
268
358
|
hidden_states = self.norm_out(hidden_states, temb)
|
|
269
359
|
output = self.proj_out(hidden_states)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Dict, Any
|
|
5
|
+
from diffusers.models.transformers.dit_transformer_2d import (
|
|
6
|
+
DiTTransformer2DModel,
|
|
7
|
+
Transformer2DModelOutput,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
10
|
+
PatchFunctor,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DiTPatchFunctor(PatchFunctor):
|
|
18
|
+
|
|
19
|
+
def apply(
|
|
20
|
+
self,
|
|
21
|
+
transformer: DiTTransformer2DModel,
|
|
22
|
+
**kwargs,
|
|
23
|
+
) -> DiTTransformer2DModel:
|
|
24
|
+
if hasattr(transformer, "_is_patched"):
|
|
25
|
+
return transformer
|
|
26
|
+
|
|
27
|
+
is_patched = False
|
|
28
|
+
|
|
29
|
+
transformer._norm1_emb = transformer.transformer_blocks[0].norm1.emb
|
|
30
|
+
|
|
31
|
+
is_patched = True
|
|
32
|
+
|
|
33
|
+
cls_name = transformer.__class__.__name__
|
|
34
|
+
|
|
35
|
+
if is_patched:
|
|
36
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
37
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
38
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
39
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
40
|
+
"parallized forward and cause a downgrade of performance."
|
|
41
|
+
)
|
|
42
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
43
|
+
transformer
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
transformer._is_patched = is_patched # True or False
|
|
47
|
+
|
|
48
|
+
logger.info(
|
|
49
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
50
|
+
f"Patch: {is_patched}."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return transformer
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def __patch_transformer_forward__(
|
|
57
|
+
self: DiTTransformer2DModel,
|
|
58
|
+
hidden_states: torch.Tensor,
|
|
59
|
+
timestep: Optional[torch.LongTensor] = None,
|
|
60
|
+
class_labels: Optional[torch.LongTensor] = None,
|
|
61
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
|
62
|
+
return_dict: bool = True,
|
|
63
|
+
):
|
|
64
|
+
height, width = (
|
|
65
|
+
hidden_states.shape[-2] // self.patch_size,
|
|
66
|
+
hidden_states.shape[-1] // self.patch_size,
|
|
67
|
+
)
|
|
68
|
+
hidden_states = self.pos_embed(hidden_states)
|
|
69
|
+
|
|
70
|
+
# 2. Blocks
|
|
71
|
+
for block in self.transformer_blocks:
|
|
72
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
73
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
74
|
+
block,
|
|
75
|
+
hidden_states,
|
|
76
|
+
None,
|
|
77
|
+
None,
|
|
78
|
+
None,
|
|
79
|
+
timestep,
|
|
80
|
+
cross_attention_kwargs,
|
|
81
|
+
class_labels,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
hidden_states = block(
|
|
85
|
+
hidden_states,
|
|
86
|
+
attention_mask=None,
|
|
87
|
+
encoder_hidden_states=None,
|
|
88
|
+
encoder_attention_mask=None,
|
|
89
|
+
timestep=timestep,
|
|
90
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
|
91
|
+
class_labels=class_labels,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# 3. Output
|
|
95
|
+
# conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
|
96
|
+
conditioning = self._norm1_emb(
|
|
97
|
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
98
|
+
)
|
|
99
|
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
|
100
|
+
hidden_states = (
|
|
101
|
+
self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
|
102
|
+
)
|
|
103
|
+
hidden_states = self.proj_out_2(hidden_states)
|
|
104
|
+
|
|
105
|
+
# unpatchify
|
|
106
|
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
|
107
|
+
hidden_states = hidden_states.reshape(
|
|
108
|
+
shape=(
|
|
109
|
+
-1,
|
|
110
|
+
height,
|
|
111
|
+
width,
|
|
112
|
+
self.patch_size,
|
|
113
|
+
self.patch_size,
|
|
114
|
+
self.out_channels,
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
118
|
+
output = hidden_states.reshape(
|
|
119
|
+
shape=(
|
|
120
|
+
-1,
|
|
121
|
+
self.out_channels,
|
|
122
|
+
height * self.patch_size,
|
|
123
|
+
width * self.patch_size,
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if not return_dict:
|
|
128
|
+
return (output,)
|
|
129
|
+
|
|
130
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.34
|
|
4
4
|
Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -60,40 +60,97 @@ Dynamic: requires-python
|
|
|
60
60
|
</p>
|
|
61
61
|
<p align="center">
|
|
62
62
|
🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
|
|
63
|
-
🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1
|
|
64
|
-
🔥<a href="#supported">HunyuanImage-2.1</a> | <a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">
|
|
65
|
-
🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">
|
|
66
|
-
🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported">
|
|
67
|
-
🔥<a href="#supported">Allegro</a> | <a href="#supported">
|
|
63
|
+
🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1 </a> | <a href="#supported"> Wan 2.2 </a>🔥<br>
|
|
64
|
+
🔥<a href="#supported">HunyuanImage-2.1</a> | <a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">AuraFlow</a>🔥<br>
|
|
65
|
+
🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">LTXVideo</a> | <a href="#supported">CogVideoX</a> | <a href="#supported">CogVideoX 1.5</a> | <a href="#supported">ConsisID</a>🔥<br>
|
|
66
|
+
🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported">OmniGen 1/2</a> | <a href="#supported">Lumina 1/2</a> | <a href="#supported">PixArt</a>🔥<br>
|
|
67
|
+
🔥<a href="#supported">Chroma</a> | <a href="#supported">Sana</a> | <a href="#supported">Allegro</a> | <a href="#supported">Mochi</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported">Amused</a> | <a href="#supported"> ... </a> | <a href="#supported">DiT-XL</a>🔥
|
|
68
68
|
</p>
|
|
69
69
|
</div>
|
|
70
70
|
<div align='center'>
|
|
71
71
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=124px>
|
|
72
72
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=124px>
|
|
73
|
-
<img src
|
|
74
|
-
<img src
|
|
73
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/hunyuan_video.C0_L0_Q0_NONE.gif width=126px>
|
|
74
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/hunyuan_video.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S27.gif width=126px>
|
|
75
75
|
<p><b>🔥Wan2.2 MoE</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.0x↑🎉 | <b>HunyuanVideo</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
|
|
76
76
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
|
|
77
77
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
|
|
78
|
-
<img src
|
|
79
|
-
<img src
|
|
78
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux.C0_Q0_NONE_T23.69s.png width=90px>
|
|
79
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux.C0_Q0_DBCACHE_F1B0_W4M0MC0_T1O2_R0.15_S16_T11.39s.png width=90px>
|
|
80
80
|
<p><b>🔥Qwen-Image</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>FLUX.1-dev</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
|
|
81
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext-cat.C0_L0_Q0_NONE.png width=100px>
|
|
82
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_NONE.png width=100px>
|
|
83
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S10.png width=100px>
|
|
84
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S12.png width=100px>
|
|
85
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_DBCACHE_F1B0_W2M0MC2_T0O2_R0.15_S15.png width=100px>
|
|
86
|
+
<p><b>🔥FLUX-Kontext-dev</b> | Baseline | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.3x↑🎉 | 1.7x↑🎉 | 2.0x↑ 🎉</p>
|
|
81
87
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_NONE.png width=160px>
|
|
82
88
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_DBCACHE_F16B16_W2M1MC1_T0O2_R0.9_S1.png width=160px>
|
|
83
|
-
<img src
|
|
84
|
-
<img src
|
|
85
|
-
<p><b>🔥Qwen
|
|
86
|
-
<img src
|
|
87
|
-
<img src
|
|
88
|
-
<img src
|
|
89
|
-
<img src
|
|
90
|
-
<
|
|
89
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_NONE.png width=90px>
|
|
90
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_DBCACHE_F8B0_W8M0MC2_T1O2_R0.12_S25.png width=90px>
|
|
91
|
+
<p><b>🔥Qwen...Lightning</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.14x↑🎉 | <b>HunyuanImage</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.7x↑🎉</p>
|
|
92
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/examples/data/bear.png width=125px>
|
|
93
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_NONE.png width=125px>
|
|
94
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S18.png width=125px>
|
|
95
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S24.png width=125px>
|
|
96
|
+
<p><b>🔥Qwen-Image-Edit</b> | Input w/o Edit | Baseline | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.6x↑🎉 | 1.9x↑🎉 </p>
|
|
97
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/hidream.C0_L0_Q0_NONE.png width=100px>
|
|
98
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/hidream.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S24.png width=100px>
|
|
99
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview4.C0_L0_Q0_NONE.png width=100px>
|
|
100
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview4.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
|
|
101
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview4.C0_L0_Q0_DBCACHE_F1B0_W4M0MC4_T0O2_R0.2_S22.png width=100px>
|
|
91
102
|
<p><b>🔥HiDream-I1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉 | <b>CogView4</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.4x↑🎉 | 1.7x↑🎉</p>
|
|
92
|
-
<img src
|
|
93
|
-
<img src
|
|
94
|
-
<img src
|
|
95
|
-
<img src
|
|
96
|
-
<
|
|
103
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview3_plus.C0_L0_Q0_NONE.png width=100px>
|
|
104
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview3_plus.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
|
|
105
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview3_plus.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.08_S25.png width=100px>
|
|
106
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/chroma1-hd.C0_L0_Q0_NONE.png width=100px>
|
|
107
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/chroma1-hd.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S20.png width=100px>
|
|
108
|
+
<p><b>🔥CogView3</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.5x↑🎉 | 2.0x↑🎉| <b>Chroma1-HD</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉</p>
|
|
109
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/mochi.C0_L0_Q0_NONE.gif width=125px>
|
|
110
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/mochi.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S34.gif width=125px>
|
|
111
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/skyreels_v2.C0_L0_Q0_NONE.gif width=125px>
|
|
112
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/skyreels_v2.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.12_S17.gif width=125px>
|
|
113
|
+
<p><b>🔥Mochi-1-preview</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>SkyReelsV2</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.6x↑🎉</p>
|
|
114
|
+
<img src=./examples/data/visualcloze/00555_00.jpg width=100px>
|
|
115
|
+
<img src=./examples/data/visualcloze/12265_00.jpg width=100px>
|
|
116
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/visualcloze-512.C0_L0_Q0_NONE.png width=100px>
|
|
117
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/visualcloze-512.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
|
|
118
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/visualcloze-512.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S18.png width=100px>
|
|
119
|
+
<p><b>🔥VisualCloze-512</b> | Model | Cloth | Baseline | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.4x↑🎉 | 1.7x↑🎉 </p>
|
|
120
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/ltx-video.C0_L0_Q0_NONE.gif width=144px>
|
|
121
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/ltx-video.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.15_S13.gif width=144px>
|
|
122
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/cogvideox1.5.C0_L0_Q0_NONE.gif width=105px>
|
|
123
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/cogvideox1.5.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.12_S22.gif width=105px>
|
|
124
|
+
<p><b>🔥LTX-Video-0.9.7</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.7x↑🎉 | <b>CogVideoX1.5</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.0x↑🎉</p>
|
|
125
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/omingen-v1.C0_L0_Q0_NONE.png width=100px>
|
|
126
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/omingen-v1.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S24.png width=100px>
|
|
127
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/omingen-v1.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T1O2_R0.08_S38.png width=100px>
|
|
128
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/lumina2.C0_L0_Q0_NONE.png width=100px>
|
|
129
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/lumina2.C0_L0_Q0_DBCACHE_F1B0_W2M0MC2_T0O2_R0.12_S14.png width=100px>
|
|
130
|
+
<p><b>🔥OmniGen-v1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.5x↑🎉 | 3.3x↑🎉 | <b>Lumina2</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉</p>
|
|
131
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/allegro.C0_L0_Q0_NONE.gif width=117px>
|
|
132
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/allegro.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.26_S27.gif width=117px>
|
|
133
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/auraflow.C0_L0_Q0_NONE.png width=133px>
|
|
134
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/auraflow.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.08_S28.png width=133px>
|
|
135
|
+
<p><b>🔥Allegro</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.36x↑🎉 | <b>AuraFlow-v0.3</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.27x↑🎉 </p>
|
|
136
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/sana.C0_L0_Q0_NONE.png width=100px>
|
|
137
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/sana.C0_L0_Q0_DBCACHE_F8B0_W8M0MC2_T0O2_R0.25_S6.png width=100px>
|
|
138
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/sana.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.3_S8.png width=100px>
|
|
139
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-sigma.C0_L0_Q0_NONE.png width=100px>
|
|
140
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-sigma.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S28.png width=100px>
|
|
141
|
+
<p><b>🔥Sana</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.3x↑🎉 | 1.6x↑🎉| <b>PixArt-Sigma</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.3x↑🎉</p>
|
|
142
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-alpha.C0_L0_Q0_NONE.png width=100px>
|
|
143
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-alpha.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.05_S27.png width=100px>
|
|
144
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-alpha.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S32.png width=100px>
|
|
145
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/sd_3_5.C0_L0_Q0_NONE.png width=100px>
|
|
146
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/sd_3_5.C0_L0_Q0_DBCACHE_F1B0_W8M0MC3_T0O2_R0.12_S30.png width=100px>
|
|
147
|
+
<p><b>🔥PixArt-Alpha</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.6x↑🎉 | 1.8x↑🎉| <b>SD 3.5</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.5x↑🎉</p>
|
|
148
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/amused.C0_L0_Q0_NONE.png width=100px>
|
|
149
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/amused.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.34_S1.png width=100px>
|
|
150
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/amused.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.38_S2.png width=100px>
|
|
151
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/dit-xl.C0_L0_Q0_NONE.png width=100px>
|
|
152
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/dit-xl.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.15_S11.png width=100px>
|
|
153
|
+
<p><b>🔥Asumed</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.1x↑🎉 | 1.2x↑🎉 | <b>DiT-XL-256</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉
|
|
97
154
|
<br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
|
|
98
155
|
</div>
|
|
99
156
|
|
|
@@ -126,9 +183,10 @@ Dynamic: requires-python
|
|
|
126
183
|
- [🔥Supported Models](#supported)
|
|
127
184
|
- [🎉Unified Cache APIs](#unified)
|
|
128
185
|
- [📚Forward Pattern Matching](#unified)
|
|
129
|
-
- [
|
|
186
|
+
- [♥️Cache with One-line Code](#unified)
|
|
130
187
|
- [🔥Automatic Block Adapter](#unified)
|
|
131
188
|
- [📚Hybird Forward Pattern](#unified)
|
|
189
|
+
- [📚Implement Patch Functor](#unified)
|
|
132
190
|
- [🤖Cache Acceleration Stats](#unified)
|
|
133
191
|
- [⚡️Dual Block Cache](#dbcache)
|
|
134
192
|
- [🔥Hybrid TaylorSeer](#taylorseer)
|
|
@@ -161,9 +219,15 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
161
219
|
- [🚀Qwen-Image-Lightning](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
162
220
|
- [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
163
221
|
- [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
222
|
+
- [🚀SkyReelsV2](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
223
|
+
- [🚀LTXVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
224
|
+
- [🚀OmniGen](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
225
|
+
- [🚀Lumina2](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
164
226
|
- [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
165
227
|
- [🚀FLUX.1-Fill-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
166
228
|
- [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
229
|
+
- [🚀Chroma1-HD](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
230
|
+
- [🚀VisualCloze](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
167
231
|
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
168
232
|
- [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
169
233
|
- [🚀CogView3-Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -175,9 +239,16 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
175
239
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
176
240
|
- [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
177
241
|
- [🚀HiDream-I1-Full](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
242
|
+
- [🚀AuraFlow-v0.3](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
178
243
|
- [🚀PixArt-Alpha](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
179
244
|
- [🚀PixArt-Sigma](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
245
|
+
- [🚀NVIDIA Sana](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
180
246
|
- [🚀SD-3/3.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
247
|
+
- [🚀ConsisID](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
248
|
+
- [🚀Allegro](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
249
|
+
- [🚀Amused](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
250
|
+
- [🚀DiT-XL](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
251
|
+
- ...
|
|
181
252
|
|
|
182
253
|
</details>
|
|
183
254
|
|
|
@@ -265,6 +336,75 @@ cache_dit.enable_cache(
|
|
|
265
336
|
)
|
|
266
337
|
```
|
|
267
338
|
|
|
339
|
+
Even sometimes you have more complex cases, such as **Wan 2.2 MoE**, which has more than one Transformer (namely `transformer` and `transformer_2`) in its structure. Fortunately, **cache-dit** can also handle this situation very well. Please refer to [📚Wan 2.2 MoE](./examples/pipeline/run_wan_2.2.py) as an example.
|
|
340
|
+
|
|
341
|
+
```python
|
|
342
|
+
from cache_dit import ForwardPattern, BlockAdapter, ParamsModifier
|
|
343
|
+
|
|
344
|
+
cache_dit.enable_cache(
|
|
345
|
+
BlockAdapter(
|
|
346
|
+
pipe=pipe,
|
|
347
|
+
transformer=[
|
|
348
|
+
pipe.transformer,
|
|
349
|
+
pipe.transformer_2,
|
|
350
|
+
],
|
|
351
|
+
blocks=[
|
|
352
|
+
pipe.transformer.blocks,
|
|
353
|
+
pipe.transformer_2.blocks,
|
|
354
|
+
],
|
|
355
|
+
forward_pattern=[
|
|
356
|
+
ForwardPattern.Pattern_2,
|
|
357
|
+
ForwardPattern.Pattern_2,
|
|
358
|
+
],
|
|
359
|
+
# Setup different cache params for each 'blocks'. You can
|
|
360
|
+
# pass any specific cache params to ParamModifier, the old
|
|
361
|
+
# value will be overwrite by the new one.
|
|
362
|
+
params_modifiers=[
|
|
363
|
+
ParamsModifier(
|
|
364
|
+
max_warmup_steps=4,
|
|
365
|
+
max_cached_steps=8,
|
|
366
|
+
),
|
|
367
|
+
ParamsModifier(
|
|
368
|
+
max_warmup_steps=2,
|
|
369
|
+
max_cached_steps=20,
|
|
370
|
+
),
|
|
371
|
+
],
|
|
372
|
+
has_separate_cfg=True,
|
|
373
|
+
),
|
|
374
|
+
)
|
|
375
|
+
```
|
|
376
|
+
### 📚Implement Patch Functor
|
|
377
|
+
|
|
378
|
+
For any PATTERN not {0...5}, we introduced the simple abstract concept of **Patch Functor**. Users can implement a subclass of Patch Functor to convert an unknown Pattern into a known PATTERN, and for some models, users may also need to fuse the operations within the blocks for loop into block forward.
|
|
379
|
+
|
|
380
|
+

|
|
381
|
+
|
|
382
|
+
Some Patch functors have already been provided in cache-dit: [📚HiDreamPatchFunctor](./src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [📚ChromaPatchFunctor](./src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc. After implementing Patch Functor, users need to set the `patch_functor` property of **BlockAdapter**.
|
|
383
|
+
|
|
384
|
+
```python
|
|
385
|
+
@BlockAdapterRegistry.register("HiDream")
|
|
386
|
+
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
387
|
+
from diffusers import HiDreamImageTransformer2DModel
|
|
388
|
+
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
|
389
|
+
|
|
390
|
+
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
391
|
+
return BlockAdapter(
|
|
392
|
+
pipe=pipe,
|
|
393
|
+
transformer=pipe.transformer,
|
|
394
|
+
blocks=[
|
|
395
|
+
pipe.transformer.double_stream_blocks,
|
|
396
|
+
pipe.transformer.single_stream_blocks,
|
|
397
|
+
],
|
|
398
|
+
forward_pattern=[
|
|
399
|
+
ForwardPattern.Pattern_0,
|
|
400
|
+
ForwardPattern.Pattern_3,
|
|
401
|
+
],
|
|
402
|
+
# NOTE: Setup your custom patch functor here.
|
|
403
|
+
patch_functor=HiDreamPatchFunctor(),
|
|
404
|
+
**kwargs,
|
|
405
|
+
)
|
|
406
|
+
```
|
|
407
|
+
|
|
268
408
|
### 🤖Cache Acceleration Stats Summary
|
|
269
409
|
|
|
270
410
|
After finishing each inference of `pipe(...)`, you can call the `cache_dit.summary()` API on pipe to get the details of the **Cache Acceleration Stats** for the current inference.
|
|
@@ -348,7 +488,7 @@ cache_dit.enable_cache(
|
|
|
348
488
|
# Taylorseer cache type cache be hidden_states or residual.
|
|
349
489
|
taylorseer_cache_type="residual",
|
|
350
490
|
# Higher values of order will lead to longer computation time
|
|
351
|
-
taylorseer_order=
|
|
491
|
+
taylorseer_order=1, # default is 1.
|
|
352
492
|
max_warmup_steps=3, # prefer: >= order + 1
|
|
353
493
|
residual_diff_threshold=0.12
|
|
354
494
|
)s
|
|
@@ -372,7 +512,7 @@ cache_dit.enable_cache(
|
|
|
372
512
|
|
|
373
513
|
<div id="cfg"></div>
|
|
374
514
|
|
|
375
|
-
cache-dit supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `
|
|
515
|
+
cache-dit supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `enable_separate_cfg` param to **False (default, None)**. Otherwise, set it to True. For examples:
|
|
376
516
|
|
|
377
517
|
```python
|
|
378
518
|
cache_dit.enable_cache(
|
|
@@ -380,14 +520,14 @@ cache_dit.enable_cache(
|
|
|
380
520
|
...,
|
|
381
521
|
# CFG: classifier free guidance or not
|
|
382
522
|
# For model that fused CFG and non-CFG into single forward step,
|
|
383
|
-
# should set
|
|
523
|
+
# should set enable_separate_cfg as False. For example, set it as True
|
|
384
524
|
# for Wan 2.1/Qwen-Image and set it as False for FLUX.1, HunyuanVideo,
|
|
385
525
|
# CogVideoX, Mochi, LTXVideo, Allegro, CogView3Plus, EasyAnimate, SD3, etc.
|
|
386
|
-
|
|
526
|
+
enable_separate_cfg=True, # Wan 2.1, Qwen-Image, CogView4, Cosmos, SkyReelsV2, etc.
|
|
387
527
|
# Compute cfg forward first or not, default False, namely,
|
|
388
528
|
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
389
529
|
cfg_compute_first=False,
|
|
390
|
-
# Compute
|
|
530
|
+
# Compute separate diff values for CFG and non-CFG step,
|
|
391
531
|
# default True. If False, we will use the computed diff from
|
|
392
532
|
# current non-CFG transformer step for current CFG step.
|
|
393
533
|
cfg_diff_compute_separate=True,
|
|
@@ -1,29 +1,30 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=CtkelOzOJFXtgJ0APT8pLd5zWrG63eLavWaOD_cX7xo,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
|
|
7
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
8
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=OFJlxxyODhoZstN4EfPgC7tE8M1ZdQFcE25gDNrW7NA,18212
|
|
8
|
+
cache_dit/cache_factory/cache_interface.py,sha256=tHQv7i8Hp6nfbjZWHwDx3nEvCfxLeBw26aMYjyu6nMw,8541
|
|
9
9
|
cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
|
|
10
10
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
11
11
|
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
12
|
-
cache_dit/cache_factory/block_adapters/__init__.py,sha256=
|
|
13
|
-
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=
|
|
14
|
-
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=
|
|
12
|
+
cache_dit/cache_factory/block_adapters/__init__.py,sha256=33geXMz56TxFWMp0c-H4__MY5SGRzKMKj3TXnUYOMlc,17512
|
|
13
|
+
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=zZbbsZYWbUClfa6He69w_Wdf8ZLhKwMAb9gURYEUmgQ,23725
|
|
14
|
+
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
|
|
15
15
|
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
|
|
16
16
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
|
|
17
17
|
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=Bv56qETXhsREvCrNvnZpSqDIIHsi6Ze3FJW4Yk2x3uI,8597
|
|
18
18
|
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=d4H9kEB0AgnVMT8aF0Y54SUMUQUxw5HQ8gRkoCuTQ_A,14577
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
|
|
20
20
|
cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
|
|
21
|
-
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=
|
|
22
|
-
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=
|
|
23
|
-
cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=
|
|
24
|
-
cache_dit/cache_factory/patch_functors/__init__.py,sha256=
|
|
21
|
+
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=FWdgInClWY8VZBsZIevtYk--rX-RL8c3QfNOJtqR8a4,11855
|
|
22
|
+
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=Ig5VKoQ46iG3lKmsaMulYxd2vCm__2rY8NBvERwexwM,32719
|
|
23
|
+
cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=4nxgSEZvDy-w-7XuJYzsyzdtF1_uFrDwlF06XBDFVKQ,3922
|
|
24
|
+
cache_dit/cache_factory/patch_functors/__init__.py,sha256=oI6F3N9ezahRHaFUOZ1GfrAw1qFdKrxFXXmlwwehHj4,530
|
|
25
25
|
cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
|
|
26
|
-
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=
|
|
26
|
+
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=xD0Q96VArp1vYBLQ0pcjRIyFB1i_Y7muZ2q07Hz8Oqs,13430
|
|
27
|
+
cache_dit/cache_factory/patch_functors/functor_dit.py,sha256=SDjhzCWa6PoFNN4_upoQEf6DHvW1yJ7zuXMS2VvyJco,3904
|
|
27
28
|
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
|
|
28
29
|
cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=pi_vvpDy1lsgQHxu3eK3v93rdJL7oNwkt3WakRP8pbw,15375
|
|
29
30
|
cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
|
|
@@ -40,9 +41,9 @@ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,
|
|
|
40
41
|
cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
|
|
41
42
|
cache_dit/quantize/quantize_ao.py,sha256=Fx1KW4l3gdEkdrcAYtPoDW7WKBJWrs3glOHiEwW_TgE,6160
|
|
42
43
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
48
|
-
cache_dit-0.2.
|
|
44
|
+
cache_dit-0.2.34.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
45
|
+
cache_dit-0.2.34.dist-info/METADATA,sha256=BvEY08xjrGPcqTEZSHvSDtJP4sGZv1T6jzhGj-jQbvo,38284
|
|
46
|
+
cache_dit-0.2.34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
47
|
+
cache_dit-0.2.34.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
48
|
+
cache_dit-0.2.34.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
49
|
+
cache_dit-0.2.34.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|