cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cache_dit/__init__.py +8 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +17 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +555 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +262 -938
  8. cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
  12. cache_dit/cache_factory/cache_blocks/utils.py +16 -10
  13. cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
  14. cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
  16. cache_dit/cache_factory/cache_interface.py +31 -31
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +26 -26
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
  22. cache_dit-0.2.28.dist-info/RECORD +47 -0
  23. cache_dit/cache_factory/cache_context.py +0 -1155
  24. cache_dit-0.2.26.dist-info/RECORD +0 -42
  25. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
1
+ import logging
2
+ import dataclasses
3
+ from collections import defaultdict
4
+ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
5
+
6
+ import torch
7
+
8
+ from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
9
+ from cache_dit.logger import init_logger
10
+
11
+ logger = init_logger(__name__)
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class CachedContext: # Internal CachedContext Impl class
16
+ name: str = "default"
17
+ # Dual Block Cache
18
+ # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
19
+ Fn_compute_blocks: int = 1
20
+ Bn_compute_blocks: int = 0
21
+ # We have added residual cache pattern for selected compute blocks
22
+ Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
23
+ Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
24
+ # non compute blocks diff threshold, we don't skip the non
25
+ # compute blocks if the diff >= threshold
26
+ non_compute_blocks_diff_threshold: float = 0.08
27
+ max_Fn_compute_blocks: int = -1
28
+ max_Bn_compute_blocks: int = -1
29
+ # L1 hidden states or residual diff threshold for Fn
30
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.05
31
+ l1_hidden_states_diff_threshold: float = None
32
+ important_condition_threshold: float = 0.0
33
+
34
+ # Alter Cache Settings
35
+ # Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
36
+ enable_alter_cache: bool = False
37
+ is_alter_cache: bool = True
38
+ # 1.0 means we always cache the residuals if alter_cache is enabled.
39
+ alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
40
+
41
+ # Buffer for storing the residuals and other tensors
42
+ buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
43
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
44
+ default_factory=lambda: defaultdict(int),
45
+ )
46
+
47
+ # Other settings
48
+ downsample_factor: int = 1
49
+ num_inference_steps: int = -1 # for future use
50
+ max_warmup_steps: int = 0 # DON'T Cache in warmup steps
51
+ # DON'T Cache if the number of cached steps >= max_cached_steps
52
+ max_cached_steps: int = -1 # for both CFG and non-CFG
53
+ max_continuous_cached_steps: int = -1 # the max continuous cached steps
54
+
55
+ # Record the steps that have been cached, both cached and non-cache
56
+ executed_steps: int = 0 # cache + non-cache steps pippeline
57
+ # steps for transformer, for CFG, transformer_executed_steps will
58
+ # be double of executed_steps.
59
+ transformer_executed_steps: int = 0
60
+
61
+ # Support TaylorSeers in Dual Block Cache
62
+ # Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
63
+ # Url: https://arxiv.org/pdf/2503.06923
64
+ enable_taylorseer: bool = False
65
+ enable_encoder_taylorseer: bool = False
66
+ # NOTE: use residual cache for taylorseer may incur precision loss
67
+ taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
68
+ taylorseer_order: int = 2 # The order for TaylorSeer
69
+ taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
70
+ taylorseer: Optional[TaylorSeer] = None
71
+ encoder_tarlorseer: Optional[TaylorSeer] = None
72
+
73
+ # Support enable_spearate_cfg, such as Wan 2.1,
74
+ # Qwen-Image. For model that fused CFG and non-CFG into single
75
+ # forward step, should set enable_spearate_cfg as False.
76
+ # For example: CogVideoX, HunyuanVideo, Mochi.
77
+ enable_spearate_cfg: bool = False
78
+ # Compute cfg forward first or not, default False, namely,
79
+ # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
80
+ cfg_compute_first: bool = False
81
+ # Compute spearate diff values for CFG and non-CFG step,
82
+ # default True. If False, we will use the computed diff from
83
+ # current non-CFG transformer step for current CFG step.
84
+ cfg_diff_compute_separate: bool = True
85
+ cfg_taylorseer: Optional[TaylorSeer] = None
86
+ cfg_encoder_taylorseer: Optional[TaylorSeer] = None
87
+
88
+ # CFG & non-CFG cached steps
89
+ cached_steps: List[int] = dataclasses.field(default_factory=list)
90
+ residual_diffs: DefaultDict[str, float] = dataclasses.field(
91
+ default_factory=lambda: defaultdict(float),
92
+ )
93
+ continuous_cached_steps: int = 0
94
+ cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
95
+ cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
96
+ default_factory=lambda: defaultdict(float),
97
+ )
98
+ cfg_continuous_cached_steps: int = 0
99
+
100
+ @torch.compiler.disable
101
+ def __post_init__(self):
102
+ if logger.isEnabledFor(logging.DEBUG):
103
+ logger.info(f"Created _CacheContext: {self.name}")
104
+ # Some checks for settings
105
+ if self.enable_spearate_cfg:
106
+ assert self.enable_alter_cache is False, (
107
+ "enable_alter_cache must set as False if "
108
+ "enable_spearate_cfg is enabled."
109
+ )
110
+ if self.cfg_diff_compute_separate:
111
+ assert self.cfg_compute_first is False, (
112
+ "cfg_compute_first must set as False if "
113
+ "cfg_diff_compute_separate is enabled."
114
+ )
115
+
116
+ if "max_warmup_steps" not in self.taylorseer_kwargs:
117
+ # If max_warmup_steps is not set in taylorseer_kwargs,
118
+ # set the same as max_warmup_steps for DBCache
119
+ self.taylorseer_kwargs["max_warmup_steps"] = (
120
+ self.max_warmup_steps if self.max_warmup_steps > 0 else 1
121
+ )
122
+
123
+ # Overwrite the 'n_derivatives' by 'taylorseer_order', default: 2.
124
+ self.taylorseer_kwargs["n_derivatives"] = self.taylorseer_order
125
+
126
+ if self.enable_taylorseer:
127
+ self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
128
+ if self.enable_spearate_cfg:
129
+ self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
130
+
131
+ if self.enable_encoder_taylorseer:
132
+ self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
133
+ if self.enable_spearate_cfg:
134
+ self.cfg_encoder_taylorseer = TaylorSeer(
135
+ **self.taylorseer_kwargs
136
+ )
137
+
138
+ @torch.compiler.disable
139
+ def get_residual_diff_threshold(self):
140
+ if self.enable_alter_cache:
141
+ residual_diff_threshold = self.alter_residual_diff_threshold
142
+ else:
143
+ residual_diff_threshold = self.residual_diff_threshold
144
+ if self.l1_hidden_states_diff_threshold is not None:
145
+ # Use the L1 hidden states diff threshold if set
146
+ residual_diff_threshold = self.l1_hidden_states_diff_threshold
147
+ if isinstance(residual_diff_threshold, torch.Tensor):
148
+ residual_diff_threshold = residual_diff_threshold.item()
149
+ return residual_diff_threshold
150
+
151
+ @torch.compiler.disable
152
+ def get_buffer(self, name):
153
+ if self.enable_alter_cache and self.is_alter_cache:
154
+ name = f"{name}_alter"
155
+ return self.buffers.get(name)
156
+
157
+ @torch.compiler.disable
158
+ def set_buffer(self, name, buffer):
159
+ if self.enable_alter_cache and self.is_alter_cache:
160
+ name = f"{name}_alter"
161
+ self.buffers[name] = buffer
162
+
163
+ @torch.compiler.disable
164
+ def remove_buffer(self, name):
165
+ if self.enable_alter_cache and self.is_alter_cache:
166
+ name = f"{name}_alter"
167
+ if name in self.buffers:
168
+ del self.buffers[name]
169
+
170
+ @torch.compiler.disable
171
+ def clear_buffers(self):
172
+ self.buffers.clear()
173
+
174
+ @torch.compiler.disable
175
+ def mark_step_begin(self):
176
+ # Always increase transformer executed steps
177
+ # incr step: prev 0 -> 1; prev 1 -> 2
178
+ # current step: incr step - 1
179
+ self.transformer_executed_steps += 1
180
+ if not self.enable_spearate_cfg:
181
+ self.executed_steps += 1
182
+ else:
183
+ # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
184
+ if not self.cfg_compute_first:
185
+ if not self.is_separate_cfg_step():
186
+ # transformer step: 0,2,4,...
187
+ self.executed_steps += 1
188
+ else:
189
+ if self.is_separate_cfg_step():
190
+ # transformer step: 0,2,4,...
191
+ self.executed_steps += 1
192
+
193
+ if not self.enable_alter_cache:
194
+ # 0 F 1 T 2 F 3 T 4 F 5 T ...
195
+ self.is_alter_cache = not self.is_alter_cache
196
+
197
+ # Reset the cached steps and residual diffs at the beginning
198
+ # of each inference.
199
+ if self.get_current_transformer_step() == 0:
200
+ self.cached_steps.clear()
201
+ self.residual_diffs.clear()
202
+ self.cfg_cached_steps.clear()
203
+ self.cfg_residual_diffs.clear()
204
+ # Reset the TaylorSeers cache at the beginning of each inference.
205
+ # reset_cache will set the current step to -1 for TaylorSeer,
206
+ if self.enable_taylorseer or self.enable_encoder_taylorseer:
207
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
208
+ if taylorseer is not None:
209
+ taylorseer.reset_cache()
210
+ if encoder_taylorseer is not None:
211
+ encoder_taylorseer.reset_cache()
212
+ cfg_taylorseer, cfg_encoder_taylorseer = (
213
+ self.get_cfg_taylorseers()
214
+ )
215
+ if cfg_taylorseer is not None:
216
+ cfg_taylorseer.reset_cache()
217
+ if cfg_encoder_taylorseer is not None:
218
+ cfg_encoder_taylorseer.reset_cache()
219
+
220
+ # mark_step_begin of TaylorSeer must be called after the cache is reset.
221
+ if self.enable_taylorseer or self.enable_encoder_taylorseer:
222
+ if self.enable_spearate_cfg:
223
+ # Assume non-CFG steps: 0, 2, 4, 6, ...
224
+ if not self.is_separate_cfg_step():
225
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
226
+ if taylorseer is not None:
227
+ taylorseer.mark_step_begin()
228
+ if encoder_taylorseer is not None:
229
+ encoder_taylorseer.mark_step_begin()
230
+ else:
231
+ cfg_taylorseer, cfg_encoder_taylorseer = (
232
+ self.get_cfg_taylorseers()
233
+ )
234
+ if cfg_taylorseer is not None:
235
+ cfg_taylorseer.mark_step_begin()
236
+ if cfg_encoder_taylorseer is not None:
237
+ cfg_encoder_taylorseer.mark_step_begin()
238
+ else:
239
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
240
+ if taylorseer is not None:
241
+ taylorseer.mark_step_begin()
242
+ if encoder_taylorseer is not None:
243
+ encoder_taylorseer.mark_step_begin()
244
+
245
+ def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
246
+ return self.taylorseer, self.encoder_tarlorseer
247
+
248
+ def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
249
+ return self.cfg_taylorseer, self.cfg_encoder_taylorseer
250
+
251
+ @torch.compiler.disable
252
+ def add_residual_diff(self, diff):
253
+ # step: executed_steps - 1, not transformer_steps - 1
254
+ step = str(self.get_current_step())
255
+ # Only add the diff if it is not already recorded for this step
256
+ if not self.is_separate_cfg_step():
257
+ if step not in self.residual_diffs:
258
+ self.residual_diffs[step] = diff
259
+ else:
260
+ if step not in self.cfg_residual_diffs:
261
+ self.cfg_residual_diffs[step] = diff
262
+
263
+ @torch.compiler.disable
264
+ def get_residual_diffs(self):
265
+ return self.residual_diffs.copy()
266
+
267
+ @torch.compiler.disable
268
+ def get_cfg_residual_diffs(self):
269
+ return self.cfg_residual_diffs.copy()
270
+
271
+ @torch.compiler.disable
272
+ def add_cached_step(self):
273
+ curr_cached_step = self.get_current_step()
274
+ if not self.is_separate_cfg_step():
275
+ if self.cached_steps:
276
+ prev_cached_step = self.cached_steps[-1]
277
+ if curr_cached_step - prev_cached_step == 1:
278
+ if self.continuous_cached_steps == 0:
279
+ self.continuous_cached_steps += 2
280
+ else:
281
+ self.continuous_cached_steps += 1
282
+ else:
283
+ self.continuous_cached_steps += 1
284
+
285
+ self.cached_steps.append(curr_cached_step)
286
+ else:
287
+ if self.cfg_cached_steps:
288
+ prev_cfg_cached_step = self.cfg_cached_steps[-1]
289
+ if curr_cached_step - prev_cfg_cached_step == 1:
290
+ if self.cfg_continuous_cached_steps == 0:
291
+ self.cfg_continuous_cached_steps += 2
292
+ else:
293
+ self.cfg_continuous_cached_steps += 1
294
+ else:
295
+ self.cfg_continuous_cached_steps += 1
296
+
297
+ self.cfg_cached_steps.append(curr_cached_step)
298
+
299
+ @torch.compiler.disable
300
+ def get_cached_steps(self):
301
+ return self.cached_steps.copy()
302
+
303
+ @torch.compiler.disable
304
+ def get_cfg_cached_steps(self):
305
+ return self.cfg_cached_steps.copy()
306
+
307
+ @torch.compiler.disable
308
+ def get_current_step(self):
309
+ return self.executed_steps - 1
310
+
311
+ @torch.compiler.disable
312
+ def get_current_transformer_step(self):
313
+ return self.transformer_executed_steps - 1
314
+
315
+ @torch.compiler.disable
316
+ def is_separate_cfg_step(self):
317
+ if not self.enable_spearate_cfg:
318
+ return False
319
+ if self.cfg_compute_first:
320
+ # CFG steps: 0, 2, 4, 6, ...
321
+ return self.get_current_transformer_step() % 2 == 0
322
+ # CFG steps: 1, 3, 5, 7, ...
323
+ return self.get_current_transformer_step() % 2 != 0
324
+
325
+ @torch.compiler.disable
326
+ def is_in_warmup(self):
327
+ return self.get_current_step() < self.max_warmup_steps