cache-dit 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.3'
21
- __version_tuple__ = version_tuple = (0, 2, 3)
20
+ __version__ = version = '0.2.5'
21
+ __version_tuple__ = version_tuple = (0, 2, 5)
@@ -49,18 +49,18 @@ class DBCacheContext:
49
49
 
50
50
  # Other settings
51
51
  downsample_factor: int = 1
52
- num_inference_steps: int = -1
52
+ num_inference_steps: int = -1 # un-used now
53
53
  warmup_steps: int = 0 # DON'T Cache in warmup steps
54
54
  # DON'T Cache if the number of cached steps >= max_cached_steps
55
- max_cached_steps: int = -1
55
+ max_cached_steps: int = -1 # for both CFG and non-CFG
56
56
 
57
57
  # Statistics for botch alter cache and non-alter cache
58
58
  # Record the steps that have been cached, both alter cache and non-alter cache
59
- executed_steps: int = 0 # cache + non-cache steps
60
- cached_steps: List[int] = dataclasses.field(default_factory=list)
61
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
62
- default_factory=lambda: defaultdict(float),
63
- )
59
+ executed_steps: int = 0 # cache + non-cache steps pippeline
60
+ # steps for transformer, for CFG, transformer_executed_steps will
61
+ # be double of executed_steps.
62
+ transformer_executed_steps: int = 0
63
+
64
64
  # Support TaylorSeers in Dual Block Cache
65
65
  # Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
66
66
  # Url: https://arxiv.org/pdf/2503.06923
@@ -71,8 +71,31 @@ class DBCacheContext:
71
71
  taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
72
72
  taylorseer: Optional[TaylorSeer] = None
73
73
  encoder_tarlorseer: Optional[TaylorSeer] = None
74
- alter_taylorseer: Optional[TaylorSeer] = None
75
- alter_encoder_taylorseer: Optional[TaylorSeer] = None
74
+
75
+ # Support do_separate_classifier_free_guidance, such as Wan 2.1
76
+ # For model that fused CFG and non-CFG into single forward step,
77
+ # should set do_separate_classifier_free_guidance as False. For
78
+ # example: CogVideoX, HunyuanVideo, Mochi.
79
+ do_separate_classifier_free_guidance: bool = False
80
+ # Compute cfg forward first or not, default False, namely,
81
+ # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
82
+ cfg_compute_first: bool = False
83
+ # Compute spearate diff values for CFG and non-CFG step,
84
+ # default True. If False, we will use the computed diff from
85
+ # current non-CFG transformer step for current CFG step.
86
+ cfg_diff_compute_separate: bool = True
87
+ cfg_taylorseer: Optional[TaylorSeer] = None
88
+ cfg_encoder_taylorseer: Optional[TaylorSeer] = None
89
+
90
+ # CFG & non-CFG cached steps
91
+ cached_steps: List[int] = dataclasses.field(default_factory=list)
92
+ residual_diffs: DefaultDict[str, float] = dataclasses.field(
93
+ default_factory=lambda: defaultdict(float),
94
+ )
95
+ cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
96
+ cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
97
+ default_factory=lambda: defaultdict(float),
98
+ )
76
99
 
77
100
  # TODO: Support SLG in Dual Block Cache
78
101
  # Skip Layer Guidance, SLG
@@ -83,6 +106,17 @@ class DBCacheContext:
83
106
 
84
107
  @torch.compiler.disable
85
108
  def __post_init__(self):
109
+ # Some checks for settings
110
+ if self.do_separate_classifier_free_guidance:
111
+ assert self.enable_alter_cache is False, (
112
+ "enable_alter_cache must set as False if "
113
+ "do_separate_classifier_free_guidance is enabled."
114
+ )
115
+ if self.cfg_diff_compute_separate:
116
+ assert self.cfg_compute_first is False, (
117
+ "cfg_compute_first must set as False if "
118
+ "cfg_diff_compute_separate is enabled."
119
+ )
86
120
 
87
121
  if "warmup_steps" not in self.taylorseer_kwargs:
88
122
  # If warmup_steps is not set in taylorseer_kwargs,
@@ -99,13 +133,13 @@ class DBCacheContext:
99
133
 
100
134
  if self.enable_taylorseer:
101
135
  self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
102
- if self.enable_alter_cache:
103
- self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
136
+ if self.do_separate_classifier_free_guidance:
137
+ self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
104
138
 
105
139
  if self.enable_encoder_taylorseer:
106
140
  self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
