diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -110,7 +110,7 @@ def betas_for_alpha_bar(
110
110
  return math.exp(t * -12.0)
111
111
 
112
112
  else:
113
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
113
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
114
114
 
115
115
  betas = []
116
116
  for i in range(num_diffusion_timesteps):
@@ -184,7 +184,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
184
184
  # Glide cosine schedule
185
185
  self.betas = betas_for_alpha_bar(num_train_timesteps)
186
186
  else:
187
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
187
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
188
188
 
189
189
  self.alphas = 1.0 - self.betas
190
190
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -233,7 +233,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
233
233
  @property
234
234
  def step_index(self):
235
235
  """
236
- The index counter for current timestep. It will increae 1 after each scheduler step.
236
+ The index counter for current timestep. It will increase 1 after each scheduler step.
237
237
  """
238
238
  return self._step_index
239
239
 
@@ -257,21 +257,21 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
257
257
 
258
258
  def scale_model_input(
259
259
  self,
260
- sample: torch.FloatTensor,
261
- timestep: Union[float, torch.FloatTensor],
262
- ) -> torch.FloatTensor:
260
+ sample: torch.Tensor,
261
+ timestep: Union[float, torch.Tensor],
262
+ ) -> torch.Tensor:
263
263
  """
264
264
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
265
265
  current timestep.
266
266
 
267
267
  Args:
268
- sample (`torch.FloatTensor`):
268
+ sample (`torch.Tensor`):
269
269
  The input sample.
270
270
  timestep (`int`, *optional*):
271
271
  The current timestep in the diffusion chain.
272
272
 
273
273
  Returns:
274
- `torch.FloatTensor`:
274
+ `torch.Tensor`:
275
275
  A scaled input sample.
276
276
  """
277
277
  if self.step_index is None:
@@ -325,7 +325,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
325
325
  log_sigmas = np.log(sigmas)
326
326
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
327
327
 
328
- if self.use_karras_sigmas:
328
+ if self.config.use_karras_sigmas:
329
329
  sigmas = self._convert_to_karras(in_sigmas=sigmas)
330
330
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331
331
 
@@ -395,7 +395,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
395
395
  return t
396
396
 
397
397
  # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
398
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
398
+ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
399
399
  """Constructs the noise schedule of Karras et al. (2022)."""
400
400
 
401
401
  sigma_min: float = in_sigmas[-1].item()
@@ -414,9 +414,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
414
414
 
415
415
  def step(
416
416
  self,
417
- model_output: Union[torch.FloatTensor, np.ndarray],
418
- timestep: Union[float, torch.FloatTensor],
419
- sample: Union[torch.FloatTensor, np.ndarray],
417
+ model_output: Union[torch.Tensor, np.ndarray],
418
+ timestep: Union[float, torch.Tensor],
419
+ sample: Union[torch.Tensor, np.ndarray],
420
420
  return_dict: bool = True,
421
421
  s_noise: float = 1.0,
422
422
  ) -> Union[SchedulerOutput, Tuple]:
@@ -425,11 +425,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
425
425
  process from the learned model outputs (most often the predicted noise).
426
426
 
427
427
  Args:
428
- model_output (`torch.FloatTensor` or `np.ndarray`):
428
+ model_output (`torch.Tensor` or `np.ndarray`):
429
429
  The direct output from learned diffusion model.
430
- timestep (`float` or `torch.FloatTensor`):
430
+ timestep (`float` or `torch.Tensor`):
431
431
  The current discrete timestep in the diffusion chain.
432
- sample (`torch.FloatTensor` or `np.ndarray`):
432
+ sample (`torch.Tensor` or `np.ndarray`):
433
433
  A current instance of a sample created by the diffusion process.
434
434
  return_dict (`bool`, *optional*, defaults to `True`):
435
435
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
@@ -450,10 +450,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
450
450
  self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
451
451
 
452
452
  # Define functions to compute sigma and t from each other
453
- def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor:
453
+ def sigma_fn(_t: torch.Tensor) -> torch.Tensor:
454
454
  return _t.neg().exp()
455
455
 
456
- def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor:
456
+ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
457
457
  return _sigma.log().neg()
458
458
 
459
459
  if self.state_in_first_order:
@@ -526,10 +526,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
526
526
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
527
527
  def add_noise(
528
528
  self,
529
- original_samples: torch.FloatTensor,
530
- noise: torch.FloatTensor,
531
- timesteps: torch.FloatTensor,
532
- ) -> torch.FloatTensor:
529
+ original_samples: torch.Tensor,
530
+ noise: torch.Tensor,
531
+ timesteps: torch.Tensor,
532
+ ) -> torch.Tensor:
533
533
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
534
534
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
535
535
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -543,7 +543,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
543
543
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
544
544
  if self.begin_index is None:
