diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,15 @@ CHECKPOINT_KEY_NAMES = {
74
74
  "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
75
75
  "stable_cascade_stage_c": "clip_txt_mapper.weight",
76
76
  "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
77
+ "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
78
+ "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
79
+ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
80
+ "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
81
+ "animatediff_rgb": "controlnet_cond_embedding.weight",
82
+ "flux": [
83
+ "double_blocks.0.img_attn.norm.key_norm.scale",
84
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
85
+ ],
77
86
  }
78
87
 
79
88
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -103,6 +112,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
103
112
  "sd3": {
104
113
  "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
105
114
  },
115
+ "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
116
+ "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
117
+ "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
118
+ "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
119
+ "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
120
+ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
121
+ "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
122
+ "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
106
123
  }
107
124
 
108
125
  # Use to configure model sample size when original config is provided
@@ -244,7 +261,7 @@ SCHEDULER_DEFAULT_CONFIG = {
244
261
  "timestep_spacing": "leading",
245
262
  }
246
263
 
247
- LDM_VAE_KEY = "first_stage_model."
264
+ LDM_VAE_KEYS = ["first_stage_model.", "vae."]
248
265
  LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
249
266
  PLAYGROUND_VAE_SCALING_FACTOR = 0.5
250
267
  LDM_UNET_KEY = "model.diffusion_model."
@@ -253,8 +270,8 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
253
270
  "cond_stage_model.transformer.",
254
271
  "conditioner.embedders.0.transformer.",
255
272
  ]
256
- OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
257
273
  LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
274
+ SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
258
275
 
259
276
  VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
260
277
 
@@ -304,9 +321,12 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
304
321
  return weights_exist
305
322
 
306
323
 
