cache-dit 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.3'
21
- __version_tuple__ = version_tuple = (0, 2, 3)
20
+ __version__ = version = '0.2.4'
21
+ __version_tuple__ = version_tuple = (0, 2, 4)
@@ -49,18 +49,18 @@ class DBCacheContext:
49
49
 
50
50
  # Other settings
51
51
  downsample_factor: int = 1
52
- num_inference_steps: int = -1
52
+ num_inference_steps: int = -1 # un-used now
53
53
  warmup_steps: int = 0 # DON'T Cache in warmup steps
54
54
  # DON'T Cache if the number of cached steps >= max_cached_steps
55
- max_cached_steps: int = -1
55
+ max_cached_steps: int = -1 # for both CFG and non-CFG
56
56
 
57
57
  # Statistics for botch alter cache and non-alter cache
58
58
  # Record the steps that have been cached, both alter cache and non-alter cache
59
- executed_steps: int = 0 # cache + non-cache steps
60
- cached_steps: List[int] = dataclasses.field(default_factory=list)
61
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
62
- default_factory=lambda: defaultdict(float),
63
- )
59
+ executed_steps: int = 0 # cache + non-cache steps pippeline
60
+ # steps for transformer, for CFG, transformer_executed_steps will
61
+ # be double of executed_steps.
62
+ transformer_executed_steps: int = 0
63
+
64
64
  # Support TaylorSeers in Dual Block Cache
65
65
  # Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
66
66
  # Url: https://arxiv.org/pdf/2503.06923
@@ -71,8 +71,24 @@ class DBCacheContext:
71
71
  taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
72
72
  taylorseer: Optional[TaylorSeer] = None
73
73
  encoder_tarlorseer: Optional[TaylorSeer] = None
74
- alter_taylorseer: Optional[TaylorSeer] = None
75
- alter_encoder_taylorseer: Optional[TaylorSeer] = None
74
+ # Support do_separate_classifier_free_guidance, such as Wan 2.1
75
+ # For model that fused CFG and non-CFG into single forward step,
76
+ # should set do_separate_classifier_free_guidance as False. For
77
+ # example: CogVideoX
78
+ do_separate_classifier_free_guidance: bool = False
79
+ cfg_compute_first: bool = False
80
+ cfg_taylorseer: Optional[TaylorSeer] = None
81
+ cfg_encoder_taylorseer: Optional[TaylorSeer] = None
82
+
83
+ # CFG & non-CFG cached steps
84
+ cached_steps: List[int] = dataclasses.field(default_factory=list)
85
+ residual_diffs: DefaultDict[str, float] = dataclasses.field(
86
+ default_factory=lambda: defaultdict(float),
87
+ )
88
+ cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
89
+ cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
90
+ default_factory=lambda: defaultdict(float),
91
+ )
76
92
 
77
93
  # TODO: Support SLG in Dual Block Cache
78
94
  # Skip Layer Guidance, SLG
@@ -83,6 +99,11 @@ class DBCacheContext:
83
99
 
84
100
  @torch.compiler.disable
85
101
  def __post_init__(self):
102
+ if self.do_separate_classifier_free_guidance:
103
+ assert self.enable_alter_cache is False, (
104
+ "enable_alter_cache must set as False if "
105
+ "do_separate_classifier_free_guidance is enabled."
106
+ )
86
107
 
87
108
  if "warmup_steps" not in self.taylorseer_kwargs:
88
109
  # If warmup_steps is not set in taylorseer_kwargs,
@@ -99,13 +120,13 @@ class DBCacheContext:
99
120
 
100
121
  if self.enable_taylorseer:
101
122
  self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
102
- if self.enable_alter_cache:
103
- self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
123
+ if self.do_separate_classifier_free_guidance:
124
+ self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
104
125
 
105
126
  if self.enable_encoder_taylorseer:
106
127
  self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
