diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
18
  import numpy as np
19
19
  import PIL.Image
20
20
  import torch
21
- import torch.nn.functional as F
22
21
  from transformers import (
23
22
  CLIPImageProcessor,
24
23
  CLIPTextModel,
@@ -35,7 +34,13 @@ from ...loaders import (
35
34
  StableDiffusionXLLoraLoaderMixin,
36
35
  TextualInversionLoaderMixin,
37
36
  )
38
- from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
37
+ from ...models import (
38
+ AutoencoderKL,
39
+ ControlNetUnionModel,
40
+ ImageProjection,
41
+ MultiControlNetUnionModel,
42
+ UNet2DConditionModel,
43
+ )
39
44
  from ...models.attention_processor import (
40
45
  AttnProcessor2_0,
41
46
  XFormersAttnProcessor,
@@ -230,7 +235,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
230
235
  tokenizer: CLIPTokenizer,
231
236
  tokenizer_2: CLIPTokenizer,
232
237
  unet: UNet2DConditionModel,
233
- controlnet: ControlNetUnionModel,
238
+ controlnet: Union[
239
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
240
+ ],
234
241
  scheduler: KarrasDiffusionSchedulers,
235
242
  requires_aesthetics_score: bool = False,
236
243
  force_zeros_for_empty_prompt: bool = True,
@@ -240,8 +247,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
240
247
  ):
241
248
  super().__init__()
242
249
 
243
- if not isinstance(controlnet, ControlNetUnionModel):
244
- raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
250
+ if isinstance(controlnet, (list, tuple)):
251
+ controlnet = MultiControlNetUnionModel(controlnet)
245
252
 
246
253
  self.register_modules(
247
254
  vae=vae,
@@ -660,6 +667,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
660
667
  controlnet_conditioning_scale=1.0,
661
668
  control_guidance_start=0.0,
662
669
  control_guidance_end=1.0,
670
+ control_mode=None,
663
671
  callback_on_step_end_tensor_inputs=None,
664
672
  padding_mask_crop=None,
665
673
  ):
@@ -747,25 +755,34 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
747
755
  "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
748
756
  )
749
757
 
758
+ # `prompt` needs more sophisticated handling when there are multiple
759
+ # conditionings.
760
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
761
+ if isinstance(prompt, list):
762
+ logger.warning(
763
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
764
+ " prompts. The conditionings will be fixed across the prompts."
765
+ )
766
+
750
767
  # Check `image`
751
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
752
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
753
- )
754
- if (
755
- isinstance(self.controlnet, ControlNetModel)
756
- or is_compiled
757
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
758
- ):
759
- self.check_image(image, prompt, prompt_embeds)
760
- elif (
761
- isinstance(self.controlnet, ControlNetUnionModel)
762
- or is_compiled
763
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
764
- ):
765
- self.check_image(image, prompt, prompt_embeds)
768
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
766
769
 
767
- else:
768
- assert False
770
+ if isinstance(controlnet, ControlNetUnionModel):
771
+ for image_ in image:
772
+ self.check_image(image_, prompt, prompt_embeds)
773
+ elif isinstance(controlnet, MultiControlNetUnionModel):
774
+ if not isinstance(image, list):
775
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
776
+ elif not all(isinstance(i, list) for i in image):
777
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
778
+ elif len(image) != len(self.controlnet.nets):
779
+ raise ValueError(
780
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
781
+ )
782
+
783
+ for images_ in image:
784
+ for image_ in images_:
785
+ self.check_image(image_, prompt, prompt_embeds)
769
786
 
770
787
  if not isinstance(control_guidance_start, (tuple, list)):
771
788
  control_guidance_start = [control_guidance_start]
@@ -778,6 +795,12 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
778
795
  f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
779
796
  )
780
797
 
798
+ if isinstance(controlnet, MultiControlNetUnionModel):
799
+ if len(control_guidance_start) != len(self.controlnet.nets):
800
+ raise ValueError(
801
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
802
+ )
803
+
781
804
  for start, end in zip(control_guidance_start, control_guidance_end):
782
805
  if start >= end:
783
806
  raise ValueError(
@@ -788,6 +811,28 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
788
811
  if end > 1.0:
789
812
  raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
790
813
 
814
+ # Check `control_mode`
815
+ if isinstance(controlnet, ControlNetUnionModel):
816
+ if max(control_mode) >= controlnet.config.num_control_type:
817
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
818
+ elif isinstance(controlnet, MultiControlNetUnionModel):
819
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
820
+ if max(_control_mode) >= _controlnet.config.num_control_type:
821
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
822
+
823
+ # Equal number of `image` and `control_mode` elements
824
+ if isinstance(controlnet, ControlNetUnionModel):
825
+ if len(image) != len(control_mode):
826
+ raise ValueError("Expected len(control_image) == len(control_mode)")
827
+ elif isinstance(controlnet, MultiControlNetUnionModel):
828
+ if not all(isinstance(i, list) for i in control_mode):
829
+ raise ValueError(
830
+ "For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
831
+ )
832
+
833
+ elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
834
+ raise ValueError("Expected len(control_image) == len(control_mode)")
835
+
791
836
  if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
792
837
  raise ValueError(
793
838
  "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1117,7 +1162,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1117
1162
  prompt_2: Optional[Union[str, List[str]]] = None,
1118
1163
  image: PipelineImageInput = None,
1119
1164
  mask_image: PipelineImageInput = None,
1120
- control_image: PipelineImageInput = None,
1165
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
1121
1166
  height: Optional[int] = None,
1122
1167
  width: Optional[int] = None,
1123
1168
  padding_mask_crop: Optional[int] = None,
@@ -1145,7 +1190,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1145
1190
  guess_mode: bool = False,
1146
1191
  control_guidance_start: Union[float, List[float]] = 0.0,
1147
1192
  control_guidance_end: Union[float, List[float]] = 1.0,
1148
- control_mode: Optional[Union[int, List[int]]] = None,
1193
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
1149
1194
  guidance_rescale: float = 0.0,
1150
1195
  original_size: Tuple[int, int] = None,
1151
1196
  crops_coords_top_left: Tuple[int, int] = (0, 0),
@@ -1177,6 +1222,13 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1177
1222
  repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1178
1223
  to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1179
1224
  instead of 3, so the expected shape would be `(B, H, W, 1)`.
1225
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
1226
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
1227
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
1228
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1229
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
1230
+ images must be passed as a list such that each element of the list can be correctly batched for input
1231
+ to a single ControlNet.
1180
1232
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1181
1233
  The height in pixels of the generated image.
1182
1234
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -1269,6 +1321,22 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1269
1321
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1270
1322
  `self.processor` in
1271
1323
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1324
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1325
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1326
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1327
+ the corresponding scale as a list.
1328
+ guess_mode (`bool`, *optional*, defaults to `False`):
1329
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
1330
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1331
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1332
+ The percentage of total steps at which the ControlNet starts applying.
1333
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1334
+ The percentage of total steps at which the ControlNet stops applying.
1335
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
1336
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
1337
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
1338
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
1339
+ conditions in control_image.
1272
1340
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1273
1341
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1274
1342
  `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
@@ -1333,22 +1401,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1333
1401
 
1334
1402
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1335
1403
 
1336
- # align format for control guidance
1337
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1338
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1339
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1340
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1341
-
1342
- # # 0.0 Default height and width to unet
1343
- # height = height or self.unet.config.sample_size * self.vae_scale_factor
1344
- # width = width or self.unet.config.sample_size * self.vae_scale_factor
1345
-
1346
- # 0.1 align format for control guidance
1347
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1348
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1349
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1350
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1351
-
1352
1404
  if not isinstance(control_image, list):
1353
1405
  control_image = [control_image]
1354
1406
  else:
@@ -1357,40 +1409,59 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1357
1409
  if not isinstance(control_mode, list):
1358
1410
  control_mode = [control_mode]
1359
1411
 
1360
- if len(control_image) != len(control_mode):
1361
- raise ValueError("Expected len(control_image) == len(control_type)")
1412
+ if isinstance(controlnet, MultiControlNetUnionModel):
1413
+ control_image = [[item] for item in control_image]
1414
+ control_mode = [[item] for item in control_mode]
1362
1415
 
1363
- num_control_type = controlnet.config.num_control_type
1416
+ # align format for control guidance
1417
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1418
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1419
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1420
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1421
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1422
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
1423
+ control_guidance_start, control_guidance_end = (
1424
+ mult * [control_guidance_start],
1425
+ mult * [control_guidance_end],
1426
+ )
1427
+
1428
+ if isinstance(controlnet_conditioning_scale, float):
1429
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
1430
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
1364
1431
 
1365
1432
  # 1. Check inputs
1366
- control_type = [0 for _ in range(num_control_type)]
1367
- for _image, control_idx in zip(control_image, control_mode):
1368
- control_type[control_idx] = 1
1369
- self.check_inputs(
1370
- prompt,
1371
- prompt_2,
1372
- _image,
1373
- mask_image,
1374
- strength,
1375
- num_inference_steps,
1376
- callback_steps,
1377
- output_type,
1378
- negative_prompt,
1379
- negative_prompt_2,
1380
- prompt_embeds,
1381
- negative_prompt_embeds,
1382
- ip_adapter_image,
1383
- ip_adapter_image_embeds,
1384
- pooled_prompt_embeds,
1385
- negative_pooled_prompt_embeds,
1386
- controlnet_conditioning_scale,
1387
- control_guidance_start,
1388
- control_guidance_end,
1389
- callback_on_step_end_tensor_inputs,
1390
- padding_mask_crop,
1391
- )
1433
+ self.check_inputs(
1434
+ prompt,
1435
+ prompt_2,
1436
+ control_image,
1437
+ mask_image,
1438
+ strength,
1439
+ num_inference_steps,
1440
+ callback_steps,
1441
+ output_type,
1442
+ negative_prompt,
1443
+ negative_prompt_2,
1444
+ prompt_embeds,
1445
+ negative_prompt_embeds,
1446
+ ip_adapter_image,
1447
+ ip_adapter_image_embeds,
1448
+ pooled_prompt_embeds,
1449
+ negative_pooled_prompt_embeds,
1450
+ controlnet_conditioning_scale,
1451
+ control_guidance_start,
1452
+ control_guidance_end,
1453
+ control_mode,
1454
+ callback_on_step_end_tensor_inputs,
1455
+ padding_mask_crop,
1456
+ )
1392
1457
 
1393
- control_type = torch.Tensor(control_type)
1458
+ if isinstance(controlnet, ControlNetUnionModel):
1459
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
1460
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1461
+ control_type = [
1462
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
1463
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
1464
+ ]
1394
1465
 
1395
1466
  self._guidance_scale = guidance_scale
1396
1467
  self._clip_skip = clip_skip
@@ -1483,21 +1554,55 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1483
1554
  init_image = init_image.to(dtype=torch.float32)
1484
1555
 
1485
1556
  # 5.2 Prepare control images
1486
- for idx, _ in enumerate(control_image):
1487
- control_image[idx] = self.prepare_control_image(
1488
- image=control_image[idx],
1489
- width=width,
1490
- height=height,
1491
- batch_size=batch_size * num_images_per_prompt,
1492
- num_images_per_prompt=num_images_per_prompt,
1493
- device=device,
1494
- dtype=controlnet.dtype,
1495
- crops_coords=crops_coords,
1496
- resize_mode=resize_mode,
1497
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1498
- guess_mode=guess_mode,
1499
- )
1500
- height, width = control_image[idx].shape[-2:]
1557
+ if isinstance(controlnet, ControlNetUnionModel):
1558
+ control_images = []
1559
+
1560
+ for image_ in control_image:
1561
+ image_ = self.prepare_control_image(
1562
+ image=image_,
1563
+ width=width,
1564
+ height=height,
1565
+ batch_size=batch_size * num_images_per_prompt,
1566
+ num_images_per_prompt=num_images_per_prompt,
1567
+ device=device,
1568
+ dtype=controlnet.dtype,
1569
+ crops_coords=crops_coords,
1570
+ resize_mode=resize_mode,
1571
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1572
+ guess_mode=guess_mode,
1573
+ )
1574
+
1575
+ control_images.append(image_)
1576
+
1577
+ control_image = control_images
1578
+ height, width = control_image[0].shape[-2:]
1579
+
1580
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1581
+ control_images = []
1582
+
1583
+ for control_image_ in control_image:
1584
+ images = []
1585
+
1586
+ for image_ in control_image_:
1587
+ image_ = self.prepare_control_image(
1588
+ image=image_,
1589
+ width=width,
1590
+ height=height,
1591
+ batch_size=batch_size * num_images_per_prompt,
1592
+ num_images_per_prompt=num_images_per_prompt,
1593
+ device=device,
1594
+ dtype=controlnet.dtype,
1595
+ crops_coords=crops_coords,
1596
+ resize_mode=resize_mode,
1597
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1598
+ guess_mode=guess_mode,
1599
+ )
1600
+
1601
+ images.append(image_)
1602
+ control_images.append(images)
1603
+
1604
+ control_image = control_images
1605
+ height, width = control_image[0][0].shape[-2:]
1501
1606
 
1502
1607
  # 5.3 Prepare mask
1503
1608
  mask = self.mask_processor.preprocess(
@@ -1559,10 +1664,11 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1559
1664
  # 8.2 Create tensor stating which controlnets to keep
1560
1665
  controlnet_keep = []
1561
1666
  for i in range(len(timesteps)):
1562
- controlnet_keep.append(
1563
- 1.0
1564
- - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
1565
- )
1667
+ keeps = [
1668
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1669
+ for s, e in zip(control_guidance_start, control_guidance_end)
1670
+ ]
1671
+ controlnet_keep.append(keeps)
1566
1672
 
1567
1673
  # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1568
1674
  height, width = latents.shape[-2:]
@@ -1627,11 +1733,24 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1627
1733
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1628
1734
  timesteps = timesteps[:num_inference_steps]
1629
1735
 
1630
- control_type = (
1631
- control_type.reshape(1, -1)
1632
- .to(device, dtype=prompt_embeds.dtype)
1633
- .repeat(batch_size * num_images_per_prompt * 2, 1)
1736
+ control_type_repeat_factor = (
1737
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
1634
1738
  )
1739
+
1740
+ if isinstance(controlnet, ControlNetUnionModel):
1741
+ control_type = (
1742
+ control_type.reshape(1, -1)
1743
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
1744
+ .repeat(control_type_repeat_factor, 1)
1745
+ )
1746
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1747
+ control_type = [
1748
+ _control_type.reshape(1, -1)
1749
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
1750
+ .repeat(control_type_repeat_factor, 1)
1751
+ for _control_type in control_type
1752
+ ]
1753
+
1635
1754
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1636
1755
  for i, t in enumerate(timesteps):
1637
1756
  if self.interrupt:
@@ -1452,17 +1452,21 @@ class StableDiffusionXLControlNetUnionPipeline(
1452
1452
  is_controlnet_compiled = is_compiled_module(self.controlnet)
1453
1453
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1454
1454
 
1455
+ control_type_repeat_factor = (
1456
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
1457
+ )
1458
+
1455
1459
  if isinstance(controlnet, ControlNetUnionModel):
1456
1460
  control_type = (
1457
1461
  control_type.reshape(1, -1)
1458
1462
  .to(self._execution_device, dtype=prompt_embeds.dtype)
1459
- .repeat(batch_size * num_images_per_prompt * 2, 1)
1463
+ .repeat(control_type_repeat_factor, 1)
1460
1464
  )
1461
- if isinstance(controlnet, MultiControlNetUnionModel):
1465
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1462
1466
  control_type = [
1463
1467
  _control_type.reshape(1, -1)
1464
1468
  .to(self._execution_device, dtype=prompt_embeds.dtype)
1465
- .repeat(batch_size * num_images_per_prompt * 2, 1)
1469
+ .repeat(control_type_repeat_factor, 1)
1466
1470
  for _control_type in control_type
1467
1471
  ]
1468
1472