cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (34) hide show
  1. cache_dit/__init__.py +1 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +3 -6
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
  5. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
  7. cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
  8. cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
  10. cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
  11. cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
  12. cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
  13. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
  14. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
  15. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
  16. cache_dit/cache_factory/cache_interface.py +128 -111
  17. cache_dit/cache_factory/params_modifier.py +87 -0
  18. cache_dit/metrics/__init__.py +3 -1
  19. cache_dit/utils.py +12 -21
  20. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
  21. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
  22. cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
  23. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
  24. cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
  25. cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
  26. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
  27. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
  28. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
  29. /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
  30. /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
  31. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
  32. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
  33. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
  34. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
@@ -5,41 +5,120 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
5
5
 
6
6
  import torch
7
7
 
8
- from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
8
+ from cache_dit.cache_factory.cache_contexts.calibrators import (
9
+ Calibrator,
10
+ CalibratorBase,
11
+ CalibratorConfig,
12
+ )
9
13
  from cache_dit.logger import init_logger
10
14
 
11
15
  logger = init_logger(__name__)
12
16
 
13
17
 
14
18
  @dataclasses.dataclass
15
- class CachedContext: # Internal CachedContext Impl class
16
- name: str = "default"
17
- # Dual Block Cache with flexible FnBn configuration.
18
- Fn_compute_blocks: int = 1
19
+ class BasicCacheConfig:
20
+ # Dual Block Cache with Flexible FnBn configuration.
21
+
22
+ # Fn_compute_blocks: (`int`, *required*, defaults to 8):
23
+ # Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
24
+ # at time step t, enabling the calculation of a more stable L1 diff and delivering more
25
+ # accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
26
+ # for more details of DBCache.
27
+ Fn_compute_blocks: int = 8
28
+ # Bn_compute_blocks: (`int`, *required*, defaults to 0):
29
+ # Further fuses approximate information in the **last n** Transformer blocks to enhance
30
+ # prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
31
+ # that use residual cache.
19
32
  Bn_compute_blocks: int = 0
20
- # non compute blocks diff threshold, we don't skip the non
21
- # compute blocks if the diff >= threshold
22
- non_compute_blocks_diff_threshold: float = 0.08
23
- max_Fn_compute_blocks: int = -1
24
- max_Bn_compute_blocks: int = -1
25
- # L1 hidden states or residual diff threshold for Fn
26
- residual_diff_threshold: Union[torch.Tensor, float] = 0.05
33
+ # residual_diff_threshold (`float`, *required*, defaults to 0.08):
34
+ # the value of residual diff threshold, a higher value leads to faster performance at the
35
+ # cost of lower precision.
36
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.08
37
+ # max_warmup_steps (`int`, *required*, defaults to 8):
38
+ # DBCache does not apply the caching strategy when the number of running steps is less than
39
+ # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
40
+ max_warmup_steps: int = 8 # DON'T Cache in warmup steps
41
+ # max_cached_steps (`int`, *required*, defaults to -1):
42
+ # DBCache disables the caching strategy when the previous cached steps exceed this value to
43
+ # prevent precision degradation.
44
+ max_cached_steps: int = -1 # for both CFG and non-CFG
45
+ # max_continuous_cached_steps (`int`, *required*, defaults to -1):
46
+ # DBCache disables the caching strategy when the previous continous cached steps exceed this value to
47
+ # prevent precision degradation.
48
+ max_continuous_cached_steps: int = -1 # the max continuous cached steps
49
+ # enable_separate_cfg (`bool`, *required*, defaults to None):
50
+ # Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
51
+ # and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
52
+ # CogVideoX, HunyuanVideo, Mochi, etc.
53
+ enable_separate_cfg: Optional[bool] = None
54
+ # cfg_compute_first (`bool`, *required*, defaults to False):
55
+ # Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
56
+ # 1, 3, 5, ... -> CFG step.
57
+ cfg_compute_first: bool = False
58
+ # cfg_diff_compute_separate (`bool`, *required*, defaults to True):
59
+ # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
60
+ # use the computed diff from current non-CFG transformer step for current CFG step.
61
+ cfg_diff_compute_separate: bool = True
62
+
63
+ def update(self, **kwargs) -> "BasicCacheConfig":
64
+ for key, value in kwargs.items():
65
+ if hasattr(self, key):
66
+ setattr(self, key, value)
67
+ return self
68
+
69
+ def strify(self) -> str:
70
+ return (
71
+ f"DBCACHE_F{self.Fn_compute_blocks}"
72
+ f"B{self.Bn_compute_blocks}_"
73
+ f"W{self.max_warmup_steps}"
74
+ f"M{max(0, self.max_cached_steps)}"
75
+ f"MC{max(0, self.max_continuous_cached_steps)}_"
76
+ f"R{self.residual_diff_threshold}"
77
+ )
78
+
79
+
80
+ @dataclasses.dataclass
81
+ class ExtraCacheConfig:
82
+ # Some other not very important settings for Dual Block Cache.
83
+ # NOTE: These flags maybe deprecated in the future and users
84
+ # should never use these extra configurations in their cases.
85
+
86
+ # l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
87
+ # The hidden states diff threshold for DBCache if use hidden_states as
88
+ # cache (not residual).
27
89
  l1_hidden_states_diff_threshold: float = None
