cache-dit 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -69,11 +69,11 @@ class DBCacheContext:
69
69
  taylorseer: Optional[TaylorSeer] = None
70
70
  encoder_tarlorseer: Optional[TaylorSeer] = None
71
71
 
72
- # Support do_separate_classifier_free_guidance, such as Wan 2.1,
72
+ # Support do_separate_cfg, such as Wan 2.1,
73
73
  # Qwen-Image. For model that fused CFG and non-CFG into single
74
- # forward step, should set do_separate_classifier_free_guidance
75
- # as False. For example: CogVideoX, HunyuanVideo, Mochi.
76
- do_separate_classifier_free_guidance: bool = False
74
+ # forward step, should set do_separate_cfg as False.
75
+ # For example: CogVideoX, HunyuanVideo, Mochi.
76
+ do_separate_cfg: bool = False
77
77
  # Compute cfg forward first or not, default False, namely,
78
78
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
79
79
  cfg_compute_first: bool = False
@@ -97,10 +97,10 @@ class DBCacheContext:
97
97
  @torch.compiler.disable
98
98
  def __post_init__(self):
99
99
  # Some checks for settings
100
- if self.do_separate_classifier_free_guidance:
100
+ if self.do_separate_cfg:
101
101
  assert self.enable_alter_cache is False, (
102
102
  "enable_alter_cache must set as False if "
103
- "do_separate_classifier_free_guidance is enabled."
103
+ "do_separate_cfg is enabled."
104
104
  )
105
105
  if self.cfg_diff_compute_separate:
106
106
  assert self.cfg_compute_first is False, (
@@ -123,12 +123,12 @@ class DBCacheContext:
123
123
 
124
124
  if self.enable_taylorseer:
125
125
  self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
126
- if self.do_separate_classifier_free_guidance:
126
+ if self.do_separate_cfg:
127
127
  self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
128
128
 
129
129
  if self.enable_encoder_taylorseer:
130
130
  self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
131
- if self.do_separate_classifier_free_guidance:
131
+ if self.do_separate_cfg:
132
132
  self.cfg_encoder_taylorseer = TaylorSeer(
133
133
  **self.taylorseer_kwargs
134
134
  )
@@ -175,16 +175,16 @@ class DBCacheContext:
175
175
  # incr step: prev 0 -> 1; prev 1 -> 2
176
176
  # current step: incr step - 1
177
177
  self.transformer_executed_steps += 1
178
- if not self.do_separate_classifier_free_guidance:
178
+ if not self.do_separate_cfg:
179
179
  self.executed_steps += 1
180
180
  else:
181
181
  # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
182
182
  if not self.cfg_compute_first:
183
- if not self.is_separate_classifier_free_guidance_step():
183
+ if not self.is_separate_cfg_step():
184
184
  # transformer step: 0,2,4,...
185
185
  self.executed_steps += 1
186
186
  else:
187
- if self.is_separate_classifier_free_guidance_step():
187
+ if self.is_separate_cfg_step():
188
188
  # transformer step: 0,2,4,...
189
189
  self.executed_steps += 1
190
190
 
@@ -217,9 +217,9 @@ class DBCacheContext:
217
217
 
218
218
  # mark_step_begin of TaylorSeer must be called after the cache is reset.
219
219
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
220
- if self.do_separate_classifier_free_guidance:
220
+ if self.do_separate_cfg:
221
221
  # Assume non-CFG steps: 0, 2, 4, 6, ...
222
- if not self.is_separate_classifier_free_guidance_step():
222
+ if not self.is_separate_cfg_step():
223
223
  taylorseer, encoder_taylorseer = self.get_taylorseers()
224
224
  if taylorseer is not None:
225
225
  taylorseer.mark_step_begin()
@@ -251,7 +251,7 @@ class DBCacheContext:
251
251
  # step: executed_steps - 1, not transformer_steps - 1
252
252
  step = str(self.get_current_step())
253
253
  # Only add the diff if it is not already recorded for this step
254
- if not self.is_separate_classifier_free_guidance_step():
254
+ if not self.is_separate_cfg_step():
255
255
  if step not in self.residual_diffs:
256
256
  self.residual_diffs[step] = diff
257
257
  else:
@@ -268,7 +268,7 @@ class DBCacheContext:
268
268
 
269
269
  @torch.compiler.disable
270
270
  def add_cached_step(self):
271
- if not self.is_separate_classifier_free_guidance_step():
271
+ if not self.is_separate_cfg_step():
272
272
  self.cached_steps.append(self.get_current_step())
273
273
  else:
274
274
  self.cfg_cached_steps.append(self.get_current_step())
@@ -290,8 +290,8 @@ class DBCacheContext:
290
290
  return self.transformer_executed_steps - 1
291
291
 
292
292
  @torch.compiler.disable
293
- def is_separate_classifier_free_guidance_step(self):
294
- if not self.do_separate_classifier_free_guidance:
293
+ def is_separate_cfg_step(self):
294
+ if not self.do_separate_cfg:
295
295
  return False
296
296
  if self.cfg_compute_first:
297
297
  # CFG steps: 0, 2, 4, 6, ...
@@ -589,17 +589,17 @@ def Bn_compute_blocks_ids():
589
589
 
590
590
 
591
591
  @torch.compiler.disable
592
- def do_separate_classifier_free_guidance():
592
+ def do_separate_cfg():
593
593
  cache_context = get_current_cache_context()
594
594
  assert cache_context is not None, "cache_context must be set before"
595
- return cache_context.do_separate_classifier_free_guidance
595
+ return cache_context.do_separate_cfg
596
596
 
597
597
 
598
598
  @torch.compiler.disable
599
- def is_separate_classifier_free_guidance_step():
599
+ def is_separate_cfg_step():
600
600
  cache_context = get_current_cache_context()
601
601
  assert cache_context is not None, "cache_context must be set before"
602
- return cache_context.is_separate_classifier_free_guidance_step()
602
+ return cache_context.is_separate_cfg_step()
603
603
 
604
604
 
605
605
  @torch.compiler.disable
@@ -710,8 +710,8 @@ def are_two_tensors_similar(
710
710
 
711
711
  if all(
712
712
  (
713
- do_separate_classifier_free_guidance(),
714
- is_separate_classifier_free_guidance_step(),
713
+ do_separate_cfg(),
714
+ is_separate_cfg_step(),
715
715
  not cfg_diff_compute_separate(),
716
716
  get_current_step_residual_diff() is not None,
717
717
  )
@@ -789,7 +789,7 @@ def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
789
789
  if downsample_factor > 1:
790
790
  buffer = buffer[..., ::downsample_factor]
791
791
  buffer = buffer.contiguous()
792
- if is_separate_classifier_free_guidance_step():
792
+ if is_separate_cfg_step():
793
793
  _debugging_set_buffer(f"{prefix}_buffer_cfg")
794
794
  set_buffer(f"{prefix}_buffer_cfg", buffer)
795
795
  else:
@@ -799,7 +799,7 @@ def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
799
799
 
800
800
  @torch.compiler.disable
801
801
  def get_Fn_buffer(prefix: str = "Fn"):
802
- if is_separate_classifier_free_guidance_step():
802
+ if is_separate_cfg_step():
803
803
  _debugging_get_buffer(f"{prefix}_buffer_cfg")
804
804
  return get_buffer(f"{prefix}_buffer_cfg")
805
805
  _debugging_get_buffer(f"{prefix}_buffer")
@@ -808,7 +808,7 @@ def get_Fn_buffer(prefix: str = "Fn"):
808
808
 
809
809
  @torch.compiler.disable
810
810
  def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
811
- if is_separate_classifier_free_guidance_step():
811
+ if is_separate_cfg_step():
812
812
  _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
813
813
  set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
814
814
  else:
@@ -818,7 +818,7 @@ def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
818
818
 
819
819
  @torch.compiler.disable
820
820
  def get_Fn_encoder_buffer(prefix: str = "Fn"):
821
- if is_separate_classifier_free_guidance_step():
821
+ if is_separate_cfg_step():
822
822
  _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
823
823
  return get_buffer(f"{prefix}_encoder_buffer_cfg")
824
824
  _debugging_get_buffer(f"{prefix}_encoder_buffer")
@@ -832,7 +832,7 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
832
832
  # This buffer is use for hidden states approximation.
833
833
  if is_taylorseer_enabled():
834
834
  # taylorseer, encoder_taylorseer
835
- if is_separate_classifier_free_guidance_step():
835
+ if is_separate_cfg_step():
836
836
  taylorseer, _ = get_cfg_taylorseers()
837
837
  else:
838
838
  taylorseer, _ = get_taylorseers()
@@ -846,14 +846,14 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
846
846
  "TaylorSeer is enabled but not set in the cache context. "
847
847
  "Falling back to default buffer retrieval."
848
848
  )
849
- if is_separate_classifier_free_guidance_step():
849
+ if is_separate_cfg_step():
850
850
  _debugging_set_buffer(f"{prefix}_buffer_cfg")
851
851
  set_buffer(f"{prefix}_buffer_cfg", buffer)
852
852
  else:
853
853
  _debugging_set_buffer(f"{prefix}_buffer")
854
854
  set_buffer(f"{prefix}_buffer", buffer)
855
855
  else:
856
- if is_separate_classifier_free_guidance_step():
856
+ if is_separate_cfg_step():
857
857
  _debugging_set_buffer(f"{prefix}_buffer_cfg")
858
858
  set_buffer(f"{prefix}_buffer_cfg", buffer)
859
859
  else:
@@ -865,7 +865,7 @@ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
865
865
  def get_Bn_buffer(prefix: str = "Bn"):
866
866
  if is_taylorseer_enabled():
867
867
  # taylorseer, encoder_taylorseer
868
- if is_separate_classifier_free_guidance_step():
868
+ if is_separate_cfg_step():
869
869
  taylorseer, _ = get_cfg_taylorseers()
870
870
  else:
871
871
  taylorseer, _ = get_taylorseers()
@@ -879,13 +879,13 @@ def get_Bn_buffer(prefix: str = "Bn"):
879
879
  "Falling back to default buffer retrieval."
880
880
  )
881
881
  # Fallback to default buffer retrieval
882
- if is_separate_classifier_free_guidance_step():
882
+ if is_separate_cfg_step():
883
883
  _debugging_get_buffer(f"{prefix}_buffer_cfg")
884
884
  return get_buffer(f"{prefix}_buffer_cfg")
885
885
  _debugging_get_buffer(f"{prefix}_buffer")
886
886
  return get_buffer(f"{prefix}_buffer")
887
887
  else:
888
- if is_separate_classifier_free_guidance_step():
888
+ if is_separate_cfg_step():
889
889
  _debugging_get_buffer(f"{prefix}_buffer_cfg")
890
890
  return get_buffer(f"{prefix}_buffer_cfg")
891
891
  _debugging_get_buffer(f"{prefix}_buffer")
@@ -897,7 +897,7 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
897
897
  # This buffer is use for encoder hidden states approximation.
898
898
  if is_encoder_taylorseer_enabled():
899
899
  # taylorseer, encoder_taylorseer
900
- if is_separate_classifier_free_guidance_step():
900
+ if is_separate_cfg_step():
901
901
  _, encoder_taylorseer = get_cfg_taylorseers()
902
902
  else:
903
903
  _, encoder_taylorseer = get_taylorseers()
@@ -911,14 +911,14 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
911
911
  "TaylorSeer is enabled but not set in the cache context. "
912
912
  "Falling back to default buffer retrieval."
913
913
  )
