cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +8 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +17 -4
- cache_dit/cache_factory/block_adapters/__init__.py +555 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +262 -938
- cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
- cache_dit/cache_factory/cache_blocks/utils.py +16 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +31 -31
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +26 -26
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
- cache_dit-0.2.28.dist-info/RECORD +47 -0
- cache_dit/cache_factory/cache_context.py +0 -1155
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import dataclasses
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
|
|
9
|
+
from cache_dit.logger import init_logger
|
|
10
|
+
|
|
11
|
+
logger = init_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass
|
|
15
|
+
class CachedContext: # Internal CachedContext Impl class
|
|
16
|
+
name: str = "default"
|
|
17
|
+
# Dual Block Cache
|
|
18
|
+
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
19
|
+
Fn_compute_blocks: int = 1
|
|
20
|
+
Bn_compute_blocks: int = 0
|
|
21
|
+
# We have added residual cache pattern for selected compute blocks
|
|
22
|
+
Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
23
|
+
Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
24
|
+
# non compute blocks diff threshold, we don't skip the non
|
|
25
|
+
# compute blocks if the diff >= threshold
|
|
26
|
+
non_compute_blocks_diff_threshold: float = 0.08
|
|
27
|
+
max_Fn_compute_blocks: int = -1
|
|
28
|
+
max_Bn_compute_blocks: int = -1
|
|
29
|
+
# L1 hidden states or residual diff threshold for Fn
|
|
30
|
+
residual_diff_threshold: Union[torch.Tensor, float] = 0.05
|
|
31
|
+
l1_hidden_states_diff_threshold: float = None
|
|
32
|
+
important_condition_threshold: float = 0.0
|
|
33
|
+
|
|
34
|
+
# Alter Cache Settings
|
|
35
|
+
# Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
36
|
+
enable_alter_cache: bool = False
|
|
37
|
+
is_alter_cache: bool = True
|
|
38
|
+
# 1.0 means we always cache the residuals if alter_cache is enabled.
|
|
39
|
+
alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
|
|
40
|
+
|
|
41
|
+
# Buffer for storing the residuals and other tensors
|
|
42
|
+
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
43
|
+
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
|
|
44
|
+
default_factory=lambda: defaultdict(int),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Other settings
|
|
48
|
+
downsample_factor: int = 1
|
|
49
|
+
num_inference_steps: int = -1 # for future use
|
|
50
|
+
max_warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
51
|
+
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
52
|
+
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
53
|
+
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
54
|
+
|
|
55
|
+
# Record the steps that have been cached, both cached and non-cache
|
|
56
|
+
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
57
|
+
# steps for transformer, for CFG, transformer_executed_steps will
|
|
58
|
+
# be double of executed_steps.
|
|
59
|
+
transformer_executed_steps: int = 0
|
|
60
|
+
|
|
61
|
+
# Support TaylorSeers in Dual Block Cache
|
|
62
|
+
# Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
63
|
+
# Url: https://arxiv.org/pdf/2503.06923
|
|
64
|
+
enable_taylorseer: bool = False
|
|
65
|
+
enable_encoder_taylorseer: bool = False
|
|
66
|
+
# NOTE: use residual cache for taylorseer may incur precision loss
|
|
67
|
+
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
68
|
+
taylorseer_order: int = 2 # The order for TaylorSeer
|
|
69
|
+
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
70
|
+
taylorseer: Optional[TaylorSeer] = None
|
|
71
|
+
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
72
|
+
|
|
73
|
+
# Support enable_spearate_cfg, such as Wan 2.1,
|
|
74
|
+
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
75
|
+
# forward step, should set enable_spearate_cfg as False.
|
|
76
|
+
# For example: CogVideoX, HunyuanVideo, Mochi.
|
|
77
|
+
enable_spearate_cfg: bool = False
|
|
78
|
+
# Compute cfg forward first or not, default False, namely,
|
|
79
|
+
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
80
|
+
cfg_compute_first: bool = False
|
|
81
|
+
# Compute spearate diff values for CFG and non-CFG step,
|
|
82
|
+
# default True. If False, we will use the computed diff from
|
|
83
|
+
# current non-CFG transformer step for current CFG step.
|
|
84
|
+
cfg_diff_compute_separate: bool = True
|
|
85
|
+
cfg_taylorseer: Optional[TaylorSeer] = None
|
|
86
|
+
cfg_encoder_taylorseer: Optional[TaylorSeer] = None
|
|
87
|
+
|
|
88
|
+
# CFG & non-CFG cached steps
|
|
89
|
+
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
90
|
+
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
91
|
+
default_factory=lambda: defaultdict(float),
|
|
92
|
+
)
|
|
93
|
+
continuous_cached_steps: int = 0
|
|
94
|
+
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
95
|
+
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
96
|
+
default_factory=lambda: defaultdict(float),
|
|
97
|
+
)
|
|
98
|
+
cfg_continuous_cached_steps: int = 0
|
|
99
|
+
|
|
100
|
+
@torch.compiler.disable
|
|
101
|
+
def __post_init__(self):
|
|
102
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
103
|
+
logger.info(f"Created _CacheContext: {self.name}")
|
|
104
|
+
# Some checks for settings
|
|
105
|
+
if self.enable_spearate_cfg:
|
|
106
|
+
assert self.enable_alter_cache is False, (
|
|
107
|
+
"enable_alter_cache must set as False if "
|
|
108
|
+
"enable_spearate_cfg is enabled."
|
|
109
|
+
)
|
|
110
|
+
if self.cfg_diff_compute_separate:
|
|
111
|
+
assert self.cfg_compute_first is False, (
|
|
112
|
+
"cfg_compute_first must set as False if "
|
|
113
|
+
"cfg_diff_compute_separate is enabled."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if "max_warmup_steps" not in self.taylorseer_kwargs:
|
|
117
|
+
# If max_warmup_steps is not set in taylorseer_kwargs,
|
|
118
|
+
# set the same as max_warmup_steps for DBCache
|
|
119
|
+
self.taylorseer_kwargs["max_warmup_steps"] = (
|
|
120
|
+
self.max_warmup_steps if self.max_warmup_steps > 0 else 1
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Overwrite the 'n_derivatives' by 'taylorseer_order', default: 2.
|
|
124
|
+
self.taylorseer_kwargs["n_derivatives"] = self.taylorseer_order
|
|
125
|
+
|
|
126
|
+
if self.enable_taylorseer:
|
|
127
|
+
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
128
|
+
if self.enable_spearate_cfg:
|
|
129
|
+
self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
130
|
+
|
|
131
|
+
if self.enable_encoder_taylorseer:
|
|
132
|
+
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
133
|
+
if self.enable_spearate_cfg:
|
|
134
|
+
self.cfg_encoder_taylorseer = TaylorSeer(
|
|
135
|
+
**self.taylorseer_kwargs
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
@torch.compiler.disable
|
|
139
|
+
def get_residual_diff_threshold(self):
|
|
140
|
+
if self.enable_alter_cache:
|
|
141
|
+
residual_diff_threshold = self.alter_residual_diff_threshold
|
|
142
|
+
else:
|
|
143
|
+
residual_diff_threshold = self.residual_diff_threshold
|
|
144
|
+
if self.l1_hidden_states_diff_threshold is not None:
|
|
145
|
+
# Use the L1 hidden states diff threshold if set
|
|
146
|
+
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
147
|
+
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
148
|
+
residual_diff_threshold = residual_diff_threshold.item()
|
|
149
|
+
return residual_diff_threshold
|
|
150
|
+
|
|
151
|
+
@torch.compiler.disable
|
|
152
|
+
def get_buffer(self, name):
|
|
153
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
154
|
+
name = f"{name}_alter"
|
|
155
|
+
return self.buffers.get(name)
|
|
156
|
+
|
|
157
|
+
@torch.compiler.disable
|
|
158
|
+
def set_buffer(self, name, buffer):
|
|
159
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
160
|
+
name = f"{name}_alter"
|
|
161
|
+
self.buffers[name] = buffer
|
|
162
|
+
|
|
163
|
+
@torch.compiler.disable
|
|
164
|
+
def remove_buffer(self, name):
|
|
165
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
166
|
+
name = f"{name}_alter"
|
|
167
|
+
if name in self.buffers:
|
|
168
|
+
del self.buffers[name]
|
|
169
|
+
|
|
170
|
+
@torch.compiler.disable
|
|
171
|
+
def clear_buffers(self):
|
|
172
|
+
self.buffers.clear()
|
|
173
|
+
|
|
174
|
+
@torch.compiler.disable
|
|
175
|
+
def mark_step_begin(self):
|
|
176
|
+
# Always increase transformer executed steps
|
|
177
|
+
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
178
|
+
# current step: incr step - 1
|
|
179
|
+
self.transformer_executed_steps += 1
|
|
180
|
+
if not self.enable_spearate_cfg:
|
|
181
|
+
self.executed_steps += 1
|
|
182
|
+
else:
|
|
183
|
+
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
184
|
+
if not self.cfg_compute_first:
|
|
185
|
+
if not self.is_separate_cfg_step():
|
|
186
|
+
# transformer step: 0,2,4,...
|
|
187
|
+
self.executed_steps += 1
|
|
188
|
+
else:
|
|
189
|
+
if self.is_separate_cfg_step():
|
|
190
|
+
# transformer step: 0,2,4,...
|
|
191
|
+
self.executed_steps += 1
|
|
192
|
+
|
|
193
|
+
if not self.enable_alter_cache:
|
|
194
|
+
# 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
195
|
+
self.is_alter_cache = not self.is_alter_cache
|
|
196
|
+
|
|
197
|
+
# Reset the cached steps and residual diffs at the beginning
|
|
198
|
+
# of each inference.
|
|
199
|
+
if self.get_current_transformer_step() == 0:
|
|
200
|
+
self.cached_steps.clear()
|
|
201
|
+
self.residual_diffs.clear()
|
|
202
|
+
self.cfg_cached_steps.clear()
|
|
203
|
+
self.cfg_residual_diffs.clear()
|
|
204
|
+
# Reset the TaylorSeers cache at the beginning of each inference.
|
|
205
|
+
# reset_cache will set the current step to -1 for TaylorSeer,
|
|
206
|
+
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
207
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
208
|
+
if taylorseer is not None:
|
|
209
|
+
taylorseer.reset_cache()
|
|
210
|
+
if encoder_taylorseer is not None:
|
|
211
|
+
encoder_taylorseer.reset_cache()
|
|
212
|
+
cfg_taylorseer, cfg_encoder_taylorseer = (
|
|
213
|
+
self.get_cfg_taylorseers()
|
|
214
|
+
)
|
|
215
|
+
if cfg_taylorseer is not None:
|
|
216
|
+
cfg_taylorseer.reset_cache()
|
|
217
|
+
if cfg_encoder_taylorseer is not None:
|
|
218
|
+
cfg_encoder_taylorseer.reset_cache()
|
|
219
|
+
|
|
220
|
+
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
221
|
+
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
222
|
+
if self.enable_spearate_cfg:
|
|
223
|
+
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
224
|
+
if not self.is_separate_cfg_step():
|
|
225
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
226
|
+
if taylorseer is not None:
|
|
227
|
+
taylorseer.mark_step_begin()
|
|
228
|
+
if encoder_taylorseer is not None:
|
|
229
|
+
encoder_taylorseer.mark_step_begin()
|
|
230
|
+
else:
|
|
231
|
+
cfg_taylorseer, cfg_encoder_taylorseer = (
|
|
232
|
+
self.get_cfg_taylorseers()
|
|
233
|
+
)
|
|
234
|
+
if cfg_taylorseer is not None:
|
|
235
|
+
cfg_taylorseer.mark_step_begin()
|
|
236
|
+
if cfg_encoder_taylorseer is not None:
|
|
237
|
+
cfg_encoder_taylorseer.mark_step_begin()
|
|
238
|
+
else:
|
|
239
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
240
|
+
if taylorseer is not None:
|
|
241
|
+
taylorseer.mark_step_begin()
|
|
242
|
+
if encoder_taylorseer is not None:
|
|
243
|
+
encoder_taylorseer.mark_step_begin()
|
|
244
|
+
|
|
245
|
+
def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
246
|
+
return self.taylorseer, self.encoder_tarlorseer
|
|
247
|
+
|
|
248
|
+
def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
249
|
+
return self.cfg_taylorseer, self.cfg_encoder_taylorseer
|
|
250
|
+
|
|
251
|
+
@torch.compiler.disable
|
|
252
|
+
def add_residual_diff(self, diff):
|
|
253
|
+
# step: executed_steps - 1, not transformer_steps - 1
|
|
254
|
+
step = str(self.get_current_step())
|
|
255
|
+
# Only add the diff if it is not already recorded for this step
|
|
256
|
+
if not self.is_separate_cfg_step():
|
|
257
|
+
if step not in self.residual_diffs:
|
|
258
|
+
self.residual_diffs[step] = diff
|
|
259
|
+
else:
|
|
260
|
+
if step not in self.cfg_residual_diffs:
|
|
261
|
+
self.cfg_residual_diffs[step] = diff
|
|
262
|
+
|
|
263
|
+
@torch.compiler.disable
|
|
264
|
+
def get_residual_diffs(self):
|
|
265
|
+
return self.residual_diffs.copy()
|
|
266
|
+
|
|
267
|
+
@torch.compiler.disable
|
|
268
|
+
def get_cfg_residual_diffs(self):
|
|
269
|
+
return self.cfg_residual_diffs.copy()
|
|
270
|
+
|
|
271
|
+
@torch.compiler.disable
|
|
272
|
+
def add_cached_step(self):
|
|
273
|
+
curr_cached_step = self.get_current_step()
|
|
274
|
+
if not self.is_separate_cfg_step():
|
|
275
|
+
if self.cached_steps:
|
|
276
|
+
prev_cached_step = self.cached_steps[-1]
|
|
277
|
+
if curr_cached_step - prev_cached_step == 1:
|
|
278
|
+
if self.continuous_cached_steps == 0:
|
|
279
|
+
self.continuous_cached_steps += 2
|
|
280
|
+
else:
|
|
281
|
+
self.continuous_cached_steps += 1
|
|
282
|
+
else:
|
|
283
|
+
self.continuous_cached_steps += 1
|
|
284
|
+
|
|
285
|
+
self.cached_steps.append(curr_cached_step)
|
|
286
|
+
else:
|
|
287
|
+
if self.cfg_cached_steps:
|
|
288
|
+
prev_cfg_cached_step = self.cfg_cached_steps[-1]
|
|
289
|
+
if curr_cached_step - prev_cfg_cached_step == 1:
|
|
290
|
+
if self.cfg_continuous_cached_steps == 0:
|
|
291
|
+
self.cfg_continuous_cached_steps += 2
|
|
292
|
+
else:
|
|
293
|
+
self.cfg_continuous_cached_steps += 1
|
|
294
|
+
else:
|
|
295
|
+
self.cfg_continuous_cached_steps += 1
|
|
296
|
+
|
|
297
|
+
self.cfg_cached_steps.append(curr_cached_step)
|
|
298
|
+
|
|
299
|
+
@torch.compiler.disable
|
|
300
|
+
def get_cached_steps(self):
|
|
301
|
+
return self.cached_steps.copy()
|
|
302
|
+
|
|
303
|
+
@torch.compiler.disable
|
|
304
|
+
def get_cfg_cached_steps(self):
|
|
305
|
+
return self.cfg_cached_steps.copy()
|
|
306
|
+
|
|
307
|
+
@torch.compiler.disable
|
|
308
|
+
def get_current_step(self):
|
|
309
|
+
return self.executed_steps - 1
|
|
310
|
+
|
|
311
|
+
@torch.compiler.disable
|
|
312
|
+
def get_current_transformer_step(self):
|
|
313
|
+
return self.transformer_executed_steps - 1
|
|
314
|
+
|
|
315
|
+
@torch.compiler.disable
|
|
316
|
+
def is_separate_cfg_step(self):
|
|
317
|
+
if not self.enable_spearate_cfg:
|
|
318
|
+
return False
|
|
319
|
+
if self.cfg_compute_first:
|
|
320
|
+
# CFG steps: 0, 2, 4, 6, ...
|
|
321
|
+
return self.get_current_transformer_step() % 2 == 0
|
|
322
|
+
# CFG steps: 1, 3, 5, 7, ...
|
|
323
|
+
return self.get_current_transformer_step() % 2 != 0
|
|
324
|
+
|
|
325
|
+
@torch.compiler.disable
|
|
326
|
+
def is_in_warmup(self):
|
|
327
|
+
return self.get_current_step() < self.max_warmup_steps
|