diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -925,7 +925,11 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
925
925
  base_size = 512 // 8 // self.transformer.config.patch_size
926
926
  grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
927
927
  image_rotary_emb = get_2d_rotary_pos_embed(
928
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
928
+ self.transformer.inner_dim // self.transformer.num_heads,
929
+ grid_crops_coords,
930
+ (grid_height, grid_width),
931
+ device=device,
932
+ output_type="pt",
929
933
  )
930
934
 
931
935
  style = torch.tensor([0], device=device)
@@ -26,7 +26,7 @@ from transformers import (
26
26
  from ...image_processor import PipelineImageInput, VaeImageProcessor
27
27
  from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
28
28
  from ...models.autoencoders import AutoencoderKL
29
- from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
29
+ from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
30
30
  from ...models.transformers import SD3Transformer2DModel
31
31
  from ...schedulers import FlowMatchEulerDiscreteScheduler
32
32
  from ...utils import (
@@ -66,9 +66,13 @@ EXAMPLE_DOC_STRING = """
66
66
  ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
67
67
  ... )
68
68
  >>> pipe.to("cuda")
69
- >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
70
- >>> prompt = "A girl holding a sign that says InstantX"
71
- >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0]
69
+ >>> control_image = load_image(
70
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
71
+ ... )
72
+ >>> prompt = "A bird in space"
73
+ >>> image = pipe(
74
+ ... prompt, control_image=control_image, height=1024, width=768, controlnet_conditioning_scale=0.7
75
+ ... ).images[0]
72
76
  >>> image.save("sd3.png")
73
77
  ```
74
78
  """
@@ -194,6 +198,19 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
194
198
  super().__init__()
195
199
  if isinstance(controlnet, (list, tuple)):
196
200
  controlnet = SD3MultiControlNetModel(controlnet)
201
+ if isinstance(controlnet, SD3MultiControlNetModel):
202
+ for controlnet_model in controlnet.nets:
203
+ # for SD3.5 8b controlnet, it shares the pos_embed with the transformer
204
+ if (
205
+ hasattr(controlnet_model.config, "use_pos_embed")
206
+ and controlnet_model.config.use_pos_embed is False
207
+ ):
208
+ pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer)
209
+ controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device)
210
+ elif isinstance(controlnet, SD3ControlNetModel):
211
+ if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
212
+ pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
213
+ controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)
197
214
 
198
215
  self.register_modules(
199
216
  vae=vae,
@@ -720,7 +737,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
720
737
  height: Optional[int] = None,
721
738
  width: Optional[int] = None,
722
739
  num_inference_steps: int = 28,
723
- timesteps: List[int] = None,
740
+ sigmas: Optional[List[float]] = None,
724
741
  guidance_scale: float = 7.0,
725
742
  control_guidance_start: Union[float, List[float]] = 0.0,
726
743
  control_guidance_end: Union[float, List[float]] = 1.0,
@@ -765,10 +782,10 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
765
782
  num_inference_steps (`int`, *optional*, defaults to 50):
766
783
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
767
784
  expense of slower inference.
768
- timesteps (`List[int]`, *optional*):
769
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
770
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
771
- passed will be used. Must be in descending order.
785
+ sigmas (`List[float]`, *optional*):
786
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
787
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
788
+ will be used.
772
789
  guidance_scale (`float`, *optional*, defaults to 5.0):
773
790
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
774
791
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -858,6 +875,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
858
875
  height = height or self.default_sample_size * self.vae_scale_factor
859
876
  width = width or self.default_sample_size * self.vae_scale_factor
860
877
 
878
+ controlnet_config = (
879
+ self.controlnet.config
880
+ if isinstance(self.controlnet, SD3ControlNetModel)
881
+ else self.controlnet.nets[0].config
882
+ )
883
+
861
884
  # align format for control guidance
862
885
  if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
863
886
  control_guidance_start = len(control_guidance_end) * [control_guidance_start]
@@ -932,6 +955,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
932
955
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
933
956
 
934
957
  # 3. Prepare control image
958
+ if controlnet_config.force_zeros_for_pooled_projection:
959
+ # instantx sd3 controlnet does not apply shift factor
960
+ vae_shift_factor = 0
961
+ else:
962
+ vae_shift_factor = self.vae.config.shift_factor
935
963
  if isinstance(self.controlnet, SD3ControlNetModel):
936
964
  control_image = self.prepare_image(
937
965
  image=control_image,
@@ -947,8 +975,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
947
975
  height, width = control_image.shape[-2:]
948
976
 
949
977
  control_image = self.vae.encode(control_image).latent_dist.sample()
950
- control_image = control_image * self.vae.config.scaling_factor
951
-
978
+ control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
952
979
  elif isinstance(self.controlnet, SD3MultiControlNetModel):
953
980
  control_images = []
954
981
 
@@ -966,7 +993,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
966
993
  )
967
994
 
968
995
  control_image_ = self.vae.encode(control_image_).latent_dist.sample()
969
- control_image_ = control_image_ * self.vae.config.scaling_factor
996
+ control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor
970
997
 
971
998
  control_images.append(control_image_)
972
999
 
@@ -974,13 +1001,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
974
1001
  else:
975
1002
  assert False
976
1003
 
977
- if controlnet_pooled_projections is None:
978
- controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
979
- else:
980
- controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
981
-
982
1004
  # 4. Prepare timesteps
983
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1005
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
984
1006
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
985
1007
  self._num_timesteps = len(timesteps)
986
1008
 
@@ -1006,6 +1028,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
1006
1028
  ]