914
- if is_separate_classifier_free_guidance_step():
914
+ if is_separate_cfg_step():
915
915
  _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
916
916
  set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
917
917
  else:
918
918
  _debugging_set_buffer(f"{prefix}_encoder_buffer")
919
919
  set_buffer(f"{prefix}_encoder_buffer", buffer)
920
920
  else:
921
- if is_separate_classifier_free_guidance_step():
921
+ if is_separate_cfg_step():
922
922
  _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
923
923
  set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
924
924
  else:
@@ -929,7 +929,7 @@ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
929
929
  @torch.compiler.disable
930
930
  def get_Bn_encoder_buffer(prefix: str = "Bn"):
931
931
  if is_encoder_taylorseer_enabled():
932
- if is_separate_classifier_free_guidance_step():
932
+ if is_separate_cfg_step():
933
933
  _, encoder_taylorseer = get_cfg_taylorseers()
934
934
  else:
935
935
  _, encoder_taylorseer = get_taylorseers()
@@ -944,13 +944,13 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
944
944
  "Falling back to default buffer retrieval."
945
945
  )
946
946
  # Fallback to default buffer retrieval
947
- if is_separate_classifier_free_guidance_step():
947
+ if is_separate_cfg_step():
948
948
  _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
949
949
  return get_buffer(f"{prefix}_encoder_buffer_cfg")
