cache-dit 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +225 -40
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +1 -1
- {cache_dit-0.2.3.dist-info → cache_dit-0.2.4.dist-info}/METADATA +15 -1
- {cache_dit-0.2.3.dist-info → cache_dit-0.2.4.dist-info}/RECORD +8 -8
- {cache_dit-0.2.3.dist-info → cache_dit-0.2.4.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.3.dist-info → cache_dit-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.3.dist-info → cache_dit-0.2.4.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -49,18 +49,18 @@ class DBCacheContext:
|
|
|
49
49
|
|
|
50
50
|
# Other settings
|
|
51
51
|
downsample_factor: int = 1
|
|
52
|
-
num_inference_steps: int = -1
|
|
52
|
+
num_inference_steps: int = -1 # un-used now
|
|
53
53
|
warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
54
54
|
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
55
|
-
max_cached_steps: int = -1
|
|
55
|
+
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
56
56
|
|
|
57
57
|
# Statistics for botch alter cache and non-alter cache
|
|
58
58
|
# Record the steps that have been cached, both alter cache and non-alter cache
|
|
59
|
-
executed_steps: int = 0 # cache + non-cache steps
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
59
|
+
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
60
|
+
# steps for transformer, for CFG, transformer_executed_steps will
|
|
61
|
+
# be double of executed_steps.
|
|
62
|
+
transformer_executed_steps: int = 0
|
|
63
|
+
|
|
64
64
|
# Support TaylorSeers in Dual Block Cache
|
|
65
65
|
# Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
66
66
|
# Url: https://arxiv.org/pdf/2503.06923
|
|
@@ -71,8 +71,24 @@ class DBCacheContext:
|
|
|
71
71
|
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
72
72
|
taylorseer: Optional[TaylorSeer] = None
|
|
73
73
|
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
74
|
-
|
|
75
|
-
|
|
74
|
+
# Support do_separate_classifier_free_guidance, such as Wan 2.1
|
|
75
|
+
# For model that fused CFG and non-CFG into single forward step,
|
|
76
|
+
# should set do_separate_classifier_free_guidance as False. For
|
|
77
|
+
# example: CogVideoX
|
|
78
|
+
do_separate_classifier_free_guidance: bool = False
|
|
79
|
+
cfg_compute_first: bool = False
|
|
80
|
+
cfg_taylorseer: Optional[TaylorSeer] = None
|
|
81
|
+
cfg_encoder_taylorseer: Optional[TaylorSeer] = None
|
|
82
|
+
|
|
83
|
+
# CFG & non-CFG cached steps
|
|
84
|
+
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
85
|
+
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
86
|
+
default_factory=lambda: defaultdict(float),
|
|
87
|
+
)
|
|
88
|
+
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
89
|
+
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
90
|
+
default_factory=lambda: defaultdict(float),
|
|
91
|
+
)
|
|
76
92
|
|
|
77
93
|
# TODO: Support SLG in Dual Block Cache
|
|
78
94
|
# Skip Layer Guidance, SLG
|
|
@@ -83,6 +99,11 @@ class DBCacheContext:
|
|
|
83
99
|
|
|
84
100
|
@torch.compiler.disable
|
|
85
101
|
def __post_init__(self):
|
|
102
|
+
if self.do_separate_classifier_free_guidance:
|
|
103
|
+
assert self.enable_alter_cache is False, (
|
|
104
|
+
"enable_alter_cache must set as False if "
|
|
105
|
+
"do_separate_classifier_free_guidance is enabled."
|
|
106
|
+
)
|
|
86
107
|
|
|
87
108
|
if "warmup_steps" not in self.taylorseer_kwargs:
|
|
88
109
|
# If warmup_steps is not set in taylorseer_kwargs,
|
|
@@ -99,13 +120,13 @@ class DBCacheContext:
|
|
|
99
120
|
|
|
100
121
|
if self.enable_taylorseer:
|
|
101
122
|
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
102
|
-
if self.
|
|
103
|
-
self.
|
|
123
|
+
if self.do_separate_classifier_free_guidance:
|
|
124
|
+
self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
104
125
|
|
|
105
126
|
if self.enable_encoder_taylorseer:
|
|
106
127
|
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
107
|
-
if self.
|
|
108
|
-
self.
|
|
128
|
+
if self.do_separate_classifier_free_guidance:
|
|
129
|
+
self.cfg_encoder_taylorseer = TaylorSeer(
|
|
109
130
|
**self.taylorseer_kwargs
|
|
110
131
|
)
|
|
111
132
|
|
|
@@ -159,18 +180,27 @@ class DBCacheContext:
|
|
|
159
180
|
|
|
160
181
|
@torch.compiler.disable
|
|
161
182
|
def mark_step_begin(self):
|
|
162
|
-
|
|
163
|
-
|
|
183
|
+
# Always increase transformer executed steps
|
|
184
|
+
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
185
|
+
# current step: incr step - 1
|
|
186
|
+
self.transformer_executed_steps += 1
|
|
187
|
+
if not self.do_separate_classifier_free_guidance:
|
|
188
|
+
self.executed_steps = self.transformer_executed_steps
|
|
164
189
|
else:
|
|
165
|
-
|
|
190
|
+
# 0,1 -> 0, 2,3 -> 1, ...
|
|
191
|
+
self.executed_steps = self.transformer_executed_steps // 2
|
|
192
|
+
|
|
193
|
+
if not self.enable_alter_cache:
|
|
166
194
|
# 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
167
195
|
self.is_alter_cache = not self.is_alter_cache
|
|
168
196
|
|
|
169
197
|
# Reset the cached steps and residual diffs at the beginning
|
|
170
198
|
# of each inference.
|
|
171
|
-
if self.
|
|
199
|
+
if self.get_current_transformer_step() == 0:
|
|
172
200
|
self.cached_steps.clear()
|
|
173
201
|
self.residual_diffs.clear()
|
|
202
|
+
self.cfg_cached_steps.clear()
|
|
203
|
+
self.cfg_residual_diffs.clear()
|
|
174
204
|
self.reset_incremental_names()
|
|
175
205
|
# Reset the TaylorSeers cache at the beginning of each inference.
|
|
176
206
|
# reset_cache will set the current step to -1 for TaylorSeer,
|
|
@@ -180,44 +210,99 @@ class DBCacheContext:
|
|
|
180
210
|
taylorseer.reset_cache()
|
|
181
211
|
if encoder_taylorseer is not None:
|
|
182
212
|
encoder_taylorseer.reset_cache()
|
|
213
|
+
cfg_taylorseer, cfg_encoder_taylorseer = (
|
|
214
|
+
self.get_cfg_taylorseers()
|
|
215
|
+
)
|
|
216
|
+
if cfg_taylorseer is not None:
|
|
217
|
+
cfg_taylorseer.reset_cache()
|
|
218
|
+
if cfg_encoder_taylorseer is not None:
|
|
219
|
+
cfg_encoder_taylorseer.reset_cache()
|
|
183
220
|
|
|
184
221
|
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
185
222
|
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
223
|
+
if self.do_separate_classifier_free_guidance:
|
|
224
|
+
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
225
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
226
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
227
|
+
if taylorseer is not None:
|
|
228
|
+
taylorseer.mark_step_begin()
|
|
229
|
+
if encoder_taylorseer is not None:
|
|
230
|
+
encoder_taylorseer.mark_step_begin()
|
|
231
|
+
else:
|
|
232
|
+
cfg_taylorseer, cfg_encoder_taylorseer = (
|
|
233
|
+
self.get_cfg_taylorseers()
|
|
234
|
+
)
|
|
235
|
+
if cfg_taylorseer is not None:
|
|
236
|
+
cfg_taylorseer.mark_step_begin()
|
|
237
|
+
if cfg_encoder_taylorseer is not None:
|
|
238
|
+
cfg_encoder_taylorseer.mark_step_begin()
|
|
239
|
+
else:
|
|
240
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
241
|
+
if taylorseer is not None:
|
|
242
|
+
taylorseer.mark_step_begin()
|
|
243
|
+
if encoder_taylorseer is not None:
|
|
244
|
+
encoder_taylorseer.mark_step_begin()
|
|
191
245
|
|
|
192
246
|
@torch.compiler.disable
|
|
193
247
|
def get_taylorseers(self):
|
|
194
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
195
|
-
return self.alter_taylorseer, self.alter_encoder_taylorseer
|
|
196
248
|
return self.taylorseer, self.encoder_tarlorseer
|
|
197
249
|
|
|
250
|
+
@torch.compiler.disable
|
|
251
|
+
def get_cfg_taylorseers(self):
|
|
252
|
+
return self.cfg_taylorseer, self.cfg_encoder_taylorseer
|
|
253
|
+
|
|
198
254
|
@torch.compiler.disable
|
|
199
255
|
def add_residual_diff(self, diff):
|
|
200
256
|
step = str(self.get_current_step())
|
|
201
|
-
if
|
|
202
|
-
|
|
203
|
-
self.residual_diffs
|
|
257
|
+
# Only add the diff if it is not already recorded for this step
|
|
258
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
259
|
+
if step not in self.residual_diffs:
|
|
260
|
+
self.residual_diffs[step] = diff
|
|
261
|
+
else:
|
|
262
|
+
if step not in self.cfg_residual_diffs:
|
|
263
|
+
self.cfg_residual_diffs[step] = diff
|
|
204
264
|
|
|
205
265
|
@torch.compiler.disable
|
|
206
266
|
def get_residual_diffs(self):
|
|
207
267
|
return self.residual_diffs.copy()
|
|
208
268
|
|
|
269
|
+
@torch.compiler.disable
|
|
270
|
+
def get_cfg_residual_diffs(self):
|
|
271
|
+
return self.cfg_residual_diffs.copy()
|
|
272
|
+
|
|
209
273
|
@torch.compiler.disable
|
|
210
274
|
def add_cached_step(self):
|
|
211
|
-
self.
|
|
275
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
276
|
+
self.cached_steps.append(self.get_current_step())
|
|
277
|
+
else:
|
|
278
|
+
self.cfg_cached_steps.append(self.get_current_step())
|
|
212
279
|
|
|
213
280
|
@torch.compiler.disable
|
|
214
281
|
def get_cached_steps(self):
|
|
215
282
|
return self.cached_steps.copy()
|
|
216
283
|
|
|
284
|
+
@torch.compiler.disable
|
|
285
|
+
def get_cfg_cached_steps(self):
|
|
286
|
+
return self.cfg_cached_steps.copy()
|
|
287
|
+
|
|
217
288
|
@torch.compiler.disable
|
|
218
289
|
def get_current_step(self):
|
|
219
290
|
return self.executed_steps - 1
|
|
220
291
|
|
|
292
|
+
@torch.compiler.disable
|
|
293
|
+
def get_current_transformer_step(self):
|
|
294
|
+
return self.transformer_executed_steps - 1
|
|
295
|
+
|
|
296
|
+
@torch.compiler.disable
|
|
297
|
+
def is_separate_classifier_free_guidance_step(self):
|
|
298
|
+
if not self.do_separate_classifier_free_guidance:
|
|
299
|
+
return False
|
|
300
|
+
if self.cfg_compute_first:
|
|
301
|
+
# CFG steps: 0, 2, 4, 6, ...
|
|
302
|
+
return self.get_current_transformer_step() % 2
|
|
303
|
+
# CFG steps: 1, 3, 5, 7, ...
|
|
304
|
+
return not self.get_current_transformer_step() % 2
|
|
305
|
+
|
|
221
306
|
@torch.compiler.disable
|
|
222
307
|
def is_in_warmup(self):
|
|
223
308
|
return self.get_current_step() < self.warmup_steps
|
|
@@ -265,6 +350,13 @@ def get_current_step():
|
|
|
265
350
|
return cache_context.get_current_step()
|
|
266
351
|
|
|
267
352
|
|
|
353
|
+
@torch.compiler.disable
|
|
354
|
+
def get_current_transformer_step():
|
|
355
|
+
cache_context = get_current_cache_context()
|
|
356
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
357
|
+
return cache_context.get_current_transformer_step()
|
|
358
|
+
|
|
359
|
+
|
|
268
360
|
@torch.compiler.disable
|
|
269
361
|
def get_cached_steps():
|
|
270
362
|
cache_context = get_current_cache_context()
|
|
@@ -272,6 +364,13 @@ def get_cached_steps():
|
|
|
272
364
|
return cache_context.get_cached_steps()
|
|
273
365
|
|
|
274
366
|
|
|
367
|
+
@torch.compiler.disable
|
|
368
|
+
def get_cfg_cached_steps():
|
|
369
|
+
cache_context = get_current_cache_context()
|
|
370
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
371
|
+
return cache_context.get_cfg_cached_steps()
|
|
372
|
+
|
|
373
|
+
|
|
275
374
|
@torch.compiler.disable
|
|
276
375
|
def get_max_cached_steps():
|
|
277
376
|
cache_context = get_current_cache_context()
|
|
@@ -300,6 +399,13 @@ def get_residual_diffs():
|
|
|
300
399
|
return cache_context.get_residual_diffs()
|
|
301
400
|
|
|
302
401
|
|
|
402
|
+
@torch.compiler.disable
|
|
403
|
+
def get_cfg_residual_diffs():
|
|
404
|
+
cache_context = get_current_cache_context()
|
|
405
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
406
|
+
return cache_context.get_cfg_residual_diffs()
|
|
407
|
+
|
|
408
|
+
|
|
303
409
|
@torch.compiler.disable
|
|
304
410
|
def is_taylorseer_enabled():
|
|
305
411
|
cache_context = get_current_cache_context()
|
|
@@ -321,6 +427,13 @@ def get_taylorseers():
|
|
|
321
427
|
return cache_context.get_taylorseers()
|
|
322
428
|
|
|
323
429
|
|
|
430
|
+
@torch.compiler.disable
|
|
431
|
+
def get_cfg_taylorseers():
|
|
432
|
+
cache_context = get_current_cache_context()
|
|
433
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
434
|
+
return cache_context.get_cfg_taylorseers()
|
|
435
|
+
|
|
436
|
+
|
|
324
437
|
@torch.compiler.disable
|
|
325
438
|
def is_taylorseer_cache_residual():
|
|
326
439
|
cache_context = get_current_cache_context()
|
|
@@ -459,6 +572,20 @@ def Bn_compute_blocks_ids():
|
|
|
459
572
|
return cache_context.Bn_compute_blocks_ids
|
|
460
573
|
|
|
461
574
|
|
|
575
|
+
@torch.compiler.disable
|
|
576
|
+
def do_separate_classifier_free_guidance():
|
|
577
|
+
cache_context = get_current_cache_context()
|
|
578
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
579
|
+
return cache_context.do_separate_classifier_free_guidance
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
@torch.compiler.disable
|
|
583
|
+
def is_separate_classifier_free_guidance_step():
|
|
584
|
+
cache_context = get_current_cache_context()
|
|
585
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
586
|
+
return cache_context.is_separate_classifier_free_guidance_step()
|
|
587
|
+
|
|
588
|
+
|
|
462
589
|
_current_cache_context: DBCacheContext = None
|
|
463
590
|
|
|
464
591
|
|
|
@@ -609,21 +736,31 @@ def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
|
609
736
|
if downsample_factor > 1:
|
|
610
737
|
buffer = buffer[..., ::downsample_factor]
|
|
611
738
|
buffer = buffer.contiguous()
|
|
612
|
-
|
|
739
|
+
if is_separate_classifier_free_guidance_step():
|
|
740
|
+
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
741
|
+
else:
|
|
742
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
613
743
|
|
|
614
744
|
|
|
615
745
|
@torch.compiler.disable
|
|
616
746
|
def get_Fn_buffer(prefix: str = "Fn"):
|
|
747
|
+
if is_separate_classifier_free_guidance_step():
|
|
748
|
+
return get_buffer(f"{prefix}_buffer_cfg")
|
|
617
749
|
return get_buffer(f"{prefix}_buffer")
|
|
618
750
|
|
|
619
751
|
|
|
620
752
|
@torch.compiler.disable
|
|
621
753
|
def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
622
|
-
|
|
754
|
+
if is_separate_classifier_free_guidance_step():
|
|
755
|
+
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
756
|
+
else:
|
|
757
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
623
758
|
|
|
624
759
|
|
|
625
760
|
@torch.compiler.disable
|
|
626
761
|
def get_Fn_encoder_buffer(prefix: str = "Fn"):
|
|
762
|
+
if is_separate_classifier_free_guidance_step():
|
|
763
|
+
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
627
764
|
return get_buffer(f"{prefix}_encoder_buffer")
|
|
628
765
|
|
|
629
766
|
|
|
@@ -634,7 +771,11 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
|
634
771
|
# This buffer is use for hidden states approximation.
|
|
635
772
|
if is_taylorseer_enabled():
|
|
636
773
|
# taylorseer, encoder_taylorseer
|
|
637
|
-
|
|
774
|
+
if is_separate_classifier_free_guidance_step():
|
|
775
|
+
taylorseer, _ = get_cfg_taylorseers()
|
|
776
|
+
else:
|
|
777
|
+
taylorseer, _ = get_taylorseers()
|
|
778
|
+
|
|
638
779
|
if taylorseer is not None:
|
|
639
780
|
# Use TaylorSeer to update the buffer
|
|
640
781
|
taylorseer.update(buffer)
|
|
@@ -644,15 +785,26 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
|
644
785
|
"TaylorSeer is enabled but not set in the cache context. "
|
|
645
786
|
"Falling back to default buffer retrieval."
|
|
646
787
|
)
|
|
647
|
-
|
|
788
|
+
if is_separate_classifier_free_guidance_step():
|
|
789
|
+
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
790
|
+
else:
|
|
791
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
648
792
|
else:
|
|
649
|
-
|
|
793
|
+
if is_separate_classifier_free_guidance_step():
|
|
794
|
+
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
795
|
+
else:
|
|
796
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
650
797
|
|
|
651
798
|
|
|
652
799
|
@torch.compiler.disable
|
|
653
800
|
def get_Bn_buffer(prefix: str = "Bn"):
|
|
654
801
|
if is_taylorseer_enabled():
|
|
655
|
-
taylorseer,
|
|
802
|
+
# taylorseer, encoder_taylorseer
|
|
803
|
+
if is_separate_classifier_free_guidance_step():
|
|
804
|
+
taylorseer, _ = get_cfg_taylorseers()
|
|
805
|
+
else:
|
|
806
|
+
taylorseer, _ = get_taylorseers()
|
|
807
|
+
|
|
656
808
|
if taylorseer is not None:
|
|
657
809
|
return taylorseer.approximate_value()
|
|
658
810
|
else:
|
|
@@ -662,8 +814,12 @@ def get_Bn_buffer(prefix: str = "Bn"):
|
|
|
662
814
|
"Falling back to default buffer retrieval."
|
|
663
815
|
)
|
|
664
816
|
# Fallback to default buffer retrieval
|
|
817
|
+
if is_separate_classifier_free_guidance_step():
|
|
818
|
+
return get_buffer(f"{prefix}_buffer_cfg")
|
|
665
819
|
return get_buffer(f"{prefix}_buffer")
|
|
666
820
|
else:
|
|
821
|
+
if is_separate_classifier_free_guidance_step():
|
|
822
|
+
return get_buffer(f"{prefix}_buffer_cfg")
|
|
667
823
|
return get_buffer(f"{prefix}_buffer")
|
|
668
824
|
|
|
669
825
|
|
|
@@ -672,7 +828,11 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
|
672
828
|
# This buffer is use for encoder hidden states approximation.
|
|
673
829
|
if is_encoder_taylorseer_enabled():
|
|
674
830
|
# taylorseer, encoder_taylorseer
|
|
675
|
-
|
|
831
|
+
if is_separate_classifier_free_guidance_step():
|
|
832
|
+
_, encoder_taylorseer = get_cfg_taylorseers()
|
|
833
|
+
else:
|
|
834
|
+
_, encoder_taylorseer = get_taylorseers()
|
|
835
|
+
|
|
676
836
|
if encoder_taylorseer is not None:
|
|
677
837
|
# Use TaylorSeer to update the buffer
|
|
678
838
|
encoder_taylorseer.update(buffer)
|
|
@@ -682,15 +842,25 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
|
682
842
|
"TaylorSeer is enabled but not set in the cache context. "
|
|
683
843
|
"Falling back to default buffer retrieval."
|
|
684
844
|
)
|
|
685
|
-
|
|
845
|
+
if is_separate_classifier_free_guidance_step():
|
|
846
|
+
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
847
|
+
else:
|
|
848
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
686
849
|
else:
|
|
687
|
-
|
|
850
|
+
if is_separate_classifier_free_guidance_step():
|
|
851
|
+
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
852
|
+
else:
|
|
853
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
688
854
|
|
|
689
855
|
|
|
690
856
|
@torch.compiler.disable
|
|
691
857
|
def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
692
858
|
if is_encoder_taylorseer_enabled():
|
|
693
|
-
|
|
859
|
+
if is_separate_classifier_free_guidance_step():
|
|
860
|
+
_, encoder_taylorseer = get_cfg_taylorseers()
|
|
861
|
+
else:
|
|
862
|
+
_, encoder_taylorseer = get_taylorseers()
|
|
863
|
+
|
|
694
864
|
if encoder_taylorseer is not None:
|
|
695
865
|
# Use TaylorSeer to approximate the value
|
|
696
866
|
return encoder_taylorseer.approximate_value()
|
|
@@ -701,8 +871,12 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
|
701
871
|
"Falling back to default buffer retrieval."
|
|
702
872
|
)
|
|
703
873
|
# Fallback to default buffer retrieval
|
|
874
|
+
if is_separate_classifier_free_guidance_step():
|
|
875
|
+
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
704
876
|
return get_buffer(f"{prefix}_encoder_buffer")
|
|
705
877
|
else:
|
|
878
|
+
if is_separate_classifier_free_guidance_step():
|
|
879
|
+
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
706
880
|
return get_buffer(f"{prefix}_encoder_buffer")
|
|
707
881
|
|
|
708
882
|
|
|
@@ -766,8 +940,13 @@ def get_can_use_cache(
|
|
|
766
940
|
):
|
|
767
941
|
if is_in_warmup():
|
|
768
942
|
return False
|
|
769
|
-
|
|
943
|
+
|
|
770
944
|
max_cached_steps = get_max_cached_steps()
|
|
945
|
+
if not is_separate_classifier_free_guidance_step():
|
|
946
|
+
cached_steps = get_cached_steps()
|
|
947
|
+
else:
|
|
948
|
+
cached_steps = get_cfg_cached_steps()
|
|
949
|
+
|
|
771
950
|
if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
|
|
772
951
|
if logger.isEnabledFor(logging.DEBUG):
|
|
773
952
|
logger.debug(
|
|
@@ -775,10 +954,12 @@ def get_can_use_cache(
|
|
|
775
954
|
"cannot use cache."
|
|
776
955
|
)
|
|
777
956
|
return False
|
|
957
|
+
|
|
778
958
|
if threshold is None or threshold <= 0.0:
|
|
779
959
|
threshold = get_residual_diff_threshold()
|
|
780
960
|
if threshold <= 0.0:
|
|
781
961
|
return False
|
|
962
|
+
|
|
782
963
|
downsample_factor = get_downsample_factor()
|
|
783
964
|
if downsample_factor > 1 and "Bn" not in prefix:
|
|
784
965
|
states_tensor = states_tensor[..., ::downsample_factor]
|
|
@@ -982,7 +1163,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
982
1163
|
# Check if the current step is in cache steps.
|
|
983
1164
|
# If so, we can skip some Bn blocks and directly
|
|
984
1165
|
# use the cached values.
|
|
985
|
-
return get_current_step() in get_cached_steps()
|
|
1166
|
+
return (get_current_step() in get_cached_steps()) or (
|
|
1167
|
+
get_current_step() in get_cfg_cached_steps()
|
|
1168
|
+
)
|
|
986
1169
|
|
|
987
1170
|
@torch.compiler.disable
|
|
988
1171
|
def _Fn_transformer_blocks(self):
|
|
@@ -1601,3 +1784,5 @@ def patch_cached_stats(
|
|
|
1601
1784
|
# TODO: Patch more cached stats to the transformer
|
|
1602
1785
|
transformer._cached_steps = get_cached_steps()
|
|
1603
1786
|
transformer._residual_diffs = get_residual_diffs()
|
|
1787
|
+
transformer._cfg_cached_steps = get_cfg_cached_steps()
|
|
1788
|
+
transformer._cfg_residual_diffs = get_cfg_residual_diffs()
|
|
@@ -70,7 +70,7 @@ def apply_db_cache_on_pipe(
|
|
|
70
70
|
# "slg_layers": slg_layers,
|
|
71
71
|
# "slg_start": slg_start,
|
|
72
72
|
# "slg_end": slg_end,
|
|
73
|
-
"num_inference_steps": kwargs.get("num_inference_steps", 50),
|
|
73
|
+
# "num_inference_steps": kwargs.get("num_inference_steps", 50),
|
|
74
74
|
"warmup_steps": warmup_steps,
|
|
75
75
|
"max_cached_steps": max_cached_steps,
|
|
76
76
|
},
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -154,6 +154,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
|
|
|
154
154
|
- [🔥Supported Models](#supported)
|
|
155
155
|
- [⚡️Dual Block Cache](#dbcache)
|
|
156
156
|
- [🔥Hybrid TaylorSeer](#taylorseer)
|
|
157
|
+
- [⚡️Hybrid Cache CFG](#cfg)
|
|
157
158
|
- [🎉First Block Cache](#fbcache)
|
|
158
159
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
159
160
|
- [🎉Context Parallelism](#context-parallelism)
|
|
@@ -299,6 +300,19 @@ cache_options = {
|
|
|
299
300
|
|24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
|
|
300
301
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=105px>|
|
|
301
302
|
|
|
303
|
+
## ⚡️Hybrid Cache CFG
|
|
304
|
+
|
|
305
|
+
<div id="cfg"></div>
|
|
306
|
+
|
|
307
|
+
CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `do_separate_classifier_free_guidance` param to False. Otherwise, set it to True. Wan 2.1: True. FLUX.1, HunyunVideo, CogVideoX, Mochi: False.
|
|
308
|
+
|
|
309
|
+
```python
|
|
310
|
+
cache_options = {
|
|
311
|
+
"do_separate_classifier_free_guidance": True, # Wan 2.1
|
|
312
|
+
"cfg_compute_first": False,
|
|
313
|
+
}
|
|
314
|
+
```
|
|
315
|
+
|
|
302
316
|
## 🎉FBCache: First Block Cache
|
|
303
317
|
|
|
304
318
|
<div id="fbcache"></div>
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=1LUN_sRKOiFInoB6AlW6TYoQMCh1Z4KutwcHNvHcfB0,511
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
|
|
6
6
|
cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
|
|
7
7
|
cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
-
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=
|
|
9
|
+
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=H7u5zIAdEjiYU0QvWYIMj3lKYI4D8cmDLy7eZ9tyoyU,66848
|
|
10
10
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
|
|
11
11
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=fibkeU-FHa30BNT-uPV2Eqcd5IRli07EKb25tMDp23c,2270
|
|
12
12
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=fddSpTHXU24COMGAY-Z21EmHHAEArZBv_-XLRFD6ADU,2625
|
|
13
13
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=wcZdBhjUB8WSfz40A268BtSe3nr_hRsIi2BNlg1FHRU,9965
|
|
14
14
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=Cmy0KHRDgwXqtmqfkrr7kw0CP6CmkSnuz29gDHcD6sQ,2262
|
|
15
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=
|
|
15
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
|
|
16
16
|
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=so1wGdb8W0ATwrjv7E5IEZLPcobybaY1HJa6hBYlOOQ,34698
|
|
18
18
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
|
|
@@ -33,8 +33,8 @@ cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k
|
|
|
33
33
|
cache_dit/compile/utils.py,sha256=KU60xc474Anbj7Y_FLRFmNxEjVYLLXkhbtCLXO7o_Tc,3699
|
|
34
34
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
35
|
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
-
cache_dit-0.2.
|
|
37
|
-
cache_dit-0.2.
|
|
38
|
-
cache_dit-0.2.
|
|
39
|
-
cache_dit-0.2.
|
|
40
|
-
cache_dit-0.2.
|
|
36
|
+
cache_dit-0.2.4.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
37
|
+
cache_dit-0.2.4.dist-info/METADATA,sha256=1oDgkkUwGVfwX_jCyU0jHbQTVQDfL59OEbrUb_9SVF4,25442
|
|
38
|
+
cache_dit-0.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
39
|
+
cache_dit-0.2.4.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
40
|
+
cache_dit-0.2.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|