cache-dit 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +138 -33
- cache_dit/cache_factory/first_block_cache/cache_context.py +2 -2
- cache_dit/metrics/__init__.py +12 -0
- cache_dit/metrics/fid.py +409 -0
- cache_dit/metrics/inception.py +353 -0
- cache_dit/metrics/metrics.py +356 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/METADATA +59 -8
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/RECORD +13 -8
- cache_dit-0.2.6.dist-info/entry_points.txt +2 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -71,12 +71,19 @@ class DBCacheContext:
|
|
|
71
71
|
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
72
72
|
taylorseer: Optional[TaylorSeer] = None
|
|
73
73
|
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
74
|
+
|
|
74
75
|
# Support do_separate_classifier_free_guidance, such as Wan 2.1
|
|
75
76
|
# For model that fused CFG and non-CFG into single forward step,
|
|
76
77
|
# should set do_separate_classifier_free_guidance as False. For
|
|
77
|
-
# example: CogVideoX
|
|
78
|
+
# example: CogVideoX, HunyuanVideo, Mochi.
|
|
78
79
|
do_separate_classifier_free_guidance: bool = False
|
|
80
|
+
# Compute cfg forward first or not, default False, namely,
|
|
81
|
+
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
79
82
|
cfg_compute_first: bool = False
|
|
83
|
+
# Compute spearate diff values for CFG and non-CFG step,
|
|
84
|
+
# default True. If False, we will use the computed diff from
|
|
85
|
+
# current non-CFG transformer step for current CFG step.
|
|
86
|
+
cfg_diff_compute_separate: bool = True
|
|
80
87
|
cfg_taylorseer: Optional[TaylorSeer] = None
|
|
81
88
|
cfg_encoder_taylorseer: Optional[TaylorSeer] = None
|
|
82
89
|
|
|
@@ -99,11 +106,17 @@ class DBCacheContext:
|
|
|
99
106
|
|
|
100
107
|
@torch.compiler.disable
|
|
101
108
|
def __post_init__(self):
|
|
109
|
+
# Some checks for settings
|
|
102
110
|
if self.do_separate_classifier_free_guidance:
|
|
103
111
|
assert self.enable_alter_cache is False, (
|
|
104
112
|
"enable_alter_cache must set as False if "
|
|
105
113
|
"do_separate_classifier_free_guidance is enabled."
|
|
106
114
|
)
|
|
115
|
+
if self.cfg_diff_compute_separate:
|
|
116
|
+
assert self.cfg_compute_first is False, (
|
|
117
|
+
"cfg_compute_first must set as False if "
|
|
118
|
+
"cfg_diff_compute_separate is enabled."
|
|
119
|
+
)
|
|
107
120
|
|
|
108
121
|
if "warmup_steps" not in self.taylorseer_kwargs:
|
|
109
122
|
# If warmup_steps is not set in taylorseer_kwargs,
|
|
@@ -185,10 +198,17 @@ class DBCacheContext:
|
|
|
185
198
|
# current step: incr step - 1
|
|
186
199
|
self.transformer_executed_steps += 1
|
|
187
200
|
if not self.do_separate_classifier_free_guidance:
|
|
188
|
-
self.executed_steps
|
|
201
|
+
self.executed_steps += 1
|
|
189
202
|
else:
|
|
190
|
-
# 0,1 -> 0, 2,3 -> 1, ...
|
|
191
|
-
|
|
203
|
+
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
204
|
+
if not self.cfg_compute_first:
|
|
205
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
206
|
+
# transformer step: 0,2,4,...
|
|
207
|
+
self.executed_steps += 1
|
|
208
|
+
else:
|
|
209
|
+
if self.is_separate_classifier_free_guidance_step():
|
|
210
|
+
# transformer step: 0,2,4,...
|
|
211
|
+
self.executed_steps += 1
|
|
192
212
|
|
|
193
213
|
if not self.enable_alter_cache:
|
|
194
214
|
# 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
@@ -253,6 +273,7 @@ class DBCacheContext:
|
|
|
253
273
|
|
|
254
274
|
@torch.compiler.disable
|
|
255
275
|
def add_residual_diff(self, diff):
|
|
276
|
+
# step: executed_steps - 1, not transformer_steps - 1
|
|
256
277
|
step = str(self.get_current_step())
|
|
257
278
|
# Only add the diff if it is not already recorded for this step
|
|
258
279
|
if not self.is_separate_classifier_free_guidance_step():
|
|
@@ -299,9 +320,9 @@ class DBCacheContext:
|
|
|
299
320
|
return False
|
|
300
321
|
if self.cfg_compute_first:
|
|
301
322
|
# CFG steps: 0, 2, 4, 6, ...
|
|
302
|
-
return self.get_current_transformer_step() % 2
|
|
323
|
+
return self.get_current_transformer_step() % 2 == 0
|
|
303
324
|
# CFG steps: 1, 3, 5, 7, ...
|
|
304
|
-
return
|
|
325
|
+
return self.get_current_transformer_step() % 2 != 0
|
|
305
326
|
|
|
306
327
|
@torch.compiler.disable
|
|
307
328
|
def is_in_warmup(self):
|
|
@@ -350,6 +371,28 @@ def get_current_step():
|
|
|
350
371
|
return cache_context.get_current_step()
|
|
351
372
|
|
|
352
373
|
|
|
374
|
+
@torch.compiler.disable
|
|
375
|
+
def get_current_step_residual_diff():
|
|
376
|
+
cache_context = get_current_cache_context()
|
|
377
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
378
|
+
step = str(get_current_step())
|
|
379
|
+
residual_diffs = get_residual_diffs()
|
|
380
|
+
if step in residual_diffs:
|
|
381
|
+
return residual_diffs[step]
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@torch.compiler.disable
|
|
386
|
+
def get_current_step_cfg_residual_diff():
|
|
387
|
+
cache_context = get_current_cache_context()
|
|
388
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
389
|
+
step = str(get_current_step())
|
|
390
|
+
cfg_residual_diffs = get_cfg_residual_diffs()
|
|
391
|
+
if step in cfg_residual_diffs:
|
|
392
|
+
return cfg_residual_diffs[step]
|
|
393
|
+
return None
|
|
394
|
+
|
|
395
|
+
|
|
353
396
|
@torch.compiler.disable
|
|
354
397
|
def get_current_transformer_step():
|
|
355
398
|
cache_context = get_current_cache_context()
|
|
@@ -586,6 +629,13 @@ def is_separate_classifier_free_guidance_step():
|
|
|
586
629
|
return cache_context.is_separate_classifier_free_guidance_step()
|
|
587
630
|
|
|
588
631
|
|
|
632
|
+
@torch.compiler.disable
|
|
633
|
+
def cfg_diff_compute_separate():
|
|
634
|
+
cache_context = get_current_cache_context()
|
|
635
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
636
|
+
return cache_context.cfg_diff_compute_separate
|
|
637
|
+
|
|
638
|
+
|
|
589
639
|
_current_cache_context: DBCacheContext = None
|
|
590
640
|
|
|
591
641
|
|
|
@@ -686,38 +736,49 @@ def are_two_tensors_similar(
|
|
|
686
736
|
add_residual_diff(-2.0)
|
|
687
737
|
return False
|
|
688
738
|
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
#
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
739
|
+
if all(
|
|
740
|
+
(
|
|
741
|
+
do_separate_classifier_free_guidance(),
|
|
742
|
+
is_separate_classifier_free_guidance_step(),
|
|
743
|
+
not cfg_diff_compute_separate(),
|
|
744
|
+
get_current_step_residual_diff() is not None,
|
|
745
|
+
)
|
|
746
|
+
):
|
|
747
|
+
# Reuse computed diff value from non-CFG step
|
|
748
|
+
diff = get_current_step_residual_diff()
|
|
749
|
+
else:
|
|
750
|
+
# Find the most significant token through t1 and t2, and
|
|
751
|
+
# consider the diff of the significant token. The more significant,
|
|
752
|
+
# the more important.
|
|
753
|
+
condition_thresh = get_important_condition_threshold()
|
|
754
|
+
if condition_thresh > 0.0:
|
|
755
|
+
raw_diff = (t1 - t2).abs() # [B, seq_len, d]
|
|
756
|
+
token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
|
|
757
|
+
token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
|
|
758
|
+
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
759
|
+
token_diff = token_m_df / token_m_t1 # [B, seq_len]
|
|
760
|
+
condition = token_diff > condition_thresh # [B, seq_len]
|
|
761
|
+
if condition.sum() > 0:
|
|
762
|
+
condition = condition.unsqueeze(-1) # [B, seq_len, 1]
|
|
763
|
+
condition = condition.expand_as(raw_diff) # [B, seq_len, d]
|
|
764
|
+
mean_diff = raw_diff[condition].mean()
|
|
765
|
+
mean_t1 = t1[condition].abs().mean()
|
|
766
|
+
else:
|
|
767
|
+
mean_diff = (t1 - t2).abs().mean()
|
|
768
|
+
mean_t1 = t1.abs().mean()
|
|
705
769
|
else:
|
|
770
|
+
# Use the mean of the absolute difference of the tensors
|
|
706
771
|
mean_diff = (t1 - t2).abs().mean()
|
|
707
772
|
mean_t1 = t1.abs().mean()
|
|
708
|
-
else:
|
|
709
|
-
# Use the mean of the absolute difference of the tensors
|
|
710
|
-
mean_diff = (t1 - t2).abs().mean()
|
|
711
|
-
mean_t1 = t1.abs().mean()
|
|
712
773
|
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
774
|
+
if parallelized:
|
|
775
|
+
mean_diff = DP.all_reduce_sync(mean_diff, "avg")
|
|
776
|
+
mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
|
|
716
777
|
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
778
|
+
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
779
|
+
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
780
|
+
# H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
|
|
781
|
+
diff = (mean_diff / mean_t1).item()
|
|
721
782
|
|
|
722
783
|
if logger.isEnabledFor(logging.DEBUG):
|
|
723
784
|
logger.debug(f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}")
|
|
@@ -727,6 +788,26 @@ def are_two_tensors_similar(
|
|
|
727
788
|
return diff < threshold
|
|
728
789
|
|
|
729
790
|
|
|
791
|
+
@torch.compiler.disable
|
|
792
|
+
def _debugging_set_buffer(prefix):
|
|
793
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
794
|
+
logger.debug(
|
|
795
|
+
f"set {prefix}, "
|
|
796
|
+
f"transformer step: {get_current_transformer_step()}, "
|
|
797
|
+
f"executed step: {get_current_step()}"
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
@torch.compiler.disable
|
|
802
|
+
def _debugging_get_buffer(prefix):
|
|
803
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
804
|
+
logger.debug(
|
|
805
|
+
f"get {prefix}, "
|
|
806
|
+
f"transformer step: {get_current_transformer_step()}, "
|
|
807
|
+
f"executed step: {get_current_step()}"
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
|
|
730
811
|
# Fn buffers
|
|
731
812
|
@torch.compiler.disable
|
|
732
813
|
def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
@@ -737,30 +818,38 @@ def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
|
737
818
|
buffer = buffer[..., ::downsample_factor]
|
|
738
819
|
buffer = buffer.contiguous()
|
|
739
820
|
if is_separate_classifier_free_guidance_step():
|
|
821
|
+
_debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
740
822
|
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
741
823
|
else:
|
|
824
|
+
_debugging_set_buffer(f"{prefix}_buffer")
|
|
742
825
|
set_buffer(f"{prefix}_buffer", buffer)
|
|
743
826
|
|
|
744
827
|
|
|
745
828
|
@torch.compiler.disable
|
|
746
829
|
def get_Fn_buffer(prefix: str = "Fn"):
|
|
747
830
|
if is_separate_classifier_free_guidance_step():
|
|
831
|
+
_debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
748
832
|
return get_buffer(f"{prefix}_buffer_cfg")
|
|
833
|
+
_debugging_get_buffer(f"{prefix}_buffer")
|
|
749
834
|
return get_buffer(f"{prefix}_buffer")
|
|
750
835
|
|
|
751
836
|
|
|
752
837
|
@torch.compiler.disable
|
|
753
838
|
def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
754
839
|
if is_separate_classifier_free_guidance_step():
|
|
840
|
+
_debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
755
841
|
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
756
842
|
else:
|
|
843
|
+
_debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
757
844
|
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
758
845
|
|
|
759
846
|
|
|
760
847
|
@torch.compiler.disable
|
|
761
848
|
def get_Fn_encoder_buffer(prefix: str = "Fn"):
|
|
762
849
|
if is_separate_classifier_free_guidance_step():
|
|
850
|
+
_debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
763
851
|
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
852
|
+
_debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
764
853
|
return get_buffer(f"{prefix}_encoder_buffer")
|
|
765
854
|
|
|
766
855
|
|
|
@@ -786,13 +875,17 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
|
786
875
|
"Falling back to default buffer retrieval."
|
|
787
876
|
)
|
|
788
877
|
if is_separate_classifier_free_guidance_step():
|
|
878
|
+
_debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
789
879
|
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
790
880
|
else:
|
|
881
|
+
_debugging_set_buffer(f"{prefix}_buffer")
|
|
791
882
|
set_buffer(f"{prefix}_buffer", buffer)
|
|
792
883
|
else:
|
|
793
884
|
if is_separate_classifier_free_guidance_step():
|
|
885
|
+
_debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
794
886
|
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
795
887
|
else:
|
|
888
|
+
_debugging_set_buffer(f"{prefix}_buffer")
|
|
796
889
|
set_buffer(f"{prefix}_buffer", buffer)
|
|
797
890
|
|
|
798
891
|
|
|
@@ -815,11 +908,15 @@ def get_Bn_buffer(prefix: str = "Bn"):
|
|
|
815
908
|
)
|
|
816
909
|
# Fallback to default buffer retrieval
|
|
817
910
|
if is_separate_classifier_free_guidance_step():
|
|
911
|
+
_debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
818
912
|
return get_buffer(f"{prefix}_buffer_cfg")
|
|
913
|
+
_debugging_get_buffer(f"{prefix}_buffer")
|
|
819
914
|
return get_buffer(f"{prefix}_buffer")
|
|
820
915
|
else:
|
|
821
916
|
if is_separate_classifier_free_guidance_step():
|
|
917
|
+
_debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
822
918
|
return get_buffer(f"{prefix}_buffer_cfg")
|
|
919
|
+
_debugging_get_buffer(f"{prefix}_buffer")
|
|
823
920
|
return get_buffer(f"{prefix}_buffer")
|
|
824
921
|
|
|
825
922
|
|
|
@@ -843,13 +940,17 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
|
843
940
|
"Falling back to default buffer retrieval."
|
|
844
941
|
)
|
|
845
942
|
if is_separate_classifier_free_guidance_step():
|
|
943
|
+
_debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
846
944
|
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
847
945
|
else:
|
|
946
|
+
_debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
848
947
|
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
849
948
|
else:
|
|
850
949
|
if is_separate_classifier_free_guidance_step():
|
|
950
|
+
_debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
851
951
|
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
852
952
|
else:
|
|
953
|
+
_debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
853
954
|
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
854
955
|
|
|
855
956
|
|
|
@@ -872,11 +973,15 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
|
872
973
|
)
|
|
873
974
|
# Fallback to default buffer retrieval
|
|
874
975
|
if is_separate_classifier_free_guidance_step():
|
|
976
|
+
_debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
875
977
|
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
978
|
+
_debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
876
979
|
return get_buffer(f"{prefix}_encoder_buffer")
|
|
877
980
|
else:
|
|
878
981
|
if is_separate_classifier_free_guidance_step():
|
|
982
|
+
_debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
879
983
|
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
984
|
+
_debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
880
985
|
return get_buffer(f"{prefix}_encoder_buffer")
|
|
881
986
|
|
|
882
987
|
|
|
@@ -370,8 +370,8 @@ def apply_prev_hidden_states_residual(
|
|
|
370
370
|
hidden_states = hidden_states_residual + hidden_states
|
|
371
371
|
|
|
372
372
|
hidden_states = hidden_states.contiguous()
|
|
373
|
-
# NOTE: We should also support taylorseer for
|
|
374
|
-
# encoder_hidden_states approximation. Please
|
|
373
|
+
# NOTE: We should also support taylorseer for
|
|
374
|
+
# encoder_hidden_states approximation. Please
|
|
375
375
|
# use DBCache instead.
|
|
376
376
|
else:
|
|
377
377
|
hidden_states_residual = get_hidden_states_residual()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from cache_dit.metrics.metrics import compute_psnr
|
|
2
|
+
from cache_dit.metrics.metrics import compute_ssim
|
|
3
|
+
from cache_dit.metrics.metrics import compute_mse
|
|
4
|
+
from cache_dit.metrics.metrics import compute_video_psnr
|
|
5
|
+
from cache_dit.metrics.metrics import compute_video_ssim
|
|
6
|
+
from cache_dit.metrics.metrics import compute_video_mse
|
|
7
|
+
from cache_dit.metrics.metrics import entrypoint
|
|
8
|
+
from cache_dit.metrics.fid import FrechetInceptionDistance
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
entrypoint()
|