107
- if self.enable_alter_cache:
108
- self.alter_encoder_taylorseer = TaylorSeer(
141
+ if self.do_separate_classifier_free_guidance:
142
+ self.cfg_encoder_taylorseer = TaylorSeer(
109
143
  **self.taylorseer_kwargs
110
144
  )
111
145
 
@@ -159,18 +193,34 @@ class DBCacheContext:
159
193
 
160
194
  @torch.compiler.disable
161
195
  def mark_step_begin(self):
162
- if not self.enable_alter_cache:
196
+ # Always increase transformer executed steps
197
+ # incr step: prev 0 -> 1; prev 1 -> 2
198
+ # current step: incr step - 1
199
+ self.transformer_executed_steps += 1
200
+ if not self.do_separate_classifier_free_guidance:
163
201
  self.executed_steps += 1
164
202
  else:
165
- self.executed_steps += 1
203
+ # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
204
+ if not self.cfg_compute_first:
205
+ if not self.is_separate_classifier_free_guidance_step():
206
+ # transformer step: 0,2,4,...
207
+ self.executed_steps += 1
208
+ else:
209
+ if self.is_separate_classifier_free_guidance_step():
210
+ # transformer step: 0,2,4,...
211
+ self.executed_steps += 1
212
+
213
+ if not self.enable_alter_cache:
166
214
  # 0 F 1 T 2 F 3 T 4 F 5 T ...
167
215
  self.is_alter_cache = not self.is_alter_cache
168
216
 
169
217
  # Reset the cached steps and residual diffs at the beginning
170
218
  # of each inference.
171
- if self.get_current_step() == 0:
219
+ if self.get_current_transformer_step() == 0:
172
220
  self.cached_steps.clear()
173
221
  self.residual_diffs.clear()
222
+ self.cfg_cached_steps.clear()
223
+ self.cfg_residual_diffs.clear()
174
224
  self.reset_incremental_names()
175
225
  # Reset the TaylorSeers cache at the beginning of each inference.
176
226
  # reset_cache will set the current step to -1 for TaylorSeer,
@@ -180,44 +230,100 @@ class DBCacheContext:
180
230
  taylorseer.reset_cache()
181
231
  if encoder_taylorseer is not None:
182
232
  encoder_taylorseer.reset_cache()
233
+ cfg_taylorseer, cfg_encoder_taylorseer = (
234
+ self.get_cfg_taylorseers()
235
+ )
236
+ if cfg_taylorseer is not None:
237
+ cfg_taylorseer.reset_cache()
238
+ if cfg_encoder_taylorseer is not None:
239
+ cfg_encoder_taylorseer.reset_cache()
183
240
 
184
241
  # mark_step_begin of TaylorSeer must be called after the cache is reset.
185
242
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
186
- taylorseer, encoder_taylorseer = self.get_taylorseers()
187
- if taylorseer is not None:
188
- taylorseer.mark_step_begin()
189
- if encoder_taylorseer is not None:
190
- encoder_taylorseer.mark_step_begin()
243
+ if self.do_separate_classifier_free_guidance:
244
+ # Assume non-CFG steps: 0, 2, 4, 6, ...
245
+ if not self.is_separate_classifier_free_guidance_step():
246
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
247
+ if taylorseer is not None:
248
+ taylorseer.mark_step_begin()
249
+ if encoder_taylorseer is not None:
250
+ encoder_taylorseer.mark_step_begin()
251
+ else:
252
+ cfg_taylorseer, cfg_encoder_taylorseer = (
253
+ self.get_cfg_taylorseers()
254
+ )
255
+ if cfg_taylorseer is not None:
256
+ cfg_taylorseer.mark_step_begin()
257
+ if cfg_encoder_taylorseer is not None:
258
+ cfg_encoder_taylorseer.mark_step_begin()
259
+ else:
260
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
261
+ if taylorseer is not None:
262
+ taylorseer.mark_step_begin()
263
+ if encoder_taylorseer is not None:
264
+ encoder_taylorseer.mark_step_begin()
191
265
 
192
266
  @torch.compiler.disable
193
267
  def get_taylorseers(self):
194
- if self.enable_alter_cache and self.is_alter_cache:
195
- return self.alter_taylorseer, self.alter_encoder_taylorseer
196
268
  return self.taylorseer, self.encoder_tarlorseer
197
269
 
270
+ @torch.compiler.disable
271
+ def get_cfg_taylorseers(self):
272
+ return self.cfg_taylorseer, self.cfg_encoder_taylorseer
273
+
198
274
  @torch.compiler.disable
199
275
  def add_residual_diff(self, diff):
276
+ # step: executed_steps - 1, not transformer_steps - 1
200
277
  step = str(self.get_current_step())
201
- if step not in self.residual_diffs:
202
- # Only add the diff if it is not already recorded for this step
203
- self.residual_diffs[step] = diff
278
+ # Only add the diff if it is not already recorded for this step
279
+ if not self.is_separate_classifier_free_guidance_step():
280
+ if step not in self.residual_diffs:
281
+ self.residual_diffs[step] = diff
282
+ else:
283
+ if step not in self.cfg_residual_diffs:
284
+ self.cfg_residual_diffs[step] = diff
204
285
 
205
286
  @torch.compiler.disable
206
287
  def get_residual_diffs(self):
207
288
  return self.residual_diffs.copy()
208
289
 
290
+ @torch.compiler.disable
291
+ def get_cfg_residual_diffs(self):
292
+ return self.cfg_residual_diffs.copy()
293
+
209
294
  @torch.compiler.disable
210
295
  def add_cached_step(self):
211
- self.cached_steps.append(self.get_current_step())
296
+ if not self.is_separate_classifier_free_guidance_step():
297
+ self.cached_steps.append(self.get_current_step())
298
+ else:
299
+ self.cfg_cached_steps.append(self.get_current_step())
212
300
 
213
301
  @torch.compiler.disable
214
302
  def get_cached_steps(self):
215
303
  return self.cached_steps.copy()
216
304
 
305
+ @torch.compiler.disable
306
+ def get_cfg_cached_steps(self):
307
+ return self.cfg_cached_steps.copy()
308
+
217
309
  @torch.compiler.disable
218
310
  def get_current_step(self):
219
311
  return self.executed_steps - 1
220
312
 
313
+ @torch.compiler.disable
314
+ def get_current_transformer_step(self):
315
+ return self.transformer_executed_steps - 1
316
+
317
+ @torch.compiler.disable
318
+ def is_separate_classifier_free_guidance_step(self):
319
+ if not self.do_separate_classifier_free_guidance:
320
+ return False
321
+ if self.cfg_compute_first:
322
+ # CFG steps: 0, 2, 4, 6, ...
323
+ return self.get_current_transformer_step() % 2 == 0
324
+ # CFG steps: 1, 3, 5, 7, ...
325
+ return self.get_current_transformer_step() % 2 != 0
326
+
221
327
  @torch.compiler.disable
222
328
  def is_in_warmup(self):
223
329
  return self.get_current_step() < self.warmup_steps
@@ -265,6 +371,35 @@ def get_current_step():
265
371
  return cache_context.get_current_step()
266
372
 
267
373
 
374
+ @torch.compiler.disable
375
+ def get_current_step_residual_diff():
376
+ cache_context = get_current_cache_context()
377
+ assert cache_context is not None, "cache_context must be set before"
378
+ step = str(get_current_step())
379
+ residual_diffs = get_residual_diffs()
380
+ if step in residual_diffs:
381
+ return residual_diffs[step]
382
+ return None
383
+
384
+
385
+ @torch.compiler.disable
386
+ def get_current_step_cfg_residual_diff():
387
+ cache_context = get_current_cache_context()
388
+ assert cache_context is not None, "cache_context must be set before"
389
+ step = str(get_current_step())
390
+ cfg_residual_diffs = get_cfg_residual_diffs()
391
+ if step in cfg_residual_diffs:
392
+ return cfg_residual_diffs[step]
393
+ return None
394
+
395
+
396
+ @torch.compiler.disable
397
+ def get_current_transformer_step():
398
+ cache_context = get_current_cache_context()
399
+ assert cache_context is not None, "cache_context must be set before"
400
+ return cache_context.get_current_transformer_step()
401
+
402
+
268
403
  @torch.compiler.disable
269
404
  def get_cached_steps():
270
405
  cache_context = get_current_cache_context()
@@ -272,6 +407,13 @@ def get_cached_steps():
272
407
  return cache_context.get_cached_steps()
273
408
 
274
409
 
410
+ @torch.compiler.disable
411
+ def get_cfg_cached_steps():
412
+ cache_context = get_current_cache_context()
413
+ assert cache_context is not None, "cache_context must be set before"
414
+ return cache_context.get_cfg_cached_steps()
415
+
416
+
275
417
  @torch.compiler.disable
276
418
  def get_max_cached_steps():
277
419
  cache_context = get_current_cache_context()
@@ -300,6 +442,13 @@ def get_residual_diffs():
300
442
  return cache_context.get_residual_diffs()
301
443
 
302
444
 
445
+ @torch.compiler.disable
446
+ def get_cfg_residual_diffs():
447
+ cache_context = get_current_cache_context()
448
+ assert cache_context is not None, "cache_context must be set before"
449
+ return cache_context.get_cfg_residual_diffs()
450
+
451
+
303
452
  @torch.compiler.disable
304
453
  def is_taylorseer_enabled():
305
454
  cache_context = get_current_cache_context()
@@ -321,6 +470,13 @@ def get_taylorseers():
321
470
  return cache_context.get_taylorseers()
322
471
 
323
472
 
473
+ @torch.compiler.disable
474
+ def get_cfg_taylorseers():
475
+ cache_context = get_current_cache_context()
476
+ assert cache_context is not None, "cache_context must be set before"
477
+ return cache_context.get_cfg_taylorseers()
478
+
479
+
324
480
  @torch.compiler.disable
325
481
  def is_taylorseer_cache_residual():
326
482
  cache_context = get_current_cache_context()
@@ -459,6 +615,27 @@ def Bn_compute_blocks_ids():
459
615
  return cache_context.Bn_compute_blocks_ids
460
616
 
461
617
 
618
+ @torch.compiler.disable
619
+ def do_separate_classifier_free_guidance():
620
+ cache_context = get_current_cache_context()
621
+ assert cache_context is not None, "cache_context must be set before"
622
+ return cache_context.do_separate_classifier_free_guidance
623
+
624
+
625
+ @torch.compiler.disable
626
+ def is_separate_classifier_free_guidance_step():
627
+ cache_context = get_current_cache_context()
628
+ assert cache_context is not None, "cache_context must be set before"
629
+ return cache_context.is_separate_classifier_free_guidance_step()
630
+
631
+
632
+ @torch.compiler.disable
633
+ def cfg_diff_compute_separate():
634
+ cache_context = get_current_cache_context()
635
+ assert cache_context is not None, "cache_context must be set before"
636
+ return cache_context.cfg_diff_compute_separate
637
+
638
+
462
639
  _current_cache_context: DBCacheContext = None
463
640
 
464
641
 
@@ -559,38 +736,49 @@ def are_two_tensors_similar(
559
736
  add_residual_diff(-2.0)
560
737
  return False
561
738
 
562
- # Find the most significant token through t1 and t2, and
563
- # consider the diff of the significant token. The more significant,
564
- # the more important.
565
- condition_thresh = get_important_condition_threshold()
566
- if condition_thresh > 0.0:
567
- raw_diff = (t1 - t2).abs() # [B, seq_len, d]
568
- token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
569
- token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
570
- # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
571
- token_diff = token_m_df / token_m_t1 # [B, seq_len]
572
- condition = token_diff > condition_thresh # [B, seq_len]
573
- if condition.sum() > 0:
574
- condition = condition.unsqueeze(-1) # [B, seq_len, 1]
575
- condition = condition.expand_as(raw_diff) # [B, seq_len, d]
576
- mean_diff = raw_diff[condition].mean()
577
- mean_t1 = t1[condition].abs().mean()
739
+ if all(
740
+ (
741
+ do_separate_classifier_free_guidance(),
742
+ is_separate_classifier_free_guidance_step(),
743
+ not cfg_diff_compute_separate(),
744
+ get_current_step_residual_diff() is not None,
745
+ )
746
+ ):
747
+ # Reuse computed diff value from non-CFG step
748
+ diff = get_current_step_residual_diff()
749
+ else:
750
+ # Find the most significant token through t1 and t2, and
751
+ # consider the diff of the significant token. The more significant,
752
+ # the more important.
753
+ condition_thresh = get_important_condition_threshold()
754
+ if condition_thresh > 0.0:
755
+ raw_diff = (t1 - t2).abs() # [B, seq_len, d]
756
+ token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
757
+ token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
758
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
759
+ token_diff = token_m_df / token_m_t1 # [B, seq_len]
760
+ condition = token_diff > condition_thresh # [B, seq_len]
761
+ if condition.sum() > 0:
762
+ condition = condition.unsqueeze(-1) # [B, seq_len, 1]
763
+ condition = condition.expand_as(raw_diff) # [B, seq_len, d]
764
+ mean_diff = raw_diff[condition].mean()
765
+ mean_t1 = t1[condition].abs().mean()
766
+ else:
767
+ mean_diff = (t1 - t2).abs().mean()
768
+ mean_t1 = t1.abs().mean()
578
769
  else:
770
+ # Use the mean of the absolute difference of the tensors
579
771
  mean_diff = (t1 - t2).abs().mean()
580
772
  mean_t1 = t1.abs().mean()
581
- else:
582
- # Use the mean of the absolute difference of the tensors
583
- mean_diff = (t1 - t2).abs().mean()
584
- mean_t1 = t1.abs().mean()
585
773
 
586
- if parallelized:
587
- mean_diff = DP.all_reduce_sync(mean_diff, "avg")
588
- mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
774
+ if parallelized:
775
+ mean_diff = DP.all_reduce_sync(mean_diff, "avg")
776
+ mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
589
777
 
590
- # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
591
- # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
592
- # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
593
- diff = (mean_diff / mean_t1).item()
778
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
779
+ # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
780
+ # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
781
+ diff = (mean_diff / mean_t1).item()
594
782
 
595
783
  if logger.isEnabledFor(logging.DEBUG):
596
784
  logger.debug(f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}")
@@ -600,6 +788,26 @@ def are_two_tensors_similar(
600
788
  return diff < threshold
601
789
 
602
790
 
791
+ @torch.compiler.disable
792
+ def _debugging_set_buffer(prefix):
793
+ if logger.isEnabledFor(logging.DEBUG):
794
+ logger.debug(
795
+ f"set {prefix}, "
796
+ f"transformer step: {get_current_transformer_step()}, "
797
+ f"executed step: {get_current_step()}"
798
+ )
799
+
800
+
801
+ @torch.compiler.disable
802
+ def _debugging_get_buffer(prefix):
803
+ if logger.isEnabledFor(logging.DEBUG):
804
+ logger.debug(
805
+ f"get {prefix}, "
806
+ f"transformer step: {get_current_transformer_step()}, "
807
+ f"executed step: {get_current_step()}"
808
+ )
809
+
810
+
603
811
  # Fn buffers
604
812
  @torch.compiler.disable
605
813
  def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
@@ -609,21 +817,39 @@ def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
609
817
  if downsample_factor > 1:
610
818
  buffer = buffer[..., ::downsample_factor]
611
819
  buffer = buffer.contiguous()
612
- set_buffer(f"{prefix}_buffer", buffer)
820
+ if is_separate_classifier_free_guidance_step():
821
+ _debugging_set_buffer(f"{prefix}_buffer_cfg")
822
+ set_buffer(f"{prefix}_buffer_cfg", buffer)
823
+ else:
824
+ _debugging_set_buffer(f"{prefix}_buffer")
825
+ set_buffer(f"{prefix}_buffer", buffer)
613
826
 
614
827
 
615
828
  @torch.compiler.disable
616
829
  def get_Fn_buffer(prefix: str = "Fn"):
830
+ if is_separate_classifier_free_guidance_step():
831
+ _debugging_get_buffer(f"{prefix}_buffer_cfg")
832
+ return get_buffer(f"{prefix}_buffer_cfg")
833
+ _debugging_get_buffer(f"{prefix}_buffer")
617
834
  return get_buffer(f"{prefix}_buffer")
618
835
 
619
836
 
620
837
  @torch.compiler.disable
621
838
  def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
622
- set_buffer(f"{prefix}_encoder_buffer", buffer)
839
+ if is_separate_classifier_free_guidance_step():
840
+ _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
841
+ set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
842
+ else:
843
+ _debugging_set_buffer(f"{prefix}_encoder_buffer")
844
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
623
845
 
624
846
 
625
847
  @torch.compiler.disable
626
848
  def get_Fn_encoder_buffer(prefix: str = "Fn"):
849
+ if is_separate_classifier_free_guidance_step():
850
+ _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
851
+ return get_buffer(f"{prefix}_encoder_buffer_cfg")
852
+ _debugging_get_buffer(f"{prefix}_encoder_buffer")
627
853
  return get_buffer(f"{prefix}_encoder_buffer")
628
854
 
629
855
 
@@ -634,7 +860,11 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
634
860
  # This buffer is use for hidden states approximation.
635
861
  if is_taylorseer_enabled():
636
862
  # taylorseer, encoder_taylorseer
637
- taylorseer, _ = get_taylorseers()
863
+ if is_separate_classifier_free_guidance_step():
864
+ taylorseer, _ = get_cfg_taylorseers()
865
+ else:
866
+ taylorseer, _ = get_taylorseers()
867
+
638
868
  if taylorseer is not None:
639
869
  # Use TaylorSeer to update the buffer
640
870
  taylorseer.update(buffer)
@@ -644,15 +874,30 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
644
874
  "TaylorSeer is enabled but not set in the cache context. "
645
875
  "Falling back to default buffer retrieval."
646
876
  )
647
- set_buffer(f"{prefix}_buffer", buffer)
877
+ if is_separate_classifier_free_guidance_step():
878
+ _debugging_set_buffer(f"{prefix}_buffer_cfg")
879
+ set_buffer(f"{prefix}_buffer_cfg", buffer)
880
+ else:
881
+ _debugging_set_buffer(f"{prefix}_buffer")
882
+ set_buffer(f"{prefix}_buffer", buffer)
648
883
  else:
649
- set_buffer(f"{prefix}_buffer", buffer)
884
+ if is_separate_classifier_free_guidance_step():
885
+ _debugging_set_buffer(f"{prefix}_buffer_cfg")
886
+ set_buffer(f"{prefix}_buffer_cfg", buffer)
887
+ else:
888
+ _debugging_set_buffer(f"{prefix}_buffer")
889
+ set_buffer(f"{prefix}_buffer", buffer)
650
890
 
651
891
 
652
892
  @torch.compiler.disable
653
893
  def get_Bn_buffer(prefix: str = "Bn"):
654
894
  if is_taylorseer_enabled():
655
- taylorseer, _ = get_taylorseers()
895
+ # taylorseer, encoder_taylorseer
896
+ if is_separate_classifier_free_guidance_step():
897
+ taylorseer, _ = get_cfg_taylorseers()
898
+ else:
899
+ taylorseer, _ = get_taylorseers()
900
+
656
901
  if taylorseer is not None:
657
902
  return taylorseer.approximate_value()
658
903
  else:
@@ -662,8 +907,16 @@ def get_Bn_buffer(prefix: str = "Bn"):
662
907
  "Falling back to default buffer retrieval."
663
908
  )
664
909
  # Fallback to default buffer retrieval
910
+ if is_separate_classifier_free_guidance_step():
911
+ _debugging_get_buffer(f"{prefix}_buffer_cfg")
912
+ return get_buffer(f"{prefix}_buffer_cfg")
913
+ _debugging_get_buffer(f"{prefix}_buffer")
665
914
  return get_buffer(f"{prefix}_buffer")
666
915
  else:
916
+ if is_separate_classifier_free_guidance_step():
917
+ _debugging_get_buffer(f"{prefix}_buffer_cfg")
918
+ return get_buffer(f"{prefix}_buffer_cfg")
919
+ _debugging_get_buffer(f"{prefix}_buffer")
667
920
  return get_buffer(f"{prefix}_buffer")
668
921
 
669
922
 
@@ -672,7 +925,11 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
672
925
  # This buffer is use for encoder hidden states approximation.
673
926
  if is_encoder_taylorseer_enabled():
674
927
  # taylorseer, encoder_taylorseer
675
- _, encoder_taylorseer = get_taylorseers()
928
+ if is_separate_classifier_free_guidance_step():
929
+ _, encoder_taylorseer = get_cfg_taylorseers()
930
+ else:
931
+ _, encoder_taylorseer = get_taylorseers()
932
+
676
933
  if encoder_taylorseer is not None:
677
934
  # Use TaylorSeer to update the buffer
678
935
  encoder_taylorseer.update(buffer)
@@ -682,15 +939,29 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
682
939
  "TaylorSeer is enabled but not set in the cache context. "
683
940
  "Falling back to default buffer retrieval."
684
941
  )
