cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Optional, Dict, Any, Union, Tuple
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from diffusers.utils import (
|
|
6
|
+
USE_PEFT_BACKEND,
|
|
7
|
+
scale_lora_layers,
|
|
8
|
+
unscale_lora_layers,
|
|
9
|
+
)
|
|
10
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
11
|
+
from diffusers.models.transformers.transformer_hunyuan_video import (
|
|
12
|
+
HunyuanVideoTransformer3DModel,
|
|
13
|
+
HunyuanVideoAttnProcessor2_0,
|
|
14
|
+
)
|
|
15
|
+
from diffusers.models.attention_processor import Attention
|
|
16
|
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from diffusers import HunyuanImageTransformer2DModel
|
|
20
|
+
from diffusers.models._modeling_parallel import (
|
|
21
|
+
ContextParallelInput,
|
|
22
|
+
ContextParallelOutput,
|
|
23
|
+
ContextParallelModelPlan,
|
|
24
|
+
)
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
28
|
+
"Please install latest version of diffusers from source: \n"
|
|
29
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
30
|
+
)
|
|
31
|
+
from .cp_plan_registers import (
|
|
32
|
+
ContextParallelismPlanner,
|
|
33
|
+
ContextParallelismPlannerRegister,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from cache_dit.logger import init_logger
|
|
37
|
+
|
|
38
|
+
logger = init_logger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@ContextParallelismPlannerRegister.register("HunyuanImage")
|
|
42
|
+
class HunyuanImageContextParallelismPlanner(ContextParallelismPlanner):
|
|
43
|
+
def apply(
|
|
44
|
+
self,
|
|
45
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
46
|
+
**kwargs,
|
|
47
|
+
) -> ContextParallelModelPlan:
|
|
48
|
+
|
|
49
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
50
|
+
# for HunyuanImage now.
|
|
51
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
52
|
+
|
|
53
|
+
if (
|
|
54
|
+
transformer is not None
|
|
55
|
+
and self._cp_planner_preferred_native_diffusers
|
|
56
|
+
):
|
|
57
|
+
assert isinstance(
|
|
58
|
+
transformer, HunyuanImageTransformer2DModel
|
|
59
|
+
), "Transformer must be an instance of HunyuanImageTransformer2DModel"
|
|
60
|
+
if hasattr(transformer, "_cp_plan"):
|
|
61
|
+
if transformer._cp_plan is not None:
|
|
62
|
+
return transformer._cp_plan
|
|
63
|
+
|
|
64
|
+
# Apply monkey patch to fix attention mask preparation while using CP
|
|
65
|
+
assert isinstance(transformer, HunyuanImageTransformer2DModel)
|
|
66
|
+
HunyuanImageTransformer2DModel.forward = (
|
|
67
|
+
__patch__HunyuanImageTransformer2DModel_forward__
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
71
|
+
# a little different from the native diffusers implementation
|
|
72
|
+
# for some models.
|
|
73
|
+
_cp_plan = {
|
|
74
|
+
# Pattern of rope, split_output=True (split output rather than input):
|
|
75
|
+
# un-split input
|
|
76
|
+
# -> keep input un-split
|
|
77
|
+
# -> rope
|
|
78
|
+
# -> splited output
|
|
79
|
+
"rope": {
|
|
80
|
+
0: ContextParallelInput(
|
|
81
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
82
|
+
),
|
|
83
|
+
1: ContextParallelInput(
|
|
84
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
85
|
+
),
|
|
86
|
+
},
|
|
87
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
88
|
+
# un-split input -> split -> to_qkv/...
|
|
89
|
+
# -> all2all
|
|
90
|
+
# -> attn (local head, full seqlen)
|
|
91
|
+
# -> all2all
|
|
92
|
+
# -> splited output
|
|
93
|
+
# Pattern of the rest transformer_blocks, single_transformer_blocks:
|
|
94
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
95
|
+
# -> all2all
|
|
96
|
+
# -> attn (local head, full seqlen)
|
|
97
|
+
# -> all2all
|
|
98
|
+
# -> splited output
|
|
99
|
+
# The `encoder_hidden_states` will be changed after each block forward,
|
|
100
|
+
# so we need to split it at the first block, and keep it splited (namely,
|
|
101
|
+
# automatically split by the all2all op after attn) for the rest blocks.
|
|
102
|
+
# The `out` tensor of local attn will be splited into `hidden_states` and
|
|
103
|
+
# `encoder_hidden_states` after each block forward, thus both of them
|
|
104
|
+
# will be automatically splited by all2all comm op after local attn.
|
|
105
|
+
"transformer_blocks.0": {
|
|
106
|
+
"hidden_states": ContextParallelInput(
|
|
107
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
108
|
+
),
|
|
109
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
110
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
111
|
+
),
|
|
112
|
+
},
|
|
113
|
+
# NOTE: We have to handle the `attention_mask` carefully in monkey-patched
|
|
114
|
+
# transformer forward while using CP, since it is not splited here.
|
|
115
|
+
# Then, the final proj_out will gather the splited output.
|
|
116
|
+
# splited input (previous splited output)
|
|
117
|
+
# -> all gather
|
|
118
|
+
# -> un-split output
|
|
119
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
120
|
+
}
|
|
121
|
+
return _cp_plan
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hunyuanimage.py#L806
|
|
125
|
+
@functools.wraps(HunyuanImageTransformer2DModel.forward)
|
|
126
|
+
def __patch__HunyuanImageTransformer2DModel_forward__(
|
|
127
|
+
self: HunyuanImageTransformer2DModel,
|
|
128
|
+
hidden_states: torch.Tensor,
|
|
129
|
+
timestep: torch.LongTensor,
|
|
130
|
+
encoder_hidden_states: torch.Tensor,
|
|
131
|
+
encoder_attention_mask: torch.Tensor,
|
|
132
|
+
timestep_r: Optional[torch.LongTensor] = None,
|
|
133
|
+
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
|
134
|
+
encoder_attention_mask_2: Optional[torch.Tensor] = None,
|
|
135
|
+
guidance: Optional[torch.Tensor] = None,
|
|
136
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
137
|
+
return_dict: bool = True,
|
|
138
|
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
139
|
+
if attention_kwargs is not None:
|
|
140
|
+
attention_kwargs = attention_kwargs.copy()
|
|
141
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
142
|
+
else:
|
|
143
|
+
lora_scale = 1.0
|
|
144
|
+
|
|
145
|
+
if USE_PEFT_BACKEND:
|
|
146
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
147
|
+
scale_lora_layers(self, lora_scale)
|
|
148
|
+
else:
|
|
149
|
+
if (
|
|
150
|
+
attention_kwargs is not None
|
|
151
|
+
and attention_kwargs.get("scale", None) is not None
|
|
152
|
+
):
|
|
153
|
+
logger.warning(
|
|
154
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if hidden_states.ndim == 4:
|
|
158
|
+
batch_size, channels, height, width = hidden_states.shape
|
|
159
|
+
sizes = (height, width)
|
|
160
|
+
elif hidden_states.ndim == 5:
|
|
161
|
+
batch_size, channels, frame, height, width = hidden_states.shape
|
|
162
|
+
sizes = (frame, height, width)
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
post_patch_sizes = tuple(
|
|
169
|
+
d // p for d, p in zip(sizes, self.config.patch_size)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# 1. RoPE
|
|
173
|
+
image_rotary_emb = self.rope(hidden_states)
|
|
174
|
+
|
|
175
|
+
# 2. Conditional embeddings
|
|
176
|
+
encoder_attention_mask = encoder_attention_mask.bool()
|
|
177
|
+
temb = self.time_guidance_embed(
|
|
178
|
+
timestep, guidance=guidance, timestep_r=timestep_r
|
|
179
|
+
)
|
|
180
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
181
|
+
encoder_hidden_states = self.context_embedder(
|
|
182
|
+
encoder_hidden_states, timestep, encoder_attention_mask
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if (
|
|
186
|
+
self.context_embedder_2 is not None
|
|
187
|
+
and encoder_hidden_states_2 is not None
|
|
188
|
+
):
|
|
189
|
+
encoder_hidden_states_2 = self.context_embedder_2(
|
|
190
|
+
encoder_hidden_states_2
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
|
|
194
|
+
|
|
195
|
+
# reorder and combine text tokens: combine valid tokens first, then padding
|
|
196
|
+
new_encoder_hidden_states = []
|
|
197
|
+
new_encoder_attention_mask = []
|
|
198
|
+
|
|
199
|
+
for text, text_mask, text_2, text_mask_2 in zip(
|
|
200
|
+
encoder_hidden_states,
|
|
201
|
+
encoder_attention_mask,
|
|
202
|
+
encoder_hidden_states_2,
|
|
203
|
+
encoder_attention_mask_2,
|
|
204
|
+
):
|
|
205
|
+
# Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
|
|
206
|
+
new_encoder_hidden_states.append(
|
|
207
|
+
torch.cat(
|
|
208
|
+
[
|
|
209
|
+
text_2[text_mask_2], # valid byt5
|
|
210
|
+
text[text_mask], # valid mllm
|
|
211
|
+
text_2[~text_mask_2], # invalid byt5
|
|
212
|
+
text[~text_mask], # invalid mllm
|
|
213
|
+
],
|
|
214
|
+
dim=0,
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Apply same reordering to attention masks
|
|
219
|
+
new_encoder_attention_mask.append(
|
|
220
|
+
torch.cat(
|
|
221
|
+
[
|
|
222
|
+
text_mask_2[text_mask_2],
|
|
223
|
+
text_mask[text_mask],
|
|
224
|
+
text_mask_2[~text_mask_2],
|
|
225
|
+
text_mask[~text_mask],
|
|
226
|
+
],
|
|
227
|
+
dim=0,
|
|
228
|
+
)
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
|
|
232
|
+
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
|
|
233
|
+
|
|
234
|
+
attention_mask = torch.nn.functional.pad(
|
|
235
|
+
encoder_attention_mask, (hidden_states.shape[1], 0), value=True
|
|
236
|
+
)
|
|
237
|
+
# NOTE(DefTruth): Permute attention_mask if context parallel is used.
|
|
238
|
+
# For example, if work size = 2: [H, E] -> [H_0, E_0, H_1, E_1]
|
|
239
|
+
if self._parallel_config is not None:
|
|
240
|
+
cp_config = getattr(
|
|
241
|
+
self._parallel_config, "context_parallel_config", None
|
|
242
|
+
)
|
|
243
|
+
if cp_config is not None and cp_config._world_size > 1:
|
|
244
|
+
hidden_mask = attention_mask[:, : hidden_states.shape[1]]
|
|
245
|
+
encoder_mask = attention_mask[:, hidden_states.shape[1] :]
|
|
246
|
+
hidden_mask_splits = torch.chunk(
|
|
247
|
+
hidden_mask, cp_config._world_size, dim=1
|
|
248
|
+
)
|
|
249
|
+
encoder_mask_splits = torch.chunk(
|
|
250
|
+
encoder_mask, cp_config._world_size, dim=1
|
|
251
|
+
)
|
|
252
|
+
new_attention_mask_splits = []
|
|
253
|
+
for i in range(cp_config._world_size):
|
|
254
|
+
new_attention_mask_splits.append(hidden_mask_splits[i])
|
|
255
|
+
new_attention_mask_splits.append(encoder_mask_splits[i])
|
|
256
|
+
attention_mask = torch.cat(new_attention_mask_splits, dim=1)
|
|
257
|
+
|
|
258
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(
|
|
259
|
+
2
|
|
260
|
+
) # [1,N] -> [1,1,1,N]
|
|
261
|
+
|
|
262
|
+
# 3. Transformer blocks
|
|
263
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
264
|
+
for block in self.transformer_blocks:
|
|
265
|
+
hidden_states, encoder_hidden_states = (
|
|
266
|
+
self._gradient_checkpointing_func(
|
|
267
|
+
block,
|
|
268
|
+
hidden_states,
|
|
269
|
+
encoder_hidden_states,
|
|
270
|
+
temb,
|
|
271
|
+
attention_mask=attention_mask,
|
|
272
|
+
image_rotary_emb=image_rotary_emb,
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
for block in self.single_transformer_blocks:
|
|
277
|
+
hidden_states, encoder_hidden_states = (
|
|
278
|
+
self._gradient_checkpointing_func(
|
|
279
|
+
block,
|
|
280
|
+
hidden_states,
|
|
281
|
+
encoder_hidden_states,
|
|
282
|
+
temb,
|
|
283
|
+
attention_mask=attention_mask,
|
|
284
|
+
image_rotary_emb=image_rotary_emb,
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
else:
|
|
289
|
+
for block in self.transformer_blocks:
|
|
290
|
+
hidden_states, encoder_hidden_states = block(
|
|
291
|
+
hidden_states,
|
|
292
|
+
encoder_hidden_states,
|
|
293
|
+
temb,
|
|
294
|
+
attention_mask=attention_mask,
|
|
295
|
+
image_rotary_emb=image_rotary_emb,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
for block in self.single_transformer_blocks:
|
|
299
|
+
hidden_states, encoder_hidden_states = block(
|
|
300
|
+
hidden_states,
|
|
301
|
+
encoder_hidden_states,
|
|
302
|
+
temb,
|
|
303
|
+
attention_mask=attention_mask,
|
|
304
|
+
image_rotary_emb=image_rotary_emb,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# 4. Output projection
|
|
308
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
309
|
+
hidden_states = self.proj_out(hidden_states)
|
|
310
|
+
|
|
311
|
+
# 5. unpatchify
|
|
312
|
+
# reshape: [batch_size, *post_patch_dims, channels, *patch_size]
|
|
313
|
+
out_channels = self.config.out_channels
|
|
314
|
+
reshape_dims = (
|
|
315
|
+
[batch_size]
|
|
316
|
+
+ list(post_patch_sizes)
|
|
317
|
+
+ [out_channels]
|
|
318
|
+
+ list(self.config.patch_size)
|
|
319
|
+
)
|
|
320
|
+
hidden_states = hidden_states.reshape(*reshape_dims)
|
|
321
|
+
|
|
322
|
+
# create permutation pattern: batch, channels, then interleave post_patch and patch dims
|
|
323
|
+
# For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
|
|
324
|
+
# For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
|
|
325
|
+
ndim = len(post_patch_sizes)
|
|
326
|
+
permute_pattern = [0, ndim + 1] # batch, channels
|
|
327
|
+
for i in range(ndim):
|
|
328
|
+
permute_pattern.extend(
|
|
329
|
+
[i + 1, ndim + 2 + i]
|
|
330
|
+
) # post_patch_sizes[i], patch_sizes[i]
|
|
331
|
+
hidden_states = hidden_states.permute(*permute_pattern)
|
|
332
|
+
|
|
333
|
+
# flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
|
|
334
|
+
# batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
|
|
335
|
+
final_dims = [batch_size, out_channels] + [
|
|
336
|
+
post_patch * patch
|
|
337
|
+
for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
|
|
338
|
+
]
|
|
339
|
+
hidden_states = hidden_states.reshape(*final_dims)
|
|
340
|
+
|
|
341
|
+
if USE_PEFT_BACKEND:
|
|
342
|
+
# remove `lora_scale` from each PEFT layer
|
|
343
|
+
unscale_lora_layers(self, lora_scale)
|
|
344
|
+
|
|
345
|
+
if not return_dict:
|
|
346
|
+
return (hidden_states,)
|
|
347
|
+
|
|
348
|
+
return Transformer2DModelOutput(sample=hidden_states)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@ContextParallelismPlannerRegister.register("HunyuanVideo")
|
|
352
|
+
class HunyuanVideoContextParallelismPlanner(ContextParallelismPlanner):
|
|
353
|
+
def apply(
|
|
354
|
+
self,
|
|
355
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
356
|
+
**kwargs,
|
|
357
|
+
) -> ContextParallelModelPlan:
|
|
358
|
+
|
|
359
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
360
|
+
# for HunyuanImage now.
|
|
361
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
362
|
+
|
|
363
|
+
if (
|
|
364
|
+
transformer is not None
|
|
365
|
+
and self._cp_planner_preferred_native_diffusers
|
|
366
|
+
):
|
|
367
|
+
assert isinstance(
|
|
368
|
+
transformer, HunyuanVideoTransformer3DModel
|
|
369
|
+
), "Transformer must be an instance of HunyuanVideoTransformer3DModel"
|
|
370
|
+
if hasattr(transformer, "_cp_plan"):
|
|
371
|
+
if transformer._cp_plan is not None:
|
|
372
|
+
return transformer._cp_plan
|
|
373
|
+
|
|
374
|
+
# Apply monkey patch to fix attention mask preparation while using CP
|
|
375
|
+
assert isinstance(transformer, HunyuanVideoTransformer3DModel)
|
|
376
|
+
HunyuanVideoTransformer3DModel.forward = (
|
|
377
|
+
__patch__HunyuanVideoTransformer3DModel_forward__
|
|
378
|
+
)
|
|
379
|
+
HunyuanVideoAttnProcessor2_0.__call__ = (
|
|
380
|
+
__patch_HunyuanVideoAttnProcessor2_0__call__
|
|
381
|
+
)
|
|
382
|
+
# Also need to patch the parallel config and attention backend
|
|
383
|
+
if not hasattr(HunyuanVideoAttnProcessor2_0, "_parallel_config"):
|
|
384
|
+
HunyuanVideoAttnProcessor2_0._parallel_config = None
|
|
385
|
+
if not hasattr(HunyuanVideoAttnProcessor2_0, "_attention_backend"):
|
|
386
|
+
HunyuanVideoAttnProcessor2_0._attention_backend = None
|
|
387
|
+
|
|
388
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
389
|
+
# a little different from the native diffusers implementation
|
|
390
|
+
# for some models.
|
|
391
|
+
_cp_plan = {
|
|
392
|
+
# Pattern of rope, split_output=True (split output rather than input):
|
|
393
|
+
# un-split input
|
|
394
|
+
# -> keep input un-split
|
|
395
|
+
# -> rope
|
|
396
|
+
# -> splited output
|
|
397
|
+
"rope": {
|
|
398
|
+
0: ContextParallelInput(
|
|
399
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
400
|
+
),
|
|
401
|
+
1: ContextParallelInput(
|
|
402
|
+
split_dim=0, expected_dims=2, split_output=True
|
|
403
|
+
),
|
|
404
|
+
},
|
|
405
|
+
# Pattern of transformer_blocks.0, split_output=False:
|
|
406
|
+
# un-split input -> split -> to_qkv/...
|
|
407
|
+
# -> all2all
|
|
408
|
+
# -> attn (local head, full seqlen)
|
|
409
|
+
# -> all2all
|
|
410
|
+
# -> splited output
|
|
411
|
+
# Pattern of the rest transformer_blocks, single_transformer_blocks:
|
|
412
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
413
|
+
# -> all2all
|
|
414
|
+
# -> attn (local head, full seqlen)
|
|
415
|
+
# -> all2all
|
|
416
|
+
# -> splited output
|
|
417
|
+
# The `encoder_hidden_states` will be changed after each block forward,
|
|
418
|
+
# so we need to split it at the first block, and keep it splited (namely,
|
|
419
|
+
# automatically split by the all2all op after attn) for the rest blocks.
|
|
420
|
+
# The `out` tensor of local attn will be splited into `hidden_states` and
|
|
421
|
+
# `encoder_hidden_states` after each block forward, thus both of them
|
|
422
|
+
# will be automatically splited by all2all comm op after local attn.
|
|
423
|
+
"transformer_blocks.0": {
|
|
424
|
+
"hidden_states": ContextParallelInput(
|
|
425
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
426
|
+
),
|
|
427
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
428
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
429
|
+
),
|
|
430
|
+
},
|
|
431
|
+
# NOTE: We have to handle the `attention_mask` carefully in monkey-patched
|
|
432
|
+
# transformer forward while using CP, since it is not splited here.
|
|
433
|
+
# Then, the final proj_out will gather the splited output.
|
|
434
|
+
# splited input (previous splited output)
|
|
435
|
+
# -> all gather
|
|
436
|
+
# -> un-split output
|
|
437
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
438
|
+
}
|
|
439
|
+
return _cp_plan
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hunyuan_video.py#L1032
|
|
443
|
+
@functools.wraps(HunyuanVideoTransformer3DModel.forward)
|
|
444
|
+
def __patch__HunyuanVideoTransformer3DModel_forward__(
|
|
445
|
+
self: HunyuanVideoTransformer3DModel,
|
|
446
|
+
hidden_states: torch.Tensor,
|
|
447
|
+
timestep: torch.LongTensor,
|
|
448
|
+
encoder_hidden_states: torch.Tensor,
|
|
449
|
+
encoder_attention_mask: torch.Tensor,
|
|
450
|
+
pooled_projections: torch.Tensor,
|
|
451
|
+
guidance: torch.Tensor = None,
|
|
452
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
453
|
+
return_dict: bool = True,
|
|
454
|
+
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
|
455
|
+
if attention_kwargs is not None:
|
|
456
|
+
attention_kwargs = attention_kwargs.copy()
|
|
457
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
458
|
+
else:
|
|
459
|
+
lora_scale = 1.0
|
|
460
|
+
|
|
461
|
+
if USE_PEFT_BACKEND:
|
|
462
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
463
|
+
scale_lora_layers(self, lora_scale)
|
|
464
|
+
else:
|
|
465
|
+
if (
|
|
466
|
+
attention_kwargs is not None
|
|
467
|
+
and attention_kwargs.get("scale", None) is not None
|
|
468
|
+
):
|
|
469
|
+
logger.warning(
|
|
470
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
|
474
|
+
p, p_t = self.config.patch_size, self.config.patch_size_t
|
|
475
|
+
post_patch_num_frames = num_frames // p_t
|
|
476
|
+
post_patch_height = height // p
|
|
477
|
+
post_patch_width = width // p
|
|
478
|
+
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
|
|
479
|
+
|
|
480
|
+
# 1. RoPE
|
|
481
|
+
image_rotary_emb = self.rope(hidden_states)
|
|
482
|
+
|
|
483
|
+
# 2. Conditional embeddings
|
|
484
|
+
temb, token_replace_emb = self.time_text_embed(
|
|
485
|
+
timestep, pooled_projections, guidance
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
489
|
+
encoder_hidden_states = self.context_embedder(
|
|
490
|
+
encoder_hidden_states, timestep, encoder_attention_mask
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# 3. Attention mask preparation
|
|
494
|
+
latent_sequence_length = hidden_states.shape[1]
|
|
495
|
+
condition_sequence_length = encoder_hidden_states.shape[1]
|
|
496
|
+
sequence_length = latent_sequence_length + condition_sequence_length
|
|
497
|
+
attention_mask = torch.ones(
|
|
498
|
+
batch_size,
|
|
499
|
+
sequence_length,
|
|
500
|
+
device=hidden_states.device,
|
|
501
|
+
dtype=torch.bool,
|
|
502
|
+
) # [B, N]
|
|
503
|
+
effective_condition_sequence_length = encoder_attention_mask.sum(
|
|
504
|
+
dim=1, dtype=torch.int
|
|
505
|
+
) # [B,]
|
|
506
|
+
effective_sequence_length = (
|
|
507
|
+
latent_sequence_length + effective_condition_sequence_length
|
|
508
|
+
)
|
|
509
|
+
indices = torch.arange(
|
|
510
|
+
sequence_length, device=hidden_states.device
|
|
511
|
+
).unsqueeze(
|
|
512
|
+
0
|
|
513
|
+
) # [1, N]
|
|
514
|
+
mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
|
|
515
|
+
attention_mask = attention_mask.masked_fill(mask_indices, False)
|
|
516
|
+
# NOTE(DefTruth): Permute attention_mask if context parallel is used.
|
|
517
|
+
# For example, if work size = 2: [H, E] -> [H_0, E_0, H_1, E_1]
|
|
518
|
+
if self._parallel_config is not None:
|
|
519
|
+
cp_config = getattr(
|
|
520
|
+
self._parallel_config, "context_parallel_config", None
|
|
521
|
+
)
|
|
522
|
+
if cp_config is not None and cp_config._world_size > 1:
|
|
523
|
+
hidden_mask = attention_mask[:, :latent_sequence_length]
|
|
524
|
+
encoder_mask = attention_mask[:, latent_sequence_length:]
|
|
525
|
+
hidden_mask_splits = torch.chunk(
|
|
526
|
+
hidden_mask, cp_config._world_size, dim=1
|
|
527
|
+
)
|
|
528
|
+
encoder_mask_splits = torch.chunk(
|
|
529
|
+
encoder_mask, cp_config._world_size, dim=1
|
|
530
|
+
)
|
|
531
|
+
new_attention_mask_splits = []
|
|
532
|
+
for i in range(cp_config._world_size):
|
|
533
|
+
new_attention_mask_splits.append(hidden_mask_splits[i])
|
|
534
|
+
new_attention_mask_splits.append(encoder_mask_splits[i])
|
|
535
|
+
attention_mask = torch.cat(new_attention_mask_splits, dim=1)
|
|
536
|
+
|
|
537
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
|
|
538
|
+
|
|
539
|
+
# 4. Transformer blocks
|
|
540
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
541
|
+
for block in self.transformer_blocks:
|
|
542
|
+
hidden_states, encoder_hidden_states = (
|
|
543
|
+
self._gradient_checkpointing_func(
|
|
544
|
+
block,
|
|
545
|
+
hidden_states,
|
|
546
|
+
encoder_hidden_states,
|
|
547
|
+
temb,
|
|
548
|
+
attention_mask,
|
|
549
|
+
image_rotary_emb,
|
|
550
|
+
token_replace_emb,
|
|
551
|
+
first_frame_num_tokens,
|
|
552
|
+
)
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
for block in self.single_transformer_blocks:
|
|
556
|
+
hidden_states, encoder_hidden_states = (
|
|
557
|
+
self._gradient_checkpointing_func(
|
|
558
|
+
block,
|
|
559
|
+
hidden_states,
|
|
560
|
+
encoder_hidden_states,
|
|
561
|
+
temb,
|
|
562
|
+
attention_mask,
|
|
563
|
+
image_rotary_emb,
|
|
564
|
+
token_replace_emb,
|
|
565
|
+
first_frame_num_tokens,
|
|
566
|
+
)
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
else:
|
|
570
|
+
for block in self.transformer_blocks:
|
|
571
|
+
hidden_states, encoder_hidden_states = block(
|
|
572
|
+
hidden_states,
|
|
573
|
+
encoder_hidden_states,
|
|
574
|
+
temb,
|
|
575
|
+
attention_mask,
|
|
576
|
+
image_rotary_emb,
|
|
577
|
+
token_replace_emb,
|
|
578
|
+
first_frame_num_tokens,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
for block in self.single_transformer_blocks:
|
|
582
|
+
hidden_states, encoder_hidden_states = block(
|
|
583
|
+
hidden_states,
|
|
584
|
+
encoder_hidden_states,
|
|
585
|
+
temb,
|
|
586
|
+
attention_mask,
|
|
587
|
+
image_rotary_emb,
|
|
588
|
+
token_replace_emb,
|
|
589
|
+
first_frame_num_tokens,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# 5. Output projection
|
|
593
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
594
|
+
hidden_states = self.proj_out(hidden_states)
|
|
595
|
+
|
|
596
|
+
hidden_states = hidden_states.reshape(
|
|
597
|
+
batch_size,
|
|
598
|
+
post_patch_num_frames,
|
|
599
|
+
post_patch_height,
|
|
600
|
+
post_patch_width,
|
|
601
|
+
-1,
|
|
602
|
+
p_t,
|
|
603
|
+
p,
|
|
604
|
+
p,
|
|
605
|
+
)
|
|
606
|
+
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
|
607
|
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
|
608
|
+
|
|
609
|
+
if USE_PEFT_BACKEND:
|
|
610
|
+
# remove `lora_scale` from each PEFT layer
|
|
611
|
+
unscale_lora_layers(self, lora_scale)
|
|
612
|
+
|
|
613
|
+
if not return_dict:
|
|
614
|
+
return (hidden_states,)
|
|
615
|
+
|
|
616
|
+
return Transformer2DModelOutput(sample=hidden_states)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
@functools.wraps(HunyuanVideoAttnProcessor2_0.__call__)
|
|
620
|
+
def __patch_HunyuanVideoAttnProcessor2_0__call__(
|
|
621
|
+
self: HunyuanVideoAttnProcessor2_0,
|
|
622
|
+
attn: Attention,
|
|
623
|
+
hidden_states: torch.Tensor,
|
|
624
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
625
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
626
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
627
|
+
) -> torch.Tensor:
|
|
628
|
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
|
629
|
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
|
630
|
+
|
|
631
|
+
# 1. QKV projections
|
|
632
|
+
query = attn.to_q(hidden_states)
|
|
633
|
+
key = attn.to_k(hidden_states)
|
|
634
|
+
value = attn.to_v(hidden_states)
|
|
635
|
+
|
|
636
|
+
# NOTE(DefTruth): no transpose
|
|
637
|
+
query = query.unflatten(2, (attn.heads, -1))
|
|
638
|
+
key = key.unflatten(2, (attn.heads, -1))
|
|
639
|
+
value = value.unflatten(2, (attn.heads, -1))
|
|
640
|
+
|
|
641
|
+
# 2. QK normalization
|
|
642
|
+
if attn.norm_q is not None:
|
|
643
|
+
query = attn.norm_q(query)
|
|
644
|
+
if attn.norm_k is not None:
|
|
645
|
+
key = attn.norm_k(key)
|
|
646
|
+
|
|
647
|
+
# 3. Rotational positional embeddings applied to latent stream
|
|
648
|
+
if image_rotary_emb is not None:
|
|
649
|
+
from diffusers.models.embeddings import apply_rotary_emb
|
|
650
|
+
|
|
651
|
+
# NOTE(DefTruth): Monkey patch for encoder conditional RoPE
|
|
652
|
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
|
653
|
+
query = torch.cat(
|
|
654
|
+
[
|
|
655
|
+
apply_rotary_emb(
|
|
656
|
+
query[:, : -encoder_hidden_states.shape[1]],
|
|
657
|
+
image_rotary_emb,
|
|
658
|
+
sequence_dim=1,
|
|
659
|
+
),
|
|
660
|
+
query[:, -encoder_hidden_states.shape[1] :],
|
|
661
|
+
],
|
|
662
|
+
dim=1,
|
|
663
|
+
)
|
|
664
|
+
key = torch.cat(
|
|
665
|
+
[
|
|
666
|
+
apply_rotary_emb(
|
|
667
|
+
key[:, : -encoder_hidden_states.shape[1]],
|
|
668
|
+
image_rotary_emb,
|
|
669
|
+
sequence_dim=1,
|
|
670
|
+
),
|
|
671
|
+
key[:, -encoder_hidden_states.shape[1] :],
|
|
672
|
+
],
|
|
673
|
+
dim=1,
|
|
674
|
+
)
|
|
675
|
+
else:
|
|
676
|
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
|
677
|
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
|
678
|
+
|
|
679
|
+
# 4. Encoder condition QKV projection and normalization
|
|
680
|
+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
|
681
|
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
|
682
|
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
|
683
|
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
|
684
|
+
|
|
685
|
+
# NOTE(DefTruth): no transpose
|
|
686
|
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
|
687
|
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
|
688
|
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
|
689
|
+
|
|
690
|
+
if attn.norm_added_q is not None:
|
|
691
|
+
encoder_query = attn.norm_added_q(encoder_query)
|
|
692
|
+
if attn.norm_added_k is not None:
|
|
693
|
+
encoder_key = attn.norm_added_k(encoder_key)
|
|
694
|
+
|
|
695
|
+
query = torch.cat([query, encoder_query], dim=1)
|
|
696
|
+
key = torch.cat([key, encoder_key], dim=1)
|
|
697
|
+
value = torch.cat([value, encoder_value], dim=1)
|
|
698
|
+
|
|
699
|
+
# 5. Attention
|
|
700
|
+
# NOTE(DefTruth): use dispatch_attention_fn
|
|
701
|
+
hidden_states = dispatch_attention_fn(
|
|
702
|
+
query,
|
|
703
|
+
key,
|
|
704
|
+
value,
|
|
705
|
+
attn_mask=attention_mask,
|
|
706
|
+
dropout_p=0.0,
|
|
707
|
+
is_causal=False,
|
|
708
|
+
backend=getattr(self, "_attention_backend", None),
|
|
709
|
+
parallel_config=getattr(self, "_parallel_config", None),
|
|
710
|
+
)
|
|
711
|
+
# NOTE(DefTruth): no transpose
|
|
712
|
+
hidden_states = hidden_states.flatten(2, 3)
|
|
713
|
+
hidden_states = hidden_states.to(query.dtype)
|
|
714
|
+
|
|
715
|
+
# 6. Output projection
|
|
716
|
+
if encoder_hidden_states is not None:
|
|
717
|
+
hidden_states, encoder_hidden_states = (
|
|
718
|
+
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
|
719
|
+
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
if getattr(attn, "to_out", None) is not None:
|
|
723
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
724
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
725
|
+
|
|
726
|
+
if getattr(attn, "to_add_out", None) is not None:
|
|
727
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
|
728
|
+
|
|
729
|
+
return hidden_states, encoder_hidden_states
|