diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
61
61
  return math.exp(t * -12.0)
62
62
 
63
63
  else:
64
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
64
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
65
65
 
66
66
  betas = []
67
67
  for i in range(num_diffusion_timesteps):
@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas):
78
78
 
79
79
 
80
80
  Args:
81
- betas (`torch.FloatTensor`):
81
+ betas (`torch.Tensor`):
82
82
  the betas that the scheduler is being initialized with.
83
83
 
84
84
  Returns:
85
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
85
+ `torch.Tensor`: rescaled betas with zero terminal SNR
86
86
  """
87
87
  # Convert betas to alphas_bar_sqrt
88
88
  alphas = 1.0 - betas
@@ -166,8 +166,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
166
166
  the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
167
167
  `lambda(t)`.
168
168
  final_sigmas_type (`str`, defaults to `"zero"`):
169
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
170
- is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
169
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
170
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
171
171
  lambda_min_clipped (`float`, defaults to `-inf`):
172
172
  Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
173
173
  cosine (`squaredcos_cap_v2`) noise schedule.
@@ -229,7 +229,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
229
229
  # Glide cosine schedule
230
230
  self.betas = betas_for_alpha_bar(num_train_timesteps)
231
231
  else:
232
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
232
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
233
233
 
234
234
  if rescale_betas_zero_snr:
235
235
  self.betas = rescale_zero_terminal_snr(self.betas)
@@ -256,13 +256,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
256
256
  if algorithm_type == "deis":
257
257
  self.register_to_config(algorithm_type="dpmsolver++")
258
258
  else:
259
- raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
259
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
260
260
 
261
261
  if solver_type not in ["midpoint", "heun"]:
262
262
  if solver_type in ["logrho", "bh1", "bh2"]:
263
263
  self.register_to_config(solver_type="midpoint")
264
264
  else:
265
- raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
265
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
266
266
 
267
267
  if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
268
268
  raise ValueError(
@@ -282,7 +282,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
282
282
  @property
283
283
  def step_index(self):
284
284
  """
285
- The index counter for current timestep. It will increae 1 after each scheduler step.
285
+ The index counter for current timestep. It will increase 1 after each scheduler step.
286
286
  """
287
287
  return self._step_index
288
288
 
@@ -303,7 +303,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
303
303
  """
304
304
  self._begin_index = begin_index
305
305
 
306
- def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
306
+ def set_timesteps(
307
+ self,
308
+ num_inference_steps: int = None,
309
+ device: Union[str, torch.device] = None,
310
+ timesteps: Optional[List[int]] = None,
311
+ ):
307
312
  """
308
313
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
309
314
 
@@ -312,33 +317,54 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
312
317
  The number of diffusion steps used when generating samples with a pre-trained model.
313
318
  device (`str` or `torch.device`, *optional*):
314
319
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
320
+ timesteps (`List[int]`, *optional*):
321
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
322
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
323
+ must be `None`, and `timestep_spacing` attribute will be ignored.
315
324
  """
316
- # Clipping the minimum of all lambda(t) for numerical stability.
317
- # This is critical for cosine (squaredcos_cap_v2) noise schedule.
318
- clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
319
- last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
320
-
321
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
322
- if self.config.timestep_spacing == "linspace":
323
- timesteps = (
324
- np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
325
- )
326
- elif self.config.timestep_spacing == "leading":
327
- step_ratio = last_timestep // (num_inference_steps + 1)
328
- # creates integer timesteps by multiplying by ratio
329
- # casting to int to avoid issues when num_inference_step is power of 3
330
- timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
331
- timesteps += self.config.steps_offset
332
- elif self.config.timestep_spacing == "trailing":
333
- step_ratio = self.config.num_train_timesteps / num_inference_steps
334
- # creates integer timesteps by multiplying by ratio
335
- # casting to int to avoid issues when num_inference_step is power of 3
336
- timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
337
- timesteps -= 1
325
+ if num_inference_steps is None and timesteps is None:
326
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
327
+ if num_inference_steps is not None and timesteps is not None:
328
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
329
+ if timesteps is not None and self.config.use_karras_sigmas:
330
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
331
+ if timesteps is not None and self.config.use_lu_lambdas:
332
+ raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
333
+
334
+ if timesteps is not None:
335
+ timesteps = np.array(timesteps).astype(np.int64)
338
336
  else:
339
- raise ValueError(
340
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
341
- )
337
+ # Clipping the minimum of all lambda(t) for numerical stability.
338
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
339
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
340
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
341
+
342
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
343
+ if self.config.timestep_spacing == "linspace":
344
+ timesteps = (
345
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1)
346
+ .round()[::-1][:-1]
347
+ .copy()
348
+ .astype(np.int64)
349
+ )
350
+ elif self.config.timestep_spacing == "leading":
351
+ step_ratio = last_timestep // (num_inference_steps + 1)
352
+ # creates integer timesteps by multiplying by ratio
353
+ # casting to int to avoid issues when num_inference_step is power of 3
354
+ timesteps = (
355
+ (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
356
+ )
357
+ timesteps += self.config.steps_offset
358
+ elif self.config.timestep_spacing == "trailing":
359
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
360
+ # creates integer timesteps by multiplying by ratio
361
+ # casting to int to avoid issues when num_inference_step is power of 3
362
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
363
+ timesteps -= 1
364
+ else:
365
+ raise ValueError(
366
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
367
+ )
342
368
 
343
369
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
344
370
  log_sigmas = np.log(sigmas)
@@ -382,7 +408,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
382
408
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
383
409
 
384
410
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
385
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
411
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
386
412
  """
387
413
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
388
414
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -446,7 +472,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
446
472
  return alpha_t, sigma_t
447
473
 
448
474
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
449
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
475
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
450
476
  """Constructs the noise schedule of Karras et al. (2022)."""
451
477
 
452
478
  # Hack to make sure that other schedulers which copy this function don't break
@@ -471,7 +497,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
471
497
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
472
498
  return sigmas
473
499
 
474
- def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
500
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
475
501
  """Constructs the noise schedule of Lu et al. (2022)."""
476
502
 
477
503
  lambda_min: float = in_lambdas[-1].item()
@@ -486,11 +512,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
486
512
 
487
513
  def convert_model_output(
488
514
  self,
489
- model_output: torch.FloatTensor,
515
+ model_output: torch.Tensor,
490
516
  *args,
491
- sample: torch.FloatTensor = None,
517
+ sample: torch.Tensor = None,
492
518
  **kwargs,
493
- ) -> torch.FloatTensor:
519
+ ) -> torch.Tensor:
494
520
  """
495
521
  Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
496
522
  designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
@@ -504,13 +530,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
504
530
  </Tip>
505
531
 
506
532
  Args:
507
- model_output (`torch.FloatTensor`):
533
+ model_output (`torch.Tensor`):
508
534
  The direct output from the learned diffusion model.
509
- sample (`torch.FloatTensor`):
535
+ sample (`torch.Tensor`):
510
536
  A current instance of a sample created by the diffusion process.
511
537
 
512
538
  Returns:
513
- `torch.FloatTensor`:
539
+ `torch.Tensor`:
514
540
  The converted model output.
515
541
  """
516
542
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -585,23 +611,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
585
611
 
586
612
  def dpm_solver_first_order_update(
587
613
  self,
588
- model_output: torch.FloatTensor,
614
+ model_output: torch.Tensor,
589
615
  *args,
590
- sample: torch.FloatTensor = None,
591
- noise: Optional[torch.FloatTensor] = None,
616
+ sample: torch.Tensor = None,
617
+ noise: Optional[torch.Tensor] = None,
592
618
  **kwargs,
593
- ) -> torch.FloatTensor:
619
+ ) -> torch.Tensor:
594
620
  """
595
621
  One step for the first-order DPMSolver (equivalent to DDIM).
596
622
 
597
623
  Args:
598
- model_output (`torch.FloatTensor`):
624
+ model_output (`torch.Tensor`):
599
625
  The direct output from the learned diffusion model.
600
- sample (`torch.FloatTensor`):
626
+ sample (`torch.Tensor`):
601
627
  A current instance of a sample created by the diffusion process.
602
628
 
603
629
  Returns:
604
- `torch.FloatTensor`:
630
+ `torch.Tensor`:
605
631
  The sample tensor at the previous timestep.
606
632
  """
607
633
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -654,23 +680,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
654
680
 
655
681
  def multistep_dpm_solver_second_order_update(
656
682
  self,
657
- model_output_list: List[torch.FloatTensor],
683
+ model_output_list: List[torch.Tensor],
658
684
  *args,
659
- sample: torch.FloatTensor = None,
660
- noise: Optional[torch.FloatTensor] = None,
685
+ sample: torch.Tensor = None,
686
+ noise: Optional[torch.Tensor] = None,
661
687
  **kwargs,
662
- ) -> torch.FloatTensor:
688
+ ) -> torch.Tensor:
663
689
  """
664
690
  One step for the second-order multistep DPMSolver.
665
691
 
666
692
  Args:
667
- model_output_list (`List[torch.FloatTensor]`):
693
+ model_output_list (`List[torch.Tensor]`):
668
694
  The direct outputs from learned diffusion model at current and latter timesteps.
669
- sample (`torch.FloatTensor`):
695
+ sample (`torch.Tensor`):
670
696
  A current instance of a sample created by the diffusion process.
671
697
 
672
698
  Returns:
673
- `torch.FloatTensor`:
699
+ `torch.Tensor`:
674
700
  The sample tensor at the previous timestep.
675
701
  """
676
702
  timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
@@ -777,22 +803,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
777
803
 
778
804
  def multistep_dpm_solver_third_order_update(
779
805
  self,
780
- model_output_list: List[torch.FloatTensor],
806
+ model_output_list: List[torch.Tensor],
781
807
  *args,
782
- sample: torch.FloatTensor = None,
808
+ sample: torch.Tensor = None,
783
809
  **kwargs,
784
- ) -> torch.FloatTensor:
810
+ ) -> torch.Tensor:
785
811
  """
786
812
  One step for the third-order multistep DPMSolver.
787
813
 
788
814
  Args:
789
- model_output_list (`List[torch.FloatTensor]`):
815
+ model_output_list (`List[torch.Tensor]`):
790
816
  The direct outputs from learned diffusion model at current and latter timesteps.
791
- sample (`torch.FloatTensor`):
817
+ sample (`torch.Tensor`):
792
818
  A current instance of a sample created by diffusion process.
793
819
 
794
820
  Returns:
795
- `torch.FloatTensor`:
821
+ `torch.Tensor`:
796
822
  The sample tensor at the previous timestep.
797
823
  """
798
824
 
@@ -893,11 +919,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
893
919
 
894
920
  def step(
895
921
  self,
896
- model_output: torch.FloatTensor,
922
+ model_output: torch.Tensor,
897
923
  timestep: int,
898
- sample: torch.FloatTensor,
924
+ sample: torch.Tensor,
899
925
  generator=None,
900
- variance_noise: Optional[torch.FloatTensor] = None,
926
+ variance_noise: Optional[torch.Tensor] = None,
901
927
  return_dict: bool = True,
902
928
  ) -> Union[SchedulerOutput, Tuple]:
903
929
  """
@@ -905,15 +931,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
905
931
  the multistep DPMSolver.
906
932
 
907
933
  Args:
908
- model_output (`torch.FloatTensor`):
934
+ model_output (`torch.Tensor`):
909
935
  The direct output from learned diffusion model.
910
936
  timestep (`int`):
911
937
  The current discrete timestep in the diffusion chain.
912
- sample (`torch.FloatTensor`):
938
+ sample (`torch.Tensor`):
913
939
  A current instance of a sample created by the diffusion process.
914
940
  generator (`torch.Generator`, *optional*):
915
941
  A random number generator.
916
- variance_noise (`torch.FloatTensor`):
942
+ variance_noise (`torch.Tensor`):
917
943
  Alternative to generating noise with `generator` by directly providing the noise for the variance
918
944
  itself. Useful for methods such as [`LEdits++`].
919
945
  return_dict (`bool`):
@@ -980,27 +1006,27 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
980
1006
 
981
1007
  return SchedulerOutput(prev_sample=prev_sample)
982
1008
 
983
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1009
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
984
1010
  """
985
1011
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
986
1012
  current timestep.
987
1013
 
988
1014
  Args:
989
- sample (`torch.FloatTensor`):
1015
+ sample (`torch.Tensor`):
990
1016
  The input sample.
991
1017
 
992
1018
  Returns:
993
- `torch.FloatTensor`:
1019
+ `torch.Tensor`:
994
1020
  A scaled input sample.
995
1021
  """
996
1022
  return sample
997
1023
 
998
1024
  def add_noise(
999
1025
  self,
1000
- original_samples: torch.FloatTensor,
1001
- noise: torch.FloatTensor,
1026
+ original_samples: torch.Tensor,
1027
+ noise: torch.Tensor,
1002
1028
  timesteps: torch.IntTensor,
1003
- ) -> torch.FloatTensor:
1029
+ ) -> torch.Tensor:
1004
1030
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
1005
1031
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
1006
1032
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -1011,10 +1037,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
1011
1037
  schedule_timesteps = self.timesteps.to(original_samples.device)
