diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -2272,558 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
|
|
2272
2272
|
return hidden_states
|
2273
2273
|
|
2274
2274
|
|
2275
|
-
class FluxAttnProcessor2_0:
|
2276
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2277
|
-
|
2278
|
-
def __init__(self):
|
2279
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2280
|
-
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
2281
|
-
|
2282
|
-
def __call__(
|
2283
|
-
self,
|
2284
|
-
attn: Attention,
|
2285
|
-
hidden_states: torch.FloatTensor,
|
2286
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2287
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2288
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2289
|
-
) -> torch.FloatTensor:
|
2290
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2291
|
-
|
2292
|
-
# `sample` projections.
|
2293
|
-
query = attn.to_q(hidden_states)
|
2294
|
-
key = attn.to_k(hidden_states)
|
2295
|
-
value = attn.to_v(hidden_states)
|
2296
|
-
|
2297
|
-
inner_dim = key.shape[-1]
|
2298
|
-
head_dim = inner_dim // attn.heads
|
2299
|
-
|
2300
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2301
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2302
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2303
|
-
|
2304
|
-
if attn.norm_q is not None:
|
2305
|
-
query = attn.norm_q(query)
|
2306
|
-
if attn.norm_k is not None:
|
2307
|
-
key = attn.norm_k(key)
|
2308
|
-
|
2309
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2310
|
-
if encoder_hidden_states is not None:
|
2311
|
-
# `context` projections.
|
2312
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2313
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2314
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2315
|
-
|
2316
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2317
|
-
batch_size, -1, attn.heads, head_dim
|
2318
|
-
).transpose(1, 2)
|
2319
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2320
|
-
batch_size, -1, attn.heads, head_dim
|
2321
|
-
).transpose(1, 2)
|
2322
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2323
|
-
batch_size, -1, attn.heads, head_dim
|
2324
|
-
).transpose(1, 2)
|
2325
|
-
|
2326
|
-
if attn.norm_added_q is not None:
|
2327
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2328
|
-
if attn.norm_added_k is not None:
|
2329
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2330
|
-
|
2331
|
-
# attention
|
2332
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2333
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2334
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2335
|
-
|
2336
|
-
if image_rotary_emb is not None:
|
2337
|
-
from .embeddings import apply_rotary_emb
|
2338
|
-
|
2339
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2340
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2341
|
-
|
2342
|
-
hidden_states = F.scaled_dot_product_attention(
|
2343
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2344
|
-
)
|
2345
|
-
|
2346
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2347
|
-
hidden_states = hidden_states.to(query.dtype)
|
2348
|
-
|
2349
|
-
if encoder_hidden_states is not None:
|
2350
|
-
encoder_hidden_states, hidden_states = (
|
2351
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2352
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2353
|
-
)
|
2354
|
-
|
2355
|
-
# linear proj
|
2356
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2357
|
-
# dropout
|
2358
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2359
|
-
|
2360
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2361
|
-
|
2362
|
-
return hidden_states, encoder_hidden_states
|
2363
|
-
else:
|
2364
|
-
return hidden_states
|
2365
|
-
|
2366
|
-
|
2367
|
-
class FluxAttnProcessor2_0_NPU:
|
2368
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2369
|
-
|
2370
|
-
def __init__(self):
|
2371
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2372
|
-
raise ImportError(
|
2373
|
-
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
2374
|
-
)
|
2375
|
-
|
2376
|
-
def __call__(
|
2377
|
-
self,
|
2378
|
-
attn: Attention,
|
2379
|
-
hidden_states: torch.FloatTensor,
|
2380
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2381
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2382
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2383
|
-
) -> torch.FloatTensor:
|
2384
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2385
|
-
|
2386
|
-
# `sample` projections.
|
2387
|
-
query = attn.to_q(hidden_states)
|
2388
|
-
key = attn.to_k(hidden_states)
|
2389
|
-
value = attn.to_v(hidden_states)
|
2390
|
-
|
2391
|
-
inner_dim = key.shape[-1]
|
2392
|
-
head_dim = inner_dim // attn.heads
|
2393
|
-
|
2394
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2395
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2396
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2397
|
-
|
2398
|
-
if attn.norm_q is not None:
|
2399
|
-
query = attn.norm_q(query)
|
2400
|
-
if attn.norm_k is not None:
|
2401
|
-
key = attn.norm_k(key)
|
2402
|
-
|
2403
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2404
|
-
if encoder_hidden_states is not None:
|
2405
|
-
# `context` projections.
|
2406
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2407
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2408
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2409
|
-
|
2410
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2411
|
-
batch_size, -1, attn.heads, head_dim
|
2412
|
-
).transpose(1, 2)
|
2413
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2414
|
-
batch_size, -1, attn.heads, head_dim
|
2415
|
-
).transpose(1, 2)
|
2416
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2417
|
-
batch_size, -1, attn.heads, head_dim
|
2418
|
-
).transpose(1, 2)
|
2419
|
-
|
2420
|
-
if attn.norm_added_q is not None:
|
2421
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2422
|
-
if attn.norm_added_k is not None:
|
2423
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2424
|
-
|
2425
|
-
# attention
|
2426
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2427
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2428
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2429
|
-
|
2430
|
-
if image_rotary_emb is not None:
|
2431
|
-
from .embeddings import apply_rotary_emb
|
2432
|
-
|
2433
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2434
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2435
|
-
|
2436
|
-
if query.dtype in (torch.float16, torch.bfloat16):
|
2437
|
-
hidden_states = torch_npu.npu_fusion_attention(
|
2438
|
-
query,
|
2439
|
-
key,
|
2440
|
-
value,
|
2441
|
-
attn.heads,
|
2442
|
-
input_layout="BNSD",
|
2443
|
-
pse=None,
|
2444
|
-
scale=1.0 / math.sqrt(query.shape[-1]),
|
2445
|
-
pre_tockens=65536,
|
2446
|
-
next_tockens=65536,
|
2447
|
-
keep_prob=1.0,
|
2448
|
-
sync=False,
|
2449
|
-
inner_precise=0,
|
2450
|
-
)[0]
|
2451
|
-
else:
|
2452
|
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2453
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2454
|
-
hidden_states = hidden_states.to(query.dtype)
|
2455
|
-
|
2456
|
-
if encoder_hidden_states is not None:
|
2457
|
-
encoder_hidden_states, hidden_states = (
|
2458
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2459
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2460
|
-
)
|
2461
|
-
|
2462
|
-
# linear proj
|
2463
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2464
|
-
# dropout
|
2465
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2466
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2467
|
-
|
2468
|
-
return hidden_states, encoder_hidden_states
|
2469
|
-
else:
|
2470
|
-
return hidden_states
|
2471
|
-
|
2472
|
-
|
2473
|
-
class FusedFluxAttnProcessor2_0:
|
2474
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2475
|
-
|
2476
|
-
def __init__(self):
|
2477
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2478
|
-
raise ImportError(
|
2479
|
-
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2480
|
-
)
|
2481
|
-
|
2482
|
-
def __call__(
|
2483
|
-
self,
|
2484
|
-
attn: Attention,
|
2485
|
-
hidden_states: torch.FloatTensor,
|
2486
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2487
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2488
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2489
|
-
) -> torch.FloatTensor:
|
2490
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2491
|
-
|
2492
|
-
# `sample` projections.
|
2493
|
-
qkv = attn.to_qkv(hidden_states)
|
2494
|
-
split_size = qkv.shape[-1] // 3
|
2495
|
-
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2496
|
-
|
2497
|
-
inner_dim = key.shape[-1]
|
2498
|
-
head_dim = inner_dim // attn.heads
|
2499
|
-
|
2500
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2501
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2502
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2503
|
-
|
2504
|
-
if attn.norm_q is not None:
|
2505
|
-
query = attn.norm_q(query)
|
2506
|
-
if attn.norm_k is not None:
|
2507
|
-
key = attn.norm_k(key)
|
2508
|
-
|
2509
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2510
|
-
# `context` projections.
|
2511
|
-
if encoder_hidden_states is not None:
|
2512
|
-
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2513
|
-
split_size = encoder_qkv.shape[-1] // 3
|
2514
|
-
(
|
2515
|
-
encoder_hidden_states_query_proj,
|
2516
|
-
encoder_hidden_states_key_proj,
|
2517
|
-
encoder_hidden_states_value_proj,
|
2518
|
-
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2519
|
-
|
2520
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2521
|
-
batch_size, -1, attn.heads, head_dim
|
2522
|
-
).transpose(1, 2)
|
2523
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2524
|
-
batch_size, -1, attn.heads, head_dim
|
2525
|
-
).transpose(1, 2)
|
2526
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2527
|
-
batch_size, -1, attn.heads, head_dim
|
2528
|
-
).transpose(1, 2)
|
2529
|
-
|
2530
|
-
if attn.norm_added_q is not None:
|
2531
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2532
|
-
if attn.norm_added_k is not None:
|
2533
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2534
|
-
|
2535
|
-
# attention
|
2536
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2537
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2538
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2539
|
-
|
2540
|
-
if image_rotary_emb is not None:
|
2541
|
-
from .embeddings import apply_rotary_emb
|
2542
|
-
|
2543
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2544
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2545
|
-
|
2546
|
-
hidden_states = F.scaled_dot_product_attention(
|
2547
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2548
|
-
)
|
2549
|
-
|
2550
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2551
|
-
hidden_states = hidden_states.to(query.dtype)
|
2552
|
-
|
2553
|
-
if encoder_hidden_states is not None:
|
2554
|
-
encoder_hidden_states, hidden_states = (
|
2555
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2556
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2557
|
-
)
|
2558
|
-
|
2559
|
-
# linear proj
|
2560
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2561
|
-
# dropout
|
2562
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2563
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2564
|
-
|
2565
|
-
return hidden_states, encoder_hidden_states
|
2566
|
-
else:
|
2567
|
-
return hidden_states
|
2568
|
-
|
2569
|
-
|
2570
|
-
class FusedFluxAttnProcessor2_0_NPU:
|
2571
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2572
|
-
|
2573
|
-
def __init__(self):
|
2574
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2575
|
-
raise ImportError(
|
2576
|
-
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
2577
|
-
)
|
2578
|
-
|
2579
|
-
def __call__(
|
2580
|
-
self,
|
2581
|
-
attn: Attention,
|
2582
|
-
hidden_states: torch.FloatTensor,
|
2583
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2584
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2585
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2586
|
-
) -> torch.FloatTensor:
|
2587
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2588
|
-
|
2589
|
-
# `sample` projections.
|
2590
|
-
qkv = attn.to_qkv(hidden_states)
|
2591
|
-
split_size = qkv.shape[-1] // 3
|
2592
|
-
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2593
|
-
|
2594
|
-
inner_dim = key.shape[-1]
|
2595
|
-
head_dim = inner_dim // attn.heads
|
2596
|
-
|
2597
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2598
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2599
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2600
|
-
|
2601
|
-
if attn.norm_q is not None:
|
2602
|
-
query = attn.norm_q(query)
|
2603
|
-
if attn.norm_k is not None:
|
2604
|
-
key = attn.norm_k(key)
|
2605
|
-
|
2606
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2607
|
-
# `context` projections.
|
2608
|
-
if encoder_hidden_states is not None:
|
2609
|
-
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2610
|
-
split_size = encoder_qkv.shape[-1] // 3
|
2611
|
-
(
|
2612
|
-
encoder_hidden_states_query_proj,
|
2613
|
-
encoder_hidden_states_key_proj,
|
2614
|
-
encoder_hidden_states_value_proj,
|
2615
|
-
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2616
|
-
|
2617
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2618
|
-
batch_size, -1, attn.heads, head_dim
|
2619
|
-
).transpose(1, 2)
|
2620
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2621
|
-
batch_size, -1, attn.heads, head_dim
|
2622
|
-
).transpose(1, 2)
|
2623
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2624
|
-
batch_size, -1, attn.heads, head_dim
|
2625
|
-
).transpose(1, 2)
|
2626
|
-
|
2627
|
-
if attn.norm_added_q is not None:
|
2628
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2629
|
-
if attn.norm_added_k is not None:
|
2630
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2631
|
-
|
2632
|
-
# attention
|
2633
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2634
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2635
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2636
|
-
|
2637
|
-
if image_rotary_emb is not None:
|
2638
|
-
from .embeddings import apply_rotary_emb
|
2639
|
-
|
2640
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2641
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2642
|
-
|
2643
|
-
if query.dtype in (torch.float16, torch.bfloat16):
|
2644
|
-
hidden_states = torch_npu.npu_fusion_attention(
|
2645
|
-
query,
|
2646
|
-
key,
|
2647
|
-
value,
|
2648
|
-
attn.heads,
|
2649
|
-
input_layout="BNSD",
|
2650
|
-
pse=None,
|
2651
|
-
scale=1.0 / math.sqrt(query.shape[-1]),
|
2652
|
-
pre_tockens=65536,
|
2653
|
-
next_tockens=65536,
|
2654
|
-
keep_prob=1.0,
|
2655
|
-
sync=False,
|
2656
|
-
inner_precise=0,
|
2657
|
-
)[0]
|
2658
|
-
else:
|
2659
|
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2660
|
-
|
2661
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2662
|
-
hidden_states = hidden_states.to(query.dtype)
|
2663
|
-
|
2664
|
-
if encoder_hidden_states is not None:
|
2665
|
-
encoder_hidden_states, hidden_states = (
|
2666
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2667
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2668
|
-
)
|
2669
|
-
|
2670
|
-
# linear proj
|
2671
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2672
|
-
# dropout
|
2673
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2674
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2675
|
-
|
2676
|
-
return hidden_states, encoder_hidden_states
|
2677
|
-
else:
|
2678
|
-
return hidden_states
|
2679
|
-
|
2680
|
-
|
2681
|
-
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
2682
|
-
"""Flux Attention processor for IP-Adapter."""
|
2683
|
-
|
2684
|
-
def __init__(
|
2685
|
-
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
2686
|
-
):
|
2687
|
-
super().__init__()
|
2688
|
-
|
2689
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2690
|
-
raise ImportError(
|
2691
|
-
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2692
|
-
)
|
2693
|
-
|
2694
|
-
self.hidden_size = hidden_size
|
2695
|
-
self.cross_attention_dim = cross_attention_dim
|
2696
|
-
|
2697
|
-
if not isinstance(num_tokens, (tuple, list)):
|
2698
|
-
num_tokens = [num_tokens]
|
2699
|
-
|
2700
|
-
if not isinstance(scale, list):
|
2701
|
-
scale = [scale] * len(num_tokens)
|
2702
|
-
if len(scale) != len(num_tokens):
|
2703
|
-
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
2704
|
-
self.scale = scale
|
2705
|
-
|
2706
|
-
self.to_k_ip = nn.ModuleList(
|
2707
|
-
[
|
2708
|
-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2709
|
-
for _ in range(len(num_tokens))
|
2710
|
-
]
|
2711
|
-
)
|
2712
|
-
self.to_v_ip = nn.ModuleList(
|
2713
|
-
[
|
2714
|
-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2715
|
-
for _ in range(len(num_tokens))
|
2716
|
-
]
|
2717
|
-
)
|
2718
|
-
|
2719
|
-
def __call__(
|
2720
|
-
self,
|
2721
|
-
attn: Attention,
|
2722
|
-
hidden_states: torch.FloatTensor,
|
2723
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2724
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2725
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2726
|
-
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
2727
|
-
ip_adapter_masks: Optional[torch.Tensor] = None,
|
2728
|
-
) -> torch.FloatTensor:
|
2729
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2730
|
-
|
2731
|
-
# `sample` projections.
|
2732
|
-
hidden_states_query_proj = attn.to_q(hidden_states)
|
2733
|
-
key = attn.to_k(hidden_states)
|
2734
|
-
value = attn.to_v(hidden_states)
|
2735
|
-
|
2736
|
-
inner_dim = key.shape[-1]
|
2737
|
-
head_dim = inner_dim // attn.heads
|
2738
|
-
|
2739
|
-
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2740
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2741
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2742
|
-
|
2743
|
-
if attn.norm_q is not None:
|
2744
|
-
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
2745
|
-
if attn.norm_k is not None:
|
2746
|
-
key = attn.norm_k(key)
|
2747
|
-
|
2748
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2749
|
-
if encoder_hidden_states is not None:
|
2750
|
-
# `context` projections.
|
2751
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2752
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2753
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2754
|
-
|
2755
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2756
|
-
batch_size, -1, attn.heads, head_dim
|
2757
|
-
).transpose(1, 2)
|
2758
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2759
|
-
batch_size, -1, attn.heads, head_dim
|
2760
|
-
).transpose(1, 2)
|
2761
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2762
|
-
batch_size, -1, attn.heads, head_dim
|
2763
|
-
).transpose(1, 2)
|
2764
|
-
|
2765
|
-
if attn.norm_added_q is not None:
|
2766
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2767
|
-
if attn.norm_added_k is not None:
|
2768
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2769
|
-
|
2770
|
-
# attention
|
2771
|
-
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
2772
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2773
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2774
|
-
|
2775
|
-
if image_rotary_emb is not None:
|
2776
|
-
from .embeddings import apply_rotary_emb
|
2777
|
-
|
2778
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2779
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2780
|
-
|
2781
|
-
hidden_states = F.scaled_dot_product_attention(
|
2782
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2783
|
-
)
|
2784
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2785
|
-
hidden_states = hidden_states.to(query.dtype)
|
2786
|
-
|
2787
|
-
if encoder_hidden_states is not None:
|
2788
|
-
encoder_hidden_states, hidden_states = (
|
2789
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2790
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2791
|
-
)
|
2792
|
-
|
2793
|
-
# linear proj
|
2794
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2795
|
-
# dropout
|
2796
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2797
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2798
|
-
|
2799
|
-
# IP-adapter
|
2800
|
-
ip_query = hidden_states_query_proj
|
2801
|
-
ip_attn_output = torch.zeros_like(hidden_states)
|
2802
|
-
|
2803
|
-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
2804
|
-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
2805
|
-
):
|
2806
|
-
ip_key = to_k_ip(current_ip_hidden_states)
|
2807
|
-
ip_value = to_v_ip(current_ip_hidden_states)
|
2808
|
-
|
2809
|
-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2810
|
-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2811
|
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2812
|
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
2813
|
-
current_ip_hidden_states = F.scaled_dot_product_attention(
|
2814
|
-
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2815
|
-
)
|
2816
|
-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
2817
|
-
batch_size, -1, attn.heads * head_dim
|
2818
|
-
)
|
2819
|
-
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
2820
|
-
ip_attn_output += scale * current_ip_hidden_states
|
2821
|
-
|
2822
|
-
return hidden_states, encoder_hidden_states, ip_attn_output
|
2823
|
-
else:
|
2824
|
-
return hidden_states
|
2825
|
-
|
2826
|
-
|
2827
2275
|
class CogVideoXAttnProcessor2_0:
|
2828
2276
|
r"""
|
2829
2277
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
@@ -3453,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
|
|
3453
2901
|
return hidden_states
|
3454
2902
|
|
3455
2903
|
|
3456
|
-
class XLAFluxFlashAttnProcessor2_0:
|
3457
|
-
r"""
|
3458
|
-
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
3459
|
-
"""
|
3460
|
-
|
3461
|
-
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
3462
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
3463
|
-
raise ImportError(
|
3464
|
-
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
3465
|
-
)
|
3466
|
-
if is_torch_xla_version("<", "2.3"):
|
3467
|
-
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
3468
|
-
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
3469
|
-
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
3470
|
-
self.partition_spec = partition_spec
|
3471
|
-
|
3472
|
-
def __call__(
|
3473
|
-
self,
|
3474
|
-
attn: Attention,
|
3475
|
-
hidden_states: torch.FloatTensor,
|
3476
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
3477
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
3478
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
3479
|
-
) -> torch.FloatTensor:
|
3480
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3481
|
-
|
3482
|
-
# `sample` projections.
|
3483
|
-
query = attn.to_q(hidden_states)
|
3484
|
-
key = attn.to_k(hidden_states)
|
3485
|
-
value = attn.to_v(hidden_states)
|
3486
|
-
|
3487
|
-
inner_dim = key.shape[-1]
|
3488
|
-
head_dim = inner_dim // attn.heads
|
3489
|
-
|
3490
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3491
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3492
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3493
|
-
|
3494
|
-
if attn.norm_q is not None:
|
3495
|
-
query = attn.norm_q(query)
|
3496
|
-
if attn.norm_k is not None:
|
3497
|
-
key = attn.norm_k(key)
|
3498
|
-
|
3499
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
3500
|
-
if encoder_hidden_states is not None:
|
3501
|
-
# `context` projections.
|
3502
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
3503
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
3504
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
3505
|
-
|
3506
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
3507
|
-
batch_size, -1, attn.heads, head_dim
|
3508
|
-
).transpose(1, 2)
|
3509
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
3510
|
-
batch_size, -1, attn.heads, head_dim
|
3511
|
-
).transpose(1, 2)
|
3512
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
3513
|
-
batch_size, -1, attn.heads, head_dim
|
3514
|
-
).transpose(1, 2)
|
3515
|
-
|
3516
|
-
if attn.norm_added_q is not None:
|
3517
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
3518
|
-
if attn.norm_added_k is not None:
|
3519
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
3520
|
-
|
3521
|
-
# attention
|
3522
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
3523
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
3524
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
3525
|
-
|
3526
|
-
if image_rotary_emb is not None:
|
3527
|
-
from .embeddings import apply_rotary_emb
|
3528
|
-
|
3529
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
3530
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
3531
|
-
|
3532
|
-
query /= math.sqrt(head_dim)
|
3533
|
-
hidden_states = flash_attention(query, key, value, causal=False)
|
3534
|
-
|
3535
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3536
|
-
hidden_states = hidden_states.to(query.dtype)
|
3537
|
-
|
3538
|
-
if encoder_hidden_states is not None:
|
3539
|
-
encoder_hidden_states, hidden_states = (
|
3540
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
3541
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
3542
|
-
)
|
3543
|
-
|
3544
|
-
# linear proj
|
3545
|
-
hidden_states = attn.to_out[0](hidden_states)
|
3546
|
-
# dropout
|
3547
|
-
hidden_states = attn.to_out[1](hidden_states)
|
3548
|
-
|
3549
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
3550
|
-
|
3551
|
-
return hidden_states, encoder_hidden_states
|
3552
|
-
else:
|
3553
|
-
return hidden_states
|
3554
|
-
|
3555
|
-
|
3556
2904
|
class MochiVaeAttnProcessor2_0:
|
3557
2905
|
r"""
|
3558
2906
|
Attention processor used in Mochi VAE.
|
@@ -5992,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
|
|
5992
5340
|
pass
|
5993
5341
|
|
5994
5342
|
|
5995
|
-
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
|
5996
|
-
r"""
|
5997
|
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
5998
|
-
"""
|
5999
|
-
|
6000
|
-
def __init__(self):
|
6001
|
-
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
|
6002
|
-
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
6003
|
-
super().__init__()
|
6004
|
-
|
6005
|
-
|
6006
5343
|
class SanaLinearAttnProcessor2_0:
|
6007
5344
|
r"""
|
6008
5345
|
Processor for implementing scaled dot-product linear attention.
|
@@ -6167,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
|
|
6167
5504
|
return hidden_states
|
6168
5505
|
|
6169
5506
|
|
5507
|
+
class FluxAttnProcessor2_0:
|
5508
|
+
def __new__(cls, *args, **kwargs):
|
5509
|
+
deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
|
5510
|
+
deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
|
5511
|
+
|
5512
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5513
|
+
|
5514
|
+
return FluxAttnProcessor(*args, **kwargs)
|
5515
|
+
|
5516
|
+
|
5517
|
+
class FluxSingleAttnProcessor2_0:
|
5518
|
+
r"""
|
5519
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
5520
|
+
"""
|
5521
|
+
|
5522
|
+
def __new__(cls, *args, **kwargs):
|
5523
|
+
deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
|
5524
|
+
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
|
5525
|
+
|
5526
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5527
|
+
|
5528
|
+
return FluxAttnProcessor(*args, **kwargs)
|
5529
|
+
|
5530
|
+
|
5531
|
+
class FusedFluxAttnProcessor2_0:
|
5532
|
+
def __new__(cls, *args, **kwargs):
|
5533
|
+
deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
|
5534
|
+
deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
|
5535
|
+
|
5536
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5537
|
+
|
5538
|
+
return FluxAttnProcessor(*args, **kwargs)
|
5539
|
+
|
5540
|
+
|
5541
|
+
class FluxIPAdapterJointAttnProcessor2_0:
|
5542
|
+
def __new__(cls, *args, **kwargs):
|
5543
|
+
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
|
5544
|
+
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
|
5545
|
+
|
5546
|
+
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
5547
|
+
|
5548
|
+
return FluxIPAdapterAttnProcessor(*args, **kwargs)
|
5549
|
+
|
5550
|
+
|
5551
|
+
class FluxAttnProcessor2_0_NPU:
|
5552
|
+
def __new__(cls, *args, **kwargs):
|
5553
|
+
deprecation_message = (
|
5554
|
+
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
5555
|
+
"alternative solution to use NPU Flash Attention will be provided in the future."
|
5556
|
+
)
|
5557
|
+
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
5558
|
+
|
5559
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5560
|
+
|
5561
|
+
processor = FluxAttnProcessor()
|
5562
|
+
processor._attention_backend = "_native_npu"
|
5563
|
+
return processor
|
5564
|
+
|
5565
|
+
|
5566
|
+
class FusedFluxAttnProcessor2_0_NPU:
|
5567
|
+
def __new__(self):
|
5568
|
+
deprecation_message = (
|
5569
|
+
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
5570
|
+
"alternative solution to use NPU Flash Attention will be provided in the future."
|
5571
|
+
)
|
5572
|
+
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
5573
|
+
|
5574
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5575
|
+
|
5576
|
+
processor = FluxAttnProcessor()
|
5577
|
+
processor._attention_backend = "_fused_npu"
|
5578
|
+
return processor
|
5579
|
+
|
5580
|
+
|
5581
|
+
class XLAFluxFlashAttnProcessor2_0:
|
5582
|
+
r"""
|
5583
|
+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
5584
|
+
"""
|
5585
|
+
|
5586
|
+
def __new__(cls, *args, **kwargs):
|
5587
|
+
deprecation_message = (
|
5588
|
+
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
|
5589
|
+
"alternative solution to using XLA Flash Attention will be provided in the future."
|
5590
|
+
)
|
5591
|
+
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
5592
|
+
|
5593
|
+
if is_torch_xla_version("<", "2.3"):
|
5594
|
+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
5595
|
+
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
5596
|
+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
5597
|
+
|
5598
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5599
|
+
|
5600
|
+
if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
|
5601
|
+
deprecation_message = (
|
5602
|
+
"partition_spec was not used in the processor implementation when it was added. Passing it "
|
5603
|
+
"is a no-op and support for it will be removed."
|
5604
|
+
)
|
5605
|
+
deprecate("partition_spec", "1.0.0", deprecation_message)
|
5606
|
+
|
5607
|
+
processor = FluxAttnProcessor(*args, **kwargs)
|
5608
|
+
processor._attention_backend = "_native_xla"
|
5609
|
+
return processor
|
5610
|
+
|
5611
|
+
|
6170
5612
|
ADDED_KV_ATTENTION_PROCESSORS = (
|
6171
5613
|
AttnAddedKVProcessor,
|
6172
5614
|
SlicedAttnAddedKVProcessor,
|