diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -35,16 +35,16 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
35
35
  Output class for the scheduler's `step` function output.
36
36
 
37
37
  Args:
38
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
39
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
40
  denoising loop.
41
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
42
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
43
  `pred_original_sample` can be used to preview progress or for guidance.
44
44
  """
45
45
 
46
- prev_sample: torch.FloatTensor
47
- pred_original_sample: Optional[torch.FloatTensor] = None
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
48
 
49
49
 
50
50
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
82
82
  return math.exp(t * -12.0)
83
83
 
84
84
  else:
85
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
86
 
87
87
  betas = []
88
88
  for i in range(num_diffusion_timesteps):
@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
99
99
 
100
100
 
101
101
  Args:
102
- betas (`torch.FloatTensor`):
102
+ betas (`torch.Tensor`):
103
103
  the betas that the scheduler is being initialized with.
104
104
 
105
105
  Returns:
106
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
106
+ `torch.Tensor`: rescaled betas with zero terminal SNR
107
107
  """
108
108
  # Convert betas to alphas_bar_sqrt
109
109
  alphas = 1.0 - betas
@@ -190,7 +190,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
190
190
  # Glide cosine schedule
191
191
  self.betas = betas_for_alpha_bar(num_train_timesteps)
192
192
  else:
193
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
193
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
194
194
 
195
195
  if rescale_betas_zero_snr:
196
196
  self.betas = rescale_zero_terminal_snr(self.betas)
@@ -228,7 +228,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
228
228
  @property
229
229
  def step_index(self):
230
230
  """
231
- The index counter for current timestep. It will increae 1 after each scheduler step.
231
+ The index counter for current timestep. It will increase 1 after each scheduler step.
232
232
  """
233
233
  return self._step_index
234
234
 
@@ -250,21 +250,19 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
250
250
  """
251
251
  self._begin_index = begin_index
252
252
 
253
- def scale_model_input(
254
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
255
- ) -> torch.FloatTensor:
253
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
256
254
  """
257
255
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
258
256
  current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
259
257
 
260
258
  Args:
261
- sample (`torch.FloatTensor`):
259
+ sample (`torch.Tensor`):
262
260
  The input sample.
263
261
  timestep (`int`, *optional*):
264
262
  The current timestep in the diffusion chain.
265
263
 
266
264
  Returns:
267
- `torch.FloatTensor`:
265
+ `torch.Tensor`:
268
266
  A scaled input sample.
269
267
  """
270
268
 
@@ -346,9 +344,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
346
344
 
347
345
  def step(
348
346
  self,
349
- model_output: torch.FloatTensor,
350
- timestep: Union[float, torch.FloatTensor],
351
- sample: torch.FloatTensor,
347
+ model_output: torch.Tensor,
348
+ timestep: Union[float, torch.Tensor],
349
+ sample: torch.Tensor,
352
350
  generator: Optional[torch.Generator] = None,
353
351
  return_dict: bool = True,
354
352
  ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
@@ -357,11 +355,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
357
355
  process from the learned model outputs (most often the predicted noise).
358
356
 
359
357
  Args:
360
- model_output (`torch.FloatTensor`):
358
+ model_output (`torch.Tensor`):
361
359
  The direct output from learned diffusion model.
362
360
  timestep (`float`):
363
361
  The current discrete timestep in the diffusion chain.
364
- sample (`torch.FloatTensor`):
362
+ sample (`torch.Tensor`):
365
363
  A current instance of a sample created by the diffusion process.
366
364
  generator (`torch.Generator`, *optional*):
367
365
  A random number generator.
@@ -377,11 +375,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
377
375
 
378
376
  """
379
377
 
