diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -925,7 +925,11 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
|
925
925
|
base_size = 512 // 8 // self.transformer.config.patch_size
|
926
926
|
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
|
927
927
|
image_rotary_emb = get_2d_rotary_pos_embed(
|
928
|
-
self.transformer.inner_dim // self.transformer.num_heads,
|
928
|
+
self.transformer.inner_dim // self.transformer.num_heads,
|
929
|
+
grid_crops_coords,
|
930
|
+
(grid_height, grid_width),
|
931
|
+
device=device,
|
932
|
+
output_type="pt",
|
929
933
|
)
|
930
934
|
|
931
935
|
style = torch.tensor([0], device=device)
|
@@ -26,7 +26,7 @@ from transformers import (
|
|
26
26
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
27
27
|
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
28
28
|
from ...models.autoencoders import AutoencoderKL
|
29
|
-
from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
29
|
+
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
30
30
|
from ...models.transformers import SD3Transformer2DModel
|
31
31
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
32
32
|
from ...utils import (
|
@@ -66,9 +66,13 @@ EXAMPLE_DOC_STRING = """
|
|
66
66
|
... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
67
67
|
... )
|
68
68
|
>>> pipe.to("cuda")
|
69
|
-
>>> control_image = load_image(
|
70
|
-
|
71
|
-
|
69
|
+
>>> control_image = load_image(
|
70
|
+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
71
|
+
... )
|
72
|
+
>>> prompt = "A bird in space"
|
73
|
+
>>> image = pipe(
|
74
|
+
... prompt, control_image=control_image, height=1024, width=768, controlnet_conditioning_scale=0.7
|
75
|
+
... ).images[0]
|
72
76
|
>>> image.save("sd3.png")
|
73
77
|
```
|
74
78
|
"""
|
@@ -194,6 +198,19 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
194
198
|
super().__init__()
|
195
199
|
if isinstance(controlnet, (list, tuple)):
|
196
200
|
controlnet = SD3MultiControlNetModel(controlnet)
|
201
|
+
if isinstance(controlnet, SD3MultiControlNetModel):
|
202
|
+
for controlnet_model in controlnet.nets:
|
203
|
+
# for SD3.5 8b controlnet, it shares the pos_embed with the transformer
|
204
|
+
if (
|
205
|
+
hasattr(controlnet_model.config, "use_pos_embed")
|
206
|
+
and controlnet_model.config.use_pos_embed is False
|
207
|
+
):
|
208
|
+
pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer)
|
209
|
+
controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device)
|
210
|
+
elif isinstance(controlnet, SD3ControlNetModel):
|
211
|
+
if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
|
212
|
+
pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
|
213
|
+
controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)
|
197
214
|
|
198
215
|
self.register_modules(
|
199
216
|
vae=vae,
|
@@ -720,7 +737,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
720
737
|
height: Optional[int] = None,
|
721
738
|
width: Optional[int] = None,
|
722
739
|
num_inference_steps: int = 28,
|
723
|
-
|
740
|
+
sigmas: Optional[List[float]] = None,
|
724
741
|
guidance_scale: float = 7.0,
|
725
742
|
control_guidance_start: Union[float, List[float]] = 0.0,
|
726
743
|
control_guidance_end: Union[float, List[float]] = 1.0,
|
@@ -765,10 +782,10 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
765
782
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
766
783
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
767
784
|
expense of slower inference.
|
768
|
-
|
769
|
-
Custom
|
770
|
-
|
771
|
-
|
785
|
+
sigmas (`List[float]`, *optional*):
|
786
|
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
787
|
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
788
|
+
will be used.
|
772
789
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
773
790
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
774
791
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
@@ -858,6 +875,12 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
858
875
|
height = height or self.default_sample_size * self.vae_scale_factor
|
859
876
|
width = width or self.default_sample_size * self.vae_scale_factor
|
860
877
|
|
878
|
+
controlnet_config = (
|
879
|
+
self.controlnet.config
|
880
|
+
if isinstance(self.controlnet, SD3ControlNetModel)
|
881
|
+
else self.controlnet.nets[0].config
|
882
|
+
)
|
883
|
+
|
861
884
|
# align format for control guidance
|
862
885
|
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
863
886
|
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
@@ -932,6 +955,11 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
932
955
|
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
933
956
|
|
934
957
|
# 3. Prepare control image
|
958
|
+
if controlnet_config.force_zeros_for_pooled_projection:
|
959
|
+
# instantx sd3 controlnet does not apply shift factor
|
960
|
+
vae_shift_factor = 0
|
961
|
+
else:
|
962
|
+
vae_shift_factor = self.vae.config.shift_factor
|
935
963
|
if isinstance(self.controlnet, SD3ControlNetModel):
|
936
964
|
control_image = self.prepare_image(
|
937
965
|
image=control_image,
|
@@ -947,8 +975,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
947
975
|
height, width = control_image.shape[-2:]
|
948
976
|
|
949
977
|
control_image = self.vae.encode(control_image).latent_dist.sample()
|
950
|
-
control_image = control_image * self.vae.config.scaling_factor
|
951
|
-
|
978
|
+
control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
|
952
979
|
elif isinstance(self.controlnet, SD3MultiControlNetModel):
|
953
980
|
control_images = []
|
954
981
|
|
@@ -966,7 +993,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
966
993
|
)
|
967
994
|
|
968
995
|
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
969
|
-
control_image_ = control_image_ * self.vae.config.scaling_factor
|
996
|
+
control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor
|
970
997
|
|
971
998
|
control_images.append(control_image_)
|
972
999
|
|
@@ -974,13 +1001,8 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
974
1001
|
else:
|
975
1002
|
assert False
|
976
1003
|
|
977
|
-
if controlnet_pooled_projections is None:
|
978
|
-
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
|
979
|
-
else:
|
980
|
-
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
|
981
|
-
|
982
1004
|
# 4. Prepare timesteps
|
983
|
-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device,
|
1005
|
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
984
1006
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
985
1007
|
self._num_timesteps = len(timesteps)
|
986
1008
|
|
@@ -1006,6 +1028,18 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
1006
1028
|
]
|
1007
1029
|
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
|
1008
1030
|
|
1031
|
+
if controlnet_config.force_zeros_for_pooled_projection:
|
1032
|
+
# instantx sd3 controlnet used zero pooled projection
|
1033
|
+
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
|
1034
|
+
else:
|
1035
|
+
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
|
1036
|
+
|
1037
|
+
if controlnet_config.joint_attention_dim is not None:
|
1038
|
+
controlnet_encoder_hidden_states = prompt_embeds
|
1039
|
+
else:
|
1040
|
+
# SD35 official 8b controlnet does not use encoder_hidden_states
|
1041
|
+
controlnet_encoder_hidden_states = None
|
1042
|
+
|
1009
1043
|
# 7. Denoising loop
|
1010
1044
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1011
1045
|
for i, t in enumerate(timesteps):
|
@@ -1029,7 +1063,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
|
1029
1063
|
control_block_samples = self.controlnet(
|
1030
1064
|
hidden_states=latent_model_input,
|
1031
1065
|
timestep=timestep,
|
1032
|
-
encoder_hidden_states=
|
1066
|
+
encoder_hidden_states=controlnet_encoder_hidden_states,
|
1033
1067
|
pooled_projections=controlnet_pooled_projections,
|
1034
1068
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
1035
1069
|
controlnet_cond=control_image,
|
@@ -26,7 +26,7 @@ from transformers import (
|
|
26
26
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
27
27
|
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
28
28
|
from ...models.autoencoders import AutoencoderKL
|
29
|
-
from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
29
|
+
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
30
30
|
from ...models.transformers import SD3Transformer2DModel
|
31
31
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
32
32
|
from ...utils import (
|
@@ -787,7 +787,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
|
|
787
787
|
height: Optional[int] = None,
|
788
788
|
width: Optional[int] = None,
|
789
789
|
num_inference_steps: int = 28,
|
790
|
-
|
790
|
+
sigmas: Optional[List[float]] = None,
|
791
791
|
guidance_scale: float = 7.0,
|
792
792
|
control_guidance_start: Union[float, List[float]] = 0.0,
|
793
793
|
control_guidance_end: Union[float, List[float]] = 1.0,
|
@@ -833,10 +833,10 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
|
|
833
833
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
834
834
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
835
835
|
expense of slower inference.
|
836
|
-
|
837
|
-
Custom
|
838
|
-
|
839
|
-
|
836
|
+
sigmas (`List[float]`, *optional*):
|
837
|
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
838
|
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
839
|
+
will be used.
|
840
840
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
841
841
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
842
842
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
@@ -1033,7 +1033,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
|
|
1033
1033
|
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
|
1034
1034
|
|
1035
1035
|
# 4. Prepare timesteps
|
1036
|
-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device,
|
1036
|
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
1037
1037
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1038
1038
|
self._num_timesteps = len(timesteps)
|
1039
1039
|
|
@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module):
|
|
1595
1595
|
output_states = ()
|
1596
1596
|
|
1597
1597
|
for resnet in self.resnets:
|
1598
|
-
if
|
1598
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1599
1599
|
|
1600
1600
|
def create_custom_forward(module):
|
1601
1601
|
def custom_forward(*inputs):
|
@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1732
1732
|
blocks = list(zip(self.resnets, self.attentions))
|
1733
1733
|
|
1734
1734
|
for i, (resnet, attn) in enumerate(blocks):
|
1735
|
-
if
|
1735
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1736
1736
|
|
1737
1737
|
def create_custom_forward(module, return_dict=None):
|
1738
1738
|
def custom_forward(*inputs):
|
@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module):
|
|
1874
1874
|
|
1875
1875
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1876
1876
|
|
1877
|
-
if
|
1877
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1878
1878
|
|
1879
1879
|
def create_custom_forward(module):
|
1880
1880
|
def custom_forward(*inputs):
|
@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
2033
2033
|
|
2034
2034
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2035
2035
|
|
2036
|
-
if
|
2036
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2037
2037
|
|
2038
2038
|
def create_custom_forward(module, return_dict=None):
|
2039
2039
|
def custom_forward(*inputs):
|
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
|
|
2223
2223
|
self.attentions = nn.ModuleList(attentions)
|
2224
2224
|
self.resnets = nn.ModuleList(resnets)
|
2225
2225
|
|
2226
|
+
self.gradient_checkpointing = False
|
2227
|
+
|
2226
2228
|
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2227
2229
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2228
2230
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2229
|
-
if
|
2230
|
-
|
2231
|
-
|
2231
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2232
|
+
|
2233
|
+
def create_custom_forward(module, return_dict=None):
|
2234
|
+
def custom_forward(*inputs):
|
2235
|
+
if return_dict is not None:
|
2236
|
+
return module(*inputs, return_dict=return_dict)
|
2237
|
+
else:
|
2238
|
+
return module(*inputs)
|
2239
|
+
|
2240
|
+
return custom_forward
|
2241
|
+
|
2242
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2243
|
+
if attn is not None:
|
2244
|
+
hidden_states = attn(hidden_states, temb=temb)
|
2245
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2246
|
+
create_custom_forward(resnet),
|
2247
|
+
hidden_states,
|
2248
|
+
temb,
|
2249
|
+
**ckpt_kwargs,
|
2250
|
+
)
|
2251
|
+
else:
|
2252
|
+
if attn is not None:
|
2253
|
+
hidden_states = attn(hidden_states, temb=temb)
|
2254
|
+
hidden_states = resnet(hidden_states, temb)
|
2232
2255
|
|
2233
2256
|
return hidden_states
|
2234
2257
|
|
@@ -2352,7 +2375,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2352
2375
|
|
2353
2376
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2354
2377
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2355
|
-
if
|
2378
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2356
2379
|
|
2357
2380
|
def create_custom_forward(module, return_dict=None):
|
2358
2381
|
def custom_forward(*inputs):
|
@@ -12,7 +12,7 @@ from ...utils import (
|
|
12
12
|
|
13
13
|
_dummy_objects = {}
|
14
14
|
_additional_imports = {}
|
15
|
-
_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
|
15
|
+
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
|
16
16
|
|
17
17
|
try:
|
18
18
|
if not (is_transformers_available() and is_torch_available()):
|
@@ -22,12 +22,18 @@ except OptionalDependencyNotAvailable:
|
|
22
22
|
|
23
23
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
24
|
else:
|
25
|
+
_import_structure["modeling_flux"] = ["ReduxImageEncoder"]
|
25
26
|
_import_structure["pipeline_flux"] = ["FluxPipeline"]
|
27
|
+
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
|
28
|
+
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
|
29
|
+
_import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
|
26
30
|
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
|
27
31
|
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
|
28
32
|
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
|
33
|
+
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
|
29
34
|
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
|
30
35
|
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
|
36
|
+
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
|
31
37
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
32
38
|
try:
|
33
39
|
if not (is_transformers_available() and is_torch_available()):
|
@@ -35,12 +41,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
35
41
|
except OptionalDependencyNotAvailable:
|
36
42
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
37
43
|
else:
|
44
|
+
from .modeling_flux import ReduxImageEncoder
|
38
45
|
from .pipeline_flux import FluxPipeline
|
46
|
+
from .pipeline_flux_control import FluxControlPipeline
|
47
|
+
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
|
48
|
+
from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
|
39
49
|
from .pipeline_flux_controlnet import FluxControlNetPipeline
|
40
50
|
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
|
41
51
|
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
|
52
|
+
from .pipeline_flux_fill import FluxFillPipeline
|
42
53
|
from .pipeline_flux_img2img import FluxImg2ImgPipeline
|
43
54
|
from .pipeline_flux_inpaint import FluxInpaintPipeline
|
55
|
+
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
|
44
56
|
else:
|
45
57
|
import sys
|
46
58
|
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...models.modeling_utils import ModelMixin
|
24
|
+
from ...utils import BaseOutput
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class ReduxImageEncoderOutput(BaseOutput):
|
29
|
+
image_embeds: Optional[torch.Tensor] = None
|
30
|
+
|
31
|
+
|
32
|
+
class ReduxImageEncoder(ModelMixin, ConfigMixin):
|
33
|
+
@register_to_config
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
redux_dim: int = 1152,
|
37
|
+
txt_in_features: int = 4096,
|
38
|
+
) -> None:
|
39
|
+
super().__init__()
|
40
|
+
|
41
|
+
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
|
42
|
+
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
|
43
|
+
|
44
|
+
def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
|
45
|
+
projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
|
46
|
+
|
47
|
+
return ReduxImageEncoderOutput(image_embeds=projected_x)
|