cache-dit 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.8'
21
- __version_tuple__ = version_tuple = (0, 1, 8)
20
+ __version__ = version = '0.2.1'
21
+ __version_tuple__ = version_tuple = (0, 2, 1)
@@ -9,6 +9,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
9
9
  import torch
10
10
 
11
11
  import cache_dit.primitives as DP
12
+ from cache_dit.cache_factory.taylorseer import TaylorSeer
12
13
  from cache_dit.logger import init_logger
13
14
 
14
15
  logger = init_logger(__name__)
@@ -60,6 +61,18 @@ class DBCacheContext:
60
61
  residual_diffs: DefaultDict[str, float] = dataclasses.field(
61
62
  default_factory=lambda: defaultdict(float),
62
63
  )
64
+ # TODO: Support TaylorSeers and SLG in Dual Block Cache
65
+ # TaylorSeers:
66
+ # Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
67
+ # Url: https://arxiv.org/pdf/2503.06923
68
+ taylorseer: Optional[TaylorSeer] = None
69
+ alter_taylorseer: Optional[TaylorSeer] = None
70
+
71
+ # Skip Layer Guidance, SLG
72
+ # https://github.com/huggingface/candle/issues/2588
73
+ slg_layers: Optional[List[int]] = None
74
+ slg_start: float = 0.0
75
+ slg_end: float = 0.1
63
76
 
64
77
  def get_incremental_name(self, name=None):
65
78
  if name is None:
@@ -700,7 +713,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
700
713
  encoder_hidden_states,
701
714
  hidden_states_residual,
702
715
  encoder_hidden_states_residual,
