diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import torch
22
22
  import torch.nn.functional as F
23
23
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
24
24
 
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
26
  from ...image_processor import PipelineImageInput, VaeImageProcessor
26
27
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
28
  from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
@@ -97,6 +98,7 @@ def retrieve_timesteps(
97
98
  num_inference_steps: Optional[int] = None,
98
99
  device: Optional[Union[str, torch.device]] = None,
99
100
  timesteps: Optional[List[int]] = None,
101
+ sigmas: Optional[List[float]] = None,
100
102
  **kwargs,
101
103
  ):
102
104
  """
@@ -107,19 +109,23 @@ def retrieve_timesteps(
107
109
  scheduler (`SchedulerMixin`):
108
110
  The scheduler to get timesteps from.
109
111
  num_inference_steps (`int`):
110
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
111
- `timesteps` must be `None`.
112
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
113
+ must be `None`.
112
114
  device (`str` or `torch.device`, *optional*):
113
115
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
114
116
  timesteps (`List[int]`, *optional*):
115
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
116
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
117
- must be `None`.
117
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
118
+ `num_inference_steps` and `sigmas` must be `None`.
119
+ sigmas (`List[float]`, *optional*):
120
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
121
+ `num_inference_steps` and `timesteps` must be `None`.
118
122
 
119
123
  Returns:
120
124
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
121
125
  second element is the number of inference steps.
122
126
  """
127
+ if timesteps is not None and sigmas is not None:
128
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
123
129
  if timesteps is not None:
124
130
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
125
131
  if not accepts_timesteps:
@@ -130,6 +136,16 @@ def retrieve_timesteps(
130
136
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
131
137
  timesteps = scheduler.timesteps
132
138
  num_inference_steps = len(timesteps)
139
+ elif sigmas is not None:
140
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
141
+ if not accept_sigmas:
142
+ raise ValueError(
143
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
144
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
145
+ )
146
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ num_inference_steps = len(timesteps)
133
149
  else:
134
150
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
151
  timesteps = scheduler.timesteps
@@ -246,8 +262,8 @@ class StableDiffusionControlNetPipeline(
246
262
  num_images_per_prompt,
247
263
  do_classifier_free_guidance,
248
264
  negative_prompt=None,
249
- prompt_embeds: Optional[torch.FloatTensor] = None,
250
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
265
+ prompt_embeds: Optional[torch.Tensor] = None,
266
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
251
267
  lora_scale: Optional[float] = None,
252
268
  **kwargs,
253
269
  ):
@@ -279,8 +295,8 @@ class StableDiffusionControlNetPipeline(
279
295
  num_images_per_prompt,
280
296
  do_classifier_free_guidance,
281
297
  negative_prompt=None,
282
- prompt_embeds: Optional[torch.FloatTensor] = None,
283
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
298
+ prompt_embeds: Optional[torch.Tensor] = None,
299
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
284
300
  lora_scale: Optional[float] = None,
285
301
  clip_skip: Optional[int] = None,
286
302
  ):
@@ -300,10 +316,10 @@ class StableDiffusionControlNetPipeline(
300
316
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
301
317
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
302
318
  less than `1`).
303
- prompt_embeds (`torch.FloatTensor`, *optional*):
319
+ prompt_embeds (`torch.Tensor`, *optional*):
304
320
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
305
321
  provided, text embeddings will be generated from `prompt` input argument.
306
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
322
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
307
323
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
308
324
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
309
325
  argument.
@@ -661,9 +677,9 @@ class StableDiffusionControlNetPipeline(
661
677
  raise ValueError(
662
678
  f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
663
679
  )
664
-
665
- for image_ in image:
666
- self.check_image(image_, prompt, prompt_embeds)
680
+ else:
681
+ for image_ in image:
682
+ self.check_image(image_, prompt, prompt_embeds)
667
683
  else:
668
684
  assert False
669
685
 
@@ -807,7 +823,12 @@ class StableDiffusionControlNetPipeline(
807
823
 
808
824
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
809
825
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
810
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
826
+ shape = (
827
+ batch_size,
828
+ num_channels_latents,
829
+ int(height) // self.vae_scale_factor,
830
+ int(width) // self.vae_scale_factor,
831
+ )
811
832
  if isinstance(generator, list) and len(generator) != batch_size:
812
833
  raise ValueError(
813
834
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -824,20 +845,22 @@ class StableDiffusionControlNetPipeline(
824
845
  return latents
825
846
 
826
847
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
827
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
848
+ def get_guidance_scale_embedding(
849
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
850
+ ) -> torch.Tensor:
828
851
  """
829
852
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
830
853
 
831
854
  Args:
832
- timesteps (`torch.Tensor`):
833
- generate embedding vectors at these timesteps
855
+ w (`torch.Tensor`):
856
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
834
857
  embedding_dim (`int`, *optional*, defaults to 512):
835
- dimension of the embeddings to generate
836
- dtype:
837
- data type of the generated embeddings
858
+ Dimension of the embeddings to generate.
859
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
860
+ Data type of the generated embeddings.
838
861
 
839
862
  Returns:
840
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
863
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
841
864
  """
842
865
  assert len(w.shape) == 1
843
866
  w = w * 1000.0
@@ -885,16 +908,17 @@ class StableDiffusionControlNetPipeline(
885
908
  width: Optional[int] = None,
886
909
  num_inference_steps: int = 50,
887
910
  timesteps: List[int] = None,
911
+ sigmas: List[float] = None,
888
912
  guidance_scale: float = 7.5,
889
913
  negative_prompt: Optional[Union[str, List[str]]] = None,
890
914
  num_images_per_prompt: Optional[int] = 1,
891
915
  eta: float = 0.0,
892
916
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
893
- latents: Optional[torch.FloatTensor] = None,
894
- prompt_embeds: Optional[torch.FloatTensor] = None,
895
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
917
+ latents: Optional[torch.Tensor] = None,
918
+ prompt_embeds: Optional[torch.Tensor] = None,
919
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
896
920
  ip_adapter_image: Optional[PipelineImageInput] = None,
897
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
921
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
898
922
  output_type: Optional[str] = "pil",
899
923
  return_dict: bool = True,
900
924
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -903,7 +927,9 @@ class StableDiffusionControlNetPipeline(
903
927
  control_guidance_start: Union[float, List[float]] = 0.0,
904
928
  control_guidance_end: Union[float, List[float]] = 1.0,
905
929
  clip_skip: Optional[int] = None,
906
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
930
+ callback_on_step_end: Optional[
931
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
932
+ ] = None,
907
933
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
908
934
  **kwargs,
909
935
  ):
@@ -913,16 +939,16 @@ class StableDiffusionControlNetPipeline(
913
939
  Args:
914
940
  prompt (`str` or `List[str]`, *optional*):
915
941
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
916
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
917
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
942
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
943
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
918
944
  The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
919
- specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
920
- accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
921
- and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
922
- `init`, images must be passed as a list such that each element of the list can be correctly batched for
923
- input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
924
- each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
925
- where a list of image lists can be passed to batch for each prompt and each ControlNet.
945
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
946
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
947
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
948
+ images must be passed as a list such that each element of the list can be correctly batched for input
949
+ to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
950
+ ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
951
+ ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
926
952
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
927
953
  The height in pixels of the generated image.
928
954
  width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -934,6 +960,10 @@ class StableDiffusionControlNetPipeline(
934
960
  Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
935
961
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
936
962
  passed will be used. Must be in descending order.
963
+ sigmas (`List[float]`, *optional*):
964
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
965
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
966
+ will be used.
937
967
  guidance_scale (`float`, *optional*, defaults to 7.5):
938
968
  A higher guidance scale value encourages the model to generate images closely linked to the text
939
969
  `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -948,22 +978,22 @@ class StableDiffusionControlNetPipeline(
948
978
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
949
979
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
950
980
  generation deterministic.
951
- latents (`torch.FloatTensor`, *optional*):
981
+ latents (`torch.Tensor`, *optional*):
952
982
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
953
983
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
954
984
  tensor is generated by sampling using the supplied random `generator`.
955
- prompt_embeds (`torch.FloatTensor`, *optional*):
985
+ prompt_embeds (`torch.Tensor`, *optional*):
956
986
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
957
987
  provided, text embeddings are generated from the `prompt` input argument.
958
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
988
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
959
989
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
960
990
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
961
991
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
962
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
963
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
964
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
965
- if `do_classifier_free_guidance` is set to `True`.
966
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
992
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
993
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
994
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
995
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
996
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
967
997
  output_type (`str`, *optional*, defaults to `"pil"`):
968
998
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
969
999
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -971,7 +1001,7 @@ class StableDiffusionControlNetPipeline(
971
1001
  plain tuple.
972
1002
  callback (`Callable`, *optional*):
973
1003
  A function that calls every `callback_steps` steps during inference. The function is called with the
974
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1004
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
975
1005
  callback_steps (`int`, *optional*, defaults to 1):
976
1006
  The frequency at which the `callback` function is called. If not specified, the callback is called at
977
1007
  every step.
@@ -992,15 +1022,15 @@ class StableDiffusionControlNetPipeline(
992
1022
  clip_skip (`int`, *optional*):
993
1023
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
994
1024
  the output of the pre-final layer will be used for computing the prompt embeddings.
995
- callback_on_step_end (`Callable`, *optional*):
996
- A function that calls at the end of each denoising steps during the inference. The function is called
997
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
998
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
999
- `callback_on_step_end_tensor_inputs`.
1025
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1026
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1027
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1028
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1029
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1000
1030
  callback_on_step_end_tensor_inputs (`List`, *optional*):
1001
1031
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1002
1032
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1003
- `._callback_tensor_inputs` attribute of your pipeine class.
1033
+ `._callback_tensor_inputs` attribute of your pipeline class.
1004
1034
 
1005
1035
  Examples:
1006
1036
 
@@ -1028,6 +1058,9 @@ class StableDiffusionControlNetPipeline(
1028
1058
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1029
1059
  )
1030
1060
 
1061
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1062
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1063
+
1031
1064
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1032
1065
 
1033
1066
  # align format for control guidance
@@ -1155,7 +1188,9 @@ class StableDiffusionControlNetPipeline(
1155
1188
  assert False
1156
1189
 
1157
1190
  # 5. Prepare timesteps
1158
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1191
+ timesteps, num_inference_steps = retrieve_timesteps(
1192
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1193
+ )
1159
1194
  self._num_timesteps = len(timesteps)
1160
1195
 
1161
1196
  # 6. Prepare latent variables
@@ -240,7 +240,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
240
240
  condtioning_image: PIL.Image.Image,
241
241
  source_subject_category: List[str],
242
242
  target_subject_category: List[str],
243
- latents: Optional[torch.FloatTensor] = None,
243
+ latents: Optional[torch.Tensor] = None,
244
244
  guidance_scale: float = 7.5,
245
245
  height: int = 512,
246
246
  width: int = 512,
@@ -266,7 +266,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
266
266
  The source subject category.
267
267
  target_subject_category (`List[str]`):
268
268
  The target subject category.
269
- latents (`torch.FloatTensor`, *optional*):
269
+ latents (`torch.Tensor`, *optional*):
270
270
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
271
271
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
272
272
  tensor will ge generated by random sampling.
@@ -21,6 +21,7 @@ import torch
21
21
  import torch.nn.functional as F
22
22
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
23
23
 
24
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24
25
  from ...image_processor import PipelineImageInput, VaeImageProcessor
25
26
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
27
  from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
@@ -239,8 +240,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
239
240
  num_images_per_prompt,
240
241
  do_classifier_free_guidance,
241
242
  negative_prompt=None,
242
- prompt_embeds: Optional[torch.FloatTensor] = None,
243
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
243
+ prompt_embeds: Optional[torch.Tensor] = None,
244
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
244
245
  lora_scale: Optional[float] = None,
245
246
  **kwargs,
246
247
  ):
@@ -272,8 +273,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
272
273
  num_images_per_prompt,
273
274
  do_classifier_free_guidance,
274
275
  negative_prompt=None,
275
- prompt_embeds: Optional[torch.FloatTensor] = None,
276
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
276
+ prompt_embeds: Optional[torch.Tensor] = None,
277
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
277
278
  lora_scale: Optional[float] = None,
278
279
  clip_skip: Optional[int] = None,
279
280
  ):
@@ -293,10 +294,10 @@ class StableDiffusionControlNetImg2ImgPipeline(
293
294
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
294
295
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
295
296
  less than `1`).
296
- prompt_embeds (`torch.FloatTensor`, *optional*):
297
+ prompt_embeds (`torch.Tensor`, *optional*):
297
298
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
298
299
  provided, text embeddings will be generated from `prompt` input argument.
299
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
300
301
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
301
302
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
302
303
  argument.
@@ -904,11 +905,11 @@ class StableDiffusionControlNetImg2ImgPipeline(
904
905
  num_images_per_prompt: Optional[int] = 1,
905
906
  eta: float = 0.0,
906
907
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
907
- latents: Optional[torch.FloatTensor] = None,
908
- prompt_embeds: Optional[torch.FloatTensor] = None,
909
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
908
+ latents: Optional[torch.Tensor] = None,
909
+ prompt_embeds: Optional[torch.Tensor] = None,
910
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
910
911
  ip_adapter_image: Optional[PipelineImageInput] = None,
911
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
912
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
912
913
  output_type: Optional[str] = "pil",
913
914
  return_dict: bool = True,
914
915
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -917,7 +918,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
917
918
  control_guidance_start: Union[float, List[float]] = 0.0,
918
919
  control_guidance_end: Union[float, List[float]] = 1.0,
919
920
  clip_skip: Optional[int] = None,
920
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
921
+ callback_on_step_end: Optional[
922
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
923
+ ] = None,
921
924
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
922
925
  **kwargs,
923
926
  ):
@@ -927,18 +930,18 @@ class StableDiffusionControlNetImg2ImgPipeline(
927
930
  Args:
928
931
  prompt (`str` or `List[str]`, *optional*):
929
932
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
930
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
931
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
933
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
934
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
932
935
  The initial image to be used as the starting point for the image generation process. Can also accept
933
936
  image latents as `image`, and if passing latents directly they are not encoded again.
934
- control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
935
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
937
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
938
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
936
939
  The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
937
- specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
938
- accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
939
- and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
940
- `init`, images must be passed as a list such that each element of the list can be correctly batched for
941
- input to a single ControlNet.
940
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
941
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
942
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
943
+ images must be passed as a list such that each element of the list can be correctly batched for input
944
+ to a single ControlNet.
942
945
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
943
946
  The height in pixels of the generated image.
944
947
  width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -966,22 +969,22 @@ class StableDiffusionControlNetImg2ImgPipeline(
966
969
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
967
970
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
968
971
  generation deterministic.
969
- latents (`torch.FloatTensor`, *optional*):
972
+ latents (`torch.Tensor`, *optional*):
970
973
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
971
974
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
972
975
  tensor is generated by sampling using the supplied random `generator`.
973
- prompt_embeds (`torch.FloatTensor`, *optional*):
976
+ prompt_embeds (`torch.Tensor`, *optional*):
974
977
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
975
978
  provided, text embeddings are generated from the `prompt` input argument.
976
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
979
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
977
980
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
978
981
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
979
982
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
980
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
981
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
982
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
983
- if `do_classifier_free_guidance` is set to `True`.
984
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
983
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
984
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
985
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
986
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
987
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
985
988
  output_type (`str`, *optional*, defaults to `"pil"`):
986
989
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
987
990
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -1004,15 +1007,15 @@ class StableDiffusionControlNetImg2ImgPipeline(
1004
1007
  clip_skip (`int`, *optional*):
1005
1008
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1006
1009
  the output of the pre-final layer will be used for computing the prompt embeddings.
1007
- callback_on_step_end (`Callable`, *optional*):
1008
- A function that calls at the end of each denoising steps during the inference. The function is called
1009
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1010
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1011
- `callback_on_step_end_tensor_inputs`.
1010
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1011
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1012
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1013
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1014
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1012
1015
  callback_on_step_end_tensor_inputs (`List`, *optional*):
1013
1016
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1014
1017
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1015
- `._callback_tensor_inputs` attribute of your pipeine class.
1018
+ `._callback_tensor_inputs` attribute of your pipeline class.
1016
1019
 
1017
1020
  Examples:
1018
1021
 
@@ -1040,6 +1043,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
1040
1043
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1041
1044
  )
1042
1045
 
1046
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1047
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1048
+
1043
1049
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1044
1050
 
1045
1051
  # align format for control guidance
@@ -1169,15 +1175,16 @@ class StableDiffusionControlNetImg2ImgPipeline(
1169
1175
  self._num_timesteps = len(timesteps)
1170
1176
 
1171
1177
  # 6. Prepare latent variables
1172
- latents = self.prepare_latents(
1173
- image,
1174
- latent_timestep,
1175
- batch_size,
1176
- num_images_per_prompt,
1177
- prompt_embeds.dtype,
1178
- device,
1179
- generator,
1180
- )
1178
+ if latents is None:
1179
+ latents = self.prepare_latents(
1180
+ image,
1181
+ latent_timestep,
1182
+ batch_size,
1183
+ num_images_per_prompt,
1184
+ prompt_embeds.dtype,
1185
+ device,
1186
+ generator,
1187
+ )
1181
1188
 
1182
1189
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1183
1190
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)