685
- set_buffer(f"{prefix}_encoder_buffer", buffer)
942
+ if is_separate_classifier_free_guidance_step():
943
+ _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
944
+ set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
945
+ else:
946
+ _debugging_set_buffer(f"{prefix}_encoder_buffer")
947
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
686
948
  else:
687
- set_buffer(f"{prefix}_encoder_buffer", buffer)
949
+ if is_separate_classifier_free_guidance_step():
950
+ _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
951
+ set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
952
+ else:
953
+ _debugging_set_buffer(f"{prefix}_encoder_buffer")
954
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
688
955
 
689
956
 
690
957
  @torch.compiler.disable
691
958
  def get_Bn_encoder_buffer(prefix: str = "Bn"):
692
959
  if is_encoder_taylorseer_enabled():
693
- _, encoder_taylorseer = get_taylorseers()
960
+ if is_separate_classifier_free_guidance_step():
961
+ _, encoder_taylorseer = get_cfg_taylorseers()
962
+ else:
963
+ _, encoder_taylorseer = get_taylorseers()
964
+
694
965
  if encoder_taylorseer is not None:
695
966
  # Use TaylorSeer to approximate the value
696
967
  return encoder_taylorseer.approximate_value()
@@ -701,8 +972,16 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
701
972
  "Falling back to default buffer retrieval."
