diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ from transformers import (
30
30
 
31
31
  from diffusers.utils.import_utils import is_invisible_watermark_available
32
32
 
33
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
33
34
  from ...image_processor import PipelineImageInput, VaeImageProcessor
34
35
  from ...loaders import (
35
36
  FromSingleFileMixin,
@@ -114,6 +115,66 @@ EXAMPLE_DOC_STRING = """
114
115
  """
115
116
 
116
117
 
118
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
119
+ def retrieve_timesteps(
120
+ scheduler,
121
+ num_inference_steps: Optional[int] = None,
122
+ device: Optional[Union[str, torch.device]] = None,
123
+ timesteps: Optional[List[int]] = None,
124
+ sigmas: Optional[List[float]] = None,
125
+ **kwargs,
126
+ ):
127
+ """
128
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
129
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
130
+
131
+ Args:
132
+ scheduler (`SchedulerMixin`):
133
+ The scheduler to get timesteps from.
134
+ num_inference_steps (`int`):
135
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
136
+ must be `None`.
137
+ device (`str` or `torch.device`, *optional*):
138
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
+ timesteps (`List[int]`, *optional*):
140
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
141
+ `num_inference_steps` and `sigmas` must be `None`.
142
+ sigmas (`List[float]`, *optional*):
143
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
144
+ `num_inference_steps` and `timesteps` must be `None`.
145
+
146
+ Returns:
147
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
148
+ second element is the number of inference steps.
149
+ """
150
+ if timesteps is not None and sigmas is not None:
151
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
152
+ if timesteps is not None:
153
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
154
+ if not accepts_timesteps:
155
+ raise ValueError(
156
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
157
+ f" timestep schedules. Please check whether you are using the correct scheduler."
158
+ )
159
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
160
+ timesteps = scheduler.timesteps
161
+ num_inference_steps = len(timesteps)
162
+ elif sigmas is not None:
163
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
164
+ if not accept_sigmas:
165
+ raise ValueError(
166
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
167
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
168
+ )
169
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
170
+ timesteps = scheduler.timesteps
171
+ num_inference_steps = len(timesteps)
172
+ else:
173
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
174
+ timesteps = scheduler.timesteps
175
+ return timesteps, num_inference_steps
176
+
177
+
117
178
  class StableDiffusionXLControlNetPipeline(
118
179
  DiffusionPipeline,
119
180
  StableDiffusionMixin,
@@ -175,7 +236,15 @@ class StableDiffusionXLControlNetPipeline(
175
236
  "feature_extractor",
176
237
  "image_encoder",
177
238
  ]
178
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
239
+ _callback_tensor_inputs = [
240
+ "latents",
241
+ "prompt_embeds",
242
+ "negative_prompt_embeds",
243
+ "add_text_embeds",
244
+ "add_time_ids",
245
+ "negative_pooled_prompt_embeds",
246
+ "negative_add_time_ids",
247
+ ]
179
248
 
180
249
  def __init__(
181
250
  self,
@@ -233,10 +302,10 @@ class StableDiffusionXLControlNetPipeline(
233
302
  do_classifier_free_guidance: bool = True,
234
303
  negative_prompt: Optional[str] = None,
235
304
  negative_prompt_2: Optional[str] = None,
236
- prompt_embeds: Optional[torch.FloatTensor] = None,
237
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
238
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
239
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
305
+ prompt_embeds: Optional[torch.Tensor] = None,
306
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
307
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
308
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
240
309
  lora_scale: Optional[float] = None,
241
310
  clip_skip: Optional[int] = None,
242
311
  ):
@@ -262,17 +331,17 @@ class StableDiffusionXLControlNetPipeline(
262
331
  negative_prompt_2 (`str` or `List[str]`, *optional*):
263
332
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
264
333
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
265
- prompt_embeds (`torch.FloatTensor`, *optional*):
334
+ prompt_embeds (`torch.Tensor`, *optional*):
266
335
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
267
336
  provided, text embeddings will be generated from `prompt` input argument.
268
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
337
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
269
338
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
270
339
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
271
340
  argument.
272
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
341
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
273
342
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
274
343
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
275
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
344
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
276
345
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
277
346
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
278
347
  input argument.
@@ -813,7 +882,12 @@ class StableDiffusionXLControlNetPipeline(
813
882
 
814
883
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
815
884
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
816
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
885
+ shape = (
886
+ batch_size,
887
+ num_channels_latents,
888
+ int(height) // self.vae_scale_factor,
889
+ int(width) // self.vae_scale_factor,
890
+ )
817
891
  if isinstance(generator, list) and len(generator) != batch_size:
818
892
  raise ValueError(
819
893
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -869,20 +943,22 @@ class StableDiffusionXLControlNetPipeline(
869
943
  self.vae.decoder.mid_block.to(dtype)
870
944
 
871
945
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
872
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
946
+ def get_guidance_scale_embedding(
947
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
948
+ ) -> torch.Tensor:
873
949
  """
874
950
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
875
951
 
876
952
  Args:
877
- timesteps (`torch.Tensor`):
878
- generate embedding vectors at these timesteps
953
+ w (`torch.Tensor`):
954
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
879
955
  embedding_dim (`int`, *optional*, defaults to 512):
880
- dimension of the embeddings to generate
881
- dtype:
882
- data type of the generated embeddings
956
+ Dimension of the embeddings to generate.
957
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
958
+ Data type of the generated embeddings.
883
959
 
884
960
  Returns:
885
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
961
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
886
962
  """
887
963
  assert len(w.shape) == 1
888
964
  w = w * 1000.0
@@ -934,6 +1010,8 @@ class StableDiffusionXLControlNetPipeline(
934
1010
  height: Optional[int] = None,
935
1011
  width: Optional[int] = None,
936
1012
  num_inference_steps: int = 50,
1013
+ timesteps: List[int] = None,
1014
+ sigmas: List[float] = None,
937
1015
  denoising_end: Optional[float] = None,
938
1016
  guidance_scale: float = 5.0,
939
1017
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -941,13 +1019,13 @@ class StableDiffusionXLControlNetPipeline(
941
1019
  num_images_per_prompt: Optional[int] = 1,
942
1020
  eta: float = 0.0,
943
1021
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
944
- latents: Optional[torch.FloatTensor] = None,
945
- prompt_embeds: Optional[torch.FloatTensor] = None,
946
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
947
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
948
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1022
+ latents: Optional[torch.Tensor] = None,
1023
+ prompt_embeds: Optional[torch.Tensor] = None,
1024
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1025
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1026
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
949
1027
  ip_adapter_image: Optional[PipelineImageInput] = None,
950
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1028
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
951
1029
  output_type: Optional[str] = "pil",
952
1030
  return_dict: bool = True,
953
1031
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -962,7 +1040,9 @@ class StableDiffusionXLControlNetPipeline(
962
1040
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
963
1041
  negative_target_size: Optional[Tuple[int, int]] = None,
964
1042
  clip_skip: Optional[int] = None,
965
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1043
+ callback_on_step_end: Optional[
1044
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1045
+ ] = None,
966
1046
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
967
1047
  **kwargs,
968
1048
  ):
@@ -975,14 +1055,14 @@ class StableDiffusionXLControlNetPipeline(
975
1055
  prompt_2 (`str` or `List[str]`, *optional*):
976
1056
  The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
977
1057
  used in both text-encoders.
978
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
979
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1058
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1059
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
980
1060
  The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
981
- specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
982
- accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
983
- and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
984
- `init`, images must be passed as a list such that each element of the list can be correctly batched for
985
- input to a single ControlNet.
1061
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
1062
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1063
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
1064
+ images must be passed as a list such that each element of the list can be correctly batched for input
1065
+ to a single ControlNet.
986
1066
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
987
1067
  The height in pixels of the generated image. Anything below 512 pixels won't work well for
988
1068
  [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -994,6 +1074,14 @@ class StableDiffusionXLControlNetPipeline(
994
1074
  num_inference_steps (`int`, *optional*, defaults to 50):
995
1075
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
996
1076
  expense of slower inference.
1077
+ timesteps (`List[int]`, *optional*):
1078
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1079
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1080
+ passed will be used. Must be in descending order.
1081
+ sigmas (`List[float]`, *optional*):
1082
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1083
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1084
+ will be used.
997
1085
  denoising_end (`float`, *optional*):
998
1086
  When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
999
1087
  completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -1018,29 +1106,29 @@ class StableDiffusionXLControlNetPipeline(
1018
1106
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1019
1107
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1020
1108
  generation deterministic.
1021
- latents (`torch.FloatTensor`, *optional*):
1109
+ latents (`torch.Tensor`, *optional*):
1022
1110
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1023
1111
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1024
1112
  tensor is generated by sampling using the supplied random `generator`.
1025
- prompt_embeds (`torch.FloatTensor`, *optional*):
1113
+ prompt_embeds (`torch.Tensor`, *optional*):
1026
1114
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1027
1115
  provided, text embeddings are generated from the `prompt` input argument.
1028
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1116
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1029
1117
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1030
1118
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1031
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1119
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1032
1120
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1033
1121
  not provided, pooled text embeddings are generated from `prompt` input argument.
1034
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1122
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1035
1123
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
1036
1124
  weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
1037
1125
  argument.
1038
1126
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1039
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1040
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
1041
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
1042
- if `do_classifier_free_guidance` is set to `True`.
1043
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
1127
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1128
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1129
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1130
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1131
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1044
1132
  output_type (`str`, *optional*, defaults to `"pil"`):
1045
1133
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1046
1134
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -1092,15 +1180,15 @@ class StableDiffusionXLControlNetPipeline(
1092
1180
  clip_skip (`int`, *optional*):
1093
1181
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1094
1182
  the output of the pre-final layer will be used for computing the prompt embeddings.
1095
- callback_on_step_end (`Callable`, *optional*):
1096
- A function that calls at the end of each denoising steps during the inference. The function is called
1097
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1098
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1099
- `callback_on_step_end_tensor_inputs`.
1183
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1184
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1185
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1186
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1187
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1100
1188
  callback_on_step_end_tensor_inputs (`List`, *optional*):
1101
1189
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1102
1190
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1103
- `._callback_tensor_inputs` attribute of your pipeine class.
1191
+ `._callback_tensor_inputs` attribute of your pipeline class.
1104
1192
 
1105
1193
  Examples:
1106
1194
 
@@ -1126,6 +1214,9 @@ class StableDiffusionXLControlNetPipeline(
1126
1214
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1127
1215
  )
1128
1216
 
1217
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1218
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1219
+
1129
1220
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1130
1221
 
1131
1222
  # align format for control guidance
@@ -1258,8 +1349,9 @@ class StableDiffusionXLControlNetPipeline(
1258
1349
  assert False
1259
1350
 
1260
1351
  # 5. Prepare timesteps
1261
- self.scheduler.set_timesteps(num_inference_steps, device=device)
1262
- timesteps = self.scheduler.timesteps
1352
+ timesteps, num_inference_steps = retrieve_timesteps(
1353
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1354
+ )
1263
1355
  self._num_timesteps = len(timesteps)
1264
1356
 
1265
1357
  # 6. Prepare latent variables
@@ -1444,6 +1536,12 @@ class StableDiffusionXLControlNetPipeline(
1444
1536
  latents = callback_outputs.pop("latents", latents)
1445
1537
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1446
1538
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1539
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1540
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1541
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1542
+ )
1543
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1544
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1447
1545
 
1448
1546
  # call the callback, if provided
1449
1547
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -30,8 +30,10 @@ from transformers import (
30
30
 
31
31
  from diffusers.utils.import_utils import is_invisible_watermark_available
32
32
 
33
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
33
34
  from ...image_processor import PipelineImageInput, VaeImageProcessor
34
35
  from ...loaders import (
36
+ FromSingleFileMixin,
35
37
  IPAdapterMixin,
36
38
  StableDiffusionXLLoraLoaderMixin,
37
39
  TextualInversionLoaderMixin,
@@ -88,8 +90,8 @@ EXAMPLE_DOC_STRING = """
88
90
  ... variant="fp16",
89
91
  ... use_safetensors=True,
90
92
  ... torch_dtype=torch.float16,
91
- ... ).to("cuda")
92
- >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
93
+ ... )
94
+ >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
93
95
  >>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
94
96
  ... "stabilityai/stable-diffusion-xl-base-1.0",
95
97
  ... controlnet=controlnet,
@@ -97,7 +99,7 @@ EXAMPLE_DOC_STRING = """
97
99
  ... variant="fp16",
98
100
  ... use_safetensors=True,
99
101
  ... torch_dtype=torch.float16,
100
- ... ).to("cuda")
102
+ ... )
101
103
  >>> pipe.enable_model_cpu_offload()
102
104
 
103
105
 
@@ -161,6 +163,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
161
163
  StableDiffusionMixin,
162
164
  TextualInversionLoaderMixin,
163
165
  StableDiffusionXLLoraLoaderMixin,
166
+ FromSingleFileMixin,
164
167
  IPAdapterMixin,
165
168
  ):
166
169
  r"""
@@ -225,7 +228,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
225
228
  "feature_extractor",
226
229
  "image_encoder",
227
230
  ]
228
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
231
+ _callback_tensor_inputs = [
232
+ "latents",
233
+ "prompt_embeds",
234
+ "negative_prompt_embeds",
235
+ "add_text_embeds",
236
+ "add_time_ids",
237
+ "negative_pooled_prompt_embeds",
238
+ "add_neg_time_ids",
239
+ ]
229
240
 
230
241
  def __init__(
231
242
  self,
@@ -285,10 +296,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
285
296
  do_classifier_free_guidance: bool = True,
286
297
  negative_prompt: Optional[str] = None,
287
298
  negative_prompt_2: Optional[str] = None,
288
- prompt_embeds: Optional[torch.FloatTensor] = None,
289
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
290
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
291
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
299
+ prompt_embeds: Optional[torch.Tensor] = None,
300
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
301
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
302
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
292
303
  lora_scale: Optional[float] = None,
293
304
  clip_skip: Optional[int] = None,
294
305
  ):
@@ -314,17 +325,17 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
314
325
  negative_prompt_2 (`str` or `List[str]`, *optional*):
315
326
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
316
327
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
317
- prompt_embeds (`torch.FloatTensor`, *optional*):
328
+ prompt_embeds (`torch.Tensor`, *optional*):
318
329
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
319
330
  provided, text embeddings will be generated from `prompt` input argument.
320
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
331
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
321
332
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
322
333
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
323
334
  argument.
324
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
335
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
325
336
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
326
337
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
327
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
338
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
328
339
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
329
340
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
330
341
  input argument.
@@ -896,6 +907,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
896
907
  f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
897
908
  )
898
909
 
910
+ latents_mean = latents_std = None
911
+ if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
912
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
913
+ if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
914
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
915
+
899
916
  # Offload text encoder if `enable_model_cpu_offload` was enabled
900
917
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
901
918
  self.text_encoder_2.to("cpu")
@@ -933,7 +950,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
933
950
  self.vae.to(dtype)
934
951
 
935
952
  init_latents = init_latents.to(dtype)
936
- init_latents = self.vae.config.scaling_factor * init_latents
953
+ if latents_mean is not None and latents_std is not None:
954
+ latents_mean = latents_mean.to(device=self.device, dtype=dtype)
955
+ latents_std = latents_std.to(device=self.device, dtype=dtype)
956
+ init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
957
+ else:
958
+ init_latents = self.vae.config.scaling_factor * init_latents
937
959
 
938
960
  if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
939
961
  # expand init_latents for batch_size
@@ -1069,13 +1091,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1069
1091
  num_images_per_prompt: Optional[int] = 1,
1070
1092
  eta: float = 0.0,
1071
1093
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1072
- latents: Optional[torch.FloatTensor] = None,
1073
- prompt_embeds: Optional[torch.FloatTensor] = None,
1074
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1075
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1076
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1094
+ latents: Optional[torch.Tensor] = None,
1095
+ prompt_embeds: Optional[torch.Tensor] = None,
1096
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1097
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1098
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1077
1099
  ip_adapter_image: Optional[PipelineImageInput] = None,
1078
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1100
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1079
1101
  output_type: Optional[str] = "pil",
1080
1102
  return_dict: bool = True,
1081
1103
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -1092,7 +1114,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1092
1114
  aesthetic_score: float = 6.0,
1093
1115
  negative_aesthetic_score: float = 2.5,
1094
1116
  clip_skip: Optional[int] = None,
1095
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1117
+ callback_on_step_end: Optional[
1118
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1119
+ ] = None,
1096
1120
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1097
1121
  **kwargs,
1098
1122
  ):
@@ -1106,18 +1130,18 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1106
1130
  prompt_2 (`str` or `List[str]`, *optional*):
1107
1131
  The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1108
1132
  used in both text-encoders
1109
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1110
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1133
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1134
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1111
1135
  The initial image will be used as the starting point for the image generation process. Can also accept
1112
1136
  image latents as `image`, if passing latents directly, it will not be encoded again.
1113
- control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1114
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1137
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1138
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1115
1139
  The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
1116
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
1117
- also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
1118
- height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
1119
- specified in init, images must be passed as a list such that each element of the list can be correctly
1120
- batched for input to a single controlnet.
1140
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
1141
+ be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
1142
+ and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
1143
+ init, images must be passed as a list such that each element of the list can be correctly batched for
1144
+ input to a single controlnet.
1121
1145
  height (`int`, *optional*, defaults to the size of control_image):
1122
1146
  The height in pixels of the generated image. Anything below 512 pixels won't work well for
1123
1147
  [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -1156,30 +1180,30 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1156
1180
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1157
1181
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1158
1182
  to make generation deterministic.
1159
- latents (`torch.FloatTensor`, *optional*):
1183
+ latents (`torch.Tensor`, *optional*):
1160
1184
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1161
1185
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1162
1186
  tensor will ge generated by sampling using the supplied random `generator`.
1163
- prompt_embeds (`torch.FloatTensor`, *optional*):
1187
+ prompt_embeds (`torch.Tensor`, *optional*):
1164
1188
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1165
1189
  provided, text embeddings will be generated from `prompt` input argument.
1166
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1190
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1167
1191
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1168
1192
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1169
1193
  argument.
1170
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1194
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1171
1195
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1172
1196
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
1173
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1197
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1174
1198
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1175
1199
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1176
1200
  input argument.
1177
1201
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1178
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1179
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
1180
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
1181
- if `do_classifier_free_guidance` is set to `True`.
1182
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
1202
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1203
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1204
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1205
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1206
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1183
1207
  output_type (`str`, *optional*, defaults to `"pil"`):
1184
1208
  The output format of the generate image. Choose between
1185
1209
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1241,15 +1265,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1241
1265
  clip_skip (`int`, *optional*):
1242
1266
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1243
1267
  the output of the pre-final layer will be used for computing the prompt embeddings.
1244
- callback_on_step_end (`Callable`, *optional*):
1245
- A function that calls at the end of each denoising steps during the inference. The function is called
1246
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1247
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1248
- `callback_on_step_end_tensor_inputs`.
1268
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1269
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1270
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1271
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1272
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1249
1273
  callback_on_step_end_tensor_inputs (`List`, *optional*):
1250
1274
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1251
1275
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1252
- `._callback_tensor_inputs` attribute of your pipeine class.
1276
+ `._callback_tensor_inputs` attribute of your pipeline class.
1253
1277
 
1254
1278
  Examples:
1255
1279
 
@@ -1275,6 +1299,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1275
1299
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1276
1300
  )
1277
1301
 
1302
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1303
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1304
+
1278
1305
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1279
1306
 
1280
1307
  # align format for control guidance
@@ -1416,16 +1443,17 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1416
1443
  self._num_timesteps = len(timesteps)
1417
1444
 
1418
1445
  # 6. Prepare latent variables
1419
- latents = self.prepare_latents(
1420
- image,
1421
- latent_timestep,
1422
- batch_size,
1423
- num_images_per_prompt,
1424
- prompt_embeds.dtype,
1425
- device,
1426
- generator,
1427
- True,
1428
- )
1446
+ if latents is None:
1447
+ latents = self.prepare_latents(
1448
+ image,
1449
+ latent_timestep,
1450
+ batch_size,
1451
+ num_images_per_prompt,
1452
+ prompt_embeds.dtype,
1453
+ device,
1454
+ generator,
1455
+ True,
1456
+ )
1429
1457
 
1430
1458
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1431
1459
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -1564,6 +1592,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1564
1592
  latents = callback_outputs.pop("latents", latents)
1565
1593
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1566
1594
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1595
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1596
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1597
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1598
+ )
1599
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1600
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1567
1601
 
1568
1602
  # call the callback, if provided
1569
1603
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):