diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from ...models import AutoencoderOobleck, StableAudioDiTModel
26
26
  from ...models.embeddings import get_1d_rotary_pos_embed
27
27
  from ...schedulers import EDMDPMSolverMultistepScheduler
28
28
  from ...utils import (
29
+ is_torch_xla_available,
29
30
  logging,
30
31
  replace_example_docstring,
31
32
  )
@@ -34,6 +35,13 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
34
35
  from .modeling_stable_audio import StableAudioProjectionModel
35
36
 
36
37
 
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ XLA_AVAILABLE = True
42
+ else:
43
+ XLA_AVAILABLE = False
44
+
37
45
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
46
 
39
47
  EXAMPLE_DOC_STRING = """
@@ -438,7 +446,7 @@ class StableAudioPipeline(DiffusionPipeline):
438
446
  f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
439
447
  )
440
448
 
441
- audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
449
+ audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
442
450
  audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
443
451
 
444
452
  # check num_channels
@@ -726,6 +734,9 @@ class StableAudioPipeline(DiffusionPipeline):
726
734
  step_idx = i // getattr(self.scheduler, "order", 1)
727
735
  callback(step_idx, t, latents)
728
736
 
737
+ if XLA_AVAILABLE:
738
+ xm.mark_step()
739
+
729
740
  # 9. Post-processing
730
741
  if not output_type == "latent":
731
742
  audio = self.vae.decode(latents).sample
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
281
281
  def num_timesteps(self):
282
282
  return self._num_timesteps
283
283
 
284
+ def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
285
+ s = torch.tensor([0.008])
286
+ clamp_range = [0, 1]
287
+ min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
288
+ var = alphas_cumprod[t]
289
+ var = var.clamp(*clamp_range)
290
+ s, min_var = s.to(var.device), min_var.to(var.device)
291
+ ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
292
+ return ratio
293
+
284
294
  @torch.no_grad()
285
295
  @replace_example_docstring(EXAMPLE_DOC_STRING)
286
296
  def __call__(
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
434
444
  batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
435
445
  )
436
446
 
447
+ if isinstance(self.scheduler, DDPMWuerstchenScheduler):
448
+ timesteps = timesteps[:-1]
449
+ else:
450
+ if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
451
+ self.scheduler.config.clip_sample = False # disample sample clipping
452
+ logger.warning(" set `clip_sample` to be False")
453
+
437
454
  # 6. Run denoising loop
438
- self._num_timesteps = len(timesteps[:-1])
439
- for i, t in enumerate(self.progress_bar(timesteps[:-1])):
440
- timestep_ratio = t.expand(latents.size(0)).to(dtype)
455
+ if hasattr(self.scheduler, "betas"):
456
+ alphas = 1.0 - self.scheduler.betas
457
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
458
+ else:
459
+ alphas_cumprod = []
460
+
461
+ self._num_timesteps = len(timesteps)
462
+ for i, t in enumerate(self.progress_bar(timesteps)):
463
+ if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
464
+ if len(alphas_cumprod) > 0:
465
+ timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
466
+ timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
467
+ else:
468
+ timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
469
+ else:
470
+ timestep_ratio = t.expand(latents.size(0)).to(dtype)
441
471
 
442
472
  # 7. Denoise latents
443
473
  predicted_latents = self.decoder(
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
454
484
  predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
455
485
 
456
486
  # 9. Renoise latents to next timestep
487
+ if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
488
+ timestep_ratio = t
457
489
  latents = self.scheduler.step(
458
490
  model_output=predicted_latents,
459
491
  timestep=timestep_ratio,
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
353
353
  return self._num_timesteps
354
354
 
355
355
  def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
356
- s = torch.tensor([0.003])
356
+ s = torch.tensor([0.008])
357
357
  clamp_range = [0, 1]
358
358
  min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
359
359
  var = alphas_cumprod[t]
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
557
557
  if isinstance(self.scheduler, DDPMWuerstchenScheduler):
558
558
  timesteps = timesteps[:-1]
559
559
  else:
560
- if self.scheduler.config.clip_sample:
560
+ if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
561
561
  self.scheduler.config.clip_sample = False # disample sample clipping
562
562
  logger.warning(" set `clip_sample` to be False")
563
563
  # 6. Run denoising loop
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import (
29
29
  USE_PEFT_BACKEND,
30
30
  deprecate,
31
+ is_torch_xla_available,
31
32
  logging,
32
33
  replace_example_docstring,
33
34
  scale_lora_layers,
@@ -39,6 +40,13 @@ from .pipeline_output import StableDiffusionPipelineOutput
39
40
  from .safety_checker import StableDiffusionSafetyChecker
40
41
 
41
42
 
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
42
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
51
 
44
52
  EXAMPLE_DOC_STRING = """
