diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class CogView3PipelineOutput(BaseOutput):
12
+ """
13
+ Output class for CogView3 pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
@@ -1,80 +1,86 @@
1
- from typing import TYPE_CHECKING
2
-
3
- from ...utils import (
4
- DIFFUSERS_SLOW_IMPORT,
5
- OptionalDependencyNotAvailable,
6
- _LazyModule,
7
- get_objects_from_module,
8
- is_flax_available,
9
- is_torch_available,
10
- is_transformers_available,
11
- )
12
-
13
-
14
- _dummy_objects = {}
15
- _import_structure = {}
16
-
17
- try:
18
- if not (is_transformers_available() and is_torch_available()):
19
- raise OptionalDependencyNotAvailable()
20
- except OptionalDependencyNotAvailable:
21
- from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
-
23
- _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
- else:
25
- _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
26
- _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
27
- _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
28
- _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
29
- _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
30
- _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
31
- _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
32
- _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
33
- try:
34
- if not (is_transformers_available() and is_flax_available()):
35
- raise OptionalDependencyNotAvailable()
36
- except OptionalDependencyNotAvailable:
37
- from ...utils import dummy_flax_and_transformers_objects # noqa F403
38
-
39
- _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
40
- else:
41
- _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
42
-
43
-
44
- if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
45
- try:
46
- if not (is_transformers_available() and is_torch_available()):
47
- raise OptionalDependencyNotAvailable()
48
-
49
- except OptionalDependencyNotAvailable:
50
- from ...utils.dummy_torch_and_transformers_objects import *
51
- else:
52
- from .multicontrolnet import MultiControlNetModel
53
- from .pipeline_controlnet import StableDiffusionControlNetPipeline
54
- from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
55
- from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
56
- from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
57
- from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
58
- from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
59
- from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
60
-
61
- try:
62
- if not (is_transformers_available() and is_flax_available()):
63
- raise OptionalDependencyNotAvailable()
64
- except OptionalDependencyNotAvailable:
65
- from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
66
- else:
67
- from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
68
-
69
-
70
- else:
71
- import sys
72
-
73
- sys.modules[__name__] = _LazyModule(
74
- __name__,
75
- globals()["__file__"],
76
- _import_structure,
77
- module_spec=__spec__,
78
- )
79
- for name, value in _dummy_objects.items():
80
- setattr(sys.modules[__name__], name, value)
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_flax_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
26
+ _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
27
+ _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
28
+ _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
29
+ _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
30
+ _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
31
+ _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
32
+ _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
33
+ _import_structure["pipeline_controlnet_union_inpaint_sd_xl"] = ["StableDiffusionXLControlNetUnionInpaintPipeline"]
34
+ _import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"]
35
+ _import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"]
36
+ try:
37
+ if not (is_transformers_available() and is_flax_available()):
38
+ raise OptionalDependencyNotAvailable()
39
+ except OptionalDependencyNotAvailable:
40
+ from ...utils import dummy_flax_and_transformers_objects # noqa F403
41
+
42
+ _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
43
+ else:
44
+ _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
45
+
46
+
47
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
48
+ try:
49
+ if not (is_transformers_available() and is_torch_available()):
50
+ raise OptionalDependencyNotAvailable()
51
+
52
+ except OptionalDependencyNotAvailable:
53
+ from ...utils.dummy_torch_and_transformers_objects import *
54
+ else:
55
+ from .multicontrolnet import MultiControlNetModel
56
+ from .pipeline_controlnet import StableDiffusionControlNetPipeline
57
+ from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
58
+ from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
59
+ from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
60
+ from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
61
+ from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
62
+ from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
63
+ from .pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
64
+ from .pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline
65
+ from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline
66
+
67
+ try:
68
+ if not (is_transformers_available() and is_flax_available()):
69
+ raise OptionalDependencyNotAvailable()
70
+ except OptionalDependencyNotAvailable:
71
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
72
+ else:
73
+ from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
74
+
75
+
76
+ else:
77
+ import sys
78
+
79
+ sys.modules[__name__] = _LazyModule(
80
+ __name__,
81
+ globals()["__file__"],
82
+ _import_structure,
83
+ module_spec=__spec__,
84
+ )
85
+ for name, value in _dummy_objects.items():
86
+ setattr(sys.modules[__name__], name, value)
@@ -1,183 +1,12 @@
1
- import os
2
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
-
4
- import torch
5
- from torch import nn
6
-
7
- from ...models.controlnet import ControlNetModel, ControlNetOutput
8
- from ...models.modeling_utils import ModelMixin
9
- from ...utils import logging
1
+ from ...models.controlnets.multicontrolnet import MultiControlNetModel
2
+ from ...utils import deprecate, logging
10
3
 
11
4
 
12
5
  logger = logging.get_logger(__name__)
13
6
 
14
7
 
15
- class MultiControlNetModel(ModelMixin):
16
- r"""
17
- Multiple `ControlNetModel` wrapper class for Multi-ControlNet
18
-
19
- This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
20
- compatible with `ControlNetModel`.
21
-
22
- Args:
23
- controlnets (`List[ControlNetModel]`):
24
- Provides additional conditioning to the unet during the denoising process. You must set multiple
25
- `ControlNetModel` as a list.
26
- """
27
-
28
- def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
29
- super().__init__()
30
- self.nets = nn.ModuleList(controlnets)
31
-
32
- def forward(
33
- self,
34
- sample: torch.Tensor,
35
- timestep: Union[torch.Tensor, float, int],
36
- encoder_hidden_states: torch.Tensor,
37
- controlnet_cond: List[torch.tensor],
38
- conditioning_scale: List[float],
39
- class_labels: Optional[torch.Tensor] = None,
40
- timestep_cond: Optional[torch.Tensor] = None,
41
- attention_mask: Optional[torch.Tensor] = None,
42
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
43
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
44
- guess_mode: bool = False,
45
- return_dict: bool = True,
46
- ) -> Union[ControlNetOutput, Tuple]:
47
- for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
48
- down_samples, mid_sample = controlnet(
49
- sample=sample,
50
- timestep=timestep,
51
- encoder_hidden_states=encoder_hidden_states,
52
- controlnet_cond=image,
53
- conditioning_scale=scale,
54
- class_labels=class_labels,
55
- timestep_cond=timestep_cond,
56
- attention_mask=attention_mask,
57
- added_cond_kwargs=added_cond_kwargs,
58
- cross_attention_kwargs=cross_attention_kwargs,
59
- guess_mode=guess_mode,
60
- return_dict=return_dict,
61
- )
62
-
63
- # merge samples
64
- if i == 0:
65
- down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
66
- else:
67
- down_block_res_samples = [
68
- samples_prev + samples_curr
69
- for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
70
- ]
71
- mid_block_res_sample += mid_sample
72
-
73
- return down_block_res_samples, mid_block_res_sample
74
-
75
- def save_pretrained(
76
- self,
77
- save_directory: Union[str, os.PathLike],
78
- is_main_process: bool = True,
79
- save_function: Callable = None,
80
- safe_serialization: bool = True,
81
- variant: Optional[str] = None,
82
- ):
83
- """
84
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
85
- `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
86
-
87
- Arguments:
88
- save_directory (`str` or `os.PathLike`):
89
- Directory to which to save. Will be created if it doesn't exist.
90
- is_main_process (`bool`, *optional*, defaults to `True`):
91
- Whether the process calling this is the main process or not. Useful when in distributed training like
92
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
93
- the main process to avoid race conditions.
94
- save_function (`Callable`):
95
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
96
- need to replace `torch.save` by another method. Can be configured with the environment variable
97
- `DIFFUSERS_SAVE_MODE`.
98
- safe_serialization (`bool`, *optional*, defaults to `True`):
99
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
100
- variant (`str`, *optional*):
101
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
102
- """
103
- for idx, controlnet in enumerate(self.nets):
104
- suffix = "" if idx == 0 else f"_{idx}"
105
- controlnet.save_pretrained(
106
- save_directory + suffix,
107
- is_main_process=is_main_process,
108
- save_function=save_function,
109
- safe_serialization=safe_serialization,
110
- variant=variant,
111
- )
112
-
113
- @classmethod
114
- def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
115
- r"""
116
- Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
117
-
118
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
119
- the model, you should first set it back in training mode with `model.train()`.
120
-
121
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
122
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
123
- task.
124
-
125
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
126
- weights are discarded.
127
-
128
- Parameters:
129
- pretrained_model_path (`os.PathLike`):
130
- A path to a *directory* containing model weights saved using
131
- [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
132
- `./my_model_directory/controlnet`.
133
- torch_dtype (`str` or `torch.dtype`, *optional*):
134
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
135
- will be automatically derived from the model's weights.
136
- output_loading_info(`bool`, *optional*, defaults to `False`):
137
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
138
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
139
- A map that specifies where each submodule should go. It doesn't need to be refined to each
140
- parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
141
- same device.
142
-
143
- To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
144
- more information about each option see [designing a device
145
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
146
- max_memory (`Dict`, *optional*):
147
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
148
- GPU and the available CPU RAM if unset.
149
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
150
- Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
151
- also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
152
- model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
153
- setting this argument to `True` will raise an error.
154
- variant (`str`, *optional*):
155
- If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
156
- ignored when using `from_flax`.
157
- use_safetensors (`bool`, *optional*, defaults to `None`):
158
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
159
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
160
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
161
- """
162
- idx = 0
163
- controlnets = []
164
-
165
- # load controlnet and append to list until no controlnet directory exists anymore
166
- # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
167
- # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
168
- model_path_to_load = pretrained_model_path
169
- while os.path.isdir(model_path_to_load):
170
- controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
171
- controlnets.append(controlnet)
172
-
173
- idx += 1
174
- model_path_to_load = pretrained_model_path + f"_{idx}"
175
-
176
- logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
177
-
178
- if len(controlnets) == 0:
179
- raise ValueError(
180
- f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
181
- )
182
-
183
- return cls(controlnets)
8
+ class MultiControlNetModel(MultiControlNetModel):
9
+ def __init__(self, *args, **kwargs):
10
+ deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
11
+ deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message)
12
+ super().__init__(*args, **kwargs)
@@ -25,12 +25,13 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
25
25
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
26
  from ...image_processor import PipelineImageInput, VaeImageProcessor
27
27
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
28
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
28
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
29
29
  from ...models.lora import adjust_lora_scale_text_encoder
30
30
  from ...schedulers import KarrasDiffusionSchedulers
31
31
  from ...utils import (
32
32
  USE_PEFT_BACKEND,
33
33
  deprecate,
34
+ is_torch_xla_available,
34
35
  logging,
35
36
  replace_example_docstring,
36
37
  scale_lora_layers,
@@ -40,9 +41,15 @@ from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_ten
40
41
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
41
42
  from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
42
43
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
- from .multicontrolnet import MultiControlNetModel
44
44
 
45
45
 
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
46
53
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
54
 
48
55
 
@@ -101,7 +108,7 @@ def retrieve_timesteps(
101
108
  sigmas: Optional[List[float]] = None,
102
109
  **kwargs,
103
110
  ):
104
- """
111
+ r"""
105
112
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
106
113
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107
114
 
@@ -893,6 +900,10 @@ class StableDiffusionControlNetPipeline(
893
900
  def num_timesteps(self):
894
901
  return self._num_timesteps
895
902
 
903
+ @property
904
+ def interrupt(self):
905
+ return self._interrupt
906
+
896
907
  @torch.no_grad()
897
908
  @replace_example_docstring(EXAMPLE_DOC_STRING)
898
909
  def __call__(
@@ -1089,6 +1100,7 @@ class StableDiffusionControlNetPipeline(
1089
1100
  self._guidance_scale = guidance_scale
1090
1101
  self._clip_skip = clip_skip
1091
1102
  self._cross_attention_kwargs = cross_attention_kwargs
1103
+ self._interrupt = False
1092
1104
 
1093
1105
  # 2. Define call parameters
1094
1106
  if prompt is not None and isinstance(prompt, str):
@@ -1235,6 +1247,9 @@ class StableDiffusionControlNetPipeline(
1235
1247
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1236
1248
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1237
1249
  for i, t in enumerate(timesteps):
1250
+ if self.interrupt:
1251
+ continue
1252
+
1238
1253
  # Relevant thread:
1239
1254
  # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1240
1255
  if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
@@ -1316,6 +1331,8 @@ class StableDiffusionControlNetPipeline(
1316
1331
  step_idx = i // getattr(self.scheduler, "order", 1)
1317
1332
  callback(step_idx, t, latents)
1318
1333
 
1334
+ if XLA_AVAILABLE:
1335
+ xm.mark_step()
1319
1336
  # If we do sequential model offloading, let's offload unet and controlnet
1320
1337
  # manually for max memory savings
1321
1338
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
24
24
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
25
  from ...image_processor import PipelineImageInput, VaeImageProcessor
26
26
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
27
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
27
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
28
28
  from ...models.lora import adjust_lora_scale_text_encoder
29
29
  from ...schedulers import KarrasDiffusionSchedulers
30
30
  from ...utils import (
@@ -39,7 +39,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
39
39
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
40
  from ..stable_diffusion import StableDiffusionPipelineOutput
41
41
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
42
- from .multicontrolnet import MultiControlNetModel
43
42
 
44
43
 
45
44
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -891,6 +890,10 @@ class StableDiffusionControlNetImg2ImgPipeline(
891
890
  def num_timesteps(self):
892
891
  return self._num_timesteps
893
892
 
893
+ @property
894
+ def interrupt(self):
895
+ return self._interrupt
896
+
894
897
  @torch.no_grad()
895
898
  @replace_example_docstring(EXAMPLE_DOC_STRING)
896
899
  def __call__(
@@ -1081,6 +1084,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
1081
1084
  self._guidance_scale = guidance_scale
1082
1085
  self._clip_skip = clip_skip
1083
1086
  self._cross_attention_kwargs = cross_attention_kwargs
1087
+ self._interrupt = False
1084
1088
 
1085
1089
  # 2. Define call parameters
1086
1090
  if prompt is not None and isinstance(prompt, str):
@@ -1211,6 +1215,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
1211
1215
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1212
1216
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1213
1217
  for i, t in enumerate(timesteps):
1218
+ if self.interrupt:
1219
+ continue
1220
+
1214
1221
  # expand the latents if we are doing classifier free guidance
1215
1222
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1216
1223
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -26,7 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
26
26
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27
27
  from ...image_processor import PipelineImageInput, VaeImageProcessor
28
28
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
29
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
29
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
30
30
  from ...models.lora import adjust_lora_scale_text_encoder
31
31
  from ...schedulers import KarrasDiffusionSchedulers
32
32
  from ...utils import (
@@ -41,7 +41,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
41
41
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
42
42
  from ..stable_diffusion import StableDiffusionPipelineOutput
43
43
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
44
- from .multicontrolnet import MultiControlNetModel
45
44
 
46
45
 
47
46
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -976,6 +975,10 @@ class StableDiffusionControlNetInpaintPipeline(
976
975
  def num_timesteps(self):
977
976
  return self._num_timesteps
978
977
 
978
+ @property
979
+ def interrupt(self):
980
+ return self._interrupt
981
+
979
982
  @torch.no_grad()
980
983
  @replace_example_docstring(EXAMPLE_DOC_STRING)
981
984
  def __call__(
@@ -1191,6 +1194,7 @@ class StableDiffusionControlNetInpaintPipeline(
1191
1194
  self._guidance_scale = guidance_scale
1192
1195
  self._clip_skip = clip_skip
1193
1196
  self._cross_attention_kwargs = cross_attention_kwargs
1197
+ self._interrupt = False
1194
1198
 
1195
1199
  # 2. Define call parameters
1196
1200
  if prompt is not None and isinstance(prompt, str):
@@ -1375,6 +1379,9 @@ class StableDiffusionControlNetInpaintPipeline(
1375
1379
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1376
1380
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1377
1381
  for i, t in enumerate(timesteps):
1382
+ if self.interrupt:
1383
+ continue
1384
+
1378
1385
  # expand the latents if we are doing classifier free guidance
1379
1386
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1380
1387
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)