324
+ def _is_legacy_scheduler_kwargs(kwargs):
325
+ return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
326
+
327
+
307
328
  def load_single_file_checkpoint(
308
329
  pretrained_model_link_or_path,
309
- resume_download=False,
310
330
  force_download=False,
311
331
  proxies=None,
312
332
  token=None,
@@ -324,7 +344,6 @@ def load_single_file_checkpoint(
324
344
  weights_name=weights_name,
325
345
  force_download=force_download,
326
346
  cache_dir=cache_dir,
327
- resume_download=resume_download,
328
347
  proxies=proxies,
329
348
  local_files_only=local_files_only,
330
349
  token=token,
@@ -485,6 +504,32 @@ def infer_diffusers_model_type(checkpoint):
485
504
  elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
486
505
  model_type = "sd3"
487
506
 
507
+ elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
508
+ if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
509
+ model_type = "animatediff_scribble"
510
+
511
+ elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
512
+ model_type = "animatediff_rgb"
513
+
514
+ elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
515
+ model_type = "animatediff_v2"
516
+
517
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
518
+ model_type = "animatediff_sdxl_beta"
519
+
520
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
521
+ model_type = "animatediff_v1"
522
+
523
+ else:
524
+ model_type = "animatediff_v3"
525
+
526
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
527
+ if any(
528
+ g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
529
+ ):
530
+ model_type = "flux-dev"
531
+ else:
532
+ model_type = "flux-schnell"
488
533
  else:
489
534
  model_type = "v1"
490
535
 
@@ -1140,7 +1185,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1140
1185
  # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
1141
1186
  vae_state_dict = {}
1142
1187
  keys = list(checkpoint.keys())
1143
- vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
1188
+ vae_key = ""
1189
+ for ldm_vae_key in LDM_VAE_KEYS:
1190
+ if any(k.startswith(ldm_vae_key) for k in keys):
1191
+ vae_key = ldm_vae_key
1192
+
1144
1193
  for key in keys:
1145
1194
  if key.startswith(vae_key):
1146
1195
  vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
@@ -1441,14 +1490,22 @@ def _legacy_load_scheduler(
1441
1490
 
1442
1491
  if scheduler_type is not None:
1443
1492
  deprecation_message = (
1444
- "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
1493
+ "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
1494
+ "Example:\n\n"
1495
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
1496
+ "scheduler = DDIMScheduler()\n"
1497
+ "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
1445
1498
  )
1446
1499
  deprecate("scheduler_type", "1.0.0", deprecation_message)
1447
1500
 
1448
1501
  if prediction_type is not None:
1449
1502
  deprecation_message = (
1450
- "Please configure an instance of a Scheduler with the appropriate `prediction_type` "
1451
- "and pass the object directly to the `scheduler` argument in `from_single_file`."
1503
+ "Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
1504
+ "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
1505
+ "Example:\n\n"
1506
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
1507
+ 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
1508
+ "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
1452
1509
  )
1453
1510
  deprecate("prediction_type", "1.0.0", deprecation_message)
1454
1511
 
@@ -1808,4 +1865,232 @@ def create_diffusers_t5_model_from_checkpoint(
1808
1865
 
1809
1866
  else:
1810
1867
  model.load_state_dict(diffusers_format_checkpoint)
1868
+
1869
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
1870
+ if use_keep_in_fp32_modules:
1871
+ keep_in_fp32_modules = model._keep_in_fp32_modules
1872
+ else:
1873
+ keep_in_fp32_modules = []
1874
+
1875
+ if keep_in_fp32_modules is not None:
1876
+ for name, param in model.named_parameters():
1877
+ if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
1878
+ # param = param.to(torch.float32) does not work here as only in the local scope.
1879
+ param.data = param.data.to(torch.float32)
1880
+
1811
1881
  return model
1882
+
1883
+
1884
+ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
1885
+ converted_state_dict = {}
1886
+ for k, v in checkpoint.items():
1887
+ if "pos_encoder" in k:
1888
+ continue
1889
+
1890
+ else:
1891
+ converted_state_dict[
1892
+ k.replace(".norms.0", ".norm1")
1893
+ .replace(".norms.1", ".norm2")
1894
+ .replace(".ff_norm", ".norm3")
1895
+ .replace(".attention_blocks.0", ".attn1")
1896
+ .replace(".attention_blocks.1", ".attn2")
1897
+ .replace(".temporal_transformer", "")
1898
+ ] = v
1899
+
1900
+ return converted_state_dict
1901
+
1902
+
1903
+ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1904
+ converted_state_dict = {}
1905
+ keys = list(checkpoint.keys())
1906
+ for k in keys:
1907
+ if "model.diffusion_model." in k:
1908
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1909
+
1910
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
1911
+ num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
1912
+ mlp_ratio = 4.0
1913
+ inner_dim = 3072
1914
+
1915
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1916
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1917
+ def swap_scale_shift(weight):
1918
+ shift, scale = weight.chunk(2, dim=0)
1919
+ new_weight = torch.cat([scale, shift], dim=0)
1920
+ return new_weight
1921
+
1922
+ ## time_text_embed.timestep_embedder <- time_in
1923
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1924
+ "time_in.in_layer.weight"
1925
+ )
1926
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
1927
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1928
+ "time_in.out_layer.weight"
1929
+ )
1930
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
1931
+
1932
+ ## time_text_embed.text_embedder <- vector_in
1933
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
1934
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
1935
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
1936
+ "vector_in.out_layer.weight"
1937
+ )
1938
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
1939
+
1940
+ # guidance
1941
+ has_guidance = any("guidance" in k for k in checkpoint)
1942
+ if has_guidance:
1943
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
1944
+ "guidance_in.in_layer.weight"
1945
+ )
1946
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
1947
+ "guidance_in.in_layer.bias"
1948
+ )
1949
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
1950
+ "guidance_in.out_layer.weight"
1951
+ )
1952
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
1953
+ "guidance_in.out_layer.bias"
1954
+ )
1955
+
1956
+ # context_embedder
1957
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
1958
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
1959
+
1960
+ # x_embedder
1961
+ converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
1962
+ converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
1963
+
1964
+ # double transformer blocks
1965
+ for i in range(num_layers):
1966
+ block_prefix = f"transformer_blocks.{i}."
1967
+ # norms.
1968
+ ## norm1
1969
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
1970
+ f"double_blocks.{i}.img_mod.lin.weight"
1971
+ )
1972
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
1973
+ f"double_blocks.{i}.img_mod.lin.bias"
1974
+ )
1975
+ ## norm1_context
1976
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
1977
+ f"double_blocks.{i}.txt_mod.lin.weight"
1978
+ )
1979
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
1980
+ f"double_blocks.{i}.txt_mod.lin.bias"
1981
+ )
1982
+ # Q, K, V
1983
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
1984
+ context_q, context_k, context_v = torch.chunk(
1985
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
1986
+ )
1987
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1988
+ checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
1989
+ )
1990
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1991
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
1992
+ )
1993
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
1994
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
1995
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
1996
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
1997
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
1998
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
1999
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
2000
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
2001
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
2002
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
2003
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
2004
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
2005
+ # qk_norm
2006
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2007
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
2008
+ )
2009
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2010
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
2011
+ )
2012
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
2013
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
2014
+ )
2015
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
2016
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
2017
+ )
2018
+ # ff img_mlp
2019
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
2020
+ f"double_blocks.{i}.img_mlp.0.weight"
2021
+ )
2022
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
2023
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
2024
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
2025
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
2026
+ f"double_blocks.{i}.txt_mlp.0.weight"
2027
+ )
2028
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
2029
+ f"double_blocks.{i}.txt_mlp.0.bias"
2030
+ )
2031
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
2032
+ f"double_blocks.{i}.txt_mlp.2.weight"
2033
+ )
2034
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
2035
+ f"double_blocks.{i}.txt_mlp.2.bias"
2036
+ )
2037
+ # output projections.
2038
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
2039
+ f"double_blocks.{i}.img_attn.proj.weight"
2040
+ )
2041
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
2042
+ f"double_blocks.{i}.img_attn.proj.bias"
2043
+ )
2044
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
2045
+ f"double_blocks.{i}.txt_attn.proj.weight"
2046
+ )
2047
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
2048
+ f"double_blocks.{i}.txt_attn.proj.bias"
2049
+ )
2050
+
2051
+ # single transfomer blocks
2052
+ for i in range(num_single_layers):
2053
+ block_prefix = f"single_transformer_blocks.{i}."
2054
+ # norm.linear <- single_blocks.0.modulation.lin
2055
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
2056
+ f"single_blocks.{i}.modulation.lin.weight"
2057
+ )
2058
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
2059
+ f"single_blocks.{i}.modulation.lin.bias"
2060
+ )
2061
+ # Q, K, V, mlp
2062
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
2063
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
2064
+ q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
2065
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
2066
+ checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
2067
+ )
2068
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
2069
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
2070
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
2071
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
2072
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
2073
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
2074
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
2075
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
2076
+ # qk norm
2077
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2078
+ f"single_blocks.{i}.norm.query_norm.scale"
2079
+ )
2080
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2081
+ f"single_blocks.{i}.norm.key_norm.scale"
2082
+ )
2083
+ # output projections.
2084
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
2085
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
2086
+
2087
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2088
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2089
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
2090
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight")
2091
+ )
2092
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
2093
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias")
2094
+ )
2095
+
2096
+ return converted_state_dict
@@ -38,7 +38,6 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
38
38
  def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
