diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/peft.py CHANGED
@@ -33,6 +33,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
33
33
  "UNetMotionModel": _maybe_expand_lora_scales,
34
34
  "SD3Transformer2DModel": lambda model_cls, weights: weights,
35
35
  "FluxTransformer2DModel": lambda model_cls, weights: weights,
36
+ "CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
36
37
  }
37
38
 
38
39
 
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Conversion script for the Stable Diffusion checkpoints."""
16
16
 
17
+ import copy
17
18
  import os
18
19
  import re
19
20
  from contextlib import nullcontext
@@ -74,6 +75,7 @@ CHECKPOINT_KEY_NAMES = {
74
75
  "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
75
76
  "stable_cascade_stage_c": "clip_txt_mapper.weight",
76
77
  "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
78
+ "sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
77
79
  "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
78
80
  "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
79
81
  "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -91,11 +93,11 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
91
93
  "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
92
94
  "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
93
95
  "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
94
- "inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"},
96
+ "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
95
97
  "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
96
98
  "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
97
99
  "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
98
- "v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"},
100
+ "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
99
101
  "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
100
102
  "stable_cascade_stage_b_lite": {
101
103
  "pretrained_model_name_or_path": "stabilityai/stable-cascade",
@@ -112,6 +114,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
112
114
  "sd3": {
113
115
  "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
114
116
  },
117
+ "sd35_large": {
118
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
119
+ },
115
120
  "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
116
121
  "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
117
122
  "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
@@ -456,6 +461,8 @@ def infer_diffusers_model_type(checkpoint):
456
461
  ):
457
462
  if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
458
463
  model_type = "inpainting_v2"
464
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
465
+ model_type = "xl_inpaint"
459
466
  else:
460
467
  model_type = "inpainting"
461
468
 
@@ -501,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
501
508
  ):
502
509
  model_type = "stable_cascade_stage_b"
503
510
 
504
- elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
511
+ elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
505
512
  model_type = "sd3"
506
513
 
514
+ elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
515
+ model_type = "sd35_large"
516
+
507
517
  elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
508
518
  if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
509
519
  model_type = "animatediff_scribble"
@@ -539,6 +549,7 @@ def infer_diffusers_model_type(checkpoint):
539
549
  def fetch_diffusers_config(checkpoint):
540
550
  model_type = infer_diffusers_model_type(checkpoint)
541
551
  model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
552
+ model_path = copy.deepcopy(model_path)
542
553
 
543
554
  return model_path
544
555
 
@@ -1666,6 +1677,22 @@ def swap_scale_shift(weight, dim):
1666
1677
  return new_weight
1667
1678
 
1668
1679
 
1680
+ def get_attn2_layers(state_dict):
1681
+ attn2_layers = []
1682
+ for key in state_dict.keys():
1683
+ if "attn2." in key:
1684
+ # Extract the layer number from the key
1685
+ layer_num = int(key.split(".")[1])
1686
+ attn2_layers.append(layer_num)
1687
+
1688
+ return tuple(sorted(set(attn2_layers)))
1689
+
1690
+
1691
+ def get_caption_projection_dim(state_dict):
1692
+ caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
1693
+ return caption_projection_dim
1694
+
1695
+
1669
1696
  def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1670
1697
  converted_state_dict = {}
1671
1698
  keys = list(checkpoint.keys())
@@ -1674,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1674
1701
  checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1675
1702
 
1676
1703
  num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1677
- caption_projection_dim = 1536
1704
+ dual_attention_layers = get_attn2_layers(checkpoint)
1705
+
1706
+ caption_projection_dim = get_caption_projection_dim(checkpoint)
1707
+ has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
1678
1708
 
1679
1709
  # Positional and patch embeddings.
1680
1710
  converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
@@ -1731,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1731
1761
  converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
1732
1762
  converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1733
1763
 
1764
+ # qk norm
1765
+ if has_qk_norm:
1766
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
1767
+ f"joint_blocks.{i}.x_block.attn.ln_q.weight"
1768
+ )
1769
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
1770
+ f"joint_blocks.{i}.x_block.attn.ln_k.weight"
1771
+ )
1772
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
1773
+ f"joint_blocks.{i}.context_block.attn.ln_q.weight"
1774
+ )
1775
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
1776
+ f"joint_blocks.{i}.context_block.attn.ln_k.weight"
1777
+ )
1778
+
1734
1779
  # output projections.
1735
1780
  converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
1736
1781
  f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -1746,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1746
1791
  f"joint_blocks.{i}.context_block.attn.proj.bias"
1747
1792
  )
1748
1793
 
1794
+ if i in dual_attention_layers:
1795
+ # Q, K, V
1796
+ sample_q2, sample_k2, sample_v2 = torch.chunk(
1797
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
1798
+ )
1799
+ sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
1800
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
1801
+ )
1802
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
1803
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
1804
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
1805
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
1806
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
1807
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
1808
+
1809
+ # qk norm
1810
+ if has_qk_norm:
1811
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
1812
+ f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
1813
+ )
1814
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
1815
+ f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
1816
+ )
1817
+
1818
+ # output projections.
1819
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
1820
+ f"joint_blocks.{i}.x_block.attn2.proj.weight"
1821
+ )
1822
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
1823
+ f"joint_blocks.{i}.x_block.attn2.proj.bias"
1824
+ )
1825
+
1749
1826
  # norms.
1750
1827
  converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
1751
1828
  f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
@@ -561,6 +561,8 @@ class TextualInversionLoaderMixin:
561
561
  tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
562
562
  key_id += 1
563
563
  tokenizer._update_trie()
564
+ # set correct total vocab size after removing tokens
565
+ tokenizer._update_total_vocab_size()
564
566
 
565
567
  # Delete from text encoder
566
568
  text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
diffusers/loaders/unet.py CHANGED
@@ -115,6 +115,9 @@ class UNet2DConditionLoadersMixin:
115
115
  `default_{i}` where i is the total number of adapters being loaded.
116
116
  weight_name (`str`, *optional*, defaults to None):
117
117
  Name of the serialized state dict file.
118
+ low_cpu_mem_usage (`bool`, *optional*):
119
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
120
+ weights.
118
121
 
119
122
  Example:
120
123
 
@@ -142,8 +145,14 @@ class UNet2DConditionLoadersMixin:
142
145
  adapter_name = kwargs.pop("adapter_name", None)
143
146
  _pipeline = kwargs.pop("_pipeline", None)
144
147
  network_alphas = kwargs.pop("network_alphas", None)
148
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
145
149
  allow_pickle = False
146
150
 
151
+ if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
152
+ raise ValueError(
153
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
154
+ )
155
+
147
156
  if use_safetensors is None:
148
157
  use_safetensors = True
149
158
  allow_pickle = True
@@ -209,6 +218,7 @@ class UNet2DConditionLoadersMixin:
209
218
  network_alphas=network_alphas,
210
219
  adapter_name=adapter_name,
211
220
  _pipeline=_pipeline,
221
+ low_cpu_mem_usage=low_cpu_mem_usage,
212
222
  )
213
223
  else:
214
224
  raise ValueError(
@@ -268,7 +278,9 @@ class UNet2DConditionLoadersMixin:
268
278
 
269
279
  return attn_processors
270
280
 
271
- def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
281
+ def _process_lora(
282
+ self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
283
+ ):
272
284
  # This method does the following things:
273
285
  # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
274
286
  # format. For legacy format no filtering is applied.
@@ -335,18 +347,37 @@ class UNet2DConditionLoadersMixin:
335
347
  # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
336
348
  # otherwise loading LoRA weights will lead to an error
337
349
  is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
350
+ peft_kwargs = {}
351
+ if is_peft_version(">=", "0.13.1"):
352
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
338
353
 
339
- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
340
- incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
354
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
355
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
341
356
 
357
+ warn_msg = ""
342
358
  if incompatible_keys is not None:
343
- # check only for unexpected keys
359
+ # Check only for unexpected keys.
344
360
  unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
345
361
  if unexpected_keys:
346
- logger.warning(
347
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
348
- f" {unexpected_keys}. "
349
- )
362
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
363
+ if lora_unexpected_keys:
364
+ warn_msg = (
365
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
366
+ f" {', '.join(lora_unexpected_keys)}. "
367
+ )
368
+
369
+ # Filter missing keys specific to the current adapter.
370
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
371
+ if missing_keys:
372
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
373
+ if lora_missing_keys:
374
+ warn_msg += (
375
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
376
+ f" {', '.join(lora_missing_keys)}."
377
+ )
378
+
379
+ if warn_msg:
380
+ logger.warning(warn_msg)
350
381
 
351
382
  return is_model_cpu_offload, is_sequential_cpu_offload
352
383
 
@@ -35,6 +35,7 @@ if is_torch_available():
35
35
  _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
36
36
  _import_structure["autoencoders.vq_model"] = ["VQModel"]
37
37
  _import_structure["controlnet"] = ["ControlNetModel"]
38
+ _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
38
39
  _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
39
40
  _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
40
41
  _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
@@ -53,6 +54,7 @@ if is_torch_available():
53
54
  _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
54
55
  _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
55
56
  _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
57
+ _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
56
58
  _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
57
59
  _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
58
60
  _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -87,6 +89,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
87
89
  VQModel,
88
90
  )
89
91
  from .controlnet import ControlNetModel
92
+ from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
90
93
  from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
91
94
  from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
92
95
  from .controlnet_sparsectrl import SparseControlNetModel
@@ -96,6 +99,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
96
99
  from .transformers import (
97
100
  AuraFlowTransformer2DModel,
98
101
  CogVideoXTransformer3DModel,
102
+ CogView3PlusTransformer2DModel,
99
103
  DiTTransformer2DModel,
100
104
  DualTransformer2DModel,
101
105
  FluxTransformer2DModel,
@@ -30,10 +30,10 @@ class MultiAdapter(ModelMixin):
30
30
  MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
31
31
  user-assigned weighting.
32
32
 
33
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
34
- implements for all the model (such as downloading or saving, etc.)
33
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
34
+ or saving.
35
35
 
36
- Parameters:
36
+ Args:
37
37
  adapters (`List[T2IAdapter]`, *optional*, defaults to None):
38
38
  A list of `T2IAdapter` model instances.
39
39
  """
