diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@ from ...schedulers import DDPMScheduler
16
16
  from ...utils import (
17
17
  BACKENDS_MAPPING,
18
18
  PIL_INTERPOLATION,
19
- is_accelerate_available,
20
19
  is_bs4_available,
21
20
  is_ftfy_available,
22
21
  logging,
@@ -145,6 +144,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
145
144
 
146
145
  model_cpu_offload_seq = "text_encoder->unet"
147
146
  _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
147
+ _exclude_from_cpu_offload = ["watermarker"]
148
148
 
149
149
  def __init__(
150
150
  self,
@@ -193,21 +193,6 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
193
193
  )
194
194
  self.register_to_config(requires_safety_checker=requires_safety_checker)
195
195
 
196
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
197
- def remove_all_hooks(self):
198
- if is_accelerate_available():
199
- from accelerate.hooks import remove_hook_from_module
200
- else:
201
- raise ImportError("Please install accelerate via `pip install accelerate`")
202
-
203
- for model in [self.text_encoder, self.unet, self.safety_checker]:
204
- if model is not None:
205
- remove_hook_from_module(model, recurse=True)
206
-
207
- self.unet_offload_hook = None
208
- self.text_encoder_offload_hook = None
209
- self.final_offload_hook = None
210
-
211
196
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
212
197
  def _text_preprocessing(self, text, clean_caption=False):
213
198
  if clean_caption and not is_bs4_available():
@@ -357,8 +342,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
357
342
  num_images_per_prompt: int = 1,
358
343
  device: Optional[torch.device] = None,
359
344
  negative_prompt: Optional[Union[str, List[str]]] = None,
360
- prompt_embeds: Optional[torch.FloatTensor] = None,
361
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
345
+ prompt_embeds: Optional[torch.Tensor] = None,
346
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
362
347
  clean_caption: bool = False,
363
348
  ):