702
973
  )
703
974
  # Fallback to default buffer retrieval
975
+ if is_separate_classifier_free_guidance_step():
976
+ _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
977
+ return get_buffer(f"{prefix}_encoder_buffer_cfg")
978
+ _debugging_get_buffer(f"{prefix}_encoder_buffer")
704
979
  return get_buffer(f"{prefix}_encoder_buffer")
705
980
  else:
981
+ if is_separate_classifier_free_guidance_step():
982
+ _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
983
+ return get_buffer(f"{prefix}_encoder_buffer_cfg")
984
+ _debugging_get_buffer(f"{prefix}_encoder_buffer")
706
985
  return get_buffer(f"{prefix}_encoder_buffer")
707
986
 
708
987
 
@@ -766,8 +1045,13 @@ def get_can_use_cache(
766
1045
  ):
767
1046
  if is_in_warmup():
768
1047
  return False
769
- cached_steps = get_cached_steps()
1048
+
770
1049
  max_cached_steps = get_max_cached_steps()
1050
+ if not is_separate_classifier_free_guidance_step():
1051
+ cached_steps = get_cached_steps()
1052
+ else:
1053
+ cached_steps = get_cfg_cached_steps()
1054
+
771
1055
  if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
772
1056
  if logger.isEnabledFor(logging.DEBUG):
