diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +28 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +278 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. diffusers-0.26.2.dist-info/RECORD +0 -384
  296. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  297. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
  298. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
3
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,7 +19,6 @@ import inspect
19
19
  import os
20
20
  import re
21
21
  import sys
22
- import warnings
23
22
  from dataclasses import dataclass
24
23
  from pathlib import Path
25
24
  from typing import Any, Callable, Dict, List, Optional, Union
@@ -42,21 +41,20 @@ from tqdm.auto import tqdm
42
41
 
43
42
  from .. import __version__
44
43
  from ..configuration_utils import ConfigMixin
44
+ from ..models import AutoencoderKL
45
+ from ..models.attention_processor import FusedAttnProcessor2_0
45
46
  from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
46
47
  from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
47
48
  from ..utils import (
48
49
  CONFIG_NAME,
49
50
  DEPRECATED_REVISION_ARGS,
50
- SAFETENSORS_WEIGHTS_NAME,
51
- WEIGHTS_NAME,
52
51
  BaseOutput,
52
+ PushToHubMixin,
53
53
  deprecate,
54
- get_class_from_dynamic_module,
55
54
  is_accelerate_available,
56
55
  is_accelerate_version,
57
- is_peft_available,
56
+ is_torch_npu_available,
58
57
  is_torch_version,
59
- is_transformers_available,
60
58
  logging,
61
59
  numpy_to_pil,
62
60
  )
@@ -64,55 +62,37 @@ from ..utils.hub_utils import load_or_create_model_card, populate_model_card
64
62
  from ..utils.torch_utils import is_compiled_module
65
63
 
66
64
 
67
- if is_transformers_available():
68
- import transformers
69
- from transformers import PreTrainedModel
70
- from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
71
- from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
72
- from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
65
+ if is_torch_npu_available():
66
+ import torch_npu # noqa: F401
73
67
 
74
- from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
68
+
69
+ from .pipeline_loading_utils import (
70
+ ALL_IMPORTABLE_CLASSES,
71
+ CONNECTED_PIPES_KEYS,
72
+ CUSTOM_PIPELINE_FILE_NAME,
73
+ LOADABLE_CLASSES,
74
+ _fetch_class_library_tuple,
75
+ _get_pipeline_class,
76
+ _unwrap_model,
77
+ is_safetensors_compatible,
78
+ load_sub_model,
79
+ maybe_raise_or_warn,
80
+ variant_compatible_siblings,
81
+ warn_deprecated_model_variant,
82
+ )
75
83
 
76
84
 
77
85
  if is_accelerate_available():
78
86
  import accelerate
79
87
 
80
88
 
81
- INDEX_FILE = "diffusion_pytorch_model.bin"
82
- CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
83
- DUMMY_MODULES_FOLDER = "diffusers.utils"
84
- TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
85
- CONNECTED_PIPES_KEYS = ["prior"]
86
-
89
+ LIBRARIES = []
90
+ for library in LOADABLE_CLASSES:
91
+ LIBRARIES.append(library)
87
92
 
88
93
  logger = logging.get_logger(__name__)
89
94
 
90
95
 
91
- LOADABLE_CLASSES = {
92
- "diffusers": {
93
- "ModelMixin": ["save_pretrained", "from_pretrained"],
94
- "SchedulerMixin": ["save_pretrained", "from_pretrained"],
95
- "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
96
- "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
97
- },
98
- "transformers": {
99
- "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
100
- "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
101
- "PreTrainedModel": ["save_pretrained", "from_pretrained"],
102
- "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
103
- "ProcessorMixin": ["save_pretrained", "from_pretrained"],
104
- "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
105
- },
106
- "onnxruntime.training": {
107
- "ORTModule": ["save_pretrained", "from_pretrained"],
108
- },
109
- }
110
-
111
- ALL_IMPORTABLE_CLASSES = {}
112
- for library in LOADABLE_CLASSES:
113
- ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
114
-
115
-
116
96
  @dataclass
117
97
  class ImagePipelineOutput(BaseOutput):
