diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@
14
14
 
15
15
  import re
16
16
 
17
+ import torch
18
+
17
19
  from ..utils import is_peft_version, logging
18
20
 
19
21
 
@@ -326,3 +328,333 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
326
328
  prefix = "text_encoder_2."
327
329
  new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
328
330
  return {new_name: alpha}
331
+
332
+
333
+ # The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
+ # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
+ # All credits go to `kohya-ss`.
336
+ def _convert_kohya_flux_lora_to_diffusers(state_dict):
337
+ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338
+ if sds_key + ".lora_down.weight" not in sds_sd:
339
+ return
340
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
341
+
342
+ # scale weight by alpha and dim
343
+ rank = down_weight.shape[0]
344
+ alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
345
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
+
347
+ # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
348
+ scale_down = scale
349
+ scale_up = 1.0
350
+ while scale_down * 2 < scale_up:
351
+ scale_down *= 2
352
+ scale_up /= 2
353
+
354
+ ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
355
+ ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
356
+
357
+ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
358
+ if sds_key + ".lora_down.weight" not in sds_sd:
359
+ return
360
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
361
+ up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
362
+ sd_lora_rank = down_weight.shape[0]
363
+
364
+ # scale weight by alpha and dim
365
+ alpha = sds_sd.pop(sds_key + ".alpha")
366
+ scale = alpha / sd_lora_rank
367
+
368
+ # calculate scale_down and scale_up
369
+ scale_down = scale
370
+ scale_up = 1.0
371
+ while scale_down * 2 < scale_up:
372
+ scale_down *= 2
373
+ scale_up /= 2
374
+
375
+ down_weight = down_weight * scale_down
376
+ up_weight = up_weight * scale_up
377
+
378
+ # calculate dims if not provided
379
+ num_splits = len(ait_keys)
380
+ if dims is None:
381
+ dims = [up_weight.shape[0] // num_splits] * num_splits
382
+ else:
383
+ assert sum(dims) == up_weight.shape[0]
384
+
385
+ # check upweight is sparse or not
386
+ is_sparse = False
387
+ if sd_lora_rank % num_splits == 0:
388
+ ait_rank = sd_lora_rank // num_splits
389
+ is_sparse = True
390
+ i = 0
391
+ for j in range(len(dims)):
392
+ for k in range(len(dims)):
393
+ if j == k:
394
+ continue
395
+ is_sparse = is_sparse and torch.all(
396
+ up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
397
+ )
398
+ i += dims[j]
399
+ if is_sparse:
400
+ logger.info(f"weight is sparse: {sds_key}")
401
+
402
+ # make ai-toolkit weight
403
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
404
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
405
+ if not is_sparse:
406
+ # down_weight is copied to each split
407
+ ait_sd.update({k: down_weight for k in ait_down_keys})
408
+
409
+ # up_weight is split to each split
410
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
411
+ else:
412
+ # down_weight is chunked to each split
413
+ ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
414
+
415
+ # up_weight is sparse: only non-zero values are copied to each split
416
+ i = 0
417
+ for j in range(len(dims)):
418
+ ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
419
+ i += dims[j]
420
+
421
+ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
422
+ ait_sd = {}
423
+ for i in range(19):
424
+ _convert_to_ai_toolkit(
425
+ sds_sd,
426
+ ait_sd,
427
+ f"lora_unet_double_blocks_{i}_img_attn_proj",
428
+ f"transformer.transformer_blocks.{i}.attn.to_out.0",
429
+ )
430
+ _convert_to_ai_toolkit_cat(
431
+ sds_sd,
432
+ ait_sd,
433
+ f"lora_unet_double_blocks_{i}_img_attn_qkv",
434
+ [
435
+ f"transformer.transformer_blocks.{i}.attn.to_q",
436
+ f"transformer.transformer_blocks.{i}.attn.to_k",
437
+ f"transformer.transformer_blocks.{i}.attn.to_v",
438
+ ],
439
+ )
440
+ _convert_to_ai_toolkit(
441
+ sds_sd,
442
+ ait_sd,
443
+ f"lora_unet_double_blocks_{i}_img_mlp_0",
444
+ f"transformer.transformer_blocks.{i}.ff.net.0.proj",
445
+ )
446
+ _convert_to_ai_toolkit(
447
+ sds_sd,
448
+ ait_sd,
449
+ f"lora_unet_double_blocks_{i}_img_mlp_2",
450
+ f"transformer.transformer_blocks.{i}.ff.net.2",
451
+ )
452
+ _convert_to_ai_toolkit(
453
+ sds_sd,
454
+ ait_sd,
455
+ f"lora_unet_double_blocks_{i}_img_mod_lin",
456
+ f"transformer.transformer_blocks.{i}.norm1.linear",
457
+ )
458
+ _convert_to_ai_toolkit(
459
+ sds_sd,
460
+ ait_sd,
461
+ f"lora_unet_double_blocks_{i}_txt_attn_proj",
462
+ f"transformer.transformer_blocks.{i}.attn.to_add_out",
463
+ )
464
+ _convert_to_ai_toolkit_cat(
465
+ sds_sd,
466
+ ait_sd,
467
+ f"lora_unet_double_blocks_{i}_txt_attn_qkv",
468
+ [
469
+ f"transformer.transformer_blocks.{i}.attn.add_q_proj",
470
+ f"transformer.transformer_blocks.{i}.attn.add_k_proj",
471
+ f"transformer.transformer_blocks.{i}.attn.add_v_proj",
472
+ ],
473
+ )
474
+ _convert_to_ai_toolkit(
475
+ sds_sd,
476
+ ait_sd,
477
+ f"lora_unet_double_blocks_{i}_txt_mlp_0",
478
+ f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
479
+ )
480
+ _convert_to_ai_toolkit(
481
+ sds_sd,
482
+ ait_sd,
483
+ f"lora_unet_double_blocks_{i}_txt_mlp_2",
484
+ f"transformer.transformer_blocks.{i}.ff_context.net.2",
485
+ )
486
+ _convert_to_ai_toolkit(
487
+ sds_sd,
488
+ ait_sd,
489
+ f"lora_unet_double_blocks_{i}_txt_mod_lin",
490
+ f"transformer.transformer_blocks.{i}.norm1_context.linear",
491
+ )
492
+
493
+ for i in range(38):
494
+ _convert_to_ai_toolkit_cat(
495
+ sds_sd,
496
+ ait_sd,
497
+ f"lora_unet_single_blocks_{i}_linear1",
498
+ [
499
+ f"transformer.single_transformer_blocks.{i}.attn.to_q",
500
+ f"transformer.single_transformer_blocks.{i}.attn.to_k",
501
+ f"transformer.single_transformer_blocks.{i}.attn.to_v",
502
+ f"transformer.single_transformer_blocks.{i}.proj_mlp",
503
+ ],
504
+ dims=[3072, 3072, 3072, 12288],
505
+ )
506
+ _convert_to_ai_toolkit(
507
+ sds_sd,
508
+ ait_sd,
509
+ f"lora_unet_single_blocks_{i}_linear2",
510
+ f"transformer.single_transformer_blocks.{i}.proj_out",
511
+ )
512
+ _convert_to_ai_toolkit(
513
+ sds_sd,
514
+ ait_sd,
515
+ f"lora_unet_single_blocks_{i}_modulation_lin",
516
+ f"transformer.single_transformer_blocks.{i}.norm.linear",
517
+ )
518
+
519
+ remaining_keys = list(sds_sd.keys())
520
+ te_state_dict = {}
521
+ if remaining_keys:
522
+ if not all(k.startswith("lora_te1") for k in remaining_keys):
523
+ raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524
+ for key in remaining_keys:
525
+ if not key.endswith("lora_down.weight"):
526
+ continue
527
+
528
+ lora_name = key.split(".")[0]
529
+ lora_name_up = f"{lora_name}.lora_up.weight"
530
+ lora_name_alpha = f"{lora_name}.alpha"
531
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
532
+
533
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
534
+ down_weight = sds_sd.pop(key)
535
+ sd_lora_rank = down_weight.shape[0]
536
+ te_state_dict[diffusers_name] = down_weight
537
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
538
+
539
+ if lora_name_alpha in sds_sd:
540
+ alpha = sds_sd.pop(lora_name_alpha).item()
541
+ scale = alpha / sd_lora_rank
542
+
543
+ scale_down = scale
544
+ scale_up = 1.0
545
+ while scale_down * 2 < scale_up:
546
+ scale_down *= 2
547
+ scale_up /= 2
548
+
549
+ te_state_dict[diffusers_name] *= scale_down
550
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
551
+
552
+ if len(sds_sd) > 0:
553
+ logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
554
+
555
+ if te_state_dict:
556
+ te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
557
+
558
+ new_state_dict = {**ait_sd, **te_state_dict}
559
+ return new_state_dict
560
+
561
+ return _convert_sd_scripts_to_ai_toolkit(state_dict)
562
+
563
+
564
+ # Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
565
+ # Some utilities were reused from
566
+ # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
567
+ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
568
+ new_state_dict = {}
569
+ orig_keys = list(old_state_dict.keys())
570
+
571
+ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
572
+ down_weight = sds_sd.pop(sds_key)
573
+ up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
574
+
575
+ # calculate dims if not provided
576
+ num_splits = len(ait_keys)
577
+ if dims is None:
578
+ dims = [up_weight.shape[0] // num_splits] * num_splits
579
+ else:
580
+ assert sum(dims) == up_weight.shape[0]
581
+
582
+ # make ai-toolkit weight
583
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
584
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
585
+
586
+ # down_weight is copied to each split
587
+ ait_sd.update({k: down_weight for k in ait_down_keys})
588
+
589
+ # up_weight is split to each split
590
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
591
+
592
+ for old_key in orig_keys:
593
+ # Handle double_blocks
594
+ if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
595
+ block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
596
+ new_key = f"transformer.transformer_blocks.{block_num}"
597
+
598
+ if "processor.proj_lora1" in old_key:
599
+ new_key += ".attn.to_out.0"
600
+ elif "processor.proj_lora2" in old_key:
601
+ new_key += ".attn.to_add_out"
602
+ # Handle text latents.
603
+ elif "processor.qkv_lora2" in old_key and "up" not in old_key:
604
+ handle_qkv(
605
+ old_state_dict,
606
+ new_state_dict,
607
+ old_key,
608
+ [
609
+ f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
610
+ f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
611
+ f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
612
+ ],
613
+ )
614
+ # continue
615
+ # Handle image latents.
616
+ elif "processor.qkv_lora1" in old_key and "up" not in old_key:
617
+ handle_qkv(
618
+ old_state_dict,
619
+ new_state_dict,
620
+ old_key,
621
+ [
622
+ f"transformer.transformer_blocks.{block_num}.attn.to_q",
623
+ f"transformer.transformer_blocks.{block_num}.attn.to_k",
624
+ f"transformer.transformer_blocks.{block_num}.attn.to_v",
625
+ ],
626
+ )
627
+ # continue
628
+
629
+ if "down" in old_key:
630
+ new_key += ".lora_A.weight"
631
+ elif "up" in old_key:
632
+ new_key += ".lora_B.weight"
633
+
634
+ # Handle single_blocks
635
+ elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
636
+ block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637
+ new_key = f"transformer.single_transformer_blocks.{block_num}"
638
+
639
+ if "proj_lora1" in old_key or "proj_lora2" in old_key:
640
+ new_key += ".proj_out"
641
+ elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
642
+ new_key += ".norm.linear"
643
+
644
+ if "down" in old_key:
645
+ new_key += ".lora_A.weight"
646
+ elif "up" in old_key:
647
+ new_key += ".lora_B.weight"
648
+
649
+ else:
650
+ # Handle other potential key patterns here
651
+ new_key = old_key
652
+
653
+ # Since we already handle qkv above.
654
+ if "qkv" not in old_key:
655
+ new_state_dict[new_key] = old_state_dict.pop(old_key)
656
+
657
+ if len(old_state_dict) > 0:
658
+ raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
659
+
660
+ return new_state_dict