diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +28 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +278 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. diffusers-0.26.2.dist-info/RECORD +0 -384
  296. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  297. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
  298. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,381 @@
1
+ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, logging
23
+ from ..utils.torch_utils import randn_tensor
24
+ from .scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
32
+ class EDMEulerSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's `step` function output.
35
+
36
+ Args:
37
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39
+ denoising loop.
40
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+ pred_original_sample: Optional[torch.FloatTensor] = None
47
+
48
+
49
+ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
50
+ """
51
+ Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1].
52
+
53
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
54
+ https://arxiv.org/abs/2206.00364
55
+
56
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
57
+ methods the library implements for all schedulers such as loading and saving.
58
+
59
+ Args:
60
+ sigma_min (`float`, *optional*, defaults to 0.002):
61
+ Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
62
+ range is [0, 10].
63
+ sigma_max (`float`, *optional*, defaults to 80.0):
64
+ Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
65
+ range is [0.2, 80.0].
66
+ sigma_data (`float`, *optional*, defaults to 0.5):
67
+ The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
68
+ num_train_timesteps (`int`, defaults to 1000):
69
+ The number of diffusion steps to train the model.
70
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
71
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
72
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
73
+ Video](https://imagen.research.google/video/paper.pdf) paper).
74
+ rho (`float`, *optional*, defaults to 7.0):
75
+ The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
76
+ """
77
+
78
+ _compatibles = []
79
+ order = 1
80
+
81
+ @register_to_config
82
+ def __init__(
83
+ self,
84
+ sigma_min: float = 0.002,
85
+ sigma_max: float = 80.0,
86
+ sigma_data: float = 0.5,
87
+ num_train_timesteps: int = 1000,
88
+ prediction_type: str = "epsilon",
89
+ rho: float = 7.0,
90
+ ):
91
+ # setable values
92
+ self.num_inference_steps = None
93
+
94
+ ramp = torch.linspace(0, 1, num_train_timesteps)
95
+ sigmas = self._compute_sigmas(ramp)
96
+ self.timesteps = self.precondition_noise(sigmas)
97
+
98
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
99
+
100
+ self.is_scale_input_called = False
101
+
102
+ self._step_index = None
103
+ self._begin_index = None
104
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
105
+
106
+ @property
107
+ def init_noise_sigma(self):
108
+ # standard deviation of the initial noise distribution
109
+ return (self.config.sigma_max**2 + 1) ** 0.5
110
+
111
+ @property
112
+ def step_index(self):
113
+ """
114
+ The index counter for current timestep. It will increae 1 after each scheduler step.
115
+ """
116
+ return self._step_index
117
+
118
+ @property
119
+ def begin_index(self):
120
+ """
121
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
122
+ """
123
+ return self._begin_index
124
+
125
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
126
+ def set_begin_index(self, begin_index: int = 0):
127
+ """
128
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
129
+
130
+ Args:
131
+ begin_index (`int`):
132
+ The begin index for the scheduler.
133
+ """
134
+ self._begin_index = begin_index
135
+
136
+ def precondition_inputs(self, sample, sigma):
137
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
138
+ scaled_sample = sample * c_in
139
+ return scaled_sample
140
+
141
+ def precondition_noise(self, sigma):
142
+ if not isinstance(sigma, torch.Tensor):
143
+ sigma = torch.tensor([sigma])
144
+
145
+ c_noise = 0.25 * torch.log(sigma)
146
+
147
+ return c_noise
148
+
149
+ def precondition_outputs(self, sample, model_output, sigma):
150
+ sigma_data = self.config.sigma_data
151
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
152
+
153
+ if self.config.prediction_type == "epsilon":
154
+ c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
155
+ elif self.config.prediction_type == "v_prediction":
156
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
157
+ else:
158
+ raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
159
+
160
+ denoised = c_skip * sample + c_out * model_output
161
+
162
+ return denoised
163
+
164
+ def scale_model_input(
165
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
166
+ ) -> torch.FloatTensor:
167
+ """
168
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
169
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
170
+
171
+ Args:
172
+ sample (`torch.FloatTensor`):
173
+ The input sample.
174
+ timestep (`int`, *optional*):
175
+ The current timestep in the diffusion chain.
176
+
177
+ Returns:
178
+ `torch.FloatTensor`:
179
+ A scaled input sample.
180
+ """
181
+ if self.step_index is None:
182
+ self._init_step_index(timestep)
183
+
184
+ sigma = self.sigmas[self.step_index]
185
+ sample = self.precondition_inputs(sample, sigma)
186
+
187
+ self.is_scale_input_called = True
188
+ return sample
189
+
190
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
191
+ """
192
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
193
+
194
+ Args:
195
+ num_inference_steps (`int`):
196
+ The number of diffusion steps used when generating samples with a pre-trained model.
197
+ device (`str` or `torch.device`, *optional*):
198
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
199
+ """
200
+ self.num_inference_steps = num_inference_steps
201
+
202
+ ramp = np.linspace(0, 1, self.num_inference_steps)
203
+ sigmas = self._compute_sigmas(ramp)
204
+
205
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
206
+ self.timesteps = self.precondition_noise(sigmas)
207
+
208
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
209
+ self._step_index = None
210
+ self._begin_index = None
211
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
212
+
213
+ # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
214
+ def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
215
+ """Constructs the noise schedule of Karras et al. (2022)."""
216
+
217
+ sigma_min = sigma_min or self.config.sigma_min
218
+ sigma_max = sigma_max or self.config.sigma_max
219
+
220
+ rho = self.config.rho
221
+ min_inv_rho = sigma_min ** (1 / rho)
222
+ max_inv_rho = sigma_max ** (1 / rho)
223
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
224
+ return sigmas
225
+
226
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
227
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
228
+ if schedule_timesteps is None:
229
+ schedule_timesteps = self.timesteps
230
+
231
+ indices = (schedule_timesteps == timestep).nonzero()
232
+
233
+ # The sigma index that is taken for the **very** first `step`
234
+ # is always the second index (or the last index if there is only 1)
235
+ # This way we can ensure we don't accidentally skip a sigma in
236
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
237
+ pos = 1 if len(indices) > 1 else 0
238
+
239
+ return indices[pos].item()
240
+
241
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
242
+ def _init_step_index(self, timestep):
243
+ if self.begin_index is None:
244
+ if isinstance(timestep, torch.Tensor):
245
+ timestep = timestep.to(self.timesteps.device)
246
+ self._step_index = self.index_for_timestep(timestep)
247
+ else:
248
+ self._step_index = self._begin_index
249
+
250
+ def step(
251
+ self,
252
+ model_output: torch.FloatTensor,
253
+ timestep: Union[float, torch.FloatTensor],
254
+ sample: torch.FloatTensor,
255
+ s_churn: float = 0.0,
256
+ s_tmin: float = 0.0,
257
+ s_tmax: float = float("inf"),
258
+ s_noise: float = 1.0,
259
+ generator: Optional[torch.Generator] = None,
260
+ return_dict: bool = True,
261
+ ) -> Union[EDMEulerSchedulerOutput, Tuple]:
262
+ """
263
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
264
+ process from the learned model outputs (most often the predicted noise).
265
+
266
+ Args:
267
+ model_output (`torch.FloatTensor`):
268
+ The direct output from learned diffusion model.
269
+ timestep (`float`):
270
+ The current discrete timestep in the diffusion chain.
271
+ sample (`torch.FloatTensor`):
272
+ A current instance of a sample created by the diffusion process.
273
+ s_churn (`float`):
274
+ s_tmin (`float`):
275
+ s_tmax (`float`):
276
+ s_noise (`float`, defaults to 1.0):
277
+ Scaling factor for noise added to the sample.
278
+ generator (`torch.Generator`, *optional*):
279
+ A random number generator.
280
+ return_dict (`bool`):
281
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or
282
+ tuple.
283
+
284
+ Returns:
285
+ [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
286
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
287
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
288
+ """
289
+
290
+ if (
291
+ isinstance(timestep, int)
292
+ or isinstance(timestep, torch.IntTensor)
293
+ or isinstance(timestep, torch.LongTensor)
294
+ ):
295
+ raise ValueError(
296
+ (
297
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
298
+ " `EDMEulerScheduler.step()` is not supported. Make sure to pass"
299
+ " one of the `scheduler.timesteps` as a timestep."
300
+ ),
301
+ )
302
+
303
+ if not self.is_scale_input_called:
304
+ logger.warning(
305
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
306
+ "See `StableDiffusionPipeline` for a usage example."
307
+ )
308
+
309
+ if self.step_index is None:
310
+ self._init_step_index(timestep)
311
+
312
+ # Upcast to avoid precision issues when computing prev_sample
313
+ sample = sample.to(torch.float32)
314
+
315
+ sigma = self.sigmas[self.step_index]
316
+
317
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
318
+
319
+ noise = randn_tensor(
320
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
321
+ )
322
+
323
+ eps = noise * s_noise
324
+ sigma_hat = sigma * (gamma + 1)
325
+
326
+ if gamma > 0:
327
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
328
+
329
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
330
+ pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
331
+
332
+ # 2. Convert to an ODE derivative
333
+ derivative = (sample - pred_original_sample) / sigma_hat
334
+
335
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
336
+
337
+ prev_sample = sample + derivative * dt
338
+
339
+ # Cast sample back to model compatible dtype
340
+ prev_sample = prev_sample.to(model_output.dtype)
341
+
342
+ # upon completion increase step index by one
343
+ self._step_index += 1
344
+
345
+ if not return_dict:
346
+ return (prev_sample,)
347
+
348
+ return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
349
+
350
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
351
+ def add_noise(
352
+ self,
353
+ original_samples: torch.FloatTensor,
354
+ noise: torch.FloatTensor,
355
+ timesteps: torch.FloatTensor,
356
+ ) -> torch.FloatTensor:
357
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
358
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
359
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
360
+ # mps does not support float64
361
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
362
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
363
+ else:
364
+ schedule_timesteps = self.timesteps.to(original_samples.device)
365
+ timesteps = timesteps.to(original_samples.device)
366
+
367
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
368
+ if self.begin_index is None:
369
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
370
+ else:
371
+ step_indices = [self.begin_index] * timesteps.shape[0]
372
+
373
+ sigma = sigmas[step_indices].flatten()
374
+ while len(sigma.shape) < len(original_samples.shape):
375
+ sigma = sigma.unsqueeze(-1)
376
+
377
+ noisy_samples = original_samples + noise * sigma
378
+ return noisy_samples
379
+
380
+ def __len__(self):
381
+ return self.config.num_train_timesteps
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -156,9 +156,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
156
156
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
157
157
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
158
158
  steps_offset (`int`, defaults to 0):