703
- ) = self.call_MN2n_transformer_blocks( # middle
716
+ ) = self.call_Mn_transformer_blocks( # middle
704
717
  hidden_states,
705
718
  encoder_hidden_states,
706
719
  *args,
@@ -772,32 +785,32 @@ class DBCachedTransformerBlocks(torch.nn.Module):
772
785
  return selected_Fn_transformer_blocks
773
786
 
774
787
  @torch.compiler.disable
775
- def _MN2n_single_transformer_blocks(self): # middle
788
+ def _Mn_single_transformer_blocks(self): # middle blocks
776
789
  # M(N-2n): transformer_blocks [n,...] + single_transformer_blocks [0,...,N-n]
777
- selected_MN2n_single_transformer_blocks = []
790
+ selected_Mn_single_transformer_blocks = []
778
791
  if self.single_transformer_blocks is not None:
779
792
  if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
780
- selected_MN2n_single_transformer_blocks = (
793
+ selected_Mn_single_transformer_blocks = (
781
794
  self.single_transformer_blocks
782
795
  )
783
796
  else:
784
- selected_MN2n_single_transformer_blocks = (
797
+ selected_Mn_single_transformer_blocks = (
785
798
  self.single_transformer_blocks[: -Bn_compute_blocks()]
786
799
  )
787
- return selected_MN2n_single_transformer_blocks
800
+ return selected_Mn_single_transformer_blocks
788
801
 
789
802
  @torch.compiler.disable
790
- def _MN2n_transformer_blocks(self):
803
+ def _Mn_transformer_blocks(self): # middle blocks
791
804
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
792
805
  if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
793
- selected_MN2n_transformer_blocks = self.transformer_blocks[
806
+ selected_Mn_transformer_blocks = self.transformer_blocks[
794
807
  Fn_compute_blocks() :
795
808
  ]
796
809
  else:
797
- selected_MN2n_transformer_blocks = self.transformer_blocks[
810
+ selected_Mn_transformer_blocks = self.transformer_blocks[
798
811
  Fn_compute_blocks() : -Bn_compute_blocks()
799
812
  ]
800
- return selected_MN2n_transformer_blocks
813
+ return selected_Mn_transformer_blocks
801
814
 
802
815
  @torch.compiler.disable
803
816
  def _Bn_single_transformer_blocks(self):
@@ -845,7 +858,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
845
858
 
846
859
  return hidden_states, encoder_hidden_states
847
860
 
848
- def call_MN2n_transformer_blocks(
861
+ def call_Mn_transformer_blocks(
849
862
  self,
850
863
  hidden_states: torch.Tensor,
851
864
  encoder_hidden_states: torch.Tensor,
@@ -873,7 +886,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
873
886
  hidden_states = torch.cat(
874
887
  [encoder_hidden_states, hidden_states], dim=1
875
888
  )
876
- for block in self._MN2n_single_transformer_blocks():
889
+ for block in self._Mn_single_transformer_blocks():
877
890
  hidden_states = block(
878
891
  hidden_states,
879
892
  *args,
@@ -887,7 +900,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
887
900
  dim=1,
888
901
  )
889
902
  else:
890
- for block in self._MN2n_transformer_blocks():
903
+ for block in self._Mn_transformer_blocks():
891
904
  hidden_states = block(
892
905
  hidden_states,
893
906
  encoder_hidden_states,
@@ -1016,7 +1029,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1016
1029
 
1017
1030
  def _compute_and_cache_single_transformer_block(
1018
1031
  self,
1019
- i: int, # Block index in the transformer blocks
1032
+ # Block index in the transformer blocks
1033
+ # Bn: 8, block_id should be in [0, 8)
1034
+ block_id: int,
1020
1035
  # Helper inputs for hidden states split and reshape
1021
1036
  original_hidden_states: torch.Tensor,
1022
1037
  original_encoder_hidden_states: torch.Tensor,
@@ -1042,7 +1057,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1042
1057
  )
1043
1058
  # Cache residuals for the non-compute Bn blocks for
1044
1059
  # subsequent cache steps.
1045
- if i not in Bn_compute_blocks_ids():
1060
+ if block_id not in Bn_compute_blocks_ids():
1046
1061
  Bn_i_hidden_states = hidden_states
1047
1062
  (
1048
1063
  Bn_i_hidden_states_residual,
@@ -1057,16 +1072,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1057
1072
  # Save original_hidden_states for diff calculation.
1058
1073
  set_Bn_buffer(
1059
1074
  Bn_i_original_hidden_states,
1060
- prefix=f"Bn_{i}_single_original",
1075
+ prefix=f"Bn_{block_id}_single_original",
1061
1076
  )
1062
1077
 
1063
1078
  set_Bn_buffer(
1064
1079
  Bn_i_hidden_states_residual,
1065
- prefix=f"Bn_{i}_single_residual",
1080
+ prefix=f"Bn_{block_id}_single_residual",
1066
1081
  )
1067
1082
  set_Bn_encoder_buffer(
1068
1083
  Bn_i_encoder_hidden_states_residual,
1069
- prefix=f"Bn_{i}_single_residual",
1084
+ prefix=f"Bn_{block_id}_single_residual",
1070
1085
  )
1071
1086
  del Bn_i_hidden_states
1072
1087
  del Bn_i_hidden_states_residual
@@ -1077,7 +1092,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1077
1092
  else:
1078
1093
  # Cache steps: Reuse the cached residuals.
1079
1094
  # Check if the block is in the Bn_compute_blocks_ids.
1080
- if i in Bn_compute_blocks_ids():
1095
+ if block_id in Bn_compute_blocks_ids():
1081
1096
  hidden_states = block(
1082
1097
  hidden_states,
1083
1098
  *args,
@@ -1091,7 +1106,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1091
1106
  hidden_states, # curr step
1092
1107
  parallelized=self._is_parallelized(),
1093
1108
  threshold=non_compute_blocks_diff_threshold(),
1094
- prefix=f"Bn_{i}_single_original", # prev step
1109
+ prefix=f"Bn_{block_id}_single_original", # prev step
1095
1110
  ):
1096
1111
  Bn_i_original_hidden_states = hidden_states
1097
1112
  (
@@ -1106,7 +1121,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1106
1121
  apply_hidden_states_residual(
1107
1122
  Bn_i_original_hidden_states,
1108
1123
  Bn_i_original_encoder_hidden_states,
1109
- prefix=f"Bn_{i}_single_residual",
1124
+ prefix=f"Bn_{block_id}_single_residual",
1110
1125
  )
1111
1126
  )
1112
1127
  hidden_states = torch.cat(
@@ -1125,7 +1140,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1125
1140
 
1126
1141
  def _compute_and_cache_transformer_block(
1127
1142
  self,
1128
- i: int, # Block index in the transformer blocks
1143
+ # Block index in the transformer blocks
1144
+ # Bn: 8, block_id should be in [0, 8)
1145
+ block_id: int,
1129
1146
  # Below are the inputs to the block
1130
1147
  block, # The transformer block to be executed
1131
1148
  hidden_states: torch.Tensor,
@@ -1158,7 +1175,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1158
1175
  )
1159
1176
  # Cache residuals for the non-compute Bn blocks for
1160
1177
  # subsequent cache steps.
1161
- if i not in Bn_compute_blocks_ids():
1178
+ if block_id not in Bn_compute_blocks_ids():
1162
1179
  Bn_i_hidden_states_residual = (
1163
1180
  hidden_states - Bn_i_original_hidden_states
1164
1181
  )
@@ -1169,16 +1186,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1169
1186
  # Save original_hidden_states for diff calculation.
1170
1187
  set_Bn_buffer(
1171
1188
  Bn_i_original_hidden_states,
1172
- prefix=f"Bn_{i}_original",
1189
+ prefix=f"Bn_{block_id}_original",
1173
1190
  )
1174
1191
 
1175
1192
  set_Bn_buffer(
1176
1193
  Bn_i_hidden_states_residual,
1177
- prefix=f"Bn_{i}_residual",
1194
+ prefix=f"Bn_{block_id}_residual",
1178
1195
  )
1179
1196
  set_Bn_encoder_buffer(
1180
1197
  Bn_i_encoder_hidden_states_residual,
1181
- prefix=f"Bn_{i}_residual",
1198
+ prefix=f"Bn_{block_id}_residual",
1182
1199
  )
1183
1200
  del Bn_i_hidden_states_residual
1184
1201
  del Bn_i_encoder_hidden_states_residual
@@ -1189,7 +1206,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1189
1206
  else:
1190
1207
  # Cache steps: Reuse the cached residuals.
1191
1208
  # Check if the block is in the Bn_compute_blocks_ids.
1192
- if i in Bn_compute_blocks_ids():
1209
+ if block_id in Bn_compute_blocks_ids():
1193
1210
  hidden_states = block(
1194
1211
  hidden_states,
1195
1212
  encoder_hidden_states,
@@ -1211,13 +1228,13 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1211
1228
  hidden_states, # curr step
1212
1229
  parallelized=self._is_parallelized(),
1213
1230
  threshold=non_compute_blocks_diff_threshold(),
1214
- prefix=f"Bn_{i}_original", # prev step
1231
+ prefix=f"Bn_{block_id}_original", # prev step
1215
1232
  ):
1216
1233
  hidden_states, encoder_hidden_states = (
1217
1234
  apply_hidden_states_residual(
1218
1235
  hidden_states,
1219
1236
  encoder_hidden_states,
1220
- prefix=f"Bn_{i}_residual",
1237
+ prefix=f"Bn_{block_id}_residual",
1221
1238
  )
1222
1239
  )
1223
1240
  else:
@@ -13,6 +13,10 @@ def apply_db_cache_on_transformer(transformer, *args, **kwargs):
13
13
  adapter_name = "mochi"
14
14
  elif transformer_cls_name.startswith("CogVideoX"):
15
15
  adapter_name = "cogvideox"
16
+ elif transformer_cls_name.startswith("Wan"):
17
+ adapter_name = "wan"
18
+ elif transformer_cls_name.startswith("HunyuanVideo"):
19
+ adapter_name = "hunyuan_video"
16
20
  else:
17
21
  raise ValueError(
18
22
  f"Unknown transformer class name: {transformer_cls_name}"
@@ -35,6 +39,10 @@ def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
35
39
  adapter_name = "mochi"
36
40
  elif pipe_cls_name.startswith("CogVideoX"):
37
41
  adapter_name = "cogvideox"
42
+ elif pipe_cls_name.startswith("Wan"):
43
+ adapter_name = "wan"
44
+ elif pipe_cls_name.startswith("HunyuanVideo"):
45
+ adapter_name = "hunyuan_video"
38
46
  else:
39
47
  raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
40
48
 
@@ -0,0 +1,295 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/hunyuan_video.py
2
+ import functools
3
+ import unittest
4
+ from typing import Any, Dict, Optional, Union
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
8
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
9
+ from diffusers.utils import (
10
+ scale_lora_layers,
11
+ unscale_lora_layers,
12
+ USE_PEFT_BACKEND,
13
+ )
14
+
15
+ from cache_dit.cache_factory.dual_block_cache import cache_context
16
+ from cache_dit.logger import init_logger
17
+
18
+ try:
19
+ from para_attn.para_attn_interface import SparseKVAttnMode
20
+
21
+ def is_sparse_kv_attn_available():
22
+ return True
23
+
24
+ except ImportError:
25
+
26
+ class SparseKVAttnMode:
27
+ def __enter__(self):
28
+ pass
29
+
30
+ def __exit__(self, exc_type, exc_value, traceback):
31
+ pass
32
+
33
+ def is_sparse_kv_attn_available():
34
+ return False
35
+
36
+
37
+ logger = init_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ def apply_db_cache_on_transformer(
41
+ transformer: HunyuanVideoTransformer3DModel,
42
+ ):
43
+ if getattr(transformer, "_is_cached", False):
44
+ return transformer
45
+
46
+ cached_transformer_blocks = torch.nn.ModuleList(
47
+ [
48
+ cache_context.DBCachedTransformerBlocks(
49
+ transformer.transformer_blocks
50
+ + transformer.single_transformer_blocks,
51
+ transformer=transformer,
52
+ )
53
+ ]
54
+ )
55
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
56
+
57
+ original_forward = transformer.forward
58
+
59
+ @functools.wraps(transformer.__class__.forward)
60
+ def new_forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ timestep: torch.LongTensor,
64
+ encoder_hidden_states: torch.Tensor,
65
+ encoder_attention_mask: torch.Tensor,
66
+ pooled_projections: torch.Tensor,
67
+ guidance: torch.Tensor = None,
68
+ attention_kwargs: Optional[Dict[str, Any]] = None,
69
+ return_dict: bool = True,
70
+ **kwargs,
71
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
72
+ with (
73
+ unittest.mock.patch.object(
74
+ self,
75
+ "transformer_blocks",
76
+ cached_transformer_blocks,
77
+ ),
78
+ unittest.mock.patch.object(
79
+ self,
80
+ "single_transformer_blocks",
81
+ dummy_single_transformer_blocks,
82
+ ),
83
+ ):
84
+ if getattr(self, "_is_parallelized", False):
85
+ return original_forward(
86
+ hidden_states,
87
+ timestep,
88
+ encoder_hidden_states,
89
+ encoder_attention_mask,
90
+ pooled_projections,
91
+ guidance=guidance,
92
+ attention_kwargs=attention_kwargs,
93
+ return_dict=return_dict,
94
+ **kwargs,
95
+ )
96
+ else:
97
+ if attention_kwargs is not None:
98
+ attention_kwargs = attention_kwargs.copy()
99
+ lora_scale = attention_kwargs.pop("scale", 1.0)
100
+ else:
101
+ lora_scale = 1.0
102
+
103
+ if USE_PEFT_BACKEND:
104
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
105
+ scale_lora_layers(self, lora_scale)
106
+ else:
107
+ if (
108
+ attention_kwargs is not None
109
+ and attention_kwargs.get("scale", None) is not None
110
+ ):
111
+ logger.warning(
112
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
113
+ )
114
+
115
+ batch_size, num_channels, num_frames, height, width = (
116
+ hidden_states.shape
117
+ )
118
+ p, p_t = self.config.patch_size, self.config.patch_size_t
119
+ post_patch_num_frames = num_frames // p_t
120
+ post_patch_height = height // p
121
+ post_patch_width = width // p
122
+
123
+ # 1. RoPE
124
+ image_rotary_emb = self.rope(hidden_states)
125
+
126
+ # 2. Conditional embeddings
127
+ temb = self.time_text_embed(
128
+ timestep, guidance, pooled_projections
129
+ )
130
+ hidden_states = self.x_embedder(hidden_states)
131
+ encoder_hidden_states = self.context_embedder(
132
+ encoder_hidden_states, timestep, encoder_attention_mask
133
+ )
134
+
135
+ # 3. Attention mask preparation
136
+ latent_sequence_length = hidden_states.shape[1]
137
+ latent_attention_mask = torch.ones(
138
+ batch_size,
139
+ 1,
140
+ latent_sequence_length,
141
+ device=hidden_states.device,
142
+ dtype=torch.bool,
143
+ ) # [B, 1, N]
144
+ attention_mask = torch.cat(
145
+ [
146
+ latent_attention_mask,
147
+ encoder_attention_mask.unsqueeze(1).to(torch.bool),
148
+ ],
149
+ dim=-1,
150
+ ) # [B, 1, N + M]
151
+
152
+ with SparseKVAttnMode():
153
+ # 4. Transformer blocks
154
+ hidden_states, encoder_hidden_states = (
155
+ self.call_transformer_blocks(
156
+ hidden_states,
157
+ encoder_hidden_states,
158
+ temb,
159
+ attention_mask,
160
+ image_rotary_emb,
161
+ )
162
+ )
163
+
164
+ # 5. Output projection
165
+ hidden_states = self.norm_out(hidden_states, temb)
166
+ hidden_states = self.proj_out(hidden_states)
167
+
168
+ hidden_states = hidden_states.reshape(
169
+ batch_size,
170
+ post_patch_num_frames,
171
+ post_patch_height,
172
+ post_patch_width,
173
+ -1,
174
+ p_t,
175
+ p,
176
+ p,
177
+ )
178
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
179
+ hidden_states = (
180
+ hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
181
+ )
182
+
183
+ hidden_states = hidden_states.to(timestep.dtype)
184
+
185
+ if USE_PEFT_BACKEND:
186
+ # remove `lora_scale` from each PEFT layer
187
+ unscale_lora_layers(self, lora_scale)
188
+
189
+ if not return_dict:
190
+ return (hidden_states,)
191
+
192
+ return Transformer2DModelOutput(sample=hidden_states)
193
+
194
+ transformer.forward = new_forward.__get__(transformer)
195
+
196
+ def call_transformer_blocks(
197
+ self, hidden_states, encoder_hidden_states, *args, **kwargs
198
+ ):
199
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
200
+
201
+ def create_custom_forward(module, return_dict=None):
202
+ def custom_forward(*inputs):
203
+ if return_dict is not None:
204
+ return module(*inputs, return_dict=return_dict)
205
+ else:
206
+ return module(*inputs)
207
+
208
+ return custom_forward
209
+
210
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
211
+
212
+ for block in self.transformer_blocks:
213
+ hidden_states, encoder_hidden_states = (
214
+ torch.utils.checkpoint.checkpoint(
215
+ create_custom_forward(block),
216
+ hidden_states,
217
+ encoder_hidden_states,
218
+ *args,
219
+ **kwargs,
220
+ **ckpt_kwargs,
221
+ )
222
+ )
223
+
224
+ for block in self.single_transformer_blocks:
225
+ hidden_states, encoder_hidden_states = (
226
+ torch.utils.checkpoint.checkpoint(
227
+ create_custom_forward(block),
228
+ hidden_states,
229
+ encoder_hidden_states,
230
+ *args,
231
+ **kwargs,
232
+ **ckpt_kwargs,
233
+ )
234
+ )
235
+
236
+ else:
237
+ for block in self.transformer_blocks:
238
+ hidden_states, encoder_hidden_states = block(
239
+ hidden_states, encoder_hidden_states, *args, **kwargs
240
+ )
241
+
242
+ for block in self.single_transformer_blocks:
243
+ hidden_states, encoder_hidden_states = block(
244
+ hidden_states, encoder_hidden_states, *args, **kwargs
245
+ )
246
+
247
+ return hidden_states, encoder_hidden_states
248
+
249
+ transformer.call_transformer_blocks = call_transformer_blocks.__get__(
250
+ transformer
251
+ )
252
+
253
+ transformer._is_cached = True
254
+
255
+ return transformer
256
+
257
+
258
+ def apply_db_cache_on_pipe(
259
+ pipe: DiffusionPipeline,
260
+ *,
261
+ shallow_patch: bool = False,
262
+ residual_diff_threshold=0.06,
263
+ downsample_factor=1,
264
+ warmup_steps=0,
265
+ max_cached_steps=-1,
266
+ **kwargs,
267
+ ):
268
+ cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
269
+ default_attrs={
270
+ "residual_diff_threshold": residual_diff_threshold,
271
+ "downsample_factor": downsample_factor,
272
+ "warmup_steps": warmup_steps,
273
+ "max_cached_steps": max_cached_steps,
274
+ },
275
+ **kwargs,
276
+ )
277
+ if not getattr(pipe, "_is_cached", False):
278
+ original_call = pipe.__class__.__call__
279
+
280
+ @functools.wraps(original_call)
281
+ def new_call(self, *args, **kwargs):
282
+ with cache_context.cache_context(
283
+ cache_context.create_cache_context(
284
+ **cache_kwargs,
285
+ )
286
+ ):
287
+ return original_call(self, *args, **kwargs)
288
+
289
+ pipe.__class__.__call__ = new_call
290
+ pipe.__class__._is_cached = True
291
+
292
+ if not shallow_patch:
293
+ apply_db_cache_on_transformer(pipe.transformer, **kwargs)
294
+
295
+ return pipe
@@ -0,0 +1,99 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
2
+
3
+ import functools
4
+ import unittest
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, WanTransformer3DModel
8
+
9
+ from cache_dit.cache_factory.dual_block_cache import cache_context
10
+
11
+
12
+ def apply_db_cache_on_transformer(
13
+ transformer: WanTransformer3DModel,
14
+ ):
15
+ if getattr(transformer, "_is_cached", False):
16
+ return transformer
17
+
18
+ blocks = torch.nn.ModuleList(
19
+ [
20
+ cache_context.DBCachedTransformerBlocks(
21
+ transformer.blocks,
22
+ transformer=transformer,
23
+ return_hidden_states_only=True,
24
+ )
25
+ ]
26
+ )
27
+
28
+ original_forward = transformer.forward
29
+
30
+ @functools.wraps(transformer.__class__.forward)
31
+ def new_forward(
32
+ self,
33
+ *args,
34
+ **kwargs,
35
+ ):
36
+ with unittest.mock.patch.object(
37
+ self,
38
+ "blocks",
39
+ blocks,
40
+ ):
41
+ return original_forward(
42
+ *args,
43
+ **kwargs,
44
+ )
45
+
46
+ transformer.forward = new_forward.__get__(transformer)
47
+
48
+ transformer._is_cached = True
49
+
50
+ return transformer
51
+
52
+
53
+ def apply_db_cache_on_pipe(
54
+ pipe: DiffusionPipeline,
55
+ *,
56
+ shallow_patch: bool = False,
57
+ residual_diff_threshold=0.03,
58
+ downsample_factor=1,
59
+ # SLG is not supported in WAN with DBCache yet
60
+ # slg_layers=None,
61
+ # slg_start: float = 0.0,
62
+ # slg_end: float = 0.1,
63
+ warmup_steps=0,
64
+ max_cached_steps=-1,
65
+ **kwargs,
66
+ ):
67
+ cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
68
+ default_attrs={
69
+ "residual_diff_threshold": residual_diff_threshold,
70
+ "downsample_factor": downsample_factor,
71
+ # "enable_alter_cache": True,
72
+ # "slg_layers": slg_layers,
73
+ # "slg_start": slg_start,
74
+ # "slg_end": slg_end,
75
+ "num_inference_steps": kwargs.get("num_inference_steps", 50),
76
+ "warmup_steps": warmup_steps,
77
+ "max_cached_steps": max_cached_steps,
78
+ },
79
+ **kwargs,
80
+ )
81
+ if not getattr(pipe, "_is_cached", False):
82
+ original_call = pipe.__class__.__call__
83
+
84
+ @functools.wraps(original_call)
85
+ def new_call(self, *args, **kwargs):
86
+ with cache_context.cache_context(
87
+ cache_context.create_cache_context(
88
+ **cache_kwargs,
89
+ )
90
+ ):
91
+ return original_call(self, *args, **kwargs)
92
+
93
+ pipe.__class__.__call__ = new_call
94
+ pipe.__class__._is_cached = True
95
+
96
+ if not shallow_patch:
97
+ apply_db_cache_on_transformer(pipe.transformer, **kwargs)
98
+
99
+ return pipe