773
1057
  logger.debug(
@@ -775,10 +1059,12 @@ def get_can_use_cache(
775
1059
  "cannot use cache."
776
1060
  )
777
1061
  return False
1062
+
778
1063
  if threshold is None or threshold <= 0.0:
779
1064
  threshold = get_residual_diff_threshold()
780
1065
  if threshold <= 0.0:
781
1066
  return False
1067
+
782
1068
  downsample_factor = get_downsample_factor()
783
1069
  if downsample_factor > 1 and "Bn" not in prefix:
784
1070
  states_tensor = states_tensor[..., ::downsample_factor]
@@ -982,7 +1268,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
982
1268
  # Check if the current step is in cache steps.
983
1269
  # If so, we can skip some Bn blocks and directly
984
1270
  # use the cached values.
985
- return get_current_step() in get_cached_steps()
1271
+ return (get_current_step() in get_cached_steps()) or (
1272
+ get_current_step() in get_cfg_cached_steps()
1273
+ )
986
1274
 
987
1275
  @torch.compiler.disable
988
1276
  def _Fn_transformer_blocks(self):
@@ -1601,3 +1889,5 @@ def patch_cached_stats(
1601
1889
  # TODO: Patch more cached stats to the transformer
1602
1890
  transformer._cached_steps = get_cached_steps()
1603
1891
  transformer._residual_diffs = get_residual_diffs()
1892
+ transformer._cfg_cached_steps = get_cfg_cached_steps()
1893
+ transformer._cfg_residual_diffs = get_cfg_residual_diffs()
@@ -70,7 +70,7 @@ def apply_db_cache_on_pipe(
70
70
  # "slg_layers": slg_layers,
71
71
  # "slg_start": slg_start,
72
72
  # "slg_end": slg_end,
73
- "num_inference_steps": kwargs.get("num_inference_steps", 50),
73
+ # "num_inference_steps": kwargs.get("num_inference_steps", 50),
74
74
  "warmup_steps": warmup_steps,
75
75
  "max_cached_steps": max_cached_steps,
76
76
  },
@@ -370,8 +370,8 @@ def apply_prev_hidden_states_residual(
370
370
  hidden_states = hidden_states_residual + hidden_states
371
371
 
372
372
  hidden_states = hidden_states.contiguous()
373
- # NOTE: We should also support taylorseer for
374
- # encoder_hidden_states approximation. Please
373
+ # NOTE: We should also support taylorseer for
374
+ # encoder_hidden_states approximation. Please
375
375
  # use DBCache instead.
376
376
  else:
377
377
  hidden_states_residual = get_hidden_states_residual()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -44,7 +44,7 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.2.2-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.2-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
50
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers <br>a set of training-free cache accelerators for DiT: <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">TaylorSeer</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
@@ -154,6 +154,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
154
154
  - [🔥Supported Models](#supported)
155
155
  - [⚡️Dual Block Cache](#dbcache)
156
156
  - [🔥Hybrid TaylorSeer](#taylorseer)
157
+ - [⚡️Hybrid Cache CFG](#cfg)
157
158
  - [🎉First Block Cache](#fbcache)
158
159
  - [⚡️Dynamic Block Prune](#dbprune)
159
160
  - [🎉Context Parallelism](#context-parallelism)
@@ -168,7 +169,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
168
169
  You can install the stable release of `cache-dit` from PyPI:
169
170
 
170
171
  ```bash
171
- pip3 install cache-dit
172
+ pip3 install -U cache-dit
172
173
  ```
173
174
  Or you can install the latest develop version from GitHub:
174
175
 
@@ -180,11 +181,13 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
180
181
 
181
182
  <div id="supported"></div>
182
183
 
183
- - [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples)
184
- - [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples)
184
+ - [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
185
+ - [🚀FLUX.1-Fill-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
186
+ - [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
185
187
  - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
186
188
  - [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
187
- - [🚀Wan2.1](https://github.com/vipshop/cache-dit/raw/main/examples)
189
+ - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
190
+ - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
188
191
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
189
192
 
190
193
 
@@ -280,7 +283,7 @@ cache_options = {
280
283
  "taylorseer_kwargs": {
281
284
  "n_derivatives": 2, # default is 2.
282
285
  },
283
- "warmup_steps": 3, # n_derivatives + 1
286
+ "warmup_steps": 3, # prefer: >= n_derivatives + 1
284
287
  "residual_diff_threshold": 0.12,
285
288
  }
286
289
  ```
@@ -299,6 +302,30 @@ cache_options = {
299
302
  |24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
300
303
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=105px>|
301
304
 
305
+ ## ⚡️Hybrid Cache CFG
306
+
307
+ <div id="cfg"></div>
308
+
309
+ CacheDiT supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `do_separate_classifier_free_guidance` param to **False (default)**. Otherwise, set it to True. For examples:
310
+
311
+ ```python
312
+ cache_options = {
313
+ # CFG: classifier free guidance or not
314
+ # For model that fused CFG and non-CFG into single forward step,
315
+ # should set do_separate_classifier_free_guidance as False.
316
+ # For example, set it as True for Wan 2.1 and set it as False
317
+ # for FLUX.1, HunyuanVideo, CogVideoX, Mochi.
318
+ "do_separate_classifier_free_guidance": True, # Wan 2.1
319
+ # Compute cfg forward first or not, default False, namely,
320
+ # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
321
+ "cfg_compute_first": False,
322
+ # Compute spearate diff values for CFG and non-CFG step,
323
+ # default True. If False, we will use the computed diff from
324
+ # current non-CFG transformer step for current CFG step.
325
+ "cfg_diff_compute_separate": True,
326
+ }
327
+ ```
328
+
302
329
  ## 🎉FBCache: First Block Cache
303
330
 
304
331
  <div id="fbcache"></div>
@@ -1,18 +1,18 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=wD8hnA5gV5UmPkQnpT3xR6V2csgj9K5NEADogbLK79M,511
2
+ cache_dit/_version.py,sha256=N3oBwJUFmS-AwCjqOcSlRW4GvSq-uJJMaBvoGfv1-hM,511
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
6
6
  cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
7
7
  cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=EV0REfwLuXTqnkZ9vD2nFFjRceRLrXhl1b1SO5N4os8,59272
9
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=wE_xYp7DRbgB-fD8dpr75o4Cvvl2s-jnT2fRyqWm_RM,71286
10
10
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
11
11
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=fibkeU-FHa30BNT-uPV2Eqcd5IRli07EKb25tMDp23c,2270
12
12
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=fddSpTHXU24COMGAY-Z21EmHHAEArZBv_-XLRFD6ADU,2625
13
13
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=wcZdBhjUB8WSfz40A268BtSe3nr_hRsIi2BNlg1FHRU,9965
14
14
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=Cmy0KHRDgwXqtmqfkrr7kw0CP6CmkSnuz29gDHcD6sQ,2262
15
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=kNoCZshNUxgwS9Di84Mz8Js1QAcs_U665x6wcSKYE2A,2594
15
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
16
16
  cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=so1wGdb8W0ATwrjv7E5IEZLPcobybaY1HJa6hBYlOOQ,34698
18
18
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
@@ -22,7 +22,7 @@ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,
22
22
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
23
23
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
24
24
  cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=NeAfDJlJVVUAL4btax5_iOLTuue1x4qeXwk0pM-QH28,23219
25
+ cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=tTPwhPLEA7LqGupps1Zy2MycCtLzs22wsW0yUhiiF-U,23217
26
26
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
27
27
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
28
28
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
@@ -33,8 +33,8 @@ cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k
33
33
  cache_dit/compile/utils.py,sha256=KU60xc474Anbj7Y_FLRFmNxEjVYLLXkhbtCLXO7o_Tc,3699
34
34
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
- cache_dit-0.2.3.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
37
- cache_dit-0.2.3.dist-info/METADATA,sha256=XH7Wn7GdTRth6g0yAY5heArGqXmrrC7CflOJ8ZXH-1k,24867
38
- cache_dit-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- cache_dit-0.2.3.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
40
- cache_dit-0.2.3.dist-info/RECORD,,
36
+ cache_dit-0.2.5.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
37
+ cache_dit-0.2.5.dist-info/METADATA,sha256=J37Waq-cMbuFfTrngXuxqouXpjHK9qhR_MZHlE2odmY,26249
38
+ cache_dit-0.2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ cache_dit-0.2.5.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
40
+ cache_dit-0.2.5.dist-info/RECORD,,