diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
 
14
14
  import copy
15
15
  import inspect
16
- from typing import Any, Callable, Dict, List, Optional, Union
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
@@ -59,6 +59,81 @@ EXAMPLE_DOC_STRING = """
59
59
  """
60
60
 
61
61
 
62
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
63
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
64
+ """
65
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
66
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
67
+ """
68
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
69
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
70
+ # rescale the results from guidance (fixes overexposure)
71
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
72
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
73
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
74
+ return noise_cfg
75
+
76
+
77
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
78
+ def retrieve_timesteps(
79
+ scheduler,
80
+ num_inference_steps: Optional[int] = None,
81
+ device: Optional[Union[str, torch.device]] = None,
82
+ timesteps: Optional[List[int]] = None,
83
+ sigmas: Optional[List[float]] = None,
84
+ **kwargs,
85
+ ):
86
+ """
87
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
88
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
89
+
90
+ Args:
91
+ scheduler (`SchedulerMixin`):
92
+ The scheduler to get timesteps from.
93
+ num_inference_steps (`int`):
94
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
95
+ must be `None`.
96
+ device (`str` or `torch.device`, *optional*):
97
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
98
+ timesteps (`List[int]`, *optional*):
99
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
100
+ `num_inference_steps` and `sigmas` must be `None`.
101
+ sigmas (`List[float]`, *optional*):
102
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
103
+ `num_inference_steps` and `timesteps` must be `None`.
104
+
105
+ Returns:
106
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
107
+ second element is the number of inference steps.
108
+ """
109
+ if timesteps is not None and sigmas is not None:
110
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
111
+ if timesteps is not None:
112
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
113
+ if not accepts_timesteps:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" timestep schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ elif sigmas is not None:
122
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accept_sigmas:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ else:
132
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
133
+ timesteps = scheduler.timesteps
134
+ return timesteps, num_inference_steps
135
+
136
+
62
137
  class StableDiffusionPanoramaPipeline(
63
138
  DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin
64
139
  ):