118
98
  """
@@ -140,432 +120,6 @@ class AudioPipelineOutput(BaseOutput):
140
120
  audios: np.ndarray
141
121
 
142
122
 
143
- def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
144
- """
145
- Checking for safetensors compatibility:
146
- - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
147
- files to know which safetensors files are needed.
148
- - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
149
-
150
- Converting default pytorch serialized filenames to safetensors serialized filenames:
151
- - For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
152
- - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
153
- extension is replaced with ".safetensors"
154
- """
155
- pt_filenames = []
156
-
157
- sf_filenames = set()
158
-
159
- passed_components = passed_components or []
160
-
161
- for filename in filenames:
162
- _, extension = os.path.splitext(filename)
163
-
164
- if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
165
- continue
166
-
167
- if extension == ".bin":
168
- pt_filenames.append(os.path.normpath(filename))
169
- elif extension == ".safetensors":
170
- sf_filenames.add(os.path.normpath(filename))
171
-
172
- for filename in pt_filenames:
173
- # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
174
- path, filename = os.path.split(filename)
175
- filename, extension = os.path.splitext(filename)
176
-
177
- if filename.startswith("pytorch_model"):
178
- filename = filename.replace("pytorch_model", "model")
179
- else:
180
- filename = filename
181
-
182
- expected_sf_filename = os.path.normpath(os.path.join(path, filename))
183
- expected_sf_filename = f"{expected_sf_filename}.safetensors"
184
- if expected_sf_filename not in sf_filenames:
185
- logger.warning(f"{expected_sf_filename} not found")
186
- return False
187
-
188
- return True
189
-
190
-
191
- def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
192
- weight_names = [
193
- WEIGHTS_NAME,
194
- SAFETENSORS_WEIGHTS_NAME,
195
- FLAX_WEIGHTS_NAME,
196
- ONNX_WEIGHTS_NAME,
197
- ONNX_EXTERNAL_WEIGHTS_NAME,
198
- ]
199
-
200
- if is_transformers_available():
201
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
202
-
203
- # model_pytorch, diffusion_model_pytorch, ...
204
- weight_prefixes = [w.split(".")[0] for w in weight_names]
205
- # .bin, .safetensors, ...
206
- weight_suffixs = [w.split(".")[-1] for w in weight_names]
207
- # -00001-of-00002
208
- transformers_index_format = r"\d{5}-of-\d{5}"
209
-
210
- if variant is not None:
211
- # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
212
- variant_file_re = re.compile(
213
- rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
214
- )
215
- # `text_encoder/pytorch_model.bin.index.fp16.json`
216
- variant_index_re = re.compile(
217
- rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
218
- )
219
-
220
- # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
221
- non_variant_file_re = re.compile(
222
- rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
223
- )
224
- # `text_encoder/pytorch_model.bin.index.json`
225
- non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
226
-
227
- if variant is not None:
228
- variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
229
- variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
230
- variant_filenames = variant_weights | variant_indexes
231
- else:
232
- variant_filenames = set()
233
-
234
- non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
235
- non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
236
- non_variant_filenames = non_variant_weights | non_variant_indexes
237
-
238
- # all variant filenames will be used by default
239
- usable_filenames = set(variant_filenames)
240
-
241
- def convert_to_variant(filename):
242
- if "index" in filename:
243
- variant_filename = filename.replace("index", f"index.{variant}")
244
- elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
245
- variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
246
- else:
247
- variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
248
- return variant_filename
249
-
250
- for f in non_variant_filenames:
251
- variant_filename = convert_to_variant(f)
252
- if variant_filename not in usable_filenames:
253
- usable_filenames.add(f)
254
-
255
- return usable_filenames, variant_filenames
256
-
257
-
258
- @validate_hf_hub_args
259
- def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
260
- info = model_info(
261
- pretrained_model_name_or_path,
262
- token=token,
263
- revision=None,
264
- )
265
- filenames = {sibling.rfilename for sibling in info.siblings}
266
- comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
267
- comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
268
-
269
- if set(model_filenames).issubset(set(comp_model_filenames)):
270
- warnings.warn(
271
- f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
272
- FutureWarning,
273
- )
274
- else:
275
- warnings.warn(
276
- f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
277
- FutureWarning,
278
- )
279
-
280
-
281
- def _unwrap_model(model):
282
- """Unwraps a model."""
283
- if is_compiled_module(model):
284
- model = model._orig_mod
285
-
286
- if is_peft_available():
287
- from peft import PeftModel
288
-
289
- if isinstance(model, PeftModel):
290
- model = model.base_model.model
291
-
292
- return model
293
-
294
-
295
- def maybe_raise_or_warn(
296
- library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
297
- ):
298
- """Simple helper method to raise or warn in case incorrect module has been passed"""
299
- if not is_pipeline_module:
300
- library = importlib.import_module(library_name)
301
- class_obj = getattr(library, class_name)
302
- class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
303
-
304
- expected_class_obj = None
305
- for class_name, class_candidate in class_candidates.items():
306
- if class_candidate is not None and issubclass(class_obj, class_candidate):
307
- expected_class_obj = class_candidate
308
-
309
- # Dynamo wraps the original model in a private class.
310
- # I didn't find a public API to get the original class.
311
- sub_model = passed_class_obj[name]
312
- unwrapped_sub_model = _unwrap_model(sub_model)
313
- model_cls = unwrapped_sub_model.__class__
314
-
315
- if not issubclass(model_cls, expected_class_obj):
316
- raise ValueError(
317
- f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
318
- )
319
- else:
320
- logger.warning(
321
- f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
322
- " has the correct type"
323
- )
324
-
325
-
326
- def get_class_obj_and_candidates(
327
- library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
328
- ):
329
- """Simple helper method to retrieve class object of module as well as potential parent class objects"""
330
- component_folder = os.path.join(cache_dir, component_name)
331
-
332
- if is_pipeline_module:
333
- pipeline_module = getattr(pipelines, library_name)
334
-
335
- class_obj = getattr(pipeline_module, class_name)
336
- class_candidates = {c: class_obj for c in importable_classes.keys()}
337
- elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
338
- # load custom component
339
- class_obj = get_class_from_dynamic_module(
340
- component_folder, module_file=library_name + ".py", class_name=class_name
341
- )
342
- class_candidates = {c: class_obj for c in importable_classes.keys()}
343
- else:
344
- # else we just import it from the library.
345
- library = importlib.import_module(library_name)
346
-
347
- class_obj = getattr(library, class_name)
348
- class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
349
-
350
- return class_obj, class_candidates
351
-
352
-
353
- def _get_pipeline_class(
354
- class_obj,
355
- config=None,
356
- load_connected_pipeline=False,
357
- custom_pipeline=None,
358
- repo_id=None,
359
- hub_revision=None,
360
- class_name=None,
361
- cache_dir=None,
362
- revision=None,
363
- ):
364
- if custom_pipeline is not None:
365
- if custom_pipeline.endswith(".py"):
366
- path = Path(custom_pipeline)
367
- # decompose into folder & file
368
- file_name = path.name
369
- custom_pipeline = path.parent.absolute()
370
- elif repo_id is not None:
371
- file_name = f"{custom_pipeline}.py"
372
- custom_pipeline = repo_id
373
- else:
374
- file_name = CUSTOM_PIPELINE_FILE_NAME
375
-
376
- if repo_id is not None and hub_revision is not None:
377
- # if we load the pipeline code from the Hub
378
- # make sure to overwrite the `revison`
379
- revision = hub_revision
380
-
381
- return get_class_from_dynamic_module(
382
- custom_pipeline,
383
- module_file=file_name,
384
- class_name=class_name,
385
- cache_dir=cache_dir,
386
- revision=revision,
387
- )
388
-
389
- if class_obj != DiffusionPipeline:
390
- return class_obj
391
-
392
- diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
393
- class_name = class_name or config["_class_name"]
394
- if not class_name:
395
- raise ValueError(
396
- "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
397
- )
398
-
399
- class_name = class_name[4:] if class_name.startswith("Flax") else class_name
400
-
401
- pipeline_cls = getattr(diffusers_module, class_name)
402
-
403
- if load_connected_pipeline:
404
- from .auto_pipeline import _get_connected_pipeline
405
-
406
- connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
407
- if connected_pipeline_cls is not None:
408
- logger.info(
409
- f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
410
- )
411
- else:
412
- logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
413
-
414
- pipeline_cls = connected_pipeline_cls or pipeline_cls
415
-
416
- return pipeline_cls
417
-
418
-
419
- def load_sub_model(
420
- library_name: str,
421
- class_name: str,
422
- importable_classes: List[Any],
423
- pipelines: Any,
424
- is_pipeline_module: bool,
425
- pipeline_class: Any,
426
- torch_dtype: torch.dtype,
427
- provider: Any,
428
- sess_options: Any,
429
- device_map: Optional[Union[Dict[str, torch.device], str]],
430
- max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
431
- offload_folder: Optional[Union[str, os.PathLike]],
432
- offload_state_dict: bool,
433
- model_variants: Dict[str, str],
434
- name: str,
435
- from_flax: bool,
436
- variant: str,
437
- low_cpu_mem_usage: bool,
438
- cached_folder: Union[str, os.PathLike],
439
- revision: str = None,
440
- ):
441
- """Helper method to load the module `name` from `library_name` and `class_name`"""
442
- # retrieve class candidates
443
- class_obj, class_candidates = get_class_obj_and_candidates(
444
- library_name,
445
- class_name,
446
- importable_classes,
447
- pipelines,
448
- is_pipeline_module,
449
- component_name=name,
450
- cache_dir=cached_folder,
451
- )
452
-
453
- load_method_name = None
454
- # retrive load method name
455
- for class_name, class_candidate in class_candidates.items():
456
- if class_candidate is not None and issubclass(class_obj, class_candidate):
457
- load_method_name = importable_classes[class_name][1]
458
-
459
- # if load method name is None, then we have a dummy module -> raise Error
460
- if load_method_name is None:
461
- none_module = class_obj.__module__
462
- is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
463
- TRANSFORMERS_DUMMY_MODULES_FOLDER
464
- )
465
- if is_dummy_path and "dummy" in none_module:
466
- # call class_obj for nice error message of missing requirements
467
- class_obj()
468
-
469
- raise ValueError(
470
- f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
471
- f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
472
- )
473
-
474
- load_method = getattr(class_obj, load_method_name)
475
-
476
- # add kwargs to loading method
477
- diffusers_module = importlib.import_module(__name__.split(".")[0])
478
- loading_kwargs = {}
479
- if issubclass(class_obj, torch.nn.Module):
480
- loading_kwargs["torch_dtype"] = torch_dtype
481
- if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
482
- loading_kwargs["provider"] = provider
483
- loading_kwargs["sess_options"] = sess_options
484
-
485
- is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
486
-
487
- if is_transformers_available():
488
- transformers_version = version.parse(version.parse(transformers.__version__).base_version)
489
- else:
490
- transformers_version = "N/A"
491
-
492
- is_transformers_model = (
493
- is_transformers_available()
494
- and issubclass(class_obj, PreTrainedModel)
495
- and transformers_version >= version.parse("4.20.0")
496
- )
497
-
498
- # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
499
- # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
500
- # This makes sure that the weights won't be initialized which significantly speeds up loading.
501
- if is_diffusers_model or is_transformers_model:
502
- loading_kwargs["device_map"] = device_map
503
- loading_kwargs["max_memory"] = max_memory
504
- loading_kwargs["offload_folder"] = offload_folder
505
- loading_kwargs["offload_state_dict"] = offload_state_dict
506
- loading_kwargs["variant"] = model_variants.pop(name, None)
507
- if from_flax:
508
- loading_kwargs["from_flax"] = True
509
-
510
- # the following can be deleted once the minimum required `transformers` version
511
- # is higher than 4.27
512
- if (
513
- is_transformers_model
514
- and loading_kwargs["variant"] is not None
515
- and transformers_version < version.parse("4.27.0")
516
- ):
517
- raise ImportError(
518
- f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
519
- )
520
- elif is_transformers_model and loading_kwargs["variant"] is None:
521
- loading_kwargs.pop("variant")
522
-
523
- # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
524
- if not (from_flax and is_transformers_model):
525
- loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
526
- else:
527
- loading_kwargs["low_cpu_mem_usage"] = False
528
-
529
- # check if the module is in a subdirectory
530
- if os.path.isdir(os.path.join(cached_folder, name)):
531
- loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
532
- else:
533
- # else load from the root directory
534
- loaded_sub_model = load_method(cached_folder, **loading_kwargs)
535
-
536
- return loaded_sub_model
537
-
538
-
539
- def _fetch_class_library_tuple(module):
540
- # import it here to avoid circular import
541
- diffusers_module = importlib.import_module(__name__.split(".")[0])
542
- pipelines = getattr(diffusers_module, "pipelines")
543
-
544
- # register the config from the original module, not the dynamo compiled one
545
- not_compiled_module = _unwrap_model(module)
546
- library = not_compiled_module.__module__.split(".")[0]
547
-
548
- # check if the module is a pipeline module
549
- module_path_items = not_compiled_module.__module__.split(".")
550
- pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
551
-
552
- path = not_compiled_module.__module__.split(".")
553
- is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
554
-
555
- # if library is not in LOADABLE_CLASSES, then it is a custom module.
556
- # Or if it's a pipeline module, then the module is inside the pipeline
557
- # folder so we set the library to module name.
558
- if is_pipeline_module:
559
- library = pipeline_dir
560
- elif library not in LOADABLE_CLASSES:
561
- library = not_compiled_module.__module__
562
-
563
- # retrieve class_name
564
- class_name = not_compiled_module.__class__.__name__
565
-
566
- return (library, class_name)
567
-
568
-
569
123
  class DiffusionPipeline(ConfigMixin, PushToHubMixin):
570
124
  r"""