107
- if self.enable_alter_cache:
108
- self.alter_encoder_taylorseer = TaylorSeer(
128
+ if self.do_separate_classifier_free_guidance:
129
+ self.cfg_encoder_taylorseer = TaylorSeer(
109
130
  **self.taylorseer_kwargs
110
131
  )
111
132
 
@@ -159,18 +180,27 @@ class DBCacheContext:
159
180
 
160
181
  @torch.compiler.disable
161
182
  def mark_step_begin(self):
162
- if not self.enable_alter_cache:
163
- self.executed_steps += 1
183
+ # Always increase transformer executed steps
184
+ # incr step: prev 0 -> 1; prev 1 -> 2
185
+ # current step: incr step - 1
186
+ self.transformer_executed_steps += 1
187
+ if not self.do_separate_classifier_free_guidance:
188
+ self.executed_steps = self.transformer_executed_steps
164
189
  else:
165
- self.executed_steps += 1
190
+ # 0,1 -> 0, 2,3 -> 1, ...
191
+ self.executed_steps = self.transformer_executed_steps // 2
192
+
193
+ if not self.enable_alter_cache:
166
194
  # 0 F 1 T 2 F 3 T 4 F 5 T ...
167
195
  self.is_alter_cache = not self.is_alter_cache
168
196
 
169
197
  # Reset the cached steps and residual diffs at the beginning
170
198
  # of each inference.
171
- if self.get_current_step() == 0:
199
+ if self.get_current_transformer_step() == 0:
172
200
  self.cached_steps.clear()
173
201
  self.residual_diffs.clear()
202
+ self.cfg_cached_steps.clear()
203
+ self.cfg_residual_diffs.clear()
174
204
  self.reset_incremental_names()
175
205
  # Reset the TaylorSeers cache at the beginning of each inference.
176
206
  # reset_cache will set the current step to -1 for TaylorSeer,
@@ -180,44 +210,99 @@ class DBCacheContext:
180
210
  taylorseer.reset_cache()
181
211
  if encoder_taylorseer is not None:
182
212
  encoder_taylorseer.reset_cache()
213
+ cfg_taylorseer, cfg_encoder_taylorseer = (
214
+ self.get_cfg_taylorseers()
215
+ )
216
+ if cfg_taylorseer is not None:
217
+ cfg_taylorseer.reset_cache()
218
+ if cfg_encoder_taylorseer is not None:
219
+ cfg_encoder_taylorseer.reset_cache()
183
220
 
184
221
  # mark_step_begin of TaylorSeer must be called after the cache is reset.
185
222
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
186
- taylorseer, encoder_taylorseer = self.get_taylorseers()
187
- if taylorseer is not None:
188
- taylorseer.mark_step_begin()
189
- if encoder_taylorseer is not None:
190
- encoder_taylorseer.mark_step_begin()
223
+ if self.do_separate_classifier_free_guidance:
224
+ # Assume non-CFG steps: 0, 2, 4, 6, ...
225
+ if not self.is_separate_classifier_free_guidance_step():
226
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
227
+ if taylorseer is not None:
228
+ taylorseer.mark_step_begin()
229
+ if encoder_taylorseer is not None:
230
+ encoder_taylorseer.mark_step_begin()
231
+ else:
232
+ cfg_taylorseer, cfg_encoder_taylorseer = (
233
+ self.get_cfg_taylorseers()
234
+ )
235
+ if cfg_taylorseer is not None:
236
+ cfg_taylorseer.mark_step_begin()
237
+ if cfg_encoder_taylorseer is not None:
238
+ cfg_encoder_taylorseer.mark_step_begin()
239
+ else:
240
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
241
+ if taylorseer is not None:
242
+ taylorseer.mark_step_begin()
243
+ if encoder_taylorseer is not None:
244
+ encoder_taylorseer.mark_step_begin()
191
245
 
192
246
  @torch.compiler.disable
193
247
  def get_taylorseers(self):
194
- if self.enable_alter_cache and self.is_alter_cache:
195
- return self.alter_taylorseer, self.alter_encoder_taylorseer
196
248
  return self.taylorseer, self.encoder_tarlorseer
197
249
 
250
+ @torch.compiler.disable
251
+ def get_cfg_taylorseers(self):
252
+ return self.cfg_taylorseer, self.cfg_encoder_taylorseer
253
+
198
254
  @torch.compiler.disable
199
255
  def add_residual_diff(self, diff):
200
256
  step = str(self.get_current_step())
201
- if step not in self.residual_diffs:
202
- # Only add the diff if it is not already recorded for this step
203
- self.residual_diffs[step] = diff
257
+ # Only add the diff if it is not already recorded for this step
258
+ if not self.is_separate_classifier_free_guidance_step():
259
+ if step not in self.residual_diffs:
260
+ self.residual_diffs[step] = diff
261
+ else:
262
+ if step not in self.cfg_residual_diffs:
263
+ self.cfg_residual_diffs[step] = diff
204
264
 
205
265
  @torch.compiler.disable
206
266
  def get_residual_diffs(self):
207
267
  return self.residual_diffs.copy()
208
268
 
269
+ @torch.compiler.disable
270
+ def get_cfg_residual_diffs(self):
271
+ return self.cfg_residual_diffs.copy()
272
+
209
273
  @torch.compiler.disable
210
274
  def add_cached_step(self):
211
- self.cached_steps.append(self.get_current_step())
275
+ if not self.is_separate_classifier_free_guidance_step():
276
+ self.cached_steps.append(self.get_current_step())
277
+ else:
278
+ self.cfg_cached_steps.append(self.get_current_step())
212
279
 
213
280
  @torch.compiler.disable
214
281
  def get_cached_steps(self):
215
282
  return self.cached_steps.copy()
216
283
 
284
+ @torch.compiler.disable
285
+ def get_cfg_cached_steps(self):
286
+ return self.cfg_cached_steps.copy()
287
+
217
288
  @torch.compiler.disable
218
289
  def get_current_step(self):
219
290
  return self.executed_steps - 1
220
291
 
292
+ @torch.compiler.disable
293
+ def get_current_transformer_step(self):
294
+ return self.transformer_executed_steps - 1
295
+
296
+ @torch.compiler.disable
297
+ def is_separate_classifier_free_guidance_step(self):
298
+ if not self.do_separate_classifier_free_guidance:
299
+ return False
300
+ if self.cfg_compute_first:
301
+ # CFG steps: 0, 2, 4, 6, ...
302
+ return self.get_current_transformer_step() % 2
303
+ # CFG steps: 1, 3, 5, 7, ...
304
+ return not self.get_current_transformer_step() % 2
305
+
221
306
  @torch.compiler.disable
222
307
  def is_in_warmup(self):
223
308
  return self.get_current_step() < self.warmup_steps
@@ -265,6 +350,13 @@ def get_current_step():
265
350
  return cache_context.get_current_step()
266
351
 
267
352
 
353
+ @torch.compiler.disable
354
+ def get_current_transformer_step():
355
+ cache_context = get_current_cache_context()
356
+ assert cache_context is not None, "cache_context must be set before"
357
+ return cache_context.get_current_transformer_step()
358
+
359
+
268
360
  @torch.compiler.disable
269
361
  def get_cached_steps():
270
362
  cache_context = get_current_cache_context()
@@ -272,6 +364,13 @@ def get_cached_steps():
272
364
  return cache_context.get_cached_steps()
273
365
 
274
366
 
367
+ @torch.compiler.disable
368
+ def get_cfg_cached_steps():
369
+ cache_context = get_current_cache_context()
370
+ assert cache_context is not None, "cache_context must be set before"
371
+ return cache_context.get_cfg_cached_steps()
372
+
373
+
275
374
  @torch.compiler.disable
276
375
  def get_max_cached_steps():
277
376
  cache_context = get_current_cache_context()
@@ -300,6 +399,13 @@ def get_residual_diffs():
300
399
  return cache_context.get_residual_diffs()
301
400
 
302
401
 
402
+ @torch.compiler.disable
403
+ def get_cfg_residual_diffs():
404
+ cache_context = get_current_cache_context()
405
+ assert cache_context is not None, "cache_context must be set before"
406
+ return cache_context.get_cfg_residual_diffs()
407
+
408
+
303
409
  @torch.compiler.disable
304
410
  def is_taylorseer_enabled():
305
411
  cache_context = get_current_cache_context()
@@ -321,6 +427,13 @@ def get_taylorseers():
321
427
  return cache_context.get_taylorseers()
322
428
 
323
429
 
430
+ @torch.compiler.disable
431
+ def get_cfg_taylorseers():
432
+ cache_context = get_current_cache_context()
433
+ assert cache_context is not None, "cache_context must be set before"
434
+ return cache_context.get_cfg_taylorseers()
435
+
436
+
324
437
  @torch.compiler.disable
325
438
  def is_taylorseer_cache_residual():
326
439
  cache_context = get_current_cache_context()
@@ -459,6 +572,20 @@ def Bn_compute_blocks_ids():
459
572
  return cache_context.Bn_compute_blocks_ids
460
573
 
461
574
 
575
+ @torch.compiler.disable
576
+ def do_separate_classifier_free_guidance():
577
+ cache_context = get_current_cache_context()
578
+ assert cache_context is not None, "cache_context must be set before"
579
+ return cache_context.do_separate_classifier_free_guidance
580
+
581
+
582
+ @torch.compiler.disable
583
+ def is_separate_classifier_free_guidance_step():
584
+ cache_context = get_current_cache_context()
585
+ assert cache_context is not None, "cache_context must be set before"
586
+ return cache_context.is_separate_classifier_free_guidance_step()
587
+
588
+
462
589
  _current_cache_context: DBCacheContext = None
463
590
 
464
591
 
@@ -609,21 +736,31 @@ def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
609
736
  if downsample_factor > 1:
610
737
  buffer = buffer[..., ::downsample_factor]
611
738
  buffer = buffer.contiguous()
612
- set_buffer(f"{prefix}_buffer", buffer)
739
+ if is_separate_classifier_free_guidance_step():
740
+ set_buffer(f"{prefix}_buffer_cfg", buffer)
741
+ else:
742
+ set_buffer(f"{prefix}_buffer", buffer)
613
743
 
614
744
 
615
745
  @torch.compiler.disable
616
746
  def get_Fn_buffer(prefix: str = "Fn"):
747
+ if is_separate_classifier_free_guidance_step():
748
+ return get_buffer(f"{prefix}_buffer_cfg")
617
749
  return get_buffer(f"{prefix}_buffer")
618
750
 
619
751
 
620
752
  @torch.compiler.disable
621
753
  def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
622
- set_buffer(f"{prefix}_encoder_buffer", buffer)
754
+ if is_separate_classifier_free_guidance_step():
755
+ set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
756
+ else:
757
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
623
758
 
624
759
 
625
760
  @torch.compiler.disable
626
761
  def get_Fn_encoder_buffer(prefix: str = "Fn"):
762
+ if is_separate_classifier_free_guidance_step():
763
+ return get_buffer(f"{prefix}_encoder_buffer_cfg")
627
764
  return get_buffer(f"{prefix}_encoder_buffer")
628
765
 
629
766
 
@@ -634,7 +771,11 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
634
771
  # This buffer is use for hidden states approximation.
635
772
  if is_taylorseer_enabled():
636
773
  # taylorseer, encoder_taylorseer
637
- taylorseer, _ = get_taylorseers()
774
+ if is_separate_classifier_free_guidance_step():
775
+ taylorseer, _ = get_cfg_taylorseers()
776
+ else:
777
+ taylorseer, _ = get_taylorseers()
778
+
638
779
  if taylorseer is not None:
639
780
  # Use TaylorSeer to update the buffer
640
781
  taylorseer.update(buffer)
@@ -644,15 +785,26 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
644
785
  "TaylorSeer is enabled but not set in the cache context. "
645
786
  "Falling back to default buffer retrieval."
646
787
  )