90
+ # important_condition_threshold (`float`, *optional*, defaults to 0.0):
91
+ # Only select the most important tokens while calculating the l1 diff.
28
92
  important_condition_threshold: float = 0.0
93
+ # downsample_factor (`int`, *optional*, defaults to 1):
94
+ # Downsample factor for Fn buffer, in order the save GPU memory.
95
+ downsample_factor: int = 1
96
+ # num_inference_steps (`int`, *optional*, defaults to -1):
97
+ # num_inference_steps for DiffusionPipeline, for future use.
98
+ num_inference_steps: int = -1
29
99
 
100
+
101
+ @dataclasses.dataclass
102
+ class CachedContext:
103
+ name: str = "default"
30
104
  # Buffer for storing the residuals and other tensors
31
105
  buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
32
- incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
33
- default_factory=lambda: defaultdict(int),
106
+ # Basic Dual Block Cache Config
107
+ cache_config: BasicCacheConfig = dataclasses.field(
108
+ default_factory=BasicCacheConfig,
34
109
  )
110
+ # NOTE: Users should never use these extra configurations.
111
+ extra_cache_config: ExtraCacheConfig = dataclasses.field(
112
+ default_factory=ExtraCacheConfig,
113
+ )
114
+ # Calibrator config for Dual Block Cache: TaylorSeer, FoCa, etc.
115
+ calibrator_config: Optional[CalibratorConfig] = None
35
116
 
36
- # Other settings
37
- downsample_factor: int = 1
38
- num_inference_steps: int = -1 # for future use
39
- max_warmup_steps: int = 0 # DON'T Cache in warmup steps
40
- # DON'T Cache if the number of cached steps >= max_cached_steps
41
- max_cached_steps: int = -1 # for both CFG and non-CFG
42
- max_continuous_cached_steps: int = -1 # the max continuous cached steps
117
+ # Calibrators for both CFG and non-CFG
118
+ calibrator: Optional[CalibratorBase] = None
119
+ encoder_calibrator: Optional[CalibratorBase] = None
120
+ cfg_calibrator: Optional[CalibratorBase] = None
121
+ cfg_encoder_calibrator: Optional[CalibratorBase] = None
43
122
 
44
123
  # Record the steps that have been cached, both cached and non-cache
45
124
  executed_steps: int = 0 # cache + non-cache steps pippeline
@@ -47,32 +126,6 @@ class CachedContext: # Internal CachedContext Impl class
47
126
  # be double of executed_steps.
48
127
  transformer_executed_steps: int = 0
