diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -98,15 +98,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
98
98
  self.custom_timesteps = False
99
99
  self.is_scale_input_called = False
100
100
  self._step_index = None
101
+ self._begin_index = None
101
102
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
102
103
 
103
- def index_for_timestep(self, timestep, schedule_timesteps=None):
104
- if schedule_timesteps is None:
105
- schedule_timesteps = self.timesteps
106
-
107
- indices = (schedule_timesteps == timestep).nonzero()
108
- return indices.item()
109
-
110
104
  @property
111
105
  def step_index(self):
112
106
  """
@@ -114,6 +108,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
114
108
  """
115
109
  return self._step_index
116
110
 
111
+ @property
112
+ def begin_index(self):
113
+ """
114
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
115
+ """
116
+ return self._begin_index
117
+
118
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
119
+ def set_begin_index(self, begin_index: int = 0):
120
+ """
121
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
122
+
123
+ Args:
124
+ begin_index (`int`):
125
+ The begin index for the scheduler.
126
+ """
127
+ self._begin_index = begin_index
128
+
117
129
  def scale_model_input(
118
130
  self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
119
131
  ) -> torch.FloatTensor:
@@ -231,6 +243,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
231
243
  self.timesteps = torch.from_numpy(timesteps).to(device=device)
232
244
 
233
245
  self._step_index = None
246
+ self._begin_index = None
234
247
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
235
248
 
236
249
  # Modified _convert_to_karras implementation that takes in ramp as argument
@@ -280,23 +293,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
280
293
  c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
281
294
  return c_skip, c_out
282
295
 
283
- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
284
- def _init_step_index(self, timestep):
285
- if isinstance(timestep, torch.Tensor):
286
- timestep = timestep.to(self.timesteps.device)
296
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
297
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
298
+ if schedule_timesteps is None:
299
+ schedule_timesteps = self.timesteps
287
300
 
288
- index_candidates = (self.timesteps == timestep).nonzero()
301
+ indices = (schedule_timesteps == timestep).nonzero()
289
302
 
290
303
  # The sigma index that is taken for the **very** first `step`
291
304
  # is always the second index (or the last index if there is only 1)
292
305
  # This way we can ensure we don't accidentally skip a sigma in
293
306
  # case we start in the middle of the denoising schedule (e.g. for image-to-image)
294
- if len(index_candidates) > 1:
295
- step_index = index_candidates[1]
296
- else:
297
- step_index = index_candidates[0]
307
+ pos = 1 if len(indices) > 1 else 0
308
+
309
+ return indices[pos].item()
298
310
 
299
- self._step_index = step_index.item()
311
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
312
+ def _init_step_index(self, timestep):
313
+ if self.begin_index is None:
314
+ if isinstance(timestep, torch.Tensor):
315
+ timestep = timestep.to(self.timesteps.device)
316
+ self._step_index = self.index_for_timestep(timestep)
317
+ else:
318
+ self._step_index = self._begin_index
300
319
 
301
320
  def step(
302
321
  self,
@@ -412,7 +431,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
412
431
  schedule_timesteps = self.timesteps.to(original_samples.device)
413
432
  timesteps = timesteps.to(original_samples.device)
414
433
 
415
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
434
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
435
+ if self.begin_index is None:
436
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
437
+ else:
438
+ step_indices = [self.begin_index] * timesteps.shape[0]
416
439
 
417
440
  sigma = sigmas[step_indices].flatten()
418
441
  while len(sigma.shape) < len(original_samples.shape):
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -157,9 +157,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
157
157
  there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
158
158
  otherwise it uses the alpha value at step 0.
159
159
  steps_offset (`int`, defaults to 0):
160
- An offset added to the inference steps. You can use a combination of `offset=1` and
161
- `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
162
- Diffusion.
160
+ An offset added to the inference steps, as required by some model families.
163
161
  prediction_type (`str`, defaults to `epsilon`, *optional*):
164
162
  Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
165
163
  `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -85,15 +85,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
85
85
  trained_betas (`jnp.ndarray`, optional):
86
86
  option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
87
87
  clip_sample (`bool`, default `True`):
88
- option to clip predicted sample between -1 and 1 for numerical stability.
88
+ option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`.
89
+ clip_sample_range (`float`, default `1.0`):
90
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
89
91
  set_alpha_to_one (`bool`, default `True`):
90
92
  each diffusion step uses the value of alphas product at that step and at the previous one. For the final
91
93
  step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
92
94
  otherwise it uses the value of alpha at step 0.
93
95
  steps_offset (`int`, default `0`):
94
- an offset added to the inference steps. You can use a combination of `offset=1` and
95
- `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
96
- stable diffusion.
96
+ An offset added to the inference steps, as required by some model families.
97
97
  prediction_type (`str`, default `epsilon`):
98
98
  indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
99
99
  `v-prediction` is not supported for this scheduler.
@@ -117,6 +117,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
117
117
  beta_end: float = 0.02,
118
118
  beta_schedule: str = "linear",
119
119
  trained_betas: Optional[jnp.ndarray] = None,
120
+ clip_sample: bool = True,
121
+ clip_sample_range: float = 1.0,
120
122
  set_alpha_to_one: bool = True,
121
123
  steps_offset: int = 0,
122
124
  prediction_type: str = "epsilon",
@@ -267,6 +269,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
267
269
  " `v_prediction`"
268
270
  )
