diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -32,16 +32,16 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
32
32
  Output class for the scheduler's `step` function output.
33
33
 
34
34
  Args:
35
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
35
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
36
36
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
37
37
  denoising loop.
38
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
39
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
40
40
  `pred_original_sample` can be used to preview progress or for guidance.
41
41
  """
42
42
 
43
- prev_sample: torch.FloatTensor
44
- pred_original_sample: Optional[torch.FloatTensor] = None
43
+ prev_sample: torch.Tensor
44
+ pred_original_sample: Optional[torch.Tensor] = None
45
45
 
46
46
 
47
47
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -79,7 +79,7 @@ def betas_for_alpha_bar(
79
79
  return math.exp(t * -12.0)
80
80
 
81
81
  else:
82
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
82
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
83
83
 
84
84
  betas = []
85
85
  for i in range(num_diffusion_timesteps):
@@ -149,7 +149,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
149
149
  # Glide cosine schedule
150
150
  self.betas = betas_for_alpha_bar(num_train_timesteps)
151
151
  else:
152
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
152
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
153
153
 
154
154
  self.alphas = 1.0 - self.betas
155
155
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -180,7 +180,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
180
180
  @property
181
181
  def step_index(self):
182
182
  """
183
- The index counter for current timestep. It will increae 1 after each scheduler step.
183
+ The index counter for current timestep. It will increase 1 after each scheduler step.
184
184
  """
185
185
  return self._step_index
186
186
 
@@ -202,21 +202,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
202
202
  """
203
203
  self._begin_index = begin_index
204
204
 
205
- def scale_model_input(
206
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
207
- ) -> torch.FloatTensor:
205
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
208
206
  """
209
207
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
210
208
  current timestep.
211
209
 
212
210
  Args:
213
- sample (`torch.FloatTensor`):
211
+ sample (`torch.Tensor`):
214
212
  The input sample.
215
- timestep (`float` or `torch.FloatTensor`):
213
+ timestep (`float` or `torch.Tensor`):
216
214
  The current timestep in the diffusion chain.
217
215
 
218
216
  Returns:
219
- `torch.FloatTensor`:
217
+ `torch.Tensor`:
220
218
  A scaled input sample.
221
219
  """
222
220
 
@@ -288,7 +286,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
288
286
  log_sigmas = np.log(sigmas)
289
287
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
290
288
 
291
- if self.use_karras_sigmas:
289
+ if self.config.use_karras_sigmas:
292
290
  sigmas = self._convert_to_karras(in_sigmas=sigmas)
293
291
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
294
292
 
@@ -351,7 +349,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
351
349
  return t
352
350
 
353
351
  # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
354
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
352
+ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
355
353
  """Constructs the noise schedule of Karras et al. (2022)."""
356
354
 
357
355
  sigma_min: float = in_sigmas[-1].item()
@@ -366,9 +364,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
366
364
 
367
365
  def step(
368
366
  self,
369
- model_output: torch.FloatTensor,
370
- timestep: Union[float, torch.FloatTensor],
371
- sample: torch.FloatTensor,
367
+ model_output: torch.Tensor,
368
+ timestep: Union[float, torch.Tensor],
369
+ sample: torch.Tensor,
372
370
  order: int = 4,
373
371
  return_dict: bool = True,
374
372
  ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
@@ -377,11 +375,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
377
375
  process from the learned model outputs (most often the predicted noise).
378
376
 
379
377
  Args:
380
- model_output (`torch.FloatTensor`):
378
+ model_output (`torch.Tensor`):
381
379
  The direct output from learned diffusion model.
382
- timestep (`float` or `torch.FloatTensor`):
380
+ timestep (`float` or `torch.Tensor`):
383
381
  The current discrete timestep in the diffusion chain.
384
- sample (`torch.FloatTensor`):
382
+ sample (`torch.Tensor`):
385
383
  A current instance of a sample created by the diffusion process.
386
384
  order (`int`, defaults to 4):
387
385
  The order of the linear multistep method.
@@ -444,10 +442,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
444
442
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
445
443
  def add_noise(
446
444
  self,
447
- original_samples: torch.FloatTensor,
448
- noise: torch.FloatTensor,
449
- timesteps: torch.FloatTensor,
450
- ) -> torch.FloatTensor:
445
+ original_samples: torch.Tensor,
446
+ noise: torch.Tensor,
447
+ timesteps: torch.Tensor,
448
+ ) -> torch.Tensor:
451
449
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
452
450
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
453
451
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -461,7 +459,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
461
459
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
462
460
  if self.begin_index is None:
463
461
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
462
+ elif self.step_index is not None:
463
+ # add_noise is called after first denoising step (for inpainting)
464
+ step_indices = [self.step_index] * timesteps.shape[0]
464
465
  else:
466
+ # add noise is called before first denoising step to create initial latent(img2img)
465
467
  step_indices = [self.begin_index] * timesteps.shape[0]
466
468
 
467
469
  sigma = sigmas[step_indices].flatten()
@@ -59,7 +59,7 @@ def betas_for_alpha_bar(
59
59
  return math.exp(t * -12.0)
60
60
 
61
61
  else:
62
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
62
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
63
63
 
64
64
  betas = []
65
65
  for i in range(num_diffusion_timesteps):
@@ -135,7 +135,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
135
135
  # Glide cosine schedule
136
136
  self.betas = betas_for_alpha_bar(num_train_timesteps)
137
137
  else:
138
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
138
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
139
139
 
140
140
  self.alphas = 1.0 - self.betas
141
141
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -225,9 +225,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
225
225
 
226
226
  def step(
227
227
  self,
228
- model_output: torch.FloatTensor,
228
+ model_output: torch.Tensor,
229
229
  timestep: int,
230
- sample: torch.FloatTensor,
230
+ sample: torch.Tensor,
231
231
  return_dict: bool = True,
232
232
  ) -> Union[SchedulerOutput, Tuple]:
233
233
  """
@@ -236,11 +236,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
236
236
  or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`.
237
237
 
238
238
  Args:
239
- model_output (`torch.FloatTensor`):
239
+ model_output (`torch.Tensor`):
240
240
  The direct output from learned diffusion model.
241
241
  timestep (`int`):
242
242
  The current discrete timestep in the diffusion chain.
243
- sample (`torch.FloatTensor`):
243
+ sample (`torch.Tensor`):
244
244
  A current instance of a sample created by the diffusion process.
245
245
  return_dict (`bool`):
246
246
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
@@ -258,9 +258,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
258
258
 
259
259
  def step_prk(
260
260
  self,
261
- model_output: torch.FloatTensor,
261
+ model_output: torch.Tensor,
262
262
  timestep: int,
263
- sample: torch.FloatTensor,
263
+ sample: torch.Tensor,
264
264
  return_dict: bool = True,
265
265
  ) -> Union[SchedulerOutput, Tuple]:
266
266
  """
@@ -269,11 +269,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
269
269
  equation.
270
270
 
271
271
  Args:
272
- model_output (`torch.FloatTensor`):
272
+ model_output (`torch.Tensor`):
273
273
  The direct output from learned diffusion model.
274
274
  timestep (`int`):
275
275
  The current discrete timestep in the diffusion chain.
276
- sample (`torch.FloatTensor`):
276
+ sample (`torch.Tensor`):
277
277
  A current instance of a sample created by the diffusion process.
278
278
  return_dict (`bool`):
279
279
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
@@ -318,9 +318,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
318
318
 
319
319
  def step_plms(
320
320
  self,
321
- model_output: torch.FloatTensor,
321
+ model_output: torch.Tensor,
322
322
  timestep: int,
323
- sample: torch.FloatTensor,
323
+ sample: torch.Tensor,
324
324
  return_dict: bool = True,
325
325
  ) -> Union[SchedulerOutput, Tuple]:
326
326
  """
@@ -328,11 +328,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
328
328
  the linear multistep method. It performs one forward pass multiple times to approximate the solution.
329
329
 
330
330
  Args:
331
- model_output (`torch.FloatTensor`):
331
+ model_output (`torch.Tensor`):
332
332
  The direct output from learned diffusion model.
333
333
  timestep (`int`):
334
334
  The current discrete timestep in the diffusion chain.
335
- sample (`torch.FloatTensor`):
335
+ sample (`torch.Tensor`):
336
336
  A current instance of a sample created by the diffusion process.
337
337
  return_dict (`bool`):
338
338
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
@@ -387,17 +387,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
387
387
 
388
388
  return SchedulerOutput(prev_sample=prev_sample)
389
389
 
390
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
390
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
391
391
  """
392
392
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
393
393
  current timestep.
394
394
 
395
395
  Args:
396
- sample (`torch.FloatTensor`):
396
+ sample (`torch.Tensor`):
397
397
  The input sample.
398
398
 
399
399
  Returns:
400
- `torch.FloatTensor`:
400
+ `torch.Tensor`:
401
401
  A scaled input sample.
402
402
  """
403
403
  return sample
@@ -448,10 +448,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
448
448
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
449
449
  def add_noise(
450
450
  self,
451
- original_samples: torch.FloatTensor,
452
- noise: torch.FloatTensor,
451
+ original_samples: torch.Tensor,
452
+ noise: torch.Tensor,
453
453
  timesteps: torch.IntTensor,
454
- ) -> torch.FloatTensor:
454
+ ) -> torch.Tensor:
455
455
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
456
456
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
457
457
  # for the subsequent add_noise calls
@@ -31,16 +31,16 @@ class RePaintSchedulerOutput(BaseOutput):
31
31
  Output class for the scheduler's step function output.
32
32
 
33
33
  Args:
34
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
34
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
35
35
  Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
36
36
  denoising loop.
37
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38
38
  The predicted denoised sample (x_{0}) based on the model output from
39
39
  the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
40
40
  """
41
41
 
42
- prev_sample: torch.FloatTensor
43
- pred_original_sample: torch.FloatTensor
42
+ prev_sample: torch.Tensor
43
+ pred_original_sample: torch.Tensor
44
44
 
45
45
 
46
46
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -78,7 +78,7 @@ def betas_for_alpha_bar(
78
78
  return math.exp(t * -12.0)
79
79
 
80
80
  else:
81
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
81
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
82
82
 
83
83
  betas = []
84
84
  for i in range(num_diffusion_timesteps):
@@ -143,7 +143,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
143
143
  betas = torch.linspace(-6, 6, num_train_timesteps)
144
144
  self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
145
145
  else:
146
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
146
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
147
147
 
148
148
  self.alphas = 1.0 - self.betas
149
149
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -160,19 +160,19 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
160
160
 
161
161
  self.eta = eta
162
162
 
163
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
163
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
164
164
  """
165
165
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
166
166
  current timestep.
167
167
 
168
168
  Args:
169
- sample (`torch.FloatTensor`):
169
+ sample (`torch.Tensor`):
170
170
  The input sample.
171
171
  timestep (`int`, *optional*):
172
172
  The current timestep in the diffusion chain.
173
173
 
174
174
  Returns:
175
- `torch.FloatTensor`:
175
+ `torch.Tensor`:
176
176
  A scaled input sample.
177
177
  """
178
178
  return sample
@@ -245,11 +245,11 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
245
245
 
246
246
  def step(
247
247
  self,
248
- model_output: torch.FloatTensor,
248
+ model_output: torch.Tensor,
249
249
  timestep: int,
250
- sample: torch.FloatTensor,
251
- original_image: torch.FloatTensor,
252
- mask: torch.FloatTensor,
250
+ sample: torch.Tensor,
251
+ original_image: torch.Tensor,
252
+ mask: torch.Tensor,
253
253
  generator: Optional[torch.Generator] = None,
254
254
  return_dict: bool = True,
255
255
  ) -> Union[RePaintSchedulerOutput, Tuple]:
@@ -258,15 +258,15 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
258
258
  process from the learned model outputs (most often the predicted noise).
259
259
 
260
260
  Args:
261
- model_output (`torch.FloatTensor`):
261
+ model_output (`torch.Tensor`):
262
262
  The direct output from learned diffusion model.
263
263
  timestep (`int`):
264
264
  The current discrete timestep in the diffusion chain.
265
- sample (`torch.FloatTensor`):
265
+ sample (`torch.Tensor`):
266
266
  A current instance of a sample created by the diffusion process.
267
- original_image (`torch.FloatTensor`):
267
+ original_image (`torch.Tensor`):
268
268
  The original image to inpaint on.
269
- mask (`torch.FloatTensor`):
269
+ mask (`torch.Tensor`):
270
270
  The mask where a value of 0.0 indicates which part of the original image to inpaint.
271
271
  generator (`torch.Generator`, *optional*):
272
272
  A random number generator.
@@ -351,10 +351,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
351
351
 
352
352
  def add_noise(
353
353
  self,
354
- original_samples: torch.FloatTensor,
355
- noise: torch.FloatTensor,
354
+ original_samples: torch.Tensor,
355
+ noise: torch.Tensor,
356
356
  timesteps: torch.IntTensor,
357
- ) -> torch.FloatTensor:
357
+ ) -> torch.Tensor:
358
358
  raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
359
359
 
360
360
  def __len__(self):
@@ -62,7 +62,7 @@ def betas_for_alpha_bar(
62
62
  return math.exp(t * -12.0)
63
63
 
64
64
  else:
65
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
65
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
66
66
 
67
67
  betas = []
68
68
  for i in range(num_diffusion_timesteps):
@@ -92,19 +92,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
92
92
  trained_betas (`np.ndarray`, *optional*):
93
93
  Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
94
94
  predictor_order (`int`, defaults to 2):
95
- The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided
96
- sampling, and `predictor_order=3` for unconditional sampling.
95
+ The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for
96
+ guided sampling, and `predictor_order=3` for unconditional sampling.
97
97
  corrector_order (`int`, defaults to 2):
98
- The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided
99
- sampling, and `corrector_order=3` for unconditional sampling.
98
+ The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for
99
+ guided sampling, and `corrector_order=3` for unconditional sampling.
100
100
  prediction_type (`str`, defaults to `epsilon`, *optional*):
101
101
  Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
102
102
  `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
103
103
  Video](https://imagen.research.google/video/paper.pdf) paper).
104
104
  tau_func (`Callable`, *optional*):
105
- Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`. SA-Solver
106
- will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample from vanilla
107
- diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check https://arxiv.org/abs/2309.05019
105
+ Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`.
106
+ SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample
107
+ from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check
108
+ https://arxiv.org/abs/2309.05019
108
109
  thresholding (`bool`, defaults to `False`):
109
110
  Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
110
111
  as Stable Diffusion.
@@ -114,8 +115,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
114
115
  The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
115
116
  `algorithm_type="dpmsolver++"`.
116
117
  algorithm_type (`str`, defaults to `data_prediction`):
117
- Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction`
118
- with `solver_order=2` for guided sampling like in Stable Diffusion.
118
+ Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use
119
+ `data_prediction` with `solver_order=2` for guided sampling like in Stable Diffusion.
119
120
  lower_order_final (`bool`, defaults to `True`):
120
121
  Whether to use lower-order solvers in the final steps. Default = True.
121
122
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -179,7 +180,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
179
180
  # Glide cosine schedule
180
181
  self.betas = betas_for_alpha_bar(num_train_timesteps)
181
182
  else:
182
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
183
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
183
184
 
184
185
  self.alphas = 1.0 - self.betas
185
186
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -193,7 +194,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
193
194
  self.init_noise_sigma = 1.0
194
195
 
195
196
  if algorithm_type not in ["data_prediction", "noise_prediction"]:
196
- raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
197
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
197
198
 
198
199
  # setable values
199
200
  self.num_inference_steps = None
@@ -216,7 +217,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
216
217
  @property
217
218
  def step_index(self):
218
219
  """
219
- The index counter for current timestep. It will increae 1 after each scheduler step.
220
+ The index counter for current timestep. It will increase 1 after each scheduler step.
220
221
  """
221
222
  return self._step_index
222
223
 
@@ -304,7 +305,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
304
305
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
305
306
 
306
307
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
307
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
308
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
308
309
  """
309
310
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
310
311
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -369,7 +370,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
369
370
  return alpha_t, sigma_t
370
371
 
371
372
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
372
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
373
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
373
374
  """Constructs the noise schedule of Karras et al. (2022)."""
374
375
 
375
376
  # Hack to make sure that other schedulers which copy this function don't break
@@ -396,31 +397,31 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
396
397
 
397
398
  def convert_model_output(
398
399
  self,
399
- model_output: torch.FloatTensor,
400
+ model_output: torch.Tensor,
400
401
  *args,
401
- sample: torch.FloatTensor = None,
402
+ sample: torch.Tensor = None,
402
403
  **kwargs,
403
- ) -> torch.FloatTensor:
404
+ ) -> torch.Tensor:
404
405
  """
405
- Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs. Noise_prediction is
406
- designed to discretize an integral of the noise prediction model, and data_prediction is designed to discretize an
407
- integral of the data prediction model.
406
+ Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs.
407
+ Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
408
+ designed to discretize an integral of the data prediction model.
408
409
 
409
410
  <Tip>
410
411
 
411
- The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both noise
412
- prediction and data prediction models.
412
+ The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both
413
+ noise prediction and data prediction models.
413
414
 
414
415
  </Tip>
415
416
 
416
417
  Args:
417
- model_output (`torch.FloatTensor`):
418
+ model_output (`torch.Tensor`):
418
419
  The direct output from the learned diffusion model.
419
- sample (`torch.FloatTensor`):
420
+ sample (`torch.Tensor`):
420
421
  A current instance of a sample created by the diffusion process.
421
422
 
422
423
  Returns:
423
- `torch.FloatTensor`:
424
+ `torch.Tensor`:
424
425
  The converted model output.
425
426
  """
426
427
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -685,29 +686,29 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
685
686
 
686
687
  def stochastic_adams_bashforth_update(
687
688
  self,
688
- model_output: torch.FloatTensor,
689
+ model_output: torch.Tensor,
689
690
  *args,
690
- sample: torch.FloatTensor,
691
- noise: torch.FloatTensor,
691
+ sample: torch.Tensor,
692
+ noise: torch.Tensor,
692
693
  order: int,
693
- tau: torch.FloatTensor,
694
+ tau: torch.Tensor,
694
695
  **kwargs,
695
- ) -> torch.FloatTensor:
696
+ ) -> torch.Tensor:
696
697
  """
697
698
  One step for the SA-Predictor.
698
699
 
699
700
  Args:
700
- model_output (`torch.FloatTensor`):
701
+ model_output (`torch.Tensor`):
701
702
  The direct output from the learned diffusion model at the current timestep.
702
703
  prev_timestep (`int`):
703
704
  The previous discrete timestep in the diffusion chain.
704
- sample (`torch.FloatTensor`):
705
+ sample (`torch.Tensor`):
705
706
  A current instance of a sample created by the diffusion process.
706
707
  order (`int`):
707
708
  The order of SA-Predictor at this timestep.
708
709
 
709
710
  Returns:
710
- `torch.FloatTensor`:
711
+ `torch.Tensor`:
711
712
  The sample tensor at the previous timestep.
712
713
  """
713
714
  prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
@@ -812,32 +813,32 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
812
813
 
813
814
  def stochastic_adams_moulton_update(
814
815
  self,
815
- this_model_output: torch.FloatTensor,
816
+ this_model_output: torch.Tensor,
816
817
  *args,
817
- last_sample: torch.FloatTensor,
818
- last_noise: torch.FloatTensor,
819
- this_sample: torch.FloatTensor,
818
+ last_sample: torch.Tensor,
819
+ last_noise: torch.Tensor,
820
+ this_sample: torch.Tensor,
820
821
  order: int,
821
- tau: torch.FloatTensor,
822
+ tau: torch.Tensor,
822
823
  **kwargs,
823
- ) -> torch.FloatTensor:
824
+ ) -> torch.Tensor:
824
825
  """
825
826
  One step for the SA-Corrector.
826
827
 
827
828
  Args:
828
- this_model_output (`torch.FloatTensor`):
829
+ this_model_output (`torch.Tensor`):
829
830
  The model outputs at `x_t`.
830
831
  this_timestep (`int`):
831
832
  The current timestep `t`.
832
- last_sample (`torch.FloatTensor`):
833
+ last_sample (`torch.Tensor`):
833
834
  The generated sample before the last predictor `x_{t-1}`.
834
- this_sample (`torch.FloatTensor`):
835
+ this_sample (`torch.Tensor`):
835
836
  The generated sample after the last predictor `x_{t}`.
836
837
  order (`int`):
837
838
  The order of SA-Corrector at this step.
838
839
 
839
840
  Returns:
840
- `torch.FloatTensor`:
841
+ `torch.Tensor`:
841
842
  The corrected sample tensor at the current timestep.
842
843
  """
843
844
 
@@ -978,9 +979,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
978
979
 
979
980
  def step(
980
981
  self,
981
- model_output: torch.FloatTensor,
982
+ model_output: torch.Tensor,
982
983
  timestep: int,
983
- sample: torch.FloatTensor,
984
+ sample: torch.Tensor,
984
985
  generator=None,
985
986
  return_dict: bool = True,
986
987
  ) -> Union[SchedulerOutput, Tuple]:
@@ -989,11 +990,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
989
990
  the SA-Solver.
990
991
 
991
992
  Args:
992
- model_output (`torch.FloatTensor`):
993
+ model_output (`torch.Tensor`):
993
994
  The direct output from learned diffusion model.
994
995
  timestep (`int`):
995
996
  The current discrete timestep in the diffusion chain.
996
- sample (`torch.FloatTensor`):
997
+ sample (`torch.Tensor`):
997
998
  A current instance of a sample created by the diffusion process.
998
999
  generator (`torch.Generator`, *optional*):
999
1000
  A random number generator.
@@ -1078,17 +1079,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
1078
1079
 
1079
1080
  return SchedulerOutput(prev_sample=prev_sample)
1080
1081
 
1081
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1082
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1082
1083
  """
1083
1084
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
1084
1085
  current timestep.
1085
1086
 
1086
1087
  Args:
1087
- sample (`torch.FloatTensor`):
1088
+ sample (`torch.Tensor`):
1088
1089
  The input sample.
1089
1090
 
1090
1091
  Returns:
1091
- `torch.FloatTensor`:
1092
+ `torch.Tensor`:
1092
1093
  A scaled input sample.
1093
1094
  """
1094
1095
  return sample
@@ -1096,10 +1097,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
1096
1097
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1097
1098
  def add_noise(
1098
1099
  self,
1099
- original_samples: torch.FloatTensor,
1100
- noise: torch.FloatTensor,
1100
+ original_samples: torch.Tensor,
1101
+ noise: torch.Tensor,
1101
1102
  timesteps: torch.IntTensor,
1102
- ) -> torch.FloatTensor:
1103
+ ) -> torch.Tensor:
1103
1104
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1104
1105
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
1105
1106
  # for the subsequent add_noise calls