diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
|
1
|
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,61 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
from typing import Any, Dict,
|
16
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
17
17
|
|
18
|
+
import numpy as np
|
18
19
|
import torch
|
19
20
|
import torch.nn as nn
|
20
21
|
import torch.nn.functional as F
|
21
22
|
|
22
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
+
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
24
25
|
from ...models.attention import FeedForward
|
25
|
-
from ...models.attention_processor import
|
26
|
+
from ...models.attention_processor import (
|
27
|
+
Attention,
|
28
|
+
AttentionProcessor,
|
29
|
+
FluxAttnProcessor2_0,
|
30
|
+
FluxAttnProcessor2_0_NPU,
|
31
|
+
FusedFluxAttnProcessor2_0,
|
32
|
+
)
|
26
33
|
from ...models.modeling_utils import ModelMixin
|
27
34
|
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
28
35
|
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
36
|
+
from ...utils.import_utils import is_torch_npu_available
|
29
37
|
from ...utils.torch_utils import maybe_allow_in_graph
|
30
|
-
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
38
|
+
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
31
39
|
from ..modeling_outputs import Transformer2DModelOutput
|
32
40
|
|
33
41
|
|
34
42
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
43
|
|
36
44
|
|
37
|
-
# YiYi to-do: refactor rope related functions/classes
|
38
|
-
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
39
|
-
assert dim % 2 == 0, "The dimension must be even."
|
40
|
-
|
41
|
-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
42
|
-
omega = 1.0 / (theta**scale)
|
43
|
-
|
44
|
-
batch_size, seq_length = pos.shape
|
45
|
-
out = torch.einsum("...n,d->...nd", pos, omega)
|
46
|
-
cos_out = torch.cos(out)
|
47
|
-
sin_out = torch.sin(out)
|
48
|
-
|
49
|
-
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
50
|
-
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
51
|
-
return out.float()
|
52
|
-
|
53
|
-
|
54
|
-
# YiYi to-do: refactor rope related functions/classes
|
55
|
-
class EmbedND(nn.Module):
|
56
|
-
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
57
|
-
super().__init__()
|
58
|
-
self.dim = dim
|
59
|
-
self.theta = theta
|
60
|
-
self.axes_dim = axes_dim
|
61
|
-
|
62
|
-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
63
|
-
n_axes = ids.shape[-1]
|
64
|
-
emb = torch.cat(
|
65
|
-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
66
|
-
dim=-3,
|
67
|
-
)
|
68
|
-
return emb.unsqueeze(1)
|
69
|
-
|
70
|
-
|
71
45
|
@maybe_allow_in_graph
|
72
46
|
class FluxSingleTransformerBlock(nn.Module):
|
73
47
|
r"""
|
@@ -92,7 +66,10 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
92
66
|
self.act_mlp = nn.GELU(approximate="tanh")
|
93
67
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
94
68
|
|
95
|
-
|
69
|
+
if is_torch_npu_available():
|
70
|
+
processor = FluxAttnProcessor2_0_NPU()
|
71
|
+
else:
|
72
|
+
processor = FluxAttnProcessor2_0()
|
96
73
|
self.attn = Attention(
|
97
74
|
query_dim=dim,
|
98
75
|
cross_attention_dim=None,
|
@@ -111,14 +88,16 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
111
88
|
hidden_states: torch.FloatTensor,
|
112
89
|
temb: torch.FloatTensor,
|
113
90
|
image_rotary_emb=None,
|
91
|
+
joint_attention_kwargs=None,
|
114
92
|
):
|
115
93
|
residual = hidden_states
|
116
94
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
117
95
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
118
|
-
|
96
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
119
97
|
attn_output = self.attn(
|
120
98
|
hidden_states=norm_hidden_states,
|
121
99
|
image_rotary_emb=image_rotary_emb,
|
100
|
+
**joint_attention_kwargs,
|
122
101
|
)
|
123
102
|
|
124
103
|
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
@@ -189,20 +168,27 @@ class FluxTransformerBlock(nn.Module):
|
|
189
168
|
encoder_hidden_states: torch.FloatTensor,
|
190
169
|
temb: torch.FloatTensor,
|
191
170
|
image_rotary_emb=None,
|
171
|
+
joint_attention_kwargs=None,
|
192
172
|
):
|
193
173
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
194
174
|
|
195
175
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
196
176
|
encoder_hidden_states, emb=temb
|
197
177
|
)
|
198
|
-
|
178
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
199
179
|
# Attention.
|
200
|
-
|
180
|
+
attention_outputs = self.attn(
|
201
181
|
hidden_states=norm_hidden_states,
|
202
182
|
encoder_hidden_states=norm_encoder_hidden_states,
|
203
183
|
image_rotary_emb=image_rotary_emb,
|
184
|
+
**joint_attention_kwargs,
|
204
185
|
)
|
205
186
|
|
187
|
+
if len(attention_outputs) == 2:
|
188
|
+
attn_output, context_attn_output = attention_outputs
|
189
|
+
elif len(attention_outputs) == 3:
|
190
|
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
191
|
+
|
206
192
|
# Process attention outputs for the `hidden_states`.
|
207
193
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
208
194
|
hidden_states = hidden_states + attn_output
|
@@ -214,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
|
|
214
200
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
215
201
|
|
216
202
|
hidden_states = hidden_states + ff_output
|
203
|
+
if len(attention_outputs) == 3:
|
204
|
+
hidden_states = hidden_states + ip_attn_output
|
217
205
|
|
218
206
|
# Process attention outputs for the `encoder_hidden_states`.
|
219
207
|
|
@@ -231,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
|
|
231
219
|
return encoder_hidden_states, hidden_states
|
232
220
|
|
233
221
|
|
234
|
-
class FluxTransformer2DModel(
|
222
|
+
class FluxTransformer2DModel(
|
223
|
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
224
|
+
):
|
235
225
|
"""
|
236
226
|
The Transformer model introduced in Flux.
|
237
227
|
|
@@ -250,12 +240,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
250
240
|
"""
|
251
241
|
|
252
242
|
_supports_gradient_checkpointing = True
|
243
|
+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
253
244
|
|
254
245
|
@register_to_config
|
255
246
|
def __init__(
|
256
247
|
self,
|
257
248
|
patch_size: int = 1,
|
258
249
|
in_channels: int = 64,
|
250
|
+
out_channels: Optional[int] = None,
|
259
251
|
num_layers: int = 19,
|
260
252
|
num_single_layers: int = 38,
|
261
253
|
attention_head_dim: int = 128,
|
@@ -263,13 +255,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
263
255
|
joint_attention_dim: int = 4096,
|
264
256
|
pooled_projection_dim: int = 768,
|
265
257
|
guidance_embeds: bool = False,
|
266
|
-
axes_dims_rope:
|
258
|
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
267
259
|
):
|
268
260
|
super().__init__()
|
269
|
-
self.out_channels = in_channels
|
261
|
+
self.out_channels = out_channels or in_channels
|
270
262
|
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
271
263
|
|
272
|
-
self.pos_embed =
|
264
|
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
265
|
+
|
273
266
|
text_time_guidance_cls = (
|
274
267
|
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
275
268
|
)
|
@@ -278,7 +271,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
278
271
|
)
|
279
272
|
|
280
273
|
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
281
|
-
self.x_embedder =
|
274
|
+
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
|
282
275
|
|
283
276
|
self.transformer_blocks = nn.ModuleList(
|
284
277
|
[
|
@@ -307,6 +300,106 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
307
300
|
|
308
301
|
self.gradient_checkpointing = False
|
309
302
|
|
303
|
+
@property
|
304
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
305
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
306
|
+
r"""
|
307
|
+
Returns:
|
308
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
309
|
+
indexed by its weight name.
|
310
|
+
"""
|
311
|
+
# set recursively
|
312
|
+
processors = {}
|
313
|
+
|
314
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
315
|
+
if hasattr(module, "get_processor"):
|
316
|
+
processors[f"{name}.processor"] = module.get_processor()
|
317
|
+
|
318
|
+
for sub_name, child in module.named_children():
|
319
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
320
|
+
|
321
|
+
return processors
|
322
|
+
|
323
|
+
for name, module in self.named_children():
|
324
|
+
fn_recursive_add_processors(name, module, processors)
|
325
|
+
|
326
|
+
return processors
|
327
|
+
|
328
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
329
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
330
|
+
r"""
|
331
|
+
Sets the attention processor to use to compute attention.
|
332
|
+
|
333
|
+
Parameters:
|
334
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
335
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
336
|
+
for **all** `Attention` layers.
|
337
|
+
|
338
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
339
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
340
|
+
|
341
|
+
"""
|
342
|
+
count = len(self.attn_processors.keys())
|
343
|
+
|
344
|
+
if isinstance(processor, dict) and len(processor) != count:
|
345
|
+
raise ValueError(
|
346
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
347
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
348
|
+
)
|
349
|
+
|
350
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
351
|
+
if hasattr(module, "set_processor"):
|
352
|
+
if not isinstance(processor, dict):
|
353
|
+
module.set_processor(processor)
|
354
|
+
else:
|
355
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
356
|
+
|
357
|
+
for sub_name, child in module.named_children():
|
358
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
359
|
+
|
360
|
+
for name, module in self.named_children():
|
361
|
+
fn_recursive_attn_processor(name, module, processor)
|
362
|
+
|
363
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
364
|
+
def fuse_qkv_projections(self):
|
365
|
+
"""
|
366
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
367
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
368
|
+
|
369
|
+
<Tip warning={true}>
|
370
|
+
|
371
|
+
This API is 🧪 experimental.
|
372
|
+
|
373
|
+
</Tip>
|
374
|
+
"""
|
375
|
+
self.original_attn_processors = None
|
376
|
+
|
377
|
+
for _, attn_processor in self.attn_processors.items():
|
378
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
379
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
380
|
+
|
381
|
+
self.original_attn_processors = self.attn_processors
|
382
|
+
|
383
|
+
for module in self.modules():
|
384
|
+
if isinstance(module, Attention):
|
385
|
+
module.fuse_projections(fuse=True)
|
386
|
+
|
387
|
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
388
|
+
|
389
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
390
|
+
def unfuse_qkv_projections(self):
|
391
|
+
"""Disables the fused QKV projection if enabled.
|
392
|
+
|
393
|
+
<Tip warning={true}>
|
394
|
+
|
395
|
+
This API is 🧪 experimental.
|
396
|
+
|
397
|
+
</Tip>
|
398
|
+
|
399
|
+
"""
|
400
|
+
if self.original_attn_processors is not None:
|
401
|
+
self.set_attn_processor(self.original_attn_processors)
|
402
|
+
|
310
403
|
def _set_gradient_checkpointing(self, module, value=False):
|
311
404
|
if hasattr(module, "gradient_checkpointing"):
|
312
405
|
module.gradient_checkpointing = value
|
@@ -321,7 +414,10 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
321
414
|
txt_ids: torch.Tensor = None,
|
322
415
|
guidance: torch.Tensor = None,
|
323
416
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
417
|
+
controlnet_block_samples=None,
|
418
|
+
controlnet_single_block_samples=None,
|
324
419
|
return_dict: bool = True,
|
420
|
+
controlnet_blocks_repeat: bool = False,
|
325
421
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
326
422
|
"""
|
327
423
|
The [`FluxTransformer2DModel`] forward method.
|
@@ -363,6 +459,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
363
459
|
logger.warning(
|
364
460
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
365
461
|
)
|
462
|
+
|
366
463
|
hidden_states = self.x_embedder(hidden_states)
|
367
464
|
|
368
465
|
timestep = timestep.to(hidden_states.dtype) * 1000
|
@@ -370,6 +467,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
370
467
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
371
468
|
else:
|
372
469
|
guidance = None
|
470
|
+
|
373
471
|
temb = (
|
374
472
|
self.time_text_embed(timestep, pooled_projections)
|
375
473
|
if guidance is None
|
@@ -377,11 +475,29 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
377
475
|
)
|
378
476
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
379
477
|
|
380
|
-
|
478
|
+
if txt_ids.ndim == 3:
|
479
|
+
logger.warning(
|
480
|
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
481
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
482
|
+
)
|
483
|
+
txt_ids = txt_ids[0]
|
484
|
+
if img_ids.ndim == 3:
|
485
|
+
logger.warning(
|
486
|
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
487
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
488
|
+
)
|
489
|
+
img_ids = img_ids[0]
|
490
|
+
|
491
|
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
381
492
|
image_rotary_emb = self.pos_embed(ids)
|
382
493
|
|
494
|
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
495
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
496
|
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
497
|
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
498
|
+
|
383
499
|
for index_block, block in enumerate(self.transformer_blocks):
|
384
|
-
if
|
500
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
385
501
|
|
386
502
|
def create_custom_forward(module, return_dict=None):
|
387
503
|
def custom_forward(*inputs):
|
@@ -408,12 +524,24 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
408
524
|
encoder_hidden_states=encoder_hidden_states,
|
409
525
|
temb=temb,
|
410
526
|
image_rotary_emb=image_rotary_emb,
|
527
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
411
528
|
)
|
412
529
|
|
530
|
+
# controlnet residual
|
531
|
+
if controlnet_block_samples is not None:
|
532
|
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
533
|
+
interval_control = int(np.ceil(interval_control))
|
534
|
+
# For Xlabs ControlNet.
|
535
|
+
if controlnet_blocks_repeat:
|
536
|
+
hidden_states = (
|
537
|
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
538
|
+
)
|
539
|
+
else:
|
540
|
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
413
541
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
414
542
|
|
415
543
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
416
|
-
if
|
544
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
417
545
|
|
418
546
|
def create_custom_forward(module, return_dict=None):
|
419
547
|
def custom_forward(*inputs):
|
@@ -438,6 +566,16 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
438
566
|
hidden_states=hidden_states,
|
439
567
|
temb=temb,
|
440
568
|
image_rotary_emb=image_rotary_emb,
|
569
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
570
|
+
)
|
571
|
+
|
572
|
+
# controlnet residual
|
573
|
+
if controlnet_single_block_samples is not None:
|
574
|
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
575
|
+
interval_control = int(np.ceil(interval_control))
|
576
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
577
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
578
|
+
+ controlnet_single_block_samples[index_block // interval_control]
|
441
579
|
)
|
442
580
|
|
443
581
|
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|