647
- set_buffer(f"{prefix}_buffer", buffer)
788
+ if is_separate_classifier_free_guidance_step():
789
+ set_buffer(f"{prefix}_buffer_cfg", buffer)
790
+ else:
791
+ set_buffer(f"{prefix}_buffer", buffer)
648
792
  else:
649
- set_buffer(f"{prefix}_buffer", buffer)
793
+ if is_separate_classifier_free_guidance_step():
794
+ set_buffer(f"{prefix}_buffer_cfg", buffer)
795
+ else:
796
+ set_buffer(f"{prefix}_buffer", buffer)
650
797
 
651
798
 
652
799
  @torch.compiler.disable
653
800
  def get_Bn_buffer(prefix: str = "Bn"):
654
801
  if is_taylorseer_enabled():
655
- taylorseer, _ = get_taylorseers()
802
+ # taylorseer, encoder_taylorseer
803
+ if is_separate_classifier_free_guidance_step():
804
+ taylorseer, _ = get_cfg_taylorseers()
805
+ else:
806
+ taylorseer, _ = get_taylorseers()
807
+
656
808
  if taylorseer is not None:
657
809
  return taylorseer.approximate_value()
658
810
  else:
@@ -662,8 +814,12 @@ def get_Bn_buffer(prefix: str = "Bn"):
662
814
  "Falling back to default buffer retrieval."