380
- if (
381
- isinstance(timestep, int)
382
- or isinstance(timestep, torch.IntTensor)
383
- or isinstance(timestep, torch.LongTensor)
384
- ):
378
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
385
379
  raise ValueError(
386
380
  (
387
381
  "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
@@ -450,10 +444,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
450
444
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
451
445
  def add_noise(
452
446
  self,
453
- original_samples: torch.FloatTensor,
454
- noise: torch.FloatTensor,
455
- timesteps: torch.FloatTensor,
456
- ) -> torch.FloatTensor:
447
+ original_samples: torch.Tensor,
448
+ noise: torch.Tensor,
449
+ timesteps: torch.Tensor,
450
+ ) -> torch.Tensor:
457
451
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
458
452
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
459
453
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -467,7 +461,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
467
461
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
468
462
  if self.begin_index is None:
469
463
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
464
+ elif self.step_index is not None:
465
+ # add_noise is called after first denoising step (for inpainting)
466
+ step_indices = [self.step_index] * timesteps.shape[0]
470
467
  else:
468
+ # add noise is called before first denoising step to create initial latent(img2img)
471
469
  step_indices = [self.begin_index] * timesteps.shape[0]
472
470
 
473
471
  sigma = sigmas[step_indices].flatten()
@@ -35,16 +35,16 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
35
35
  Output class for the scheduler's `step` function output.
36
36
 
37
37
  Args:
38
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
39
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
40
  denoising loop.
41
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
42
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
43
  `pred_original_sample` can be used to preview progress or for guidance.
44
44
  """
45
45
 
46
- prev_sample: torch.FloatTensor
47
- pred_original_sample: Optional[torch.FloatTensor] = None
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
48
 
49
49
 
50
50
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
82
82
  return math.exp(t * -12.0)
83
83
 
84
84
  else:
85
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
86
 
87
87
  betas = []
88
88
  for i in range(num_diffusion_timesteps):
@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
99
99
 
100
100
 
101
101
  Args:
102
- betas (`torch.FloatTensor`):
102
+ betas (`torch.Tensor`):
103
103
  the betas that the scheduler is being initialized with.
104
104
 
105
105
  Returns:
106
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
106
+ `torch.Tensor`: rescaled betas with zero terminal SNR
107
107
  """
108
108
  # Convert betas to alphas_bar_sqrt
109
109
  alphas = 1.0 - betas
@@ -167,6 +167,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
167
167
  Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
168
168
  dark samples instead of limiting it to samples with medium brightness. Loosely related to
169
169
  [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
170
+ final_sigmas_type (`str`, defaults to `"zero"`):
171
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
172
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
170
173
  """
171
174
 
172
175
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -189,6 +192,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
189
192
  timestep_type: str = "discrete", # can be "discrete" or "continuous"
190
193
  steps_offset: int = 0,
191
194
  rescale_betas_zero_snr: bool = False,
195
+ final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
192
196
  ):
193
197
  if trained_betas is not None:
194
198
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -201,7 +205,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
201
205
  # Glide cosine schedule
202
206
  self.betas = betas_for_alpha_bar(num_train_timesteps)
203
207
  else:
204
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
208
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
205
209
 
206
210
  if rescale_betas_zero_snr:
207
211
  self.betas = rescale_zero_terminal_snr(self.betas)
@@ -248,7 +252,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
248
252
  @property
249
253
  def step_index(self):
250
254
  """
251
- The index counter for current timestep. It will increae 1 after each scheduler step.
255
+ The index counter for current timestep. It will increase 1 after each scheduler step.
252
256
  """
253
257
  return self._step_index
254
258
 
@@ -270,21 +274,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
270
274
  """
271
275
  self._begin_index = begin_index
272
276
 
273
- def scale_model_input(
274
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
275
- ) -> torch.FloatTensor:
277
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
276
278
  """
277
279
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
278
280
  current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
279
281
 
280
282
  Args:
281
- sample (`torch.FloatTensor`):
283
+ sample (`torch.Tensor`):
282
284
  The input sample.
283
285
  timestep (`int`, *optional*):
284
286
  The current timestep in the diffusion chain.
285
287
 
286
288
  Returns:
287
- `torch.FloatTensor`:
289
+ `torch.Tensor`:
288
290
  A scaled input sample.
289
291
  """
290
292
  if self.step_index is None:
@@ -296,7 +298,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
296
298
  self.is_scale_input_called = True
297
299
  return sample
298
300
 
299
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
301
+ def set_timesteps(
302
+ self,
303
+ num_inference_steps: int = None,
304
+ device: Union[str, torch.device] = None,
305
+ timesteps: Optional[List[int]] = None,
306
+ sigmas: Optional[List[float]] = None,
307
+ ):
300
308
  """
301
309
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
302
310
 
@@ -305,60 +313,111 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
305
313
  The number of diffusion steps used when generating samples with a pre-trained model.
306
314
  device (`str` or `torch.device`, *optional*):
307
315
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
316
+ timesteps (`List[int]`, *optional*):
317
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
318
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
319
+ must be `None`, and `timestep_spacing` attribute will be ignored.
320
+ sigmas (`List[float]`, *optional*):
321
+ Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
322
+ will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
323
+ `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
324
+ custom sigmas schedule.
308
325
  """
309
- self.num_inference_steps = num_inference_steps
310
326
 
311
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
312
- if self.config.timestep_spacing == "linspace":
313
- timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
314
- ::-1
315
- ].copy()
316
- elif self.config.timestep_spacing == "leading":
317
- step_ratio = self.config.num_train_timesteps // self.num_inference_steps
318
- # creates integer timesteps by multiplying by ratio
319
- # casting to int to avoid issues when num_inference_step is power of 3
320
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
321
- timesteps += self.config.steps_offset
322
- elif self.config.timestep_spacing == "trailing":
323
- step_ratio = self.config.num_train_timesteps / self.num_inference_steps
324
- # creates integer timesteps by multiplying by ratio
325
- # casting to int to avoid issues when num_inference_step is power of 3
326
- timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
327
- timesteps -= 1
328
- else:
327
+ if timesteps is not None and sigmas is not None:
328
+ raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
329
+ if num_inference_steps is None and timesteps is None and sigmas is None:
330
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
331
+ if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
332
+ raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
333
+ if timesteps is not None and self.config.use_karras_sigmas:
334
+ raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
335
+ if (
336
+ timesteps is not None
337
+ and self.config.timestep_type == "continuous"
338
+ and self.config.prediction_type == "v_prediction"
339
+ ):
329
340
  raise ValueError(
330
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
341
+ "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
331
342
  )
332
343
 
333
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
334
- log_sigmas = np.log(sigmas)
344
+ if num_inference_steps is None:
345
+ num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
346
+ self.num_inference_steps = num_inference_steps
347
+
348
+ if sigmas is not None:
349
+ log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
350
+ sigmas = np.array(sigmas).astype(np.float32)
351
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
335
352
 
336
- if self.config.interpolation_type == "linear":
337
- sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
338
- elif self.config.interpolation_type == "log_linear":
339
- sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
340
353
  else:
341
- raise ValueError(
342
- f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
343
- " 'linear' or 'log_linear'"
344
- )
345
-
346
- if self.use_karras_sigmas:
347
- sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
348
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
354
+ if timesteps is not None:
355
+ timesteps = np.array(timesteps).astype(np.float32)
356
+ else:
357
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
358
+ if self.config.timestep_spacing == "linspace":
359
+ timesteps = np.linspace(
360
+ 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
361
+ )[::-1].copy()
362
+ elif self.config.timestep_spacing == "leading":
363
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
364
+ # creates integer timesteps by multiplying by ratio
365
+ # casting to int to avoid issues when num_inference_step is power of 3
366
+ timesteps = (
367
+ (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
368
+ )
369
+ timesteps += self.config.steps_offset
370
+ elif self.config.timestep_spacing == "trailing":
371
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
372
+ # creates integer timesteps by multiplying by ratio
373
+ # casting to int to avoid issues when num_inference_step is power of 3
374
+ timesteps = (
375
+ (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
376
+ )
377
+ timesteps -= 1
378
+ else:
379
+ raise ValueError(
380
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
381
+ )
382
+
383
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
384
+ log_sigmas = np.log(sigmas)
385
+ if self.config.interpolation_type == "linear":
386
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
387
+ elif self.config.interpolation_type == "log_linear":
388
+ sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
389
+ else:
390
+ raise ValueError(
391
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
392
+ " 'linear' or 'log_linear'"
393
+ )
394
+
395
+ if self.config.use_karras_sigmas:
396
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
397
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
398
+
399
+ if self.config.final_sigmas_type == "sigma_min":
400
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
401
+ elif self.config.final_sigmas_type == "zero":
402
+ sigma_last = 0
403
+ else:
404
+ raise ValueError(
405
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
406
+ )
407
+
408
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
349
409
 
350
410
  sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
351
411
 
352
412
  # TODO: Support the full EDM scalings for all prediction types and timestep types
353
413
  if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
354
- self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
414
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
355
415
  else:
356
416
  self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
357
417
 
358
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
359
418
  self._step_index = None
360
419
  self._begin_index = None
361
- self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
420
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
362
421
 
363
422
  def _sigma_to_t(self, sigma, log_sigmas):
364
423
  # get log sigma
@@ -384,7 +443,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
384
443
  return t
385
444
 
386
445
  # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
387
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
446
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
388
447
  """Constructs the noise schedule of Karras et al. (2022)."""
389
448
 
390
449
  # Hack to make sure that other schedulers which copy this function don't break
@@ -433,9 +492,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
433
492
 
434
493
  def step(
435
494
  self,
436
- model_output: torch.FloatTensor,
437
- timestep: Union[float, torch.FloatTensor],
438
- sample: torch.FloatTensor,
495
+ model_output: torch.Tensor,
496
+ timestep: Union[float, torch.Tensor],
497
+ sample: torch.Tensor,
439
498
  s_churn: float = 0.0,
440
499
  s_tmin: float = 0.0,
441
500
  s_tmax: float = float("inf"),
@@ -448,11 +507,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
448
507
  process from the learned model outputs (most often the predicted noise).
449
508
 
450
509
  Args:
451
- model_output (`torch.FloatTensor`):
510
+ model_output (`torch.Tensor`):
452
511
  The direct output from learned diffusion model.
453
512
  timestep (`float`):
454
513
  The current discrete timestep in the diffusion chain.
455
- sample (`torch.FloatTensor`):
514
+ sample (`torch.Tensor`):
456
515
  A current instance of a sample created by the diffusion process.
457
516
  s_churn (`float`):
458
517
  s_tmin (`float`):
@@ -471,11 +530,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
471
530
  returned, otherwise a tuple is returned where the first element is the sample tensor.
472
531
  """
473
532
 
474
- if (
475
- isinstance(timestep, int)
476
- or isinstance(timestep, torch.IntTensor)
477
- or isinstance(timestep, torch.LongTensor)
478
- ):
533
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
479
534
  raise ValueError(
480
535
  (
481
536
  "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
@@ -545,10 +600,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
545
600
 
546
601
  def add_noise(
547
602
  self,
548
- original_samples: torch.FloatTensor,
549
- noise: torch.FloatTensor,
550
- timesteps: torch.FloatTensor,
551
- ) -> torch.FloatTensor:
603
+ original_samples: torch.Tensor,
604
+ noise: torch.Tensor,
605
+ timesteps: torch.Tensor,
606
+ ) -> torch.Tensor:
552
607
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
553
608
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
554
609
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -562,7 +617,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
562
617
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
563
618
  if self.begin_index is None:
564
619
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
620
+ elif self.step_index is not None:
621
+ # add_noise is called after first denoising step (for inpainting)
622
+ step_indices = [self.step_index] * timesteps.shape[0]
565
623
  else:
624
+ # add noise is called before first denoising step to create initial latent(img2img)
566
625
  step_indices = [self.begin_index] * timesteps.shape[0]
567
626
 
568
627
  sigma = sigmas[step_indices].flatten()
@@ -572,5 +631,42 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
572
631
  noisy_samples = original_samples + noise * sigma
573
632
  return noisy_samples
574
633
 
634
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
635
+ if (
636
+ isinstance(timesteps, int)
637
+ or isinstance(timesteps, torch.IntTensor)
638
+ or isinstance(timesteps, torch.LongTensor)
639
+ ):
640
+ raise ValueError(
641
+ (
642
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
643
+ " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
644
+ " one of the `scheduler.timesteps` as a timestep."
645
+ ),
646
+ )
647
+
648
+ if sample.device.type == "mps" and torch.is_floating_point(timesteps):
649
+ # mps does not support float64
650
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
651
+ timesteps = timesteps.to(sample.device, dtype=torch.float32)
652
+ else:
653
+ schedule_timesteps = self.timesteps.to(sample.device)
654
+ timesteps = timesteps.to(sample.device)
655
+
656
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
657
+ alphas_cumprod = self.alphas_cumprod.to(sample)
658
+ sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
659
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
660
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
661
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
662
+
663
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
664
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
665
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
666
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
667
+
668
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
669
+ return velocity
670
+
575
671
  def __len__(self):
576
672
  return self.config.num_train_timesteps