1012
1038
  timesteps = timesteps.to(original_samples.device)
1013
1039
 
1014
- # begin_index is None when the scheduler is used for training
1040
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1015
1041
  if self.begin_index is None:
1016
1042
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1043
+ elif self.step_index is not None:
1044
+ # add_noise is called after first denoising step (for inpainting)
1045
+ step_indices = [self.step_index] * timesteps.shape[0]
1017
1046
  else:
1047
+ # add noise is called before first denoising step to create initial latent(img2img)
1018
1048
  step_indices = [self.begin_index] * timesteps.shape[0]
1019
1049
 
1020
1050
  sigma = sigmas[step_indices].flatten()
@@ -182,9 +182,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
182
182
 
183
183
  # settings for DPM-Solver
184
184
  if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
185
- raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}")
185
+ raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}")
186
186
  if self.config.solver_type not in ["midpoint", "heun"]:
187
- raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}")
187
+ raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
188
188
 
189
189
  # standard deviation of the initial noise distribution
190
190
  init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
61
61
  return math.exp(t * -12.0)
62
62
 
63
63
  else:
64
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
64
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
65
65
 
66
66
  betas = []
67
67
  for i in range(num_diffusion_timesteps):
@@ -178,7 +178,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
178
178
  # Glide cosine schedule
