diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
817
817
  # has both `peft` and non-peft state dict.
818
818
  has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
819
819
  if has_peft_state_dict:
820
- state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
820
+ state_dict = {
821
+ k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
822
+ for k, v in state_dict.items()
823
+ if k.startswith("transformer.")
824
+ }
821
825
  return state_dict
822
826
 
823
827
  # Another weird one.
@@ -1346,6 +1350,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
1346
1350
  return converted_state_dict
1347
1351
 
1348
1352
 
1353
+ def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
1354
+ converted_state_dict = {}
1355
+ original_state_dict_keys = list(original_state_dict.keys())
1356
+ num_layers = 19
1357
+ num_single_layers = 38
1358
+ inner_dim = 3072
1359
+ mlp_ratio = 4.0
1360
+
1361
+ # double transformer blocks
1362
+ for i in range(num_layers):
1363
+ block_prefix = f"transformer_blocks.{i}."
1364
+ original_block_prefix = "base_model.model."
1365
+
1366
+ for lora_key in ["lora_A", "lora_B"]:
1367
+ # norms
1368
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
1369
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
1370
+ )
1371
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
1372
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
1373
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
1374
+ )
1375
+
1376
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
1377
+ f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
1378
+ )
1379
+
1380
+ # Q, K, V
1381
+ if lora_key == "lora_A":
1382
+ sample_lora_weight = original_state_dict.pop(
1383
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
1384
+ )
1385
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1386
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1387
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
1388
+
1389
+ context_lora_weight = original_state_dict.pop(
1390
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
1391
+ )
1392
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
1393
+ [context_lora_weight]
1394
+ )
1395
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
1396
+ [context_lora_weight]
1397
+ )
1398
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
1399
+ [context_lora_weight]
1400
+ )
1401
+ else:
1402
+ sample_q, sample_k, sample_v = torch.chunk(
1403
+ original_state_dict.pop(
1404
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
1405
+ ),
1406
+ 3,
1407
+ dim=0,
1408
+ )
1409
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
1410
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
1411
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
1412
+
1413
+ context_q, context_k, context_v = torch.chunk(
1414
+ original_state_dict.pop(
1415
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
1416
+ ),
1417
+ 3,
1418
+ dim=0,
1419
+ )
1420
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
1421
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
1422
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
1423
+
1424
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
1425
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1426
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
1427
+ 3,
1428
+ dim=0,
1429
+ )
1430
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
1431
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
1432
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
1433
+
1434
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
1435
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1436
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
1437
+ 3,
1438
+ dim=0,
1439
+ )
1440
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
1441
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
1442
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
1443
+
1444
+ # ff img_mlp
1445
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
1446
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
1447
+ )
1448
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
1449
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
1450
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
1451
+ )
1452
+
1453
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
1454
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
1455
+ )
1456
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
1457
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
1458
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
1459
+ )
1460
+
1461
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
1462
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
1463
+ )
1464
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
1465
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
1466
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
1467
+ )
1468
+
1469
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
1470
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
1471
+ )
1472
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
1473
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
1474
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
1475
+ )
1476
+
1477
+ # output projections.
1478
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
1479
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
1480
+ )
1481
+ if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
1482
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
1483
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
1484
+ )
1485
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
1486
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
1487
+ )
1488
+ if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
1489
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
1490
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
1491
+ )
1492
+
1493
+ # single transformer blocks
1494
+ for i in range(num_single_layers):
1495
+ block_prefix = f"single_transformer_blocks.{i}."
1496
+
1497
+ for lora_key in ["lora_A", "lora_B"]:
1498
+ # norm.linear <- single_blocks.0.modulation.lin
1499
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
1500
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
1501
+ )
1502
+ if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
1503
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
1504
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
1505
+ )
1506
+
1507
+ # Q, K, V, mlp
1508
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
1509
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
1510
+
1511
+ if lora_key == "lora_A":
1512
+ lora_weight = original_state_dict.pop(
1513
+ f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
1514
+ )
1515
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
1516
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
1517
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
1518
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
1519
+
1520
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
1521
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
1522
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
1523
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
1524
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
1525
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
1526
+ else:
1527
+ q, k, v, mlp = torch.split(
1528
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
1529
+ split_size,
1530
+ dim=0,
1531
+ )
1532
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
1533
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
1534
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
1535
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
1536
+
1537
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
1538
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
1539
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
1540
+ split_size,
1541
+ dim=0,
1542
+ )
1543
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
1544
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
1545
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
1546
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
1547
+
1548
+ # output projections.
1549
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
1550
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
1551
+ )
1552
+ if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
1553
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
1554
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
1555
+ )
1556
+
1557
+ for lora_key in ["lora_A", "lora_B"]:
1558
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
1559
+ f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
1560
+ )
1561
+ if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
1562
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
1563
+ f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
1564
+ )
1565
+
1566
+ if len(original_state_dict) > 0:
1567
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
1568
+
1569
+ for key in list(converted_state_dict.keys()):
1570
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1571
+
1572
+ return converted_state_dict
1573
+
1574
+
1349
1575
  def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