39
39
  cache_dir = kwargs.pop("cache_dir", None)
40
40
  force_download = kwargs.pop("force_download", False)
41
- resume_download = kwargs.pop("resume_download", None)
42
41
  proxies = kwargs.pop("proxies", None)
43
42
  local_files_only = kwargs.pop("local_files_only", None)
44
43
  token = kwargs.pop("token", None)
@@ -72,7 +71,6 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
72
71
  weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
73
72
  cache_dir=cache_dir,
74
73
  force_download=force_download,
75
- resume_download=resume_download,
76
74
  proxies=proxies,
77
75
  local_files_only=local_files_only,
78
76
  token=token,
@@ -93,7 +91,6 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
93
91
  weights_name=weight_name or TEXT_INVERSION_NAME,
94
92
  cache_dir=cache_dir,
95
93
  force_download=force_download,
96
- resume_download=resume_download,
97
94
  proxies=proxies,
98
95
  local_files_only=local_files_only,
99
96
  token=token,
@@ -308,9 +305,7 @@ class TextualInversionLoaderMixin:
308
305
  force_download (`bool`, *optional*, defaults to `False`):
309
306
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
310
307
  cached versions if they exist.
311
- resume_download:
312
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
313
- of Diffusers.
308
+
314
309
  proxies (`Dict[str, str]`, *optional*):
315
310
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
316
311
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.