diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class CogView3PipelineOutput(BaseOutput):
12
+ """
13
+ Output class for CogView3 pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
@@ -101,7 +101,7 @@ def retrieve_timesteps(
101
101
  sigmas: Optional[List[float]] = None,
102
102
  **kwargs,
103
103
  ):
104
- """
104
+ r"""
105
105
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
106
106
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107
107
 
@@ -893,6 +893,10 @@ class StableDiffusionControlNetPipeline(
893
893
  def num_timesteps(self):
894
894
  return self._num_timesteps
895
895
 
896
+ @property
897
+ def interrupt(self):
898
+ return self._interrupt
899
+
896
900
  @torch.no_grad()
897
901
  @replace_example_docstring(EXAMPLE_DOC_STRING)
898
902
  def __call__(
@@ -1089,6 +1093,7 @@ class StableDiffusionControlNetPipeline(
1089
1093
  self._guidance_scale = guidance_scale
1090
1094
  self._clip_skip = clip_skip
1091
1095
  self._cross_attention_kwargs = cross_attention_kwargs
1096
+ self._interrupt = False
1092
1097
 
1093
1098
  # 2. Define call parameters
1094
1099
  if prompt is not None and isinstance(prompt, str):
@@ -1235,6 +1240,9 @@ class StableDiffusionControlNetPipeline(
1235
1240
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1236
1241
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1237
1242
  for i, t in enumerate(timesteps):
1243
+ if self.interrupt:
1244
+ continue
1245
+
1238
1246
  # Relevant thread:
1239
1247
  # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1240
1248
  if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
@@ -891,6 +891,10 @@ class StableDiffusionControlNetImg2ImgPipeline(
891
891
  def num_timesteps(self):
892
892
  return self._num_timesteps
893
893
 
894
+ @property
895
+ def interrupt(self):
896
+ return self._interrupt
897
+
894
898
  @torch.no_grad()
895
899
  @replace_example_docstring(EXAMPLE_DOC_STRING)
896
900
  def __call__(
@@ -1081,6 +1085,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
1081
1085
  self._guidance_scale = guidance_scale
1082
1086
  self._clip_skip = clip_skip
1083
1087
  self._cross_attention_kwargs = cross_attention_kwargs
1088
+ self._interrupt = False
1084
1089
 
1085
1090
  # 2. Define call parameters
1086
1091
  if prompt is not None and isinstance(prompt, str):
@@ -1211,6 +1216,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
1211
1216
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1212
1217
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1213
1218
  for i, t in enumerate(timesteps):
1219
+ if self.interrupt:
1220
+ continue
1221
+
1214
1222
  # expand the latents if we are doing classifier free guidance
1215
1223
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1216
1224
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -976,6 +976,10 @@ class StableDiffusionControlNetInpaintPipeline(
976
976
  def num_timesteps(self):
977
977
  return self._num_timesteps
978
978
 
979
+ @property
980
+ def interrupt(self):
981
+ return self._interrupt
982
+
979
983
  @torch.no_grad()
980
984
  @replace_example_docstring(EXAMPLE_DOC_STRING)
981
985
  def __call__(
@@ -1191,6 +1195,7 @@ class StableDiffusionControlNetInpaintPipeline(
1191
1195
  self._guidance_scale = guidance_scale
1192
1196
  self._clip_skip = clip_skip
1193
1197
  self._cross_attention_kwargs = cross_attention_kwargs
1198
+ self._interrupt = False
1194
1199
 
1195
1200
  # 2. Define call parameters
1196
1201
  if prompt is not None and isinstance(prompt, str):
@@ -1375,6 +1380,9 @@ class StableDiffusionControlNetInpaintPipeline(
1375
1380
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1376
1381
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1377
1382
  for i, t in enumerate(timesteps):
1383
+ if self.interrupt:
1384
+ continue
1385
+
1378
1386
  # expand the latents if we are doing classifier free guidance
1379
1387
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1380
1388
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -137,9 +137,21 @@ EXAMPLE_DOC_STRING = """
137
137
 
