cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
- cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
- cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
- cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
|
@@ -4,8 +4,9 @@ from cache_dit.cache_factory.cache_types import CacheType
|
|
|
4
4
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
5
5
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
6
6
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
7
|
-
from cache_dit.cache_factory.
|
|
7
|
+
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
8
8
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
9
|
+
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
9
10
|
|
|
10
11
|
from cache_dit.logger import init_logger
|
|
11
12
|
|
|
@@ -18,25 +19,20 @@ def enable_cache(
|
|
|
18
19
|
DiffusionPipeline,
|
|
19
20
|
BlockAdapter,
|
|
20
21
|
],
|
|
21
|
-
#
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
max_warmup_steps: int = 8,
|
|
25
|
-
max_cached_steps: int = -1,
|
|
26
|
-
max_continuous_cached_steps: int = -1,
|
|
27
|
-
residual_diff_threshold: float = 0.08,
|
|
28
|
-
# Cache CFG or not
|
|
29
|
-
enable_separate_cfg: bool = None,
|
|
30
|
-
cfg_compute_first: bool = False,
|
|
31
|
-
cfg_diff_compute_separate: bool = True,
|
|
32
|
-
# Hybird TaylorSeer
|
|
33
|
-
enable_taylorseer: bool = False,
|
|
34
|
-
enable_encoder_taylorseer: bool = False,
|
|
35
|
-
taylorseer_cache_type: str = "residual",
|
|
36
|
-
taylorseer_order: int = 1,
|
|
37
|
-
# New param only for v2 API
|
|
22
|
+
# Basic DBCache config: BasicCacheConfig
|
|
23
|
+
cache_config: BasicCacheConfig = BasicCacheConfig(),
|
|
24
|
+
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
38
25
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
39
|
-
|
|
26
|
+
# Modify cache context params for specific blocks.
|
|
27
|
+
params_modifiers: Optional[
|
|
28
|
+
Union[
|
|
29
|
+
ParamsModifier,
|
|
30
|
+
List[ParamsModifier],
|
|
31
|
+
List[List[ParamsModifier]],
|
|
32
|
+
]
|
|
33
|
+
] = None,
|
|
34
|
+
# Other cache context kwargs: Deprecated cache kwargs
|
|
35
|
+
**kwargs,
|
|
40
36
|
) -> Union[
|
|
41
37
|
DiffusionPipeline,
|
|
42
38
|
BlockAdapter,
|
|
@@ -53,55 +49,51 @@ def enable_cache(
|
|
|
53
49
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
54
50
|
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
55
51
|
for the usgae of BlockAdapter.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
[TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm
|
|
90
|
-
to further improve the precision of DBCache in cases where the cached steps are large,
|
|
91
|
-
namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals,
|
|
92
|
-
the feature similarity in diffusion models decreases substantially, significantly
|
|
93
|
-
harming the generation quality.
|
|
94
|
-
enable_encoder_taylorseer (`bool`, *required*, defaults to False):
|
|
95
|
-
Enable the hybird TaylorSeer for encoder_hidden_states or not.
|
|
96
|
-
taylorseer_cache_type (`str`, *required*, defaults to `residual`):
|
|
97
|
-
The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
|
|
98
|
-
taylorseer_order (`int`, *required*, defaults to 1):
|
|
99
|
-
The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
|
|
100
|
-
the recommended value is 1 or 2.
|
|
52
|
+
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
53
|
+
Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
|
|
54
|
+
Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
55
|
+
Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
56
|
+
at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
57
|
+
accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
58
|
+
for more details of DBCache.
|
|
59
|
+
Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
60
|
+
Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
61
|
+
prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
62
|
+
that use residual cache.
|
|
63
|
+
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
64
|
+
the value of residual diff threshold, a higher value leads to faster performance at the
|
|
65
|
+
cost of lower precision.
|
|
66
|
+
max_warmup_steps (`int`, *required*, defaults to 8):
|
|
67
|
+
DBCache does not apply the caching strategy when the number of running steps is less than
|
|
68
|
+
or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
69
|
+
max_cached_steps (`int`, *required*, defaults to -1):
|
|
70
|
+
DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
71
|
+
prevent precision degradation.
|
|
72
|
+
max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
73
|
+
DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
74
|
+
prevent precision degradation.
|
|
75
|
+
enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
76
|
+
Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
77
|
+
and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
78
|
+
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
79
|
+
cfg_compute_first (`bool`, *required*, defaults to False):
|
|
80
|
+
Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
81
|
+
1, 3, 5, ... -> CFG step.
|
|
82
|
+
cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
83
|
+
Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
84
|
+
use the computed diff from current non-CFG transformer step for current CFG step.
|
|
101
85
|
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
86
|
+
Config for calibrator, if calibrator_config is not None, means that user want to use DBCache
|
|
87
|
+
with specific calibrator, such as taylorseer, foca, and so on.
|
|
88
|
+
params_modifiers ('ParamsModifier', *optional*, defaults to None):
|
|
89
|
+
Modify cache context params for specific blocks. The configurable params listed belows:
|
|
90
|
+
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
91
|
+
The same as 'cache_config' param in cache_dit.enable_cache() interface.
|
|
92
|
+
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
93
|
+
The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
|
|
94
|
+
**kwargs: (`dict`, *optional*, defaults to {}):
|
|
95
|
+
The same as 'kwargs' param in cache_dit.enable_cache() interface.
|
|
96
|
+
kwargs (`dict`, *optional*, defaults to {})
|
|
105
97
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
106
98
|
for more details.
|
|
107
99
|
|
|
@@ -116,51 +108,82 @@ def enable_cache(
|
|
|
116
108
|
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
117
109
|
"""
|
|
118
110
|
# Collect cache context kwargs
|
|
119
|
-
cache_context_kwargs =
|
|
120
|
-
if (cache_type := cache_context_kwargs.
|
|
111
|
+
cache_context_kwargs = {}
|
|
112
|
+
if (cache_type := cache_context_kwargs.pop("cache_type", None)) is not None:
|
|
121
113
|
if cache_type == CacheType.NONE:
|
|
122
114
|
return pipe_or_adapter
|
|
123
115
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
116
|
+
# WARNING: Deprecated cache config params. These parameters are now retained
|
|
117
|
+
# for backward compatibility but will be removed in the future.
|
|
118
|
+
deprecated_cache_kwargs = {
|
|
119
|
+
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
120
|
+
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
121
|
+
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
122
|
+
"max_cached_steps": kwargs.get("max_cached_steps", None),
|
|
123
|
+
"max_continuous_cached_steps": kwargs.get(
|
|
124
|
+
"max_continuous_cached_steps", None
|
|
125
|
+
),
|
|
126
|
+
"residual_diff_threshold": kwargs.get("residual_diff_threshold", None),
|
|
127
|
+
"enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
|
|
128
|
+
"cfg_compute_first": kwargs.get("cfg_compute_first", None),
|
|
129
|
+
"cfg_diff_compute_separate": kwargs.get(
|
|
130
|
+
"cfg_diff_compute_separate", None
|
|
131
|
+
),
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
deprecated_cache_kwargs = {
|
|
135
|
+
k: v for k, v in deprecated_cache_kwargs.items() if v is not None
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
if deprecated_cache_kwargs:
|
|
139
|
+
logger.warning(
|
|
140
|
+
"Manually settup DBCache context without BasicCacheConfig is "
|
|
141
|
+
"deprecated and will be removed in the future, please use "
|
|
142
|
+
"`cache_config` parameter instead!"
|
|
143
|
+
)
|
|
144
|
+
if cache_config is not None:
|
|
145
|
+
cache_config.update(**deprecated_cache_kwargs)
|
|
146
|
+
else:
|
|
147
|
+
cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
|
|
148
|
+
|
|
149
|
+
if cache_config is not None:
|
|
150
|
+
cache_context_kwargs["cache_config"] = cache_config
|
|
151
|
+
|
|
152
|
+
# WARNING: Deprecated taylorseer params. These parameters are now retained
|
|
153
|
+
# for backward compatibility but will be removed in the future.
|
|
154
|
+
if (
|
|
155
|
+
kwargs.get("enable_taylorseer", None) is not None
|
|
156
|
+
or kwargs.get("enable_encoder_taylorseer", None) is not None
|
|
157
|
+
):
|
|
158
|
+
logger.warning(
|
|
159
|
+
"Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
|
|
160
|
+
"deprecated and will be removed in the future, please use "
|
|
161
|
+
"`calibrator_config` parameter instead!"
|
|
162
|
+
)
|
|
163
|
+
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
164
|
+
TaylorSeerCalibratorConfig,
|
|
165
|
+
)
|
|
138
166
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
167
|
+
calibrator_config = TaylorSeerCalibratorConfig(
|
|
168
|
+
enable_calibrator=kwargs.get("enable_taylorseer"),
|
|
169
|
+
enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer"),
|
|
170
|
+
calibrator_cache_type=kwargs.get(
|
|
171
|
+
"taylorseer_cache_type", "residual"
|
|
172
|
+
),
|
|
173
|
+
taylorseer_order=kwargs.get("taylorseer_order", 1),
|
|
145
174
|
)
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
else:
|
|
175
|
+
|
|
176
|
+
if calibrator_config is not None:
|
|
149
177
|
cache_context_kwargs["calibrator_config"] = calibrator_config
|
|
150
178
|
|
|
179
|
+
if params_modifiers is not None:
|
|
180
|
+
cache_context_kwargs["params_modifiers"] = params_modifiers
|
|
181
|
+
|
|
151
182
|
if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
)
|
|
157
|
-
else:
|
|
158
|
-
logger.warning("You are using the un-stable V2 API!")
|
|
159
|
-
pipe_or_adapter._is_v2_api = True
|
|
160
|
-
return CachedAdapterV2.apply(
|
|
161
|
-
pipe_or_adapter,
|
|
162
|
-
**cache_context_kwargs,
|
|
163
|
-
)
|
|
183
|
+
return CachedAdapter.apply(
|
|
184
|
+
pipe_or_adapter,
|
|
185
|
+
**cache_context_kwargs,
|
|
186
|
+
)
|
|
164
187
|
else:
|
|
165
188
|
raise ValueError(
|
|
166
189
|
f"type: {type(pipe_or_adapter)} is not valid, "
|
|
@@ -175,13 +198,7 @@ def disable_cache(
|
|
|
175
198
|
BlockAdapter,
|
|
176
199
|
],
|
|
177
200
|
):
|
|
178
|
-
|
|
179
|
-
logger.warning("You are using the un-stable V2 API!")
|
|
180
|
-
CachedAdapterV2.maybe_release_hooks(pipe_or_adapter)
|
|
181
|
-
del pipe_or_adapter._is_v2_api
|
|
182
|
-
else:
|
|
183
|
-
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
|
|
184
|
-
|
|
201
|
+
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
|
|
185
202
|
logger.warning(
|
|
186
203
|
f"Cache Acceleration is disabled for: "
|
|
187
204
|
f"{pipe_or_adapter.__class__.__name__}."
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
5
|
+
|
|
6
|
+
from cache_dit.logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ParamsModifier:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
# Basic DBCache config: BasicCacheConfig
|
|
15
|
+
cache_config: BasicCacheConfig = None,
|
|
16
|
+
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
17
|
+
calibrator_config: Optional[CalibratorConfig] = None,
|
|
18
|
+
# Other cache context kwargs: Deprecated cache kwargs
|
|
19
|
+
**kwargs,
|
|
20
|
+
):
|
|
21
|
+
self._context_kwargs = {}
|
|
22
|
+
|
|
23
|
+
# WARNING: Deprecated cache config params. These parameters are now retained
|
|
24
|
+
# for backward compatibility but will be removed in the future.
|
|
25
|
+
deprecated_cache_kwargs = {
|
|
26
|
+
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
27
|
+
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
28
|
+
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
29
|
+
"max_cached_steps": kwargs.get("max_cached_steps", None),
|
|
30
|
+
"max_continuous_cached_steps": kwargs.get(
|
|
31
|
+
"max_continuous_cached_steps", None
|
|
32
|
+
),
|
|
33
|
+
"residual_diff_threshold": kwargs.get(
|
|
34
|
+
"residual_diff_threshold", None
|
|
35
|
+
),
|
|
36
|
+
"enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
|
|
37
|
+
"cfg_compute_first": kwargs.get("cfg_compute_first", None),
|
|
38
|
+
"cfg_diff_compute_separate": kwargs.get(
|
|
39
|
+
"cfg_diff_compute_separate", None
|
|
40
|
+
),
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
deprecated_cache_kwargs = {
|
|
44
|
+
k: v for k, v in deprecated_cache_kwargs.items() if v is not None
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
if deprecated_cache_kwargs:
|
|
48
|
+
logger.warning(
|
|
49
|
+
"Manually settup DBCache context without BasicCacheConfig is "
|
|
50
|
+
"deprecated and will be removed in the future, please use "
|
|
51
|
+
"`cache_config` parameter instead!"
|
|
52
|
+
)
|
|
53
|
+
if cache_config is not None:
|
|
54
|
+
cache_config.update(**deprecated_cache_kwargs)
|
|
55
|
+
else:
|
|
56
|
+
cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
|
|
57
|
+
|
|
58
|
+
if cache_config is not None:
|
|
59
|
+
self._context_kwargs["cache_config"] = cache_config
|
|
60
|
+
# WARNING: Deprecated taylorseer params. These parameters are now retained
|
|
61
|
+
# for backward compatibility but will be removed in the future.
|
|
62
|
+
if (
|
|
63
|
+
kwargs.get("enable_taylorseer", None) is not None
|
|
64
|
+
or kwargs.get("enable_encoder_taylorseer", None) is not None
|
|
65
|
+
):
|
|
66
|
+
logger.warning(
|
|
67
|
+
"Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
|
|
68
|
+
"deprecated and will be removed in the future, please use "
|
|
69
|
+
"`calibrator_config` parameter instead!"
|
|
70
|
+
)
|
|
71
|
+
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
72
|
+
TaylorSeerCalibratorConfig,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
calibrator_config = TaylorSeerCalibratorConfig(
|
|
76
|
+
enable_calibrator=kwargs.get("enable_taylorseer"),
|
|
77
|
+
enable_encoder_calibrator=kwargs.get(
|
|
78
|
+
"enable_encoder_taylorseer"
|
|
79
|
+
),
|
|
80
|
+
calibrator_cache_type=kwargs.get(
|
|
81
|
+
"taylorseer_cache_type", "residual"
|
|
82
|
+
),
|
|
83
|
+
taylorseer_order=kwargs.get("taylorseer_order", 1),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if calibrator_config is not None:
|
|
87
|
+
self._context_kwargs["calibrator_config"] = calibrator_config
|
cache_dit/metrics/__init__.py
CHANGED
|
@@ -4,10 +4,12 @@ from cache_dit.metrics.metrics import compute_mse
|
|
|
4
4
|
from cache_dit.metrics.metrics import compute_video_psnr
|
|
5
5
|
from cache_dit.metrics.metrics import compute_video_ssim
|
|
6
6
|
from cache_dit.metrics.metrics import compute_video_mse
|
|
7
|
-
from cache_dit.metrics.metrics import entrypoint
|
|
8
7
|
from cache_dit.metrics.fid import FrechetInceptionDistance
|
|
8
|
+
from cache_dit.metrics.fid import compute_fid
|
|
9
|
+
from cache_dit.metrics.fid import compute_video_fid
|
|
9
10
|
from cache_dit.metrics.config import set_metrics_verbose
|
|
10
11
|
from cache_dit.metrics.config import get_metrics_verbose
|
|
12
|
+
from cache_dit.metrics.metrics import entrypoint
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def main():
|
cache_dit/utils.py
CHANGED
|
@@ -9,7 +9,9 @@ from pprint import pprint
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
10
|
|
|
11
11
|
from typing import Dict, Any, List, Union
|
|
12
|
+
from cache_dit.cache_factory import CacheType
|
|
12
13
|
from cache_dit.cache_factory import BlockAdapter
|
|
14
|
+
from cache_dit.cache_factory import BasicCacheConfig
|
|
13
15
|
from cache_dit.cache_factory import CalibratorConfig
|
|
14
16
|
from cache_dit.logger import init_logger
|
|
15
17
|
|
|
@@ -162,7 +164,6 @@ def strify(
|
|
|
162
164
|
cache_options = stats.cache_options
|
|
163
165
|
cached_steps = len(stats.cached_steps)
|
|
164
166
|
elif isinstance(adapter_or_others, dict):
|
|
165
|
-
from cache_dit.cache_factory import CacheType
|
|
166
167
|
|
|
167
168
|
# Assume cache_context_kwargs
|
|
168
169
|
cache_options = adapter_or_others
|
|
@@ -180,31 +181,21 @@ def strify(
|
|
|
180
181
|
if not cache_options:
|
|
181
182
|
return "NONE"
|
|
182
183
|
|
|
183
|
-
def
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
return (
|
|
189
|
-
f"T{int(cache_options.get('enable_taylorseer', False))}"
|
|
190
|
-
f"O{taylorseer_order}"
|
|
191
|
-
)
|
|
184
|
+
def basic_cache_str():
|
|
185
|
+
cache_config: BasicCacheConfig = cache_options.get("cache_config", None)
|
|
186
|
+
if cache_config is not None:
|
|
187
|
+
return cache_config.strify()
|
|
188
|
+
return "NONE"
|
|
192
189
|
|
|
190
|
+
def calibrator_str():
|
|
193
191
|
calibrator_config: CalibratorConfig = cache_options.get(
|
|
194
192
|
"calibrator_config", None
|
|
195
193
|
)
|
|
194
|
+
if calibrator_config is not None:
|
|
195
|
+
return calibrator_config.strify()
|
|
196
|
+
return "T0O0"
|
|
196
197
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
cache_type_str = (
|
|
200
|
-
f"DBCACHE_F{cache_options.get('Fn_compute_blocks', 1)}"
|
|
201
|
-
f"B{cache_options.get('Bn_compute_blocks', 0)}_"
|
|
202
|
-
f"W{cache_options.get('max_warmup_steps', 0)}"
|
|
203
|
-
f"M{max(0, cache_options.get('max_cached_steps', -1))}"
|
|
204
|
-
f"MC{max(0, cache_options.get('max_continuous_cached_steps', -1))}_"
|
|
205
|
-
f"{calibrator_str()}_"
|
|
206
|
-
f"R{cache_options.get('residual_diff_threshold', 0.08)}"
|
|
207
|
-
)
|
|
198
|
+
cache_type_str = f"{basic_cache_str()}_{calibrator_str()}"
|
|
208
199
|
|
|
209
200
|
if cached_steps:
|
|
210
201
|
cache_type_str += f"_S{cached_steps}"
|