diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@
|
|
14
14
|
|
15
15
|
import re
|
16
16
|
|
17
|
+
import torch
|
18
|
+
|
17
19
|
from ..utils import is_peft_version, logging
|
18
20
|
|
19
21
|
|
@@ -326,3 +328,333 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
|
326
328
|
prefix = "text_encoder_2."
|
327
329
|
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
328
330
|
return {new_name: alpha}
|
331
|
+
|
332
|
+
|
333
|
+
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
334
|
+
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
335
|
+
# All credits go to `kohya-ss`.
|
336
|
+
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
337
|
+
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
338
|
+
if sds_key + ".lora_down.weight" not in sds_sd:
|
339
|
+
return
|
340
|
+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
341
|
+
|
342
|
+
# scale weight by alpha and dim
|
343
|
+
rank = down_weight.shape[0]
|
344
|
+
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
345
|
+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
346
|
+
|
347
|
+
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
348
|
+
scale_down = scale
|
349
|
+
scale_up = 1.0
|
350
|
+
while scale_down * 2 < scale_up:
|
351
|
+
scale_down *= 2
|
352
|
+
scale_up /= 2
|
353
|
+
|
354
|
+
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
|
355
|
+
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
356
|
+
|
357
|
+
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
358
|
+
if sds_key + ".lora_down.weight" not in sds_sd:
|
359
|
+
return
|
360
|
+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
361
|
+
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
362
|
+
sd_lora_rank = down_weight.shape[0]
|
363
|
+
|
364
|
+
# scale weight by alpha and dim
|
365
|
+
alpha = sds_sd.pop(sds_key + ".alpha")
|
366
|
+
scale = alpha / sd_lora_rank
|
367
|
+
|
368
|
+
# calculate scale_down and scale_up
|
369
|
+
scale_down = scale
|
370
|
+
scale_up = 1.0
|
371
|
+
while scale_down * 2 < scale_up:
|
372
|
+
scale_down *= 2
|
373
|
+
scale_up /= 2
|
374
|
+
|
375
|
+
down_weight = down_weight * scale_down
|
376
|
+
up_weight = up_weight * scale_up
|
377
|
+
|
378
|
+
# calculate dims if not provided
|
379
|
+
num_splits = len(ait_keys)
|
380
|
+
if dims is None:
|
381
|
+
dims = [up_weight.shape[0] // num_splits] * num_splits
|
382
|
+
else:
|
383
|
+
assert sum(dims) == up_weight.shape[0]
|
384
|
+
|
385
|
+
# check upweight is sparse or not
|
386
|
+
is_sparse = False
|
387
|
+
if sd_lora_rank % num_splits == 0:
|
388
|
+
ait_rank = sd_lora_rank // num_splits
|
389
|
+
is_sparse = True
|
390
|
+
i = 0
|
391
|
+
for j in range(len(dims)):
|
392
|
+
for k in range(len(dims)):
|
393
|
+
if j == k:
|
394
|
+
continue
|
395
|
+
is_sparse = is_sparse and torch.all(
|
396
|
+
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
|
397
|
+
)
|
398
|
+
i += dims[j]
|
399
|
+
if is_sparse:
|
400
|
+
logger.info(f"weight is sparse: {sds_key}")
|
401
|
+
|
402
|
+
# make ai-toolkit weight
|
403
|
+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
404
|
+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
405
|
+
if not is_sparse:
|
406
|
+
# down_weight is copied to each split
|
407
|
+
ait_sd.update({k: down_weight for k in ait_down_keys})
|
408
|
+
|
409
|
+
# up_weight is split to each split
|
410
|
+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
411
|
+
else:
|
412
|
+
# down_weight is chunked to each split
|
413
|
+
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
|
414
|
+
|
415
|
+
# up_weight is sparse: only non-zero values are copied to each split
|
416
|
+
i = 0
|
417
|
+
for j in range(len(dims)):
|
418
|
+
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
419
|
+
i += dims[j]
|
420
|
+
|
421
|
+
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
|
422
|
+
ait_sd = {}
|
423
|
+
for i in range(19):
|
424
|
+
_convert_to_ai_toolkit(
|
425
|
+
sds_sd,
|
426
|
+
ait_sd,
|
427
|
+
f"lora_unet_double_blocks_{i}_img_attn_proj",
|
428
|
+
f"transformer.transformer_blocks.{i}.attn.to_out.0",
|
429
|
+
)
|
430
|
+
_convert_to_ai_toolkit_cat(
|
431
|
+
sds_sd,
|
432
|
+
ait_sd,
|
433
|
+
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
434
|
+
[
|
435
|
+
f"transformer.transformer_blocks.{i}.attn.to_q",
|
436
|
+
f"transformer.transformer_blocks.{i}.attn.to_k",
|
437
|
+
f"transformer.transformer_blocks.{i}.attn.to_v",
|
438
|
+
],
|
439
|
+
)
|
440
|
+
_convert_to_ai_toolkit(
|
441
|
+
sds_sd,
|
442
|
+
ait_sd,
|
443
|
+
f"lora_unet_double_blocks_{i}_img_mlp_0",
|
444
|
+
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
|
445
|
+
)
|
446
|
+
_convert_to_ai_toolkit(
|
447
|
+
sds_sd,
|
448
|
+
ait_sd,
|
449
|
+
f"lora_unet_double_blocks_{i}_img_mlp_2",
|
450
|
+
f"transformer.transformer_blocks.{i}.ff.net.2",
|
451
|
+
)
|
452
|
+
_convert_to_ai_toolkit(
|
453
|
+
sds_sd,
|
454
|
+
ait_sd,
|
455
|
+
f"lora_unet_double_blocks_{i}_img_mod_lin",
|
456
|
+
f"transformer.transformer_blocks.{i}.norm1.linear",
|
457
|
+
)
|
458
|
+
_convert_to_ai_toolkit(
|
459
|
+
sds_sd,
|
460
|
+
ait_sd,
|
461
|
+
f"lora_unet_double_blocks_{i}_txt_attn_proj",
|
462
|
+
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
463
|
+
)
|
464
|
+
_convert_to_ai_toolkit_cat(
|
465
|
+
sds_sd,
|
466
|
+
ait_sd,
|
467
|
+
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
468
|
+
[
|
469
|
+
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
470
|
+
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
471
|
+
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
472
|
+
],
|
473
|
+
)
|
474
|
+
_convert_to_ai_toolkit(
|
475
|
+
sds_sd,
|
476
|
+
ait_sd,
|
477
|
+
f"lora_unet_double_blocks_{i}_txt_mlp_0",
|
478
|
+
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
|
479
|
+
)
|
480
|
+
_convert_to_ai_toolkit(
|
481
|
+
sds_sd,
|
482
|
+
ait_sd,
|
483
|
+
f"lora_unet_double_blocks_{i}_txt_mlp_2",
|
484
|
+
f"transformer.transformer_blocks.{i}.ff_context.net.2",
|
485
|
+
)
|
486
|
+
_convert_to_ai_toolkit(
|
487
|
+
sds_sd,
|
488
|
+
ait_sd,
|
489
|
+
f"lora_unet_double_blocks_{i}_txt_mod_lin",
|
490
|
+
f"transformer.transformer_blocks.{i}.norm1_context.linear",
|
491
|
+
)
|
492
|
+
|
493
|
+
for i in range(38):
|
494
|
+
_convert_to_ai_toolkit_cat(
|
495
|
+
sds_sd,
|
496
|
+
ait_sd,
|
497
|
+
f"lora_unet_single_blocks_{i}_linear1",
|
498
|
+
[
|
499
|
+
f"transformer.single_transformer_blocks.{i}.attn.to_q",
|
500
|
+
f"transformer.single_transformer_blocks.{i}.attn.to_k",
|
501
|
+
f"transformer.single_transformer_blocks.{i}.attn.to_v",
|
502
|
+
f"transformer.single_transformer_blocks.{i}.proj_mlp",
|
503
|
+
],
|
504
|
+
dims=[3072, 3072, 3072, 12288],
|
505
|
+
)
|
506
|
+
_convert_to_ai_toolkit(
|
507
|
+
sds_sd,
|
508
|
+
ait_sd,
|
509
|
+
f"lora_unet_single_blocks_{i}_linear2",
|
510
|
+
f"transformer.single_transformer_blocks.{i}.proj_out",
|
511
|
+
)
|
512
|
+
_convert_to_ai_toolkit(
|
513
|
+
sds_sd,
|
514
|
+
ait_sd,
|
515
|
+
f"lora_unet_single_blocks_{i}_modulation_lin",
|
516
|
+
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
517
|
+
)
|
518
|
+
|
519
|
+
remaining_keys = list(sds_sd.keys())
|
520
|
+
te_state_dict = {}
|
521
|
+
if remaining_keys:
|
522
|
+
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
523
|
+
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
524
|
+
for key in remaining_keys:
|
525
|
+
if not key.endswith("lora_down.weight"):
|
526
|
+
continue
|
527
|
+
|
528
|
+
lora_name = key.split(".")[0]
|
529
|
+
lora_name_up = f"{lora_name}.lora_up.weight"
|
530
|
+
lora_name_alpha = f"{lora_name}.alpha"
|
531
|
+
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
532
|
+
|
533
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
534
|
+
down_weight = sds_sd.pop(key)
|
535
|
+
sd_lora_rank = down_weight.shape[0]
|
536
|
+
te_state_dict[diffusers_name] = down_weight
|
537
|
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
|
538
|
+
|
539
|
+
if lora_name_alpha in sds_sd:
|
540
|
+
alpha = sds_sd.pop(lora_name_alpha).item()
|
541
|
+
scale = alpha / sd_lora_rank
|
542
|
+
|
543
|
+
scale_down = scale
|
544
|
+
scale_up = 1.0
|
545
|
+
while scale_down * 2 < scale_up:
|
546
|
+
scale_down *= 2
|
547
|
+
scale_up /= 2
|
548
|
+
|
549
|
+
te_state_dict[diffusers_name] *= scale_down
|
550
|
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
|
551
|
+
|
552
|
+
if len(sds_sd) > 0:
|
553
|
+
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
|
554
|
+
|
555
|
+
if te_state_dict:
|
556
|
+
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
|
557
|
+
|
558
|
+
new_state_dict = {**ait_sd, **te_state_dict}
|
559
|
+
return new_state_dict
|
560
|
+
|
561
|
+
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
562
|
+
|
563
|
+
|
564
|
+
# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
|
565
|
+
# Some utilities were reused from
|
566
|
+
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
567
|
+
def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
568
|
+
new_state_dict = {}
|
569
|
+
orig_keys = list(old_state_dict.keys())
|
570
|
+
|
571
|
+
def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
572
|
+
down_weight = sds_sd.pop(sds_key)
|
573
|
+
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
|
574
|
+
|
575
|
+
# calculate dims if not provided
|
576
|
+
num_splits = len(ait_keys)
|
577
|
+
if dims is None:
|
578
|
+
dims = [up_weight.shape[0] // num_splits] * num_splits
|
579
|
+
else:
|
580
|
+
assert sum(dims) == up_weight.shape[0]
|
581
|
+
|
582
|
+
# make ai-toolkit weight
|
583
|
+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
584
|
+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
585
|
+
|
586
|
+
# down_weight is copied to each split
|
587
|
+
ait_sd.update({k: down_weight for k in ait_down_keys})
|
588
|
+
|
589
|
+
# up_weight is split to each split
|
590
|
+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
591
|
+
|
592
|
+
for old_key in orig_keys:
|
593
|
+
# Handle double_blocks
|
594
|
+
if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
|
595
|
+
block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
|
596
|
+
new_key = f"transformer.transformer_blocks.{block_num}"
|
597
|
+
|
598
|
+
if "processor.proj_lora1" in old_key:
|
599
|
+
new_key += ".attn.to_out.0"
|
600
|
+
elif "processor.proj_lora2" in old_key:
|
601
|
+
new_key += ".attn.to_add_out"
|
602
|
+
# Handle text latents.
|
603
|
+
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
|
604
|
+
handle_qkv(
|
605
|
+
old_state_dict,
|
606
|
+
new_state_dict,
|
607
|
+
old_key,
|
608
|
+
[
|
609
|
+
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
|
610
|
+
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
|
611
|
+
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
|
612
|
+
],
|
613
|
+
)
|
614
|
+
# continue
|
615
|
+
# Handle image latents.
|
616
|
+
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
|
617
|
+
handle_qkv(
|
618
|
+
old_state_dict,
|
619
|
+
new_state_dict,
|
620
|
+
old_key,
|
621
|
+
[
|
622
|
+
f"transformer.transformer_blocks.{block_num}.attn.to_q",
|
623
|
+
f"transformer.transformer_blocks.{block_num}.attn.to_k",
|
624
|
+
f"transformer.transformer_blocks.{block_num}.attn.to_v",
|
625
|
+
],
|
626
|
+
)
|
627
|
+
# continue
|
628
|
+
|
629
|
+
if "down" in old_key:
|
630
|
+
new_key += ".lora_A.weight"
|
631
|
+
elif "up" in old_key:
|
632
|
+
new_key += ".lora_B.weight"
|
633
|
+
|
634
|
+
# Handle single_blocks
|
635
|
+
elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
|
636
|
+
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
|
637
|
+
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
638
|
+
|
639
|
+
if "proj_lora1" in old_key or "proj_lora2" in old_key:
|
640
|
+
new_key += ".proj_out"
|
641
|
+
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
|
642
|
+
new_key += ".norm.linear"
|
643
|
+
|
644
|
+
if "down" in old_key:
|
645
|
+
new_key += ".lora_A.weight"
|
646
|
+
elif "up" in old_key:
|
647
|
+
new_key += ".lora_B.weight"
|
648
|
+
|
649
|
+
else:
|
650
|
+
# Handle other potential key patterns here
|
651
|
+
new_key = old_key
|
652
|
+
|
653
|
+
# Since we already handle qkv above.
|
654
|
+
if "qkv" not in old_key:
|
655
|
+
new_state_dict[new_key] = old_state_dict.pop(old_key)
|
656
|
+
|
657
|
+
if len(old_state_dict) > 0:
|
658
|
+
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
|
659
|
+
|
660
|
+
return new_state_dict
|