1007
1029
  controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
1008
1030
 
1031
+ if controlnet_config.force_zeros_for_pooled_projection:
1032
+ # instantx sd3 controlnet used zero pooled projection
1033
+ controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
1034
+ else:
1035
+ controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
1036
+
1037
+ if controlnet_config.joint_attention_dim is not None:
1038
+ controlnet_encoder_hidden_states = prompt_embeds
1039
+ else:
1040
+ # SD35 official 8b controlnet does not use encoder_hidden_states
1041
+ controlnet_encoder_hidden_states = None
1042
+
1009
1043
  # 7. Denoising loop
1010
1044
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1011
1045
  for i, t in enumerate(timesteps):
@@ -1029,7 +1063,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
1029
1063
  control_block_samples = self.controlnet(
1030
1064
  hidden_states=latent_model_input,
1031
1065
  timestep=timestep,
1032
- encoder_hidden_states=prompt_embeds,
1066
+ encoder_hidden_states=controlnet_encoder_hidden_states,
1033
1067
  pooled_projections=controlnet_pooled_projections,
1034
1068
  joint_attention_kwargs=self.joint_attention_kwargs,
1035
1069
  controlnet_cond=control_image,
@@ -26,7 +26,7 @@ from transformers import (
26
26
  from ...image_processor import PipelineImageInput, VaeImageProcessor
27
27
  from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
28
28
  from ...models.autoencoders import AutoencoderKL
29
- from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
29
+ from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
30
30
  from ...models.transformers import SD3Transformer2DModel
31
31
  from ...schedulers import FlowMatchEulerDiscreteScheduler
32
32
  from ...utils import (
@@ -787,7 +787,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
787
787
  height: Optional[int] = None,
788
788
  width: Optional[int] = None,
789
789
  num_inference_steps: int = 28,
790
- timesteps: List[int] = None,
790
+ sigmas: Optional[List[float]] = None,
791
791
  guidance_scale: float = 7.0,
792
792
  control_guidance_start: Union[float, List[float]] = 0.0,
793
793
  control_guidance_end: Union[float, List[float]] = 1.0,
@@ -833,10 +833,10 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
833
833
  num_inference_steps (`int`, *optional*, defaults to 50):
834
834
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
835
835
  expense of slower inference.
836
- timesteps (`List[int]`, *optional*):
837
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
838
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
839
- passed will be used. Must be in descending order.
836
+ sigmas (`List[float]`, *optional*):
837
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
838
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
839
+ will be used.
840
840
  guidance_scale (`float`, *optional*, defaults to 5.0):
841
841
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
842
842
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -1033,7 +1033,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
1033
1033
  controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
1034
1034
 
1035
1035
  # 4. Prepare timesteps
1036
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1036
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
1037
1037
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1038
1038
  self._num_timesteps = len(timesteps)
1039
1039
 
@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module):
1595
1595
  output_states = ()
1596
1596
 
1597
1597
  for resnet in self.resnets:
1598
- if self.training and self.gradient_checkpointing:
1598
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1599
1599
 
1600
1600
  def create_custom_forward(module):
1601
1601
  def custom_forward(*inputs):
@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module):
1732
1732
  blocks = list(zip(self.resnets, self.attentions))
1733
1733
 
1734
1734
  for i, (resnet, attn) in enumerate(blocks):
1735
- if self.training and self.gradient_checkpointing:
1735
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1736
1736
 
1737
1737
  def create_custom_forward(module, return_dict=None):
1738
1738
  def custom_forward(*inputs):
@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module):
1874
1874
 