@@ -57,9 +65,21 @@ EXAMPLE_DOC_STRING = """
57
65
 
58
66
 
59
67
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
60
- """
61
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
62
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
68
+ r"""
69
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
70
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
71
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
72
+
73
+ Args:
74
+ noise_cfg (`torch.Tensor`):
75
+ The predicted noise tensor for the guided diffusion process.
76
+ noise_pred_text (`torch.Tensor`):
77
+ The predicted noise tensor for the text-guided diffusion process.
78
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
79
+ A rescale factor applied to the noise predictions.
80
+
81
+ Returns:
82
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
63
83
  """
64
84
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
65
85
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -78,7 +98,7 @@ def retrieve_timesteps(
78
98
  sigmas: Optional[List[float]] = None,
79
99
  **kwargs,
80
100
  ):
81
- """
101
+ r"""
82
102
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
83
103
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
84
104
 
@@ -137,7 +157,7 @@ class StableDiffusionPipeline(
137
157
  IPAdapterMixin,
138
158
  FromSingleFileMixin,
139
159
  ):
140
- r"""
160
+ """
141
161
  Pipeline for text-to-image generation using Stable Diffusion.
142
162
 
143
163
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
@@ -235,7 +255,12 @@ class StableDiffusionPipeline(
235
255
  is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
236
256
  version.parse(unet.config._diffusers_version).base_version
237
257
  ) < version.parse("0.9.0.dev0")
238
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
258
+ self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
259
+ is_unet_sample_size_less_64 = (
260
+ hasattr(unet.config, "sample_size")
261
+ and self._is_unet_config_sample_size_int
262
+ and unet.config.sample_size < 64
263
+ )
239
264
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
240
265
  deprecation_message = (
241
266
  "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -882,8 +907,18 @@ class StableDiffusionPipeline(
882
907
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
883
908
 
884
909
  # 0. Default height and width to unet
885
- height = height or self.unet.config.sample_size * self.vae_scale_factor
886
- width = width or self.unet.config.sample_size * self.vae_scale_factor
910
+ if not height or not width:
911
+ height = (
912
+ self.unet.config.sample_size
913
+ if self._is_unet_config_sample_size_int
914
+ else self.unet.config.sample_size[0]
915
+ )
916
+ width = (
917
+ self.unet.config.sample_size
918
+ if self._is_unet_config_sample_size_int
919
+ else self.unet.config.sample_size[1]
920
+ )
921
+ height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
887
922
  # to deal with lora scaling and other possible forward hooks
888
923
 
889
924
  # 1. Check inputs. Raise error if not correct
@@ -1036,6 +1071,9 @@ class StableDiffusionPipeline(
1036
1071
  step_idx = i // getattr(self.scheduler, "order", 1)
1037
1072
  callback(step_idx, t, latents)
1038
1073
 
1074
+ if XLA_AVAILABLE:
1075
+ xm.mark_step()
1076
+
1039
1077
  if not output_type == "latent":
1040
1078
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1041
1079
  0
@@ -1049,7 +1087,6 @@ class StableDiffusionPipeline(
1049
1087
  do_denormalize = [True] * image.shape[0]
1050
1088
  else:
1051
1089
  do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1052
-
1053
1090
  image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1054
1091
 
1055
1092
  # Offload all models
@@ -119,7 +119,7 @@ def retrieve_timesteps(
119
119
  sigmas: Optional[List[float]] = None,
120
120
  **kwargs,
121
121
  ):
122
- """
122
+ r"""
123
123
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
124
124
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
125
125
 
@@ -60,7 +60,7 @@ def retrieve_timesteps(
60
60
  sigmas: Optional[List[float]] = None,
61
61
  **kwargs,
62
62
  ):
63
- """
63
+ r"""
64
64
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
65
65
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
66
66