545
545
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
546
+ elif self.step_index is not None:
547
+ # add_noise is called after first denoising step (for inpainting)
548
+ step_indices = [self.step_index] * timesteps.shape[0]
546
549
  else:
550
+ # add noise is called before first denoising step to create initial latent(img2img)
547
551
  step_indices = [self.begin_index] * timesteps.shape[0]
548
552
 
549
553
  sigma = sigmas[step_indices].flatten()
@@ -63,7 +63,7 @@ def betas_for_alpha_bar(
63
63
  return math.exp(t * -12.0)
64
64
 
65
65
  else:
66
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
66
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
67
67
 
68
68
  betas = []
69
69
  for i in range(num_diffusion_timesteps):
@@ -108,11 +108,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108
108
  The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
109
109
  `algorithm_type="dpmsolver++"`.
110
110
  algorithm_type (`str`, defaults to `dpmsolver++`):
111
- Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
112
- `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
113
- paper, and the `dpmsolver++` type implements the algorithms in the
114
- [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
115
- `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
111
+ Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
112
+ algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
113
+ implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
114
+ recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
115
+ Stable Diffusion.
116
116
  solver_type (`str`, defaults to `midpoint`):
117
117
  Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
118
118
  sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -123,8 +123,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
123
123
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124
124
  the sigmas are determined according to a sequence of noise levels {σi}.
125
125
  final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
126
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
127
- is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
126
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
127
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
128
128
  lambda_min_clipped (`float`, defaults to `-inf`):
129
129
  Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
130
130
  cosine (`squaredcos_cap_v2`) noise schedule.
@@ -172,7 +172,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
172
172
  # Glide cosine schedule
173
173
  self.betas = betas_for_alpha_bar(num_train_timesteps)
174
174
  else:
175
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
175
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
176
176
 
177
177
  self.alphas = 1.0 - self.betas
178
178
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -190,12 +190,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
190
190
  if algorithm_type == "deis":
191
191
  self.register_to_config(algorithm_type="dpmsolver++")
192
192
  else:
193
- raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
193
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
194
194
  if solver_type not in ["midpoint", "heun"]:
195
195
  if solver_type in ["logrho", "bh1", "bh2"]:
196
196
  self.register_to_config(solver_type="midpoint")
197
197
  else:
198
- raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
198
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
199
199
 
200
200
  if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
201
201
  raise ValueError(
@@ -252,7 +252,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
252
252
  @property
253
253
  def step_index(self):
254
254
  """
255
- The index counter for current timestep. It will increae 1 after each scheduler step.
255
+ The index counter for current timestep. It will increase 1 after each scheduler step.
256
256
  """
257
257
  return self._step_index
258
258
 
@@ -274,7 +274,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
274
274
  """
275
275
  self._begin_index = begin_index
276
276
 
277
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
277
+ def set_timesteps(
278
+ self,
279
+ num_inference_steps: int = None,
280
+ device: Union[str, torch.device] = None,
281
+ timesteps: Optional[List[int]] = None,
282
+ ):
278
283
  """
279
284
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
280
285
 
@@ -283,17 +288,33 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
283
288
  The number of diffusion steps used when generating samples with a pre-trained model.
284
289
  device (`str` or `torch.device`, *optional*):
285
290
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
291
+ timesteps (`List[int]`, *optional*):
292
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
293
+ timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
294
+ passed, `num_inference_steps` must be `None`.
286
295
  """
296
+ if num_inference_steps is None and timesteps is None:
297
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
298
+ if num_inference_steps is not None and timesteps is not None:
299
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
300
+ if timesteps is not None and self.config.use_karras_sigmas:
301
+ raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
302
+
303
+ num_inference_steps = num_inference_steps or len(timesteps)
287
304
  self.num_inference_steps = num_inference_steps
288
- # Clipping the minimum of all lambda(t) for numerical stability.
289
- # This is critical for cosine (squaredcos_cap_v2) noise schedule.
290
- clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
291
- timesteps = (
292
- np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
293
- .round()[::-1][:-1]
294
- .copy()
295
- .astype(np.int64)
296
- )
305
+
306
+ if timesteps is not None:
307
+ timesteps = np.array(timesteps).astype(np.int64)
308
+ else:
309
+ # Clipping the minimum of all lambda(t) for numerical stability.
310
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
311
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
312
+ timesteps = (
313
+ np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
314
+ .round()[::-1][:-1]
315
+ .copy()
316
+ .astype(np.int64)
317
+ )
297
318
 
298
319
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
299
320
  if self.config.use_karras_sigmas:
@@ -340,7 +361,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
340
361
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
341
362
 
342
363
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
343
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
364
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
344
365
  """
345
366
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
346
367
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -405,7 +426,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
405
426
  return alpha_t, sigma_t
406
427
 
407
428
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
408
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
429
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
409
430
  """Constructs the noise schedule of Karras et al. (2022)."""
410
431
 
411
432
  # Hack to make sure that other schedulers which copy this function don't break
@@ -432,11 +453,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
432
453
 
433
454
  def convert_model_output(
434
455
  self,
435
- model_output: torch.FloatTensor,
456
+ model_output: torch.Tensor,
436
457
  *args,
437
- sample: torch.FloatTensor = None,
458
+ sample: torch.Tensor = None,
438
459
  **kwargs,
439
- ) -> torch.FloatTensor:
460
+ ) -> torch.Tensor:
440
461
  """
441
462
  Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
442
463
  designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
@@ -450,13 +471,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
450
471
  </Tip>
451
472
 
452
473
  Args:
453
- model_output (`torch.FloatTensor`):
474
+ model_output (`torch.Tensor`):
454
475
  The direct output from the learned diffusion model.
455
- sample (`torch.FloatTensor`):
476
+ sample (`torch.Tensor`):
456
477
  A current instance of a sample created by the diffusion process.
457
478
 
458
479
  Returns:
459
- `torch.FloatTensor`:
480
+ `torch.Tensor`:
460
481
  The converted model output.
461
482
  """
462
483
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -521,26 +542,26 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
521
542
 
522
543
  def dpm_solver_first_order_update(
523
544
  self,
524
- model_output: torch.FloatTensor,
545
+ model_output: torch.Tensor,
525
546
  *args,
526
- sample: torch.FloatTensor = None,
547
+ sample: torch.Tensor = None,
527
548
  **kwargs,
528
- ) -> torch.FloatTensor:
549
+ ) -> torch.Tensor:
529
550
  """
530
551
  One step for the first-order DPMSolver (equivalent to DDIM).
531
552
 
532
553
  Args:
533
- model_output (`torch.FloatTensor`):
554
+ model_output (`torch.Tensor`):
534
555
  The direct output from the learned diffusion model.
535
556
  timestep (`int`):
536
557
  The current discrete timestep in the diffusion chain.
537
558
  prev_timestep (`int`):
538
559
  The previous discrete timestep in the diffusion chain.
539
- sample (`torch.FloatTensor`):
560
+ sample (`torch.Tensor`):
540
561
  A current instance of a sample created by the diffusion process.
541
562
 
542
563
  Returns:
543
- `torch.FloatTensor`:
564
+ `torch.Tensor`:
544
565
  The sample tensor at the previous timestep.
545
566
  """