1875
1875
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1876
1876
 
1877
- if self.training and self.gradient_checkpointing:
1877
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1878
1878
 
1879
1879
  def create_custom_forward(module):
1880
1880
  def custom_forward(*inputs):
@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
2033
2033
 
2034
2034
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2035
2035
 
2036
- if self.training and self.gradient_checkpointing:
2036
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2037
2037
 
2038
2038
  def create_custom_forward(module, return_dict=None):
2039
2039
  def custom_forward(*inputs):
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
2223
2223
  self.attentions = nn.ModuleList(attentions)
2224
2224
  self.resnets = nn.ModuleList(resnets)
2225
2225
 
2226
+ self.gradient_checkpointing = False
2227
+
2226
2228
  def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
2227
2229
  hidden_states = self.resnets[0](hidden_states, temb)
2228
2230
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
2229
- if attn is not None:
2230
- hidden_states = attn(hidden_states, temb=temb)
2231
- hidden_states = resnet(hidden_states, temb)
2231
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2232
+
2233
+ def create_custom_forward(module, return_dict=None):
2234
+ def custom_forward(*inputs):
2235
+ if return_dict is not None:
2236
+ return module(*inputs, return_dict=return_dict)
2237
+ else:
2238
+ return module(*inputs)
2239
+
2240
+ return custom_forward
2241
+
2242
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2243
+ if attn is not None:
2244
+ hidden_states = attn(hidden_states, temb=temb)
2245
+ hidden_states = torch.utils.checkpoint.checkpoint(
2246
+ create_custom_forward(resnet),
2247
+ hidden_states,
2248
+ temb,
2249
+ **ckpt_kwargs,
2250
+ )
2251
+ else:
2252
+ if attn is not None:
2253
+ hidden_states = attn(hidden_states, temb=temb)
2254
+ hidden_states = resnet(hidden_states, temb)
2232
2255
 
2233
2256
  return hidden_states
2234
2257
 
@@ -2352,7 +2375,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2352
2375
 
2353
2376
  hidden_states = self.resnets[0](hidden_states, temb)
2354
2377
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
2355
- if self.training and self.gradient_checkpointing:
2378
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2356
2379
 
2357
2380
  def create_custom_forward(module, return_dict=None):
2358
2381
  def custom_forward(*inputs):
@@ -12,7 +12,7 @@ from ...utils import (
12
12
 
13
13
  _dummy_objects = {}
14
14
  _additional_imports = {}
15
- _import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
15
+ _import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
16
16
 
17
17
  try:
18
18
  if not (is_transformers_available() and is_torch_available()):
@@ -22,12 +22,18 @@ except OptionalDependencyNotAvailable:
22
22
 
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
+ _import_structure["modeling_flux"] = ["ReduxImageEncoder"]
25
26
  _import_structure["pipeline_flux"] = ["FluxPipeline"]
27
+ _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
28
+ _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
29
+ _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
26
30
  _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
27
31
  _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
28
32
  _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
33
+ _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
29
34
  _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
30
35
  _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
36
+ _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
31
37
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
32
38
  try:
33
39
  if not (is_transformers_available() and is_torch_available()):
@@ -35,12 +41,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
35
41
  except OptionalDependencyNotAvailable:
36
42
  from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
37
43
  else:
44
+ from .modeling_flux import ReduxImageEncoder
38
45
  from .pipeline_flux import FluxPipeline
46
+ from .pipeline_flux_control import FluxControlPipeline
47
+ from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
48
+ from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
39
49
  from .pipeline_flux_controlnet import FluxControlNetPipeline
40
50
  from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
41
51
  from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
52
+ from .pipeline_flux_fill import FluxFillPipeline
42
53
  from .pipeline_flux_img2img import FluxImg2ImgPipeline
43
54
  from .pipeline_flux_inpaint import FluxInpaintPipeline
55
+ from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
44
56
  else:
45
57
  import sys
46
58
 
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...models.modeling_utils import ModelMixin
24
+ from ...utils import BaseOutput
25
+
26
+
27
+ @dataclass
28
+ class ReduxImageEncoderOutput(BaseOutput):
29
+ image_embeds: Optional[torch.Tensor] = None
30
+
31
+
32
+ class ReduxImageEncoder(ModelMixin, ConfigMixin):
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ redux_dim: int = 1152,
37
+ txt_in_features: int = 4096,
38
+ ) -> None:
39
+ super().__init__()
40
+
41
+ self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
42
+ self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
43
+
44
+ def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
45
+ projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
46
+
47
+ return ReduxImageEncoderOutput(image_embeds=projected_x)