diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
33
33
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
34
 
35
35
 
36
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
37
+ def retrieve_latents(
38
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
39
+ ):
40
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
41
+ return encoder_output.latent_dist.sample(generator)
42
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
43
+ return encoder_output.latent_dist.mode()
44
+ elif hasattr(encoder_output, "latents"):
45
+ return encoder_output.latents
46
+ else:
47
+ raise AttributeError("Could not access latents of provided encoder_output")
48
+
49
+
36
50
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
37
51
  def preprocess(image):
38
52
  warnings.warn(
@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
105
119
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
106
120
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
107
121
 
108
- def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt):
122
+ def _encode_prompt(
123
+ self,
124
+ prompt,
125
+ device,
126
+ do_classifier_free_guidance,
127
+ negative_prompt=None,
128
+ prompt_embeds: Optional[torch.Tensor] = None,
129
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
130
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
131
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
132
+ **kwargs,
133
+ ):
134
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
135
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
136
+
137
+ (
138
+ prompt_embeds,
139
+ negative_prompt_embeds,
140
+ pooled_prompt_embeds,
141
+ negative_pooled_prompt_embeds,
142
+ ) = self.encode_prompt(
143
+ prompt=prompt,
144
+ device=device,
145
+ do_classifier_free_guidance=do_classifier_free_guidance,
146
+ negative_prompt=negative_prompt,
147
+ prompt_embeds=prompt_embeds,
148
+ negative_prompt_embeds=negative_prompt_embeds,
149
+ pooled_prompt_embeds=pooled_prompt_embeds,
150
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
151
+ **kwargs,
152
+ )
153
+
154
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
155
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
156
+
157
+ return prompt_embeds, pooled_prompt_embeds
158
+
159
+ def encode_prompt(
160
+ self,
161
+ prompt,
162
+ device,
163
+ do_classifier_free_guidance,
164
+ negative_prompt=None,
165
+ prompt_embeds: Optional[torch.Tensor] = None,
166
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
167
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
168
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
169
+ ):
109
170
  r"""
110
171
  Encodes the prompt into text encoder hidden states.
111
172
 
@@ -119,81 +180,100 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
119
180
  negative_prompt (`str` or `List[str]`):
120
181
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
121
182
  if `guidance_scale` is less than `1`).
183
+ prompt_embeds (`torch.FloatTensor`, *optional*):
184
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
185
+ provided, text embeddings will be generated from `prompt` input argument.
186
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
187
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
188
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
189
+ argument.
190
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
191
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
192
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
193
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
194
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
195
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
196
+ input argument.
122
197
  """
123
- batch_size = len(prompt) if isinstance(prompt, list) else 1
124
-
125
- text_inputs = self.tokenizer(
126
- prompt,
127
- padding="max_length",
128
- max_length=self.tokenizer.model_max_length,
129
- truncation=True,
130
- return_length=True,
131
- return_tensors="pt",
132
- )
133
- text_input_ids = text_inputs.input_ids
134
-
135
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
136
-
137
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
138
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
139
- logger.warning(
140
- "The following part of your input was truncated because CLIP can only handle sequences up to"
141
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
142
- )
143
-
144
- text_encoder_out = self.text_encoder(
145
- text_input_ids.to(device),
146
- output_hidden_states=True,
147
- )
148
- text_embeddings = text_encoder_out.hidden_states[-1]
149
- text_pooler_out = text_encoder_out.pooler_output
150
-
151
- # get unconditional embeddings for classifier free guidance
152
- if do_classifier_free_guidance:
153
- uncond_tokens: List[str]
154
- if negative_prompt is None:
155
- uncond_tokens = [""] * batch_size
156
- elif type(prompt) is not type(negative_prompt):
157
- raise TypeError(
158
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
159
- f" {type(prompt)}."
160
- )
161
- elif isinstance(negative_prompt, str):
162
- uncond_tokens = [negative_prompt]
163
- elif batch_size != len(negative_prompt):
164
- raise ValueError(
165
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
166
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
167
- " the batch size of `prompt`."
168
- )
169
- else:
170
- uncond_tokens = negative_prompt
198
+ if prompt is not None and isinstance(prompt, str):
199
+ batch_size = 1
200
+ elif prompt is not None and isinstance(prompt, list):
201
+ batch_size = len(prompt)
202
+ else:
203
+ batch_size = prompt_embeds.shape[0]
171
204
 
