cache-dit 0.2.33__py3-none-any.whl → 0.2.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +5 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +46 -2
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +11 -11
- cache_dit/cache_factory/cache_contexts/cache_manager.py +5 -5
- cache_dit/cache_factory/cache_contexts/taylorseer.py +12 -6
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- cache_dit/metrics/clip_score.py +135 -0
- cache_dit/metrics/fid.py +42 -0
- cache_dit/metrics/image_reward.py +177 -0
- cache_dit/metrics/lpips.py +2 -14
- cache_dit/metrics/metrics.py +420 -76
- cache_dit/utils.py +15 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/METADATA +261 -52
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/RECORD +25 -22
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -4,6 +4,10 @@ except ImportError:
|
|
|
4
4
|
__version__ = "unknown version"
|
|
5
5
|
version_tuple = (0, 0, "unknown version")
|
|
6
6
|
|
|
7
|
+
from cache_dit.utils import summary
|
|
8
|
+
from cache_dit.utils import strify
|
|
9
|
+
from cache_dit.utils import disable_print
|
|
10
|
+
from cache_dit.logger import init_logger
|
|
7
11
|
from cache_dit.cache_factory import load_options
|
|
8
12
|
from cache_dit.cache_factory import enable_cache
|
|
9
13
|
from cache_dit.cache_factory import disable_cache
|
|
@@ -18,9 +22,7 @@ from cache_dit.cache_factory import supported_pipelines
|
|
|
18
22
|
from cache_dit.cache_factory import get_adapter
|
|
19
23
|
from cache_dit.compile import set_compile_configs
|
|
20
24
|
from cache_dit.quantize import quantize
|
|
21
|
-
|
|
22
|
-
from cache_dit.utils import strify
|
|
23
|
-
from cache_dit.logger import init_logger
|
|
25
|
+
|
|
24
26
|
|
|
25
27
|
NONE = CacheType.NONE
|
|
26
28
|
DBCache = CacheType.DBCache
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.36'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 36)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -153,7 +153,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
|
|
156
|
-
@BlockAdapterRegistry.register("
|
|
156
|
+
@BlockAdapterRegistry.register("LTX")
|
|
157
157
|
def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
158
158
|
from diffusers import LTXVideoTransformer3DModel
|
|
159
159
|
|
|
@@ -248,7 +248,10 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
248
248
|
pipe=pipe,
|
|
249
249
|
transformer=pipe.transformer,
|
|
250
250
|
blocks=pipe.transformer.blocks,
|
|
251
|
-
|
|
251
|
+
# NOTE: Use Pattern_3 instead of Pattern_2 because the
|
|
252
|
+
# encoder_hidden_states will never change in the blocks
|
|
253
|
+
# forward loop.
|
|
254
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
252
255
|
has_separate_cfg=True,
|
|
253
256
|
**kwargs,
|
|
254
257
|
)
|
|
@@ -285,6 +288,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
285
288
|
@BlockAdapterRegistry.register("DiT")
|
|
286
289
|
def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
287
290
|
from diffusers import DiTTransformer2DModel
|
|
291
|
+
from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
|
|
288
292
|
|
|
289
293
|
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
290
294
|
return BlockAdapter(
|
|
@@ -292,6 +296,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
292
296
|
transformer=pipe.transformer,
|
|
293
297
|
blocks=pipe.transformer.transformer_blocks,
|
|
294
298
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
299
|
+
patch_functor=DiTPatchFunctor(),
|
|
295
300
|
**kwargs,
|
|
296
301
|
)
|
|
297
302
|
|
|
@@ -331,24 +336,13 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
331
336
|
|
|
332
337
|
|
|
333
338
|
@BlockAdapterRegistry.register("Lumina")
|
|
334
|
-
def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
335
|
-
from diffusers import LuminaNextDiT2DModel
|
|
336
|
-
|
|
337
|
-
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
338
|
-
return BlockAdapter(
|
|
339
|
-
pipe=pipe,
|
|
340
|
-
transformer=pipe.transformer,
|
|
341
|
-
blocks=pipe.transformer.layers,
|
|
342
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
343
|
-
**kwargs,
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@BlockAdapterRegistry.register("Lumina2")
|
|
348
339
|
def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
349
340
|
from diffusers import Lumina2Transformer2DModel
|
|
341
|
+
from diffusers import LuminaNextDiT2DModel
|
|
350
342
|
|
|
351
|
-
assert isinstance(
|
|
343
|
+
assert isinstance(
|
|
344
|
+
pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel)
|
|
345
|
+
)
|
|
352
346
|
return BlockAdapter(
|
|
353
347
|
pipe=pipe,
|
|
354
348
|
transformer=pipe.transformer,
|
|
@@ -386,12 +380,10 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
386
380
|
)
|
|
387
381
|
|
|
388
382
|
|
|
389
|
-
@BlockAdapterRegistry.register("Sana"
|
|
383
|
+
@BlockAdapterRegistry.register("Sana")
|
|
390
384
|
def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
391
385
|
from diffusers import SanaTransformer2DModel
|
|
392
386
|
|
|
393
|
-
# TODO: fix -> got multiple values for argument 'encoder_hidden_states'
|
|
394
|
-
|
|
395
387
|
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
396
388
|
return BlockAdapter(
|
|
397
389
|
pipe=pipe,
|
|
@@ -469,6 +461,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
469
461
|
@BlockAdapterRegistry.register("Chroma")
|
|
470
462
|
def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
471
463
|
from diffusers import ChromaTransformer2DModel
|
|
464
|
+
from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
|
|
472
465
|
|
|
473
466
|
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
474
467
|
return BlockAdapter(
|
|
@@ -482,6 +475,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
482
475
|
ForwardPattern.Pattern_1,
|
|
483
476
|
ForwardPattern.Pattern_3,
|
|
484
477
|
],
|
|
478
|
+
patch_functor=ChromaPatchFunctor(),
|
|
485
479
|
has_separate_cfg=True,
|
|
486
480
|
**kwargs,
|
|
487
481
|
)
|
|
@@ -16,8 +16,52 @@ logger = init_logger(__name__)
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class ParamsModifier:
|
|
19
|
-
def __init__(
|
|
20
|
-
self
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
# Cache context kwargs
|
|
22
|
+
Fn_compute_blocks: Optional[int] = None,
|
|
23
|
+
Bn_compute_blocks: Optional[int] = None,
|
|
24
|
+
max_warmup_steps: Optional[int] = None,
|
|
25
|
+
max_cached_steps: Optional[int] = None,
|
|
26
|
+
max_continuous_cached_steps: Optional[int] = None,
|
|
27
|
+
residual_diff_threshold: Optional[float] = None,
|
|
28
|
+
# Cache CFG or not
|
|
29
|
+
enable_separate_cfg: Optional[bool] = None,
|
|
30
|
+
cfg_compute_first: Optional[bool] = None,
|
|
31
|
+
cfg_diff_compute_separate: Optional[bool] = None,
|
|
32
|
+
# Hybird TaylorSeer
|
|
33
|
+
enable_taylorseer: Optional[bool] = None,
|
|
34
|
+
enable_encoder_taylorseer: Optional[bool] = None,
|
|
35
|
+
taylorseer_cache_type: Optional[str] = None,
|
|
36
|
+
taylorseer_order: Optional[int] = None,
|
|
37
|
+
**other_cache_context_kwargs,
|
|
38
|
+
):
|
|
39
|
+
self._context_kwargs = other_cache_context_kwargs.copy()
|
|
40
|
+
self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
|
|
41
|
+
self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
|
|
42
|
+
self._maybe_update_param("max_warmup_steps", max_warmup_steps)
|
|
43
|
+
self._maybe_update_param("max_cached_steps", max_cached_steps)
|
|
44
|
+
self._maybe_update_param(
|
|
45
|
+
"max_continuous_cached_steps", max_continuous_cached_steps
|
|
46
|
+
)
|
|
47
|
+
self._maybe_update_param(
|
|
48
|
+
"residual_diff_threshold", residual_diff_threshold
|
|
49
|
+
)
|
|
50
|
+
self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
|
|
51
|
+
self._maybe_update_param("cfg_compute_first", cfg_compute_first)
|
|
52
|
+
self._maybe_update_param(
|
|
53
|
+
"cfg_diff_compute_separate", cfg_diff_compute_separate
|
|
54
|
+
)
|
|
55
|
+
self._maybe_update_param("enable_taylorseer", enable_taylorseer)
|
|
56
|
+
self._maybe_update_param(
|
|
57
|
+
"enable_encoder_taylorseer", enable_encoder_taylorseer
|
|
58
|
+
)
|
|
59
|
+
self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
|
|
60
|
+
self._maybe_update_param("taylorseer_order", taylorseer_order)
|
|
61
|
+
|
|
62
|
+
def _maybe_update_param(self, key: str, value: Any):
|
|
63
|
+
if value is not None:
|
|
64
|
+
self._context_kwargs[key] = value
|
|
21
65
|
|
|
22
66
|
|
|
23
67
|
@dataclasses.dataclass
|
|
@@ -10,13 +10,14 @@ logger = init_logger(__name__)
|
|
|
10
10
|
|
|
11
11
|
class BlockAdapterRegistry:
|
|
12
12
|
_adapters: Dict[str, Callable[..., BlockAdapter]] = {}
|
|
13
|
-
|
|
13
|
+
_predefined_adapters_has_separate_cfg: List[str] = [
|
|
14
14
|
"QwenImage",
|
|
15
15
|
"Wan",
|
|
16
16
|
"CogView4",
|
|
17
17
|
"Cosmos",
|
|
18
18
|
"SkyReelsV2",
|
|
19
19
|
"Chroma",
|
|
20
|
+
"Lumina2",
|
|
20
21
|
]
|
|
21
22
|
|
|
22
23
|
@classmethod
|
|
@@ -68,7 +69,7 @@ class BlockAdapterRegistry:
|
|
|
68
69
|
return True
|
|
69
70
|
|
|
70
71
|
pipe_cls_name = pipe_or_adapter.__class__.__name__
|
|
71
|
-
for name in cls.
|
|
72
|
+
for name in cls._predefined_adapters_has_separate_cfg:
|
|
72
73
|
if pipe_cls_name.startswith(name):
|
|
73
74
|
return True
|
|
74
75
|
|
|
@@ -114,27 +114,27 @@ class CachedAdapter:
|
|
|
114
114
|
**cache_context_kwargs,
|
|
115
115
|
):
|
|
116
116
|
# Check cache_context_kwargs
|
|
117
|
-
if cache_context_kwargs["
|
|
117
|
+
if cache_context_kwargs["enable_separate_cfg"] is None:
|
|
118
118
|
# Check cfg for some specific case if users don't set it as True
|
|
119
119
|
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
-
cache_context_kwargs["
|
|
120
|
+
cache_context_kwargs["enable_separate_cfg"] = True
|
|
121
121
|
logger.info(
|
|
122
|
-
f"Use custom '
|
|
122
|
+
f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
|
|
123
123
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
124
|
)
|
|
125
125
|
else:
|
|
126
|
-
cache_context_kwargs["
|
|
126
|
+
cache_context_kwargs["enable_separate_cfg"] = (
|
|
127
127
|
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
128
|
)
|
|
129
129
|
logger.info(
|
|
130
|
-
f"Use default '
|
|
131
|
-
f"register: {cache_context_kwargs['
|
|
130
|
+
f"Use default 'enable_separate_cfg' from block adapter "
|
|
131
|
+
f"register: {cache_context_kwargs['enable_separate_cfg']}, "
|
|
132
132
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
133
|
)
|
|
134
134
|
else:
|
|
135
135
|
logger.info(
|
|
136
|
-
f"Use custom '
|
|
137
|
-
f"kwargs: {cache_context_kwargs['
|
|
136
|
+
f"Use custom 'enable_separate_cfg' from cache context "
|
|
137
|
+
f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
|
|
138
138
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
139
|
)
|
|
140
140
|
|
|
@@ -53,20 +53,20 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
53
53
|
enable_taylorseer: bool = False
|
|
54
54
|
enable_encoder_taylorseer: bool = False
|
|
55
55
|
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
56
|
-
taylorseer_order: int =
|
|
56
|
+
taylorseer_order: int = 1 # The order for TaylorSeer
|
|
57
57
|
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
58
58
|
taylorseer: Optional[TaylorSeer] = None
|
|
59
59
|
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
60
60
|
|
|
61
|
-
# Support
|
|
61
|
+
# Support enable_separate_cfg, such as Wan 2.1,
|
|
62
62
|
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
63
|
-
# forward step, should set
|
|
63
|
+
# forward step, should set enable_separate_cfg as False.
|
|
64
64
|
# For example: CogVideoX, HunyuanVideo, Mochi.
|
|
65
|
-
|
|
65
|
+
enable_separate_cfg: bool = False
|
|
66
66
|
# Compute cfg forward first or not, default False, namely,
|
|
67
67
|
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
68
68
|
cfg_compute_first: bool = False
|
|
69
|
-
# Compute
|
|
69
|
+
# Compute separate diff values for CFG and non-CFG step,
|
|
70
70
|
# default True. If False, we will use the computed diff from
|
|
71
71
|
# current non-CFG transformer step for current CFG step.
|
|
72
72
|
cfg_diff_compute_separate: bool = True
|
|
@@ -89,7 +89,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
89
89
|
if logger.isEnabledFor(logging.DEBUG):
|
|
90
90
|
logger.info(f"Created _CacheContext: {self.name}")
|
|
91
91
|
# Some checks for settings
|
|
92
|
-
if self.
|
|
92
|
+
if self.enable_separate_cfg:
|
|
93
93
|
if self.cfg_diff_compute_separate:
|
|
94
94
|
assert self.cfg_compute_first is False, (
|
|
95
95
|
"cfg_compute_first must set as False if "
|
|
@@ -108,12 +108,12 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
108
108
|
|
|
109
109
|
if self.enable_taylorseer:
|
|
110
110
|
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
111
|
-
if self.
|
|
111
|
+
if self.enable_separate_cfg:
|
|
112
112
|
self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
113
113
|
|
|
114
114
|
if self.enable_encoder_taylorseer:
|
|
115
115
|
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
116
|
-
if self.
|
|
116
|
+
if self.enable_separate_cfg:
|
|
117
117
|
self.cfg_encoder_taylorseer = TaylorSeer(
|
|
118
118
|
**self.taylorseer_kwargs
|
|
119
119
|
)
|
|
@@ -145,7 +145,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
145
145
|
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
146
146
|
# current step: incr step - 1
|
|
147
147
|
self.transformer_executed_steps += 1
|
|
148
|
-
if not self.
|
|
148
|
+
if not self.enable_separate_cfg:
|
|
149
149
|
self.executed_steps += 1
|
|
150
150
|
else:
|
|
151
151
|
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
@@ -183,7 +183,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
183
183
|
|
|
184
184
|
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
185
185
|
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
186
|
-
if self.
|
|
186
|
+
if self.enable_separate_cfg:
|
|
187
187
|
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
188
188
|
if not self.is_separate_cfg_step():
|
|
189
189
|
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
@@ -269,7 +269,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
269
269
|
return self.transformer_executed_steps - 1
|
|
270
270
|
|
|
271
271
|
def is_separate_cfg_step(self):
|
|
272
|
-
if not self.
|
|
272
|
+
if not self.enable_separate_cfg:
|
|
273
273
|
return False
|
|
274
274
|
if self.cfg_compute_first:
|
|
275
275
|
# CFG steps: 0, 2, 4, 6, ...
|
|
@@ -74,8 +74,8 @@ class CachedContextManager:
|
|
|
74
74
|
del self._cached_context_manager[cached_context]
|
|
75
75
|
|
|
76
76
|
def clear_contexts(self):
|
|
77
|
-
for
|
|
78
|
-
self.remove_context(
|
|
77
|
+
for context_name in list(self._cached_context_manager.keys()):
|
|
78
|
+
self.remove_context(context_name)
|
|
79
79
|
|
|
80
80
|
@contextlib.contextmanager
|
|
81
81
|
def enter_context(self, cached_context: CachedContext | str):
|
|
@@ -364,10 +364,10 @@ class CachedContextManager:
|
|
|
364
364
|
return cached_context.Bn_compute_blocks
|
|
365
365
|
|
|
366
366
|
@torch.compiler.disable
|
|
367
|
-
def
|
|
367
|
+
def enable_separate_cfg(self) -> bool:
|
|
368
368
|
cached_context = self.get_context()
|
|
369
369
|
assert cached_context is not None, "cached_context must be set before"
|
|
370
|
-
return cached_context.
|
|
370
|
+
return cached_context.enable_separate_cfg
|
|
371
371
|
|
|
372
372
|
@torch.compiler.disable
|
|
373
373
|
def is_separate_cfg_step(self) -> bool:
|
|
@@ -410,7 +410,7 @@ class CachedContextManager:
|
|
|
410
410
|
|
|
411
411
|
if all(
|
|
412
412
|
(
|
|
413
|
-
self.
|
|
413
|
+
self.enable_separate_cfg(),
|
|
414
414
|
self.is_separate_cfg_step(),
|
|
415
415
|
not self.cfg_diff_compute_separate(),
|
|
416
416
|
self.get_current_step_residual_diff() is not None,
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import torch
|
|
3
|
+
from typing import List, Dict
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class TaylorSeer:
|
|
@@ -17,7 +19,7 @@ class TaylorSeer:
|
|
|
17
19
|
self.reset_cache()
|
|
18
20
|
|
|
19
21
|
def reset_cache(self):
|
|
20
|
-
self.state = {
|
|
22
|
+
self.state: Dict[str, List[torch.Tensor]] = {
|
|
21
23
|
"dY_prev": [None] * self.ORDER,
|
|
22
24
|
"dY_current": [None] * self.ORDER,
|
|
23
25
|
}
|
|
@@ -36,15 +38,19 @@ class TaylorSeer:
|
|
|
36
38
|
return True
|
|
37
39
|
return False
|
|
38
40
|
|
|
39
|
-
def approximate_derivative(self, Y):
|
|
41
|
+
def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
|
|
40
42
|
# n-th order Taylor expansion:
|
|
41
43
|
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
42
44
|
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
43
45
|
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
44
46
|
# especially for large n_derivatives.
|
|
45
|
-
dY_current = [None] * self.ORDER
|
|
47
|
+
dY_current: List[torch.Tensor] = [None] * self.ORDER
|
|
46
48
|
dY_current[0] = Y
|
|
47
49
|
window = self.current_step - self.last_non_approximated_step
|
|
50
|
+
if self.state["dY_prev"][0] is not None:
|
|
51
|
+
if dY_current[0].shape != self.state["dY_prev"][0].shape:
|
|
52
|
+
self.reset_cache()
|
|
53
|
+
|
|
48
54
|
for i in range(self.n_derivatives):
|
|
49
55
|
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
50
56
|
dY_current[i + 1] = (
|
|
@@ -54,7 +60,7 @@ class TaylorSeer:
|
|
|
54
60
|
break
|
|
55
61
|
return dY_current
|
|
56
62
|
|
|
57
|
-
def approximate_value(self):
|
|
63
|
+
def approximate_value(self) -> torch.Tensor:
|
|
58
64
|
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
59
65
|
# especially for large n_derivatives.
|
|
60
66
|
elapsed = self.current_step - self.last_non_approximated_step
|
|
@@ -69,7 +75,7 @@ class TaylorSeer:
|
|
|
69
75
|
def mark_step_begin(self):
|
|
70
76
|
self.current_step += 1
|
|
71
77
|
|
|
72
|
-
def update(self, Y):
|
|
78
|
+
def update(self, Y: torch.Tensor):
|
|
73
79
|
# Directly call this method will ingnore the warmup
|
|
74
80
|
# policy and force full computation.
|
|
75
81
|
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
@@ -87,7 +93,7 @@ class TaylorSeer:
|
|
|
87
93
|
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
88
94
|
self.last_non_approximated_step = self.current_step
|
|
89
95
|
|
|
90
|
-
def step(self, Y):
|
|
96
|
+
def step(self, Y: torch.Tensor):
|
|
91
97
|
self.mark_step_begin()
|
|
92
98
|
if self.should_compute_full():
|
|
93
99
|
self.update(Y)
|
|
@@ -24,14 +24,14 @@ def enable_cache(
|
|
|
24
24
|
max_continuous_cached_steps: int = -1,
|
|
25
25
|
residual_diff_threshold: float = 0.08,
|
|
26
26
|
# Cache CFG or not
|
|
27
|
-
|
|
27
|
+
enable_separate_cfg: bool = None,
|
|
28
28
|
cfg_compute_first: bool = False,
|
|
29
29
|
cfg_diff_compute_separate: bool = True,
|
|
30
30
|
# Hybird TaylorSeer
|
|
31
31
|
enable_taylorseer: bool = False,
|
|
32
32
|
enable_encoder_taylorseer: bool = False,
|
|
33
33
|
taylorseer_cache_type: str = "residual",
|
|
34
|
-
taylorseer_order: int =
|
|
34
|
+
taylorseer_order: int = 1,
|
|
35
35
|
**other_cache_context_kwargs,
|
|
36
36
|
) -> Union[
|
|
37
37
|
DiffusionPipeline,
|
|
@@ -70,15 +70,15 @@ def enable_cache(
|
|
|
70
70
|
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
71
71
|
he value of residual diff threshold, a higher value leads to faster performance at the
|
|
72
72
|
cost of lower precision.
|
|
73
|
-
|
|
73
|
+
enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
74
74
|
Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
75
|
-
and non-CFG into single forward step, should set
|
|
75
|
+
and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
76
76
|
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
77
77
|
cfg_compute_first (`bool`, *required*, defaults to False):
|
|
78
78
|
Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
79
79
|
1, 3, 5, ... -> CFG step.
|
|
80
80
|
cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
81
|
-
Compute
|
|
81
|
+
Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
82
82
|
use the computed diff from current non-CFG transformer step for current CFG step.
|
|
83
83
|
enable_taylorseer (`bool`, *required*, defaults to False):
|
|
84
84
|
Enable the hybird TaylorSeer for hidden_states or not. We have supported the
|
|
@@ -91,10 +91,10 @@ def enable_cache(
|
|
|
91
91
|
Enable the hybird TaylorSeer for encoder_hidden_states or not.
|
|
92
92
|
taylorseer_cache_type (`str`, *required*, defaults to `residual`):
|
|
93
93
|
The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
|
|
94
|
-
taylorseer_order (`int`, *required*, defaults to
|
|
94
|
+
taylorseer_order (`int`, *required*, defaults to 1):
|
|
95
95
|
The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
the recommended value is 1 or 2.
|
|
97
|
+
other_cache_context_kwargs: (`dict`, *optional*, defaults to {})
|
|
98
98
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
99
99
|
for more details.
|
|
100
100
|
|
|
@@ -123,7 +123,7 @@ def enable_cache(
|
|
|
123
123
|
max_continuous_cached_steps
|
|
124
124
|
)
|
|
125
125
|
cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
|
|
126
|
-
cache_context_kwargs["
|
|
126
|
+
cache_context_kwargs["enable_separate_cfg"] = enable_separate_cfg
|
|
127
127
|
cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
|
|
128
128
|
cache_context_kwargs["cfg_diff_compute_separate"] = (
|
|
129
129
|
cfg_diff_compute_separate
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
|
|
2
|
+
from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
|
|
2
3
|
from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
|
|
3
4
|
from cache_dit.cache_factory.patch_functors.functor_chroma import (
|
|
4
5
|
ChromaPatchFunctor,
|