950
950
  _debugging_get_buffer(f"{prefix}_encoder_buffer")
951
951
  return get_buffer(f"{prefix}_encoder_buffer")
952
952
  else:
953
- if is_separate_classifier_free_guidance_step():
953
+ if is_separate_cfg_step():
954
954
  _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
955
955
  return get_buffer(f"{prefix}_encoder_buffer_cfg")
956
956
  _debugging_get_buffer(f"{prefix}_encoder_buffer")
@@ -1021,7 +1021,7 @@ def get_can_use_cache(
1021
1021
  return False
1022
1022
 
1023
1023
  max_cached_steps = get_max_cached_steps()
1024
- if not is_separate_classifier_free_guidance_step():
1024
+ if not is_separate_cfg_step():
1025
1025
  cached_steps = get_cached_steps()
1026
1026
  else:
1027
1027
  cached_steps = get_cfg_cached_steps()
@@ -0,0 +1,149 @@
1
+ from diffusers import DiffusionPipeline
2
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
3
+ from cache_dit.cache_factory.cache_types import CacheType
4
+ from cache_dit.cache_factory.cache_adapters import BlockAdapter
5
+ from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
6
+
7
+ from cache_dit.logger import init_logger
8
+
9
+ logger = init_logger(__name__)
10
+
11
+
12
+ def enable_cache(
13
+ # BlockAdapter & forward pattern
14
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
15
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
16
+ # Cache context kwargs
17
+ Fn_compute_blocks: int = 8,
18
+ Bn_compute_blocks: int = 0,
19
+ warmup_steps: int = 8,
20
+ max_cached_steps: int = -1,
21
+ residual_diff_threshold: float = 0.08,
22
+ # Cache CFG or not
23
+ do_separate_cfg: bool = False,
24
+ cfg_compute_first: bool = False,
25
+ cfg_diff_compute_separate: bool = False,
26
+ # Hybird TaylorSeer
27
+ enable_taylorseer: bool = False,
28
+ enable_encoder_taylorseer: bool = False,
29
+ taylorseer_cache_type: str = "residual",
30
+ taylorseer_order: int = 2,
31
+ **other_cache_kwargs,
32
+ ) -> DiffusionPipeline:
33
+ r"""
34
+ Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
35
+ that match the specific Input and Output patterns).
36
+
37
+ For a good balance between performance and precision, DBCache is configured by default
38
+ with F8B0, 8 warmup steps, and unlimited cached steps.
39
+
40
+ Args:
41
+ pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
42
+ The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
43
+ For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
44
+ for the usgae of BlockAdapter.
45
+ forward_pattern (`ForwardPattern`, *required*, defaults to `ForwardPattern.Pattern_0`):
46
+ The forward pattern of Transformer block, please check https://github.com/vipshop/cache-dit/tree/main?tab=readme-ov-file#forward-pattern-matching
47
+ for more details.
48
+ Fn_compute_blocks (`int`, *required*, defaults to 8):
49
+ Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
50
+ at time step t, enabling the calculation of a more stable L1 diff and delivering more
51
+ accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
52
+ for more details of DBCache.
53
+ Bn_compute_blocks: (`int`, *required*, defaults to 0):
54
+ Further fuses approximate information in the **last n** Transformer blocks to enhance
55
+ prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
56
+ that use residual cache.
57
+ warmup_steps (`int`, *required*, defaults to 8):
58
+ DBCache does not apply the caching strategy when the number of running steps is less than
59
+ or equal to this value, ensuring the model sufficiently learns basic features during warmup.
60
+ max_cached_steps (`int`, *required*, defaults to -1):
61
+ DBCache disables the caching strategy when the previous cached steps exceed this value to
62
+ prevent precision degradation.
63
+ residual_diff_threshold (`float`, *required*, defaults to 0.08):
64
+ he value of residual diff threshold, a higher value leads to faster performance at the
65
+ cost of lower precision.
66
+ do_separate_cfg (`bool`, *required*, defaults to False):
67
+ Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
68
+ and non-CFG into single forward step, should set do_separate_cfg as False, for example:
69
+ CogVideoX, HunyuanVideo, Mochi, etc.
70
+ cfg_compute_first (`bool`, *required*, defaults to False):
71
+ Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
72
+ 1, 3, 5, ... -> CFG step.
73
+ cfg_diff_compute_separate (`bool`, *required*, defaults to True):
74
+ Compute spearate diff values for CFG and non-CFG step, default True. If False, we will
75
+ use the computed diff from current non-CFG transformer step for current CFG step.
76
+ enable_taylorseer (`bool`, *required*, defaults to False):
77
+ Enable the hybird TaylorSeer for hidden_states or not. We have supported the
78
+ [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm
79
+ to further improve the precision of DBCache in cases where the cached steps are large,
80
+ namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals,
81
+ the feature similarity in diffusion models decreases substantially, significantly
82
+ harming the generation quality.
83
+ enable_encoder_taylorseer (`bool`, *required*, defaults to False):
84
+ Enable the hybird TaylorSeer for encoder_hidden_states or not.
85
+ taylorseer_cache_type (`str`, *required*, defaults to `residual`):
86
+ The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
87
+ taylorseer_order (`int`, *required*, defaults to 2):
88
+ The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
89
+ but may improve precision significantly.
90
+ other_cache_kwargs: (`dict`, *optional*, defaults to {})
91
+ Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_context.py
92
+ for more details.
93
+
94
+ Examples:
95
+ ```py
96
+ >>> import cache_dit
97
+ >>> from diffusers import DiffusionPipeline
98
+ >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
99
+ >>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
100
+ >>> output = pipe(...) # Just call the pipe as normal.
101
+ >>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
102
+ """
103
+
104
+ # Collect cache context kwargs
105
+ cache_context_kwargs = other_cache_kwargs.copy()
106
+ cache_context_kwargs["cache_type"] = CacheType.DBCache
107
+ cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
108
+ cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
109
+ cache_context_kwargs["warmup_steps"] = warmup_steps
110
+ cache_context_kwargs["max_cached_steps"] = max_cached_steps
111
+ cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
112
+ cache_context_kwargs["do_separate_cfg"] = do_separate_cfg
113
+ cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
114
+ cache_context_kwargs["cfg_diff_compute_separate"] = (
115
+ cfg_diff_compute_separate
116
+ )
117
+ cache_context_kwargs["enable_taylorseer"] = enable_taylorseer
118
+ cache_context_kwargs["enable_encoder_taylorseer"] = (
119
+ enable_encoder_taylorseer
120
+ )
121
+ cache_context_kwargs["taylorseer_cache_type"] = taylorseer_cache_type
122
+ if "taylorseer_kwargs" in cache_context_kwargs:
123
+ cache_context_kwargs["taylorseer_kwargs"][
124
+ "n_derivatives"
125
+ ] = taylorseer_order
126
+ else:
127
+ cache_context_kwargs["taylorseer_kwargs"] = {
128
+ "n_derivatives": taylorseer_order
129
+ }
130
+
131
+ if isinstance(pipe_or_adapter, BlockAdapter):
132
+ return UnifiedCacheAdapter.apply(
133
+ pipe=None,
134
+ block_adapter=pipe_or_adapter,
135
+ forward_pattern=forward_pattern,
136
+ **cache_context_kwargs,
137
+ )
138
+ elif isinstance(pipe_or_adapter, DiffusionPipeline):
139
+ return UnifiedCacheAdapter.apply(
140
+ pipe=pipe_or_adapter,
141
+ block_adapter=None,
142
+ forward_pattern=forward_pattern,
143
+ **cache_context_kwargs,
144
+ )
145
+ else:
146
+ raise ValueError(
147
+ "Please pass DiffusionPipeline or BlockAdapter"
148
+ "for the 1's position param: pipe_or_adapter"
149
+ )
@@ -9,62 +9,31 @@ class CacheType(Enum):
9
9
  DBCache = "Dual_Block_Cache"
10
10
 
11
11
  @staticmethod
12
- def type(cache_type: "CacheType | str") -> "CacheType":
13
- if isinstance(cache_type, CacheType):
14
- return cache_type
15
- return CacheType.cache_type(cache_type)
12
+ def type(type_hint: "CacheType | str") -> "CacheType":
13
+ if isinstance(type_hint, CacheType):
14
+ return type_hint
15
+ return cache_type(type_hint)
16
16
 
17
- @staticmethod
18
- def cache_type(cache_type: "CacheType | str") -> "CacheType":
19
- if cache_type is None:
20
- return CacheType.NONE
21
-
22
- if isinstance(cache_type, CacheType):
23
- return cache_type
24
17
 
25
- elif cache_type.lower() in (
26
- "dual_block_cache",
27
- "db_cache",
28
- "dbcache",
29
- "db",
30
- ):
31
- return CacheType.DBCache
18
+ def cache_type(type_hint: "CacheType | str") -> "CacheType":
19
+ if type_hint is None:
32
20
  return CacheType.NONE
33
21
 
34
- @staticmethod
35
- def block_range(start: int, end: int, step: int = 1) -> list[int]:
36
- if start > end or end <= 0 or step <= 1:
37
- return []
38
- # Always compute 0 and end - 1 blocks for DB Cache
39
- return list(
40
- sorted(set([0] + list(range(start, end, step)) + [end - 1]))
41
- )
42
-
43
- @staticmethod
44
- def default_options(cache_type: "CacheType | str") -> dict:
45
- _no_options = {
46
- "cache_type": CacheType.NONE,
47
- }
22
+ if isinstance(type_hint, CacheType):
23
+ return type_hint
48
24
 
49
- _Fn_compute_blocks = 8
50
- _Bn_compute_blocks = 0
25
+ elif type_hint.lower() in (
26
+ "dual_block_cache",
27
+ "db_cache",
28
+ "dbcache",
29
+ "db",
30
+ ):
31
+ return CacheType.DBCache
32
+ return CacheType.NONE
51
33
 
52
- _db_options = {
53
- "cache_type": CacheType.DBCache,
54
- "residual_diff_threshold": 0.12,
55
- "warmup_steps": 8,
56
- "max_cached_steps": -1, # -1 means no limit
57
- "Fn_compute_blocks": _Fn_compute_blocks,
58
- "Bn_compute_blocks": _Bn_compute_blocks,
59
- "max_Fn_compute_blocks": 16,
60
- "max_Bn_compute_blocks": 16,
61
- "Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
62
- "Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
63
- }
64
34
 
65
- if cache_type == CacheType.DBCache:
66
- return _db_options
67
- elif cache_type == CacheType.NONE:
68
- return _no_options
69
- else:
70
- raise ValueError(f"Unknown cache type: {cache_type}")
35
+ def block_range(start: int, end: int, step: int = 1) -> list[int]:
36
+ if start > end or end <= 0 or step <= 1:
37
+ return []
38
+ # Always compute 0 and end - 1 blocks for DB Cache
39
+ return list(sorted(set([0] + list(range(start, end, step)) + [end - 1])))
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/taylorseer.py
2
- # Reference: https://github.com/Shenyi-Z/TaylorSeer/TaylorSeer-FLUX/src/flux/taylor_utils/__init__.py
3
1
  import math
4
2
  import torch
5
3
 
@@ -51,3 +51,7 @@ def load_cache_options_from_yaml(yaml_file_path):
51
51
  )
