diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
diffusers/models/__init__.py
CHANGED
@@ -28,22 +28,32 @@ if is_torch_available():
|
|
28
28
|
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
29
29
|
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
30
30
|
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
31
|
+
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
31
32
|
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
33
|
+
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
32
34
|
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
33
35
|
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
34
36
|
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
35
37
|
_import_structure["controlnet"] = ["ControlNetModel"]
|
38
|
+
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
|
36
39
|
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
40
|
+
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
37
41
|
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
38
42
|
_import_structure["embeddings"] = ["ImageProjection"]
|
39
43
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
44
|
+
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
45
|
+
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
|
40
46
|
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
41
47
|
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
|
42
48
|
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
|
49
|
+
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
|
50
|
+
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
|
43
51
|
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
44
52
|
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
53
|
+
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
|
45
54
|
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
46
55
|
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
56
|
+
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
47
57
|
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
48
58
|
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
49
59
|
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
@@ -69,23 +79,33 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
69
79
|
from .autoencoders import (
|
70
80
|
AsymmetricAutoencoderKL,
|
71
81
|
AutoencoderKL,
|
82
|
+
AutoencoderKLCogVideoX,
|
72
83
|
AutoencoderKLTemporalDecoder,
|
84
|
+
AutoencoderOobleck,
|
73
85
|
AutoencoderTiny,
|
74
86
|
ConsistencyDecoderVAE,
|
75
87
|
VQModel,
|
76
88
|
)
|
77
89
|
from .controlnet import ControlNetModel
|
90
|
+
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
78
91
|
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
92
|
+
from .controlnet_sparsectrl import SparseControlNetModel
|
79
93
|
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
80
94
|
from .embeddings import ImageProjection
|
81
95
|
from .modeling_utils import ModelMixin
|
82
96
|
from .transformers import (
|
97
|
+
AuraFlowTransformer2DModel,
|
98
|
+
CogVideoXTransformer3DModel,
|
83
99
|
DiTTransformer2DModel,
|
84
100
|
DualTransformer2DModel,
|
101
|
+
FluxTransformer2DModel,
|
85
102
|
HunyuanDiT2DModel,
|
103
|
+
LatteTransformer3DModel,
|
104
|
+
LuminaNextDiT2DModel,
|
86
105
|
PixArtTransformer2DModel,
|
87
106
|
PriorTransformer,
|
88
107
|
SD3Transformer2DModel,
|
108
|
+
StableAudioDiTModel,
|
89
109
|
T5FilmDecoder,
|
90
110
|
Transformer2DModel,
|
91
111
|
TransformerTemporalModel,
|
diffusers/models/activations.py
CHANGED
@@ -123,6 +123,28 @@ class GEGLU(nn.Module):
|
|
123
123
|
return hidden_states * self.gelu(gate)
|
124
124
|
|
125
125
|
|
126
|
+
class SwiGLU(nn.Module):
|
127
|
+
r"""
|
128
|
+
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
|
129
|
+
but uses SiLU / Swish instead of GeLU.
|
130
|
+
|
131
|
+
Parameters:
|
132
|
+
dim_in (`int`): The number of channels in the input.
|
133
|
+
dim_out (`int`): The number of channels in the output.
|
134
|
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
135
|
+
"""
|
136
|
+
|
137
|
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
138
|
+
super().__init__()
|
139
|
+
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
140
|
+
self.activation = nn.SiLU()
|
141
|
+
|
142
|
+
def forward(self, hidden_states):
|
143
|
+
hidden_states = self.proj(hidden_states)
|
144
|
+
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
145
|
+
return hidden_states * self.activation(gate)
|
146
|
+
|
147
|
+
|
126
148
|
class ApproximateGELU(nn.Module):
|
127
149
|
r"""
|
128
150
|
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
|
diffusers/models/attention.py
CHANGED
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
from typing import Any, Dict, Optional
|
14
|
+
from typing import Any, Dict, List, Optional, Tuple
|
15
15
|
|
16
16
|
import torch
|
17
17
|
import torch.nn.functional as F
|
@@ -19,7 +19,7 @@ from torch import nn
|
|
19
19
|
|
20
20
|
from ..utils import deprecate, logging
|
21
21
|
from ..utils.torch_utils import maybe_allow_in_graph
|
22
|
-
from .activations import GEGLU, GELU, ApproximateGELU
|
22
|
+
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
23
23
|
from .attention_processor import Attention, JointAttnProcessor2_0
|
24
24
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
25
|
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
@@ -128,9 +128,9 @@ class JointTransformerBlock(nn.Module):
|
|
128
128
|
query_dim=dim,
|
129
129
|
cross_attention_dim=None,
|
130
130
|
added_kv_proj_dim=dim,
|
131
|
-
dim_head=attention_head_dim
|
131
|
+
dim_head=attention_head_dim,
|
132
132
|
heads=num_attention_heads,
|
133
|
-
out_dim=
|
133
|
+
out_dim=dim,
|
134
134
|
context_pre_only=context_pre_only,
|
135
135
|
bias=True,
|
136
136
|
processor=processor,
|
@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
|
|
272
272
|
attention_out_bias: bool = True,
|
273
273
|
):
|
274
274
|
super().__init__()
|
275
|
+
self.dim = dim
|
276
|
+
self.num_attention_heads = num_attention_heads
|
277
|
+
self.attention_head_dim = attention_head_dim
|
278
|
+
self.dropout = dropout
|
279
|
+
self.cross_attention_dim = cross_attention_dim
|
280
|
+
self.activation_fn = activation_fn
|
281
|
+
self.attention_bias = attention_bias
|
282
|
+
self.double_self_attention = double_self_attention
|
283
|
+
self.norm_elementwise_affine = norm_elementwise_affine
|
284
|
+
self.positional_embeddings = positional_embeddings
|
285
|
+
self.num_positional_embeddings = num_positional_embeddings
|
275
286
|
self.only_cross_attention = only_cross_attention
|
276
287
|
|
277
288
|
# We keep these boolean flags for backward-compatibility.
|
@@ -359,7 +370,10 @@ class BasicTransformerBlock(nn.Module):
|
|
359
370
|
out_bias=attention_out_bias,
|
360
371
|
) # is self-attn if encoder_hidden_states is none
|
361
372
|
else:
|
362
|
-
|
373
|
+
if norm_type == "ada_norm_single": # For Latte
|
374
|
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
375
|
+
else:
|
376
|
+
self.norm2 = None
|
363
377
|
self.attn2 = None
|
364
378
|
|
365
379
|
# 3. Feed-forward
|
@@ -373,7 +387,7 @@ class BasicTransformerBlock(nn.Module):
|
|
373
387
|
"layer_norm",
|
374
388
|
)
|
375
389
|
|
376
|
-
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"
|
390
|
+
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
377
391
|
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
378
392
|
elif norm_type == "layer_norm_i2vgen":
|
379
393
|
self.norm3 = None
|
@@ -439,7 +453,6 @@ class BasicTransformerBlock(nn.Module):
|
|
439
453
|
).chunk(6, dim=1)
|
440
454
|
norm_hidden_states = self.norm1(hidden_states)
|
441
455
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
442
|
-
norm_hidden_states = norm_hidden_states.squeeze(1)
|
443
456
|
else:
|
444
457
|
raise ValueError("Incorrect norm used")
|
445
458
|
|
@@ -456,6 +469,7 @@ class BasicTransformerBlock(nn.Module):
|
|
456
469
|
attention_mask=attention_mask,
|
457
470
|
**cross_attention_kwargs,
|
458
471
|
)
|
472
|
+
|
459
473
|
if self.norm_type == "ada_norm_zero":
|
460
474
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
461
475
|
elif self.norm_type == "ada_norm_single":
|
@@ -527,6 +541,56 @@ class BasicTransformerBlock(nn.Module):
|
|
527
541
|
return hidden_states
|
528
542
|
|
529
543
|
|
544
|
+
class LuminaFeedForward(nn.Module):
|
545
|
+
r"""
|
546
|
+
A feed-forward layer.
|
547
|
+
|
548
|
+
Parameters:
|
549
|
+
hidden_size (`int`):
|
550
|
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
551
|
+
hidden representations.
|
552
|
+
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
553
|
+
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
554
|
+
of this value.
|
555
|
+
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
556
|
+
dimension. Defaults to None.
|
557
|
+
"""
|
558
|
+
|
559
|
+
def __init__(
|
560
|
+
self,
|
561
|
+
dim: int,
|
562
|
+
inner_dim: int,
|
563
|
+
multiple_of: Optional[int] = 256,
|
564
|
+
ffn_dim_multiplier: Optional[float] = None,
|
565
|
+
):
|
566
|
+
super().__init__()
|
567
|
+
inner_dim = int(2 * inner_dim / 3)
|
568
|
+
# custom hidden_size factor multiplier
|
569
|
+
if ffn_dim_multiplier is not None:
|
570
|
+
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
571
|
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
572
|
+
|
573
|
+
self.linear_1 = nn.Linear(
|
574
|
+
dim,
|
575
|
+
inner_dim,
|
576
|
+
bias=False,
|
577
|
+
)
|
578
|
+
self.linear_2 = nn.Linear(
|
579
|
+
inner_dim,
|
580
|
+
dim,
|
581
|
+
bias=False,
|
582
|
+
)
|
583
|
+
self.linear_3 = nn.Linear(
|
584
|
+
dim,
|
585
|
+
inner_dim,
|
586
|
+
bias=False,
|
587
|
+
)
|
588
|
+
self.silu = FP32SiLU()
|
589
|
+
|
590
|
+
def forward(self, x):
|
591
|
+
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
|
592
|
+
|
593
|
+
|
530
594
|
@maybe_allow_in_graph
|
531
595
|
class TemporalBasicTransformerBlock(nn.Module):
|
532
596
|
r"""
|
@@ -729,6 +793,319 @@ class SkipFFTransformerBlock(nn.Module):
|
|
729
793
|
return hidden_states
|
730
794
|
|
731
795
|
|
796
|
+
@maybe_allow_in_graph
|
797
|
+
class FreeNoiseTransformerBlock(nn.Module):
|
798
|
+
r"""
|
799
|
+
A FreeNoise Transformer block.
|
800
|
+
|
801
|
+
Parameters:
|
802
|
+
dim (`int`):
|
803
|
+
The number of channels in the input and output.
|
804
|
+
num_attention_heads (`int`):
|
805
|
+
The number of heads to use for multi-head attention.
|
806
|
+
attention_head_dim (`int`):
|
807
|
+
The number of channels in each head.
|
808
|
+
dropout (`float`, *optional*, defaults to 0.0):
|
809
|
+
The dropout probability to use.
|
810
|
+
cross_attention_dim (`int`, *optional*):
|
811
|
+
The size of the encoder_hidden_states vector for cross attention.
|
812
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
813
|
+
Activation function to be used in feed-forward.
|
814
|
+
num_embeds_ada_norm (`int`, *optional*):
|
815
|
+
The number of diffusion steps used during training. See `Transformer2DModel`.
|
816
|
+
attention_bias (`bool`, defaults to `False`):
|
817
|
+
Configure if the attentions should contain a bias parameter.
|
818
|
+
only_cross_attention (`bool`, defaults to `False`):
|
819
|
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
820
|
+
double_self_attention (`bool`, defaults to `False`):
|
821
|
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
822
|
+
upcast_attention (`bool`, defaults to `False`):
|
823
|
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
824
|
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
825
|
+
Whether to use learnable elementwise affine parameters for normalization.
|
826
|
+
norm_type (`str`, defaults to `"layer_norm"`):
|
827
|
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
828
|
+
final_dropout (`bool` defaults to `False`):
|
829
|
+
Whether to apply a final dropout after the last feed-forward layer.
|
830
|
+
attention_type (`str`, defaults to `"default"`):
|
831
|
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
832
|
+
positional_embeddings (`str`, *optional*):
|
833
|
+
The type of positional embeddings to apply to.
|
834
|
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
835
|
+
The maximum number of positional embeddings to apply.
|
836
|
+
ff_inner_dim (`int`, *optional*):
|
837
|
+
Hidden dimension of feed-forward MLP.
|
838
|
+
ff_bias (`bool`, defaults to `True`):
|
839
|
+
Whether or not to use bias in feed-forward MLP.
|
840
|
+
attention_out_bias (`bool`, defaults to `True`):
|
841
|
+
Whether or not to use bias in attention output project layer.
|
842
|
+
context_length (`int`, defaults to `16`):
|
843
|
+
The maximum number of frames that the FreeNoise block processes at once.
|
844
|
+
context_stride (`int`, defaults to `4`):
|
845
|
+
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
|
846
|
+
weighting_scheme (`str`, defaults to `"pyramid"`):
|
847
|
+
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
|
848
|
+
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
|
849
|
+
used.
|
850
|
+
"""
|
851
|
+
|
852
|
+
def __init__(
|
853
|
+
self,
|
854
|
+
dim: int,
|
855
|
+
num_attention_heads: int,
|
856
|
+
attention_head_dim: int,
|
857
|
+
dropout: float = 0.0,
|
858
|
+
cross_attention_dim: Optional[int] = None,
|
859
|
+
activation_fn: str = "geglu",
|
860
|
+
num_embeds_ada_norm: Optional[int] = None,
|
861
|
+
attention_bias: bool = False,
|
862
|
+
only_cross_attention: bool = False,
|
863
|
+
double_self_attention: bool = False,
|
864
|
+
upcast_attention: bool = False,
|
865
|
+
norm_elementwise_affine: bool = True,
|
866
|
+
norm_type: str = "layer_norm",
|
867
|
+
norm_eps: float = 1e-5,
|
868
|
+
final_dropout: bool = False,
|
869
|
+
positional_embeddings: Optional[str] = None,
|
870
|
+
num_positional_embeddings: Optional[int] = None,
|
871
|
+
ff_inner_dim: Optional[int] = None,
|
872
|
+
ff_bias: bool = True,
|
873
|
+
attention_out_bias: bool = True,
|
874
|
+
context_length: int = 16,
|
875
|
+
context_stride: int = 4,
|
876
|
+
weighting_scheme: str = "pyramid",
|
877
|
+
):
|
878
|
+
super().__init__()
|
879
|
+
self.dim = dim
|
880
|
+
self.num_attention_heads = num_attention_heads
|
881
|
+
self.attention_head_dim = attention_head_dim
|
882
|
+
self.dropout = dropout
|
883
|
+
self.cross_attention_dim = cross_attention_dim
|
884
|
+
self.activation_fn = activation_fn
|
885
|
+
self.attention_bias = attention_bias
|
886
|
+
self.double_self_attention = double_self_attention
|
887
|
+
self.norm_elementwise_affine = norm_elementwise_affine
|
888
|
+
self.positional_embeddings = positional_embeddings
|
889
|
+
self.num_positional_embeddings = num_positional_embeddings
|
890
|
+
self.only_cross_attention = only_cross_attention
|
891
|
+
|
892
|
+
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
|
893
|
+
|
894
|
+
# We keep these boolean flags for backward-compatibility.
|
895
|
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
896
|
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
897
|
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
898
|
+
self.use_layer_norm = norm_type == "layer_norm"
|
899
|
+
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
900
|
+
|
901
|
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
902
|
+
raise ValueError(
|
903
|
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
904
|
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
905
|
+
)
|
906
|
+
|
907
|
+
self.norm_type = norm_type
|
908
|
+
self.num_embeds_ada_norm = num_embeds_ada_norm
|
909
|
+
|
910
|
+
if positional_embeddings and (num_positional_embeddings is None):
|
911
|
+
raise ValueError(
|
912
|
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
913
|
+
)
|
914
|
+
|
915
|
+
if positional_embeddings == "sinusoidal":
|
916
|
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
917
|
+
else:
|
918
|
+
self.pos_embed = None
|
919
|
+
|
920
|
+
# Define 3 blocks. Each block has its own normalization layer.
|
921
|
+
# 1. Self-Attn
|
922
|
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
923
|
+
|
924
|
+
self.attn1 = Attention(
|
925
|
+
query_dim=dim,
|
926
|
+
heads=num_attention_heads,
|
927
|
+
dim_head=attention_head_dim,
|
928
|
+
dropout=dropout,
|
929
|
+
bias=attention_bias,
|
930
|
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
931
|
+
upcast_attention=upcast_attention,
|
932
|
+
out_bias=attention_out_bias,
|
933
|
+
)
|
934
|
+
|
935
|
+
# 2. Cross-Attn
|
936
|
+
if cross_attention_dim is not None or double_self_attention:
|
937
|
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
938
|
+
|
939
|
+
self.attn2 = Attention(
|
940
|
+
query_dim=dim,
|
941
|
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
942
|
+
heads=num_attention_heads,
|
943
|
+
dim_head=attention_head_dim,
|
944
|
+
dropout=dropout,
|
945
|
+
bias=attention_bias,
|
946
|
+
upcast_attention=upcast_attention,
|
947
|
+
out_bias=attention_out_bias,
|
948
|
+
) # is self-attn if encoder_hidden_states is none
|
949
|
+
|
950
|
+
# 3. Feed-forward
|
951
|
+
self.ff = FeedForward(
|
952
|
+
dim,
|
953
|
+
dropout=dropout,
|
954
|
+
activation_fn=activation_fn,
|
955
|
+
final_dropout=final_dropout,
|
956
|
+
inner_dim=ff_inner_dim,
|
957
|
+
bias=ff_bias,
|
958
|
+
)
|
959
|
+
|
960
|
+
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
961
|
+
|
962
|
+
# let chunk size default to None
|
963
|
+
self._chunk_size = None
|
964
|
+
self._chunk_dim = 0
|
965
|
+
|
966
|
+
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
967
|
+
frame_indices = []
|
968
|
+
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
969
|
+
window_start = i
|
970
|
+
window_end = min(num_frames, i + self.context_length)
|
971
|
+
frame_indices.append((window_start, window_end))
|
972
|
+
return frame_indices
|
973
|
+
|
974
|
+
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
975
|
+
if weighting_scheme == "pyramid":
|
976
|
+
if num_frames % 2 == 0:
|
977
|
+
# num_frames = 4 => [1, 2, 2, 1]
|
978
|
+
weights = list(range(1, num_frames // 2 + 1))
|
979
|
+
weights = weights + weights[::-1]
|
980
|
+
else:
|
981
|
+
# num_frames = 5 => [1, 2, 3, 2, 1]
|
982
|
+
weights = list(range(1, num_frames // 2 + 1))
|
983
|
+
weights = weights + [num_frames // 2 + 1] + weights[::-1]
|
984
|
+
else:
|
985
|
+
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
986
|
+
|
987
|
+
return weights
|
988
|
+
|
989
|
+
def set_free_noise_properties(
|
990
|
+
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
|
991
|
+
) -> None:
|
992
|
+
self.context_length = context_length
|
993
|
+
self.context_stride = context_stride
|
994
|
+
self.weighting_scheme = weighting_scheme
|
995
|
+
|
996
|
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
|
997
|
+
# Sets chunk feed-forward
|
998
|
+
self._chunk_size = chunk_size
|
999
|
+
self._chunk_dim = dim
|
1000
|
+
|
1001
|
+
def forward(
|
1002
|
+
self,
|
1003
|
+
hidden_states: torch.Tensor,
|
1004
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1005
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1006
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1007
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
1008
|
+
*args,
|
1009
|
+
**kwargs,
|
1010
|
+
) -> torch.Tensor:
|
1011
|
+
if cross_attention_kwargs is not None:
|
1012
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1013
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1014
|
+
|
1015
|
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
1016
|
+
|
1017
|
+
# hidden_states: [B x H x W, F, C]
|
1018
|
+
device = hidden_states.device
|
1019
|
+
dtype = hidden_states.dtype
|
1020
|
+
|
1021
|
+
num_frames = hidden_states.size(1)
|
1022
|
+
frame_indices = self._get_frame_indices(num_frames)
|
1023
|
+
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
|
1024
|
+
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
|
1025
|
+
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
|
1026
|
+
|
1027
|
+
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
|
1028
|
+
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
|
1029
|
+
# [(0, 16), (4, 20), (8, 24), (10, 26)]
|
1030
|
+
if not is_last_frame_batch_complete:
|
1031
|
+
if num_frames < self.context_length:
|
1032
|
+
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
|
1033
|
+
last_frame_batch_length = num_frames - frame_indices[-1][1]
|
1034
|
+
frame_indices.append((num_frames - self.context_length, num_frames))
|
1035
|
+
|
1036
|
+
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
|
1037
|
+
accumulated_values = torch.zeros_like(hidden_states)
|
1038
|
+
|
1039
|
+
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
1040
|
+
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
|
1041
|
+
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
|
1042
|
+
# essentially a non-multiple of `context_length`.
|
1043
|
+
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
1044
|
+
weights *= frame_weights
|
1045
|
+
|
1046
|
+
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
|
1047
|
+
|
1048
|
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
1049
|
+
# 1. Self-Attention
|
1050
|
+
norm_hidden_states = self.norm1(hidden_states_chunk)
|
1051
|
+
|
1052
|
+
if self.pos_embed is not None:
|
1053
|
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
1054
|
+
|
1055
|
+
attn_output = self.attn1(
|
1056
|
+
norm_hidden_states,
|
1057
|
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
1058
|
+
attention_mask=attention_mask,
|
1059
|
+
**cross_attention_kwargs,
|
1060
|
+
)
|
1061
|
+
|
1062
|
+
hidden_states_chunk = attn_output + hidden_states_chunk
|
1063
|
+
if hidden_states_chunk.ndim == 4:
|
1064
|
+
hidden_states_chunk = hidden_states_chunk.squeeze(1)
|
1065
|
+
|
1066
|
+
# 2. Cross-Attention
|
1067
|
+
if self.attn2 is not None:
|
1068
|
+
norm_hidden_states = self.norm2(hidden_states_chunk)
|
1069
|
+
|
1070
|
+
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
1071
|
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
1072
|
+
|
1073
|
+
attn_output = self.attn2(
|
1074
|
+
norm_hidden_states,
|
1075
|
+
encoder_hidden_states=encoder_hidden_states,
|
1076
|
+
attention_mask=encoder_attention_mask,
|
1077
|
+
**cross_attention_kwargs,
|
1078
|
+
)
|
1079
|
+
hidden_states_chunk = attn_output + hidden_states_chunk
|
1080
|
+
|
1081
|
+
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
|
1082
|
+
accumulated_values[:, -last_frame_batch_length:] += (
|
1083
|
+
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
|
1084
|
+
)
|
1085
|
+
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
|
1086
|
+
else:
|
1087
|
+
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
1088
|
+
num_times_accumulated[:, frame_start:frame_end] += weights
|
1089
|
+
|
1090
|
+
hidden_states = torch.where(
|
1091
|
+
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
1092
|
+
).to(dtype)
|
1093
|
+
|
1094
|
+
# 3. Feed-forward
|
1095
|
+
norm_hidden_states = self.norm3(hidden_states)
|
1096
|
+
|
1097
|
+
if self._chunk_size is not None:
|
1098
|
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
1099
|
+
else:
|
1100
|
+
ff_output = self.ff(norm_hidden_states)
|
1101
|
+
|
1102
|
+
hidden_states = ff_output + hidden_states
|
1103
|
+
if hidden_states.ndim == 4:
|
1104
|
+
hidden_states = hidden_states.squeeze(1)
|
1105
|
+
|
1106
|
+
return hidden_states
|
1107
|
+
|
1108
|
+
|
732
1109
|
class FeedForward(nn.Module):
|
733
1110
|
r"""
|
734
1111
|
A feed-forward layer.
|
@@ -767,6 +1144,8 @@ class FeedForward(nn.Module):
|
|
767
1144
|
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
768
1145
|
elif activation_fn == "geglu-approximate":
|
769
1146
|
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
1147
|
+
elif activation_fn == "swiglu":
|
1148
|
+
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
770
1149
|
|
771
1150
|
self.net = nn.ModuleList([])
|
772
1151
|
# project in
|