172
- max_length = text_input_ids.shape[-1]
173
- uncond_input = self.tokenizer(
174
- uncond_tokens,
205
+ if prompt_embeds is None or pooled_prompt_embeds is None:
206
+ text_inputs = self.tokenizer(
207
+ prompt,
175
208
  padding="max_length",
176
- max_length=max_length,
209
+ max_length=self.tokenizer.model_max_length,
177
210
  truncation=True,
178
211
  return_length=True,
179
212
  return_tensors="pt",
180
213
  )
214
+ text_input_ids = text_inputs.input_ids
181
215
 
182
- uncond_encoder_out = self.text_encoder(
183
- uncond_input.input_ids.to(device),
216
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
217
+
218
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
219
+ text_input_ids, untruncated_ids
220
+ ):
221
+ removed_text = self.tokenizer.batch_decode(
222
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
223
+ )
224
+ logger.warning(
225
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
226
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
227
+ )
228
+
229
+ text_encoder_out = self.text_encoder(
230
+ text_input_ids.to(device),
184
231
  output_hidden_states=True,
185
232
  )
233
+ prompt_embeds = text_encoder_out.hidden_states[-1]
234
+ pooled_prompt_embeds = text_encoder_out.pooler_output
186
235
 
187
- uncond_embeddings = uncond_encoder_out.hidden_states[-1]
188
- uncond_pooler_out = uncond_encoder_out.pooler_output
236
+ # get unconditional embeddings for classifier free guidance
237
+ if do_classifier_free_guidance:
238
+ if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
239
+ uncond_tokens: List[str]
240
+ if negative_prompt is None:
241
+ uncond_tokens = [""] * batch_size
242
+ elif type(prompt) is not type(negative_prompt):
243
+ raise TypeError(
244
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
245
+ f" {type(prompt)}."
246
+ )
247
+ elif isinstance(negative_prompt, str):
248
+ uncond_tokens = [negative_prompt]
249
+ elif batch_size != len(negative_prompt):
250
+ raise ValueError(
251
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
252
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
253
+ " the batch size of `prompt`."
254
+ )
255
+ else:
256
+ uncond_tokens = negative_prompt
257
+
258
+ max_length = text_input_ids.shape[-1]
259
+ uncond_input = self.tokenizer(
260
+ uncond_tokens,
261
+ padding="max_length",
262
+ max_length=max_length,
263
+ truncation=True,
264
+ return_length=True,
265
+ return_tensors="pt",
266
+ )
267
+
268
+ uncond_encoder_out = self.text_encoder(
269
+ uncond_input.input_ids.to(device),
270
+ output_hidden_states=True,
271
+ )
189
272
 
190
- # For classifier free guidance, we need to do two forward passes.
191
- # Here we concatenate the unconditional and text embeddings into a single batch
192
- # to avoid doing two forward passes
193
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
194
- text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
273
+ negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
274
+ negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
195
275
 
196
- return text_embeddings, text_pooler_out
276
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
197
277
 
198
278
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
199
279
  def decode_latents(self, latents):
@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
207
287
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
208
288
  return image
209
289
 
210
- def check_inputs(self, prompt, image, callback_steps):
211
- if not isinstance(prompt, str) and not isinstance(prompt, list):
290
+ def check_inputs(
291
+ self,
292
+ prompt,
293
+ image,
294
+ callback_steps,
295
+ negative_prompt=None,
296
+ prompt_embeds=None,
297
+ negative_prompt_embeds=None,
298
+ pooled_prompt_embeds=None,
299
+ negative_pooled_prompt_embeds=None,
300
+ ):
301
+ if prompt is not None and prompt_embeds is not None:
302
+ raise ValueError(
303
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
304
+ " only forward one of the two."
305
+ )
306
+ elif prompt is None and prompt_embeds is None:
307
+ raise ValueError(
308
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
309
+ )
310
+ elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
212
311
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
213
312
 
313
+ if negative_prompt is not None and negative_prompt_embeds is not None:
314
+ raise ValueError(
315
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
316
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
317
+ )
318
+
319
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
320
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
321
+ raise ValueError(
322
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
323
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
324
+ f" {negative_prompt_embeds.shape}."
325
+ )
326
+
327
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
328
+ raise ValueError(
329
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
330
+ )
331
+
332
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
333
+ raise ValueError(
334
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
335
+ )
336
+
214
337
  if (
215
338
  not isinstance(image, torch.Tensor)
339
+ and not isinstance(image, np.ndarray)
216
340
  and not isinstance(image, PIL.Image.Image)
217
341
  and not isinstance(image, list)
218
342
  ):
@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
222
346
 
223
347
  # verify batch size of prompt and image are same if image is a list or tensor
224
348
  if isinstance(image, (list, torch.Tensor)):
225
- if isinstance(prompt, str):
226
- batch_size = 1
349
+ if prompt is not None:
350
+ if isinstance(prompt, str):
351
+ batch_size = 1
352
+ else:
353
+ batch_size = len(prompt)
227
354
  else:
228
- batch_size = len(prompt)
355
+ batch_size = prompt_embeds.shape[0]
356
+
229
357
  if isinstance(image, list):
230
358
  image_batch_size = len(image)
231
359
  else:
@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
261
389
  @torch.no_grad()
262
390
  def __call__(
263
391
  self,
264
- prompt: Union[str, List[str]],
392
+ prompt: Union[str, List[str]] = None,
265
393
  image: PipelineImageInput = None,
266
394
  num_inference_steps: int = 75,
267
395
  guidance_scale: float = 9.0,
268
396
  negative_prompt: Optional[Union[str, List[str]]] = None,
269
397
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
270
398
  latents: Optional[torch.Tensor] = None,
399
+ prompt_embeds: Optional[torch.Tensor] = None,
400
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
401
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
402
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
271
403
  output_type: Optional[str] = "pil",
272
404
  return_dict: bool = True,
273
405
  callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
359
491
  """
360
492
 
361
493
  # 1. Check inputs
362
- self.check_inputs(prompt, image, callback_steps)
494
+ self.check_inputs(
495
+ prompt,
496
+ image,
497
+ callback_steps,
498
+ negative_prompt,
499
+ prompt_embeds,
500
+ negative_prompt_embeds,
501
+ pooled_prompt_embeds,
502
+ negative_pooled_prompt_embeds,
503
+ )
363
504
 
364
505
  # 2. Define call parameters
365
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
506
+ if prompt is not None:
507
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
508
+ else:
509
+ batch_size = prompt_embeds.shape[0]
366
510
  device = self._execution_device
367
511
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
368
512
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
373
517
  prompt = [""] * batch_size
374
518
 
375
519
  # 3. Encode input prompt
376
- text_embeddings, text_pooler_out = self._encode_prompt(
377
- prompt, device, do_classifier_free_guidance, negative_prompt
520
+ (
521
+ prompt_embeds,
522
+ negative_prompt_embeds,
523
+ pooled_prompt_embeds,
524
+ negative_pooled_prompt_embeds,
525
+ ) = self.encode_prompt(
526
+ prompt,
527
+ device,
528
+ do_classifier_free_guidance,
529
+ negative_prompt,
530
+ prompt_embeds,
531
+ negative_prompt_embeds,
532
+ pooled_prompt_embeds,
533
+ negative_pooled_prompt_embeds,
378
534
  )
379
535
 
536
+ if do_classifier_free_guidance:
537
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
538
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
539
+
380
540
  # 4. Preprocess image
381
541
  image = self.image_processor.preprocess(image)
382
- image = image.to(dtype=text_embeddings.dtype, device=device)
542
+ image = image.to(dtype=prompt_embeds.dtype, device=device)
383
543
  if image.shape[1] == 3:
384
544
  # encode image if not in latent-space yet
385
- image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
545
+ image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
386
546
 
387
547
  # 5. set timesteps
388
548
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
400
560
  inv_noise_level = (noise_level**2 + 1) ** (-0.5)
401
561
 
402
562
  image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
403
- image_cond = image_cond.to(text_embeddings.dtype)
563
+ image_cond = image_cond.to(prompt_embeds.dtype)
404
564
 
405
565
  noise_level_embed = torch.cat(
406
566
  [
407
- torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
408
- torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
567
+ torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
568
+ torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
409
569
  ],
410
570
  dim=1,
411
571
  )
412
572
 
413
- timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1)
573
+ timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
414
574
 
415
575
  # 6. Prepare latent variables
416
576
  height, width = image.shape[2:]
@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
420
580
  num_channels_latents,
421
581
  height * 2, # 2x upscale
422
582
  width * 2,
423
- text_embeddings.dtype,
583
+ prompt_embeds.dtype,
424
584
  device,
425
585
  generator,
426
586
  latents,
@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
454
614
  noise_pred = self.unet(
455
615
  scaled_model_input,
456
616
  timestep,
457
- encoder_hidden_states=text_embeddings,
617
+ encoder_hidden_states=prompt_embeds,
458
618
  timestep_cond=timestep_condition,
459
619
  ).sample
460
620