571
125
  Base class for all pipelines.
@@ -702,7 +256,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
702
256
  break
703
257
 
704
258
  if save_method_name is None:
705
- logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
259
+ logger.warning(
260
+ f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved."
261
+ )
706
262
  # make sure that unsaveable components are not tried to be loaded afterward
707
263
  self.register_to_config(**{pipeline_component_name: (None, None)})
708
264
  continue
@@ -775,32 +331,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
775
331
  Returns:
776
332
  [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
777
333
  """
778
-
779
- torch_dtype = kwargs.pop("torch_dtype", None)
780
- if torch_dtype is not None:
781
- deprecate("torch_dtype", "0.27.0", "")
782
- torch_device = kwargs.pop("torch_device", None)
783
- if torch_device is not None:
784
- deprecate("torch_device", "0.27.0", "")
785
-
786
- dtype_kwarg = kwargs.pop("dtype", None)
787
- device_kwarg = kwargs.pop("device", None)
334
+ dtype = kwargs.pop("dtype", None)
335
+ device = kwargs.pop("device", None)
788
336
  silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
789
337
 
790
- if torch_dtype is not None and dtype_kwarg is not None:
791
- raise ValueError(
792
- "You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
793
- )
794
-
795
- dtype = torch_dtype or dtype_kwarg
796
-
797
- if torch_device is not None and device_kwarg is not None:
798
- raise ValueError(
799
- "You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
800
- )
801
-
802
- device = torch_device or device_kwarg
803
-
804
338
  dtype_arg = None
805
339
  device_arg = None
806
340
  if len(args) == 1:
@@ -873,12 +407,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
873
407
 
874
408
  if is_loaded_in_8bit and dtype is not None:
875
409
  logger.warning(
876
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
410
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
877
411
  )
878
412
 
879
413
  if is_loaded_in_8bit and device is not None:
880
414
  logger.warning(
881
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
415
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
882
416
  )
883
417
  else:
884
418
  module.to(device, dtype)
@@ -1003,10 +537,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1003
537
  revision (`str`, *optional*, defaults to `"main"`):
1004
538
  The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1005
539
  allowed by Git.
1006
- custom_revision (`str`, *optional*, defaults to `"main"`):
540
+ custom_revision (`str`, *optional*):
1007
541
  The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
1008
- `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
1009
- custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
542
+ `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
1010
543
  mirror (`str`, *optional*):
1011
544
  Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
1012
545
  guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
@@ -1100,6 +633,33 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1100
633
  use_onnx = kwargs.pop("use_onnx", None)
1101
634
  load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
1102
635
 
636
+ if low_cpu_mem_usage and not is_accelerate_available():
637
+ low_cpu_mem_usage = False
638
+ logger.warning(
639
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
640
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
641
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
642
+ " install accelerate\n```\n."
643
+ )
644
+
645
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
646
+ raise NotImplementedError(
647
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
648
+ " `device_map=None`."
649
+ )
650
+
651
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
652
+ raise NotImplementedError(
653
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
654
+ " `low_cpu_mem_usage=False`."
655
+ )
656
+
657
+ if low_cpu_mem_usage is False and device_map is not None:
658
+ raise ValueError(
659
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
660
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
661
+ )
662
+
1103
663
  # 1. Download the checkpoints and configs
1104
664
  # use snapshot download here to get it working from from_pretrained
1105
665
  if not os.path.isdir(pretrained_model_name_or_path):
@@ -1232,33 +792,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1232
792
  f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
1233
793
  )
1234
794
 
1235
- if low_cpu_mem_usage and not is_accelerate_available():
1236
- low_cpu_mem_usage = False
1237
- logger.warning(
1238
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1239
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1240
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
1241
- " install accelerate\n```\n."
1242
- )
1243
-
1244
- if device_map is not None and not is_torch_version(">=", "1.9.0"):
1245
- raise NotImplementedError(
1246
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1247
- " `device_map=None`."
1248
- )
1249
-
1250
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
1251
- raise NotImplementedError(
1252
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1253
- " `low_cpu_mem_usage=False`."
1254
- )
1255
-
1256
- if low_cpu_mem_usage is False and device_map is not None:
1257
- raise ValueError(
1258
- f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
1259
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1260
- )
1261
-
1262
795
  # import it here to avoid circular import
1263
796
  from diffusers import pipelines
1264
797
 
@@ -1303,7 +836,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1303
836
  variant=variant,
1304
837
  low_cpu_mem_usage=low_cpu_mem_usage,
1305
838
  cached_folder=cached_folder,
1306
- revision=revision,
1307
839
  )
1308
840
  logger.info(
1309
841
  f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1445,6 +977,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1445
977
 
1446
978
  device_type = torch_device.type
1447
979
  device = torch.device(f"{device_type}:{self._offload_gpu_id}")
980
+ self._offload_device = device
1448
981
 
1449
982
  if self.device.type != "cpu":
1450
983
  self.to("cpu", silence_dtype_warnings=True)
@@ -1494,7 +1027,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1494
1027
  hook.remove()
1495
1028
 
1496
1029
  # make sure the model is in the same state as before calling it
1497
- self.enable_model_cpu_offload()
1030
+ self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
1498
1031
 
1499
1032
  def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
1500
1033
  r"""
@@ -1530,6 +1063,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1530
1063
 
1531
1064
  device_type = torch_device.type
1532
1065
  device = torch.device(f"{device_type}:{self._offload_gpu_id}")
1066
+ self._offload_device = device
1533
1067
 
1534
1068
  if self.device.type != "cpu":
1535
1069
  self.to("cpu", silence_dtype_warnings=True)
@@ -1670,7 +1204,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1670
1204
  try:
1671
1205
  info = model_info(pretrained_model_name, token=token, revision=revision)
1672
1206
  except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
1673
- logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
1207
+ logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
1674
1208
  local_files_only = True
1675
1209
  model_info_call_error = e # save error to reraise it if model is not cached locally
1676
1210
 
@@ -1821,7 +1355,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1821
1355
  len(safetensors_variant_filenames) > 0
1822
1356
  and safetensors_model_filenames != safetensors_variant_filenames
1823
1357
  ):