546
567
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -577,27 +598,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
577
598
 
578
599
  def singlestep_dpm_solver_second_order_update(
579
600
  self,
580
- model_output_list: List[torch.FloatTensor],
601
+ model_output_list: List[torch.Tensor],
581
602
  *args,
582
- sample: torch.FloatTensor = None,
603
+ sample: torch.Tensor = None,
583
604
  **kwargs,
584
- ) -> torch.FloatTensor:
605
+ ) -> torch.Tensor:
585
606
  """
586
607
  One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
587
608
  time `timestep_list[-2]`.
588
609
 
589
610
  Args:
590
- model_output_list (`List[torch.FloatTensor]`):
611
+ model_output_list (`List[torch.Tensor]`):
591
612
  The direct outputs from learned diffusion model at current and latter timesteps.
592
613
  timestep (`int`):
593
614
  The current and latter discrete timestep in the diffusion chain.
594
615
  prev_timestep (`int`):
595
616
  The previous discrete timestep in the diffusion chain.
596
- sample (`torch.FloatTensor`):
617
+ sample (`torch.Tensor`):
597
618
  A current instance of a sample created by the diffusion process.
598
619
 
599
620
  Returns:
600
- `torch.FloatTensor`:
621
+ `torch.Tensor`:
601
622
  The sample tensor at the previous timestep.
602
623
  """
603
624
  timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
@@ -671,27 +692,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
671
692
 
672
693
  def singlestep_dpm_solver_third_order_update(
673
694
  self,
674
- model_output_list: List[torch.FloatTensor],
695
+ model_output_list: List[torch.Tensor],
675
696
  *args,
676
- sample: torch.FloatTensor = None,
697
+ sample: torch.Tensor = None,
677
698
  **kwargs,
678
- ) -> torch.FloatTensor:
699
+ ) -> torch.Tensor:
679
700
  """
680
701
  One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
681
702
  time `timestep_list[-3]`.
682
703
 
683
704
  Args:
684
- model_output_list (`List[torch.FloatTensor]`):
705
+ model_output_list (`List[torch.Tensor]`):
685
706
  The direct outputs from learned diffusion model at current and latter timesteps.
686
707
  timestep (`int`):
687
708
  The current and latter discrete timestep in the diffusion chain.
688
709
  prev_timestep (`int`):
689
710
  The previous discrete timestep in the diffusion chain.
690
- sample (`torch.FloatTensor`):
711
+ sample (`torch.Tensor`):
691
712
  A current instance of a sample created by diffusion process.
692
713
 
693
714
  Returns:
694
- `torch.FloatTensor`:
715
+ `torch.Tensor`:
695
716
  The sample tensor at the previous timestep.
696
717
  """