159
- An offset added to the inference steps. You can use a combination of `offset=1` and
160
- `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
161
- Diffusion.
159
+ An offset added to the inference steps, as required by some model families.
162
160
  rescale_betas_zero_snr (`bool`, defaults to `False`):
163
161
  Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
164
162
  dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -216,6 +214,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
216
214
  self.is_scale_input_called = False
217
215
 
218
216
  self._step_index = None
217
+ self._begin_index = None
219
218
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
220
219
 
221
220
  @property
@@ -233,6 +232,24 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
233
232
  """
234
233
  return self._step_index
235
234
 
235
+ @property
236
+ def begin_index(self):
237
+ """
238
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
239
+ """
240
+ return self._begin_index
241
+
242
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
243
+ def set_begin_index(self, begin_index: int = 0):
244
+ """
245
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
246
+
247
+ Args:
248
+ begin_index (`int`):
249
+ The begin index for the scheduler.
250
+ """
251
+ self._begin_index = begin_index
252
+
236
253
  def scale_model_input(
237
254
  self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
238
255
  ) -> torch.FloatTensor:
@@ -300,25 +317,32 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
300
317
 
301
318
  self.timesteps = torch.from_numpy(timesteps).to(device=device)
302
319
  self._step_index = None
320
+ self._begin_index = None
303
321
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
304
322
 
305
- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
306
- def _init_step_index(self, timestep):
307
- if isinstance(timestep, torch.Tensor):
308
- timestep = timestep.to(self.timesteps.device)
323
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
324
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
325
+ if schedule_timesteps is None:
326
+ schedule_timesteps = self.timesteps
309
327
 
310
- index_candidates = (self.timesteps == timestep).nonzero()
328
+ indices = (schedule_timesteps == timestep).nonzero()
311
329
 
312
330
  # The sigma index that is taken for the **very** first `step`
313
331
  # is always the second index (or the last index if there is only 1)
314
332
  # This way we can ensure we don't accidentally skip a sigma in
315
333
  # case we start in the middle of the denoising schedule (e.g. for image-to-image)
316
- if len(index_candidates) > 1:
317
- step_index = index_candidates[1]
318
- else:
319
- step_index = index_candidates[0]
334
+ pos = 1 if len(indices) > 1 else 0
335
+
336
+ return indices[pos].item()
320
337
 
321
- self._step_index = step_index.item()
338
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
339
+ def _init_step_index(self, timestep):
340
+ if self.begin_index is None:
341
+ if isinstance(timestep, torch.Tensor):
342
+ timestep = timestep.to(self.timesteps.device)
343
+ self._step_index = self.index_for_timestep(timestep)
344
+ else:
345
+ self._step_index = self._begin_index
322
346
 
323
347
  def step(
324
348
  self,
@@ -440,7 +464,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
440
464
  schedule_timesteps = self.timesteps.to(original_samples.device)
441
465
  timesteps = timesteps.to(original_samples.device)
442
466
 
443
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
467
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
468
+ if self.begin_index is None:
469
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
470
+ else:
471
+ step_indices = [self.begin_index] * timesteps.shape[0]
444
472
 
445
473
  sigma = sigmas[step_indices].flatten()
446
474
  while len(sigma.shape) < len(original_samples.shape):
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -162,9 +162,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
162
162
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
163
163
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
164
164
  steps_offset (`int`, defaults to 0):
165
- An offset added to the inference steps. You can use a combination of `offset=1` and
166
- `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
167
- Diffusion.
165
+ An offset added to the inference steps, as required by some model families.
168
166
  rescale_betas_zero_snr (`bool`, defaults to `False`):