@@ -77,11 +77,13 @@ class MultiAdapter(ModelMixin):
77
77
  r"""
78
78
  Args:
79
79
  xs (`torch.Tensor`):
80
- (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
81
- `channel` should equal to `num_adapter` * "number of channel of image".
80
+ A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
81
+ models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
82
+ `num_adapter` * number of channel per image.
83
+
82
84
  adapter_weights (`List[float]`, *optional*, defaults to None):
83
- List of floats representing the weight which will be multiply to each adapter's output before adding
84
- them together.
85
+ A list of floats representing the weights which will be multiplied by each adapter's output before
86
+ summing them together. If `None`, equal weights will be used for all adapters.
85
87
  """
86
88
  if adapter_weights is None:
87
89
  adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
@@ -109,24 +111,24 @@ class MultiAdapter(ModelMixin):
109
111
  variant: Optional[str] = None,
110
112
  ):
111
113
  """
112
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
114
+ Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
113
115
  `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
114
116
 
115
- Arguments:
117
+ Args:
116
118
  save_directory (`str` or `os.PathLike`):
117
- Directory to which to save. Will be created if it doesn't exist.
118
- is_main_process (`bool`, *optional*, defaults to `True`):
119
- Whether the process calling this is the main process or not. Useful when in distributed training like
120
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
121
- the main process to avoid race conditions.
119
+ The directory where the model will be saved. If the directory does not exist, it will be created.
120
+ is_main_process (`bool`, optional, defaults=True):
121
+ Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
122
+ TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
123
+ for the main process to avoid race conditions.
122
124
  save_function (`Callable`):
123
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
124
- need to replace `torch.save` by another method. Can be configured with the environment variable
125
- `DIFFUSERS_SAVE_MODE`.
126
- safe_serialization (`bool`, *optional*, defaults to `True`):
127
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
125
+ Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
126
+ `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
127
+ variable.
128
+ safe_serialization (`bool`, optional, defaults=True):
129
+ If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
128
130
  variant (`str`, *optional*):
129
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
131
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
130
132
  """