1350
1576
  converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
1351
1577
 
@@ -1603,24 +1829,33 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1603
1829
  is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
1604
1830
  lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
1605
1831
  lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
1832
+ has_time_projection_weight = any(
1833
+ k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
1834
+ )
1606
1835
 
1607
- diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
1608
- if diff_keys:
1609
- for diff_k in diff_keys:
1610
- param = original_state_dict[diff_k]
1611
- # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1612
- # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1613
- # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1614
- # is okay to ignore because they do not affect the model output in a significant manner.
1615
- threshold = 1.6e-2
1616
- absdiff = param.abs().max() - param.abs().min()
1617
- all_zero = torch.all(param == 0).item()
1618
- all_absdiff_lower_than_threshold = absdiff < threshold
1619
- if all_zero or all_absdiff_lower_than_threshold:
1620
- logger.debug(
1621
- f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1622
- )
1623
- original_state_dict.pop(diff_k)
1836
+ def get_alpha_scales(down_weight, alpha_key):
1837
+ rank = down_weight.shape[0]
1838
+ alpha = original_state_dict.pop(alpha_key).item()
1839
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1840
+ scale_down = scale
1841
+ scale_up = 1.0
1842
+ while scale_down * 2 < scale_up:
1843
+ scale_down *= 2
1844
+ scale_up /= 2
1845
+ return scale_down, scale_up
1846
+
1847
+ for key in list(original_state_dict.keys()):
1848
+ if key.endswith((".diff", ".diff_b")) and "norm" in key:
1849
+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
1850
+ # in future if needed and they are not zeroed.
1851
+ original_state_dict.pop(key)
1852
+ logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
1853
+
1854
+ if "time_projection" in key and not has_time_projection_weight:
1855
+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
1856
+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
1857
+ # CausVid lora has the weight keys and the bias keys.
1858
+ original_state_dict.pop(key)
1624
1859
 
1625
1860
  # For the `diff_b` keys, we treat them as lora_bias.
1626
1861
  # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
@@ -1628,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1628
1863
  for i in range(min_block, max_block + 1):
1629
1864
  # Self-attention
1630
1865
  for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1631
- original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1632
- converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
1633
- if original_key in original_state_dict:
1634
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1866
+ alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
1867
+ has_alpha = alpha_key in original_state_dict
1868
+ original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1869
+ converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
1635
1870
 
1636
- original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1637
- converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
1638
- if original_key in original_state_dict:
1639
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1871
+ original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1872
+ converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
1873
+
1874
+ if has_alpha:
1875
+ down_weight = original_state_dict.pop(original_key_A)
1876
+ up_weight = original_state_dict.pop(original_key_B)
1877
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1878
+ converted_state_dict[converted_key_A] = down_weight * scale_down
1879
+ converted_state_dict[converted_key_B] = up_weight * scale_up
1880
+
1881
+ else:
1882
+ if original_key_A in original_state_dict:
1883
+ converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1884
+ if original_key_B in original_state_dict:
1885
+ converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
1640
1886
 
