diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -19,10 +19,9 @@ import urllib.parse as ul
19
19
  from typing import Callable, List, Optional, Tuple, Union
20
20
 
21
21
  import torch
22
- import torch.nn.functional as F
23
22
  from transformers import T5EncoderModel, T5Tokenizer
24
23
 
25
- from ...image_processor import VaeImageProcessor
24
+ from ...image_processor import PixArtImageProcessor
26
25
  from ...models import AutoencoderKL, Transformer2DModel
27
26
  from ...schedulers import DPMSolverMultistepScheduler
28
27
  from ...utils import (
@@ -176,6 +175,7 @@ def retrieve_timesteps(
176
175
  num_inference_steps: Optional[int] = None,
177
176
  device: Optional[Union[str, torch.device]] = None,
178
177
  timesteps: Optional[List[int]] = None,
178
+ sigmas: Optional[List[float]] = None,
179
179
  **kwargs,
180
180
  ):
181
181
  """
@@ -186,19 +186,23 @@ def retrieve_timesteps(
186
186
  scheduler (`SchedulerMixin`):
187
187
  The scheduler to get timesteps from.
188
188
  num_inference_steps (`int`):
189
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
190
- `timesteps` must be `None`.
189
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
190
+ must be `None`.
191
191
  device (`str` or `torch.device`, *optional*):
192
192
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
193
193
  timesteps (`List[int]`, *optional*):
194
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
195
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
196
- must be `None`.
194
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
195
+ `num_inference_steps` and `sigmas` must be `None`.
196
+ sigmas (`List[float]`, *optional*):
197
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
198
+ `num_inference_steps` and `timesteps` must be `None`.
197
199
 
198
200
  Returns:
199
201
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
200
202
  second element is the number of inference steps.
201
203
  """
204
+ if timesteps is not None and sigmas is not None:
205
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
202
206
  if timesteps is not None:
203
207
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
204
208
  if not accepts_timesteps:
@@ -209,6 +213,16 @@ def retrieve_timesteps(
209
213
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
210
214
  timesteps = scheduler.timesteps
211
215
  num_inference_steps = len(timesteps)
216
+ elif sigmas is not None:
217
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
218
+ if not accept_sigmas:
219
+ raise ValueError(
220
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
221
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
222
+ )
223
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
224
+ timesteps = scheduler.timesteps
225
+ num_inference_steps = len(timesteps)
212
226
  else:
213
227
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
214
228
  timesteps = scheduler.timesteps
@@ -272,16 +286,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
272
286
  )
273
287
 
274
288
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
275
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
276
-
277
- # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
278
- def mask_text_embeddings(self, emb, mask):
279
- if emb.shape[0] == 1:
280
- keep_index = mask.sum().item()
281
- return emb[:, :, :keep_index, :], keep_index
282
- else:
283
- masked_feature = emb * mask[:, None, :, None]
284
- return masked_feature, emb.shape[2]
289
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
285
290
 
286
291
  # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
287
292
  def encode_prompt(
@@ -291,10 +296,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
291
296
  negative_prompt: str = "",
292
297
  num_images_per_prompt: int = 1,
293
298
  device: Optional[torch.device] = None,
294
- prompt_embeds: Optional[torch.FloatTensor] = None,
295
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
296
- prompt_attention_mask: Optional[torch.FloatTensor] = None,
297
- negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
299
+ prompt_embeds: Optional[torch.Tensor] = None,
300
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
301
+ prompt_attention_mask: Optional[torch.Tensor] = None,
302
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
298
303
  clean_caption: bool = False,
299
304
  max_sequence_length: int = 120,
300
305
  **kwargs,
@@ -315,10 +320,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
315
320
  number of images that should be generated per prompt
316
321
  device: (`torch.device`, *optional*):
317
322
  torch device to place the resulting embeddings on
318
- prompt_embeds (`torch.FloatTensor`, *optional*):
323
+ prompt_embeds (`torch.Tensor`, *optional*):
319
324
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
320
325
  provided, text embeddings will be generated from `prompt` input argument.
321
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
326
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
322
327
  Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
323
328
  string.
324
329
  clean_caption (`bool`, defaults to `False`):
@@ -361,7 +366,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
361
366
  ):
362
367
  removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
363
368
  logger.warning(
364
- "The following part of your input was truncated because CLIP can only handle sequences up to"
369
+ "The following part of your input was truncated because T5 can only handle sequences up to"
365
370
  f" {max_length} tokens: {removed_text}"
366
371
  )
367
372
 
@@ -653,7 +658,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
653
658
 
654
659
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
655
660
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
656
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
661
+ shape = (
662
+ batch_size,
663
+ num_channels_latents,
664
+ int(height) // self.vae_scale_factor,
665
+ int(width) // self.vae_scale_factor,
666
+ )
657
667
  if isinstance(generator, list) and len(generator) != batch_size:
658
668
  raise ValueError(
659
669
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -669,38 +679,6 @@ class PixArtAlphaPipeline(DiffusionPipeline):
669
679
  latents = latents * self.scheduler.init_noise_sigma
670
680
  return latents
671
681
 
672
- @staticmethod
673
- def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
674
- """Returns binned height and width."""
675
- ar = float(height / width)
676
- closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
677
- default_hw = ratios[closest_ratio]
678
- return int(default_hw[0]), int(default_hw[1])
679
-
680
- @staticmethod
681
- def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
682
- orig_height, orig_width = samples.shape[2], samples.shape[3]
683
-
684
- # Check if resizing is needed
685
- if orig_height != new_height or orig_width != new_width:
686
- ratio = max(new_height / orig_height, new_width / orig_width)
687
- resized_width = int(orig_width * ratio)
688
- resized_height = int(orig_height * ratio)
689
-
690
- # Resize
691
- samples = F.interpolate(
692
- samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
693
- )
694
-
695
- # Center Crop
696
- start_x = (resized_width - new_width) // 2
697
- end_x = start_x + new_width
698
- start_y = (resized_height - new_height) // 2
699
- end_y = start_y + new_height
700
- samples = samples[:, :, start_y:end_y, start_x:end_x]
701
-
702
- return samples
703
-
704
682
  @torch.no_grad()
705
683
  @replace_example_docstring(EXAMPLE_DOC_STRING)
706
684
  def __call__(
@@ -709,20 +687,21 @@ class PixArtAlphaPipeline(DiffusionPipeline):
709
687
  negative_prompt: str = "",
710
688
  num_inference_steps: int = 20,
711
689
  timesteps: List[int] = None,
690
+ sigmas: List[float] = None,
712
691
  guidance_scale: float = 4.5,
713
692
  num_images_per_prompt: Optional[int] = 1,
714
693
  height: Optional[int] = None,
715
694
  width: Optional[int] = None,
716
695
  eta: float = 0.0,
717
696
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
718
- latents: Optional[torch.FloatTensor] = None,
719
- prompt_embeds: Optional[torch.FloatTensor] = None,
720
- prompt_attention_mask: Optional[torch.FloatTensor] = None,
721
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
722
- negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
697
+ latents: Optional[torch.Tensor] = None,
698
+ prompt_embeds: Optional[torch.Tensor] = None,
699
+ prompt_attention_mask: Optional[torch.Tensor] = None,
700
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
701
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
723
702
  output_type: Optional[str] = "pil",
724
703
  return_dict: bool = True,
725
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
704
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
726
705
  callback_steps: int = 1,
727
706
  clean_caption: bool = True,
728
707
  use_resolution_binning: bool = True,
@@ -744,8 +723,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
744
723
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
745
724
  expense of slower inference.
746
725
  timesteps (`List[int]`, *optional*):
747
- Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
748
- timesteps are used. Must be in descending order.
726
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
727
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
728
+ passed will be used. Must be in descending order.
729
+ sigmas (`List[float]`, *optional*):
730
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
731
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
732
+ will be used.
749
733
  guidance_scale (`float`, *optional*, defaults to 4.5):
750
734
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
751
735
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -764,18 +748,18 @@ class PixArtAlphaPipeline(DiffusionPipeline):
764
748
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
765
749
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
766
750
  to make generation deterministic.
767
- latents (`torch.FloatTensor`, *optional*):
751
+ latents (`torch.Tensor`, *optional*):
768
752
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
769
753
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
770
754
  tensor will ge generated by sampling using the supplied random `generator`.
771
- prompt_embeds (`torch.FloatTensor`, *optional*):
755
+ prompt_embeds (`torch.Tensor`, *optional*):
772
756
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
773
757
  provided, text embeddings will be generated from `prompt` input argument.
774
- prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
775
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
758
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
759
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
776
760
  Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
777
761
  provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
778
- negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
762
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
779
763
  Pre-generated attention mask for negative text embeddings.
780
764
  output_type (`str`, *optional*, defaults to `"pil"`):
781
765
  The output format of the generate image. Choose between
@@ -784,7 +768,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
784
768
  Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
785
769
  callback (`Callable`, *optional*):
786
770
  A function that will be called every `callback_steps` steps during inference. The function will be
787
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
771
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
788
772
  callback_steps (`int`, *optional*, defaults to 1):
789
773
  The frequency at which the `callback` function will be called. If not specified, the callback will be
790
774
  called at every step.
@@ -821,7 +805,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
821
805
  else:
822
806
  raise ValueError("Invalid sample size")
823
807
  orig_height, orig_width = height, width
824
- height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
808
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
825
809
 
826
810
  self.check_inputs(
827
811
  prompt,
@@ -874,7 +858,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
874
858
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
875
859
 
876
860
  # 4. Prepare timesteps
877
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
861
+ timesteps, num_inference_steps = retrieve_timesteps(
862
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
863
+ )
878
864
 
879
865
  # 5. Prepare latents.
880
866
  latent_channels = self.transformer.config.in_channels
@@ -951,7 +937,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
951
937
  noise_pred = noise_pred
952
938
 
953
939
  # compute previous image: x_t -> x_t-1
954
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
940
+ if num_inference_steps == 1:
941
+ # For DMD one step sampling: https://arxiv.org/abs/2311.18828
942
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
943
+ else:
944
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
955
945
 
956
946
  # call the callback, if provided
957
947
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -963,7 +953,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
963
953
  if not output_type == "latent":
964
954
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
965
955
  if use_resolution_binning:
966
- image = self.resize_and_crop_tensor(image, orig_width, orig_height)
956
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
967
957
  else:
968
958
  image = latents
969
959