179
179
  self.betas = betas_for_alpha_bar(num_train_timesteps)
180
180
  else:
181
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
181
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
182
182
 
183
183
  self.alphas = 1.0 - self.betas
184
184
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -196,13 +196,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
196
196
  if algorithm_type == "deis":
197
197
  self.register_to_config(algorithm_type="dpmsolver++")
198
198
  else:
199
- raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
199
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
200
200
 
201
201
  if solver_type not in ["midpoint", "heun"]:
202
202
  if solver_type in ["logrho", "bh1", "bh2"]:
203
203
  self.register_to_config(solver_type="midpoint")
204
204
  else:
205
- raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
205
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
206
206
 
207
207
  # setable values
208
208
  self.num_inference_steps = None
@@ -217,7 +217,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
217
217
  @property
218
218
  def step_index(self):
219
219
  """
220
- The index counter for current timestep. It will increae 1 after each scheduler step.
220
+ The index counter for current timestep. It will increase 1 after each scheduler step.
221
221
  """
222
222
  return self._step_index
223
223
 
@@ -233,7 +233,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
233
233
  """
234
234
  # Clipping the minimum of all lambda(t) for numerical stability.
235
235
  # This is critical for cosine (squaredcos_cap_v2) noise schedule.
236
- clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item()
236
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item()
237
237
  self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx
238
238
 
239
239
  # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
@@ -295,7 +295,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
295
295
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
296
296
 
297
297
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
298
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
298
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
299
299
  """
