diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
 
15
15
  # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
16
16
 
17
+ import math
17
18
  from typing import List, Optional, Tuple, Union
18
19
 
19
20
  import numpy as np
@@ -44,6 +45,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
44
45
  range is [0.2, 80.0].
45
46
  sigma_data (`float`, *optional*, defaults to 0.5):
46
47
  The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
48
+ sigma_schedule (`str`, *optional*, defaults to `karras`):
49
+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
50
+ (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
51
+ incorporated in this model: https://huggingface.co/stabilityai/cosxl.
47
52
  num_train_timesteps (`int`, defaults to 1000):
48
53
  The number of diffusion steps to train the model.
49
54
  solver_order (`int`, defaults to 2):
@@ -62,10 +67,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
62
67
  The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
63
68
  `algorithm_type="dpmsolver++"`.
64
69
  algorithm_type (`str`, defaults to `dpmsolver++`):
65
- Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The
66
- `dpmsolver++` type implements the algorithms in the
67
- [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
68
- `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
70
+ Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements
71
+ the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to
72
+ use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
69
73
  solver_type (`str`, defaults to `midpoint`):
70
74
  Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
71
75
  sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -77,8 +81,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
77
81
  richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
78
82
  steps, but sometimes may result in blurring.
79
83
  final_sigmas_type (`str`, defaults to `"zero"`):
80
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
81
- is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
84
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
85
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
82
86
  """
83
87
 
84
88
  _compatibles = []
@@ -90,6 +94,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
90
94
  sigma_min: float = 0.002,
91
95
  sigma_max: float = 80.0,
92
96
  sigma_data: float = 0.5,
97
+ sigma_schedule: str = "karras",
93
98
  num_train_timesteps: int = 1000,
94
99
  prediction_type: str = "epsilon",
95
100
  rho: float = 7.0,
@@ -114,7 +119,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
114
119
  if solver_type in ["logrho", "bh1", "bh2"]:
115
120
  self.register_to_config(solver_type="midpoint")
116
121
  else:
117
- raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
122
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
118
123
 
119
124
  if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
120
125
  raise ValueError(
@@ -122,7 +127,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
122
127
  )
123
128
 
124
129
  ramp = torch.linspace(0, 1, num_train_timesteps)
125
- sigmas = self._compute_sigmas(ramp)
130
+ if sigma_schedule == "karras":
131
+ sigmas = self._compute_karras_sigmas(ramp)
132
+ elif sigma_schedule == "exponential":
133
+ sigmas = self._compute_exponential_sigmas(ramp)
134
+
126
135
  self.timesteps = self.precondition_noise(sigmas)
127
136
 
128
137
  self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
@@ -143,7 +152,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
143
152
  @property
144
153
  def step_index(self):
145
154
  """
146
- The index counter for current timestep. It will increae 1 after each scheduler step.
155
+ The index counter for current timestep. It will increase 1 after each scheduler step.
147
156
  """
148
157
  return self._step_index
149
158
 
@@ -197,21 +206,19 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
197
206
  return denoised
198
207
 
199
208
  # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
200
- def scale_model_input(
201
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
202
- ) -> torch.FloatTensor:
209
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
203
210
  """
204
211
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
205
212
  current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
206
213
 
207
214
  Args:
208
- sample (`torch.FloatTensor`):
215
+ sample (`torch.Tensor`):
209
216
  The input sample.
210
217
  timestep (`int`, *optional*):
211
218
  The current timestep in the diffusion chain.
212
219
 
213
220
  Returns:
214
- `torch.FloatTensor`:
221
+ `torch.Tensor`:
215
222
  A scaled input sample.
216
223
  """
217
224
  if self.step_index is None:
@@ -237,7 +244,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
237
244
  self.num_inference_steps = num_inference_steps
238
245
 
239
246
  ramp = np.linspace(0, 1, self.num_inference_steps)
240
- sigmas = self._compute_sigmas(ramp)
247
+ if self.config.sigma_schedule == "karras":
248
+ sigmas = self._compute_karras_sigmas(ramp)
249
+ elif self.config.sigma_schedule == "exponential":
250
+ sigmas = self._compute_exponential_sigmas(ramp)
241
251
 
242
252
  sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
243
253
  self.timesteps = self.precondition_noise(sigmas)
@@ -263,10 +273,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
263
273
  self._begin_index = None
264
274
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
265
275
 
266
- # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
267
- def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
276
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
277
+ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
268
278
  """Constructs the noise schedule of Karras et al. (2022)."""
269
-
270
279
  sigma_min = sigma_min or self.config.sigma_min
271
280
  sigma_max = sigma_max or self.config.sigma_max
272
281
 
@@ -274,10 +283,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
274
283
  min_inv_rho = sigma_min ** (1 / rho)
275
284
  max_inv_rho = sigma_max ** (1 / rho)
276
285
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
286
+
287
+ return sigmas
288
+
289
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
290
+ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
291
+ """Implementation closely follows k-diffusion.
292
+
293
+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
294
+ """
295
+ sigma_min = sigma_min or self.config.sigma_min
296
+ sigma_max = sigma_max or self.config.sigma_max
297
+ sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
277
298
  return sigmas
278
299
 
279
300
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
280
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
301
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
281
302
  """
282
303
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
283
304
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -342,9 +363,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
342
363
 
343
364
  def convert_model_output(
344
365
  self,
345
- model_output: torch.FloatTensor,
346
- sample: torch.FloatTensor = None,
347
- ) -> torch.FloatTensor:
366
+ model_output: torch.Tensor,
367
+ sample: torch.Tensor = None,
368
+ ) -> torch.Tensor:
348
369
  """
349
370
  Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
350
371
  designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
@@ -358,13 +379,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
358
379
  </Tip>
359
380
 
360
381
  Args:
361
- model_output (`torch.FloatTensor`):
382
+ model_output (`torch.Tensor`):
362
383
  The direct output from the learned diffusion model.
363
- sample (`torch.FloatTensor`):
384
+ sample (`torch.Tensor`):
364
385
  A current instance of a sample created by the diffusion process.
365
386
 
366
387
  Returns:
367
- `torch.FloatTensor`:
388
+ `torch.Tensor`:
368
389
  The converted model output.
369
390
  """
370
391
  sigma = self.sigmas[self.step_index]
@@ -377,21 +398,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
377
398
 
378
399
  def dpm_solver_first_order_update(
379
400
  self,
380
- model_output: torch.FloatTensor,
381
- sample: torch.FloatTensor = None,
382
- noise: Optional[torch.FloatTensor] = None,
383
- ) -> torch.FloatTensor:
401
+ model_output: torch.Tensor,
402
+ sample: torch.Tensor = None,
403
+ noise: Optional[torch.Tensor] = None,
404
+ ) -> torch.Tensor:
384
405
  """
385
406
  One step for the first-order DPMSolver (equivalent to DDIM).
386
407
 
387
408
  Args:
388
- model_output (`torch.FloatTensor`):
409
+ model_output (`torch.Tensor`):
389
410
  The direct output from the learned diffusion model.
390
- sample (`torch.FloatTensor`):
411
+ sample (`torch.Tensor`):
391
412
  A current instance of a sample created by the diffusion process.
392
413
 
393
414
  Returns:
394
- `torch.FloatTensor`:
415
+ `torch.Tensor`:
395
416
  The sample tensor at the previous timestep.
396
417
  """
397
418
  sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
@@ -415,21 +436,21 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
415
436
 
416
437
  def multistep_dpm_solver_second_order_update(
417
438
  self,
418
- model_output_list: List[torch.FloatTensor],
419
- sample: torch.FloatTensor = None,
420
- noise: Optional[torch.FloatTensor] = None,
421
- ) -> torch.FloatTensor:
439
+ model_output_list: List[torch.Tensor],
440
+ sample: torch.Tensor = None,
441
+ noise: Optional[torch.Tensor] = None,
442
+ ) -> torch.Tensor:
422
443
  """
423
444
  One step for the second-order multistep DPMSolver.
424
445
 
425
446
  Args:
426
- model_output_list (`List[torch.FloatTensor]`):
447
+ model_output_list (`List[torch.Tensor]`):
427
448
  The direct outputs from learned diffusion model at current and latter timesteps.
428
- sample (`torch.FloatTensor`):
449
+ sample (`torch.Tensor`):
429
450
  A current instance of a sample created by the diffusion process.
430
451
 
431
452
  Returns:
432
- `torch.FloatTensor`:
453
+ `torch.Tensor`:
433
454
  The sample tensor at the previous timestep.
434
455
  """
435
456
  sigma_t, sigma_s0, sigma_s1 = (
@@ -486,20 +507,20 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
486
507
 
487
508
  def multistep_dpm_solver_third_order_update(
488
509
  self,
489
- model_output_list: List[torch.FloatTensor],
490
- sample: torch.FloatTensor = None,
491
- ) -> torch.FloatTensor:
510
+ model_output_list: List[torch.Tensor],
511
+ sample: torch.Tensor = None,
512
+ ) -> torch.Tensor:
492
513
  """
493
514
  One step for the third-order multistep DPMSolver.
494
515
 
495
516
  Args:
496
- model_output_list (`List[torch.FloatTensor]`):
517
+ model_output_list (`List[torch.Tensor]`):
497
518
  The direct outputs from learned diffusion model at current and latter timesteps.
498
- sample (`torch.FloatTensor`):
519
+ sample (`torch.Tensor`):
499
520
  A current instance of a sample created by diffusion process.
500
521
 
501
522
  Returns:
502
- `torch.FloatTensor`:
523
+ `torch.Tensor`:
503
524
  The sample tensor at the previous timestep.
504
525
  """
505
526
  sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
@@ -573,9 +594,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
573
594
 
574
595
  def step(
575
596
  self,
576
- model_output: torch.FloatTensor,
597
+ model_output: torch.Tensor,
577
598
  timestep: int,
578
- sample: torch.FloatTensor,
599
+ sample: torch.Tensor,
579
600
  generator=None,
580
601
  return_dict: bool = True,
581
602
  ) -> Union[SchedulerOutput, Tuple]:
@@ -584,11 +605,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
584
605
  the multistep DPMSolver.
585
606
 
586
607
  Args:
587
- model_output (`torch.FloatTensor`):
608
+ model_output (`torch.Tensor`):
588
609
  The direct output from learned diffusion model.
589
610
  timestep (`int`):
590
611
  The current discrete timestep in the diffusion chain.
591
- sample (`torch.FloatTensor`):
612
+ sample (`torch.Tensor`):
592
613
  A current instance of a sample created by the diffusion process.
593
614
  generator (`torch.Generator`, *optional*):
594
615
  A random number generator.
@@ -652,10 +673,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
652
673
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
653
674
  def add_noise(
654
675
  self,
655
- original_samples: torch.FloatTensor,
656
- noise: torch.FloatTensor,
657
- timesteps: torch.FloatTensor,
658
- ) -> torch.FloatTensor:
676
+ original_samples: torch.Tensor,
677
+ noise: torch.Tensor,
678
+ timesteps: torch.Tensor,
679
+ ) -> torch.Tensor:
659
680
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
660
681
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
661
682
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -669,7 +690,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
669
690
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
670
691
  if self.begin_index is None:
671
692
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
693
+ elif self.step_index is not None:
694
+ # add_noise is called after first denoising step (for inpainting)
695
+ step_indices = [self.step_index] * timesteps.shape[0]
672
696
  else:
697
+ # add noise is called before first denoising step to create initial latent(img2img)
673
698
  step_indices = [self.begin_index] * timesteps.shape[0]
674
699
 
675
700
  sigma = sigmas[step_indices].flatten()
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import math
15
16
  from dataclasses import dataclass
16
17
  from typing import Optional, Tuple, Union
17
18
 
@@ -34,16 +35,16 @@ class EDMEulerSchedulerOutput(BaseOutput):
34
35
  Output class for the scheduler's `step` function output.
35
36
 
36
37
  Args:
37
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38
39
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39
40
  denoising loop.
40
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
42
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
42
43
  `pred_original_sample` can be used to preview progress or for guidance.
43
44
  """
44
45
 
45
- prev_sample: torch.FloatTensor
46
- pred_original_sample: Optional[torch.FloatTensor] = None
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
47
48
 
48
49
 
49
50
  class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
@@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
65
66
  range is [0.2, 80.0].
66
67
  sigma_data (`float`, *optional*, defaults to 0.5):
67
68
  The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
69
+ sigma_schedule (`str`, *optional*, defaults to `karras`):
70
+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
71
+ (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
72
+ incorporated in this model: https://huggingface.co/stabilityai/cosxl.
68
73
  num_train_timesteps (`int`, defaults to 1000):
69
74
  The number of diffusion steps to train the model.
70
75
  prediction_type (`str`, defaults to `epsilon`, *optional*):
@@ -84,15 +89,23 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
84
89
  sigma_min: float = 0.002,
85
90
  sigma_max: float = 80.0,
86
91
  sigma_data: float = 0.5,
92
+ sigma_schedule: str = "karras",
87
93
  num_train_timesteps: int = 1000,
88
94
  prediction_type: str = "epsilon",
89
95
  rho: float = 7.0,
90
96
  ):
97
+ if sigma_schedule not in ["karras", "exponential"]:
98
+ raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
99
+
91
100
  # setable values
92
101
  self.num_inference_steps = None
93
102
 
94
103
  ramp = torch.linspace(0, 1, num_train_timesteps)
95
- sigmas = self._compute_sigmas(ramp)
104
+ if sigma_schedule == "karras":
105
+ sigmas = self._compute_karras_sigmas(ramp)
106
+ elif sigma_schedule == "exponential":
107
+ sigmas = self._compute_exponential_sigmas(ramp)
108
+
96
109
  self.timesteps = self.precondition_noise(sigmas)
97
110
 
98
111
  self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
@@ -111,7 +124,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
111
124
  @property
112
125
  def step_index(self):
113
126
  """
114
- The index counter for current timestep. It will increae 1 after each scheduler step.
127
+ The index counter for current timestep. It will increase 1 after each scheduler step.
115
128
  """
116
129
  return self._step_index
117
130
 
@@ -161,21 +174,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
161
174
 
162
175
  return denoised
163
176
 
164
- def scale_model_input(
165
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
166
- ) -> torch.FloatTensor:
177
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
167
178
  """
168
179
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
169
180
  current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
170
181
 
171
182
  Args:
172
- sample (`torch.FloatTensor`):
183
+ sample (`torch.Tensor`):
173
184
  The input sample.
174
185
  timestep (`int`, *optional*):
175
186
  The current timestep in the diffusion chain.
176
187
 
177
188
  Returns:
178
- `torch.FloatTensor`:
189
+ `torch.Tensor`:
179
190
  A scaled input sample.
180
191
  """
181
192
  if self.step_index is None:
@@ -200,7 +211,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
200
211
  self.num_inference_steps = num_inference_steps
201
212
 
202
213
  ramp = np.linspace(0, 1, self.num_inference_steps)
203
- sigmas = self._compute_sigmas(ramp)
214
+ if self.config.sigma_schedule == "karras":
215
+ sigmas = self._compute_karras_sigmas(ramp)
216
+ elif self.config.sigma_schedule == "exponential":
217
+ sigmas = self._compute_exponential_sigmas(ramp)
204
218
 
205
219
  sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
206
220
  self.timesteps = self.precondition_noise(sigmas)
@@ -211,9 +225,8 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
211
225
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
212
226
 
213
227
  # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
214
- def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
228
+ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
215
229
  """Constructs the noise schedule of Karras et al. (2022)."""
216
-
217
230
  sigma_min = sigma_min or self.config.sigma_min
218
231
  sigma_max = sigma_max or self.config.sigma_max
219
232
 
@@ -221,6 +234,17 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
221
234
  min_inv_rho = sigma_min ** (1 / rho)
222
235
  max_inv_rho = sigma_max ** (1 / rho)
223
236
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
237
+
238
+ return sigmas
239
+
240
+ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
241
+ """Implementation closely follows k-diffusion.
242
+
243
+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
244
+ """
245
+ sigma_min = sigma_min or self.config.sigma_min
246
+ sigma_max = sigma_max or self.config.sigma_max
247
+ sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
224
248
  return sigmas
225
249
 
226
250
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
@@ -249,9 +273,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
249
273
 
250
274
  def step(
251
275
  self,
252
- model_output: torch.FloatTensor,
253
- timestep: Union[float, torch.FloatTensor],
254
- sample: torch.FloatTensor,
276
+ model_output: torch.Tensor,
277
+ timestep: Union[float, torch.Tensor],
278
+ sample: torch.Tensor,
255
279
  s_churn: float = 0.0,
256
280
  s_tmin: float = 0.0,
257
281
  s_tmax: float = float("inf"),
@@ -264,11 +288,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
264
288
  process from the learned model outputs (most often the predicted noise).
265
289
 
266
290
  Args:
267
- model_output (`torch.FloatTensor`):
291
+ model_output (`torch.Tensor`):
268
292
  The direct output from learned diffusion model.
269
293
  timestep (`float`):
270
294
  The current discrete timestep in the diffusion chain.
271
- sample (`torch.FloatTensor`):
295
+ sample (`torch.Tensor`):
272
296
  A current instance of a sample created by the diffusion process.
273
297
  s_churn (`float`):
274
298
  s_tmin (`float`):
@@ -278,8 +302,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
278
302
  generator (`torch.Generator`, *optional*):
279
303
  A random number generator.
280
304
  return_dict (`bool`):
281
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or
282
- tuple.
305
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
283
306
 
284
307
  Returns:
285
308
  [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
@@ -287,11 +310,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
287
310
  returned, otherwise a tuple is returned where the first element is the sample tensor.
288
311
  """
289
312
 
290
- if (
291
- isinstance(timestep, int)
292
- or isinstance(timestep, torch.IntTensor)
293
- or isinstance(timestep, torch.LongTensor)
294
- ):
313
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
295
314
  raise ValueError(
296
315
  (
297
316
  "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
@@ -350,10 +369,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
350
369
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
351
370
  def add_noise(
352
371
  self,
353
- original_samples: torch.FloatTensor,
354
- noise: torch.FloatTensor,
355
- timesteps: torch.FloatTensor,
356
- ) -> torch.FloatTensor:
372
+ original_samples: torch.Tensor,
373
+ noise: torch.Tensor,
374
+ timesteps: torch.Tensor,
375
+ ) -> torch.Tensor:
357
376
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
358
377
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
359
378
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -367,7 +386,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
367
386
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
368
387
  if self.begin_index is None:
369
388
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
389
+ elif self.step_index is not None:
390
+ # add_noise is called after first denoising step (for inpainting)
391
+ step_indices = [self.step_index] * timesteps.shape[0]
370
392
  else:
393
+ # add noise is called before first denoising step to create initial latent(img2img)
371
394
  step_indices = [self.begin_index] * timesteps.shape[0]
372
395
 
373
396
  sigma = sigmas[step_indices].flatten()