@@ -97,6 +172,7 @@ class StableDiffusionPanoramaPipeline(
97
172
  model_cpu_offload_seq = "text_encoder->unet->vae"
98
173
  _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
99
174
  _exclude_from_cpu_offload = ["safety_checker"]
175
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
100
176
 
101
177
  def __init__(
102
178
  self,
@@ -150,8 +226,8 @@ class StableDiffusionPanoramaPipeline(
150
226
  num_images_per_prompt,
151
227
  do_classifier_free_guidance,
152
228
  negative_prompt=None,
153
- prompt_embeds: Optional[torch.FloatTensor] = None,
154
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
229
+ prompt_embeds: Optional[torch.Tensor] = None,
230
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
155
231
  lora_scale: Optional[float] = None,
156
232
  **kwargs,
157
233
  ):
@@ -183,8 +259,8 @@ class StableDiffusionPanoramaPipeline(
183
259
  num_images_per_prompt,
184
260
  do_classifier_free_guidance,
185
261
  negative_prompt=None,
186
- prompt_embeds: Optional[torch.FloatTensor] = None,
187
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
262
+ prompt_embeds: Optional[torch.Tensor] = None,
263
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
188
264
  lora_scale: Optional[float] = None,
189
265
  clip_skip: Optional[int] = None,
190
266
  ):
@@ -204,10 +280,10 @@ class StableDiffusionPanoramaPipeline(
204
280
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
205
281
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
206
282
  less than `1`).
207
- prompt_embeds (`torch.FloatTensor`, *optional*):
283
+ prompt_embeds (`torch.Tensor`, *optional*):
208
284
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
209
285
  provided, text embeddings will be generated from `prompt` input argument.
210
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
286
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
211
287
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
212
288
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
213
289
  argument.
@@ -461,10 +537,23 @@ class StableDiffusionPanoramaPipeline(
461
537
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
462
538
  return image
463
539
 
464
- def decode_latents_with_padding(self, latents, padding=8):
465
- # Add padding to latents for circular inference
466
- # padding is the number of latents to add on each side
467
- # it would slightly increase the memory usage, but remove the boundary artifacts
540
+ def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -> torch.Tensor:
541
+ """
542
+ Decode the given latents with padding for circular inference.
543
+
544
+ Args:
545
+ latents (torch.Tensor): The input latents to decode.
546
+ padding (int, optional): The number of latents to add on each side for padding. Defaults to 8.
547
+
548
+ Returns:
549
+ torch.Tensor: The decoded image with padding removed.
550
+
551
+ Notes:
552
+ - The padding is added to remove boundary artifacts and improve the output quality.
553
+ - This would slightly increase the memory usage.
554
+ - The padding pixels are then removed from the decoded image.
555
+
556
+ """
468
557
  latents = 1 / self.vae.config.scaling_factor * latents
469
558
  latents_left = latents[..., :padding]
470
559
  latents_right = latents[..., -padding:]
@@ -564,7 +653,12 @@ class StableDiffusionPanoramaPipeline(
564
653
 
565
654
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
566
655
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
567
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
656
+ shape = (
657
+ batch_size,
658
+ num_channels_latents,
659
+ int(height) // self.vae_scale_factor,
660
+ int(width) // self.vae_scale_factor,
661
+ )
568
662
  if isinstance(generator, list) and len(generator) != batch_size:
569
663
  raise ValueError(
570
664
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -580,9 +674,62 @@ class StableDiffusionPanoramaPipeline(
580
674
  latents = latents * self.scheduler.init_noise_sigma
581
675
  return latents
582
676
 
583
- def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
584
- # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
585
- # if panorama's height/width < window_size, num_blocks of height/width should return 1
677
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
678
+ def get_guidance_scale_embedding(
679
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
680
+ ) -> torch.Tensor:
681
+ """
682
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
683
+
684
+ Args:
685
+ w (`torch.Tensor`):
686
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
687
+ embedding_dim (`int`, *optional*, defaults to 512):
688
+ Dimension of the embeddings to generate.
689
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
690
+ Data type of the generated embeddings.
691
+
692
+ Returns:
693
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
694
+ """
695
+ assert len(w.shape) == 1
696
+ w = w * 1000.0
697
+
698
+ half_dim = embedding_dim // 2
699
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
700
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
701
+ emb = w.to(dtype)[:, None] * emb[None, :]
702
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
703
+ if embedding_dim % 2 == 1: # zero pad
704
+ emb = torch.nn.functional.pad(emb, (0, 1))
705
+ assert emb.shape == (w.shape[0], embedding_dim)
706
+ return emb
707
+
708
+ def get_views(
709
+ self,
710
+ panorama_height: int,
711
+ panorama_width: int,
712
+ window_size: int = 64,
713
+ stride: int = 8,
714
+ circular_padding: bool = False,
715
+ ) -> List[Tuple[int, int, int, int]]:
716
+ """
717
+ Generates a list of views based on the given parameters. Here, we define the mappings F_i (see Eq. 7 in the
718
+ MultiDiffusion paper https://arxiv.org/abs/2302.08113). If panorama's height/width < window_size, num_blocks of
719
+ height/width should return 1.
720
+
721
+ Args:
722
+ panorama_height (int): The height of the panorama.
723
+ panorama_width (int): The width of the panorama.
724
+ window_size (int, optional): The size of the window. Defaults to 64.
725
+ stride (int, optional): The stride value. Defaults to 8.
726
+ circular_padding (bool, optional): Whether to apply circular padding. Defaults to False.
727
+
728
+ Returns:
729
+ List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains four integers
730
+ representing the start and end coordinates of the window in the panorama.
731
+
732
+ """
586
733
  panorama_height /= 8
587
734
  panorama_width /= 8
588
735
  num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
@@ -600,6 +747,34 @@ class StableDiffusionPanoramaPipeline(
600
747
  views.append((h_start, h_end, w_start, w_end))
601
748
  return views
602
749
 
750
+ @property
751
+ def guidance_scale(self):
752
+ return self._guidance_scale
753
+
754
+ @property
755
+ def guidance_rescale(self):
756
+ return self._guidance_rescale
757
+
758
+ @property
759
+ def cross_attention_kwargs(self):
760
+ return self._cross_attention_kwargs
761
+
762
+ @property
763
+ def clip_skip(self):
764
+ return self._clip_skip
765
+
766
+ @property
767
+ def do_classifier_free_guidance(self):
768
+ return False
769
+
770
+ @property
771
+ def num_timesteps(self):
772
+ return self._num_timesteps
773
+
774
+ @property
775
+ def interrupt(self):
776
+ return self._interrupt
777
+
603
778
  @torch.no_grad()
604
779
  @replace_example_docstring(EXAMPLE_DOC_STRING)
605
780
  def __call__(
@@ -608,24 +783,27 @@ class StableDiffusionPanoramaPipeline(
608
783
  height: Optional[int] = 512,
609
784
  width: Optional[int] = 2048,
610
785
  num_inference_steps: int = 50,
786
+ timesteps: List[int] = None,
611
787
  guidance_scale: float = 7.5,
612
788
  view_batch_size: int = 1,
613
789
  negative_prompt: Optional[Union[str, List[str]]] = None,
614
790
  num_images_per_prompt: Optional[int] = 1,
615
791
  eta: float = 0.0,
616
792
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
617
- latents: Optional[torch.FloatTensor] = None,
618
- prompt_embeds: Optional[torch.FloatTensor] = None,
619
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
793
+ latents: Optional[torch.Tensor] = None,
794
+ prompt_embeds: Optional[torch.Tensor] = None,
795
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
620
796
  ip_adapter_image: Optional[PipelineImageInput] = None,
621
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
797
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
622
798
  output_type: Optional[str] = "pil",
623
799
  return_dict: bool = True,
624
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
625
- callback_steps: Optional[int] = 1,
626
800
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
801
+ guidance_rescale: float = 0.0,
627
802
  circular_padding: bool = False,
628
803
  clip_skip: Optional[int] = None,
804
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
805
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
806
+ **kwargs: Any,
629
807
  ):
630
808
  r"""
631
809
  The call function to the pipeline for generation.
@@ -641,6 +819,9 @@ class StableDiffusionPanoramaPipeline(
641
819
  num_inference_steps (`int`, *optional*, defaults to 50):
642
820
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
643
821
  expense of slower inference.
822
+ timesteps (`List[int]`, *optional*):
823
+ The timesteps at which to generate the images. If not specified, then the default timestep spacing
824
+ strategy of the scheduler is used.
644
825
  guidance_scale (`float`, *optional*, defaults to 7.5):
645
826
  A higher guidance scale value encourages the model to generate images closely linked to the text
646
827
  `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -658,38 +839,34 @@ class StableDiffusionPanoramaPipeline(
658
839
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
659
840
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
660
841
  generation deterministic.
661
- latents (`torch.FloatTensor`, *optional*):
842
+ latents (`torch.Tensor`, *optional*):
662
843
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
663
844
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
664
845
  tensor is generated by sampling using the supplied random `generator`.
665
- prompt_embeds (`torch.FloatTensor`, *optional*):
846
+ prompt_embeds (`torch.Tensor`, *optional*):
666
847
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
667
848
  provided, text embeddings are generated from the `prompt` input argument.
668
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
849
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
669
850
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
670
851
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
671
852
  ip_adapter_image: (`PipelineImageInput`, *optional*):
672
853
  Optional image input to work with IP Adapters.
673
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
674
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
675
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
676
- if `do_classifier_free_guidance` is set to `True`.
677
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
854
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
855
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
856
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
857
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
858
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
678
859
  output_type (`str`, *optional*, defaults to `"pil"`):
679
860
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
680
861
  return_dict (`bool`, *optional*, defaults to `True`):
681
862
  Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
682
863
  plain tuple.
683
- callback (`Callable`, *optional*):
684
- A function that calls every `callback_steps` steps during inference. The function is called with the
685
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
686
- callback_steps (`int`, *optional*, defaults to 1):
687
- The frequency at which the `callback` function is called. If not specified, the callback is called at
688
- every step.
689
864
  cross_attention_kwargs (`dict`, *optional*):
690
865
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
691
866
  `self.processor` in
692
867
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
868
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
869
+ A rescaling factor for the guidance embeddings. A value of 0.0 means no rescaling is applied.
693
870
  circular_padding (`bool`, *optional*, defaults to `False`):
694
871
  If set to `True`, circular padding is applied to ensure there are no stitching artifacts. Circular
695
872
  padding allows the model to seamlessly generate a transition from the rightmost part of the image to
@@ -697,6 +874,15 @@ class StableDiffusionPanoramaPipeline(
697
874
  clip_skip (`int`, *optional*):
698
875
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
699
876
  the output of the pre-final layer will be used for computing the prompt embeddings.
877
+ callback_on_step_end (`Callable`, *optional*):
878
+ A function that calls at the end of each denoising steps during the inference. The function is called
879
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
880
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
881
+ `callback_on_step_end_tensor_inputs`.
882
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
883
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
884
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
885
+ `._callback_tensor_inputs` attribute of your pipeline class.
700
886
  Examples:
701
887
 
702
888
  Returns:
@@ -706,6 +892,22 @@ class StableDiffusionPanoramaPipeline(
706
892
  second element is a list of `bool`s indicating whether the corresponding generated image contains
707
893
  "not-safe-for-work" (nsfw) content.
708
894
  """
895
+ callback = kwargs.pop("callback", None)
896
+ callback_steps = kwargs.pop("callback_steps", None)
897
+
898
+ if callback is not None:
899
+ deprecate(
900
+ "callback",
901
+ "1.0.0",
902
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
903
+ )
904
+ if callback_steps is not None:
905
+ deprecate(
906
+ "callback_steps",
907
+ "1.0.0",
908
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
909
+ )
910
+
709
911
  # 0. Default height and width to unet
710
912
  height = height or self.unet.config.sample_size * self.vae_scale_factor
711
913
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -721,8 +923,15 @@ class StableDiffusionPanoramaPipeline(
721
923
  negative_prompt_embeds,
722
924
  ip_adapter_image,
723
925
  ip_adapter_image_embeds,
926
+ callback_on_step_end_tensor_inputs,
724
927
  )
725
928
 
929
+ self._guidance_scale = guidance_scale
930
+ self._guidance_rescale = guidance_rescale
931
+ self._clip_skip = clip_skip
932
+ self._cross_attention_kwargs = cross_attention_kwargs
933
+ self._interrupt = False
934
+
726
935
  # 2. Define call parameters
727
936
  if prompt is not None and isinstance(prompt, str):
728
937
  batch_size = 1
@@ -768,8 +977,7 @@ class StableDiffusionPanoramaPipeline(
768
977
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
769
978
 
770
979
  # 4. Prepare timesteps
771
- self.scheduler.set_timesteps(num_inference_steps, device=device)
772
- timesteps = self.scheduler.timesteps
980
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
773
981
 
774
982
  # 5. Prepare latent variables
775
983
  num_channels_latents = self.unet.config.in_channels
@@ -802,12 +1010,23 @@ class StableDiffusionPanoramaPipeline(
802
1010
  else None
803
1011
  )
804
1012
 
1013
+ # 7.2 Optionally get Guidance Scale Embedding
1014
+ timestep_cond = None
1015
+ if self.unet.config.time_cond_proj_dim is not None:
1016
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1017
+ timestep_cond = self.get_guidance_scale_embedding(
1018
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1019
+ ).to(device=device, dtype=latents.dtype)
1020
+
805
1021
  # 8. Denoising loop
806
1022
  # Each denoising step also includes refinement of the latents with respect to the
807
1023
  # views.
808
1024
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1025
+ self._num_timesteps = len(timesteps)
809
1026
  with self.progress_bar(total=num_inference_steps) as progress_bar:
810
1027
  for i, t in enumerate(timesteps):
1028
+ if self.interrupt:
1029
+ continue
811
1030
  count.zero_()
812
1031
  value.zero_()
813
1032
 
@@ -863,6 +1082,7 @@ class StableDiffusionPanoramaPipeline(
863
1082
  latent_model_input,
864
1083
  t,
865
1084
  encoder_hidden_states=prompt_embeds_input,
1085
+ timestep_cond=timestep_cond,
866
1086
  cross_attention_kwargs=cross_attention_kwargs,
867
1087
  added_cond_kwargs=added_cond_kwargs,
868
1088
  ).sample
@@ -872,6 +1092,12 @@ class StableDiffusionPanoramaPipeline(
872
1092
  noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
873
1093
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
874
1094
 
1095
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1096
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1097
+ noise_pred = rescale_noise_cfg(
1098
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
1099
+ )
1100
+
875
1101
  # compute the previous noisy sample x_t -> x_t-1
876
1102
  latents_denoised_batch = self.scheduler.step(
877
1103
  noise_pred, t, latents_for_view, **extra_step_kwargs
@@ -901,6 +1127,16 @@ class StableDiffusionPanoramaPipeline(
901
1127
  # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
902
1128
  latents = torch.where(count > 0, value / count, value)
903
1129
 
1130
+ if callback_on_step_end is not None:
1131
+ callback_kwargs = {}
1132
+ for k in callback_on_step_end_tensor_inputs:
1133
+ callback_kwargs[k] = locals()[k]
1134
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1135
+
1136
+ latents = callback_outputs.pop("latents", latents)
1137
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1138
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1139
+
904
1140
  # call the callback, if provided
905
1141
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
906
1142
  progress_bar.update()
@@ -908,7 +1144,7 @@ class StableDiffusionPanoramaPipeline(
908
1144
  step_idx = i // getattr(self.scheduler, "order", 1)
909
1145
  callback(step_idx, t, latents)
910
1146
 
911
- if not output_type == "latent":
1147
+ if output_type != "latent":
912
1148
  if circular_padding:
913
1149
  image = self.decode_latents_with_padding(latents)
914
1150
  else:
@@ -416,7 +416,12 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
416
416
 
417
417
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
418
418
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
419
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
419
+ shape = (
420
+ batch_size,
421
+ num_channels_latents,
422
+ int(height) // self.vae_scale_factor,
423
+ int(width) // self.vae_scale_factor,
424
+ )
420
425
  if isinstance(generator, list) and len(generator) != batch_size:
421
426
  raise ValueError(
422
427
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -511,11 +516,11 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
511
516
  num_images_per_prompt: Optional[int] = 1,
512
517
  eta: float = 0.0,
513
518
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
514
- latents: Optional[torch.FloatTensor] = None,
519
+ latents: Optional[torch.Tensor] = None,
515
520
  ip_adapter_image: Optional[PipelineImageInput] = None,
516
521
  output_type: Optional[str] = "pil",
517
522
  return_dict: bool = True,
518
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
523
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
519
524
  callback_steps: int = 1,
520
525
  sld_guidance_scale: Optional[float] = 1000,
521
526
  sld_warmup_steps: Optional[int] = 10,
@@ -550,7 +555,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
550
555
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
551
556
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
552
557
  generation deterministic.
553
- latents (`torch.FloatTensor`, *optional*):
558
+ latents (`torch.Tensor`, *optional*):
554
559
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
555
560
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
556
561
  tensor is generated by sampling using the supplied random `generator`.
@@ -563,7 +568,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
563
568
  plain tuple.
564
569
  callback (`Callable`, *optional*):
565
570
  A function that calls every `callback_steps` steps during inference. The function is called with the
566
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
571
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
567
572
  callback_steps (`int`, *optional*, defaults to 1):
568
573
  The frequency at which the `callback` function is called. If not specified, the callback is called at
569
574
  every step.
@@ -85,7 +85,7 @@ class SafeStableDiffusionSafetyChecker(PreTrainedModel):
85
85
  return images, has_nsfw_concepts
86
86
 
87
87
  @torch.no_grad()
88
- def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
88
+ def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor):
89
89
  pooled_output = self.vision_model(clip_input)[1] # pooled_output
90
90
  image_embeds = self.visual_projection(pooled_output)
91
91