138
138
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
139
139
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
140
- """
141
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
142
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
140
+ r"""
141
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
142
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
143
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
144
+
145
+ Args:
146
+ noise_cfg (`torch.Tensor`):
147
+ The predicted noise tensor for the guided diffusion process.
148
+ noise_pred_text (`torch.Tensor`):
149
+ The predicted noise tensor for the text-guided diffusion process.
150
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
151
+ A rescale factor applied to the noise predictions.
152
+
153
+ Returns:
154
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
143
155
  """
144
156
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
145
157
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -1024,14 +1036,16 @@ class StableDiffusionXLControlNetInpaintPipeline(
1024
1036
  if denoising_start is None:
1025
1037
  init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
1026
1038
  t_start = max(num_inference_steps - init_timestep, 0)
1027
- else:
1028
- t_start = 0
1029
1039
 
1030
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
1040
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
1041
+ if hasattr(self.scheduler, "set_begin_index"):
1042
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
1043
+
1044
+ return timesteps, num_inference_steps - t_start
1031
1045
 
1032
- # Strength is irrelevant if we directly request a timestep to start at;
1033
- # that is, strength is determined by the denoising_start instead.
1034
- if denoising_start is not None:
1046
+ else:
1047
+ # Strength is irrelevant if we directly request a timestep to start at;
1048
+ # that is, strength is determined by the denoising_start instead.
1035
1049
  discrete_timestep_cutoff = int(
1036
1050
  round(
1037
1051
  self.scheduler.config.num_train_timesteps
@@ -1039,7 +1053,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
1039
1053
  )
1040
1054
  )
1041
1055
 
1042
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
1056
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
1043
1057
  if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
1044
1058
  # if the scheduler is a 2nd order scheduler we might have to do +1
1045
1059
  # because `num_inference_steps` might be even given that every timestep
@@ -1050,11 +1064,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
1050
1064
  num_inference_steps = num_inference_steps + 1
1051
1065
 
1052
1066
  # because t_n+1 >= t_n, we slice the timesteps starting from the end
1053
- timesteps = timesteps[-num_inference_steps:]
1067
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
1068
+ timesteps = self.scheduler.timesteps[t_start:]
1069
+ if hasattr(self.scheduler, "set_begin_index"):
1070
+ self.scheduler.set_begin_index(t_start)
1054
1071
  return timesteps, num_inference_steps
1055
1072
 
1056
- return timesteps, num_inference_steps - t_start
1057
-
1058
1073
  def _get_add_time_ids(
1059
1074
  self,
1060
1075
  original_size,
@@ -1142,6 +1157,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
1142
1157
  def num_timesteps(self):
1143
1158
  return self._num_timesteps
1144
1159
 
1160
+ @property
1161
+ def interrupt(self):
1162
+ return self._interrupt
1163
+
1145
1164
  @torch.no_grad()
1146
1165
  @replace_example_docstring(EXAMPLE_DOC_STRING)
1147
1166
  def __call__(
@@ -1424,6 +1443,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
1424
1443
  self._guidance_scale = guidance_scale
1425
1444
  self._clip_skip = clip_skip
1426
1445
  self._cross_attention_kwargs = cross_attention_kwargs
1446
+ self._interrupt = False
1427
1447
 
1428
1448
  # 2. Define call parameters
1429
1449
  if prompt is not None and isinstance(prompt, str):
@@ -1692,6 +1712,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
1692
1712
 
1693
1713
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1694
1714
  for i, t in enumerate(timesteps):
1715
+ if self.interrupt:
1716
+ continue
1717
+
1695
1718
  # expand the latents if we are doing classifier free guidance
1696
1719
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1697
1720
 
@@ -122,7 +122,7 @@ def retrieve_timesteps(
122
122
  sigmas: Optional[List[float]] = None,
123
123
  **kwargs,
124
124
  ):
125
- """
125
+ r"""
126
126
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
127
127
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
128
128
 
@@ -990,6 +990,10 @@ class StableDiffusionXLControlNetPipeline(
990
990
  def num_timesteps(self):
991
991
  return self._num_timesteps
992
992
 
993
+ @property
994
+ def interrupt(self):
995
+ return self._interrupt
996
+
993
997
  @torch.no_grad()
994
998
  @replace_example_docstring(EXAMPLE_DOC_STRING)
995
999
  def __call__(
@@ -1245,6 +1249,7 @@ class StableDiffusionXLControlNetPipeline(
1245
1249
  self._clip_skip = clip_skip
1246
1250
  self._cross_attention_kwargs = cross_attention_kwargs
1247
1251
  self._denoising_end = denoising_end
1252
+ self._interrupt = False
1248
1253
 
1249
1254
  # 2. Define call parameters
1250
1255
  if prompt is not None and isinstance(prompt, str):
@@ -1442,6 +1447,9 @@ class StableDiffusionXLControlNetPipeline(
1442
1447
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1443
1448
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1444
1449
  for i, t in enumerate(timesteps):
1450
+ if self.interrupt:
1451
+ continue
1452
+
1445
1453
  # Relevant thread:
1446
1454
  # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1447
1455
  if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
@@ -1070,6 +1070,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1070
1070
  def num_timesteps(self):
1071
1071
  return self._num_timesteps
1072
1072
 
1073
+ @property
1074
+ def interrupt(self):
1075
+ return self._interrupt
1076
+
1073
1077
  @torch.no_grad()
1074
1078
  @replace_example_docstring(EXAMPLE_DOC_STRING)
1075
1079
  def __call__(
@@ -1338,6 +1342,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1338
1342
  self._guidance_scale = guidance_scale
1339
1343
  self._clip_skip = clip_skip
1340
1344
  self._cross_attention_kwargs = cross_attention_kwargs
1345
+ self._interrupt = False
1341
1346
 
1342
1347
  # 2. Define call parameters
1343
1348
  if prompt is not None and isinstance(prompt, str):
@@ -1510,6 +1515,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1510
1515
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1511
1516
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1512
1517
  for i, t in enumerate(timesteps):
1518
+ if self.interrupt:
1519
+ continue
1520
+
1513
1521
  # expand the latents if we are doing classifier free guidance
1514
1522
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1515
1523
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1538,7 +1546,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
1538
1546
  if isinstance(controlnet_cond_scale, list):
1539
1547
  controlnet_cond_scale = controlnet_cond_scale[0]
1540
1548
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1541
-
1542
1549
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1543
1550
  control_model_input,
1544
1551
  t,
@@ -141,9 +141,21 @@ def get_resize_crop_region_for_grid(src, tgt_size):
141
141
 
142
142
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
143
143
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
144
- """
145
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
146
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
144
+ r"""
145
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
146
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
147
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
148
+
149
+ Args:
150
+ noise_cfg (`torch.Tensor`):
151
+ The predicted noise tensor for the guided diffusion process.
152
+ noise_pred_text (`torch.Tensor`):
153
+ The predicted noise tensor for the text-guided diffusion process.
154
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
155
+ A rescale factor applied to the noise predictions.
156
+
157
+ Returns:
158
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
147
159
  """
