diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
|
817
817
|
# has both `peft` and non-peft state dict.
|
818
818
|
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
819
819
|
if has_peft_state_dict:
|
820
|
-
state_dict = {
|
820
|
+
state_dict = {
|
821
|
+
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
|
822
|
+
for k, v in state_dict.items()
|
823
|
+
if k.startswith("transformer.")
|
824
|
+
}
|
821
825
|
return state_dict
|
822
826
|
|
823
827
|
# Another weird one.
|
@@ -1346,6 +1350,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
|
1346
1350
|
return converted_state_dict
|
1347
1351
|
|
1348
1352
|
|
1353
|
+
def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
|
1354
|
+
converted_state_dict = {}
|
1355
|
+
original_state_dict_keys = list(original_state_dict.keys())
|
1356
|
+
num_layers = 19
|
1357
|
+
num_single_layers = 38
|
1358
|
+
inner_dim = 3072
|
1359
|
+
mlp_ratio = 4.0
|
1360
|
+
|
1361
|
+
# double transformer blocks
|
1362
|
+
for i in range(num_layers):
|
1363
|
+
block_prefix = f"transformer_blocks.{i}."
|
1364
|
+
original_block_prefix = "base_model.model."
|
1365
|
+
|
1366
|
+
for lora_key in ["lora_A", "lora_B"]:
|
1367
|
+
# norms
|
1368
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
|
1369
|
+
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
|
1370
|
+
)
|
1371
|
+
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
1372
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
|
1373
|
+
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
|
1374
|
+
)
|
1375
|
+
|
1376
|
+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
|
1377
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
|
1378
|
+
)
|
1379
|
+
|
1380
|
+
# Q, K, V
|
1381
|
+
if lora_key == "lora_A":
|
1382
|
+
sample_lora_weight = original_state_dict.pop(
|
1383
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
1384
|
+
)
|
1385
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
1386
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
1387
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
1388
|
+
|
1389
|
+
context_lora_weight = original_state_dict.pop(
|
1390
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
1391
|
+
)
|
1392
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
|
1393
|
+
[context_lora_weight]
|
1394
|
+
)
|
1395
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
|
1396
|
+
[context_lora_weight]
|
1397
|
+
)
|
1398
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
|
1399
|
+
[context_lora_weight]
|
1400
|
+
)
|
1401
|
+
else:
|
1402
|
+
sample_q, sample_k, sample_v = torch.chunk(
|
1403
|
+
original_state_dict.pop(
|
1404
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
1405
|
+
),
|
1406
|
+
3,
|
1407
|
+
dim=0,
|
1408
|
+
)
|
1409
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
1410
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
1411
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
1412
|
+
|
1413
|
+
context_q, context_k, context_v = torch.chunk(
|
1414
|
+
original_state_dict.pop(
|
1415
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
1416
|
+
),
|
1417
|
+
3,
|
1418
|
+
dim=0,
|
1419
|
+
)
|
1420
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
|
1421
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
|
1422
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
|
1423
|
+
|
1424
|
+
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
1425
|
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
1426
|
+
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
|
1427
|
+
3,
|
1428
|
+
dim=0,
|
1429
|
+
)
|
1430
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
|
1431
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
|
1432
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
|
1433
|
+
|
1434
|
+
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
1435
|
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
1436
|
+
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
|
1437
|
+
3,
|
1438
|
+
dim=0,
|
1439
|
+
)
|
1440
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
|
1441
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
|
1442
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
|
1443
|
+
|
1444
|
+
# ff img_mlp
|
1445
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
1446
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
|
1447
|
+
)
|
1448
|
+
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
1449
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
1450
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
|
1451
|
+
)
|
1452
|
+
|
1453
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
1454
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
|
1455
|
+
)
|
1456
|
+
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
1457
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
1458
|
+
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
|
1459
|
+
)
|
1460
|
+
|
1461
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
1462
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
|
1463
|
+
)
|
1464
|
+
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
1465
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
1466
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
|
1467
|
+
)
|
1468
|
+
|
1469
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
1470
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
|
1471
|
+
)
|
1472
|
+
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
1473
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
1474
|
+
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
|
1475
|
+
)
|
1476
|
+
|
1477
|
+
# output projections.
|
1478
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
|
1479
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
|
1480
|
+
)
|
1481
|
+
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
1482
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
|
1483
|
+
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
|
1484
|
+
)
|
1485
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
|
1486
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
|
1487
|
+
)
|
1488
|
+
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
1489
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
|
1490
|
+
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
|
1491
|
+
)
|
1492
|
+
|
1493
|
+
# single transformer blocks
|
1494
|
+
for i in range(num_single_layers):
|
1495
|
+
block_prefix = f"single_transformer_blocks.{i}."
|
1496
|
+
|
1497
|
+
for lora_key in ["lora_A", "lora_B"]:
|
1498
|
+
# norm.linear <- single_blocks.0.modulation.lin
|
1499
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
|
1500
|
+
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
|
1501
|
+
)
|
1502
|
+
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
|
1503
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
|
1504
|
+
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
|
1505
|
+
)
|
1506
|
+
|
1507
|
+
# Q, K, V, mlp
|
1508
|
+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
1509
|
+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
1510
|
+
|
1511
|
+
if lora_key == "lora_A":
|
1512
|
+
lora_weight = original_state_dict.pop(
|
1513
|
+
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
|
1514
|
+
)
|
1515
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
|
1516
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
|
1517
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
|
1518
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
|
1519
|
+
|
1520
|
+
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
1521
|
+
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
|
1522
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
|
1523
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
|
1524
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
|
1525
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
|
1526
|
+
else:
|
1527
|
+
q, k, v, mlp = torch.split(
|
1528
|
+
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
|
1529
|
+
split_size,
|
1530
|
+
dim=0,
|
1531
|
+
)
|
1532
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
|
1533
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
|
1534
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
|
1535
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
|
1536
|
+
|
1537
|
+
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
1538
|
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
1539
|
+
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
|
1540
|
+
split_size,
|
1541
|
+
dim=0,
|
1542
|
+
)
|
1543
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
|
1544
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
|
1545
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
|
1546
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
|
1547
|
+
|
1548
|
+
# output projections.
|
1549
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
1550
|
+
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
|
1551
|
+
)
|
1552
|
+
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
|
1553
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
1554
|
+
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
|
1555
|
+
)
|
1556
|
+
|
1557
|
+
for lora_key in ["lora_A", "lora_B"]:
|
1558
|
+
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
1559
|
+
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
|
1560
|
+
)
|
1561
|
+
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
|
1562
|
+
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
1563
|
+
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
|
1564
|
+
)
|
1565
|
+
|
1566
|
+
if len(original_state_dict) > 0:
|
1567
|
+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
1568
|
+
|
1569
|
+
for key in list(converted_state_dict.keys()):
|
1570
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
1571
|
+
|
1572
|
+
return converted_state_dict
|
1573
|
+
|
1574
|
+
|
1349
1575
|
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
1350
1576
|
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
1351
1577
|
|
@@ -1603,24 +1829,33 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1603
1829
|
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
1604
1830
|
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
|
1605
1831
|
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
|
1832
|
+
has_time_projection_weight = any(
|
1833
|
+
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
|
1834
|
+
)
|
1606
1835
|
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
|
1836
|
+
def get_alpha_scales(down_weight, alpha_key):
|
1837
|
+
rank = down_weight.shape[0]
|
1838
|
+
alpha = original_state_dict.pop(alpha_key).item()
|
1839
|
+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
1840
|
+
scale_down = scale
|
1841
|
+
scale_up = 1.0
|
1842
|
+
while scale_down * 2 < scale_up:
|
1843
|
+
scale_down *= 2
|
1844
|
+
scale_up /= 2
|
1845
|
+
return scale_down, scale_up
|
1846
|
+
|
1847
|
+
for key in list(original_state_dict.keys()):
|
1848
|
+
if key.endswith((".diff", ".diff_b")) and "norm" in key:
|
1849
|
+
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
|
1850
|
+
# in future if needed and they are not zeroed.
|
1851
|
+
original_state_dict.pop(key)
|
1852
|
+
logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
|
1853
|
+
|
1854
|
+
if "time_projection" in key and not has_time_projection_weight:
|
1855
|
+
# AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
|
1856
|
+
# our lora config adds the time proj lora layers, but we don't have the weights for them.
|
1857
|
+
# CausVid lora has the weight keys and the bias keys.
|
1858
|
+
original_state_dict.pop(key)
|
1624
1859
|
|
1625
1860
|
# For the `diff_b` keys, we treat them as lora_bias.
|
1626
1861
|
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
|
@@ -1628,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1628
1863
|
for i in range(min_block, max_block + 1):
|
1629
1864
|
# Self-attention
|
1630
1865
|
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
1631
|
-
|
1632
|
-
|
1633
|
-
|
1634
|
-
|
1866
|
+
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
|
1867
|
+
has_alpha = alpha_key in original_state_dict
|
1868
|
+
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
1869
|
+
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
1635
1870
|
|
1636
|
-
|
1637
|
-
|
1638
|
-
|
1639
|
-
|
1871
|
+
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
1872
|
+
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
1873
|
+
|
1874
|
+
if has_alpha:
|
1875
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1876
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1877
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1878
|
+
converted_state_dict[converted_key_A] = down_weight * scale_down
|
1879
|
+
converted_state_dict[converted_key_B] = up_weight * scale_up
|
1880
|
+
|
1881
|
+
else:
|
1882
|
+
if original_key_A in original_state_dict:
|
1883
|
+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
|
1884
|
+
if original_key_B in original_state_dict:
|
1885
|
+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
|
1640
1886
|
|
1641
1887
|
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
|
1642
1888
|
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
|
@@ -1645,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1645
1891
|
|
1646
1892
|
# Cross-attention
|
1647
1893
|
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
1648
|
-
|
1649
|
-
|
1650
|
-
|
1651
|
-
|
1652
|
-
|
1653
|
-
|
1654
|
-
|
1655
|
-
|
1656
|
-
|
1894
|
+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
1895
|
+
has_alpha = alpha_key in original_state_dict
|
1896
|
+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
1897
|
+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
1898
|
+
|
1899
|
+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
1900
|
+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
1901
|
+
|
1902
|
+
if original_key_A in original_state_dict:
|
1903
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1904
|
+
converted_state_dict[converted_key_A] = down_weight
|
1905
|
+
if original_key_B in original_state_dict:
|
1906
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1907
|
+
converted_state_dict[converted_key_B] = up_weight
|
1908
|
+
if has_alpha:
|
1909
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1910
|
+
converted_state_dict[converted_key_A] *= scale_down
|
1911
|
+
converted_state_dict[converted_key_B] *= scale_up
|
1657
1912
|
|
1658
1913
|
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
1659
1914
|
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
@@ -1662,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1662
1917
|
|
1663
1918
|
if is_i2v_lora:
|
1664
1919
|
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1669
|
-
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
|
1920
|
+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
1921
|
+
has_alpha = alpha_key in original_state_dict
|
1922
|
+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
1923
|
+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
1924
|
+
|
1925
|
+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
1926
|
+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
1927
|
+
|
1928
|
+
if original_key_A in original_state_dict:
|
1929
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1930
|
+
converted_state_dict[converted_key_A] = down_weight
|
1931
|
+
if original_key_B in original_state_dict:
|
1932
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1933
|
+
converted_state_dict[converted_key_B] = up_weight
|
1934
|
+
if has_alpha:
|
1935
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1936
|
+
converted_state_dict[converted_key_A] *= scale_down
|
1937
|
+
converted_state_dict[converted_key_B] *= scale_up
|
1674
1938
|
|
1675
1939
|
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
1676
1940
|
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
@@ -1679,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1679
1943
|
|
1680
1944
|
# FFN
|
1681
1945
|
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
1682
|
-
|
1683
|
-
|
1684
|
-
|
1685
|
-
|
1686
|
-
|
1687
|
-
|
1688
|
-
|
1689
|
-
|
1690
|
-
|
1946
|
+
alpha_key = f"blocks.{i}.{o}.alpha"
|
1947
|
+
has_alpha = alpha_key in original_state_dict
|
1948
|
+
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
1949
|
+
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
1950
|
+
|
1951
|
+
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
1952
|
+
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
1953
|
+
|
1954
|
+
if original_key_A in original_state_dict:
|
1955
|
+
down_weight = original_state_dict.pop(original_key_A)
|
1956
|
+
converted_state_dict[converted_key_A] = down_weight
|
1957
|
+
if original_key_B in original_state_dict:
|
1958
|
+
up_weight = original_state_dict.pop(original_key_B)
|
1959
|
+
converted_state_dict[converted_key_B] = up_weight
|
1960
|
+
if has_alpha:
|
1961
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
1962
|
+
converted_state_dict[converted_key_A] *= scale_down
|
1963
|
+
converted_state_dict[converted_key_B] *= scale_up
|
1691
1964
|
|
1692
1965
|
original_key = f"blocks.{i}.{o}.diff_b"
|
1693
1966
|
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
|
@@ -1754,6 +2027,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
1754
2027
|
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
|
1755
2028
|
if original_key in original_state_dict:
|
1756
2029
|
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
2030
|
+
bias_key_theirs = original_key.removesuffix(f".{lora_up_key}.weight") + ".diff_b"
|
2031
|
+
if bias_key_theirs in original_state_dict:
|
2032
|
+
bias_key = converted_key.removesuffix(".weight") + ".bias"
|
2033
|
+
converted_state_dict[bias_key] = original_state_dict.pop(bias_key_theirs)
|
1757
2034
|
|
1758
2035
|
if len(original_state_dict) > 0:
|
1759
2036
|
diff = all(".diff" in k for k in original_state_dict)
|
@@ -1849,3 +2126,107 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
|
1849
2126
|
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
1850
2127
|
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
1851
2128
|
return converted_state_dict
|
2129
|
+
|
2130
|
+
|
2131
|
+
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
2132
|
+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
2133
|
+
if has_lora_unet:
|
2134
|
+
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
|
2135
|
+
|
2136
|
+
def convert_key(key: str) -> str:
|
2137
|
+
prefix = "transformer_blocks"
|
2138
|
+
if "." in key:
|
2139
|
+
base, suffix = key.rsplit(".", 1)
|
2140
|
+
else:
|
2141
|
+
base, suffix = key, ""
|
2142
|
+
|
2143
|
+
start = f"{prefix}_"
|
2144
|
+
rest = base[len(start) :]
|
2145
|
+
|
2146
|
+
if "." in rest:
|
2147
|
+
head, tail = rest.split(".", 1)
|
2148
|
+
tail = "." + tail
|
2149
|
+
else:
|
2150
|
+
head, tail = rest, ""
|
2151
|
+
|
2152
|
+
# Protected n-grams that must keep their internal underscores
|
2153
|
+
protected = {
|
2154
|
+
# pairs
|
2155
|
+
("to", "q"),
|
2156
|
+
("to", "k"),
|
2157
|
+
("to", "v"),
|
2158
|
+
("to", "out"),
|
2159
|
+
("add", "q"),
|
2160
|
+
("add", "k"),
|
2161
|
+
("add", "v"),
|
2162
|
+
("txt", "mlp"),
|
2163
|
+
("img", "mlp"),
|
2164
|
+
("txt", "mod"),
|
2165
|
+
("img", "mod"),
|
2166
|
+
# triplets
|
2167
|
+
("add", "q", "proj"),
|
2168
|
+
("add", "k", "proj"),
|
2169
|
+
("add", "v", "proj"),
|
2170
|
+
("to", "add", "out"),
|
2171
|
+
}
|
2172
|
+
|
2173
|
+
prot_by_len = {}
|
2174
|
+
for ng in protected:
|
2175
|
+
prot_by_len.setdefault(len(ng), set()).add(ng)
|
2176
|
+
|
2177
|
+
parts = head.split("_")
|
2178
|
+
merged = []
|
2179
|
+
i = 0
|
2180
|
+
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
|
2181
|
+
|
2182
|
+
while i < len(parts):
|
2183
|
+
matched = False
|
2184
|
+
for L in lengths_desc:
|
2185
|
+
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
|
2186
|
+
merged.append("_".join(parts[i : i + L]))
|
2187
|
+
i += L
|
2188
|
+
matched = True
|
2189
|
+
break
|
2190
|
+
if not matched:
|
2191
|
+
merged.append(parts[i])
|
2192
|
+
i += 1
|
2193
|
+
|
2194
|
+
head_converted = ".".join(merged)
|
2195
|
+
converted_base = f"{prefix}.{head_converted}{tail}"
|
2196
|
+
return converted_base + (("." + suffix) if suffix else "")
|
2197
|
+
|
2198
|
+
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
2199
|
+
|
2200
|
+
converted_state_dict = {}
|
2201
|
+
all_keys = list(state_dict.keys())
|
2202
|
+
down_key = ".lora_down.weight"
|
2203
|
+
up_key = ".lora_up.weight"
|
2204
|
+
|
2205
|
+
def get_alpha_scales(down_weight, alpha_key):
|
2206
|
+
rank = down_weight.shape[0]
|
2207
|
+
alpha = state_dict.pop(alpha_key).item()
|
2208
|
+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
2209
|
+
scale_down = scale
|
2210
|
+
scale_up = 1.0
|
2211
|
+
while scale_down * 2 < scale_up:
|
2212
|
+
scale_down *= 2
|
2213
|
+
scale_up /= 2
|
2214
|
+
return scale_down, scale_up
|
2215
|
+
|
2216
|
+
for k in all_keys:
|
2217
|
+
if k.endswith(down_key):
|
2218
|
+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
2219
|
+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
|
2220
|
+
alpha_key = k.replace(down_key, ".alpha")
|
2221
|
+
|
2222
|
+
down_weight = state_dict.pop(k)
|
2223
|
+
up_weight = state_dict.pop(k.replace(down_key, up_key))
|
2224
|
+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
2225
|
+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
2226
|
+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
2227
|
+
|
2228
|
+
if len(state_dict) > 0:
|
2229
|
+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
2230
|
+
|
2231
|
+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
2232
|
+
return converted_state_dict
|