697
718
 
@@ -775,29 +796,29 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
775
796
 
776
797
  def singlestep_dpm_solver_update(
777
798
  self,
778
- model_output_list: List[torch.FloatTensor],
799
+ model_output_list: List[torch.Tensor],
779
800
  *args,
780
- sample: torch.FloatTensor = None,
801
+ sample: torch.Tensor = None,
781
802
  order: int = None,
782
803
  **kwargs,
783
- ) -> torch.FloatTensor:
804
+ ) -> torch.Tensor:
784
805
  """
785
806
  One step for the singlestep DPMSolver.
786
807
 
787
808
  Args:
788
- model_output_list (`List[torch.FloatTensor]`):
809
+ model_output_list (`List[torch.Tensor]`):
789
810
  The direct outputs from learned diffusion model at current and latter timesteps.
790
811
  timestep (`int`):
791
812
  The current and latter discrete timestep in the diffusion chain.
792
813
  prev_timestep (`int`):
793
814
  The previous discrete timestep in the diffusion chain.
794
- sample (`torch.FloatTensor`):
815
+ sample (`torch.Tensor`):
795
816
  A current instance of a sample created by diffusion process.
796
817
  order (`int`):
797
818
  The solver order at this step.
798
819
 
799
820
  Returns:
800
- `torch.FloatTensor`:
821
+ `torch.Tensor`:
801
822
  The sample tensor at the previous timestep.
802
823
  """
803
824
  timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
@@ -870,9 +891,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
870
891
 
871
892
  def step(
872
893
  self,
873
- model_output: torch.FloatTensor,
894
+ model_output: torch.Tensor,
874
895
  timestep: int,
875
- sample: torch.FloatTensor,
896
+ sample: torch.Tensor,
876
897
  return_dict: bool = True,
877
898
  ) -> Union[SchedulerOutput, Tuple]:
878
899
  """
@@ -880,11 +901,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
880
901
  the singlestep DPMSolver.
881
902
 
882
903
  Args:
883
- model_output (`torch.FloatTensor`):
904
+ model_output (`torch.Tensor`):
884
905
  The direct output from learned diffusion model.
885
906
  timestep (`int`):
886
907
  The current discrete timestep in the diffusion chain.
887
- sample (`torch.FloatTensor`):
908
+ sample (`torch.Tensor`):
888
909
  A current instance of a sample created by the diffusion process.
889
910
  return_dict (`bool`):
890
911
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
@@ -929,17 +950,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
929
950
 
930
951
  return SchedulerOutput(prev_sample=prev_sample)
931
952
 
932
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
953
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
933
954
  """
934
955
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
935
956
  current timestep.
936
957
 
937
958
  Args:
938
- sample (`torch.FloatTensor`):
959
+ sample (`torch.Tensor`):
939
960
  The input sample.
940
961
 
941
962
  Returns:
942
- `torch.FloatTensor`:
963
+ `torch.Tensor`:
943
964
  A scaled input sample.
944
965
  """
945
966
  return sample
@@ -947,10 +968,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
947
968
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
948
969
  def add_noise(
949
970
  self,
950
- original_samples: torch.FloatTensor,
951
- noise: torch.FloatTensor,
971
+ original_samples: torch.Tensor,
972
+ noise: torch.Tensor,
952
973
  timesteps: torch.IntTensor,
953
- ) -> torch.FloatTensor:
974
+ ) -> torch.Tensor:
954
975
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
955
976
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
956
977
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -961,10 +982,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
961
982
  schedule_timesteps = self.timesteps.to(original_samples.device)
962
983
  timesteps = timesteps.to(original_samples.device)
963
984
 
964
- # begin_index is None when the scheduler is used for training
985
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
965
986
  if self.begin_index is None:
966
987
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
988
+ elif self.step_index is not None:
989
+ # add_noise is called after first denoising step (for inpainting)
990
+ step_indices = [self.step_index] * timesteps.shape[0]
967
991
  else:
992
+ # add noise is called before first denoising step to create initial latent(img2img)
968
993
  step_indices = [self.begin_index] * timesteps.shape[0]
969
994
 
970
995
  sigma = sigmas[step_indices].flatten()