131
133
  idx = 0
132
134
  model_path_to_save = save_directory
@@ -145,19 +147,17 @@ class MultiAdapter(ModelMixin):
145
147
  @classmethod
146
148
  def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
147
149
  r"""
148
- Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
150
+ Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
149
151
 
150
152
  The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
151
- the model, you should first set it back in training mode with `model.train()`.
153
+ the model, set it back to training mode using `model.train()`.
152
154
 
153
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
154
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
155
- task.
155
+ Warnings:
156
+ *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
157
+ with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
158
+ from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.
156
159
 
157
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
158
- weights are discarded.
159
-
160
- Parameters:
160
+ Args:
161
161
  pretrained_model_path (`os.PathLike`):
162
162
  A path to a *directory* containing model weights saved using
163
163
  [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
@@ -175,20 +175,20 @@ class MultiAdapter(ModelMixin):
175
175
  more information about each option see [designing a device
176
176
  map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
177
177
  max_memory (`Dict`, *optional*):
178
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
179
- GPU and the available CPU RAM if unset.
178
+ A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
179
+ available for each GPU and the available CPU RAM if unset.
180
180
  low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
181
181
  Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
182
182
  also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
183
183
  model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
184
184
  setting this argument to `True` will raise an error.
185
185
  variant (`str`, *optional*):
186
- If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
187
- ignored when using `from_flax`.
186
+ If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
187
+ be ignored when using `from_flax`.
188
188
  use_safetensors (`bool`, *optional*, defaults to `None`):
189
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
190
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
191
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
189
+ If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
190
+ installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
191
+ `safetensors` is not used.
192
192
  """