663
815
  )
664
816
  # Fallback to default buffer retrieval
817
+ if is_separate_classifier_free_guidance_step():
818
+ return get_buffer(f"{prefix}_buffer_cfg")
665
819
  return get_buffer(f"{prefix}_buffer")
666
820
  else:
821
+ if is_separate_classifier_free_guidance_step():
822
+ return get_buffer(f"{prefix}_buffer_cfg")
667
823
  return get_buffer(f"{prefix}_buffer")
668
824
 
669
825
 
@@ -672,7 +828,11 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
672
828
  # This buffer is use for encoder hidden states approximation.
673
829
  if is_encoder_taylorseer_enabled():
674
830
  # taylorseer, encoder_taylorseer
675
- _, encoder_taylorseer = get_taylorseers()
831
+ if is_separate_classifier_free_guidance_step():
832
+ _, encoder_taylorseer = get_cfg_taylorseers()
833
+ else:
834
+ _, encoder_taylorseer = get_taylorseers()
835
+
676
836
  if encoder_taylorseer is not None:
677
837
  # Use TaylorSeer to update the buffer
678
838
  encoder_taylorseer.update(buffer)
@@ -682,15 +842,25 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
682
842
  "TaylorSeer is enabled but not set in the cache context. "
683
843
  "Falling back to default buffer retrieval."