300
300
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
301
301
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -360,7 +360,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
360
360
  return alpha_t, sigma_t
361
361
 
362
362
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
363
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
363
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
364
364
  """Constructs the noise schedule of Karras et al. (2022)."""
365
365
 
366
366
  # Hack to make sure that other schedulers which copy this function don't break
@@ -388,11 +388,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
388
388
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
389
389
  def convert_model_output(
390
390
  self,
391
- model_output: torch.FloatTensor,
391
+ model_output: torch.Tensor,
392
392
  *args,
393
- sample: torch.FloatTensor = None,
393
+ sample: torch.Tensor = None,
394
394
  **kwargs,
395
- ) -> torch.FloatTensor:
395
+ ) -> torch.Tensor:
396
396
  """
397
397
  Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
398
398
  designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
@@ -406,13 +406,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
406
406
  </Tip>
407
407
 
408
408
  Args:
409
- model_output (`torch.FloatTensor`):
409
+ model_output (`torch.Tensor`):
410
410
  The direct output from the learned diffusion model.
411
- sample (`torch.FloatTensor`):
411
+ sample (`torch.Tensor`):
412
412
  A current instance of a sample created by the diffusion process.
413
413
 
414
414
  Returns:
415
- `torch.FloatTensor`:
415
+ `torch.Tensor`:
416
416
  The converted model output.
417
417
  """
418
418
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -488,23 +488,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
488
488
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
489
489
  def dpm_solver_first_order_update(
490
490
  self,
491
- model_output: torch.FloatTensor,
491
+ model_output: torch.Tensor,
492
492
  *args,
493
- sample: torch.FloatTensor = None,
494
- noise: Optional[torch.FloatTensor] = None,
493
+ sample: torch.Tensor = None,
494
+ noise: Optional[torch.Tensor] = None,
495
495
  **kwargs,
496
- ) -> torch.FloatTensor:
496
+ ) -> torch.Tensor:
497
497
  """
498
498
  One step for the first-order DPMSolver (equivalent to DDIM).
499
499
 
500
500
  Args:
501
- model_output (`torch.FloatTensor`):
501
+ model_output (`torch.Tensor`):
502
502
  The direct output from the learned diffusion model.
503
- sample (`torch.FloatTensor`):
503
+ sample (`torch.Tensor`):
504
504
  A current instance of a sample created by the diffusion process.
505
505
 
506
506
  Returns:
507
- `torch.FloatTensor`:
507
+ `torch.Tensor`:
508
508
  The sample tensor at the previous timestep.
509
509
  """
510
510
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -558,23 +558,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
558
558
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
559
559
  def multistep_dpm_solver_second_order_update(
560
560
  self,
561
- model_output_list: List[torch.FloatTensor],
561
+ model_output_list: List[torch.Tensor],
562
562
  *args,
563
- sample: torch.FloatTensor = None,
564
- noise: Optional[torch.FloatTensor] = None,
563
+ sample: torch.Tensor = None,
564
+ noise: Optional[torch.Tensor] = None,
565
565
  **kwargs,
566
- ) -> torch.FloatTensor:
566
+ ) -> torch.Tensor:
567
567
  """
568
568
  One step for the second-order multistep DPMSolver.
569
569
 
570
570
  Args:
571
- model_output_list (`List[torch.FloatTensor]`):
571
+ model_output_list (`List[torch.Tensor]`):
572
572
  The direct outputs from learned diffusion model at current and latter timesteps.
573
- sample (`torch.FloatTensor`):
573
+ sample (`torch.Tensor`):
574
574
  A current instance of a sample created by the diffusion process.
575
575
 
576
576
  Returns:
577
- `torch.FloatTensor`:
577
+ `torch.Tensor`:
578
578
  The sample tensor at the previous timestep.
579
579
  """