1641
1887
  original_key = f"blocks.{i}.self_attn.{o}.diff_b"
1642
1888
  converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
@@ -1645,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1645
1891
 
1646
1892
  # Cross-attention
1647
1893
  for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1648
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1649
- converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1650
- if original_key in original_state_dict:
1651
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1652
-
1653
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1654
- converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1655
- if original_key in original_state_dict:
1656
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1894
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1895
+ has_alpha = alpha_key in original_state_dict
1896
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1897
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1898
+
1899
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1900
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1901
+
1902
+ if original_key_A in original_state_dict:
1903
+ down_weight = original_state_dict.pop(original_key_A)
1904
+ converted_state_dict[converted_key_A] = down_weight
1905
+ if original_key_B in original_state_dict:
1906
+ up_weight = original_state_dict.pop(original_key_B)
1907
+ converted_state_dict[converted_key_B] = up_weight
1908
+ if has_alpha:
1909
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1910
+ converted_state_dict[converted_key_A] *= scale_down
1911
+ converted_state_dict[converted_key_B] *= scale_up
1657
1912
 
1658
1913
  original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1659
1914
  converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1662,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1662
1917
 
1663
1918
  if is_i2v_lora:
1664
1919
  for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1665
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1666
- converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1667
- if original_key in original_state_dict:
1668
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1669
-
1670
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1671
- converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1672
- if original_key in original_state_dict:
1673
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1920
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1921
+ has_alpha = alpha_key in original_state_dict
1922
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1923
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1924
+
1925
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1926
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1927
+
1928
+ if original_key_A in original_state_dict:
1929
+ down_weight = original_state_dict.pop(original_key_A)
1930
+ converted_state_dict[converted_key_A] = down_weight
1931
+ if original_key_B in original_state_dict:
1932
+ up_weight = original_state_dict.pop(original_key_B)
1933
+ converted_state_dict[converted_key_B] = up_weight
1934
+ if has_alpha:
1935
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1936
+ converted_state_dict[converted_key_A] *= scale_down
1937
+ converted_state_dict[converted_key_B] *= scale_up
1674
1938
 
1675
1939
  original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1676
1940
  converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1679,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1679
1943
 
1680
1944
  # FFN
1681
1945
  for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1682
- original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
1683
- converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
1684
- if original_key in original_state_dict:
1685
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1686
-
1687
- original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
1688
- converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
1689
- if original_key in original_state_dict:
1690
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1946
+ alpha_key = f"blocks.{i}.{o}.alpha"
1947
+ has_alpha = alpha_key in original_state_dict
1948
+ original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
1949
+ converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
1950
+
1951
+ original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
1952
+ converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
1953
+
1954
+ if original_key_A in original_state_dict:
1955
+ down_weight = original_state_dict.pop(original_key_A)
1956
+ converted_state_dict[converted_key_A] = down_weight
1957
+ if original_key_B in original_state_dict:
1958
+ up_weight = original_state_dict.pop(original_key_B)
1959
+ converted_state_dict[converted_key_B] = up_weight
1960
+ if has_alpha:
1961
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1962
+ converted_state_dict[converted_key_A] *= scale_down
1963
+ converted_state_dict[converted_key_B] *= scale_up
1691
1964
 
1692
1965
  original_key = f"blocks.{i}.{o}.diff_b"
1693
1966
  converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
@@ -1754,6 +2027,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1754
2027
  converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
1755
2028
  if original_key in original_state_dict:
1756
2029
  converted_state_dict[converted_key] = original_state_dict.pop(original_key)
2030
+ bias_key_theirs = original_key.removesuffix(f".{lora_up_key}.weight") + ".diff_b"
2031
+ if bias_key_theirs in original_state_dict:
2032
+ bias_key = converted_key.removesuffix(".weight") + ".bias"
2033
+ converted_state_dict[bias_key] = original_state_dict.pop(bias_key_theirs)
1757
2034
 