193
193
  idx = 0
194
194
  adapters = []
@@ -223,22 +223,22 @@ class T2IAdapter(ModelMixin, ConfigMixin):
223
223
  and
224
224
  [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
225
225
 
226
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
227
- implements for all the model (such as downloading or saving, etc.)
226
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
227
+ downloading or saving.
228
228
 
229
- Parameters:
230
- in_channels (`int`, *optional*, defaults to 3):
231
- Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
232
- image as *control image*.
229
+ Args:
230
+ in_channels (`int`, *optional*, defaults to `3`):
231
+ The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
232
+ image.
233
233
  channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
234
- The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
235
- also determine the number of downsample blocks in the Adapter.
236
- num_res_blocks (`int`, *optional*, defaults to 2):
234
+ The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
235
+ determines the number of downsample blocks in the adapter.
236
+ num_res_blocks (`int`, *optional*, defaults to `2`):
237
237
  Number of ResNet blocks in each downsample block.
238
- downscale_factor (`int`, *optional*, defaults to 8):
238
+ downscale_factor (`int`, *optional*, defaults to `8`):
239
239
  A factor that determines the total downscale factor of the Adapter.
240
240
  adapter_type (`str`, *optional*, defaults to `full_adapter`):
241
- The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
241
+ Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
242
242
  """
243
243
 
244
244
  @register_to_config
@@ -393,7 +393,7 @@ class AdapterBlock(nn.Module):
393
393
  An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
394
394
  `FullAdapterXL` models.
395
395
 
396
- Parameters:
396
+ Args:
397
397
  in_channels (`int`):
398
398
  Number of channels of AdapterBlock's input.
399
399
  out_channels (`int`):
@@ -401,7 +401,7 @@ class AdapterBlock(nn.Module):
401
401
  num_res_blocks (`int`):
402
402
  Number of ResNet blocks in the AdapterBlock.
403
403
  down (`bool`, *optional*, defaults to `False`):
404
- Whether to perform downsampling on AdapterBlock's input.
404
+ If `True`, perform downsampling on AdapterBlock's input.
405
405
  """
406
406
 
407
407
  def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
@@ -440,7 +440,7 @@ class AdapterResnetBlock(nn.Module):
440
440
  r"""
441
441
  An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
442
442
 
443
- Parameters:
443
+ Args:
444
444
  channels (`int`):
445
445
  Number of channels of AdapterResnetBlock's input and output.
446
446
  """
@@ -518,7 +518,7 @@ class LightAdapterBlock(nn.Module):
518
518
  A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
519
519
  `LightAdapter` model.
520
520
 
521
- Parameters:
521
+ Args:
522
522
  in_channels (`int`):
523
523
  Number of channels of LightAdapterBlock's input.
524
524
  out_channels (`int`):
@@ -526,7 +526,7 @@ class LightAdapterBlock(nn.Module):
526
526
  num_res_blocks (`int`):
527
527
  Number of LightAdapterResnetBlocks in the LightAdapterBlock.
528
528
  down (`bool`, *optional*, defaults to `False`):
529
- Whether to perform downsampling on LightAdapterBlock's input.
529
+ If `True`, perform downsampling on LightAdapterBlock's input.
530
530
  """
531
531
 
532
532
  def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
@@ -561,7 +561,7 @@ class LightAdapterResnetBlock(nn.Module):
561
561
  A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
562
562
  architecture than `AdapterResnetBlock`.
563
563
 
564
- Parameters:
564
+ Args:
565
565
  channels (`int`):
566
566
  Number of channels of LightAdapterResnetBlock's input and output.
567
567
  """