364
349
  r"""
@@ -377,10 +362,10 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
377
362
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
378
363
  `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
379
364
  Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
380
- prompt_embeds (`torch.FloatTensor`, *optional*):
365
+ prompt_embeds (`torch.Tensor`, *optional*):
381
366
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
382
367
  provided, text embeddings will be generated from `prompt` input argument.
383
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
368
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
384
369
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
385
370
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
386
371
  argument.
@@ -515,9 +500,6 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
515
500
  nsfw_detected = None
516
501
  watermark_detected = None
517
502
 
518
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
519
- self.unet_offload_hook.offload()
520
-
521
503
  return image, nsfw_detected, watermark_detected
522
504
 
523
505
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
@@ -597,7 +579,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
597
579
  and not isinstance(check_image_type, np.ndarray)
598
580
  ):
599
581
  raise ValueError(
600
- "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
582
+ "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
601
583
  f" {type(check_image_type)}"
602
584
  )
603
585
 
@@ -628,7 +610,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
628
610
  and not isinstance(check_image_type, np.ndarray)
629
611
  ):
630
612
  raise ValueError(
631
- "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
613
+ "`original_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
632
614
  f" {type(check_image_type)}"
633
615
  )
634
616
 
@@ -661,7 +643,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
661
643
  and not isinstance(check_image_type, np.ndarray)
662
644
  ):
663
645
  raise ValueError(
664
- "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
646
+ "`mask_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
665
647
  f" {type(check_image_type)}"
666
648
  )
667
649
 
@@ -698,7 +680,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
698
680
 
699
681
  for image_ in image:
700
682
  image_ = image_.convert("RGB")
701
- image_ = resize(image_, self.unet.sample_size)
683
+ image_ = resize(image_, self.unet.config.sample_size)
702
684
  image_ = np.array(image_)
703
685
  image_ = image_.astype(np.float32)
704
686
  image_ = image_ / 127.5 - 1
@@ -778,7 +760,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
778
760
 
779
761
  for mask_image_ in mask_image:
780
762
  mask_image_ = mask_image_.convert("L")
781
- mask_image_ = resize(mask_image_, self.unet.sample_size)
763
+ mask_image_ = resize(mask_image_, self.unet.config.sample_size)
782
764
  mask_image_ = np.array(mask_image_)
783
765
  mask_image_ = mask_image_[None, None, :]
784
766
  new_mask_image.append(mask_image_)
@@ -800,13 +782,15 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
800
782
 
801
783
  return mask_image
802
784
 
803
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
785
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
804
786
  def get_timesteps(self, num_inference_steps, strength):
805
787
  # get the original timestep using init_timestep
806
788
  init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
807
789
 
808
790
  t_start = max(num_inference_steps - init_timestep, 0)
809
- timesteps = self.scheduler.timesteps[t_start:]
791
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
792
+ if hasattr(self.scheduler, "set_begin_index"):
793
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
810
794
 
811
795
  return timesteps, num_inference_steps - t_start
812
796
 
@@ -839,7 +823,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
839
823
  @replace_example_docstring(EXAMPLE_DOC_STRING)
840
824
  def __call__(
841
825
  self,
842
- image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor],
826
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
843
827
  original_image: Union[
844
828
  PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
845
829
  ] = None,
@@ -855,11 +839,11 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
855
839
  num_images_per_prompt: Optional[int] = 1,
856
840
  eta: float = 0.0,
857
841
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
858
- prompt_embeds: Optional[torch.FloatTensor] = None,
859
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
842
+ prompt_embeds: Optional[torch.Tensor] = None,
843
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
860
844
  output_type: Optional[str] = "pil",
861
845
  return_dict: bool = True,
862
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
846
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
863
847
  callback_steps: int = 1,
864
848
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
865
849
  noise_level: int = 0,
@@ -869,10 +853,10 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
869
853
  Function invoked when calling the pipeline for generation.
870
854
 
871
855
  Args:
872
- image (`torch.FloatTensor` or `PIL.Image.Image`):
856
+ image (`torch.Tensor` or `PIL.Image.Image`):
873
857
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
874
858
  process.
875
- original_image (`torch.FloatTensor` or `PIL.Image.Image`):
859
+ original_image (`torch.Tensor` or `PIL.Image.Image`):
876
860
  The original image that `image` was varied from.
877
861
  mask_image (`PIL.Image.Image`):
878
862
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
@@ -912,10 +896,10 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
912
896
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
913
897
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
914
898
  to make generation deterministic.
915
- prompt_embeds (`torch.FloatTensor`, *optional*):
899
+ prompt_embeds (`torch.Tensor`, *optional*):
916
900
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
917
901
  provided, text embeddings will be generated from `prompt` input argument.
918
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
902
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
919
903
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
920
904
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
921
905
  argument.
@@ -926,7 +910,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
926
910
  Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
927
911
  callback (`Callable`, *optional*):
928
912
  A function that will be called every `callback_steps` steps during inference. The function will be
929
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
913
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
930
914
  callback_steps (`int`, *optional*, defaults to 1):
931
915
  The frequency at which the `callback` function will be called. If not specified, the callback will be
932
916
  called at every step.
@@ -15,7 +15,6 @@ from ...models import UNet2DConditionModel
15
15
  from ...schedulers import DDPMScheduler
16
16
  from ...utils import (
17
17
  BACKENDS_MAPPING,
18
- is_accelerate_available,
19
18
  is_bs4_available,
20
19
  is_ftfy_available,
21
20
  logging,
@@ -101,6 +100,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
101
100
 
102
101
  _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
103
102
  model_cpu_offload_seq = "text_encoder->unet"
103
+ _exclude_from_cpu_offload = ["watermarker"]
104
104
 
105
105
  def __init__(
106
106
  self,
@@ -149,21 +149,6 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
149
149
  )
150
150
  self.register_to_config(requires_safety_checker=requires_safety_checker)
151
151
 
152
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
153
- def remove_all_hooks(self):
154
- if is_accelerate_available():
155
- from accelerate.hooks import remove_hook_from_module
156
- else:
157
- raise ImportError("Please install accelerate via `pip install accelerate`")
158
-
159
- for model in [self.text_encoder, self.unet, self.safety_checker]:
160
- if model is not None:
161
- remove_hook_from_module(model, recurse=True)
162
-
163
- self.unet_offload_hook = None
164
- self.text_encoder_offload_hook = None
165
- self.final_offload_hook = None
166
-
167
152
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
168
153
  def _text_preprocessing(self, text, clean_caption=False):
169
154
  if clean_caption and not is_bs4_available():
@@ -313,8 +298,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
313
298
  num_images_per_prompt: int = 1,
314
299
  device: Optional[torch.device] = None,
315
300
  negative_prompt: Optional[Union[str, List[str]]] = None,
316
- prompt_embeds: Optional[torch.FloatTensor] = None,
317
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
301
+ prompt_embeds: Optional[torch.Tensor] = None,
302
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
318
303
  clean_caption: bool = False,
319
304
  ):
320
305
  r"""