52
52
  except yaml.YAMLError as e:
53
53
  raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
54
+
55
+
56
+ def load_options(path: str):
57
+ return load_cache_options_from_yaml(path)
@@ -24,14 +24,15 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
24
24
 
25
25
 
26
26
  def set_compile_configs(
27
+ descent_tuning: bool = True,
27
28
  cuda_graphs: bool = False,
28
29
  force_disable_compile_caches: bool = False,
29
30
  use_fast_math: bool = False,
30
31
  **kwargs, # other kwargs
31
32
  ):
32
33
  # Alway increase recompile_limit for dynamic shape compilation
33
- torch._dynamo.config.recompile_limit = 96 # default is 8
34
- torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
34
+ torch._dynamo.config.recompile_limit = 1024 # default is 8
35
+ torch._dynamo.config.accumulated_recompile_limit = 8192 # default is 256
35
36
  # Handle compiler caches
36
37
  # https://github.com/vllm-project/vllm/blob/23baa2180b0ebba5ae94073ba9b8e93f88b75486/vllm/compilation/compiler_interface.py#L270
37
38
  torch._inductor.config.fx_graph_cache = True
@@ -47,6 +48,9 @@ def set_compile_configs(
47
48
  64 if "L20" in torch.cuda.get_device_name() else 300
48
49
  )