269
271
 
272
+ # 4. Clip or threshold "predicted x_0"
273
+ if self.config.clip_sample:
274
+ pred_original_sample = pred_original_sample.clip(
275
+ -self.config.clip_sample_range, self.config.clip_sample_range
276
+ )
277
+
270
278
  # 4. compute variance: "sigma_t(η)" -> see formula (16)
271
279
  # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
272
280
  variance = self._get_variance(state, timestep, prev_timestep)
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -155,9 +155,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
155
155
  there is no previous alpha. When this option is `True` the previous alpha product is fixed to 0, otherwise
156
156
  it uses the alpha value at step `num_train_timesteps - 1`.
157
157
  steps_offset (`int`, defaults to 0):
158
- An offset added to the inference steps. You can use a combination of `offset=1` and
159
- `set_alpha_to_one=False` to make the last step use `num_train_timesteps - 1` for the previous alpha
160
- product.
158
+ An offset added to the inference steps, as required by some model families.
161
159
  prediction_type (`str`, defaults to `epsilon`, *optional*):
162
160
  Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
163
161
  `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -1,4 +1,4 @@
1
- # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -159,9 +159,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
159
159
  step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
160
160
  otherwise it uses the value of alpha at step 0.
161
161
  steps_offset (`int`, default `0`):
162
- an offset added to the inference steps. You can use a combination of `offset=1` and
163
- `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
164
- stable diffusion.
162
+ An offset added to the inference steps, as required by some model families.
165
163
  prediction_type (`str`, default `epsilon`, optional):
166
164
  prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
167
165
  process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
@@ -1,4 +1,4 @@
1
- # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -167,9 +167,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
167
167
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
168
168
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
169
169
  steps_offset (`int`, defaults to 0):
170
- An offset added to the inference steps. You can use a combination of `offset=1` and
171
- `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
172
- Diffusion.
170
+ An offset added to the inference steps, as required by some model families.
173
171
  rescale_betas_zero_snr (`bool`, defaults to `False`):
174
172
  Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
175
173
  dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -1,4 +1,4 @@
1
- # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -173,9 +173,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
173
173
  The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
174
174
  Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
175
175
  steps_offset (`int`, default `0`):
176
- an offset added to the inference steps. You can use a combination of `offset=1` and
177
- `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
178
- stable diffusion.
176
+ An offset added to the inference steps, as required by some model families.
179
177
  rescale_betas_zero_snr (`bool`, defaults to `False`):
180
178
  Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
181
179
  dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -1,5 +1,5 @@
1
1
  # Copyright (c) 2022 Pablo Pernías MIT License
2
- # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 FLAIR Lab and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 FLAIR Lab and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -115,9 +115,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
115
115
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
116
116
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
117
117
  steps_offset (`int`, defaults to 0):