684
844
  )
685
- set_buffer(f"{prefix}_encoder_buffer", buffer)
845
+ if is_separate_classifier_free_guidance_step():
846
+ set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
847
+ else:
848
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
686
849
  else:
687
- set_buffer(f"{prefix}_encoder_buffer", buffer)
850
+ if is_separate_classifier_free_guidance_step():
851
+ set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
852
+ else:
853
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
688
854
 
689
855
 
690
856
  @torch.compiler.disable
691
857
  def get_Bn_encoder_buffer(prefix: str = "Bn"):
692
858
  if is_encoder_taylorseer_enabled():
693
- _, encoder_taylorseer = get_taylorseers()
859
+ if is_separate_classifier_free_guidance_step():
860
+ _, encoder_taylorseer = get_cfg_taylorseers()
861
+ else:
862
+ _, encoder_taylorseer = get_taylorseers()
863
+
694
864
  if encoder_taylorseer is not None:
695
865
  # Use TaylorSeer to approximate the value
696
866
  return encoder_taylorseer.approximate_value()
@@ -701,8 +871,12 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
701
871
  "Falling back to default buffer retrieval."
702
872
  )
703
873
  # Fallback to default buffer retrieval
874
+ if is_separate_classifier_free_guidance_step():
875
+ return get_buffer(f"{prefix}_encoder_buffer_cfg")
704
876
  return get_buffer(f"{prefix}_encoder_buffer")
705
877
  else:
878
+ if is_separate_classifier_free_guidance_step():
879
+ return get_buffer(f"{prefix}_encoder_buffer_cfg")
706
880
  return get_buffer(f"{prefix}_encoder_buffer")
707
881
 
708
882
 
@@ -766,8 +940,13 @@ def get_can_use_cache(
766
940
  ):
767
941
  if is_in_warmup():
768
942
  return False
769
- cached_steps = get_cached_steps()
943
+
770
944
  max_cached_steps = get_max_cached_steps()
945
+ if not is_separate_classifier_free_guidance_step():
946
+ cached_steps = get_cached_steps()
947
+ else:
948
+ cached_steps = get_cfg_cached_steps()
949
+
771
950
  if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
772
951
  if logger.isEnabledFor(logging.DEBUG):
