diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -2272,558 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
2272
2272
  return hidden_states
2273
2273
 
2274
2274
 
2275
- class FluxAttnProcessor2_0:
2276
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2277
-
2278
- def __init__(self):
2279
- if not hasattr(F, "scaled_dot_product_attention"):
2280
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2281
-
2282
- def __call__(
2283
- self,
2284
- attn: Attention,
2285
- hidden_states: torch.FloatTensor,
2286
- encoder_hidden_states: torch.FloatTensor = None,
2287
- attention_mask: Optional[torch.FloatTensor] = None,
2288
- image_rotary_emb: Optional[torch.Tensor] = None,
2289
- ) -> torch.FloatTensor:
2290
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2291
-
2292
- # `sample` projections.
2293
- query = attn.to_q(hidden_states)
2294
- key = attn.to_k(hidden_states)
2295
- value = attn.to_v(hidden_states)
2296
-
2297
- inner_dim = key.shape[-1]
2298
- head_dim = inner_dim // attn.heads
2299
-
2300
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2301
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2302
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2303
-
2304
- if attn.norm_q is not None:
2305
- query = attn.norm_q(query)
2306
- if attn.norm_k is not None:
2307
- key = attn.norm_k(key)
2308
-
2309
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2310
- if encoder_hidden_states is not None:
2311
- # `context` projections.
2312
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2313
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2314
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2315
-
2316
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2317
- batch_size, -1, attn.heads, head_dim
2318
- ).transpose(1, 2)
2319
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2320
- batch_size, -1, attn.heads, head_dim
2321
- ).transpose(1, 2)
2322
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2323
- batch_size, -1, attn.heads, head_dim
2324
- ).transpose(1, 2)
2325
-
2326
- if attn.norm_added_q is not None:
2327
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2328
- if attn.norm_added_k is not None:
2329
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2330
-
2331
- # attention
2332
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2333
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2334
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2335
-
2336
- if image_rotary_emb is not None:
2337
- from .embeddings import apply_rotary_emb
2338
-
2339
- query = apply_rotary_emb(query, image_rotary_emb)
2340
- key = apply_rotary_emb(key, image_rotary_emb)
2341
-
2342
- hidden_states = F.scaled_dot_product_attention(
2343
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2344
- )
2345
-
2346
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2347
- hidden_states = hidden_states.to(query.dtype)
2348
-
2349
- if encoder_hidden_states is not None:
2350
- encoder_hidden_states, hidden_states = (
2351
- hidden_states[:, : encoder_hidden_states.shape[1]],
2352
- hidden_states[:, encoder_hidden_states.shape[1] :],
2353
- )
2354
-
2355
- # linear proj
2356
- hidden_states = attn.to_out[0](hidden_states)
2357
- # dropout
2358
- hidden_states = attn.to_out[1](hidden_states)
2359
-
2360
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2361
-
2362
- return hidden_states, encoder_hidden_states
2363
- else:
2364
- return hidden_states
2365
-
2366
-
2367
- class FluxAttnProcessor2_0_NPU:
2368
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2369
-
2370
- def __init__(self):
2371
- if not hasattr(F, "scaled_dot_product_attention"):
2372
- raise ImportError(
2373
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
2374
- )
2375
-
2376
- def __call__(
2377
- self,
2378
- attn: Attention,
2379
- hidden_states: torch.FloatTensor,
2380
- encoder_hidden_states: torch.FloatTensor = None,
2381
- attention_mask: Optional[torch.FloatTensor] = None,
2382
- image_rotary_emb: Optional[torch.Tensor] = None,
2383
- ) -> torch.FloatTensor:
2384
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2385
-
2386
- # `sample` projections.
2387
- query = attn.to_q(hidden_states)
2388
- key = attn.to_k(hidden_states)
2389
- value = attn.to_v(hidden_states)
2390
-
2391
- inner_dim = key.shape[-1]
2392
- head_dim = inner_dim // attn.heads
2393
-
2394
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2395
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2396
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2397
-
2398
- if attn.norm_q is not None:
2399
- query = attn.norm_q(query)
2400
- if attn.norm_k is not None:
2401
- key = attn.norm_k(key)
2402
-
2403
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2404
- if encoder_hidden_states is not None:
2405
- # `context` projections.
2406
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2407
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2408
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2409
-
2410
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2411
- batch_size, -1, attn.heads, head_dim
2412
- ).transpose(1, 2)
2413
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2414
- batch_size, -1, attn.heads, head_dim
2415
- ).transpose(1, 2)
2416
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2417
- batch_size, -1, attn.heads, head_dim
2418
- ).transpose(1, 2)
2419
-
2420
- if attn.norm_added_q is not None:
2421
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2422
- if attn.norm_added_k is not None:
2423
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2424
-
2425
- # attention
2426
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2427
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2428
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2429
-
2430
- if image_rotary_emb is not None:
2431
- from .embeddings import apply_rotary_emb
2432
-
2433
- query = apply_rotary_emb(query, image_rotary_emb)
2434
- key = apply_rotary_emb(key, image_rotary_emb)
2435
-
2436
- if query.dtype in (torch.float16, torch.bfloat16):
2437
- hidden_states = torch_npu.npu_fusion_attention(
2438
- query,
2439
- key,
2440
- value,
2441
- attn.heads,
2442
- input_layout="BNSD",
2443
- pse=None,
2444
- scale=1.0 / math.sqrt(query.shape[-1]),
2445
- pre_tockens=65536,
2446
- next_tockens=65536,
2447
- keep_prob=1.0,
2448
- sync=False,
2449
- inner_precise=0,
2450
- )[0]
2451
- else:
2452
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2453
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2454
- hidden_states = hidden_states.to(query.dtype)
2455
-
2456
- if encoder_hidden_states is not None:
2457
- encoder_hidden_states, hidden_states = (
2458
- hidden_states[:, : encoder_hidden_states.shape[1]],
2459
- hidden_states[:, encoder_hidden_states.shape[1] :],
2460
- )
2461
-
2462
- # linear proj
2463
- hidden_states = attn.to_out[0](hidden_states)
2464
- # dropout
2465
- hidden_states = attn.to_out[1](hidden_states)
2466
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2467
-
2468
- return hidden_states, encoder_hidden_states
2469
- else:
2470
- return hidden_states
2471
-
2472
-
2473
- class FusedFluxAttnProcessor2_0:
2474
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2475
-
2476
- def __init__(self):
2477
- if not hasattr(F, "scaled_dot_product_attention"):
2478
- raise ImportError(
2479
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2480
- )
2481
-
2482
- def __call__(
2483
- self,
2484
- attn: Attention,
2485
- hidden_states: torch.FloatTensor,
2486
- encoder_hidden_states: torch.FloatTensor = None,
2487
- attention_mask: Optional[torch.FloatTensor] = None,
2488
- image_rotary_emb: Optional[torch.Tensor] = None,
2489
- ) -> torch.FloatTensor:
2490
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2491
-
2492
- # `sample` projections.
2493
- qkv = attn.to_qkv(hidden_states)
2494
- split_size = qkv.shape[-1] // 3
2495
- query, key, value = torch.split(qkv, split_size, dim=-1)
2496
-
2497
- inner_dim = key.shape[-1]
2498
- head_dim = inner_dim // attn.heads
2499
-
2500
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2501
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2502
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2503
-
2504
- if attn.norm_q is not None:
2505
- query = attn.norm_q(query)
2506
- if attn.norm_k is not None:
2507
- key = attn.norm_k(key)
2508
-
2509
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2510
- # `context` projections.
2511
- if encoder_hidden_states is not None:
2512
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2513
- split_size = encoder_qkv.shape[-1] // 3
2514
- (
2515
- encoder_hidden_states_query_proj,
2516
- encoder_hidden_states_key_proj,
2517
- encoder_hidden_states_value_proj,
2518
- ) = torch.split(encoder_qkv, split_size, dim=-1)
2519
-
2520
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2521
- batch_size, -1, attn.heads, head_dim
2522
- ).transpose(1, 2)
2523
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2524
- batch_size, -1, attn.heads, head_dim
2525
- ).transpose(1, 2)
2526
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2527
- batch_size, -1, attn.heads, head_dim
2528
- ).transpose(1, 2)
2529
-
2530
- if attn.norm_added_q is not None:
2531
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2532
- if attn.norm_added_k is not None:
2533
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2534
-
2535
- # attention
2536
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2537
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2538
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2539
-
2540
- if image_rotary_emb is not None:
2541
- from .embeddings import apply_rotary_emb
2542
-
2543
- query = apply_rotary_emb(query, image_rotary_emb)
2544
- key = apply_rotary_emb(key, image_rotary_emb)
2545
-
2546
- hidden_states = F.scaled_dot_product_attention(
2547
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2548
- )
2549
-
2550
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2551
- hidden_states = hidden_states.to(query.dtype)
2552
-
2553
- if encoder_hidden_states is not None:
2554
- encoder_hidden_states, hidden_states = (
2555
- hidden_states[:, : encoder_hidden_states.shape[1]],
2556
- hidden_states[:, encoder_hidden_states.shape[1] :],
2557
- )
2558
-
2559
- # linear proj
2560
- hidden_states = attn.to_out[0](hidden_states)
2561
- # dropout
2562
- hidden_states = attn.to_out[1](hidden_states)
2563
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2564
-
2565
- return hidden_states, encoder_hidden_states
2566
- else:
2567
- return hidden_states
2568
-
2569
-
2570
- class FusedFluxAttnProcessor2_0_NPU:
2571
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2572
-
2573
- def __init__(self):
2574
- if not hasattr(F, "scaled_dot_product_attention"):
2575
- raise ImportError(
2576
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
2577
- )
2578
-
2579
- def __call__(
2580
- self,
2581
- attn: Attention,
2582
- hidden_states: torch.FloatTensor,
2583
- encoder_hidden_states: torch.FloatTensor = None,
2584
- attention_mask: Optional[torch.FloatTensor] = None,
2585
- image_rotary_emb: Optional[torch.Tensor] = None,
2586
- ) -> torch.FloatTensor:
2587
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2588
-
2589
- # `sample` projections.
2590
- qkv = attn.to_qkv(hidden_states)
2591
- split_size = qkv.shape[-1] // 3
2592
- query, key, value = torch.split(qkv, split_size, dim=-1)
2593
-
2594
- inner_dim = key.shape[-1]
2595
- head_dim = inner_dim // attn.heads
2596
-
2597
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2598
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2599
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2600
-
2601
- if attn.norm_q is not None:
2602
- query = attn.norm_q(query)
2603
- if attn.norm_k is not None:
2604
- key = attn.norm_k(key)
2605
-
2606
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2607
- # `context` projections.
2608
- if encoder_hidden_states is not None:
2609
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2610
- split_size = encoder_qkv.shape[-1] // 3
2611
- (
2612
- encoder_hidden_states_query_proj,
2613
- encoder_hidden_states_key_proj,
2614
- encoder_hidden_states_value_proj,
2615
- ) = torch.split(encoder_qkv, split_size, dim=-1)
2616
-
2617
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2618
- batch_size, -1, attn.heads, head_dim
2619
- ).transpose(1, 2)
2620
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2621
- batch_size, -1, attn.heads, head_dim
2622
- ).transpose(1, 2)
2623
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2624
- batch_size, -1, attn.heads, head_dim
2625
- ).transpose(1, 2)
2626
-
2627
- if attn.norm_added_q is not None:
2628
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2629
- if attn.norm_added_k is not None:
2630
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2631
-
2632
- # attention
2633
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2634
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2635
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2636
-
2637
- if image_rotary_emb is not None:
2638
- from .embeddings import apply_rotary_emb
2639
-
2640
- query = apply_rotary_emb(query, image_rotary_emb)
2641
- key = apply_rotary_emb(key, image_rotary_emb)
2642
-
2643
- if query.dtype in (torch.float16, torch.bfloat16):
2644
- hidden_states = torch_npu.npu_fusion_attention(
2645
- query,
2646
- key,
2647
- value,
2648
- attn.heads,
2649
- input_layout="BNSD",
2650
- pse=None,
2651
- scale=1.0 / math.sqrt(query.shape[-1]),
2652
- pre_tockens=65536,
2653
- next_tockens=65536,
2654
- keep_prob=1.0,
2655
- sync=False,
2656
- inner_precise=0,
2657
- )[0]
2658
- else:
2659
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2660
-
2661
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2662
- hidden_states = hidden_states.to(query.dtype)
2663
-
2664
- if encoder_hidden_states is not None:
2665
- encoder_hidden_states, hidden_states = (
2666
- hidden_states[:, : encoder_hidden_states.shape[1]],
2667
- hidden_states[:, encoder_hidden_states.shape[1] :],
2668
- )
2669
-
2670
- # linear proj
2671
- hidden_states = attn.to_out[0](hidden_states)
2672
- # dropout
2673
- hidden_states = attn.to_out[1](hidden_states)
2674
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2675
-
2676
- return hidden_states, encoder_hidden_states
2677
- else:
2678
- return hidden_states
2679
-
2680
-
2681
- class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2682
- """Flux Attention processor for IP-Adapter."""
2683
-
2684
- def __init__(
2685
- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
2686
- ):
2687
- super().__init__()
2688
-
2689
- if not hasattr(F, "scaled_dot_product_attention"):
2690
- raise ImportError(
2691
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2692
- )
2693
-
2694
- self.hidden_size = hidden_size
2695
- self.cross_attention_dim = cross_attention_dim
2696
-
2697
- if not isinstance(num_tokens, (tuple, list)):
2698
- num_tokens = [num_tokens]
2699
-
2700
- if not isinstance(scale, list):
2701
- scale = [scale] * len(num_tokens)
2702
- if len(scale) != len(num_tokens):
2703
- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
2704
- self.scale = scale
2705
-
2706
- self.to_k_ip = nn.ModuleList(
2707
- [
2708
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2709
- for _ in range(len(num_tokens))
2710
- ]
2711
- )
2712
- self.to_v_ip = nn.ModuleList(
2713
- [
2714
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2715
- for _ in range(len(num_tokens))
2716
- ]
2717
- )
2718
-
2719
- def __call__(
2720
- self,
2721
- attn: Attention,
2722
- hidden_states: torch.FloatTensor,
2723
- encoder_hidden_states: torch.FloatTensor = None,
2724
- attention_mask: Optional[torch.FloatTensor] = None,
2725
- image_rotary_emb: Optional[torch.Tensor] = None,
2726
- ip_hidden_states: Optional[List[torch.Tensor]] = None,
2727
- ip_adapter_masks: Optional[torch.Tensor] = None,
2728
- ) -> torch.FloatTensor:
2729
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2730
-
2731
- # `sample` projections.
2732
- hidden_states_query_proj = attn.to_q(hidden_states)
2733
- key = attn.to_k(hidden_states)
2734
- value = attn.to_v(hidden_states)
2735
-
2736
- inner_dim = key.shape[-1]
2737
- head_dim = inner_dim // attn.heads
2738
-
2739
- hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2740
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2741
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2742
-
2743
- if attn.norm_q is not None:
2744
- hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
2745
- if attn.norm_k is not None:
2746
- key = attn.norm_k(key)
2747
-
2748
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2749
- if encoder_hidden_states is not None:
2750
- # `context` projections.
2751
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2752
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2753
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2754
-
2755
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2756
- batch_size, -1, attn.heads, head_dim
2757
- ).transpose(1, 2)
2758
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2759
- batch_size, -1, attn.heads, head_dim
2760
- ).transpose(1, 2)
2761
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2762
- batch_size, -1, attn.heads, head_dim
2763
- ).transpose(1, 2)
2764
-
2765
- if attn.norm_added_q is not None:
2766
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2767
- if attn.norm_added_k is not None:
2768
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2769
-
2770
- # attention
2771
- query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
2772
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2773
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2774
-
2775
- if image_rotary_emb is not None:
2776
- from .embeddings import apply_rotary_emb
2777
-
2778
- query = apply_rotary_emb(query, image_rotary_emb)
2779
- key = apply_rotary_emb(key, image_rotary_emb)
2780
-
2781
- hidden_states = F.scaled_dot_product_attention(
2782
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2783
- )
2784
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2785
- hidden_states = hidden_states.to(query.dtype)
2786
-
2787
- if encoder_hidden_states is not None:
2788
- encoder_hidden_states, hidden_states = (
2789
- hidden_states[:, : encoder_hidden_states.shape[1]],
2790
- hidden_states[:, encoder_hidden_states.shape[1] :],
2791
- )
2792
-
2793
- # linear proj
2794
- hidden_states = attn.to_out[0](hidden_states)
2795
- # dropout
2796
- hidden_states = attn.to_out[1](hidden_states)
2797
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2798
-
2799
- # IP-adapter
2800
- ip_query = hidden_states_query_proj
2801
- ip_attn_output = torch.zeros_like(hidden_states)
2802
-
2803
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2804
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2805
- ):
2806
- ip_key = to_k_ip(current_ip_hidden_states)
2807
- ip_value = to_v_ip(current_ip_hidden_states)
2808
-
2809
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2810
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2811
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2812
- # TODO: add support for attn.scale when we move to Torch 2.1
2813
- current_ip_hidden_states = F.scaled_dot_product_attention(
2814
- ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2815
- )
2816
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2817
- batch_size, -1, attn.heads * head_dim
2818
- )
2819
- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2820
- ip_attn_output += scale * current_ip_hidden_states
2821
-
2822
- return hidden_states, encoder_hidden_states, ip_attn_output
2823
- else:
2824
- return hidden_states
2825
-
2826
-
2827
2275
  class CogVideoXAttnProcessor2_0:
2828
2276
  r"""
2829
2277
  Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -3453,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
3453
2901
  return hidden_states
3454
2902
 
3455
2903
 
3456
- class XLAFluxFlashAttnProcessor2_0:
3457
- r"""
3458
- Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3459
- """
3460
-
3461
- def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
3462
- if not hasattr(F, "scaled_dot_product_attention"):
3463
- raise ImportError(
3464
- "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3465
- )
3466
- if is_torch_xla_version("<", "2.3"):
3467
- raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
3468
- if is_spmd() and is_torch_xla_version("<", "2.4"):
3469
- raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
3470
- self.partition_spec = partition_spec
3471
-
3472
- def __call__(
3473
- self,
3474
- attn: Attention,
3475
- hidden_states: torch.FloatTensor,
3476
- encoder_hidden_states: torch.FloatTensor = None,
3477
- attention_mask: Optional[torch.FloatTensor] = None,
3478
- image_rotary_emb: Optional[torch.Tensor] = None,
3479
- ) -> torch.FloatTensor:
3480
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3481
-
3482
- # `sample` projections.
3483
- query = attn.to_q(hidden_states)
3484
- key = attn.to_k(hidden_states)
3485
- value = attn.to_v(hidden_states)
3486
-
3487
- inner_dim = key.shape[-1]
3488
- head_dim = inner_dim // attn.heads
3489
-
3490
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3491
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3492
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3493
-
3494
- if attn.norm_q is not None:
3495
- query = attn.norm_q(query)
3496
- if attn.norm_k is not None:
3497
- key = attn.norm_k(key)
3498
-
3499
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
3500
- if encoder_hidden_states is not None:
3501
- # `context` projections.
3502
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
3503
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
3504
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
3505
-
3506
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
3507
- batch_size, -1, attn.heads, head_dim
3508
- ).transpose(1, 2)
3509
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
3510
- batch_size, -1, attn.heads, head_dim
3511
- ).transpose(1, 2)
3512
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
3513
- batch_size, -1, attn.heads, head_dim
3514
- ).transpose(1, 2)
3515
-
3516
- if attn.norm_added_q is not None:
3517
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
3518
- if attn.norm_added_k is not None:
3519
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
3520
-
3521
- # attention
3522
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
3523
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
3524
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
3525
-
3526
- if image_rotary_emb is not None:
3527
- from .embeddings import apply_rotary_emb
3528
-
3529
- query = apply_rotary_emb(query, image_rotary_emb)
3530
- key = apply_rotary_emb(key, image_rotary_emb)
3531
-
3532
- query /= math.sqrt(head_dim)
3533
- hidden_states = flash_attention(query, key, value, causal=False)
3534
-
3535
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3536
- hidden_states = hidden_states.to(query.dtype)
3537
-
3538
- if encoder_hidden_states is not None:
3539
- encoder_hidden_states, hidden_states = (
3540
- hidden_states[:, : encoder_hidden_states.shape[1]],
3541
- hidden_states[:, encoder_hidden_states.shape[1] :],
3542
- )
3543
-
3544
- # linear proj
3545
- hidden_states = attn.to_out[0](hidden_states)
3546
- # dropout
3547
- hidden_states = attn.to_out[1](hidden_states)
3548
-
3549
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3550
-
3551
- return hidden_states, encoder_hidden_states
3552
- else:
3553
- return hidden_states
3554
-
3555
-
3556
2904
  class MochiVaeAttnProcessor2_0:
3557
2905
  r"""