118
- An offset added to the inference steps. You can use a combination of `offset=1` and
119
- `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
120
- Diffusion.
118
+ An offset added to the inference steps, as required by some model families.
121
119
  """
122
120
 
123
121
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -187,6 +185,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
187
185
  self.model_outputs = [None] * solver_order
188
186
  self.lower_order_nums = 0
189
187
  self._step_index = None
188
+ self._begin_index = None
190
189
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
191
190
 
192
191
  @property
@@ -196,6 +195,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
196
195
  """
197
196
  return self._step_index
198
197
 
198
+ @property
199
+ def begin_index(self):
200
+ """
201
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
202
+ """
203
+ return self._begin_index
204
+
205
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
206
+ def set_begin_index(self, begin_index: int = 0):
207
+ """
208
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
209
+
210
+ Args:
211
+ begin_index (`int`):
212
+ The begin index for the scheduler.
213
+ """
214
+ self._begin_index = begin_index
215
+
199
216
  def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
200
217
  """
201
218
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -255,6 +272,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
255
272
 
256
273
  # add an index counter for schedulers that allow duplicated timesteps
257
274
  self._step_index = None
275
+ self._begin_index = None
258
276
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
259
277
 
260
278
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -620,11 +638,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
620
638
  else:
621
639
  raise NotImplementedError("only support log-rho multistep deis now")
622
640
 
623
- def _init_step_index(self, timestep):
624
- if isinstance(timestep, torch.Tensor):
625
- timestep = timestep.to(self.timesteps.device)
641
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
642
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
643
+ if schedule_timesteps is None:
644
+ schedule_timesteps = self.timesteps
626
645
 
627
- index_candidates = (self.timesteps == timestep).nonzero()
646
+ index_candidates = (schedule_timesteps == timestep).nonzero()
628
647
 
629
648
  if len(index_candidates) == 0:
630
649
  step_index = len(self.timesteps) - 1
@@ -637,7 +656,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
637
656
  else:
638
657
  step_index = index_candidates[0].item()
639
658
 
640
- self._step_index = step_index
659
+ return step_index
660
+
661
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
662
+ def _init_step_index(self, timestep):
663
+ """
664
+ Initialize the step_index counter for the scheduler.
665
+ """
666
+
667
+ if self.begin_index is None:
668
+ if isinstance(timestep, torch.Tensor):
669
+ timestep = timestep.to(self.timesteps.device)
670
+ self._step_index = self.index_for_timestep(timestep)
671
+ else:
672
+ self._step_index = self._begin_index
641
673
 
642
674
  def step(
643
675
  self,
@@ -736,16 +768,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
736
768
  schedule_timesteps = self.timesteps.to(original_samples.device)
737
769
  timesteps = timesteps.to(original_samples.device)
738
770
 
739
- step_indices = []
740
- for timestep in timesteps:
741
- index_candidates = (schedule_timesteps == timestep).nonzero()
742
- if len(index_candidates) == 0:
743
- step_index = len(schedule_timesteps) - 1
744
- elif len(index_candidates) > 1:
745
- step_index = index_candidates[1].item()
746
- else:
747
- step_index = index_candidates[0].item()
748
- step_indices.append(step_index)
771
+ # begin_index is None when the scheduler is used for training
772
+ if self.begin_index is None:
773
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
774
+ else:
775
+ step_indices = [self.begin_index] * timesteps.shape[0]
749
776
 
750
777
  sigma = sigmas[step_indices].flatten()
751
778
  while len(sigma.shape) < len(original_samples.shape):
@@ -1,4 +1,4 @@
1
- # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -71,6 +71,43 @@ def betas_for_alpha_bar(
71
71
  return torch.tensor(betas, dtype=torch.float32)
72
72
 
73
73
 
74
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
75
+ def rescale_zero_terminal_snr(betas):
76
+ """
77
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
78
+
79
+
80
+ Args:
81
+ betas (`torch.FloatTensor`):
82
+ the betas that the scheduler is being initialized with.
83
+
84
+ Returns:
85
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
86
+ """
87
+ # Convert betas to alphas_bar_sqrt
88
+ alphas = 1.0 - betas
89
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
90
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
91
+
92
+ # Store old values.
93
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
94
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
95
+
96
+ # Shift so the last timestep is zero.
97
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
98
+
99
+ # Scale so the first timestep is back to the old value.
100
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
101
+
102
+ # Convert alphas_bar_sqrt to betas
103
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
104
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
105
+ alphas = torch.cat([alphas_bar[0:1], alphas])
106
+ betas = 1 - alphas
107
+
108
+ return betas
109
+
110
+
74
111
  class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
75
112
  """
76
113
  `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
@@ -141,9 +178,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
141
178
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
142
179
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
143
180
  steps_offset (`int`, defaults to 0):