49
50
 
51
+ if not descent_tuning:
52
+ return
53
+
50
54
  FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
51
55
  os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
52
56
  == "1"
cache_dit/utils.py CHANGED
@@ -26,14 +26,17 @@ class CacheStats:
26
26
  )
27
27
 
28
28
 
29
- def summary(pipe: DiffusionPipeline, details: bool = False):
29
+ def summary(
30
+ pipe: DiffusionPipeline, details: bool = False, logging: bool = True
31
+ ):
30
32
  cache_stats = CacheStats()
31
33
  pipe_cls_name = pipe.__class__.__name__
32
34
 
33
35
  if hasattr(pipe, "_cache_options"):
34
36
  cache_options = pipe._cache_options
35
37
  cache_stats.cache_options = cache_options
36
- print(f"\n🤗Cache Options: {pipe_cls_name}\n\n{cache_options}")
38
+ if logging:
39
+ print(f"\n🤗Cache Options: {pipe_cls_name}\n\n{cache_options}")
37
40
 
38
41
  if hasattr(pipe.transformer, "_cached_steps"):
39
42
  cached_steps: list[int] = pipe.transformer._cached_steps
@@ -43,7 +46,7 @@ def summary(pipe: DiffusionPipeline, details: bool = False):
43
46
  cache_stats.cached_steps = cached_steps