580
580
  timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
@@ -682,22 +682,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
682
682
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
683
683
  def multistep_dpm_solver_third_order_update(
684
684
  self,
685
- model_output_list: List[torch.FloatTensor],
685
+ model_output_list: List[torch.Tensor],
686
686
  *args,
687
- sample: torch.FloatTensor = None,
687
+ sample: torch.Tensor = None,
688
688
  **kwargs,
689
- ) -> torch.FloatTensor:
689
+ ) -> torch.Tensor:
690
690
  """
691
691
  One step for the third-order multistep DPMSolver.
692
692
 
693
693
  Args:
694
- model_output_list (`List[torch.FloatTensor]`):
694
+ model_output_list (`List[torch.Tensor]`):
695
695
  The direct outputs from learned diffusion model at current and latter timesteps.
696
- sample (`torch.FloatTensor`):
696
+ sample (`torch.Tensor`):
697
697
  A current instance of a sample created by diffusion process.
698
698
 
699
699
  Returns:
700
- `torch.FloatTensor`:
700
+ `torch.Tensor`:
701
701
  The sample tensor at the previous timestep.
702
702
  """
703
703
 
@@ -786,11 +786,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
786
786
 
787
787
  def step(
788
788
  self,
789
- model_output: torch.FloatTensor,
789
+ model_output: torch.Tensor,
790
790
  timestep: int,
791
- sample: torch.FloatTensor,
791
+ sample: torch.Tensor,
792
792
  generator=None,
793
- variance_noise: Optional[torch.FloatTensor] = None,
793
+ variance_noise: Optional[torch.Tensor] = None,
794
794
  return_dict: bool = True,
795
795
  ) -> Union[SchedulerOutput, Tuple]:
796
796
  """
@@ -798,15 +798,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
798
798
  the multistep DPMSolver.
799
799
 
800
800
  Args:
801
- model_output (`torch.FloatTensor`):
801
+ model_output (`torch.Tensor`):
802
802
  The direct output from learned diffusion model.
803
803
  timestep (`int`):
804
804
  The current discrete timestep in the diffusion chain.
805
- sample (`torch.FloatTensor`):
805
+ sample (`torch.Tensor`):
806
806
  A current instance of a sample created by the diffusion process.
807
807
  generator (`torch.Generator`, *optional*):
808
808
  A random number generator.
809
- variance_noise (`torch.FloatTensor`):
809
+ variance_noise (`torch.Tensor`):
810
810
  Alternative to generating noise with `generator` by directly providing the noise for the variance
811
811
  itself. Useful for methods such as [`CycleDiffusion`].
812
812
  return_dict (`bool`):
@@ -867,27 +867,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
867
867
  return SchedulerOutput(prev_sample=prev_sample)
868
868
 
869
869
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
870
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
870
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
871
871
  """
872
872
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
873
873
  current timestep.
874
874
 
875
875
  Args:
876
- sample (`torch.FloatTensor`):
876
+ sample (`torch.Tensor`):
877
877
  The input sample.
878
878
 
879
879
  Returns:
880
- `torch.FloatTensor`:
880
+ `torch.Tensor`:
881
881
  A scaled input sample.
882
882
  """
883
883
  return sample
884
884
 
885
885
  def add_noise(
886
886
  self,
887
- original_samples: torch.FloatTensor,
888
- noise: torch.FloatTensor,
887
+ original_samples: torch.Tensor,
888
+ noise: torch.Tensor,
889
889
  timesteps: torch.IntTensor,
890
- ) -> torch.FloatTensor:
890
+ ) -> torch.Tensor:
891
891
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
892
892
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
893
893
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):