1824
- logger.warn(
1358
+ logger.warning(
1825
1359
  f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1826
1360
  )
1827
1361
  else:
@@ -1834,7 +1368,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1834
1368
  bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
1835
1369
  bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
1836
1370
  if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
1837
- logger.warn(
1371
+ logger.warning(
1838
1372
  f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1839
1373
  )
1840
1374
 
@@ -1918,7 +1452,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1918
1452
  else:
1919
1453
  # 2. we forced `local_files_only=True` when `model_info` failed
1920
1454
  raise EnvironmentError(
1921
- f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occured"
1455
+ f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
1922
1456
  " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
1923
1457
  " above."
1924
1458
  ) from model_info_call_error
@@ -2115,3 +1649,123 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
2115
1649
 
2116
1650
  for module in modules:
2117
1651
  module.set_attention_slice(slice_size)
1652
+
1653
+
1654
+ class StableDiffusionMixin:
1655
+ r"""
1656
+ Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
1657
+ """
1658
+
1659
+ def enable_vae_slicing(self):
1660
+ r"""
1661
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1662
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1663
+ """
1664
+ self.vae.enable_slicing()
1665
+
1666
+ def disable_vae_slicing(self):
1667
+ r"""
1668
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
1669
+ computing decoding in one step.
1670
+ """
1671
+ self.vae.disable_slicing()
1672
+
1673
+ def enable_vae_tiling(self):
1674
+ r"""
1675
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1676
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1677
+ processing larger images.
1678
+ """
1679
+ self.vae.enable_tiling()
1680
+
1681
+ def disable_vae_tiling(self):
1682
+ r"""
1683
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
1684
+ computing decoding in one step.
1685
+ """
1686
+ self.vae.disable_tiling()
1687
+
1688
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
1689
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
1690
+
1691
+ The suffixes after the scaling factors represent the stages where they are being applied.
1692
+
1693
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
1694
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
1695
+
1696
+ Args:
1697
+ s1 (`float`):
1698
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
1699
+ mitigate "oversmoothing effect" in the enhanced denoising process.
1700
+ s2 (`float`):
1701
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
1702
+ mitigate "oversmoothing effect" in the enhanced denoising process.
1703
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
1704
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
1705
+ """
1706
+ if not hasattr(self, "unet"):
1707
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
1708
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
1709
+
1710
+ def disable_freeu(self):
1711
+ """Disables the FreeU mechanism if enabled."""
1712
+ self.unet.disable_freeu()
1713
+
1714
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
1715
+ """
1716
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
1717
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1718
+
1719
+ <Tip warning={true}>
1720
+
1721
+ This API is 🧪 experimental.
1722
+
1723
+ </Tip>
1724
+
1725
+ Args:
1726
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
1727
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
1728
+ """
1729
+ self.fusing_unet = False
1730
+ self.fusing_vae = False
1731
+
1732
+ if unet:
1733
+ self.fusing_unet = True
1734
+ self.unet.fuse_qkv_projections()
1735
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
1736
+
1737
+ if vae:
1738
+ if not isinstance(self.vae, AutoencoderKL):
1739
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
1740
+
1741
+ self.fusing_vae = True
1742
+ self.vae.fuse_qkv_projections()
1743
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
1744
+
1745
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
1746
+ """Disable QKV projection fusion if enabled.
1747
+
1748
+ <Tip warning={true}>
1749
+
1750
+ This API is 🧪 experimental.
1751
+
1752
+ </Tip>
1753
+
1754
+ Args:
1755
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
1756
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
1757
+
1758
+ """
1759
+ if unet:
1760
+ if not self.fusing_unet:
1761
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
1762
+ else:
1763
+ self.unet.unfuse_qkv_projections()
1764
+ self.fusing_unet = False
1765
+
1766
+ if vae:
1767
+ if not self.fusing_vae:
1768
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
1769
+ else:
1770
+ self.vae.unfuse_qkv_projections()
1771
+ self.fusing_vae = False