3558
2906
  Attention processor used in Mochi VAE.
@@ -5992,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
5992
5340
  pass
5993
5341
 
5994
5342
 
5995
- class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
5996
- r"""
5997
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
5998
- """
5999
-
6000
- def __init__(self):
6001
- deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
6002
- deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
6003
- super().__init__()
6004
-
6005
-
6006
5343
  class SanaLinearAttnProcessor2_0:
6007
5344
  r"""
6008
5345
  Processor for implementing scaled dot-product linear attention.
@@ -6167,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
6167
5504
  return hidden_states
6168
5505
 
6169
5506
 
5507
+ class FluxAttnProcessor2_0:
5508
+ def __new__(cls, *args, **kwargs):
5509
+ deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
5510
+ deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
5511
+
5512
+ from .transformers.transformer_flux import FluxAttnProcessor
5513
+
5514
+ return FluxAttnProcessor(*args, **kwargs)
5515
+
5516
+
5517
+ class FluxSingleAttnProcessor2_0:
5518
+ r"""
5519
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
5520
+ """
5521
+
5522
+ def __new__(cls, *args, **kwargs):
5523
+ deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
5524
+ deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
5525
+
5526
+ from .transformers.transformer_flux import FluxAttnProcessor
5527
+
5528
+ return FluxAttnProcessor(*args, **kwargs)
5529
+
5530
+
5531
+ class FusedFluxAttnProcessor2_0:
5532
+ def __new__(cls, *args, **kwargs):
5533
+ deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
5534
+ deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
5535
+
5536
+ from .transformers.transformer_flux import FluxAttnProcessor
5537
+
5538
+ return FluxAttnProcessor(*args, **kwargs)
5539
+
5540
+
5541
+ class FluxIPAdapterJointAttnProcessor2_0:
5542
+ def __new__(cls, *args, **kwargs):
5543
+ deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
5544
+ deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
5545
+
5546
+ from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
5547
+
5548
+ return FluxIPAdapterAttnProcessor(*args, **kwargs)
5549
+
5550
+
5551
+ class FluxAttnProcessor2_0_NPU:
5552
+ def __new__(cls, *args, **kwargs):
5553
+ deprecation_message = (
5554
+ "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
5555
+ "alternative solution to use NPU Flash Attention will be provided in the future."
5556
+ )
5557
+ deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
5558
+
5559
+ from .transformers.transformer_flux import FluxAttnProcessor
5560
+
5561
+ processor = FluxAttnProcessor()
5562
+ processor._attention_backend = "_native_npu"
5563
+ return processor
5564
+
5565
+
5566
+ class FusedFluxAttnProcessor2_0_NPU:
5567
+ def __new__(self):
5568
+ deprecation_message = (
5569
+ "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
5570
+ "alternative solution to use NPU Flash Attention will be provided in the future."
5571
+ )
5572
+ deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
5573
+
5574
+ from .transformers.transformer_flux import FluxAttnProcessor
5575
+
5576
+ processor = FluxAttnProcessor()
5577
+ processor._attention_backend = "_fused_npu"
5578
+ return processor
5579
+
5580
+
5581
+ class XLAFluxFlashAttnProcessor2_0:
5582
+ r"""
5583
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
5584
+ """
5585
+
5586
+ def __new__(cls, *args, **kwargs):
5587
+ deprecation_message = (
5588
+ "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
5589
+ "alternative solution to using XLA Flash Attention will be provided in the future."
5590
+ )
5591
+ deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
5592
+
5593
+ if is_torch_xla_version("<", "2.3"):
5594
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
5595
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
5596
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
5597
+
5598
+ from .transformers.transformer_flux import FluxAttnProcessor
5599
+
5600
+ if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
5601
+ deprecation_message = (
5602
+ "partition_spec was not used in the processor implementation when it was added. Passing it "
5603
+ "is a no-op and support for it will be removed."
5604
+ )
5605
+ deprecate("partition_spec", "1.0.0", deprecation_message)
5606
+
5607
+ processor = FluxAttnProcessor(*args, **kwargs)
5608
+ processor._attention_backend = "_native_xla"
5609
+ return processor
5610
+
5611
+
6170
5612
  ADDED_KV_ATTENTION_PROCESSORS = (
6171
5613
  AttnAddedKVProcessor,
6172
5614
  SlicedAttnAddedKVProcessor,