148
160
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
149
161
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -225,6 +237,8 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
225
237
  requires_safety_checker: bool = True,
226
238
  ):
227
239
  super().__init__()
240
+ if isinstance(controlnet, (list, tuple)):
241
+ controlnet = HunyuanDiT2DMultiControlNetModel(controlnet)
228
242
 
229
243
  self.register_modules(
230
244
  vae=vae,
@@ -23,6 +23,9 @@ except OptionalDependencyNotAvailable:
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
25
  _import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"]
26
+ _import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [
27
+ "StableDiffusion3ControlNetInpaintingPipeline"
28
+ ]
26
29
 
27
30
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
31
  try:
@@ -33,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
33
36
  from ...utils.dummy_torch_and_transformers_objects import *
34
37
  else:
35
38
  from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
39
+ from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline
36
40
 
37
41
  try:
38
42
  if not (is_transformers_available() and is_flax_available()):
@@ -83,7 +83,7 @@ def retrieve_timesteps(
83
83
  sigmas: Optional[List[float]] = None,
84
84
  **kwargs,
85
85
  ):
86
- """
86
+ r"""
87
87
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
88
88
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
89
89
 
@@ -192,6 +192,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
192
192
  ],
193
193
  ):
194
194
  super().__init__()
195
+ if isinstance(controlnet, (list, tuple)):
196
+ controlnet = SD3MultiControlNetModel(controlnet)
195
197
 
196
198
  self.register_modules(
197
199
  vae=vae,