144
- An offset added to the inference steps. You can use a combination of `offset=1` and
145
- `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
146
- Diffusion.
181
+ An offset added to the inference steps, as required by some model families.
182
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
183
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
184
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
185
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
147
186
  """
148
187
 
149
188
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -173,6 +212,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
173
212
  variance_type: Optional[str] = None,
174
213
  timestep_spacing: str = "linspace",
175
214
  steps_offset: int = 0,
215
+ rescale_betas_zero_snr: bool = False,
176
216
  ):
177
217
  if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
178
218
  deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
@@ -191,8 +231,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
191
231
  else:
192
232
  raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
193
233
 
234
+ if rescale_betas_zero_snr:
235
+ self.betas = rescale_zero_terminal_snr(self.betas)
236
+
194
237
  self.alphas = 1.0 - self.betas
195
238
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
239
+
240
+ if rescale_betas_zero_snr:
241
+ # Close to 0 without being 0 so first sigma is not inf
242
+ # FP16 smallest positive subnormal works well here
243
+ self.alphas_cumprod[-1] = 2**-24
244
+
196
245
  # Currently we only support VP-type noise schedule
197
246
  self.alpha_t = torch.sqrt(self.alphas_cumprod)
198
247
  self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
@@ -227,6 +276,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
227
276
  self.model_outputs = [None] * solver_order
228
277
  self.lower_order_nums = 0
229
278
  self._step_index = None
279
+ self._begin_index = None
230
280
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
231
281
 
232
282
  @property
@@ -236,6 +286,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
236
286
  """
237
287
  return self._step_index
238
288
 
289
+ @property
290
+ def begin_index(self):
291
+ """
292
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
293
+ """
294
+ return self._begin_index
295
+
296
+ def set_begin_index(self, begin_index: int = 0):
297
+ """
298
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
299
+
300
+ Args:
301
+ begin_index (`int`):
302
+ The begin index for the scheduler.
303
+ """
304
+ self._begin_index = begin_index
305
+
239
306
  def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
240
307
  """
241
308
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -311,6 +378,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
311
378
 
312
379
  # add an index counter for schedulers that allow duplicated timesteps
313
380
  self._step_index = None
381
+ self._begin_index = None
314
382
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
315
383
 
316
384
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -792,11 +860,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
792
860
  )
793
861
  return x_t
794
862
 
795
- def _init_step_index(self, timestep):
796
- if isinstance(timestep, torch.Tensor):
797
- timestep = timestep.to(self.timesteps.device)
863
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
864
+ if schedule_timesteps is None:
865
+ schedule_timesteps = self.timesteps
798
866
 
799
- index_candidates = (self.timesteps == timestep).nonzero()
867
+ index_candidates = (schedule_timesteps == timestep).nonzero()
800
868
 
801
869
  if len(index_candidates) == 0:
802
870
  step_index = len(self.timesteps) - 1
@@ -809,7 +877,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
809
877
  else:
810
878
  step_index = index_candidates[0].item()
811
879
 
812
- self._step_index = step_index
880
+ return step_index
881
+
882
+ def _init_step_index(self, timestep):
883
+ """
884
+ Initialize the step_index counter for the scheduler.
885
+ """
886
+
887
+ if self.begin_index is None:
888
+ if isinstance(timestep, torch.Tensor):
889
+ timestep = timestep.to(self.timesteps.device)
890
+ self._step_index = self.index_for_timestep(timestep)
891
+ else:
892
+ self._step_index = self._begin_index
813
893
 
814
894
  def step(
815
895
  self,
@@ -817,6 +897,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
817
897
  timestep: int,
818
898
  sample: torch.FloatTensor,
819
899
  generator=None,
900
+ variance_noise: Optional[torch.FloatTensor] = None,
820
901
  return_dict: bool = True,
821
902
  ) -> Union[SchedulerOutput, Tuple]:
822
903
  """
@@ -832,6 +913,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
832
913
  A current instance of a sample created by the diffusion process.
833
914
  generator (`torch.Generator`, *optional*):
834
915
  A random number generator.
916
+ variance_noise (`torch.FloatTensor`):
917
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
918
+ itself. Useful for methods such as [`LEdits++`].
835
919
  return_dict (`bool`):
836
920
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
837
921
 
@@ -864,10 +948,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
864
948
  self.model_outputs[i] = self.model_outputs[i + 1]
865
949
  self.model_outputs[-1] = model_output
866
950
 
867
- if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
951
+ # Upcast to avoid precision issues when computing prev_sample
952
+ sample = sample.to(torch.float32)
953
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
868
954
  noise = randn_tensor(
869
- model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
955
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
870
956
  )
957
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
958
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
871
959
  else:
872
960
  noise = None
873
961
 
@@ -881,6 +969,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
881
969
  if self.lower_order_nums < self.config.solver_order:
882
970
  self.lower_order_nums += 1
883
971
 
972
+ # Cast sample back to expected dtype
973
+ prev_sample = prev_sample.to(model_output.dtype)
974
+
884
975
  # upon completion increase step index by one
885
976
  self._step_index += 1
886
977
 
@@ -920,16 +1011,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
920
1011
  schedule_timesteps = self.timesteps.to(original_samples.device)
921
1012
  timesteps = timesteps.to(original_samples.device)
922
1013
 
923
- step_indices = []
924
- for timestep in timesteps:
925
- index_candidates = (schedule_timesteps == timestep).nonzero()
926
- if len(index_candidates) == 0:
927
- step_index = len(schedule_timesteps) - 1
928
- elif len(index_candidates) > 1:
929
- step_index = index_candidates[1].item()
930
- else:
931
- step_index = index_candidates[0].item()
932
- step_indices.append(step_index)
1014
+ # begin_index is None when the scheduler is used for training
1015
+ if self.begin_index is None:
1016
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1017
+ else:
1018
+ step_indices = [self.begin_index] * timesteps.shape[0]
933
1019
 
934
1020
  sigma = sigmas[step_indices].flatten()
935
1021
  while len(sigma.shape) < len(original_samples.shape):
@@ -1,4 +1,4 @@
1
- # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.