49
128
 
50
- # Support TaylorSeers in Dual Block Cache
51
- # Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
52
- # Url: https://arxiv.org/pdf/2503.06923
53
- enable_taylorseer: bool = False
54
- enable_encoder_taylorseer: bool = False
55
- taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
56
- taylorseer_order: int = 1 # The order for TaylorSeer
57
- taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
58
- taylorseer: Optional[TaylorSeer] = None
59
- encoder_tarlorseer: Optional[TaylorSeer] = None
60
-
61
- # Support enable_separate_cfg, such as Wan 2.1,
62
- # Qwen-Image. For model that fused CFG and non-CFG into single
63
- # forward step, should set enable_separate_cfg as False.
64
- # For example: CogVideoX, HunyuanVideo, Mochi.
65
- enable_separate_cfg: bool = False
66
- # Compute cfg forward first or not, default False, namely,
67
- # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
68
- cfg_compute_first: bool = False
69
- # Compute separate diff values for CFG and non-CFG step,
70
- # default True. If False, we will use the computed diff from
71
- # current non-CFG transformer step for current CFG step.
72
- cfg_diff_compute_separate: bool = True
73
- cfg_taylorseer: Optional[TaylorSeer] = None
74
- cfg_encoder_taylorseer: Optional[TaylorSeer] = None
75
-
76
129
  # CFG & non-CFG cached steps
77
130
  cached_steps: List[int] = dataclasses.field(default_factory=list)
