diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +28 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +278 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. diffusers-0.26.2.dist-info/RECORD +0 -384
  296. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  297. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
  298. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,508 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import importlib
18
+ import os
19
+ import re
20
+ import warnings
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Optional, Union
23
+
24
+ import torch
25
+ from huggingface_hub import (
26
+ model_info,
27
+ )
28
+ from packaging import version
29
+
30
+ from ..utils import (
31
+ SAFETENSORS_WEIGHTS_NAME,
32
+ WEIGHTS_NAME,
33
+ get_class_from_dynamic_module,
34
+ is_peft_available,
35
+ is_transformers_available,
36
+ logging,
37
+ )
38
+ from ..utils.torch_utils import is_compiled_module
39
+
40
+
41
+ if is_transformers_available():
42
+ import transformers
43
+ from transformers import PreTrainedModel
44
+ from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
45
+ from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
46
+ from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
47
+ from huggingface_hub.utils import validate_hf_hub_args
48
+
49
+ from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
50
+
51
+
52
+ INDEX_FILE = "diffusion_pytorch_model.bin"
53
+ CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
54
+ DUMMY_MODULES_FOLDER = "diffusers.utils"
55
+ TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
56
+ CONNECTED_PIPES_KEYS = ["prior"]
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+ LOADABLE_CLASSES = {
61
+ "diffusers": {
62
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
63
+ "SchedulerMixin": ["save_pretrained", "from_pretrained"],
64
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
65
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
66
+ },
67
+ "transformers": {
68
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
69
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
70
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
71
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
72
+ "ProcessorMixin": ["save_pretrained", "from_pretrained"],
73
+ "ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
74
+ },
75
+ "onnxruntime.training": {
76
+ "ORTModule": ["save_pretrained", "from_pretrained"],
77
+ },
78
+ }
79
+
80
+ ALL_IMPORTABLE_CLASSES = {}
81
+ for library in LOADABLE_CLASSES:
82
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
83
+
84
+
85
+ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
86
+ """
87
+ Checking for safetensors compatibility:
88
+ - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
89
+ files to know which safetensors files are needed.
90
+ - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
91
+
92
+ Converting default pytorch serialized filenames to safetensors serialized filenames:
93
+ - For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
94
+ - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
95
+ extension is replaced with ".safetensors"
96
+ """
97
+ pt_filenames = []
98
+
99
+ sf_filenames = set()
100
+
101
+ passed_components = passed_components or []
102
+
103
+ for filename in filenames:
104
+ _, extension = os.path.splitext(filename)
105
+
106
+ if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
107
+ continue
108
+
109
+ if extension == ".bin":
110
+ pt_filenames.append(os.path.normpath(filename))
111
+ elif extension == ".safetensors":
112
+ sf_filenames.add(os.path.normpath(filename))
113
+
114
+ for filename in pt_filenames:
115
+ # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
116
+ path, filename = os.path.split(filename)
117
+ filename, extension = os.path.splitext(filename)
118
+
119
+ if filename.startswith("pytorch_model"):
120
+ filename = filename.replace("pytorch_model", "model")
121
+ else:
122
+ filename = filename
123
+
124
+ expected_sf_filename = os.path.normpath(os.path.join(path, filename))
125
+ expected_sf_filename = f"{expected_sf_filename}.safetensors"
126
+ if expected_sf_filename not in sf_filenames:
127
+ logger.warning(f"{expected_sf_filename} not found")
128
+ return False
129
+
130
+ return True
131
+
132
+
133
+ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
134
+ weight_names = [
135
+ WEIGHTS_NAME,
136
+ SAFETENSORS_WEIGHTS_NAME,
137
+ FLAX_WEIGHTS_NAME,
138
+ ONNX_WEIGHTS_NAME,
139
+ ONNX_EXTERNAL_WEIGHTS_NAME,
140
+ ]
141
+
142
+ if is_transformers_available():
143
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
144
+
145
+ # model_pytorch, diffusion_model_pytorch, ...
146
+ weight_prefixes = [w.split(".")[0] for w in weight_names]
147
+ # .bin, .safetensors, ...
148
+ weight_suffixs = [w.split(".")[-1] for w in weight_names]
149
+ # -00001-of-00002
150
+ transformers_index_format = r"\d{5}-of-\d{5}"
151
+
152
+ if variant is not None:
153
+ # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
154
+ variant_file_re = re.compile(
155
+ rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
156
+ )
157
+ # `text_encoder/pytorch_model.bin.index.fp16.json`
158
+ variant_index_re = re.compile(
159
+ rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
160
+ )
161
+
162
+ # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
163
+ non_variant_file_re = re.compile(
164
+ rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
165
+ )
166
+ # `text_encoder/pytorch_model.bin.index.json`
167
+ non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
168
+
169
+ if variant is not None:
170
+ variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
171
+ variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
172
+ variant_filenames = variant_weights | variant_indexes
173
+ else:
174
+ variant_filenames = set()
175
+
176
+ non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
177
+ non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
178
+ non_variant_filenames = non_variant_weights | non_variant_indexes
179
+
180
+ # all variant filenames will be used by default
181
+ usable_filenames = set(variant_filenames)
182
+
183
+ def convert_to_variant(filename):
184
+ if "index" in filename:
185
+ variant_filename = filename.replace("index", f"index.{variant}")
186
+ elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
187
+ variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
188
+ else:
189
+ variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
190
+ return variant_filename
191
+
192
+ for f in non_variant_filenames:
193
+ variant_filename = convert_to_variant(f)
194
+ if variant_filename not in usable_filenames:
195
+ usable_filenames.add(f)
196
+
197
+ return usable_filenames, variant_filenames
198
+
199
+
200
+ @validate_hf_hub_args
201
+ def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
202
+ info = model_info(
203
+ pretrained_model_name_or_path,
204
+ token=token,
205
+ revision=None,
206
+ )
207
+ filenames = {sibling.rfilename for sibling in info.siblings}
208
+ comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
209
+ comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
210
+
211
+ if set(model_filenames).issubset(set(comp_model_filenames)):
212
+ warnings.warn(
213
+ f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
214
+ FutureWarning,
215
+ )
216
+ else:
217
+ warnings.warn(
218
+ f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
219
+ FutureWarning,
220
+ )
221
+
222
+
223
+ def _unwrap_model(model):
224
+ """Unwraps a model."""
225
+ if is_compiled_module(model):
226
+ model = model._orig_mod
227
+
228
+ if is_peft_available():
229
+ from peft import PeftModel
230
+
231
+ if isinstance(model, PeftModel):
232
+ model = model.base_model.model
233
+
234
+ return model
235
+
236
+
237
+ def maybe_raise_or_warn(
238
+ library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
239
+ ):
240
+ """Simple helper method to raise or warn in case incorrect module has been passed"""
241
+ if not is_pipeline_module:
242
+ library = importlib.import_module(library_name)
243
+ class_obj = getattr(library, class_name)
244
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
245
+
246
+ expected_class_obj = None
247
+ for class_name, class_candidate in class_candidates.items():
248
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
249
+ expected_class_obj = class_candidate
250
+
251
+ # Dynamo wraps the original model in a private class.
252
+ # I didn't find a public API to get the original class.
253
+ sub_model = passed_class_obj[name]
254
+ unwrapped_sub_model = _unwrap_model(sub_model)
255
+ model_cls = unwrapped_sub_model.__class__
256
+
257
+ if not issubclass(model_cls, expected_class_obj):
258
+ raise ValueError(
259
+ f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
260
+ )
261
+ else:
262
+ logger.warning(
263
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
264
+ " has the correct type"
265
+ )
266
+
267
+
268
+ def get_class_obj_and_candidates(
269
+ library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
270
+ ):
271
+ """Simple helper method to retrieve class object of module as well as potential parent class objects"""
272
+ component_folder = os.path.join(cache_dir, component_name)
273
+
274
+ if is_pipeline_module:
275
+ pipeline_module = getattr(pipelines, library_name)
276
+
277
+ class_obj = getattr(pipeline_module, class_name)
278
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
279
+ elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
280
+ # load custom component
281
+ class_obj = get_class_from_dynamic_module(
282
+ component_folder, module_file=library_name + ".py", class_name=class_name
283
+ )
284
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
285
+ else:
286
+ # else we just import it from the library.
287
+ library = importlib.import_module(library_name)
288
+
289
+ class_obj = getattr(library, class_name)
290
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
291
+
292
+ return class_obj, class_candidates
293
+
294
+
295
+ def _get_pipeline_class(
296
+ class_obj,
297
+ config=None,
298
+ load_connected_pipeline=False,
299
+ custom_pipeline=None,
300
+ repo_id=None,
301
+ hub_revision=None,
302
+ class_name=None,
303
+ cache_dir=None,
304
+ revision=None,
305
+ ):
306
+ if custom_pipeline is not None:
307
+ if custom_pipeline.endswith(".py"):
308
+ path = Path(custom_pipeline)
309
+ # decompose into folder & file
310
+ file_name = path.name
311
+ custom_pipeline = path.parent.absolute()
312
+ elif repo_id is not None:
313
+ file_name = f"{custom_pipeline}.py"
314
+ custom_pipeline = repo_id
315
+ else:
316
+ file_name = CUSTOM_PIPELINE_FILE_NAME
317
+
318
+ if repo_id is not None and hub_revision is not None:
319
+ # if we load the pipeline code from the Hub
320
+ # make sure to overwrite the `revision`
321
+ revision = hub_revision
322
+
323
+ return get_class_from_dynamic_module(
324
+ custom_pipeline,
325
+ module_file=file_name,
326
+ class_name=class_name,
327
+ cache_dir=cache_dir,
328
+ revision=revision,
329
+ )
330
+
331
+ if class_obj.__name__ != "DiffusionPipeline":
332
+ return class_obj
333
+
334
+ diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
335
+ class_name = class_name or config["_class_name"]
336
+ if not class_name:
337
+ raise ValueError(
338
+ "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
339
+ )
340
+
341
+ class_name = class_name[4:] if class_name.startswith("Flax") else class_name
342
+
343
+ pipeline_cls = getattr(diffusers_module, class_name)
344
+
345
+ if load_connected_pipeline:
346
+ from .auto_pipeline import _get_connected_pipeline
347
+
348
+ connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
349
+ if connected_pipeline_cls is not None:
350
+ logger.info(
351
+ f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
352
+ )
353
+ else:
354
+ logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
355
+
356
+ pipeline_cls = connected_pipeline_cls or pipeline_cls
357
+
358
+ return pipeline_cls
359
+
360
+
361
+ def load_sub_model(
362
+ library_name: str,
363
+ class_name: str,
364
+ importable_classes: List[Any],
365
+ pipelines: Any,
366
+ is_pipeline_module: bool,
367
+ pipeline_class: Any,
368
+ torch_dtype: torch.dtype,
369
+ provider: Any,
370
+ sess_options: Any,
371
+ device_map: Optional[Union[Dict[str, torch.device], str]],
372
+ max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
373
+ offload_folder: Optional[Union[str, os.PathLike]],
374
+ offload_state_dict: bool,
375
+ model_variants: Dict[str, str],
376
+ name: str,
377
+ from_flax: bool,
378
+ variant: str,
379
+ low_cpu_mem_usage: bool,
380
+ cached_folder: Union[str, os.PathLike],
381
+ ):
382
+ """Helper method to load the module `name` from `library_name` and `class_name`"""
383
+ # retrieve class candidates
384
+ class_obj, class_candidates = get_class_obj_and_candidates(
385
+ library_name,
386
+ class_name,
387
+ importable_classes,
388
+ pipelines,
389
+ is_pipeline_module,
390
+ component_name=name,
391
+ cache_dir=cached_folder,
392
+ )
393
+
394
+ load_method_name = None
395
+ # retrieve load method name
396
+ for class_name, class_candidate in class_candidates.items():
397
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
398
+ load_method_name = importable_classes[class_name][1]
399
+
400
+ # if load method name is None, then we have a dummy module -> raise Error
401
+ if load_method_name is None:
402
+ none_module = class_obj.__module__
403
+ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
404
+ TRANSFORMERS_DUMMY_MODULES_FOLDER
405
+ )
406
+ if is_dummy_path and "dummy" in none_module:
407
+ # call class_obj for nice error message of missing requirements
408
+ class_obj()
409
+
410
+ raise ValueError(
411
+ f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
412
+ f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
413
+ )
414
+
415
+ load_method = getattr(class_obj, load_method_name)
416
+
417
+ # add kwargs to loading method
418
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
419
+ loading_kwargs = {}
420
+ if issubclass(class_obj, torch.nn.Module):
421
+ loading_kwargs["torch_dtype"] = torch_dtype
422
+ if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
423
+ loading_kwargs["provider"] = provider
424
+ loading_kwargs["sess_options"] = sess_options
425
+
426
+ is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
427
+
428
+ if is_transformers_available():
429
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
430
+ else:
431
+ transformers_version = "N/A"
432
+
433
+ is_transformers_model = (
434
+ is_transformers_available()
435
+ and issubclass(class_obj, PreTrainedModel)
436
+ and transformers_version >= version.parse("4.20.0")
437
+ )
438
+
439
+ # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
440
+ # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
441
+ # This makes sure that the weights won't be initialized which significantly speeds up loading.
442
+ if is_diffusers_model or is_transformers_model:
443
+ loading_kwargs["device_map"] = device_map
444
+ loading_kwargs["max_memory"] = max_memory
445
+ loading_kwargs["offload_folder"] = offload_folder
446
+ loading_kwargs["offload_state_dict"] = offload_state_dict
447
+ loading_kwargs["variant"] = model_variants.pop(name, None)
448
+
449
+ if from_flax:
450
+ loading_kwargs["from_flax"] = True
451
+
452
+ # the following can be deleted once the minimum required `transformers` version
453
+ # is higher than 4.27
454
+ if (
455
+ is_transformers_model
456
+ and loading_kwargs["variant"] is not None
457
+ and transformers_version < version.parse("4.27.0")
458
+ ):
459
+ raise ImportError(
460
+ f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
461
+ )
462
+ elif is_transformers_model and loading_kwargs["variant"] is None:
463
+ loading_kwargs.pop("variant")
464
+
465
+ # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
466
+ if not (from_flax and is_transformers_model):
467
+ loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
468
+ else:
469
+ loading_kwargs["low_cpu_mem_usage"] = False
470
+
471
+ # check if the module is in a subdirectory
472
+ if os.path.isdir(os.path.join(cached_folder, name)):
473
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
474
+ else:
475
+ # else load from the root directory
476
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
477
+
478
+ return loaded_sub_model
479
+
480
+
481
+ def _fetch_class_library_tuple(module):
482
+ # import it here to avoid circular import
483
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
484
+ pipelines = getattr(diffusers_module, "pipelines")
485
+
486
+ # register the config from the original module, not the dynamo compiled one
487
+ not_compiled_module = _unwrap_model(module)
488
+ library = not_compiled_module.__module__.split(".")[0]
489
+
490
+ # check if the module is a pipeline module
491
+ module_path_items = not_compiled_module.__module__.split(".")
492
+ pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
493
+
494
+ path = not_compiled_module.__module__.split(".")
495
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
496
+
497
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
498
+ # Or if it's a pipeline module, then the module is inside the pipeline
499
+ # folder so we set the library to module name.
500
+ if is_pipeline_module:
501
+ library = pipeline_dir
502
+ elif library not in LOADABLE_CLASSES:
503
+ library = not_compiled_module.__module__
504
+
505
+ # retrieve class_name
506
+ class_name = not_compiled_module.__class__.__name__
507
+
508
+ return (library, class_name)