773
952
  logger.debug(
@@ -775,10 +954,12 @@ def get_can_use_cache(
775
954
  "cannot use cache."
776
955
  )
777
956
  return False
957
+
778
958
  if threshold is None or threshold <= 0.0:
779
959
  threshold = get_residual_diff_threshold()
780
960
  if threshold <= 0.0:
781
961
  return False
962
+
782
963
  downsample_factor = get_downsample_factor()
783
964
  if downsample_factor > 1 and "Bn" not in prefix:
784
965
  states_tensor = states_tensor[..., ::downsample_factor]
@@ -982,7 +1163,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
982
1163
  # Check if the current step is in cache steps.
983
1164
  # If so, we can skip some Bn blocks and directly
984
1165
  # use the cached values.
985
- return get_current_step() in get_cached_steps()
1166
+ return (get_current_step() in get_cached_steps()) or (
1167
+ get_current_step() in get_cfg_cached_steps()
1168
+ )
986
1169
 
987
1170
  @torch.compiler.disable
988
1171
  def _Fn_transformer_blocks(self):
@@ -1601,3 +1784,5 @@ def patch_cached_stats(
1601
1784
  # TODO: Patch more cached stats to the transformer
1602
1785
  transformer._cached_steps = get_cached_steps()
1603
1786
  transformer._residual_diffs = get_residual_diffs()
1787
+ transformer._cfg_cached_steps = get_cfg_cached_steps()
1788
+ transformer._cfg_residual_diffs = get_cfg_residual_diffs()
@@ -70,7 +70,7 @@ def apply_db_cache_on_pipe(
70
70
  # "slg_layers": slg_layers,
71
71
  # "slg_start": slg_start,
72
72
  # "slg_end": slg_end,
73
- "num_inference_steps": kwargs.get("num_inference_steps", 50),
73
+ # "num_inference_steps": kwargs.get("num_inference_steps", 50),
74
74
  "warmup_steps": warmup_steps,
75
75
  "max_cached_steps": max_cached_steps,
76
76
  },
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -154,6 +154,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
154
154
  - [🔥Supported Models](#supported)
155
155
  - [⚡️Dual Block Cache](#dbcache)
156
156
  - [🔥Hybrid TaylorSeer](#taylorseer)
157
+ - [⚡️Hybrid Cache CFG](#cfg)
157
158
  - [🎉First Block Cache](#fbcache)
158
159
  - [⚡️Dynamic Block Prune](#dbprune)
159
160
  - [🎉Context Parallelism](#context-parallelism)
@@ -299,6 +300,19 @@ cache_options = {
299
300
  |24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
300
301
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=105px>|
301
302
 
303
+ ## ⚡️Hybrid Cache CFG
304
+
305
+ <div id="cfg"></div>
306
+
307
+ CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `do_separate_classifier_free_guidance` param to False. Otherwise, set it to True. Wan 2.1: True. FLUX.1, HunyunVideo, CogVideoX, Mochi: False.
308
+
309
+ ```python
310
+ cache_options = {
311
+ "do_separate_classifier_free_guidance": True, # Wan 2.1
312
+ "cfg_compute_first": False,
313
+ }
314
+ ```
315
+
302
316
  ## 🎉FBCache: First Block Cache
303
317
 
304
318
  <div id="fbcache"></div>
@@ -1,18 +1,18 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=wD8hnA5gV5UmPkQnpT3xR6V2csgj9K5NEADogbLK79M,511
2
+ cache_dit/_version.py,sha256=1LUN_sRKOiFInoB6AlW6TYoQMCh1Z4KutwcHNvHcfB0,511
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
6
6
  cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
7
7
  cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=EV0REfwLuXTqnkZ9vD2nFFjRceRLrXhl1b1SO5N4os8,59272
9
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=H7u5zIAdEjiYU0QvWYIMj3lKYI4D8cmDLy7eZ9tyoyU,66848
10
10
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
11
11
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=fibkeU-FHa30BNT-uPV2Eqcd5IRli07EKb25tMDp23c,2270
12
12
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=fddSpTHXU24COMGAY-Z21EmHHAEArZBv_-XLRFD6ADU,2625
13
13
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=wcZdBhjUB8WSfz40A268BtSe3nr_hRsIi2BNlg1FHRU,9965
14
14
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=Cmy0KHRDgwXqtmqfkrr7kw0CP6CmkSnuz29gDHcD6sQ,2262
15
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=kNoCZshNUxgwS9Di84Mz8Js1QAcs_U665x6wcSKYE2A,2594
15
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
16
16
  cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=so1wGdb8W0ATwrjv7E5IEZLPcobybaY1HJa6hBYlOOQ,34698
18
18
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
@@ -33,8 +33,8 @@ cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k
33
33
  cache_dit/compile/utils.py,sha256=KU60xc474Anbj7Y_FLRFmNxEjVYLLXkhbtCLXO7o_Tc,3699
34
34
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
- cache_dit-0.2.3.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
37
- cache_dit-0.2.3.dist-info/METADATA,sha256=XH7Wn7GdTRth6g0yAY5heArGqXmrrC7CflOJ8ZXH-1k,24867
38
- cache_dit-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- cache_dit-0.2.3.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
40
- cache_dit-0.2.3.dist-info/RECORD,,
36
+ cache_dit-0.2.4.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
37
+ cache_dit-0.2.4.dist-info/METADATA,sha256=1oDgkkUwGVfwX_jCyU0jHbQTVQDfL59OEbrUb_9SVF4,25442
38
+ cache_dit-0.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ cache_dit-0.2.4.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
40
+ cache_dit-0.2.4.dist-info/RECORD,,