78
131
  residual_diffs: DefaultDict[str, float] = dataclasses.field(
@@ -87,42 +140,58 @@ class CachedContext: # Internal CachedContext Impl class
87
140
 
88
141
  def __post_init__(self):
89
142
  if logger.isEnabledFor(logging.DEBUG):
90
- logger.info(f"Created _CacheContext: {self.name}")
143
+ logger.info(f"Created CachedContext: {self.name}")
91
144
  # Some checks for settings
92
- if self.enable_separate_cfg:
93
- if self.cfg_diff_compute_separate:
94
- assert self.cfg_compute_first is False, (
145
+ if self.cache_config.enable_separate_cfg:
146
+ if self.cache_config.cfg_diff_compute_separate:
147
+ assert self.cache_config.cfg_compute_first is False, (
95
148
  "cfg_compute_first must set as False if "
96
149
  "cfg_diff_compute_separate is enabled."
97
150
  )
98
151
 
99
- if "max_warmup_steps" not in self.taylorseer_kwargs:
100
- # If max_warmup_steps is not set in taylorseer_kwargs,
101
- # set the same as max_warmup_steps for DBCache
102
- self.taylorseer_kwargs["max_warmup_steps"] = (
103
- self.max_warmup_steps if self.max_warmup_steps > 0 else 1
104
- )
105
-
106
- # Overwrite the 'n_derivatives' by 'taylorseer_order', default: 2.
107
- self.taylorseer_kwargs["n_derivatives"] = self.taylorseer_order
108
-
109
- if self.enable_taylorseer:
110
- self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
111
- if self.enable_separate_cfg:
112
- self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
152
+ if self.calibrator_config is not None:
153
+ if self.calibrator_config.enable_calibrator:
154
+ self.calibrator = Calibrator(self.calibrator_config)
155
+ if self.cache_config.enable_separate_cfg:
156
+ self.cfg_calibrator = Calibrator(self.calibrator_config)
157
+
158
+ if self.calibrator_config.enable_encoder_calibrator:
159
+ self.encoder_calibrator = Calibrator(self.calibrator_config)
160
+ if self.cache_config.enable_separate_cfg:
161
+ self.cfg_encoder_calibrator = Calibrator(
162
+ self.calibrator_config
163
+ )
113
164
 
114
- if self.enable_encoder_taylorseer:
115
- self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
116
- if self.enable_separate_cfg:
117
- self.cfg_encoder_taylorseer = TaylorSeer(
118
- **self.taylorseer_kwargs
119
- )
165
+ def enable_calibrator(self):
166
+ if self.calibrator_config is not None:
167
+ return self.calibrator_config.enable_calibrator
168
+ return False
169
+
170
+ def enable_encoder_calibrator(self):
171
+ if self.calibrator_config is not None:
172
+ return self.calibrator_config.enable_encoder_calibrator
173
+ return False
174
+
175
+ def calibrator_cache_type(self):
176
+ if self.calibrator_config is not None:
177
+ return self.calibrator_config.calibrator_cache_type
178
+ return "residual"
179
+
180
+ def has_calibrators(self) -> bool:
181
+ if self.calibrator_config is not None:
182
+ return (
183
+ self.calibrator_config.enable_calibrator
184
+ or self.calibrator_config.enable_encoder_calibrator
185
+ )
186
+ return False
120
187
 
121
188
  def get_residual_diff_threshold(self):
122
- residual_diff_threshold = self.residual_diff_threshold
123
- if self.l1_hidden_states_diff_threshold is not None:
189
+ residual_diff_threshold = self.cache_config.residual_diff_threshold
190
+ if self.extra_cache_config.l1_hidden_states_diff_threshold is not None:
124
191
  # Use the L1 hidden states diff threshold if set
125
- residual_diff_threshold = self.l1_hidden_states_diff_threshold
192
+ residual_diff_threshold = (
193
+ self.extra_cache_config.l1_hidden_states_diff_threshold
194
+ )
126
195
  if isinstance(residual_diff_threshold, torch.Tensor):
127
196
  residual_diff_threshold = residual_diff_threshold.item()
128
197
  return residual_diff_threshold
@@ -145,11 +214,11 @@ class CachedContext: # Internal CachedContext Impl class
145
214
  # incr step: prev 0 -> 1; prev 1 -> 2
146
215
  # current step: incr step - 1
147
216
  self.transformer_executed_steps += 1
148
- if not self.enable_separate_cfg:
217
+ if not self.cache_config.enable_separate_cfg:
149
218
  self.executed_steps += 1
150
219
  else:
151
220
  # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
152
- if not self.cfg_compute_first:
221
+ if not self.cache_config.cfg_compute_first:
153
222
  if not self.is_separate_cfg_step():
154
223
  # transformer step: 0,2,4,...
155
224
  self.executed_steps += 1
@@ -165,52 +234,52 @@ class CachedContext: # Internal CachedContext Impl class
165
234
  self.residual_diffs.clear()
166
235
  self.cfg_cached_steps.clear()
167
236
  self.cfg_residual_diffs.clear()
168
- # Reset the TaylorSeers cache at the beginning of each inference.
169
- # reset_cache will set the current step to -1 for TaylorSeer,
170
- if self.enable_taylorseer or self.enable_encoder_taylorseer:
171
- taylorseer, encoder_taylorseer = self.get_taylorseers()
172
- if taylorseer is not None:
173
- taylorseer.reset_cache()
174
- if encoder_taylorseer is not None:
175
- encoder_taylorseer.reset_cache()
176
- cfg_taylorseer, cfg_encoder_taylorseer = (
177
- self.get_cfg_taylorseers()
237
+ # Reset the calibrators cache at the beginning of each inference.
238
+ # reset_cache will set the current step to -1 for calibrator,
239
+ if self.has_calibrators():
240
+ calibrator, encoder_calibrator = self.get_calibrators()
241
+ if calibrator is not None:
242
+ calibrator.reset_cache()
243
+ if encoder_calibrator is not None:
244
+ encoder_calibrator.reset_cache()
245
+ cfg_calibrator, cfg_encoder_calibrator = (
246
+ self.get_cfg_calibrators()
178
247
  )
179
- if cfg_taylorseer is not None:
180
- cfg_taylorseer.reset_cache()
181
- if cfg_encoder_taylorseer is not None:
182
- cfg_encoder_taylorseer.reset_cache()
183
-
184
- # mark_step_begin of TaylorSeer must be called after the cache is reset.
185
- if self.enable_taylorseer or self.enable_encoder_taylorseer:
186
- if self.enable_separate_cfg:
248
+ if cfg_calibrator is not None:
249
+ cfg_calibrator.reset_cache()
250
+ if cfg_encoder_calibrator is not None:
251
+ cfg_encoder_calibrator.reset_cache()
252
+
253
+ # mark_step_begin of calibrator must be called after the cache is reset.
254
+ if self.has_calibrators():
255
+ if self.cache_config.enable_separate_cfg:
187
256
  # Assume non-CFG steps: 0, 2, 4, 6, ...
188
257
  if not self.is_separate_cfg_step():
189
- taylorseer, encoder_taylorseer = self.get_taylorseers()
190
- if taylorseer is not None:
191
- taylorseer.mark_step_begin()
192
- if encoder_taylorseer is not None:
193
- encoder_taylorseer.mark_step_begin()
258
+ calibrator, encoder_calibrator = self.get_calibrators()
259
+ if calibrator is not None:
260
+ calibrator.mark_step_begin()
261
+ if encoder_calibrator is not None:
262
+ encoder_calibrator.mark_step_begin()
194
263
  else:
195
- cfg_taylorseer, cfg_encoder_taylorseer = (
196
- self.get_cfg_taylorseers()
264
+ cfg_calibrator, cfg_encoder_calibrator = (
265
+ self.get_cfg_calibrators()
197
266
  )
198
- if cfg_taylorseer is not None:
199
- cfg_taylorseer.mark_step_begin()
200
- if cfg_encoder_taylorseer is not None:
201
- cfg_encoder_taylorseer.mark_step_begin()
267
+ if cfg_calibrator is not None:
268
+ cfg_calibrator.mark_step_begin()
269
+ if cfg_encoder_calibrator is not None:
270
+ cfg_encoder_calibrator.mark_step_begin()
202
271
  else:
203
- taylorseer, encoder_taylorseer = self.get_taylorseers()
204
- if taylorseer is not None:
205
- taylorseer.mark_step_begin()
206
- if encoder_taylorseer is not None:
207
- encoder_taylorseer.mark_step_begin()
272
+ calibrator, encoder_calibrator = self.get_calibrators()
273
+ if calibrator is not None:
274
+ calibrator.mark_step_begin()
275
+ if encoder_calibrator is not None:
276
+ encoder_calibrator.mark_step_begin()
208
277
 
209
- def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
210
- return self.taylorseer, self.encoder_tarlorseer
278
+ def get_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
279
+ return self.calibrator, self.encoder_calibrator
211
280
 
212
- def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
213
- return self.cfg_taylorseer, self.cfg_encoder_taylorseer
281
+ def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
282
+ return self.cfg_calibrator, self.cfg_encoder_calibrator
214
283
 
215
284
  def add_residual_diff(self, diff):
216
285
  # step: executed_steps - 1, not transformer_steps - 1
@@ -269,13 +338,13 @@ class CachedContext: # Internal CachedContext Impl class
269
338
  return self.transformer_executed_steps - 1
270
339
 
271
340
  def is_separate_cfg_step(self):
272
- if not self.enable_separate_cfg:
341
+ if not self.cache_config.enable_separate_cfg:
273
342
  return False
274
- if self.cfg_compute_first:
343
+ if self.cache_config.cfg_compute_first:
275
344
  # CFG steps: 0, 2, 4, 6, ...
276
345
  return self.get_current_transformer_step() % 2 == 0
277
346
  # CFG steps: 1, 3, 5, 7, ...
278
347
  return self.get_current_transformer_step() % 2 != 0
279
348
 
280
349
  def is_in_warmup(self):
281
- return self.get_current_step() < self.max_warmup_steps
350
+ return self.get_current_step() < self.cache_config.max_warmup_steps