44
47
  cache_stats.residual_diffs = residual_diffs
45
48
 
46
- if residual_diffs:
49
+ if residual_diffs and logging:
47
50
  diffs_values = list(residual_diffs.values())
48
51
  qmin = np.min(diffs_values)
49
52
  q0 = np.percentile(diffs_values, 0)
@@ -90,7 +93,7 @@ def summary(pipe: DiffusionPipeline, details: bool = False):
90
93
  cache_stats.cfg_cached_steps = cfg_cached_steps
91
94
  cache_stats.cfg_residual_diffs = cfg_residual_diffs
92
95
 
93
- if cfg_residual_diffs:
96
+ if cfg_residual_diffs and logging:
94
97
  cfg_diffs_values = list(cfg_residual_diffs.values())
95
98
  qmin = np.min(cfg_diffs_values)
96
99
  q0 = np.percentile(cfg_diffs_values, 0)
@@ -130,3 +133,29 @@ def summary(pipe: DiffusionPipeline, details: bool = False):
130
133
  )
131
134
 
132
135
  return cache_stats
136
+
137
+
138
+ def strify(pipe_or_stats: DiffusionPipeline | CacheStats):
139
+ if not isinstance(pipe_or_stats, CacheStats):
140
+ stats = summary(pipe_or_stats, logging=False)
141
+ else:
142
+ stats = pipe_or_stats
143
+
144
+ cache_options = stats.cache_options
145
+ cached_steps = len(stats.cached_steps)
146
+
147
+ if not cache_options:
148
+ return "NONE"
149
+
150
+ cache_type_str = (
151
+ f"DBCACHE_F{cache_options['Fn_compute_blocks']}"
152
+ f"B{cache_options['Bn_compute_blocks']}"
153
+ f"W{cache_options['warmup_steps']}"
154
+ f"M{max(0, cache_options['max_cached_steps'])}"
155
+ f"T{int(cache_options['enable_taylorseer'])}"
156
+ f"O{cache_options['taylorseer_kwargs']['n_derivatives']}_"
157
+ f"R{cache_options['residual_diff_threshold']}_"
158
+ f"S{cached_steps}" # skiped steps
159
+ )
160
+
161
+ return cache_type_str