cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
from cache_dit.caching import CachedContext
|
|
5
|
+
from cache_dit.caching import CachedContextManager
|
|
6
|
+
from cache_dit.caching import PrunedContextManager
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def apply_stats(
|
|
10
|
+
module: torch.nn.Module | Any,
|
|
11
|
+
cache_context: CachedContext | str = None,
|
|
12
|
+
context_manager: CachedContextManager | PrunedContextManager = None,
|
|
13
|
+
):
|
|
14
|
+
# Patch the cached stats to the module, the cached stats
|
|
15
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
16
|
+
if module is None or context_manager is None:
|
|
17
|
+
return
|
|
18
|
+
|
|
19
|
+
if cache_context is not None:
|
|
20
|
+
context_manager.set_context(cache_context)
|
|
21
|
+
|
|
22
|
+
# Cache stats for Dual Block Cache
|
|
23
|
+
module._cached_steps = context_manager.get_cached_steps()
|
|
24
|
+
module._residual_diffs = context_manager.get_residual_diffs()
|
|
25
|
+
module._cfg_cached_steps = context_manager.get_cfg_cached_steps()
|
|
26
|
+
module._cfg_residual_diffs = context_manager.get_cfg_residual_diffs()
|
|
27
|
+
# Pruned stats for Dynamic Block Prune
|
|
28
|
+
if not isinstance(context_manager, PrunedContextManager):
|
|
29
|
+
return
|
|
30
|
+
module._pruned_steps = context_manager.get_pruned_steps()
|
|
31
|
+
module._cfg_pruned_steps = context_manager.get_cfg_pruned_steps()
|
|
32
|
+
module._pruned_blocks = context_manager.get_pruned_blocks()
|
|
33
|
+
module._cfg_pruned_blocks = context_manager.get_cfg_pruned_blocks()
|
|
34
|
+
module._actual_blocks = context_manager.get_actual_blocks()
|
|
35
|
+
module._cfg_actual_blocks = context_manager.get_cfg_actual_blocks()
|
|
36
|
+
# Caculate pruned ratio
|
|
37
|
+
if len(module._pruned_blocks) > 0 and sum(module._actual_blocks) > 0:
|
|
38
|
+
module._pruned_ratio = sum(module._pruned_blocks) / sum(
|
|
39
|
+
module._actual_blocks
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
module._pruned_ratio = None
|
|
43
|
+
if (
|
|
44
|
+
len(module._cfg_pruned_blocks) > 0
|
|
45
|
+
and sum(module._cfg_actual_blocks) > 0
|
|
46
|
+
):
|
|
47
|
+
module._cfg_pruned_ratio = sum(module._cfg_pruned_blocks) / sum(
|
|
48
|
+
module._cfg_actual_blocks
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
module._cfg_pruned_ratio = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def remove_stats(
|
|
55
|
+
module: torch.nn.Module | Any,
|
|
56
|
+
):
|
|
57
|
+
if module is None:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
# Dual Block Cache
|
|
61
|
+
if hasattr(module, "_cached_steps"):
|
|
62
|
+
del module._cached_steps
|
|
63
|
+
if hasattr(module, "_residual_diffs"):
|
|
64
|
+
del module._residual_diffs
|
|
65
|
+
if hasattr(module, "_cfg_cached_steps"):
|
|
66
|
+
del module._cfg_cached_steps
|
|
67
|
+
if hasattr(module, "_cfg_residual_diffs"):
|
|
68
|
+
del module._cfg_residual_diffs
|
|
69
|
+
|
|
70
|
+
# Dynamic Block Prune
|
|
71
|
+
if hasattr(module, "_pruned_steps"):
|
|
72
|
+
del module._pruned_steps
|
|
73
|
+
if hasattr(module, "_cfg_pruned_steps"):
|
|
74
|
+
del module._cfg_pruned_steps
|
|
75
|
+
if hasattr(module, "_pruned_blocks"):
|
|
76
|
+
del module._pruned_blocks
|
|
77
|
+
if hasattr(module, "_cfg_pruned_blocks"):
|
|
78
|
+
del module._cfg_pruned_blocks
|
|
79
|
+
if hasattr(module, "_actual_blocks"):
|
|
80
|
+
del module._actual_blocks
|
|
81
|
+
if hasattr(module, "_cfg_actual_blocks"):
|
|
82
|
+
del module._cfg_actual_blocks
|
|
83
|
+
if hasattr(module, "_pruned_ratio"):
|
|
84
|
+
del module._pruned_ratio
|
|
85
|
+
if hasattr(module, "_cfg_pruned_ratio"):
|
|
86
|
+
del module._cfg_pruned_ratio
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
2
|
+
Calibrator,
|
|
3
|
+
CalibratorBase,
|
|
4
|
+
CalibratorConfig,
|
|
5
|
+
TaylorSeerCalibratorConfig,
|
|
6
|
+
FoCaCalibratorConfig,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.caching.cache_contexts.cache_config import (
|
|
9
|
+
BasicCacheConfig,
|
|
10
|
+
DBCacheConfig,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.caching.cache_contexts.cache_context import (
|
|
13
|
+
CachedContext,
|
|
14
|
+
)
|
|
15
|
+
from cache_dit.caching.cache_contexts.cache_manager import (
|
|
16
|
+
CachedContextManager,
|
|
17
|
+
ContextNotExistError,
|
|
18
|
+
)
|
|
19
|
+
from cache_dit.caching.cache_contexts.prune_config import DBPruneConfig
|
|
20
|
+
from cache_dit.caching.cache_contexts.prune_context import (
|
|
21
|
+
PrunedContext,
|
|
22
|
+
)
|
|
23
|
+
from cache_dit.caching.cache_contexts.prune_manager import (
|
|
24
|
+
PrunedContextManager,
|
|
25
|
+
)
|
|
26
|
+
from cache_dit.caching.cache_contexts.context_manager import (
|
|
27
|
+
ContextManager,
|
|
28
|
+
)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import dataclasses
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
from cache_dit.caching.cache_types import CacheType
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclasses.dataclass
|
|
11
|
+
class BasicCacheConfig:
|
|
12
|
+
# Default: Dual Block Cache with Flexible FnBn configuration.
|
|
13
|
+
cache_type: CacheType = CacheType.DBCache # DBCache, DBPrune, NONE
|
|
14
|
+
|
|
15
|
+
# Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
16
|
+
# Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
17
|
+
# at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
18
|
+
# accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
19
|
+
# for more details of DBCache.
|
|
20
|
+
Fn_compute_blocks: int = 8
|
|
21
|
+
# Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
22
|
+
# Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
23
|
+
# prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
24
|
+
# that use residual cache.
|
|
25
|
+
Bn_compute_blocks: int = 0
|
|
26
|
+
# residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
27
|
+
# the value of residual diff threshold, a higher value leads to faster performance at the
|
|
28
|
+
# cost of lower precision.
|
|
29
|
+
residual_diff_threshold: Union[torch.Tensor, float] = 0.08
|
|
30
|
+
# max_warmup_steps (`int`, *required*, defaults to 8):
|
|
31
|
+
# DBCache does not apply the caching strategy when the number of running steps is less than
|
|
32
|
+
# or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
33
|
+
max_warmup_steps: int = 8 # DON'T Cache in warmup steps
|
|
34
|
+
# warmup_interval (`int`, *required*, defaults to 1):
|
|
35
|
+
# Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
|
|
36
|
+
# in warmup steps will be computed, others will use dynamic cache.
|
|
37
|
+
warmup_interval: int = 1 # skip interval in warmup steps
|
|
38
|
+
# max_cached_steps (`int`, *required*, defaults to -1):
|
|
39
|
+
# DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
40
|
+
# prevent precision degradation.
|
|
41
|
+
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
42
|
+
# max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
43
|
+
# DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
44
|
+
# prevent precision degradation.
|
|
45
|
+
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
46
|
+
# enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
47
|
+
# Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
48
|
+
# and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
49
|
+
# CogVideoX, HunyuanVideo, Mochi, etc.
|
|
50
|
+
enable_separate_cfg: Optional[bool] = None
|
|
51
|
+
# cfg_compute_first (`bool`, *required*, defaults to False):
|
|
52
|
+
# Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
53
|
+
# 1, 3, 5, ... -> CFG step.
|
|
54
|
+
cfg_compute_first: bool = False
|
|
55
|
+
# cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
56
|
+
# Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
57
|
+
# use the computed diff from current non-CFG transformer step for current CFG step.
|
|
58
|
+
cfg_diff_compute_separate: bool = True
|
|
59
|
+
# num_inference_steps (`int`, *optional*, defaults to None):
|
|
60
|
+
# num_inference_steps for DiffusionPipeline, used to adjust some internal settings
|
|
61
|
+
# for better caching performance. For example, we will refresh the cache once the
|
|
62
|
+
# executed steps exceed num_inference_steps if num_inference_steps is provided.
|
|
63
|
+
num_inference_steps: Optional[int] = None
|
|
64
|
+
|
|
65
|
+
def update(self, **kwargs) -> "BasicCacheConfig":
|
|
66
|
+
for key, value in kwargs.items():
|
|
67
|
+
if hasattr(self, key):
|
|
68
|
+
if value is not None:
|
|
69
|
+
setattr(self, key, value)
|
|
70
|
+
return self
|
|
71
|
+
|
|
72
|
+
def empty(self, **kwargs) -> "BasicCacheConfig":
|
|
73
|
+
# Set all fields to None
|
|
74
|
+
for field in dataclasses.fields(self):
|
|
75
|
+
if hasattr(self, field.name):
|
|
76
|
+
setattr(self, field.name, None)
|
|
77
|
+
if kwargs:
|
|
78
|
+
self.update(**kwargs)
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
def reset(self, **kwargs) -> "BasicCacheConfig":
|
|
82
|
+
return self.empty(**kwargs)
|
|
83
|
+
|
|
84
|
+
def as_dict(self) -> dict:
|
|
85
|
+
return dataclasses.asdict(self)
|
|
86
|
+
|
|
87
|
+
def strify(self) -> str:
|
|
88
|
+
return (
|
|
89
|
+
f"{self.cache_type}_"
|
|
90
|
+
f"F{self.Fn_compute_blocks}"
|
|
91
|
+
f"B{self.Bn_compute_blocks}_"
|
|
92
|
+
f"W{self.max_warmup_steps}"
|
|
93
|
+
f"I{self.warmup_interval}"
|
|
94
|
+
f"M{max(0, self.max_cached_steps)}"
|
|
95
|
+
f"MC{max(0, self.max_continuous_cached_steps)}_"
|
|
96
|
+
f"R{self.residual_diff_threshold}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclasses.dataclass
|
|
101
|
+
class ExtraCacheConfig:
|
|
102
|
+
# Some other not very important settings for Dual Block Cache.
|
|
103
|
+
# NOTE: These flags maybe deprecated in the future and users
|
|
104
|
+
# should never use these extra configurations in their cases.
|
|
105
|
+
|
|
106
|
+
# l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
|
|
107
|
+
# The hidden states diff threshold for DBCache if use hidden_states as
|
|
108
|
+
# cache (not residual).
|
|
109
|
+
l1_hidden_states_diff_threshold: float = None
|
|
110
|
+
# important_condition_threshold (`float`, *optional*, defaults to 0.0):
|
|
111
|
+
# Only select the most important tokens while calculating the l1 diff.
|
|
112
|
+
important_condition_threshold: float = 0.0
|
|
113
|
+
# downsample_factor (`int`, *optional*, defaults to 1):
|
|
114
|
+
# Downsample factor for Fn buffer, in order the save GPU memory.
|
|
115
|
+
downsample_factor: int = 1
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclasses.dataclass
|
|
119
|
+
class DBCacheConfig(BasicCacheConfig):
|
|
120
|
+
pass # Just an alias for BasicCacheConfig
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# The cache context codebase is adapted from FBCache. Over time its codebase
|
|
2
|
+
# diverged a lot, and context API is no longer compatible with FBCache.
|
|
1
3
|
import logging
|
|
2
4
|
import dataclasses
|
|
3
5
|
from collections import defaultdict
|
|
@@ -5,7 +7,12 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
|
5
7
|
|
|
6
8
|
import torch
|
|
7
9
|
|
|
8
|
-
from cache_dit.
|
|
10
|
+
from cache_dit.caching.cache_contexts.cache_config import (
|
|
11
|
+
BasicCacheConfig,
|
|
12
|
+
ExtraCacheConfig,
|
|
13
|
+
DBCacheConfig,
|
|
14
|
+
)
|
|
15
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
9
16
|
Calibrator,
|
|
10
17
|
CalibratorBase,
|
|
11
18
|
CalibratorConfig,
|
|
@@ -15,101 +22,16 @@ from cache_dit.logger import init_logger
|
|
|
15
22
|
logger = init_logger(__name__)
|
|
16
23
|
|
|
17
24
|
|
|
18
|
-
@dataclasses.dataclass
|
|
19
|
-
class BasicCacheConfig:
|
|
20
|
-
# Dual Block Cache with Flexible FnBn configuration.
|
|
21
|
-
|
|
22
|
-
# Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
23
|
-
# Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
24
|
-
# at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
25
|
-
# accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
26
|
-
# for more details of DBCache.
|
|
27
|
-
Fn_compute_blocks: int = 8
|
|
28
|
-
# Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
29
|
-
# Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
30
|
-
# prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
31
|
-
# that use residual cache.
|
|
32
|
-
Bn_compute_blocks: int = 0
|
|
33
|
-
# residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
34
|
-
# the value of residual diff threshold, a higher value leads to faster performance at the
|
|
35
|
-
# cost of lower precision.
|
|
36
|
-
residual_diff_threshold: Union[torch.Tensor, float] = 0.08
|
|
37
|
-
# max_warmup_steps (`int`, *required*, defaults to 8):
|
|
38
|
-
# DBCache does not apply the caching strategy when the number of running steps is less than
|
|
39
|
-
# or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
40
|
-
max_warmup_steps: int = 8 # DON'T Cache in warmup steps
|
|
41
|
-
# warmup_interval (`int`, *required*, defaults to 1):
|
|
42
|
-
# Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
|
|
43
|
-
# in warmup steps will be computed, others will use dynamic cache.
|
|
44
|
-
warmup_interval: int = 1 # skip interval in warmup steps
|
|
45
|
-
# max_cached_steps (`int`, *required*, defaults to -1):
|
|
46
|
-
# DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
47
|
-
# prevent precision degradation.
|
|
48
|
-
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
49
|
-
# max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
50
|
-
# DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
51
|
-
# prevent precision degradation.
|
|
52
|
-
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
53
|
-
# enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
54
|
-
# Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
55
|
-
# and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
56
|
-
# CogVideoX, HunyuanVideo, Mochi, etc.
|
|
57
|
-
enable_separate_cfg: Optional[bool] = None
|
|
58
|
-
# cfg_compute_first (`bool`, *required*, defaults to False):
|
|
59
|
-
# Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
60
|
-
# 1, 3, 5, ... -> CFG step.
|
|
61
|
-
cfg_compute_first: bool = False
|
|
62
|
-
# cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
63
|
-
# Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
64
|
-
# use the computed diff from current non-CFG transformer step for current CFG step.
|
|
65
|
-
cfg_diff_compute_separate: bool = True
|
|
66
|
-
|
|
67
|
-
def update(self, **kwargs) -> "BasicCacheConfig":
|
|
68
|
-
for key, value in kwargs.items():
|
|
69
|
-
if hasattr(self, key):
|
|
70
|
-
setattr(self, key, value)
|
|
71
|
-
return self
|
|
72
|
-
|
|
73
|
-
def strify(self) -> str:
|
|
74
|
-
return (
|
|
75
|
-
f"DBCACHE_F{self.Fn_compute_blocks}"
|
|
76
|
-
f"B{self.Bn_compute_blocks}_"
|
|
77
|
-
f"W{self.max_warmup_steps}"
|
|
78
|
-
f"I{self.warmup_interval}"
|
|
79
|
-
f"M{max(0, self.max_cached_steps)}"
|
|
80
|
-
f"MC{max(0, self.max_continuous_cached_steps)}_"
|
|
81
|
-
f"R{self.residual_diff_threshold}"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@dataclasses.dataclass
|
|
86
|
-
class ExtraCacheConfig:
|
|
87
|
-
# Some other not very important settings for Dual Block Cache.
|
|
88
|
-
# NOTE: These flags maybe deprecated in the future and users
|
|
89
|
-
# should never use these extra configurations in their cases.
|
|
90
|
-
|
|
91
|
-
# l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
|
|
92
|
-
# The hidden states diff threshold for DBCache if use hidden_states as
|
|
93
|
-
# cache (not residual).
|
|
94
|
-
l1_hidden_states_diff_threshold: float = None
|
|
95
|
-
# important_condition_threshold (`float`, *optional*, defaults to 0.0):
|
|
96
|
-
# Only select the most important tokens while calculating the l1 diff.
|
|
97
|
-
important_condition_threshold: float = 0.0
|
|
98
|
-
# downsample_factor (`int`, *optional*, defaults to 1):
|
|
99
|
-
# Downsample factor for Fn buffer, in order the save GPU memory.
|
|
100
|
-
downsample_factor: int = 1
|
|
101
|
-
# num_inference_steps (`int`, *optional*, defaults to -1):
|
|
102
|
-
# num_inference_steps for DiffusionPipeline, for future use.
|
|
103
|
-
num_inference_steps: int = -1
|
|
104
|
-
|
|
105
|
-
|
|
106
25
|
@dataclasses.dataclass
|
|
107
26
|
class CachedContext:
|
|
108
27
|
name: str = "default"
|
|
109
28
|
# Buffer for storing the residuals and other tensors
|
|
110
29
|
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
111
30
|
# Basic Dual Block Cache Config
|
|
112
|
-
cache_config:
|
|
31
|
+
cache_config: Union[
|
|
32
|
+
BasicCacheConfig,
|
|
33
|
+
DBCacheConfig,
|
|
34
|
+
] = dataclasses.field(
|
|
113
35
|
default_factory=BasicCacheConfig,
|
|
114
36
|
)
|
|
115
37
|
# NOTE: Users should never use these extra configurations.
|
|
@@ -131,14 +53,14 @@ class CachedContext:
|
|
|
131
53
|
# be double of executed_steps.
|
|
132
54
|
transformer_executed_steps: int = 0
|
|
133
55
|
|
|
134
|
-
# CFG & non-CFG cached steps
|
|
56
|
+
# CFG & non-CFG cached/pruned steps
|
|
135
57
|
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
136
|
-
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
58
|
+
residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
|
|
137
59
|
default_factory=lambda: defaultdict(float),
|
|
138
60
|
)
|
|
139
61
|
continuous_cached_steps: int = 0
|
|
140
62
|
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
141
|
-
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
63
|
+
cfg_residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
|
|
142
64
|
default_factory=lambda: defaultdict(float),
|
|
143
65
|
)
|
|
144
66
|
cfg_continuous_cached_steps: int = 0
|
|
@@ -286,7 +208,9 @@ class CachedContext:
|
|
|
286
208
|
def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
287
209
|
return self.cfg_calibrator, self.cfg_encoder_calibrator
|
|
288
210
|
|
|
289
|
-
def add_residual_diff(self, diff):
|
|
211
|
+
def add_residual_diff(self, diff: float | torch.Tensor):
|
|
212
|
+
if isinstance(diff, torch.Tensor):
|
|
213
|
+
diff = diff.item()
|
|
290
214
|
# step: executed_steps - 1, not transformer_steps - 1
|
|
291
215
|
step = str(self.get_current_step())
|
|
292
216
|
# Only add the diff if it is not already recorded for this step
|
|
@@ -5,8 +5,9 @@ from typing import Dict, Optional, Tuple, Union, List
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.distributed as dist
|
|
7
7
|
|
|
8
|
-
from cache_dit.
|
|
9
|
-
from cache_dit.
|
|
8
|
+
from cache_dit.caching.cache_contexts.calibrators import CalibratorBase
|
|
9
|
+
from cache_dit.caching.cache_contexts.cache_context import (
|
|
10
|
+
BasicCacheConfig,
|
|
10
11
|
CachedContext,
|
|
11
12
|
)
|
|
12
13
|
from cache_dit.logger import init_logger
|
|
@@ -14,36 +15,156 @@ from cache_dit.logger import init_logger
|
|
|
14
15
|
logger = init_logger(__name__)
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
class
|
|
18
|
+
class ContextNotExistError(Exception):
|
|
18
19
|
pass
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class CachedContextManager:
|
|
22
23
|
# Each Pipeline should have it's own context manager instance.
|
|
23
24
|
|
|
24
|
-
def __init__(self, name: str = None):
|
|
25
|
+
def __init__(self, name: str = None, persistent_context: bool = False):
|
|
25
26
|
self.name = name
|
|
26
27
|
self._current_context: CachedContext = None
|
|
27
28
|
self._cached_context_manager: Dict[str, CachedContext] = {}
|
|
29
|
+
# Whether to create new context automatically when setting
|
|
30
|
+
# a non-exist context name. Persistent context is useful when
|
|
31
|
+
# the pipeline class is not provided and users want to use
|
|
32
|
+
# cache-dit in a transformer-only way.
|
|
33
|
+
self._persistent_context = persistent_context
|
|
34
|
+
self._current_step_refreshed: bool = False
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def persistent_context(self) -> bool:
|
|
38
|
+
return self._persistent_context
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def current_context(self) -> CachedContext:
|
|
42
|
+
return self._current_context
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
@torch.compiler.disable
|
|
46
|
+
def current_step_refreshed(self) -> bool:
|
|
47
|
+
return self._current_step_refreshed
|
|
28
48
|
|
|
49
|
+
@torch.compiler.disable
|
|
50
|
+
def is_pre_refreshed(self) -> bool:
|
|
51
|
+
_context = self._current_context
|
|
52
|
+
if _context is None:
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
num_inference_steps = _context.cache_config.num_inference_steps
|
|
56
|
+
if num_inference_steps is not None:
|
|
57
|
+
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
|
|
58
|
+
return current_step == num_inference_steps - 1
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
@torch.compiler.disable
|
|
29
62
|
def new_context(self, *args, **kwargs) -> CachedContext:
|
|
63
|
+
if self._persistent_context:
|
|
64
|
+
cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
|
|
65
|
+
assert (
|
|
66
|
+
cache_config is not None
|
|
67
|
+
and cache_config.num_inference_steps is not None
|
|
68
|
+
), (
|
|
69
|
+
"When persistent_context is True, num_inference_steps "
|
|
70
|
+
"must be set in cache_config for proper cache refreshing."
|
|
71
|
+
f"\nkwargs: {kwargs}"
|
|
72
|
+
)
|
|
30
73
|
_context = CachedContext(*args, **kwargs)
|
|
74
|
+
# NOTE: Patch args and kwargs for implicit refresh.
|
|
75
|
+
_context._init_args = args # maybe empty tuple: ()
|
|
76
|
+
_context._init_kwargs = kwargs # maybe empty dict: {}
|
|
31
77
|
self._cached_context_manager[_context.name] = _context
|
|
32
78
|
return _context
|
|
33
79
|
|
|
34
|
-
|
|
80
|
+
@torch.compiler.disable
|
|
81
|
+
def maybe_refresh(
|
|
82
|
+
self,
|
|
83
|
+
cached_context: Optional[CachedContext | str] = None,
|
|
84
|
+
) -> bool:
|
|
85
|
+
if cached_context is None:
|
|
86
|
+
_context = self._current_context
|
|
87
|
+
assert _context is not None, "Current context is not set!"
|
|
88
|
+
|
|
89
|
+
if isinstance(cached_context, CachedContext):
|
|
90
|
+
_context = cached_context
|
|
91
|
+
else:
|
|
92
|
+
if cached_context not in self._cached_context_manager:
|
|
93
|
+
raise ContextNotExistError("Context not exist!")
|
|
94
|
+
_context = self._cached_context_manager[cached_context]
|
|
95
|
+
|
|
96
|
+
if self._persistent_context:
|
|
97
|
+
assert _context.cache_config.num_inference_steps is not None, (
|
|
98
|
+
"When persistent_context is True, num_inference_steps must be set "
|
|
99
|
+
"in cache_config for proper cache refreshing."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
num_inference_steps = _context.cache_config.num_inference_steps
|
|
103
|
+
if num_inference_steps is not None:
|
|
104
|
+
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
|
|
105
|
+
# Another round of inference, need to refresh cache context.
|
|
106
|
+
if current_step >= num_inference_steps:
|
|
107
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
108
|
+
logger.debug(
|
|
109
|
+
f"Refreshing cache context '{_context.name}' "
|
|
110
|
+
f"as current step: {current_step} >= "
|
|
111
|
+
f"num_inference_steps: {num_inference_steps}."
|
|
112
|
+
)
|
|
113
|
+
return True
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
@torch.compiler.disable
|
|
117
|
+
def set_context(
|
|
118
|
+
self,
|
|
119
|
+
cached_context: CachedContext | str,
|
|
120
|
+
*args,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> CachedContext:
|
|
35
123
|
if isinstance(cached_context, CachedContext):
|
|
36
124
|
self._current_context = cached_context
|
|
37
125
|
else:
|
|
38
126
|
if cached_context not in self._cached_context_manager:
|
|
39
|
-
|
|
40
|
-
|
|
127
|
+
if not self._persistent_context:
|
|
128
|
+
raise ContextNotExistError(
|
|
129
|
+
"Context not exist and persistent_context is False. Please "
|
|
130
|
+
"create new context first or set persistent_context=True."
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
# Create new context if not exist
|
|
134
|
+
if any((bool(args), bool(kwargs))):
|
|
135
|
+
kwargs["name"] = cached_context
|
|
136
|
+
self._current_context = self.new_context(
|
|
137
|
+
*args, **kwargs
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"To create new context, please provide args and kwargs."
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
self._current_context = self._cached_context_manager[
|
|
145
|
+
cached_context
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
if self.maybe_refresh(self._current_context):
|
|
149
|
+
if not any((bool(args), bool(kwargs))):
|
|
150
|
+
assert hasattr(self._current_context, "_init_args")
|
|
151
|
+
assert hasattr(self._current_context, "_init_kwargs")
|
|
152
|
+
args = self._current_context._init_args
|
|
153
|
+
kwargs = self._current_context._init_kwargs
|
|
154
|
+
|
|
155
|
+
self._current_context = self.reset_context(
|
|
156
|
+
self._current_context, *args, **kwargs
|
|
157
|
+
)
|
|
158
|
+
self._current_step_refreshed = True
|
|
159
|
+
else:
|
|
160
|
+
self._current_step_refreshed = False
|
|
161
|
+
|
|
41
162
|
return self._current_context
|
|
42
163
|
|
|
43
164
|
def get_context(self, name: str = None) -> CachedContext:
|
|
44
165
|
if name is not None:
|
|
45
166
|
if name not in self._cached_context_manager:
|
|
46
|
-
raise
|
|
167
|
+
raise ContextNotExistError("Context not exist!")
|
|
47
168
|
return self._cached_context_manager[name]
|
|
48
169
|
return self._current_context
|
|
49
170
|
|
|
@@ -482,7 +603,7 @@ class CachedContextManager:
|
|
|
482
603
|
|
|
483
604
|
if calibrator is not None:
|
|
484
605
|
# Use calibrator to update the buffer
|
|
485
|
-
calibrator.update(buffer)
|
|
606
|
+
calibrator.update(buffer, name=prefix)
|
|
486
607
|
else:
|
|
487
608
|
if logger.isEnabledFor(logging.DEBUG):
|
|
488
609
|
logger.debug(
|
|
@@ -513,7 +634,7 @@ class CachedContextManager:
|
|
|
513
634
|
calibrator, _ = self.get_calibrator()
|
|
514
635
|
|
|
515
636
|
if calibrator is not None:
|
|
516
|
-
return calibrator.approximate()
|
|
637
|
+
return calibrator.approximate(name=prefix)
|
|
517
638
|
else:
|
|
518
639
|
if logger.isEnabledFor(logging.DEBUG):
|
|
519
640
|
logger.debug(
|
|
@@ -551,7 +672,7 @@ class CachedContextManager:
|
|
|
551
672
|
|
|
552
673
|
if encoder_calibrator is not None:
|
|
553
674
|
# Use CalibratorBase to update the buffer
|
|
554
|
-
encoder_calibrator.update(buffer)
|
|
675
|
+
encoder_calibrator.update(buffer, name=prefix)
|
|
555
676
|
else:
|
|
556
677
|
if logger.isEnabledFor(logging.DEBUG):
|
|
557
678
|
logger.debug(
|
|
@@ -582,7 +703,7 @@ class CachedContextManager:
|
|
|
582
703
|
|
|
583
704
|
if encoder_calibrator is not None:
|
|
584
705
|
# Use calibrator to approximate the value
|
|
585
|
-
return encoder_calibrator.approximate()
|
|
706
|
+
return encoder_calibrator.approximate(name=prefix)
|
|
586
707
|
else:
|
|
587
708
|
if logger.isEnabledFor(logging.DEBUG):
|
|
588
709
|
logger.debug(
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from cache_dit.
|
|
1
|
+
from cache_dit.caching.cache_contexts.calibrators.base import (
|
|
2
2
|
CalibratorBase,
|
|
3
3
|
)
|
|
4
|
-
from cache_dit.
|
|
4
|
+
from cache_dit.caching.cache_contexts.calibrators.taylorseer import (
|
|
5
5
|
TaylorSeerCalibrator,
|
|
6
6
|
)
|
|
7
|
-
from cache_dit.
|
|
7
|
+
from cache_dit.caching.cache_contexts.calibrators.foca import (
|
|
8
8
|
FoCaCalibrator,
|
|
9
9
|
)
|
|
10
10
|
|
|
@@ -45,6 +45,28 @@ class CalibratorConfig:
|
|
|
45
45
|
def to_kwargs(self) -> Dict:
|
|
46
46
|
return self.calibrator_kwargs.copy()
|
|
47
47
|
|
|
48
|
+
def as_dict(self) -> dict:
|
|
49
|
+
return dataclasses.asdict(self)
|
|
50
|
+
|
|
51
|
+
def update(self, **kwargs) -> "CalibratorConfig":
|
|
52
|
+
for key, value in kwargs.items():
|
|
53
|
+
if hasattr(self, key):
|
|
54
|
+
if value is not None:
|
|
55
|
+
setattr(self, key, value)
|
|
56
|
+
return self
|
|
57
|
+
|
|
58
|
+
def empty(self, **kwargs) -> "CalibratorConfig":
|
|
59
|
+
# Set all fields to None
|
|
60
|
+
for field in dataclasses.fields(self):
|
|
61
|
+
if hasattr(self, field.name):
|
|
62
|
+
setattr(self, field.name, None)
|
|
63
|
+
if kwargs:
|
|
64
|
+
self.update(**kwargs)
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
def reset(self, **kwargs) -> "CalibratorConfig":
|
|
68
|
+
return self.empty(**kwargs)
|
|
69
|
+
|
|
48
70
|
|
|
49
71
|
@dataclasses.dataclass
|
|
50
72
|
class TaylorSeerCalibratorConfig(CalibratorConfig):
|