169
167
  Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
170
168
  dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -216,10 +214,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
216
214
  # FP16 smallest positive subnormal works well here
217
215
  self.alphas_cumprod[-1] = 2**-24
218
216
 
219
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
217
+ sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
220
218
  timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
221
-
222
- sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
223
219
  timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
224
220
 
225
221
  # setable values
@@ -237,6 +233,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
237
233
  self.use_karras_sigmas = use_karras_sigmas
238
234
 
239
235
  self._step_index = None
236
+ self._begin_index = None
240
237
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
241
238
 
242
239
  @property
@@ -255,6 +252,24 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
255
252
  """
256
253
  return self._step_index
257
254
 
255
+ @property
256
+ def begin_index(self):
257
+ """
258
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
259
+ """
260
+ return self._begin_index
261
+
262
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
263
+ def set_begin_index(self, begin_index: int = 0):
264
+ """
265
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
266
+
267
+ Args:
268
+ begin_index (`int`):
269
+ The begin index for the scheduler.
270
+ """
271
+ self._begin_index = begin_index
272
+
258
273
  def scale_model_input(
259
274
  self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
260
275
  ) -> torch.FloatTensor:
@@ -342,6 +357,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
342
357
 
343
358
  self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
344
359
  self._step_index = None
360
+ self._begin_index = None
345
361
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
346
362
 
347
363
  def _sigma_to_t(self, sigma, log_sigmas):
@@ -393,22 +409,27 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
393
409
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
394
410
  return sigmas
395
411
 
396
- def _init_step_index(self, timestep):
397
- if isinstance(timestep, torch.Tensor):
398
- timestep = timestep.to(self.timesteps.device)
412
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
413
+ if schedule_timesteps is None:
414
+ schedule_timesteps = self.timesteps
399
415
 
400
- index_candidates = (self.timesteps == timestep).nonzero()
416
+ indices = (schedule_timesteps == timestep).nonzero()
401
417
 
402
418
  # The sigma index that is taken for the **very** first `step`
403
419
  # is always the second index (or the last index if there is only 1)
404
420
  # This way we can ensure we don't accidentally skip a sigma in
405
421
  # case we start in the middle of the denoising schedule (e.g. for image-to-image)
406
- if len(index_candidates) > 1:
407
- step_index = index_candidates[1]
408
- else:
409
- step_index = index_candidates[0]
422
+ pos = 1 if len(indices) > 1 else 0
410
423
 
411
- self._step_index = step_index.item()
424
+ return indices[pos].item()
425
+
426
+ def _init_step_index(self, timestep):
427
+ if self.begin_index is None:
428
+ if isinstance(timestep, torch.Tensor):
429
+ timestep = timestep.to(self.timesteps.device)
430
+ self._step_index = self.index_for_timestep(timestep)
431
+ else:
432
+ self._step_index = self._begin_index
412
433
 
413
434
  def step(
414
435
  self,
@@ -538,7 +559,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
538
559
  schedule_timesteps = self.timesteps.to(original_samples.device)
539
560
  timesteps = timesteps.to(original_samples.device)
540
561
 
541
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
562
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
563
+ if self.begin_index is None:
564
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
565
+ else:
566
+ step_indices = [self.begin_index] * timesteps.shape[0]
542
567
 
543
568
  sigma = sigmas[step_indices].flatten()
544
569
  while len(sigma.shape) < len(original_samples.shape):
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.