@@ -333,10 +318,10 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
333
318
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
334
319
  `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
335
320
  Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
336
- prompt_embeds (`torch.FloatTensor`, *optional*):
321
+ prompt_embeds (`torch.Tensor`, *optional*):
337
322
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
338
323
  provided, text embeddings will be generated from `prompt` input argument.
339
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
324
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
340
325
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
341
326
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
342
327
  argument.
@@ -471,9 +456,6 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
471
456
  nsfw_detected = None
472
457
  watermark_detected = None
473
458
 
474
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
475
- self.unet_offload_hook.offload()
476
-
477
459
  return image, nsfw_detected, watermark_detected
478
460
 
479
461
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
@@ -555,7 +537,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
555
537
  and not isinstance(check_image_type, np.ndarray)
556
538
  ):
557
539
  raise ValueError(
558
- "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
540
+ "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
559
541
  f" {type(check_image_type)}"
560
542
  )
561
543
 
@@ -626,7 +608,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
626
608
  prompt: Union[str, List[str]] = None,
627
609
  height: int = None,
628
610
  width: int = None,
629
- image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
611
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor] = None,
630
612
  num_inference_steps: int = 50,
631
613
  timesteps: List[int] = None,
632
614
  guidance_scale: float = 4.0,
@@ -634,11 +616,11 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
634
616
  num_images_per_prompt: Optional[int] = 1,
635
617
  eta: float = 0.0,
636
618
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
637
- prompt_embeds: Optional[torch.FloatTensor] = None,
638
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
619
+ prompt_embeds: Optional[torch.Tensor] = None,
620
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
639
621
  output_type: Optional[str] = "pil",
640
622
  return_dict: bool = True,
641
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
623
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
642
624
  callback_steps: int = 1,
643
625
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
644
626
  noise_level: int = 250,
@@ -655,7 +637,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
655
637
  The height in pixels of the generated image.
656
638
  width (`int`, *optional*, defaults to None):
657
639
  The width in pixels of the generated image.
658
- image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
640
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`):
659
641
  The image to be upscaled.
660
642
  num_inference_steps (`int`, *optional*, defaults to 50):
661
643
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -681,10 +663,10 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
681
663
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
682
664
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
683
665
  to make generation deterministic.
684
- prompt_embeds (`torch.FloatTensor`, *optional*):
666
+ prompt_embeds (`torch.Tensor`, *optional*):
685
667
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
686
668
  provided, text embeddings will be generated from `prompt` input argument.
687
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
669
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
688
670
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
689
671
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
690
672
  argument.
@@ -695,7 +677,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
695
677
  Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
696
678
  callback (`Callable`, *optional*):
697
679
  A function that will be called every `callback_steps` steps during inference. The function will be
698
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
680
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
699
681
  callback_steps (`int`, *optional*, defaults to 1):
