cache-dit 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,12 @@
1
- from typing import Any, Tuple, List, Union
1
+ from typing import Any, Tuple, List, Union, Optional
2
2
  from diffusers import DiffusionPipeline
3
3
  from cache_dit.cache_factory.cache_types import CacheType
4
4
  from cache_dit.cache_factory.block_adapters import BlockAdapter
5
5
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
6
6
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
7
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
8
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
9
+ from cache_dit.cache_factory.params_modifier import ParamsModifier
7
10
 
8
11
  from cache_dit.logger import init_logger
9
12
 
@@ -16,23 +19,20 @@ def enable_cache(
16
19
  DiffusionPipeline,
17
20
  BlockAdapter,
18
21
  ],
19
- # Cache context kwargs
20
- Fn_compute_blocks: int = 8,
21
- Bn_compute_blocks: int = 0,
22
- max_warmup_steps: int = 8,
23
- max_cached_steps: int = -1,
24
- max_continuous_cached_steps: int = -1,
25
- residual_diff_threshold: float = 0.08,
26
- # Cache CFG or not
27
- enable_separate_cfg: bool = None,
28
- cfg_compute_first: bool = False,
29
- cfg_diff_compute_separate: bool = True,
30
- # Hybird TaylorSeer
31
- enable_taylorseer: bool = False,
32
- enable_encoder_taylorseer: bool = False,
33
- taylorseer_cache_type: str = "residual",
34
- taylorseer_order: int = 1,
35
- **other_cache_context_kwargs,
22
+ # Basic DBCache config: BasicCacheConfig
23
+ cache_config: BasicCacheConfig = BasicCacheConfig(),
24
+ # Calibrator config: TaylorSeerCalibratorConfig, etc.
25
+ calibrator_config: Optional[CalibratorConfig] = None,
26
+ # Modify cache context params for specific blocks.
27
+ params_modifiers: Optional[
28
+ Union[
29
+ ParamsModifier,
30
+ List[ParamsModifier],
31
+ List[List[ParamsModifier]],
32
+ ]
33
+ ] = None,
34
+ # Other cache context kwargs: Deprecated cache kwargs
35
+ **kwargs,
36
36
  ) -> Union[
37
37
  DiffusionPipeline,
38
38
  BlockAdapter,
@@ -49,52 +49,51 @@ def enable_cache(
49
49
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
50
50
  For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
51
51
  for the usgae of BlockAdapter.
52
- Fn_compute_blocks (`int`, *required*, defaults to 8):
53
- Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
54
- at time step t, enabling the calculation of a more stable L1 diff and delivering more
55
- accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
56
- for more details of DBCache.
57
- Bn_compute_blocks: (`int`, *required*, defaults to 0):
58
- Further fuses approximate information in the **last n** Transformer blocks to enhance
59
- prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
60
- that use residual cache.
61
- max_warmup_steps (`int`, *required*, defaults to 8):
62
- DBCache does not apply the caching strategy when the number of running steps is less than
63
- or equal to this value, ensuring the model sufficiently learns basic features during warmup.
64
- max_cached_steps (`int`, *required*, defaults to -1):
65
- DBCache disables the caching strategy when the previous cached steps exceed this value to
66
- prevent precision degradation.
67
- max_continuous_cached_steps (`int`, *required*, defaults to -1):
68
- DBCache disables the caching strategy when the previous continous cached steps exceed this value to
69
- prevent precision degradation.
70
- residual_diff_threshold (`float`, *required*, defaults to 0.08):
71
- he value of residual diff threshold, a higher value leads to faster performance at the
72
- cost of lower precision.
73
- enable_separate_cfg (`bool`, *required*, defaults to None):
74
- Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
75
- and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
76
- CogVideoX, HunyuanVideo, Mochi, etc.
77
- cfg_compute_first (`bool`, *required*, defaults to False):
78
- Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
79
- 1, 3, 5, ... -> CFG step.
80
- cfg_diff_compute_separate (`bool`, *required*, defaults to True):
81
- Compute separate diff values for CFG and non-CFG step, default True. If False, we will
82
- use the computed diff from current non-CFG transformer step for current CFG step.
83
- enable_taylorseer (`bool`, *required*, defaults to False):
84
- Enable the hybird TaylorSeer for hidden_states or not. We have supported the
85
- [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm
86
- to further improve the precision of DBCache in cases where the cached steps are large,
87
- namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals,
88
- the feature similarity in diffusion models decreases substantially, significantly
89
- harming the generation quality.
90
- enable_encoder_taylorseer (`bool`, *required*, defaults to False):
91
- Enable the hybird TaylorSeer for encoder_hidden_states or not.
92
- taylorseer_cache_type (`str`, *required*, defaults to `residual`):
93
- The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
94
- taylorseer_order (`int`, *required*, defaults to 1):
95
- The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
96
- the recommended value is 1 or 2.
97
- other_cache_context_kwargs: (`dict`, *optional*, defaults to {})
52
+ cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
53
+ Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
54
+ Fn_compute_blocks: (`int`, *required*, defaults to 8):
55
+ Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
56
+ at time step t, enabling the calculation of a more stable L1 diff and delivering more
57
+ accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
58
+ for more details of DBCache.
59
+ Bn_compute_blocks: (`int`, *required*, defaults to 0):
60
+ Further fuses approximate information in the **last n** Transformer blocks to enhance
61
+ prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
62
+ that use residual cache.
63
+ residual_diff_threshold (`float`, *required*, defaults to 0.08):
64
+ the value of residual diff threshold, a higher value leads to faster performance at the
65
+ cost of lower precision.
66
+ max_warmup_steps (`int`, *required*, defaults to 8):
67
+ DBCache does not apply the caching strategy when the number of running steps is less than
68
+ or equal to this value, ensuring the model sufficiently learns basic features during warmup.
69
+ max_cached_steps (`int`, *required*, defaults to -1):
70
+ DBCache disables the caching strategy when the previous cached steps exceed this value to
71
+ prevent precision degradation.
72
+ max_continuous_cached_steps (`int`, *required*, defaults to -1):
73
+ DBCache disables the caching strategy when the previous continous cached steps exceed this value to
74
+ prevent precision degradation.
75
+ enable_separate_cfg (`bool`, *required*, defaults to None):
76
+ Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
77
+ and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
78
+ CogVideoX, HunyuanVideo, Mochi, etc.
79
+ cfg_compute_first (`bool`, *required*, defaults to False):
80
+ Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
81
+ 1, 3, 5, ... -> CFG step.
82
+ cfg_diff_compute_separate (`bool`, *required*, defaults to True):
83
+ Compute separate diff values for CFG and non-CFG step, default True. If False, we will
84
+ use the computed diff from current non-CFG transformer step for current CFG step.
85
+ calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
86
+ Config for calibrator, if calibrator_config is not None, means that user want to use DBCache
87
+ with specific calibrator, such as taylorseer, foca, and so on.
88
+ params_modifiers ('ParamsModifier', *optional*, defaults to None):
89
+ Modify cache context params for specific blocks. The configurable params listed belows:
90
+ cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
91
+ The same as 'cache_config' param in cache_dit.enable_cache() interface.
92
+ calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
93
+ The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
94
+ **kwargs: (`dict`, *optional*, defaults to {}):
95
+ The same as 'kwargs' param in cache_dit.enable_cache() interface.
96
+ kwargs (`dict`, *optional*, defaults to {})
98
97
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
99
98
  for more details.
100
99
 
@@ -109,31 +108,76 @@ def enable_cache(
109
108
  >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
110
109
  """
111
110
  # Collect cache context kwargs
112
- cache_context_kwargs = other_cache_context_kwargs.copy()
113
- if (cache_type := cache_context_kwargs.get("cache_type", None)) is not None:
111
+ cache_context_kwargs = {}
112
+ if (cache_type := cache_context_kwargs.pop("cache_type", None)) is not None:
114
113
  if cache_type == CacheType.NONE:
115
114
  return pipe_or_adapter
116
115
 
117
- cache_context_kwargs["cache_type"] = CacheType.DBCache
118
- cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
119
- cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
120
- cache_context_kwargs["max_warmup_steps"] = max_warmup_steps
121
- cache_context_kwargs["max_cached_steps"] = max_cached_steps
122
- cache_context_kwargs["max_continuous_cached_steps"] = (
123
- max_continuous_cached_steps
124
- )
125
- cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
126
- cache_context_kwargs["enable_separate_cfg"] = enable_separate_cfg
127
- cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
128
- cache_context_kwargs["cfg_diff_compute_separate"] = (
129
- cfg_diff_compute_separate
130
- )
131
- cache_context_kwargs["enable_taylorseer"] = enable_taylorseer
132
- cache_context_kwargs["enable_encoder_taylorseer"] = (
133
- enable_encoder_taylorseer
134
- )
135
- cache_context_kwargs["taylorseer_cache_type"] = taylorseer_cache_type
136
- cache_context_kwargs["taylorseer_order"] = taylorseer_order
116
+ # WARNING: Deprecated cache config params. These parameters are now retained
117
+ # for backward compatibility but will be removed in the future.
118
+ deprecated_cache_kwargs = {
119
+ "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
120
+ "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
121
+ "max_warmup_steps": kwargs.get("max_warmup_steps", None),
122
+ "max_cached_steps": kwargs.get("max_cached_steps", None),
123
+ "max_continuous_cached_steps": kwargs.get(
124
+ "max_continuous_cached_steps", None
125
+ ),
126
+ "residual_diff_threshold": kwargs.get("residual_diff_threshold", None),
127
+ "enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
128
+ "cfg_compute_first": kwargs.get("cfg_compute_first", None),
129
+ "cfg_diff_compute_separate": kwargs.get(
130
+ "cfg_diff_compute_separate", None
131
+ ),
132
+ }
133
+
134
+ deprecated_cache_kwargs = {
135
+ k: v for k, v in deprecated_cache_kwargs.items() if v is not None
136
+ }
137
+
138
+ if deprecated_cache_kwargs:
139
+ logger.warning(
140
+ "Manually settup DBCache context without BasicCacheConfig is "
141
+ "deprecated and will be removed in the future, please use "
142
+ "`cache_config` parameter instead!"
143
+ )
144
+ if cache_config is not None:
145
+ cache_config.update(**deprecated_cache_kwargs)
146
+ else:
147
+ cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
148
+
149
+ if cache_config is not None:
150
+ cache_context_kwargs["cache_config"] = cache_config
151
+
152
+ # WARNING: Deprecated taylorseer params. These parameters are now retained
153
+ # for backward compatibility but will be removed in the future.
154
+ if (
155
+ kwargs.get("enable_taylorseer", None) is not None
156
+ or kwargs.get("enable_encoder_taylorseer", None) is not None
157
+ ):
158
+ logger.warning(
159
+ "Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
160
+ "deprecated and will be removed in the future, please use "
161
+ "`calibrator_config` parameter instead!"
162
+ )
163
+ from cache_dit.cache_factory.cache_contexts.calibrators import (
164
+ TaylorSeerCalibratorConfig,
165
+ )
166
+
167
+ calibrator_config = TaylorSeerCalibratorConfig(
168
+ enable_calibrator=kwargs.get("enable_taylorseer"),
169
+ enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer"),
170
+ calibrator_cache_type=kwargs.get(
171
+ "taylorseer_cache_type", "residual"
172
+ ),
173
+ taylorseer_order=kwargs.get("taylorseer_order", 1),
174
+ )
175
+
176
+ if calibrator_config is not None:
177
+ cache_context_kwargs["calibrator_config"] = calibrator_config
178
+
179
+ if params_modifiers is not None:
180
+ cache_context_kwargs["params_modifiers"] = params_modifiers
137
181
 
138
182
  if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
139
183
  return CachedAdapter.apply(
@@ -0,0 +1,87 @@
1
+ from typing import Optional
2
+
3
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
4
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
5
+
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class ParamsModifier:
12
+ def __init__(
13
+ self,
14
+ # Basic DBCache config: BasicCacheConfig
15
+ cache_config: BasicCacheConfig = None,
16
+ # Calibrator config: TaylorSeerCalibratorConfig, etc.
17
+ calibrator_config: Optional[CalibratorConfig] = None,
18
+ # Other cache context kwargs: Deprecated cache kwargs
19
+ **kwargs,
20
+ ):
21
+ self._context_kwargs = {}
22
+
23
+ # WARNING: Deprecated cache config params. These parameters are now retained
24
+ # for backward compatibility but will be removed in the future.
25
+ deprecated_cache_kwargs = {
26
+ "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
27
+ "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
28
+ "max_warmup_steps": kwargs.get("max_warmup_steps", None),
29
+ "max_cached_steps": kwargs.get("max_cached_steps", None),
30
+ "max_continuous_cached_steps": kwargs.get(
31
+ "max_continuous_cached_steps", None
32
+ ),
33
+ "residual_diff_threshold": kwargs.get(
34
+ "residual_diff_threshold", None
35
+ ),
36
+ "enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
37
+ "cfg_compute_first": kwargs.get("cfg_compute_first", None),
38
+ "cfg_diff_compute_separate": kwargs.get(
39
+ "cfg_diff_compute_separate", None
40
+ ),
41
+ }
42
+
43
+ deprecated_cache_kwargs = {
44
+ k: v for k, v in deprecated_cache_kwargs.items() if v is not None
45
+ }
46
+
47
+ if deprecated_cache_kwargs:
48
+ logger.warning(
49
+ "Manually settup DBCache context without BasicCacheConfig is "
50
+ "deprecated and will be removed in the future, please use "
51
+ "`cache_config` parameter instead!"
52
+ )
53
+ if cache_config is not None:
54
+ cache_config.update(**deprecated_cache_kwargs)
55
+ else:
56
+ cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
57
+
58
+ if cache_config is not None:
59
+ self._context_kwargs["cache_config"] = cache_config
60
+ # WARNING: Deprecated taylorseer params. These parameters are now retained
61
+ # for backward compatibility but will be removed in the future.
62
+ if (
63
+ kwargs.get("enable_taylorseer", None) is not None
64
+ or kwargs.get("enable_encoder_taylorseer", None) is not None
65
+ ):
66
+ logger.warning(
67
+ "Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
68
+ "deprecated and will be removed in the future, please use "
69
+ "`calibrator_config` parameter instead!"
70
+ )
71
+ from cache_dit.cache_factory.cache_contexts.calibrators import (
72
+ TaylorSeerCalibratorConfig,
73
+ )
74
+
75
+ calibrator_config = TaylorSeerCalibratorConfig(
76
+ enable_calibrator=kwargs.get("enable_taylorseer"),
77
+ enable_encoder_calibrator=kwargs.get(
78
+ "enable_encoder_taylorseer"
79
+ ),
80
+ calibrator_cache_type=kwargs.get(
81
+ "taylorseer_cache_type", "residual"
82
+ ),
83
+ taylorseer_order=kwargs.get("taylorseer_order", 1),
84
+ )
85
+
86
+ if calibrator_config is not None:
87
+ self._context_kwargs["calibrator_config"] = calibrator_config
@@ -4,10 +4,12 @@ from cache_dit.metrics.metrics import compute_mse
4
4
  from cache_dit.metrics.metrics import compute_video_psnr
5
5
  from cache_dit.metrics.metrics import compute_video_ssim
6
6
  from cache_dit.metrics.metrics import compute_video_mse
7
- from cache_dit.metrics.metrics import entrypoint
8
7
  from cache_dit.metrics.fid import FrechetInceptionDistance
8
+ from cache_dit.metrics.fid import compute_fid
9
+ from cache_dit.metrics.fid import compute_video_fid
9
10
  from cache_dit.metrics.config import set_metrics_verbose
10
11
  from cache_dit.metrics.config import get_metrics_verbose
12
+ from cache_dit.metrics.metrics import entrypoint
11
13
 
12
14
 
13
15
  def main():
cache_dit/utils.py CHANGED
@@ -9,7 +9,10 @@ from pprint import pprint
9
9
  from diffusers import DiffusionPipeline
10
10
 
11
11
  from typing import Dict, Any, List, Union
12
+ from cache_dit.cache_factory import CacheType
12
13
  from cache_dit.cache_factory import BlockAdapter
14
+ from cache_dit.cache_factory import BasicCacheConfig
15
+ from cache_dit.cache_factory import CalibratorConfig
13
16
  from cache_dit.logger import init_logger
14
17
 
15
18
 
@@ -161,7 +164,6 @@ def strify(
161
164
  cache_options = stats.cache_options
162
165
  cached_steps = len(stats.cached_steps)
163
166
  elif isinstance(adapter_or_others, dict):
164
- from cache_dit.cache_factory import CacheType
165
167
 
166
168
  # Assume cache_context_kwargs
167
169
  cache_options = adapter_or_others
@@ -179,22 +181,21 @@ def strify(
179
181
  if not cache_options:
180
182
  return "NONE"
181
183
 
182
- def get_taylorseer_order():
183
- taylorseer_order = 0
184
- if "taylorseer_order" in cache_options:
185
- taylorseer_order = cache_options["taylorseer_order"]
186
- return taylorseer_order
187
-
188
- cache_type_str = (
189
- f"DBCACHE_F{cache_options.get('Fn_compute_blocks', 1)}"
190
- f"B{cache_options.get('Bn_compute_blocks', 0)}_"
191
- f"W{cache_options.get('max_warmup_steps', 0)}"
192
- f"M{max(0, cache_options.get('max_cached_steps', -1))}"
193
- f"MC{max(0, cache_options.get('max_continuous_cached_steps', -1))}_"
194
- f"T{int(cache_options.get('enable_taylorseer', False))}"
195
- f"O{get_taylorseer_order()}_"
196
- f"R{cache_options.get('residual_diff_threshold', 0.08)}"
197
- )
184
+ def basic_cache_str():
185
+ cache_config: BasicCacheConfig = cache_options.get("cache_config", None)
186
+ if cache_config is not None:
187
+ return cache_config.strify()
188
+ return "NONE"
189
+
190
+ def calibrator_str():
191
+ calibrator_config: CalibratorConfig = cache_options.get(
192
+ "calibrator_config", None
193
+ )
194
+ if calibrator_config is not None:
195
+ return calibrator_config.strify()
196
+ return "T0O0"
197
+
198
+ cache_type_str = f"{basic_cache_str()}_{calibrator_str()}"
198
199
 
199
200
  if cached_steps:
200
201
  cache_type_str += f"_S{cached_steps}"