cache-dit 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +46 -29
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +8 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +99 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +12 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +99 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +295 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +2 -2
- {cache_dit-0.1.8.dist-info → cache_dit-0.2.1.dist-info}/METADATA +50 -60
- {cache_dit-0.1.8.dist-info → cache_dit-0.2.1.dist-info}/RECORD +16 -11
- {cache_dit-0.1.8.dist-info → cache_dit-0.2.1.dist-info}/WHEEL +0 -0
- {cache_dit-0.1.8.dist-info → cache_dit-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.1.8.dist-info → cache_dit-0.2.1.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
import cache_dit.primitives as DP
|
|
12
|
+
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
12
13
|
from cache_dit.logger import init_logger
|
|
13
14
|
|
|
14
15
|
logger = init_logger(__name__)
|
|
@@ -60,6 +61,18 @@ class DBCacheContext:
|
|
|
60
61
|
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
61
62
|
default_factory=lambda: defaultdict(float),
|
|
62
63
|
)
|
|
64
|
+
# TODO: Support TaylorSeers and SLG in Dual Block Cache
|
|
65
|
+
# TaylorSeers:
|
|
66
|
+
# Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
67
|
+
# Url: https://arxiv.org/pdf/2503.06923
|
|
68
|
+
taylorseer: Optional[TaylorSeer] = None
|
|
69
|
+
alter_taylorseer: Optional[TaylorSeer] = None
|
|
70
|
+
|
|
71
|
+
# Skip Layer Guidance, SLG
|
|
72
|
+
# https://github.com/huggingface/candle/issues/2588
|
|
73
|
+
slg_layers: Optional[List[int]] = None
|
|
74
|
+
slg_start: float = 0.0
|
|
75
|
+
slg_end: float = 0.1
|
|
63
76
|
|
|
64
77
|
def get_incremental_name(self, name=None):
|
|
65
78
|
if name is None:
|
|
@@ -700,7 +713,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
700
713
|
encoder_hidden_states,
|
|
701
714
|
hidden_states_residual,
|
|
702
715
|
encoder_hidden_states_residual,
|
|
703
|
-
) = self.
|
|
716
|
+
) = self.call_Mn_transformer_blocks( # middle
|
|
704
717
|
hidden_states,
|
|
705
718
|
encoder_hidden_states,
|
|
706
719
|
*args,
|
|
@@ -772,32 +785,32 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
772
785
|
return selected_Fn_transformer_blocks
|
|
773
786
|
|
|
774
787
|
@torch.compiler.disable
|
|
775
|
-
def
|
|
788
|
+
def _Mn_single_transformer_blocks(self): # middle blocks
|
|
776
789
|
# M(N-2n): transformer_blocks [n,...] + single_transformer_blocks [0,...,N-n]
|
|
777
|
-
|
|
790
|
+
selected_Mn_single_transformer_blocks = []
|
|
778
791
|
if self.single_transformer_blocks is not None:
|
|
779
792
|
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
780
|
-
|
|
793
|
+
selected_Mn_single_transformer_blocks = (
|
|
781
794
|
self.single_transformer_blocks
|
|
782
795
|
)
|
|
783
796
|
else:
|
|
784
|
-
|
|
797
|
+
selected_Mn_single_transformer_blocks = (
|
|
785
798
|
self.single_transformer_blocks[: -Bn_compute_blocks()]
|
|
786
799
|
)
|
|
787
|
-
return
|
|
800
|
+
return selected_Mn_single_transformer_blocks
|
|
788
801
|
|
|
789
802
|
@torch.compiler.disable
|
|
790
|
-
def
|
|
803
|
+
def _Mn_transformer_blocks(self): # middle blocks
|
|
791
804
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
792
805
|
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
793
|
-
|
|
806
|
+
selected_Mn_transformer_blocks = self.transformer_blocks[
|
|
794
807
|
Fn_compute_blocks() :
|
|
795
808
|
]
|
|
796
809
|
else:
|
|
797
|
-
|
|
810
|
+
selected_Mn_transformer_blocks = self.transformer_blocks[
|
|
798
811
|
Fn_compute_blocks() : -Bn_compute_blocks()
|
|
799
812
|
]
|
|
800
|
-
return
|
|
813
|
+
return selected_Mn_transformer_blocks
|
|
801
814
|
|
|
802
815
|
@torch.compiler.disable
|
|
803
816
|
def _Bn_single_transformer_blocks(self):
|
|
@@ -845,7 +858,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
845
858
|
|
|
846
859
|
return hidden_states, encoder_hidden_states
|
|
847
860
|
|
|
848
|
-
def
|
|
861
|
+
def call_Mn_transformer_blocks(
|
|
849
862
|
self,
|
|
850
863
|
hidden_states: torch.Tensor,
|
|
851
864
|
encoder_hidden_states: torch.Tensor,
|
|
@@ -873,7 +886,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
873
886
|
hidden_states = torch.cat(
|
|
874
887
|
[encoder_hidden_states, hidden_states], dim=1
|
|
875
888
|
)
|
|
876
|
-
for block in self.
|
|
889
|
+
for block in self._Mn_single_transformer_blocks():
|
|
877
890
|
hidden_states = block(
|
|
878
891
|
hidden_states,
|
|
879
892
|
*args,
|
|
@@ -887,7 +900,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
887
900
|
dim=1,
|
|
888
901
|
)
|
|
889
902
|
else:
|
|
890
|
-
for block in self.
|
|
903
|
+
for block in self._Mn_transformer_blocks():
|
|
891
904
|
hidden_states = block(
|
|
892
905
|
hidden_states,
|
|
893
906
|
encoder_hidden_states,
|
|
@@ -1016,7 +1029,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1016
1029
|
|
|
1017
1030
|
def _compute_and_cache_single_transformer_block(
|
|
1018
1031
|
self,
|
|
1019
|
-
|
|
1032
|
+
# Block index in the transformer blocks
|
|
1033
|
+
# Bn: 8, block_id should be in [0, 8)
|
|
1034
|
+
block_id: int,
|
|
1020
1035
|
# Helper inputs for hidden states split and reshape
|
|
1021
1036
|
original_hidden_states: torch.Tensor,
|
|
1022
1037
|
original_encoder_hidden_states: torch.Tensor,
|
|
@@ -1042,7 +1057,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1042
1057
|
)
|
|
1043
1058
|
# Cache residuals for the non-compute Bn blocks for
|
|
1044
1059
|
# subsequent cache steps.
|
|
1045
|
-
if
|
|
1060
|
+
if block_id not in Bn_compute_blocks_ids():
|
|
1046
1061
|
Bn_i_hidden_states = hidden_states
|
|
1047
1062
|
(
|
|
1048
1063
|
Bn_i_hidden_states_residual,
|
|
@@ -1057,16 +1072,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1057
1072
|
# Save original_hidden_states for diff calculation.
|
|
1058
1073
|
set_Bn_buffer(
|
|
1059
1074
|
Bn_i_original_hidden_states,
|
|
1060
|
-
prefix=f"Bn_{
|
|
1075
|
+
prefix=f"Bn_{block_id}_single_original",
|
|
1061
1076
|
)
|
|
1062
1077
|
|
|
1063
1078
|
set_Bn_buffer(
|
|
1064
1079
|
Bn_i_hidden_states_residual,
|
|
1065
|
-
prefix=f"Bn_{
|
|
1080
|
+
prefix=f"Bn_{block_id}_single_residual",
|
|
1066
1081
|
)
|
|
1067
1082
|
set_Bn_encoder_buffer(
|
|
1068
1083
|
Bn_i_encoder_hidden_states_residual,
|
|
1069
|
-
prefix=f"Bn_{
|
|
1084
|
+
prefix=f"Bn_{block_id}_single_residual",
|
|
1070
1085
|
)
|
|
1071
1086
|
del Bn_i_hidden_states
|
|
1072
1087
|
del Bn_i_hidden_states_residual
|
|
@@ -1077,7 +1092,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1077
1092
|
else:
|
|
1078
1093
|
# Cache steps: Reuse the cached residuals.
|
|
1079
1094
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1080
|
-
if
|
|
1095
|
+
if block_id in Bn_compute_blocks_ids():
|
|
1081
1096
|
hidden_states = block(
|
|
1082
1097
|
hidden_states,
|
|
1083
1098
|
*args,
|
|
@@ -1091,7 +1106,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1091
1106
|
hidden_states, # curr step
|
|
1092
1107
|
parallelized=self._is_parallelized(),
|
|
1093
1108
|
threshold=non_compute_blocks_diff_threshold(),
|
|
1094
|
-
prefix=f"Bn_{
|
|
1109
|
+
prefix=f"Bn_{block_id}_single_original", # prev step
|
|
1095
1110
|
):
|
|
1096
1111
|
Bn_i_original_hidden_states = hidden_states
|
|
1097
1112
|
(
|
|
@@ -1106,7 +1121,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1106
1121
|
apply_hidden_states_residual(
|
|
1107
1122
|
Bn_i_original_hidden_states,
|
|
1108
1123
|
Bn_i_original_encoder_hidden_states,
|
|
1109
|
-
prefix=f"Bn_{
|
|
1124
|
+
prefix=f"Bn_{block_id}_single_residual",
|
|
1110
1125
|
)
|
|
1111
1126
|
)
|
|
1112
1127
|
hidden_states = torch.cat(
|
|
@@ -1125,7 +1140,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1125
1140
|
|
|
1126
1141
|
def _compute_and_cache_transformer_block(
|
|
1127
1142
|
self,
|
|
1128
|
-
|
|
1143
|
+
# Block index in the transformer blocks
|
|
1144
|
+
# Bn: 8, block_id should be in [0, 8)
|
|
1145
|
+
block_id: int,
|
|
1129
1146
|
# Below are the inputs to the block
|
|
1130
1147
|
block, # The transformer block to be executed
|
|
1131
1148
|
hidden_states: torch.Tensor,
|
|
@@ -1158,7 +1175,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1158
1175
|
)
|
|
1159
1176
|
# Cache residuals for the non-compute Bn blocks for
|
|
1160
1177
|
# subsequent cache steps.
|
|
1161
|
-
if
|
|
1178
|
+
if block_id not in Bn_compute_blocks_ids():
|
|
1162
1179
|
Bn_i_hidden_states_residual = (
|
|
1163
1180
|
hidden_states - Bn_i_original_hidden_states
|
|
1164
1181
|
)
|
|
@@ -1169,16 +1186,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1169
1186
|
# Save original_hidden_states for diff calculation.
|
|
1170
1187
|
set_Bn_buffer(
|
|
1171
1188
|
Bn_i_original_hidden_states,
|
|
1172
|
-
prefix=f"Bn_{
|
|
1189
|
+
prefix=f"Bn_{block_id}_original",
|
|
1173
1190
|
)
|
|
1174
1191
|
|
|
1175
1192
|
set_Bn_buffer(
|
|
1176
1193
|
Bn_i_hidden_states_residual,
|
|
1177
|
-
prefix=f"Bn_{
|
|
1194
|
+
prefix=f"Bn_{block_id}_residual",
|
|
1178
1195
|
)
|
|
1179
1196
|
set_Bn_encoder_buffer(
|
|
1180
1197
|
Bn_i_encoder_hidden_states_residual,
|
|
1181
|
-
prefix=f"Bn_{
|
|
1198
|
+
prefix=f"Bn_{block_id}_residual",
|
|
1182
1199
|
)
|
|
1183
1200
|
del Bn_i_hidden_states_residual
|
|
1184
1201
|
del Bn_i_encoder_hidden_states_residual
|
|
@@ -1189,7 +1206,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1189
1206
|
else:
|
|
1190
1207
|
# Cache steps: Reuse the cached residuals.
|
|
1191
1208
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1192
|
-
if
|
|
1209
|
+
if block_id in Bn_compute_blocks_ids():
|
|
1193
1210
|
hidden_states = block(
|
|
1194
1211
|
hidden_states,
|
|
1195
1212
|
encoder_hidden_states,
|
|
@@ -1211,13 +1228,13 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1211
1228
|
hidden_states, # curr step
|
|
1212
1229
|
parallelized=self._is_parallelized(),
|
|
1213
1230
|
threshold=non_compute_blocks_diff_threshold(),
|
|
1214
|
-
prefix=f"Bn_{
|
|
1231
|
+
prefix=f"Bn_{block_id}_original", # prev step
|
|
1215
1232
|
):
|
|
1216
1233
|
hidden_states, encoder_hidden_states = (
|
|
1217
1234
|
apply_hidden_states_residual(
|
|
1218
1235
|
hidden_states,
|
|
1219
1236
|
encoder_hidden_states,
|
|
1220
|
-
prefix=f"Bn_{
|
|
1237
|
+
prefix=f"Bn_{block_id}_residual",
|
|
1221
1238
|
)
|
|
1222
1239
|
)
|
|
1223
1240
|
else:
|
|
@@ -13,6 +13,10 @@ def apply_db_cache_on_transformer(transformer, *args, **kwargs):
|
|
|
13
13
|
adapter_name = "mochi"
|
|
14
14
|
elif transformer_cls_name.startswith("CogVideoX"):
|
|
15
15
|
adapter_name = "cogvideox"
|
|
16
|
+
elif transformer_cls_name.startswith("Wan"):
|
|
17
|
+
adapter_name = "wan"
|
|
18
|
+
elif transformer_cls_name.startswith("HunyuanVideo"):
|
|
19
|
+
adapter_name = "hunyuan_video"
|
|
16
20
|
else:
|
|
17
21
|
raise ValueError(
|
|
18
22
|
f"Unknown transformer class name: {transformer_cls_name}"
|
|
@@ -35,6 +39,10 @@ def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
|
35
39
|
adapter_name = "mochi"
|
|
36
40
|
elif pipe_cls_name.startswith("CogVideoX"):
|
|
37
41
|
adapter_name = "cogvideox"
|
|
42
|
+
elif pipe_cls_name.startswith("Wan"):
|
|
43
|
+
adapter_name = "wan"
|
|
44
|
+
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
45
|
+
adapter_name = "hunyuan_video"
|
|
38
46
|
else:
|
|
39
47
|
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
40
48
|
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/hunyuan_video.py
|
|
2
|
+
import functools
|
|
3
|
+
import unittest
|
|
4
|
+
from typing import Any, Dict, Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
|
|
8
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
9
|
+
from diffusers.utils import (
|
|
10
|
+
scale_lora_layers,
|
|
11
|
+
unscale_lora_layers,
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from cache_dit.cache_factory.dual_block_cache import cache_context
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from para_attn.para_attn_interface import SparseKVAttnMode
|
|
20
|
+
|
|
21
|
+
def is_sparse_kv_attn_available():
|
|
22
|
+
return True
|
|
23
|
+
|
|
24
|
+
except ImportError:
|
|
25
|
+
|
|
26
|
+
class SparseKVAttnMode:
|
|
27
|
+
def __enter__(self):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def is_sparse_kv_attn_available():
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = init_logger(__name__) # pylint: disable=invalid-name
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def apply_db_cache_on_transformer(
|
|
41
|
+
transformer: HunyuanVideoTransformer3DModel,
|
|
42
|
+
):
|
|
43
|
+
if getattr(transformer, "_is_cached", False):
|
|
44
|
+
return transformer
|
|
45
|
+
|
|
46
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
47
|
+
[
|
|
48
|
+
cache_context.DBCachedTransformerBlocks(
|
|
49
|
+
transformer.transformer_blocks
|
|
50
|
+
+ transformer.single_transformer_blocks,
|
|
51
|
+
transformer=transformer,
|
|
52
|
+
)
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
|
56
|
+
|
|
57
|
+
original_forward = transformer.forward
|
|
58
|
+
|
|
59
|
+
@functools.wraps(transformer.__class__.forward)
|
|
60
|
+
def new_forward(
|
|
61
|
+
self,
|
|
62
|
+
hidden_states: torch.Tensor,
|
|
63
|
+
timestep: torch.LongTensor,
|
|
64
|
+
encoder_hidden_states: torch.Tensor,
|
|
65
|
+
encoder_attention_mask: torch.Tensor,
|
|
66
|
+
pooled_projections: torch.Tensor,
|
|
67
|
+
guidance: torch.Tensor = None,
|
|
68
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
69
|
+
return_dict: bool = True,
|
|
70
|
+
**kwargs,
|
|
71
|
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
72
|
+
with (
|
|
73
|
+
unittest.mock.patch.object(
|
|
74
|
+
self,
|
|
75
|
+
"transformer_blocks",
|
|
76
|
+
cached_transformer_blocks,
|
|
77
|
+
),
|
|
78
|
+
unittest.mock.patch.object(
|
|
79
|
+
self,
|
|
80
|
+
"single_transformer_blocks",
|
|
81
|
+
dummy_single_transformer_blocks,
|
|
82
|
+
),
|
|
83
|
+
):
|
|
84
|
+
if getattr(self, "_is_parallelized", False):
|
|
85
|
+
return original_forward(
|
|
86
|
+
hidden_states,
|
|
87
|
+
timestep,
|
|
88
|
+
encoder_hidden_states,
|
|
89
|
+
encoder_attention_mask,
|
|
90
|
+
pooled_projections,
|
|
91
|
+
guidance=guidance,
|
|
92
|
+
attention_kwargs=attention_kwargs,
|
|
93
|
+
return_dict=return_dict,
|
|
94
|
+
**kwargs,
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
if attention_kwargs is not None:
|
|
98
|
+
attention_kwargs = attention_kwargs.copy()
|
|
99
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
100
|
+
else:
|
|
101
|
+
lora_scale = 1.0
|
|
102
|
+
|
|
103
|
+
if USE_PEFT_BACKEND:
|
|
104
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
105
|
+
scale_lora_layers(self, lora_scale)
|
|
106
|
+
else:
|
|
107
|
+
if (
|
|
108
|
+
attention_kwargs is not None
|
|
109
|
+
and attention_kwargs.get("scale", None) is not None
|
|
110
|
+
):
|
|
111
|
+
logger.warning(
|
|
112
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
batch_size, num_channels, num_frames, height, width = (
|
|
116
|
+
hidden_states.shape
|
|
117
|
+
)
|
|
118
|
+
p, p_t = self.config.patch_size, self.config.patch_size_t
|
|
119
|
+
post_patch_num_frames = num_frames // p_t
|
|
120
|
+
post_patch_height = height // p
|
|
121
|
+
post_patch_width = width // p
|
|
122
|
+
|
|
123
|
+
# 1. RoPE
|
|
124
|
+
image_rotary_emb = self.rope(hidden_states)
|
|
125
|
+
|
|
126
|
+
# 2. Conditional embeddings
|
|
127
|
+
temb = self.time_text_embed(
|
|
128
|
+
timestep, guidance, pooled_projections
|
|
129
|
+
)
|
|
130
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
131
|
+
encoder_hidden_states = self.context_embedder(
|
|
132
|
+
encoder_hidden_states, timestep, encoder_attention_mask
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# 3. Attention mask preparation
|
|
136
|
+
latent_sequence_length = hidden_states.shape[1]
|
|
137
|
+
latent_attention_mask = torch.ones(
|
|
138
|
+
batch_size,
|
|
139
|
+
1,
|
|
140
|
+
latent_sequence_length,
|
|
141
|
+
device=hidden_states.device,
|
|
142
|
+
dtype=torch.bool,
|
|
143
|
+
) # [B, 1, N]
|
|
144
|
+
attention_mask = torch.cat(
|
|
145
|
+
[
|
|
146
|
+
latent_attention_mask,
|
|
147
|
+
encoder_attention_mask.unsqueeze(1).to(torch.bool),
|
|
148
|
+
],
|
|
149
|
+
dim=-1,
|
|
150
|
+
) # [B, 1, N + M]
|
|
151
|
+
|
|
152
|
+
with SparseKVAttnMode():
|
|
153
|
+
# 4. Transformer blocks
|
|
154
|
+
hidden_states, encoder_hidden_states = (
|
|
155
|
+
self.call_transformer_blocks(
|
|
156
|
+
hidden_states,
|
|
157
|
+
encoder_hidden_states,
|
|
158
|
+
temb,
|
|
159
|
+
attention_mask,
|
|
160
|
+
image_rotary_emb,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# 5. Output projection
|
|
165
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
166
|
+
hidden_states = self.proj_out(hidden_states)
|
|
167
|
+
|
|
168
|
+
hidden_states = hidden_states.reshape(
|
|
169
|
+
batch_size,
|
|
170
|
+
post_patch_num_frames,
|
|
171
|
+
post_patch_height,
|
|
172
|
+
post_patch_width,
|
|
173
|
+
-1,
|
|
174
|
+
p_t,
|
|
175
|
+
p,
|
|
176
|
+
p,
|
|
177
|
+
)
|
|
178
|
+
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
|
179
|
+
hidden_states = (
|
|
180
|
+
hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
hidden_states = hidden_states.to(timestep.dtype)
|
|
184
|
+
|
|
185
|
+
if USE_PEFT_BACKEND:
|
|
186
|
+
# remove `lora_scale` from each PEFT layer
|
|
187
|
+
unscale_lora_layers(self, lora_scale)
|
|
188
|
+
|
|
189
|
+
if not return_dict:
|
|
190
|
+
return (hidden_states,)
|
|
191
|
+
|
|
192
|
+
return Transformer2DModelOutput(sample=hidden_states)
|
|
193
|
+
|
|
194
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
195
|
+
|
|
196
|
+
def call_transformer_blocks(
|
|
197
|
+
self, hidden_states, encoder_hidden_states, *args, **kwargs
|
|
198
|
+
):
|
|
199
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
200
|
+
|
|
201
|
+
def create_custom_forward(module, return_dict=None):
|
|
202
|
+
def custom_forward(*inputs):
|
|
203
|
+
if return_dict is not None:
|
|
204
|
+
return module(*inputs, return_dict=return_dict)
|
|
205
|
+
else:
|
|
206
|
+
return module(*inputs)
|
|
207
|
+
|
|
208
|
+
return custom_forward
|
|
209
|
+
|
|
210
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
|
|
211
|
+
|
|
212
|
+
for block in self.transformer_blocks:
|
|
213
|
+
hidden_states, encoder_hidden_states = (
|
|
214
|
+
torch.utils.checkpoint.checkpoint(
|
|
215
|
+
create_custom_forward(block),
|
|
216
|
+
hidden_states,
|
|
217
|
+
encoder_hidden_states,
|
|
218
|
+
*args,
|
|
219
|
+
**kwargs,
|
|
220
|
+
**ckpt_kwargs,
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
for block in self.single_transformer_blocks:
|
|
225
|
+
hidden_states, encoder_hidden_states = (
|
|
226
|
+
torch.utils.checkpoint.checkpoint(
|
|
227
|
+
create_custom_forward(block),
|
|
228
|
+
hidden_states,
|
|
229
|
+
encoder_hidden_states,
|
|
230
|
+
*args,
|
|
231
|
+
**kwargs,
|
|
232
|
+
**ckpt_kwargs,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
else:
|
|
237
|
+
for block in self.transformer_blocks:
|
|
238
|
+
hidden_states, encoder_hidden_states = block(
|
|
239
|
+
hidden_states, encoder_hidden_states, *args, **kwargs
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
for block in self.single_transformer_blocks:
|
|
243
|
+
hidden_states, encoder_hidden_states = block(
|
|
244
|
+
hidden_states, encoder_hidden_states, *args, **kwargs
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
return hidden_states, encoder_hidden_states
|
|
248
|
+
|
|
249
|
+
transformer.call_transformer_blocks = call_transformer_blocks.__get__(
|
|
250
|
+
transformer
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
transformer._is_cached = True
|
|
254
|
+
|
|
255
|
+
return transformer
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def apply_db_cache_on_pipe(
|
|
259
|
+
pipe: DiffusionPipeline,
|
|
260
|
+
*,
|
|
261
|
+
shallow_patch: bool = False,
|
|
262
|
+
residual_diff_threshold=0.06,
|
|
263
|
+
downsample_factor=1,
|
|
264
|
+
warmup_steps=0,
|
|
265
|
+
max_cached_steps=-1,
|
|
266
|
+
**kwargs,
|
|
267
|
+
):
|
|
268
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
269
|
+
default_attrs={
|
|
270
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
271
|
+
"downsample_factor": downsample_factor,
|
|
272
|
+
"warmup_steps": warmup_steps,
|
|
273
|
+
"max_cached_steps": max_cached_steps,
|
|
274
|
+
},
|
|
275
|
+
**kwargs,
|
|
276
|
+
)
|
|
277
|
+
if not getattr(pipe, "_is_cached", False):
|
|
278
|
+
original_call = pipe.__class__.__call__
|
|
279
|
+
|
|
280
|
+
@functools.wraps(original_call)
|
|
281
|
+
def new_call(self, *args, **kwargs):
|
|
282
|
+
with cache_context.cache_context(
|
|
283
|
+
cache_context.create_cache_context(
|
|
284
|
+
**cache_kwargs,
|
|
285
|
+
)
|
|
286
|
+
):
|
|
287
|
+
return original_call(self, *args, **kwargs)
|
|
288
|
+
|
|
289
|
+
pipe.__class__.__call__ = new_call
|
|
290
|
+
pipe.__class__._is_cached = True
|
|
291
|
+
|
|
292
|
+
if not shallow_patch:
|
|
293
|
+
apply_db_cache_on_transformer(pipe.transformer, **kwargs)
|
|
294
|
+
|
|
295
|
+
return pipe
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, WanTransformer3DModel
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.dual_block_cache import cache_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_db_cache_on_transformer(
|
|
13
|
+
transformer: WanTransformer3DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_cached", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
cache_context.DBCachedTransformerBlocks(
|
|
21
|
+
transformer.blocks,
|
|
22
|
+
transformer=transformer,
|
|
23
|
+
return_hidden_states_only=True,
|
|
24
|
+
)
|
|
25
|
+
]
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
original_forward = transformer.forward
|
|
29
|
+
|
|
30
|
+
@functools.wraps(transformer.__class__.forward)
|
|
31
|
+
def new_forward(
|
|
32
|
+
self,
|
|
33
|
+
*args,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
with unittest.mock.patch.object(
|
|
37
|
+
self,
|
|
38
|
+
"blocks",
|
|
39
|
+
blocks,
|
|
40
|
+
):
|
|
41
|
+
return original_forward(
|
|
42
|
+
*args,
|
|
43
|
+
**kwargs,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
47
|
+
|
|
48
|
+
transformer._is_cached = True
|
|
49
|
+
|
|
50
|
+
return transformer
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def apply_db_cache_on_pipe(
|
|
54
|
+
pipe: DiffusionPipeline,
|
|
55
|
+
*,
|
|
56
|
+
shallow_patch: bool = False,
|
|
57
|
+
residual_diff_threshold=0.03,
|
|
58
|
+
downsample_factor=1,
|
|
59
|
+
# SLG is not supported in WAN with DBCache yet
|
|
60
|
+
# slg_layers=None,
|
|
61
|
+
# slg_start: float = 0.0,
|
|
62
|
+
# slg_end: float = 0.1,
|
|
63
|
+
warmup_steps=0,
|
|
64
|
+
max_cached_steps=-1,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
68
|
+
default_attrs={
|
|
69
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
70
|
+
"downsample_factor": downsample_factor,
|
|
71
|
+
# "enable_alter_cache": True,
|
|
72
|
+
# "slg_layers": slg_layers,
|
|
73
|
+
# "slg_start": slg_start,
|
|
74
|
+
# "slg_end": slg_end,
|
|
75
|
+
"num_inference_steps": kwargs.get("num_inference_steps", 50),
|
|
76
|
+
"warmup_steps": warmup_steps,
|
|
77
|
+
"max_cached_steps": max_cached_steps,
|
|
78
|
+
},
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
if not getattr(pipe, "_is_cached", False):
|
|
82
|
+
original_call = pipe.__class__.__call__
|
|
83
|
+
|
|
84
|
+
@functools.wraps(original_call)
|
|
85
|
+
def new_call(self, *args, **kwargs):
|
|
86
|
+
with cache_context.cache_context(
|
|
87
|
+
cache_context.create_cache_context(
|
|
88
|
+
**cache_kwargs,
|
|
89
|
+
)
|
|
90
|
+
):
|
|
91
|
+
return original_call(self, *args, **kwargs)
|
|
92
|
+
|
|
93
|
+
pipe.__class__.__call__ = new_call
|
|
94
|
+
pipe.__class__._is_cached = True
|
|
95
|
+
|
|
96
|
+
if not shallow_patch:
|
|
97
|
+
apply_db_cache_on_transformer(pipe.transformer, **kwargs)
|
|
98
|
+
|
|
99
|
+
return pipe
|