700
682
  The frequency at which the `callback` function will be called. If not specified, the callback will be
701
683
  called at every step.
@@ -775,6 +757,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
775
757
  self.scheduler.set_timesteps(num_inference_steps, device=device)
776
758
  timesteps = self.scheduler.timesteps
777
759
 
760
+ if hasattr(self.scheduler, "set_begin_index"):
761
+ self.scheduler.set_begin_index(0)
762
+
778
763
  # 5. Prepare intermediate images
779
764
  num_channels = self.unet.config.in_channels // 2
780
765
  intermediate_images = self.prepare_intermediate_images(
@@ -13,27 +13,27 @@ class TransformationModelOutput(ModelOutput):
13
13
  Base class for text model's outputs that also contains a pooling of the last hidden states.
14
14
 
15
15
  Args:
16
- text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
16
+ text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
17
17
  The text embeddings obtained by applying the projection layer to the pooler_output.
18
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
18
+ last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
19
19
  Sequence of hidden-states at the output of the last layer of the model.
20
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
21
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
22
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
20
+ hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
21
+ Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one
22
+ for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
23
23
 
24
24
  Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
25
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
26
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
25
+ attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
26
+ Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
27
27
  sequence_length)`.
28
28
 
29
29
  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
30
30
  heads.
31
31
  """
32
32
 
33
- projection_state: Optional[torch.FloatTensor] = None
34
- last_hidden_state: torch.FloatTensor = None
35
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
36
- attentions: Optional[Tuple[torch.FloatTensor]] = None
33
+ projection_state: Optional[torch.Tensor] = None
34
+ last_hidden_state: torch.Tensor = None
35
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
36
+ attentions: Optional[Tuple[torch.Tensor]] = None
37
37
 
38
38
 
39
39
  class RobertaSeriesConfig(XLMRobertaConfig):
@@ -79,6 +79,7 @@ def retrieve_timesteps(
79
79
  num_inference_steps: Optional[int] = None,
80
80
  device: Optional[Union[str, torch.device]] = None,
81
81
  timesteps: Optional[List[int]] = None,
82
+ sigmas: Optional[List[float]] = None,
82
83
  **kwargs,
83
84
  ):
84
85
  """
@@ -89,19 +90,23 @@ def retrieve_timesteps(
89
90
  scheduler (`SchedulerMixin`):
90
91
  The scheduler to get timesteps from.
91
92
  num_inference_steps (`int`):
92
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
93
- `timesteps` must be `None`.
93
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
94
+ must be `None`.
94
95
  device (`str` or `torch.device`, *optional*):
95
96
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
96
97
  timesteps (`List[int]`, *optional*):
97
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
98
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
99
- must be `None`.
98
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
99
+ `num_inference_steps` and `sigmas` must be `None`.
100
+ sigmas (`List[float]`, *optional*):
101
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
102
+ `num_inference_steps` and `timesteps` must be `None`.
100
103
 
101
104
  Returns:
102
105
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
106
  second element is the number of inference steps.
104
107
  """
108
+ if timesteps is not None and sigmas is not None:
109
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
105
110
  if timesteps is not None:
106
111
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
107
112
  if not accepts_timesteps:
@@ -112,6 +117,16 @@ def retrieve_timesteps(
112
117
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
113
118
  timesteps = scheduler.timesteps
114
119
  num_inference_steps = len(timesteps)
120
+ elif sigmas is not None:
121
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
122
+ if not accept_sigmas:
123
+ raise ValueError(
124
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
126
+ )
127
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
128
+ timesteps = scheduler.timesteps
129
+ num_inference_steps = len(timesteps)
115
130
  else:
116
131
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
117
132
  timesteps = scheduler.timesteps
@@ -263,8 +278,8 @@ class AltDiffusionPipeline(
263
278
  num_images_per_prompt,
264
279
  do_classifier_free_guidance,
265
280
  negative_prompt=None,
266
- prompt_embeds: Optional[torch.FloatTensor] = None,
267
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
281
+ prompt_embeds: Optional[torch.Tensor] = None,
282
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
268
283
  lora_scale: Optional[float] = None,
269
284
  **kwargs,
270
285
  ):
@@ -295,8 +310,8 @@ class AltDiffusionPipeline(
295
310
  num_images_per_prompt,
296
311
  do_classifier_free_guidance,
297
312
  negative_prompt=None,
298
- prompt_embeds: Optional[torch.FloatTensor] = None,
299
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
313
+ prompt_embeds: Optional[torch.Tensor] = None,
314
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
300
315
  lora_scale: Optional[float] = None,
301
316
  clip_skip: Optional[int] = None,
302
317
  ):
@@ -316,10 +331,10 @@ class AltDiffusionPipeline(
316
331
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
317
332
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
318
333
  less than `1`).
319
- prompt_embeds (`torch.FloatTensor`, *optional*):
334
+ prompt_embeds (`torch.Tensor`, *optional*):
320
335
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
321
336
  provided, text embeddings will be generated from `prompt` input argument.
322
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
337
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
323
338
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
324
339
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
325
340
  argument.
@@ -588,7 +603,12 @@ class AltDiffusionPipeline(
588
603
  )
589
604
 
590
605
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
591
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
606
+ shape = (
607
+ batch_size,
608
+ num_channels_latents,
609
+ int(height) // self.vae_scale_factor,
610
+ int(width) // self.vae_scale_factor,
611
+ )
592
612
  if isinstance(generator, list) and len(generator) != batch_size:
593
613
  raise ValueError(
594
614
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -617,7 +637,7 @@ class AltDiffusionPipeline(
617
637
  data type of the generated embeddings
618
638
 
619
639
  Returns:
620
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
640
+ `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
621
641
  """
622
642
  assert len(w.shape) == 1
623
643
  w = w * 1000.0
@@ -668,14 +688,15 @@ class AltDiffusionPipeline(
668
688
  width: Optional[int] = None,
669
689
  num_inference_steps: int = 50,
670
690
  timesteps: List[int] = None,
691
+ sigmas: List[float] = None,
671
692
  guidance_scale: float = 7.5,
672
693
  negative_prompt: Optional[Union[str, List[str]]] = None,
673
694
  num_images_per_prompt: Optional[int] = 1,
674
695
  eta: float = 0.0,
675
696
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
676
- latents: Optional[torch.FloatTensor] = None,
677
- prompt_embeds: Optional[torch.FloatTensor] = None,
678
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
697
+ latents: Optional[torch.Tensor] = None,
698
+ prompt_embeds: Optional[torch.Tensor] = None,
699
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
679
700
  ip_adapter_image: Optional[PipelineImageInput] = None,
680
701
  output_type: Optional[str] = "pil",
681
702
  return_dict: bool = True,
@@ -717,14 +738,14 @@ class AltDiffusionPipeline(
717
738
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
718
739
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
719
740
  generation deterministic.
720
- latents (`torch.FloatTensor`, *optional*):
741
+ latents (`torch.Tensor`, *optional*):
721
742
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
722
743
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
723
744
  tensor is generated by sampling using the supplied random `generator`.
724
- prompt_embeds (`torch.FloatTensor`, *optional*):
745
+ prompt_embeds (`torch.Tensor`, *optional*):
725
746
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
726
747
  provided, text embeddings are generated from the `prompt` input argument.
727
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
748
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
728
749
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
729
750
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
730
751
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
@@ -843,7 +864,9 @@ class AltDiffusionPipeline(
843
864
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
844
865
 
845
866
  # 4. Prepare timesteps
846
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
867
+ timesteps, num_inference_steps = retrieve_timesteps(
868
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
869
+ )
847
870
 
848
871
  # 5. Prepare latent variables
849
872
  num_channels_latents = self.unet.config.in_channels