diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -101,10 +101,10 @@ class DDPMPipeline(DiffusionPipeline):
|
|
101
101
|
|
102
102
|
if self.device.type == "mps":
|
103
103
|
# randn does not work reproducibly on mps
|
104
|
-
image = randn_tensor(image_shape, generator=generator)
|
104
|
+
image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
|
105
105
|
image = image.to(self.device)
|
106
106
|
else:
|
107
|
-
image = randn_tensor(image_shape, generator=generator, device=self.device)
|
107
|
+
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
|
108
108
|
|
109
109
|
# set step values
|
110
110
|
self.scheduler.set_timesteps(num_inference_steps)
|
@@ -9,16 +9,17 @@ from ...utils import BaseOutput
|
|
9
9
|
|
10
10
|
@dataclass
|
11
11
|
class IFPipelineOutput(BaseOutput):
|
12
|
-
"""
|
13
|
-
Args:
|
12
|
+
r"""
|
14
13
|
Output class for Stable Diffusion pipelines.
|
15
|
-
|
14
|
+
|
15
|
+
Args:
|
16
|
+
images (`List[PIL.Image.Image]` or `np.ndarray`):
|
16
17
|
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
17
18
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
18
|
-
nsfw_detected (`List[bool]`)
|
19
|
+
nsfw_detected (`List[bool]`):
|
19
20
|
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
20
21
|
(nsfw) content or a watermark. `None` if safety checking could not be performed.
|
21
|
-
watermark_detected (`List[bool]`)
|
22
|
+
watermark_detected (`List[bool]`):
|
22
23
|
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
|
23
24
|
checking could not be performed.
|
24
25
|
"""
|
@@ -65,9 +65,21 @@ EXAMPLE_DOC_STRING = """
|
|
65
65
|
|
66
66
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
67
67
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
68
|
-
"""
|
69
|
-
|
70
|
-
|
68
|
+
r"""
|
69
|
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
70
|
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
71
|
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
72
|
+
|
73
|
+
Args:
|
74
|
+
noise_cfg (`torch.Tensor`):
|
75
|
+
The predicted noise tensor for the guided diffusion process.
|
76
|
+
noise_pred_text (`torch.Tensor`):
|
77
|
+
The predicted noise tensor for the text-guided diffusion process.
|
78
|
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
79
|
+
A rescale factor applied to the noise predictions.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
71
83
|
"""
|
72
84
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
73
85
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
@@ -87,7 +99,7 @@ def retrieve_timesteps(
|
|
87
99
|
sigmas: Optional[List[float]] = None,
|
88
100
|
**kwargs,
|
89
101
|
):
|
90
|
-
"""
|
102
|
+
r"""
|
91
103
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
92
104
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
93
105
|
|
@@ -127,7 +127,7 @@ def retrieve_timesteps(
|
|
127
127
|
sigmas: Optional[List[float]] = None,
|
128
128
|
**kwargs,
|
129
129
|
):
|
130
|
-
"""
|
130
|
+
r"""
|
131
131
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
132
132
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
133
133
|
|
@@ -546,7 +546,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
546
546
|
)
|
547
547
|
elif encoder_hid_dim_type is not None:
|
548
548
|
raise ValueError(
|
549
|
-
f"encoder_hid_dim_type
|
549
|
+
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'."
|
550
550
|
)
|
551
551
|
else:
|
552
552
|
self.encoder_hid_proj = None
|
@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module):
|
|
1595
1595
|
output_states = ()
|
1596
1596
|
|
1597
1597
|
for resnet in self.resnets:
|
1598
|
-
if
|
1598
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1599
1599
|
|
1600
1600
|
def create_custom_forward(module):
|
1601
1601
|
def custom_forward(*inputs):
|
@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1732
1732
|
blocks = list(zip(self.resnets, self.attentions))
|
1733
1733
|
|
1734
1734
|
for i, (resnet, attn) in enumerate(blocks):
|
1735
|
-
if
|
1735
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1736
1736
|
|
1737
1737
|
def create_custom_forward(module, return_dict=None):
|
1738
1738
|
def custom_forward(*inputs):
|
@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module):
|
|
1874
1874
|
|
1875
1875
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1876
1876
|
|
1877
|
-
if
|
1877
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1878
1878
|
|
1879
1879
|
def create_custom_forward(module):
|
1880
1880
|
def custom_forward(*inputs):
|
@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
2033
2033
|
|
2034
2034
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2035
2035
|
|
2036
|
-
if
|
2036
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2037
2037
|
|
2038
2038
|
def create_custom_forward(module, return_dict=None):
|
2039
2039
|
def custom_forward(*inputs):
|
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
|
|
2223
2223
|
self.attentions = nn.ModuleList(attentions)
|
2224
2224
|
self.resnets = nn.ModuleList(resnets)
|
2225
2225
|
|
2226
|
+
self.gradient_checkpointing = False
|
2227
|
+
|
2226
2228
|
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2227
2229
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2228
2230
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2229
|
-
if
|
2230
|
-
|
2231
|
-
|
2231
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2232
|
+
|
2233
|
+
def create_custom_forward(module, return_dict=None):
|
2234
|
+
def custom_forward(*inputs):
|
2235
|
+
if return_dict is not None:
|
2236
|
+
return module(*inputs, return_dict=return_dict)
|
2237
|
+
else:
|
2238
|
+
return module(*inputs)
|
2239
|
+
|
2240
|
+
return custom_forward
|
2241
|
+
|
2242
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2243
|
+
if attn is not None:
|
2244
|
+
hidden_states = attn(hidden_states, temb=temb)
|
2245
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2246
|
+
create_custom_forward(resnet),
|
2247
|
+
hidden_states,
|
2248
|
+
temb,
|
2249
|
+
**ckpt_kwargs,
|
2250
|
+
)
|
2251
|
+
else:
|
2252
|
+
if attn is not None:
|
2253
|
+
hidden_states = attn(hidden_states, temb=temb)
|
2254
|
+
hidden_states = resnet(hidden_states, temb)
|
2232
2255
|
|
2233
2256
|
return hidden_states
|
2234
2257
|
|
@@ -2352,7 +2375,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2352
2375
|
|
2353
2376
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2354
2377
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2355
|
-
if
|
2378
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2356
2379
|
|
2357
2380
|
def create_custom_forward(module, return_dict=None):
|
2358
2381
|
def custom_forward(*inputs):
|
@@ -12,7 +12,7 @@ from ...utils import (
|
|
12
12
|
|
13
13
|
_dummy_objects = {}
|
14
14
|
_additional_imports = {}
|
15
|
-
_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
|
15
|
+
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
|
16
16
|
|
17
17
|
try:
|
18
18
|
if not (is_transformers_available() and is_torch_available()):
|
@@ -22,7 +22,18 @@ except OptionalDependencyNotAvailable:
|
|
22
22
|
|
23
23
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
24
|
else:
|
25
|
+
_import_structure["modeling_flux"] = ["ReduxImageEncoder"]
|
25
26
|
_import_structure["pipeline_flux"] = ["FluxPipeline"]
|
27
|
+
_import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
|
28
|
+
_import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
|
29
|
+
_import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
|
30
|
+
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
|
31
|
+
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
|
32
|
+
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
|
33
|
+
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
|
34
|
+
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
|
35
|
+
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
|
36
|
+
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
|
26
37
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
27
38
|
try:
|
28
39
|
if not (is_transformers_available() and is_torch_available()):
|
@@ -30,7 +41,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
30
41
|
except OptionalDependencyNotAvailable:
|
31
42
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
32
43
|
else:
|
44
|
+
from .modeling_flux import ReduxImageEncoder
|
33
45
|
from .pipeline_flux import FluxPipeline
|
46
|
+
from .pipeline_flux_control import FluxControlPipeline
|
47
|
+
from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
|
48
|
+
from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
|
49
|
+
from .pipeline_flux_controlnet import FluxControlNetPipeline
|
50
|
+
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
|
51
|
+
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
|
52
|
+
from .pipeline_flux_fill import FluxFillPipeline
|
53
|
+
from .pipeline_flux_img2img import FluxImg2ImgPipeline
|
54
|
+
from .pipeline_flux_inpaint import FluxInpaintPipeline
|
55
|
+
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
|
34
56
|
else:
|
35
57
|
import sys
|
36
58
|
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...models.modeling_utils import ModelMixin
|
24
|
+
from ...utils import BaseOutput
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class ReduxImageEncoderOutput(BaseOutput):
|
29
|
+
image_embeds: Optional[torch.Tensor] = None
|
30
|
+
|
31
|
+
|
32
|
+
class ReduxImageEncoder(ModelMixin, ConfigMixin):
|
33
|
+
@register_to_config
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
redux_dim: int = 1152,
|
37
|
+
txt_in_features: int = 4096,
|
38
|
+
) -> None:
|
39
|
+
super().__init__()
|
40
|
+
|
41
|
+
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
|
42
|
+
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
|
43
|
+
|
44
|
+
def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
|
45
|
+
projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
|
46
|
+
|
47
|
+
return ReduxImageEncoderOutput(image_embeds=projected_x)
|