diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,911 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
+
24
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
26
+ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
28
+ from ...models.lora import adjust_lora_scale_text_encoder
29
+ from ...schedulers import KarrasDiffusionSchedulers
30
+ from ...utils import (
31
+ USE_PEFT_BACKEND,
32
+ deprecate,
33
+ logging,
34
+ replace_example_docstring,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+ from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
39
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
+ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
41
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+
47
+ EXAMPLE_DOC_STRING = """
48
+ Examples:
49
+ ```py
50
+ >>> # !pip install opencv-python transformers accelerate
51
+ >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter
52
+ >>> from diffusers.utils import load_image
53
+ >>> import numpy as np
54
+ >>> import torch
55
+
56
+ >>> import cv2
57
+ >>> from PIL import Image
58
+
59
+ >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
60
+ >>> negative_prompt = "low quality, bad quality, sketches"
61
+
62
+ >>> # download an image
63
+ >>> image = load_image(
64
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
65
+ ... )
66
+
67
+ >>> # initialize the models and pipeline
68
+ >>> controlnet_conditioning_scale = 0.5
69
+
70
+ >>> controlnet = ControlNetXSAdapter.from_pretrained(
71
+ ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
72
+ ... )
73
+ >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
74
+ ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
75
+ ... )
76
+ >>> pipe.enable_model_cpu_offload()
77
+
78
+ >>> # get canny image
79
+ >>> image = np.array(image)
80
+ >>> image = cv2.Canny(image, 100, 200)
81
+ >>> image = image[:, :, None]
82
+ >>> image = np.concatenate([image, image, image], axis=2)
83
+ >>> canny_image = Image.fromarray(image)
84
+ >>> # generate image
85
+ >>> image = pipe(
86
+ ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
87
+ ... ).images[0]
88
+ ```
89
+ """
90
+
91
+
92
+ class StableDiffusionControlNetXSPipeline(
93
+ DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
94
+ ):
95
+ r"""
96
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance.
97
+
98
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
99
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
100
+
101
+ The pipeline also inherits the following loading methods:
102
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
103
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
104
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
105
+ - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
106
+
107
+ Args:
108
+ vae ([`AutoencoderKL`]):
109
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
110
+ text_encoder ([`~transformers.CLIPTextModel`]):
111
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
112
+ tokenizer ([`~transformers.CLIPTokenizer`]):
113
+ A `CLIPTokenizer` to tokenize text.
114
+ unet ([`UNet2DConditionModel`]):
115
+ A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
116
+ controlnet ([`ControlNetXSAdapter`]):
117
+ A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
118
+ scheduler ([`SchedulerMixin`]):
119
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
120
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
121
+ safety_checker ([`StableDiffusionSafetyChecker`]):
122
+ Classification module that estimates whether generated images could be considered offensive or harmful.
123
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
124
+ about a model's potential harms.
125
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
126
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
127
+ """
128
+
129
+ model_cpu_offload_seq = "text_encoder->unet->vae"
130
+ _optional_components = ["safety_checker", "feature_extractor"]
131
+ _exclude_from_cpu_offload = ["safety_checker"]
132
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
133
+
134
+ def __init__(
135
+ self,
136
+ vae: AutoencoderKL,
137
+ text_encoder: CLIPTextModel,
138
+ tokenizer: CLIPTokenizer,
139
+ unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
140
+ controlnet: ControlNetXSAdapter,
141
+ scheduler: KarrasDiffusionSchedulers,
142
+ safety_checker: StableDiffusionSafetyChecker,
143
+ feature_extractor: CLIPImageProcessor,
144
+ requires_safety_checker: bool = True,
145
+ ):
146
+ super().__init__()
147
+
148
+ if isinstance(unet, UNet2DConditionModel):
149
+ unet = UNetControlNetXSModel.from_unet(unet, controlnet)
150
+
151
+ if safety_checker is None and requires_safety_checker:
152
+ logger.warning(
153
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
154
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
155
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
156
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
157
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
158
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
159
+ )
160
+
161
+ if safety_checker is not None and feature_extractor is None:
162
+ raise ValueError(
163
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
164
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
165
+ )
166
+
167
+ self.register_modules(
168
+ vae=vae,
169
+ text_encoder=text_encoder,
170
+ tokenizer=tokenizer,
171
+ unet=unet,
172
+ controlnet=controlnet,
173
+ scheduler=scheduler,
174
+ safety_checker=safety_checker,
175
+ feature_extractor=feature_extractor,
176
+ )
177
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
178
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
179
+ self.control_image_processor = VaeImageProcessor(
180
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
181
+ )
182
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
183
+
184
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
185
+ def _encode_prompt(
186
+ self,
187
+ prompt,
188
+ device,
189
+ num_images_per_prompt,
190
+ do_classifier_free_guidance,
191
+ negative_prompt=None,
192
+ prompt_embeds: Optional[torch.Tensor] = None,
193
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
194
+ lora_scale: Optional[float] = None,
195
+ **kwargs,
196
+ ):
197
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
198
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
199
+
200
+ prompt_embeds_tuple = self.encode_prompt(
201
+ prompt=prompt,
202
+ device=device,
203
+ num_images_per_prompt=num_images_per_prompt,
204
+ do_classifier_free_guidance=do_classifier_free_guidance,
205
+ negative_prompt=negative_prompt,
206
+ prompt_embeds=prompt_embeds,
207
+ negative_prompt_embeds=negative_prompt_embeds,
208
+ lora_scale=lora_scale,
209
+ **kwargs,
210
+ )
211
+
212
+ # concatenate for backwards comp
213
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
214
+
215
+ return prompt_embeds
216
+
217
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
218
+ def encode_prompt(
219
+ self,
220
+ prompt,
221
+ device,
222
+ num_images_per_prompt,
223
+ do_classifier_free_guidance,
224
+ negative_prompt=None,
225
+ prompt_embeds: Optional[torch.Tensor] = None,
226
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
227
+ lora_scale: Optional[float] = None,
228
+ clip_skip: Optional[int] = None,
229
+ ):
230
+ r"""
231
+ Encodes the prompt into text encoder hidden states.
232
+
233
+ Args:
234
+ prompt (`str` or `List[str]`, *optional*):
235
+ prompt to be encoded
236
+ device: (`torch.device`):
237
+ torch device
238
+ num_images_per_prompt (`int`):
239
+ number of images that should be generated per prompt
240
+ do_classifier_free_guidance (`bool`):
241
+ whether to use classifier free guidance or not
242
+ negative_prompt (`str` or `List[str]`, *optional*):
243
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
244
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
245
+ less than `1`).
246
+ prompt_embeds (`torch.Tensor`, *optional*):
247
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
248
+ provided, text embeddings will be generated from `prompt` input argument.
249
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
250
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
251
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
252
+ argument.
253
+ lora_scale (`float`, *optional*):
254
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
255
+ clip_skip (`int`, *optional*):
256
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
257
+ the output of the pre-final layer will be used for computing the prompt embeddings.
258
+ """
259
+ # set lora scale so that monkey patched LoRA
260
+ # function of text encoder can correctly access it
261
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
262
+ self._lora_scale = lora_scale
263
+
264
+ # dynamically adjust the LoRA scale
265
+ if not USE_PEFT_BACKEND:
266
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
267
+ else:
268
+ scale_lora_layers(self.text_encoder, lora_scale)
269
+
270
+ if prompt is not None and isinstance(prompt, str):
271
+ batch_size = 1
272
+ elif prompt is not None and isinstance(prompt, list):
273
+ batch_size = len(prompt)
274
+ else:
275
+ batch_size = prompt_embeds.shape[0]
276
+
277
+ if prompt_embeds is None:
278
+ # textual inversion: process multi-vector tokens if necessary
279
+ if isinstance(self, TextualInversionLoaderMixin):
280
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
281
+
282
+ text_inputs = self.tokenizer(
283
+ prompt,
284
+ padding="max_length",
285
+ max_length=self.tokenizer.model_max_length,
286
+ truncation=True,
287
+ return_tensors="pt",
288
+ )
289
+ text_input_ids = text_inputs.input_ids
290
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
291
+
292
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
293
+ text_input_ids, untruncated_ids
294
+ ):
295
+ removed_text = self.tokenizer.batch_decode(
296
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
297
+ )
298
+ logger.warning(
299
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
300
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
301
+ )
302
+
303
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
304
+ attention_mask = text_inputs.attention_mask.to(device)
305
+ else:
306
+ attention_mask = None
307
+
308
+ if clip_skip is None:
309
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
310
+ prompt_embeds = prompt_embeds[0]
311
+ else:
312
+ prompt_embeds = self.text_encoder(
313
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
314
+ )
315
+ # Access the `hidden_states` first, that contains a tuple of
316
+ # all the hidden states from the encoder layers. Then index into
317
+ # the tuple to access the hidden states from the desired layer.
318
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
319
+ # We also need to apply the final LayerNorm here to not mess with the
320
+ # representations. The `last_hidden_states` that we typically use for
321
+ # obtaining the final prompt representations passes through the LayerNorm
322
+ # layer.
323
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
324
+
325
+ if self.text_encoder is not None:
326
+ prompt_embeds_dtype = self.text_encoder.dtype
327
+ elif self.unet is not None:
328
+ prompt_embeds_dtype = self.unet.dtype
329
+ else:
330
+ prompt_embeds_dtype = prompt_embeds.dtype
331
+
332
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
333
+
334
+ bs_embed, seq_len, _ = prompt_embeds.shape
335
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
336
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
338
+
339
+ # get unconditional embeddings for classifier free guidance
340
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
341
+ uncond_tokens: List[str]
342
+ if negative_prompt is None:
343
+ uncond_tokens = [""] * batch_size
344
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
345
+ raise TypeError(
346
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
347
+ f" {type(prompt)}."
348
+ )
349
+ elif isinstance(negative_prompt, str):
350
+ uncond_tokens = [negative_prompt]
351
+ elif batch_size != len(negative_prompt):
352
+ raise ValueError(
353
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
354
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
355
+ " the batch size of `prompt`."
356
+ )
357
+ else:
358
+ uncond_tokens = negative_prompt
359
+
360
+ # textual inversion: process multi-vector tokens if necessary
361
+ if isinstance(self, TextualInversionLoaderMixin):
362
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
363
+
364
+ max_length = prompt_embeds.shape[1]
365
+ uncond_input = self.tokenizer(
366
+ uncond_tokens,
367
+ padding="max_length",
368
+ max_length=max_length,
369
+ truncation=True,
370
+ return_tensors="pt",
371
+ )
372
+
373
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
374
+ attention_mask = uncond_input.attention_mask.to(device)
375
+ else:
376
+ attention_mask = None
377
+
378
+ negative_prompt_embeds = self.text_encoder(
379
+ uncond_input.input_ids.to(device),
380
+ attention_mask=attention_mask,
381
+ )
382
+ negative_prompt_embeds = negative_prompt_embeds[0]
383
+
384
+ if do_classifier_free_guidance:
385
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
386
+ seq_len = negative_prompt_embeds.shape[1]
387
+
388
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
389
+
390
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
391
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
392
+
393
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
394
+ # Retrieve the original scale by scaling back the LoRA layers
395
+ unscale_lora_layers(self.text_encoder, lora_scale)
396
+
397
+ return prompt_embeds, negative_prompt_embeds
398
+
399
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
400
+ def run_safety_checker(self, image, device, dtype):
401
+ if self.safety_checker is None:
402
+ has_nsfw_concept = None
403
+ else:
404
+ if torch.is_tensor(image):
405
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
406
+ else:
407
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
408
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
409
+ image, has_nsfw_concept = self.safety_checker(
410
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
411
+ )
412
+ return image, has_nsfw_concept
413
+
414
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
415
+ def decode_latents(self, latents):
416
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
417
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
418
+
419
+ latents = 1 / self.vae.config.scaling_factor * latents
420
+ image = self.vae.decode(latents, return_dict=False)[0]
421
+ image = (image / 2 + 0.5).clamp(0, 1)
422
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
423
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
424
+ return image
425
+
426
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
427
+ def prepare_extra_step_kwargs(self, generator, eta):
428
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
429
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
430
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
431
+ # and should be between [0, 1]
432
+
433
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
434
+ extra_step_kwargs = {}
435
+ if accepts_eta:
436
+ extra_step_kwargs["eta"] = eta
437
+
438
+ # check if the scheduler accepts generator
439
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
440
+ if accepts_generator:
441
+ extra_step_kwargs["generator"] = generator
442
+ return extra_step_kwargs
443
+
444
+ def check_inputs(
445
+ self,
446
+ prompt,
447
+ image,
448
+ negative_prompt=None,
449
+ prompt_embeds=None,
450
+ negative_prompt_embeds=None,
451
+ controlnet_conditioning_scale=1.0,
452
+ control_guidance_start=0.0,
453
+ control_guidance_end=1.0,
454
+ callback_on_step_end_tensor_inputs=None,
455
+ ):
456
+ if callback_on_step_end_tensor_inputs is not None and not all(
457
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
458
+ ):
459
+ raise ValueError(
460
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
461
+ )
462
+
463
+ if prompt is not None and prompt_embeds is not None:
464
+ raise ValueError(
465
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
466
+ " only forward one of the two."
467
+ )
468
+ elif prompt is None and prompt_embeds is None:
469
+ raise ValueError(
470
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
471
+ )
472
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
473
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
474
+
475
+ if negative_prompt is not None and negative_prompt_embeds is not None:
476
+ raise ValueError(
477
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
478
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
479
+ )
480
+
481
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
482
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
483
+ raise ValueError(
484
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
485
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
486
+ f" {negative_prompt_embeds.shape}."
487
+ )
488
+
489
+ # Check `image` and `controlnet_conditioning_scale`
490
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
491
+ self.unet, torch._dynamo.eval_frame.OptimizedModule
492
+ )
493
+ if (
494
+ isinstance(self.unet, UNetControlNetXSModel)
495
+ or is_compiled
496
+ and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
497
+ ):
498
+ self.check_image(image, prompt, prompt_embeds)
499
+ if not isinstance(controlnet_conditioning_scale, float):
500
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
501
+ else:
502
+ assert False
503
+
504
+ start, end = control_guidance_start, control_guidance_end
505
+ if start >= end:
506
+ raise ValueError(
507
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
508
+ )
509
+ if start < 0.0:
510
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
511
+ if end > 1.0:
512
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
513
+
514
+ def check_image(self, image, prompt, prompt_embeds):
515
+ image_is_pil = isinstance(image, PIL.Image.Image)
516
+ image_is_tensor = isinstance(image, torch.Tensor)
517
+ image_is_np = isinstance(image, np.ndarray)
518
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
519
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
520
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
521
+
522
+ if (
523
+ not image_is_pil
524
+ and not image_is_tensor
525
+ and not image_is_np
526
+ and not image_is_pil_list
527
+ and not image_is_tensor_list
528
+ and not image_is_np_list
529
+ ):
530
+ raise TypeError(
531
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
532
+ )
533
+
534
+ if image_is_pil:
535
+ image_batch_size = 1
536
+ else:
537
+ image_batch_size = len(image)
538
+
539
+ if prompt is not None and isinstance(prompt, str):
540
+ prompt_batch_size = 1
541
+ elif prompt is not None and isinstance(prompt, list):
542
+ prompt_batch_size = len(prompt)
543
+ elif prompt_embeds is not None:
544
+ prompt_batch_size = prompt_embeds.shape[0]
545
+
546
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
547
+ raise ValueError(
548
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
549
+ )
550
+
551
+ def prepare_image(
552
+ self,
553
+ image,
554
+ width,
555
+ height,
556
+ batch_size,
557
+ num_images_per_prompt,
558
+ device,
559
+ dtype,
560
+ do_classifier_free_guidance=False,
561
+ ):
562
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
563
+ image_batch_size = image.shape[0]
564
+
565
+ if image_batch_size == 1:
566
+ repeat_by = batch_size
567
+ else:
568
+ # image batch size is the same as prompt batch size
569
+ repeat_by = num_images_per_prompt
570
+
571
+ image = image.repeat_interleave(repeat_by, dim=0)
572
+
573
+ image = image.to(device=device, dtype=dtype)
574
+
575
+ if do_classifier_free_guidance:
576
+ image = torch.cat([image] * 2)
577
+
578
+ return image
579
+
580
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
581
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
582
+ shape = (
583
+ batch_size,
584
+ num_channels_latents,
585
+ int(height) // self.vae_scale_factor,
586
+ int(width) // self.vae_scale_factor,
587
+ )
588
+ if isinstance(generator, list) and len(generator) != batch_size:
589
+ raise ValueError(
590
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
591
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
592
+ )
593
+
594
+ if latents is None:
595
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
596
+ else:
597
+ latents = latents.to(device)
598
+
599
+ # scale the initial noise by the standard deviation required by the scheduler
600
+ latents = latents * self.scheduler.init_noise_sigma
601
+ return latents
602
+
603
+ @property
604
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
605
+ def guidance_scale(self):
606
+ return self._guidance_scale
607
+
608
+ @property
609
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
610
+ def clip_skip(self):
611
+ return self._clip_skip
612
+
613
+ @property
614
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
615
+ def do_classifier_free_guidance(self):
616
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
617
+
618
+ @property
619
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
620
+ def cross_attention_kwargs(self):
621
+ return self._cross_attention_kwargs
622
+
623
+ @property
624
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
625
+ def num_timesteps(self):
626
+ return self._num_timesteps
627
+
628
+ @torch.no_grad()
629
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
630
+ def __call__(
631
+ self,
632
+ prompt: Union[str, List[str]] = None,
633
+ image: PipelineImageInput = None,
634
+ height: Optional[int] = None,
635
+ width: Optional[int] = None,
636
+ num_inference_steps: int = 50,
637
+ guidance_scale: float = 7.5,
638
+ negative_prompt: Optional[Union[str, List[str]]] = None,
639
+ num_images_per_prompt: Optional[int] = 1,
640
+ eta: float = 0.0,
641
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
642
+ latents: Optional[torch.Tensor] = None,
643
+ prompt_embeds: Optional[torch.Tensor] = None,
644
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
645
+ output_type: Optional[str] = "pil",
646
+ return_dict: bool = True,
647
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
648
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
649
+ control_guidance_start: float = 0.0,
650
+ control_guidance_end: float = 1.0,
651
+ clip_skip: Optional[int] = None,
652
+ callback_on_step_end: Optional[
653
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
654
+ ] = None,
655
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
656
+ ):
657
+ r"""
658
+ The call function to the pipeline for generation.
659
+
660
+ Args:
661
+ prompt (`str` or `List[str]`, *optional*):
662
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
663
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
664
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
665
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
666
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
667
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
668
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
669
+ images must be passed as a list such that each element of the list can be correctly batched for input
670
+ to a single ControlNet.
671
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
672
+ The height in pixels of the generated image.
673
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
674
+ The width in pixels of the generated image.
675
+ num_inference_steps (`int`, *optional*, defaults to 50):
676
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
677
+ expense of slower inference.
678
+ guidance_scale (`float`, *optional*, defaults to 7.5):
679
+ A higher guidance scale value encourages the model to generate images closely linked to the text
680
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
681
+ negative_prompt (`str` or `List[str]`, *optional*):
682
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
683
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
684
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
685
+ The number of images to generate per prompt.
686
+ eta (`float`, *optional*, defaults to 0.0):
687
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
688
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
689
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
690
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
691
+ generation deterministic.
692
+ latents (`torch.Tensor`, *optional*):
693
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
694
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
695
+ tensor is generated by sampling using the supplied random `generator`.
696
+ prompt_embeds (`torch.Tensor`, *optional*):
697
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
698
+ provided, text embeddings are generated from the `prompt` input argument.
699
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
700
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
701
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
702
+ output_type (`str`, *optional*, defaults to `"pil"`):
703
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
704
+ return_dict (`bool`, *optional*, defaults to `True`):
705
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
706
+ plain tuple.
707
+ cross_attention_kwargs (`dict`, *optional*):
708
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
709
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
710
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
711
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
712
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
713
+ the corresponding scale as a list.
714
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
715
+ The percentage of total steps at which the ControlNet starts applying.
716
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
717
+ The percentage of total steps at which the ControlNet stops applying.
718
+ clip_skip (`int`, *optional*):
719
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
720
+ the output of the pre-final layer will be used for computing the prompt embeddings.
721
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
722
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
723
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
724
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
725
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
726
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
727
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
728
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
729
+ `._callback_tensor_inputs` attribute of your pipeine class.
730
+ Examples:
731
+
732
+ Returns:
733
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
734
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
735
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
736
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
737
+ "not-safe-for-work" (nsfw) content.
738
+ """
739
+
740
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
741
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
742
+
743
+ unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
744
+
745
+ # 1. Check inputs. Raise error if not correct
746
+ self.check_inputs(
747
+ prompt,
748
+ image,
749
+ negative_prompt,
750
+ prompt_embeds,
751
+ negative_prompt_embeds,
752
+ controlnet_conditioning_scale,
753
+ control_guidance_start,
754
+ control_guidance_end,
755
+ callback_on_step_end_tensor_inputs,
756
+ )
757
+
758
+ self._guidance_scale = guidance_scale
759
+ self._clip_skip = clip_skip
760
+ self._cross_attention_kwargs = cross_attention_kwargs
761
+ self._interrupt = False
762
+
763
+ # 2. Define call parameters
764
+ if prompt is not None and isinstance(prompt, str):
765
+ batch_size = 1
766
+ elif prompt is not None and isinstance(prompt, list):
767
+ batch_size = len(prompt)
768
+ else:
769
+ batch_size = prompt_embeds.shape[0]
770
+
771
+ device = self._execution_device
772
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
773
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
774
+ # corresponds to doing no classifier free guidance.
775
+ do_classifier_free_guidance = guidance_scale > 1.0
776
+
777
+ # 3. Encode input prompt
778
+ text_encoder_lora_scale = (
779
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
780
+ )
781
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
782
+ prompt,
783
+ device,
784
+ num_images_per_prompt,
785
+ do_classifier_free_guidance,
786
+ negative_prompt,
787
+ prompt_embeds=prompt_embeds,
788
+ negative_prompt_embeds=negative_prompt_embeds,
789
+ lora_scale=text_encoder_lora_scale,
790
+ clip_skip=clip_skip,
791
+ )
792
+
793
+ # For classifier free guidance, we need to do two forward passes.
794
+ # Here we concatenate the unconditional and text embeddings into a single batch
795
+ # to avoid doing two forward passes
796
+ if do_classifier_free_guidance:
797
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
798
+
799
+ # 4. Prepare image
800
+ image = self.prepare_image(
801
+ image=image,
802
+ width=width,
803
+ height=height,
804
+ batch_size=batch_size * num_images_per_prompt,
805
+ num_images_per_prompt=num_images_per_prompt,
806
+ device=device,
807
+ dtype=unet.dtype,
808
+ do_classifier_free_guidance=do_classifier_free_guidance,
809
+ )
810
+ height, width = image.shape[-2:]
811
+
812
+ # 5. Prepare timesteps
813
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
814
+ timesteps = self.scheduler.timesteps
815
+
816
+ # 6. Prepare latent variables
817
+ num_channels_latents = self.unet.in_channels
818
+ latents = self.prepare_latents(
819
+ batch_size * num_images_per_prompt,
820
+ num_channels_latents,
821
+ height,
822
+ width,
823
+ prompt_embeds.dtype,
824
+ device,
825
+ generator,
826
+ latents,
827
+ )
828
+
829
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
830
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
831
+
832
+ # 8. Denoising loop
833
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
834
+ self._num_timesteps = len(timesteps)
835
+ is_controlnet_compiled = is_compiled_module(self.unet)
836
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
837
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
838
+ for i, t in enumerate(timesteps):
839
+ # Relevant thread:
840
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
841
+ if is_controlnet_compiled and is_torch_higher_equal_2_1:
842
+ torch._inductor.cudagraph_mark_step_begin()
843
+ # expand the latents if we are doing classifier free guidance
844
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
845
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
846
+
847
+ # predict the noise residual
848
+ apply_control = (
849
+ i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
850
+ )
851
+ noise_pred = self.unet(
852
+ sample=latent_model_input,
853
+ timestep=t,
854
+ encoder_hidden_states=prompt_embeds,
855
+ controlnet_cond=image,
856
+ conditioning_scale=controlnet_conditioning_scale,
857
+ cross_attention_kwargs=cross_attention_kwargs,
858
+ return_dict=True,
859
+ apply_control=apply_control,
860
+ ).sample
861
+
862
+ # perform guidance
863
+ if do_classifier_free_guidance:
864
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
865
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
866
+
867
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
868
+
869
+ if callback_on_step_end is not None:
870
+ callback_kwargs = {}
871
+ for k in callback_on_step_end_tensor_inputs:
872
+ callback_kwargs[k] = locals()[k]
873
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
874
+
875
+ latents = callback_outputs.pop("latents", latents)
876
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
877
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
878
+
879
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
880
+ progress_bar.update()
881
+
882
+ # If we do sequential model offloading, let's offload unet and controlnet
883
+ # manually for max memory savings
884
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
885
+ self.unet.to("cpu")
886
+ self.controlnet.to("cpu")
887
+ torch.cuda.empty_cache()
888
+
889
+ if not output_type == "latent":
890
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
891
+ 0
892
+ ]
893
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
894
+ else:
895
+ image = latents
896
+ has_nsfw_concept = None
897
+
898
+ if has_nsfw_concept is None:
899
+ do_denormalize = [True] * image.shape[0]
900
+ else:
901
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
902
+
903
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
904
+
905
+ # Offload all models
906
+ self.maybe_free_model_hooks()
907
+
908
+ if not return_dict:
909
+ return (image, has_nsfw_concept)
910
+
911
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)