cache-dit 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +47 -3
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +23 -62
- cache_dit/cache_factory/cache_blocks/pattern_base.py +23 -168
- cache_dit/cache_factory/cache_contexts/cache_context.py +18 -64
- cache_dit/cache_factory/cache_contexts/cache_manager.py +23 -71
- cache_dit/cache_factory/cache_contexts/taylorseer.py +11 -13
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- cache_dit/quantize/quantize_ao.py +3 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/METADATA +184 -39
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/RECORD +21 -21
- cache_dit/quantize/quantize_svdq.py +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/top_level.txt +0 -0
|
@@ -14,13 +14,9 @@ logger = init_logger(__name__)
|
|
|
14
14
|
@dataclasses.dataclass
|
|
15
15
|
class CachedContext: # Internal CachedContext Impl class
|
|
16
16
|
name: str = "default"
|
|
17
|
-
# Dual Block Cache
|
|
18
|
-
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
17
|
+
# Dual Block Cache with flexible FnBn configuration.
|
|
19
18
|
Fn_compute_blocks: int = 1
|
|
20
19
|
Bn_compute_blocks: int = 0
|
|
21
|
-
# We have added residual cache pattern for selected compute blocks
|
|
22
|
-
Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
23
|
-
Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
24
20
|
# non compute blocks diff threshold, we don't skip the non
|
|
25
21
|
# compute blocks if the diff >= threshold
|
|
26
22
|
non_compute_blocks_diff_threshold: float = 0.08
|
|
@@ -31,13 +27,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
31
27
|
l1_hidden_states_diff_threshold: float = None
|
|
32
28
|
important_condition_threshold: float = 0.0
|
|
33
29
|
|
|
34
|
-
# Alter Cache Settings
|
|
35
|
-
# Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
36
|
-
enable_alter_cache: bool = False
|
|
37
|
-
is_alter_cache: bool = True
|
|
38
|
-
# 1.0 means we always cache the residuals if alter_cache is enabled.
|
|
39
|
-
alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
|
|
40
|
-
|
|
41
30
|
# Buffer for storing the residuals and other tensors
|
|
42
31
|
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
43
32
|
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
|
|
@@ -63,22 +52,21 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
63
52
|
# Url: https://arxiv.org/pdf/2503.06923
|
|
64
53
|
enable_taylorseer: bool = False
|
|
65
54
|
enable_encoder_taylorseer: bool = False
|
|
66
|
-
# NOTE: use residual cache for taylorseer may incur precision loss
|
|
67
55
|
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
68
|
-
taylorseer_order: int =
|
|
56
|
+
taylorseer_order: int = 1 # The order for TaylorSeer
|
|
69
57
|
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
70
58
|
taylorseer: Optional[TaylorSeer] = None
|
|
71
59
|
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
72
60
|
|
|
73
|
-
# Support
|
|
61
|
+
# Support enable_separate_cfg, such as Wan 2.1,
|
|
74
62
|
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
75
|
-
# forward step, should set
|
|
63
|
+
# forward step, should set enable_separate_cfg as False.
|
|
76
64
|
# For example: CogVideoX, HunyuanVideo, Mochi.
|
|
77
|
-
|
|
65
|
+
enable_separate_cfg: bool = False
|
|
78
66
|
# Compute cfg forward first or not, default False, namely,
|
|
79
67
|
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
80
68
|
cfg_compute_first: bool = False
|
|
81
|
-
# Compute
|
|
69
|
+
# Compute separate diff values for CFG and non-CFG step,
|
|
82
70
|
# default True. If False, we will use the computed diff from
|
|
83
71
|
# current non-CFG transformer step for current CFG step.
|
|
84
72
|
cfg_diff_compute_separate: bool = True
|
|
@@ -97,16 +85,11 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
97
85
|
)
|
|
98
86
|
cfg_continuous_cached_steps: int = 0
|
|
99
87
|
|
|
100
|
-
@torch.compiler.disable
|
|
101
88
|
def __post_init__(self):
|
|
102
89
|
if logger.isEnabledFor(logging.DEBUG):
|
|
103
90
|
logger.info(f"Created _CacheContext: {self.name}")
|
|
104
91
|
# Some checks for settings
|
|
105
|
-
if self.
|
|
106
|
-
assert self.enable_alter_cache is False, (
|
|
107
|
-
"enable_alter_cache must set as False if "
|
|
108
|
-
"enable_spearate_cfg is enabled."
|
|
109
|
-
)
|
|
92
|
+
if self.enable_separate_cfg:
|
|
110
93
|
if self.cfg_diff_compute_separate:
|
|
111
94
|
assert self.cfg_compute_first is False, (
|
|
112
95
|
"cfg_compute_first must set as False if "
|
|
@@ -125,59 +108,44 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
125
108
|
|
|
126
109
|
if self.enable_taylorseer:
|
|
127
110
|
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
128
|
-
if self.
|
|
111
|
+
if self.enable_separate_cfg:
|
|
129
112
|
self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
130
113
|
|
|
131
114
|
if self.enable_encoder_taylorseer:
|
|
132
115
|
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
133
|
-
if self.
|
|
116
|
+
if self.enable_separate_cfg:
|
|
134
117
|
self.cfg_encoder_taylorseer = TaylorSeer(
|
|
135
118
|
**self.taylorseer_kwargs
|
|
136
119
|
)
|
|
137
120
|
|
|
138
|
-
@torch.compiler.disable
|
|
139
121
|
def get_residual_diff_threshold(self):
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
residual_diff_threshold = self.
|
|
144
|
-
if self.l1_hidden_states_diff_threshold is not None:
|
|
145
|
-
# Use the L1 hidden states diff threshold if set
|
|
146
|
-
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
122
|
+
residual_diff_threshold = self.residual_diff_threshold
|
|
123
|
+
if self.l1_hidden_states_diff_threshold is not None:
|
|
124
|
+
# Use the L1 hidden states diff threshold if set
|
|
125
|
+
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
147
126
|
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
148
127
|
residual_diff_threshold = residual_diff_threshold.item()
|
|
149
128
|
return residual_diff_threshold
|
|
150
129
|
|
|
151
|
-
@torch.compiler.disable
|
|
152
130
|
def get_buffer(self, name):
|
|
153
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
154
|
-
name = f"{name}_alter"
|
|
155
131
|
return self.buffers.get(name)
|
|
156
132
|
|
|
157
|
-
@torch.compiler.disable
|
|
158
133
|
def set_buffer(self, name, buffer):
|
|
159
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
160
|
-
name = f"{name}_alter"
|
|
161
134
|
self.buffers[name] = buffer
|
|
162
135
|
|
|
163
|
-
@torch.compiler.disable
|
|
164
136
|
def remove_buffer(self, name):
|
|
165
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
166
|
-
name = f"{name}_alter"
|
|
167
137
|
if name in self.buffers:
|
|
168
138
|
del self.buffers[name]
|
|
169
139
|
|
|
170
|
-
@torch.compiler.disable
|
|
171
140
|
def clear_buffers(self):
|
|
172
141
|
self.buffers.clear()
|
|
173
142
|
|
|
174
|
-
@torch.compiler.disable
|
|
175
143
|
def mark_step_begin(self):
|
|
176
144
|
# Always increase transformer executed steps
|
|
177
|
-
# incr
|
|
178
|
-
# current
|
|
145
|
+
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
146
|
+
# current step: incr step - 1
|
|
179
147
|
self.transformer_executed_steps += 1
|
|
180
|
-
if not self.
|
|
148
|
+
if not self.enable_separate_cfg:
|
|
181
149
|
self.executed_steps += 1
|
|
182
150
|
else:
|
|
183
151
|
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
@@ -190,10 +158,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
190
158
|
# transformer step: 0,2,4,...
|
|
191
159
|
self.executed_steps += 1
|
|
192
160
|
|
|
193
|
-
if not self.enable_alter_cache:
|
|
194
|
-
# 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
195
|
-
self.is_alter_cache = not self.is_alter_cache
|
|
196
|
-
|
|
197
161
|
# Reset the cached steps and residual diffs at the beginning
|
|
198
162
|
# of each inference.
|
|
199
163
|
if self.get_current_transformer_step() == 0:
|
|
@@ -219,7 +183,7 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
219
183
|
|
|
220
184
|
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
221
185
|
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
222
|
-
if self.
|
|
186
|
+
if self.enable_separate_cfg:
|
|
223
187
|
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
224
188
|
if not self.is_separate_cfg_step():
|
|
225
189
|
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
@@ -248,7 +212,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
248
212
|
def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
249
213
|
return self.cfg_taylorseer, self.cfg_encoder_taylorseer
|
|
250
214
|
|
|
251
|
-
@torch.compiler.disable
|
|
252
215
|
def add_residual_diff(self, diff):
|
|
253
216
|
# step: executed_steps - 1, not transformer_steps - 1
|
|
254
217
|
step = str(self.get_current_step())
|
|
@@ -260,15 +223,12 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
260
223
|
if step not in self.cfg_residual_diffs:
|
|
261
224
|
self.cfg_residual_diffs[step] = diff
|
|
262
225
|
|
|
263
|
-
@torch.compiler.disable
|
|
264
226
|
def get_residual_diffs(self):
|
|
265
227
|
return self.residual_diffs.copy()
|
|
266
228
|
|
|
267
|
-
@torch.compiler.disable
|
|
268
229
|
def get_cfg_residual_diffs(self):
|
|
269
230
|
return self.cfg_residual_diffs.copy()
|
|
270
231
|
|
|
271
|
-
@torch.compiler.disable
|
|
272
232
|
def add_cached_step(self):
|
|
273
233
|
curr_cached_step = self.get_current_step()
|
|
274
234
|
if not self.is_separate_cfg_step():
|
|
@@ -296,25 +256,20 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
296
256
|
|
|
297
257
|
self.cfg_cached_steps.append(curr_cached_step)
|
|
298
258
|
|
|
299
|
-
@torch.compiler.disable
|
|
300
259
|
def get_cached_steps(self):
|
|
301
260
|
return self.cached_steps.copy()
|
|
302
261
|
|
|
303
|
-
@torch.compiler.disable
|
|
304
262
|
def get_cfg_cached_steps(self):
|
|
305
263
|
return self.cfg_cached_steps.copy()
|
|
306
264
|
|
|
307
|
-
@torch.compiler.disable
|
|
308
265
|
def get_current_step(self):
|
|
309
266
|
return self.executed_steps - 1
|
|
310
267
|
|
|
311
|
-
@torch.compiler.disable
|
|
312
268
|
def get_current_transformer_step(self):
|
|
313
269
|
return self.transformer_executed_steps - 1
|
|
314
270
|
|
|
315
|
-
@torch.compiler.disable
|
|
316
271
|
def is_separate_cfg_step(self):
|
|
317
|
-
if not self.
|
|
272
|
+
if not self.enable_separate_cfg:
|
|
318
273
|
return False
|
|
319
274
|
if self.cfg_compute_first:
|
|
320
275
|
# CFG steps: 0, 2, 4, 6, ...
|
|
@@ -322,6 +277,5 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
322
277
|
# CFG steps: 1, 3, 5, 7, ...
|
|
323
278
|
return self.get_current_transformer_step() % 2 != 0
|
|
324
279
|
|
|
325
|
-
@torch.compiler.disable
|
|
326
280
|
def is_in_warmup(self):
|
|
327
281
|
return self.get_current_step() < self.max_warmup_steps
|
|
@@ -74,8 +74,8 @@ class CachedContextManager:
|
|
|
74
74
|
del self._cached_context_manager[cached_context]
|
|
75
75
|
|
|
76
76
|
def clear_contexts(self):
|
|
77
|
-
for
|
|
78
|
-
self.remove_context(
|
|
77
|
+
for context_name in list(self._cached_context_manager.keys()):
|
|
78
|
+
self.remove_context(context_name)
|
|
79
79
|
|
|
80
80
|
@contextlib.contextmanager
|
|
81
81
|
def enter_context(self, cached_context: CachedContext | str):
|
|
@@ -122,10 +122,7 @@ class CachedContextManager:
|
|
|
122
122
|
default_value,
|
|
123
123
|
)
|
|
124
124
|
|
|
125
|
-
# Manually set sequence fields
|
|
126
|
-
# and Bn_compute_blocks_ids, which are lists or sets.
|
|
127
|
-
_safe_set_sequence_field("Fn_compute_blocks_ids", [])
|
|
128
|
-
_safe_set_sequence_field("Bn_compute_blocks_ids", [])
|
|
125
|
+
# Manually set sequence fields
|
|
129
126
|
_safe_set_sequence_field("taylorseer_kwargs", {})
|
|
130
127
|
|
|
131
128
|
for attr in cache_attrs:
|
|
@@ -301,18 +298,6 @@ class CachedContextManager:
|
|
|
301
298
|
return self.is_taylorseer_cache_residual()
|
|
302
299
|
return True
|
|
303
300
|
|
|
304
|
-
@torch.compiler.disable
|
|
305
|
-
def is_alter_cache_enabled(self) -> bool:
|
|
306
|
-
cached_context = self.get_context()
|
|
307
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
308
|
-
return cached_context.enable_alter_cache
|
|
309
|
-
|
|
310
|
-
@torch.compiler.disable
|
|
311
|
-
def is_alter_cache(self) -> bool:
|
|
312
|
-
cached_context = self.get_context()
|
|
313
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
314
|
-
return cached_context.is_alter_cache
|
|
315
|
-
|
|
316
301
|
@torch.compiler.disable
|
|
317
302
|
def is_in_warmup(self) -> bool:
|
|
318
303
|
cached_context = self.get_context()
|
|
@@ -359,20 +344,6 @@ class CachedContextManager:
|
|
|
359
344
|
)
|
|
360
345
|
return cached_context.Fn_compute_blocks
|
|
361
346
|
|
|
362
|
-
@torch.compiler.disable
|
|
363
|
-
def Fn_compute_blocks_ids(self) -> List[int]:
|
|
364
|
-
cached_context = self.get_context()
|
|
365
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
366
|
-
assert (
|
|
367
|
-
len(cached_context.Fn_compute_blocks_ids)
|
|
368
|
-
<= cached_context.Fn_compute_blocks
|
|
369
|
-
), (
|
|
370
|
-
"The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
|
|
371
|
-
f"{cached_context.Fn_compute_blocks}, but got "
|
|
372
|
-
f"{len(cached_context.Fn_compute_blocks_ids)}"
|
|
373
|
-
)
|
|
374
|
-
return cached_context.Fn_compute_blocks_ids
|
|
375
|
-
|
|
376
347
|
@torch.compiler.disable
|
|
377
348
|
def Bn_compute_blocks(self) -> int:
|
|
378
349
|
cached_context = self.get_context()
|
|
@@ -393,24 +364,10 @@ class CachedContextManager:
|
|
|
393
364
|
return cached_context.Bn_compute_blocks
|
|
394
365
|
|
|
395
366
|
@torch.compiler.disable
|
|
396
|
-
def
|
|
397
|
-
cached_context = self.get_context()
|
|
398
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
399
|
-
assert (
|
|
400
|
-
len(cached_context.Bn_compute_blocks_ids)
|
|
401
|
-
<= cached_context.Bn_compute_blocks
|
|
402
|
-
), (
|
|
403
|
-
"The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
|
|
404
|
-
f"{cached_context.Bn_compute_blocks}, but got "
|
|
405
|
-
f"{len(cached_context.Bn_compute_blocks_ids)}"
|
|
406
|
-
)
|
|
407
|
-
return cached_context.Bn_compute_blocks_ids
|
|
408
|
-
|
|
409
|
-
@torch.compiler.disable
|
|
410
|
-
def enable_spearate_cfg(self) -> bool:
|
|
367
|
+
def enable_separate_cfg(self) -> bool:
|
|
411
368
|
cached_context = self.get_context()
|
|
412
369
|
assert cached_context is not None, "cached_context must be set before"
|
|
413
|
-
return cached_context.
|
|
370
|
+
return cached_context.enable_separate_cfg
|
|
414
371
|
|
|
415
372
|
@torch.compiler.disable
|
|
416
373
|
def is_separate_cfg_step(self) -> bool:
|
|
@@ -453,7 +410,7 @@ class CachedContextManager:
|
|
|
453
410
|
|
|
454
411
|
if all(
|
|
455
412
|
(
|
|
456
|
-
self.
|
|
413
|
+
self.enable_separate_cfg(),
|
|
457
414
|
self.is_separate_cfg_step(),
|
|
458
415
|
not self.cfg_diff_compute_separate(),
|
|
459
416
|
self.get_current_step_residual_diff() is not None,
|
|
@@ -525,6 +482,9 @@ class CachedContextManager:
|
|
|
525
482
|
# Fn buffers
|
|
526
483
|
@torch.compiler.disable
|
|
527
484
|
def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
|
|
485
|
+
# DON'T set None Buffer
|
|
486
|
+
if buffer is None:
|
|
487
|
+
return
|
|
528
488
|
# Set hidden_states or residual for Fn blocks.
|
|
529
489
|
# This buffer is only use for L1 diff calculation.
|
|
530
490
|
downsample_factor = self.get_downsample_factor()
|
|
@@ -548,6 +508,9 @@ class CachedContextManager:
|
|
|
548
508
|
|
|
549
509
|
@torch.compiler.disable
|
|
550
510
|
def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
|
|
511
|
+
# DON'T set None Buffer
|
|
512
|
+
if buffer is None:
|
|
513
|
+
return
|
|
551
514
|
if self.is_separate_cfg_step():
|
|
552
515
|
self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
553
516
|
self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
@@ -566,6 +529,9 @@ class CachedContextManager:
|
|
|
566
529
|
# Bn buffers
|
|
567
530
|
@torch.compiler.disable
|
|
568
531
|
def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
|
|
532
|
+
# DON'T set None Buffer
|
|
533
|
+
if buffer is None:
|
|
534
|
+
return
|
|
569
535
|
# Set hidden_states or residual for Bn blocks.
|
|
570
536
|
# This buffer is use for hidden states approximation.
|
|
571
537
|
if self.is_taylorseer_enabled():
|
|
@@ -820,26 +786,12 @@ class CachedContextManager:
|
|
|
820
786
|
else:
|
|
821
787
|
prev_states_tensor = self.get_Fn_buffer(prefix)
|
|
822
788
|
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
)
|
|
832
|
-
else:
|
|
833
|
-
# Only cache in the alter cache steps
|
|
834
|
-
can_cache = (
|
|
835
|
-
prev_states_tensor is not None
|
|
836
|
-
and self.similarity(
|
|
837
|
-
prev_states_tensor,
|
|
838
|
-
states_tensor,
|
|
839
|
-
threshold=threshold,
|
|
840
|
-
parallelized=parallelized,
|
|
841
|
-
prefix=prefix,
|
|
842
|
-
)
|
|
843
|
-
and self.is_alter_cache()
|
|
844
|
-
)
|
|
789
|
+
# Dynamic cache according to the residual diff
|
|
790
|
+
can_cache = prev_states_tensor is not None and self.similarity(
|
|
791
|
+
prev_states_tensor,
|
|
792
|
+
states_tensor,
|
|
793
|
+
threshold=threshold,
|
|
794
|
+
parallelized=parallelized,
|
|
795
|
+
prefix=prefix,
|
|
796
|
+
)
|
|
845
797
|
return can_cache
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
|
+
from typing import List, Dict
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class TaylorSeer:
|
|
@@ -17,16 +18,14 @@ class TaylorSeer:
|
|
|
17
18
|
self.compute_step_map = compute_step_map
|
|
18
19
|
self.reset_cache()
|
|
19
20
|
|
|
20
|
-
@torch.compiler.disable
|
|
21
21
|
def reset_cache(self):
|
|
22
|
-
self.state = {
|
|
22
|
+
self.state: Dict[str, List[torch.Tensor]] = {
|
|
23
23
|
"dY_prev": [None] * self.ORDER,
|
|
24
24
|
"dY_current": [None] * self.ORDER,
|
|
25
25
|
}
|
|
26
26
|
self.current_step = -1
|
|
27
27
|
self.last_non_approximated_step = -1
|
|
28
28
|
|
|
29
|
-
@torch.compiler.disable
|
|
30
29
|
def should_compute_full(self, step=None):
|
|
31
30
|
step = self.current_step if step is None else step
|
|
32
31
|
if self.compute_step_map is not None:
|
|
@@ -39,16 +38,19 @@ class TaylorSeer:
|
|
|
39
38
|
return True
|
|
40
39
|
return False
|
|
41
40
|
|
|
42
|
-
|
|
43
|
-
def approximate_derivative(self, Y):
|
|
41
|
+
def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
|
|
44
42
|
# n-th order Taylor expansion:
|
|
45
43
|
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
46
44
|
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
47
45
|
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
48
46
|
# especially for large n_derivatives.
|
|
49
|
-
dY_current = [None] * self.ORDER
|
|
47
|
+
dY_current: List[torch.Tensor] = [None] * self.ORDER
|
|
50
48
|
dY_current[0] = Y
|
|
51
49
|
window = self.current_step - self.last_non_approximated_step
|
|
50
|
+
if self.state["dY_prev"][0] is not None:
|
|
51
|
+
if dY_current[0].shape != self.state["dY_prev"][0].shape:
|
|
52
|
+
self.reset_cache()
|
|
53
|
+
|
|
52
54
|
for i in range(self.n_derivatives):
|
|
53
55
|
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
54
56
|
dY_current[i + 1] = (
|
|
@@ -58,8 +60,7 @@ class TaylorSeer:
|
|
|
58
60
|
break
|
|
59
61
|
return dY_current
|
|
60
62
|
|
|
61
|
-
|
|
62
|
-
def approximate_value(self):
|
|
63
|
+
def approximate_value(self) -> torch.Tensor:
|
|
63
64
|
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
64
65
|
# especially for large n_derivatives.
|
|
65
66
|
elapsed = self.current_step - self.last_non_approximated_step
|
|
@@ -71,12 +72,10 @@ class TaylorSeer:
|
|
|
71
72
|
break
|
|
72
73
|
return output
|
|
73
74
|
|
|
74
|
-
@torch.compiler.disable
|
|
75
75
|
def mark_step_begin(self):
|
|
76
76
|
self.current_step += 1
|
|
77
77
|
|
|
78
|
-
|
|
79
|
-
def update(self, Y):
|
|
78
|
+
def update(self, Y: torch.Tensor):
|
|
80
79
|
# Directly call this method will ingnore the warmup
|
|
81
80
|
# policy and force full computation.
|
|
82
81
|
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
@@ -94,8 +93,7 @@ class TaylorSeer:
|
|
|
94
93
|
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
95
94
|
self.last_non_approximated_step = self.current_step
|
|
96
95
|
|
|
97
|
-
|
|
98
|
-
def step(self, Y):
|
|
96
|
+
def step(self, Y: torch.Tensor):
|
|
99
97
|
self.mark_step_begin()
|
|
100
98
|
if self.should_compute_full():
|
|
101
99
|
self.update(Y)
|
|
@@ -24,14 +24,14 @@ def enable_cache(
|
|
|
24
24
|
max_continuous_cached_steps: int = -1,
|
|
25
25
|
residual_diff_threshold: float = 0.08,
|
|
26
26
|
# Cache CFG or not
|
|
27
|
-
|
|
27
|
+
enable_separate_cfg: bool = None,
|
|
28
28
|
cfg_compute_first: bool = False,
|
|
29
29
|
cfg_diff_compute_separate: bool = True,
|
|
30
30
|
# Hybird TaylorSeer
|
|
31
31
|
enable_taylorseer: bool = False,
|
|
32
32
|
enable_encoder_taylorseer: bool = False,
|
|
33
33
|
taylorseer_cache_type: str = "residual",
|
|
34
|
-
taylorseer_order: int =
|
|
34
|
+
taylorseer_order: int = 1,
|
|
35
35
|
**other_cache_context_kwargs,
|
|
36
36
|
) -> Union[
|
|
37
37
|
DiffusionPipeline,
|
|
@@ -70,15 +70,15 @@ def enable_cache(
|
|
|
70
70
|
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
71
71
|
he value of residual diff threshold, a higher value leads to faster performance at the
|
|
72
72
|
cost of lower precision.
|
|
73
|
-
|
|
73
|
+
enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
74
74
|
Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
75
|
-
and non-CFG into single forward step, should set
|
|
75
|
+
and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
76
76
|
CogVideoX, HunyuanVideo, Mochi, etc.
|
|
77
77
|
cfg_compute_first (`bool`, *required*, defaults to False):
|
|
78
78
|
Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
79
79
|
1, 3, 5, ... -> CFG step.
|
|
80
80
|
cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
81
|
-
Compute
|
|
81
|
+
Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
82
82
|
use the computed diff from current non-CFG transformer step for current CFG step.
|
|
83
83
|
enable_taylorseer (`bool`, *required*, defaults to False):
|
|
84
84
|
Enable the hybird TaylorSeer for hidden_states or not. We have supported the
|
|
@@ -91,10 +91,10 @@ def enable_cache(
|
|
|
91
91
|
Enable the hybird TaylorSeer for encoder_hidden_states or not.
|
|
92
92
|
taylorseer_cache_type (`str`, *required*, defaults to `residual`):
|
|
93
93
|
The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
|
|
94
|
-
taylorseer_order (`int`, *required*, defaults to
|
|
94
|
+
taylorseer_order (`int`, *required*, defaults to 1):
|
|
95
95
|
The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
the recommended value is 1 or 2.
|
|
97
|
+
other_cache_context_kwargs: (`dict`, *optional*, defaults to {})
|
|
98
98
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
99
99
|
for more details.
|
|
100
100
|
|
|
@@ -123,7 +123,7 @@ def enable_cache(
|
|
|
123
123
|
max_continuous_cached_steps
|
|
124
124
|
)
|
|
125
125
|
cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
|
|
126
|
-
cache_context_kwargs["
|
|
126
|
+
cache_context_kwargs["enable_separate_cfg"] = enable_separate_cfg
|
|
127
127
|
cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
|
|
128
128
|
cache_context_kwargs["cfg_diff_compute_separate"] = (
|
|
129
129
|
cfg_diff_compute_separate
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
|
|
2
|
+
from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
|
|
2
3
|
from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
|
|
3
4
|
from cache_dit.cache_factory.patch_functors.functor_chroma import (
|
|
4
5
|
ChromaPatchFunctor,
|