1758
2035
  if len(original_state_dict) > 0:
1759
2036
  diff = all(".diff" in k for k in original_state_dict)
@@ -1849,3 +2126,107 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
1849
2126
  converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
1850
2127
  converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
1851
2128
  return converted_state_dict
2129
+
2130
+
2131
+ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2132
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
2133
+ if has_lora_unet:
2134
+ state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
2135
+
2136
+ def convert_key(key: str) -> str:
2137
+ prefix = "transformer_blocks"
2138
+ if "." in key:
2139
+ base, suffix = key.rsplit(".", 1)
2140
+ else:
2141
+ base, suffix = key, ""
2142
+
2143
+ start = f"{prefix}_"
2144
+ rest = base[len(start) :]
2145
+
2146
+ if "." in rest:
2147
+ head, tail = rest.split(".", 1)
2148
+ tail = "." + tail
2149
+ else:
2150
+ head, tail = rest, ""
2151
+
2152
+ # Protected n-grams that must keep their internal underscores
2153
+ protected = {
2154
+ # pairs
2155
+ ("to", "q"),
2156
+ ("to", "k"),
2157
+ ("to", "v"),
2158
+ ("to", "out"),
2159
+ ("add", "q"),
2160
+ ("add", "k"),
2161
+ ("add", "v"),
2162
+ ("txt", "mlp"),
2163
+ ("img", "mlp"),
2164
+ ("txt", "mod"),
2165
+ ("img", "mod"),
2166
+ # triplets
2167
+ ("add", "q", "proj"),
2168
+ ("add", "k", "proj"),
2169
+ ("add", "v", "proj"),
2170
+ ("to", "add", "out"),
2171
+ }
2172
+
2173
+ prot_by_len = {}
2174
+ for ng in protected:
2175
+ prot_by_len.setdefault(len(ng), set()).add(ng)
2176
+
2177
+ parts = head.split("_")
2178
+ merged = []
2179
+ i = 0
2180
+ lengths_desc = sorted(prot_by_len.keys(), reverse=True)
2181
+
2182
+ while i < len(parts):
2183
+ matched = False
2184
+ for L in lengths_desc:
2185
+ if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
2186
+ merged.append("_".join(parts[i : i + L]))
2187
+ i += L
2188
+ matched = True
2189
+ break
2190
+ if not matched:
2191
+ merged.append(parts[i])
2192
+ i += 1
2193
+
2194
+ head_converted = ".".join(merged)
2195
+ converted_base = f"{prefix}.{head_converted}{tail}"
2196
+ return converted_base + (("." + suffix) if suffix else "")
2197
+
2198
+ state_dict = {convert_key(k): v for k, v in state_dict.items()}
2199
+
2200
+ converted_state_dict = {}
2201
+ all_keys = list(state_dict.keys())
2202
+ down_key = ".lora_down.weight"
2203
+ up_key = ".lora_up.weight"
2204
+
2205
+ def get_alpha_scales(down_weight, alpha_key):
2206
+ rank = down_weight.shape[0]
2207
+ alpha = state_dict.pop(alpha_key).item()
2208
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2209
+ scale_down = scale
2210
+ scale_up = 1.0
2211
+ while scale_down * 2 < scale_up:
2212
+ scale_down *= 2
2213
+ scale_up /= 2
2214
+ return scale_down, scale_up
2215
+
2216
+ for k in all_keys:
2217
+ if k.endswith(down_key):
2218
+ diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2219
+ diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2220
+ alpha_key = k.replace(down_key, ".alpha")
2221
+
2222
+ down_weight = state_dict.pop(k)
2223
+ up_weight = state_dict.pop(k.replace(down_key, up_key))
2224
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2225
+ converted_state_dict[diffusers_down_key] = down_weight * scale_down
2226
+ converted_state_dict[diffusers_up_key] = up_weight * scale_up
2227
+
2228
+ if len(state_dict) > 0:
2229
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
2230
+
2231
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2232
+ return converted_state_dict