diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -22,10 +22,7 @@ import torch.nn.functional as F
|
|
22
22
|
|
23
23
|
from ..utils import is_torch_version
|
24
24
|
from .activations import get_activation
|
25
|
-
from .embeddings import
|
26
|
-
CombinedTimestepLabelEmbeddings,
|
27
|
-
PixArtAlphaCombinedTimestepSizeEmbeddings,
|
28
|
-
)
|
25
|
+
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
29
26
|
|
30
27
|
|
31
28
|
class AdaLayerNorm(nn.Module):
|
@@ -97,6 +94,40 @@ class FP32LayerNorm(nn.LayerNorm):
|
|
97
94
|
).to(origin_dtype)
|
98
95
|
|
99
96
|
|
97
|
+
class SD35AdaLayerNormZeroX(nn.Module):
|
98
|
+
r"""
|
99
|
+
Norm layer adaptive layer norm zero (AdaLN-Zero).
|
100
|
+
|
101
|
+
Parameters:
|
102
|
+
embedding_dim (`int`): The size of each embedding vector.
|
103
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
104
|
+
"""
|
105
|
+
|
106
|
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
|
107
|
+
super().__init__()
|
108
|
+
|
109
|
+
self.silu = nn.SiLU()
|
110
|
+
self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
|
111
|
+
if norm_type == "layer_norm":
|
112
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
113
|
+
else:
|
114
|
+
raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
|
115
|
+
|
116
|
+
def forward(
|
117
|
+
self,
|
118
|
+
hidden_states: torch.Tensor,
|
119
|
+
emb: Optional[torch.Tensor] = None,
|
120
|
+
) -> Tuple[torch.Tensor, ...]:
|
121
|
+
emb = self.linear(self.silu(emb))
|
122
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
|
123
|
+
9, dim=1
|
124
|
+
)
|
125
|
+
norm_hidden_states = self.norm(hidden_states)
|
126
|
+
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
127
|
+
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
|
128
|
+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
|
129
|
+
|
130
|
+
|
100
131
|
class AdaLayerNormZero(nn.Module):
|
101
132
|
r"""
|
102
133
|
Norm layer adaptive layer norm zero (adaLN-Zero).
|
@@ -232,6 +263,7 @@ class AdaLayerNormSingle(nn.Module):
|
|
232
263
|
hidden_dtype: Optional[torch.dtype] = None,
|
233
264
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
234
265
|
# No modulation happening here.
|
266
|
+
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
235
267
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
236
268
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
237
269
|
|
@@ -324,20 +356,21 @@ class LuminaLayerNormContinuous(nn.Module):
|
|
324
356
|
out_dim: Optional[int] = None,
|
325
357
|
):
|
326
358
|
super().__init__()
|
359
|
+
|
327
360
|
# AdaLN
|
328
361
|
self.silu = nn.SiLU()
|
329
362
|
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
363
|
+
|
330
364
|
if norm_type == "layer_norm":
|
331
365
|
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
366
|
+
elif norm_type == "rms_norm":
|
367
|
+
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
332
368
|
else:
|
333
369
|
raise ValueError(f"unknown norm_type {norm_type}")
|
334
|
-
|
370
|
+
|
371
|
+
self.linear_2 = None
|
335
372
|
if out_dim is not None:
|
336
|
-
self.linear_2 = nn.Linear(
|
337
|
-
embedding_dim,
|
338
|
-
out_dim,
|
339
|
-
bias=bias,
|
340
|
-
)
|
373
|
+
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
341
374
|
|
342
375
|
def forward(
|
343
376
|
self,
|
@@ -355,6 +388,51 @@ class LuminaLayerNormContinuous(nn.Module):
|
|
355
388
|
return x
|
356
389
|
|
357
390
|
|
391
|
+
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
|
392
|
+
r"""
|
393
|
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
394
|
+
|
395
|
+
Parameters:
|
396
|
+
embedding_dim (`int`): The size of each embedding vector.
|
397
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
398
|
+
"""
|
399
|
+
|
400
|
+
def __init__(self, embedding_dim: int, dim: int):
|
401
|
+
super().__init__()
|
402
|
+
|
403
|
+
self.silu = nn.SiLU()
|
404
|
+
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
405
|
+
self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
406
|
+
self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
407
|
+
|
408
|
+
def forward(
|
409
|
+
self,
|
410
|
+
x: torch.Tensor,
|
411
|
+
context: torch.Tensor,
|
412
|
+
emb: Optional[torch.Tensor] = None,
|
413
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
414
|
+
emb = self.linear(self.silu(emb))
|
415
|
+
(
|
416
|
+
shift_msa,
|
417
|
+
scale_msa,
|
418
|
+
gate_msa,
|
419
|
+
shift_mlp,
|
420
|
+
scale_mlp,
|
421
|
+
gate_mlp,
|
422
|
+
c_shift_msa,
|
423
|
+
c_scale_msa,
|
424
|
+
c_gate_msa,
|
425
|
+
c_shift_mlp,
|
426
|
+
c_scale_mlp,
|
427
|
+
c_gate_mlp,
|
428
|
+
) = emb.chunk(12, dim=1)
|
429
|
+
normed_x = self.norm_x(x)
|
430
|
+
normed_context = self.norm_c(context)
|
431
|
+
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
432
|
+
context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
|
433
|
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
|
434
|
+
|
435
|
+
|
358
436
|
class CogVideoXLayerNormZero(nn.Module):
|
359
437
|
def __init__(
|
360
438
|
self,
|
@@ -407,20 +485,24 @@ else:
|
|
407
485
|
|
408
486
|
|
409
487
|
class RMSNorm(nn.Module):
|
410
|
-
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
488
|
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
411
489
|
super().__init__()
|
412
490
|
|
413
491
|
self.eps = eps
|
492
|
+
self.elementwise_affine = elementwise_affine
|
414
493
|
|
415
494
|
if isinstance(dim, numbers.Integral):
|
416
495
|
dim = (dim,)
|
417
496
|
|
418
497
|
self.dim = torch.Size(dim)
|
419
498
|
|
499
|
+
self.weight = None
|
500
|
+
self.bias = None
|
501
|
+
|
420
502
|
if elementwise_affine:
|
421
503
|
self.weight = nn.Parameter(torch.ones(dim))
|
422
|
-
|
423
|
-
|
504
|
+
if bias:
|
505
|
+
self.bias = nn.Parameter(torch.zeros(dim))
|
424
506
|
|
425
507
|
def forward(self, hidden_states):
|
426
508
|
input_dtype = hidden_states.dtype
|
@@ -432,12 +514,44 @@ class RMSNorm(nn.Module):
|
|
432
514
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
433
515
|
hidden_states = hidden_states.to(self.weight.dtype)
|
434
516
|
hidden_states = hidden_states * self.weight
|
517
|
+
if self.bias is not None:
|
518
|
+
hidden_states = hidden_states + self.bias
|
435
519
|
else:
|
436
520
|
hidden_states = hidden_states.to(input_dtype)
|
437
521
|
|
438
522
|
return hidden_states
|
439
523
|
|
440
524
|
|
525
|
+
# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
|
526
|
+
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
|
527
|
+
class MochiRMSNorm(nn.Module):
|
528
|
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
529
|
+
super().__init__()
|
530
|
+
|
531
|
+
self.eps = eps
|
532
|
+
|
533
|
+
if isinstance(dim, numbers.Integral):
|
534
|
+
dim = (dim,)
|
535
|
+
|
536
|
+
self.dim = torch.Size(dim)
|
537
|
+
|
538
|
+
if elementwise_affine:
|
539
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
540
|
+
else:
|
541
|
+
self.weight = None
|
542
|
+
|
543
|
+
def forward(self, hidden_states):
|
544
|
+
input_dtype = hidden_states.dtype
|
545
|
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
546
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
547
|
+
|
548
|
+
if self.weight is not None:
|
549
|
+
hidden_states = hidden_states * self.weight
|
550
|
+
hidden_states = hidden_states.to(input_dtype)
|
551
|
+
|
552
|
+
return hidden_states
|
553
|
+
|
554
|
+
|
441
555
|
class GlobalResponseNorm(nn.Module):
|
442
556
|
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
443
557
|
def __init__(self, dim):
|
@@ -449,3 +563,33 @@ class GlobalResponseNorm(nn.Module):
|
|
449
563
|
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
450
564
|
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
451
565
|
return self.gamma * (x * nx) + self.beta + x
|
566
|
+
|
567
|
+
|
568
|
+
class LpNorm(nn.Module):
|
569
|
+
def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
|
570
|
+
super().__init__()
|
571
|
+
|
572
|
+
self.p = p
|
573
|
+
self.dim = dim
|
574
|
+
self.eps = eps
|
575
|
+
|
576
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
577
|
+
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
|
578
|
+
|
579
|
+
|
580
|
+
def get_normalization(
|
581
|
+
norm_type: str = "batch_norm",
|
582
|
+
num_features: Optional[int] = None,
|
583
|
+
eps: float = 1e-5,
|
584
|
+
elementwise_affine: bool = True,
|
585
|
+
bias: bool = True,
|
586
|
+
) -> nn.Module:
|
587
|
+
if norm_type == "rms_norm":
|
588
|
+
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
589
|
+
elif norm_type == "layer_norm":
|
590
|
+
norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
591
|
+
elif norm_type == "batch_norm":
|
592
|
+
norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
|
593
|
+
else:
|
594
|
+
raise ValueError(f"{norm_type=} is not supported.")
|
595
|
+
return norm
|
@@ -11,9 +11,15 @@ if is_torch_available():
|
|
11
11
|
from .lumina_nextdit2d import LuminaNextDiT2DModel
|
12
12
|
from .pixart_transformer_2d import PixArtTransformer2DModel
|
13
13
|
from .prior_transformer import PriorTransformer
|
14
|
+
from .sana_transformer import SanaTransformer2DModel
|
14
15
|
from .stable_audio_transformer import StableAudioDiTModel
|
15
16
|
from .t5_film_transformer import T5FilmDecoder
|
16
17
|
from .transformer_2d import Transformer2DModel
|
18
|
+
from .transformer_allegro import AllegroTransformer3DModel
|
19
|
+
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
17
20
|
from .transformer_flux import FluxTransformer2DModel
|
21
|
+
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
22
|
+
from .transformer_ltx import LTXVideoTransformer3DModel
|
23
|
+
from .transformer_mochi import MochiTransformer3DModel
|
18
24
|
from .transformer_sd3 import SD3Transformer2DModel
|
19
25
|
from .transformer_temporal import TransformerTemporalModel
|
@@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
274
274
|
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
|
275
275
|
"""
|
276
276
|
|
277
|
+
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
277
278
|
_supports_gradient_checkpointing = True
|
278
279
|
|
279
280
|
@register_to_config
|
@@ -465,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
465
466
|
|
466
467
|
# MMDiT blocks.
|
467
468
|
for index_block, block in enumerate(self.joint_transformer_blocks):
|
468
|
-
if
|
469
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
469
470
|
|
470
471
|
def create_custom_forward(module, return_dict=None):
|
471
472
|
def custom_forward(*inputs):
|
@@ -496,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
496
497
|
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
497
498
|
|
498
499
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
499
|
-
if
|
500
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
500
501
|
|
501
502
|
def create_custom_forward(module, return_dict=None):
|
502
503
|
def custom_forward(*inputs):
|
@@ -19,7 +19,8 @@ import torch
|
|
19
19
|
from torch import nn
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
-
from ...
|
22
|
+
from ...loaders import PeftAdapterMixin
|
23
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
23
24
|
from ...utils.torch_utils import maybe_allow_in_graph
|
24
25
|
from ..attention import Attention, FeedForward
|
25
26
|
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
@@ -152,7 +153,7 @@ class CogVideoXBlock(nn.Module):
|
|
152
153
|
return hidden_states, encoder_hidden_states
|
153
154
|
|
154
155
|
|
155
|
-
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
156
|
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
156
157
|
"""
|
157
158
|
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
158
159
|
|
@@ -169,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
169
170
|
Whether to flip the sin to cos in the time embedding.
|
170
171
|
time_embed_dim (`int`, defaults to `512`):
|
171
172
|
Output dimension of timestep embeddings.
|
173
|
+
ofs_embed_dim (`int`, defaults to `512`):
|
174
|
+
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
|
172
175
|
text_embed_dim (`int`, defaults to `4096`):
|
173
176
|
Input dimension of text embeddings from the text encoder.
|
174
177
|
num_layers (`int`, defaults to `30`):
|
@@ -176,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
176
179
|
dropout (`float`, defaults to `0.0`):
|
177
180
|
The dropout probability to use.
|
178
181
|
attention_bias (`bool`, defaults to `True`):
|
179
|
-
Whether
|
182
|
+
Whether to use bias in the attention projection layers.
|
180
183
|
sample_width (`int`, defaults to `90`):
|
181
184
|
The width of the input latents.
|
182
185
|
sample_height (`int`, defaults to `60`):
|
@@ -197,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
197
200
|
timestep_activation_fn (`str`, defaults to `"silu"`):
|
198
201
|
Activation function to use when generating the timestep embeddings.
|
199
202
|
norm_elementwise_affine (`bool`, defaults to `True`):
|
200
|
-
Whether
|
203
|
+
Whether to use elementwise affine in normalization layers.
|
201
204
|
norm_eps (`float`, defaults to `1e-5`):
|
202
205
|
The epsilon value to use in normalization layers.
|
203
206
|
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
@@ -218,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
218
221
|
flip_sin_to_cos: bool = True,
|
219
222
|
freq_shift: int = 0,
|
220
223
|
time_embed_dim: int = 512,
|
224
|
+
ofs_embed_dim: Optional[int] = None,
|
221
225
|
text_embed_dim: int = 4096,
|
222
226
|
num_layers: int = 30,
|
223
227
|
dropout: float = 0.0,
|
@@ -226,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
226
230
|
sample_height: int = 60,
|
227
231
|
sample_frames: int = 49,
|
228
232
|
patch_size: int = 2,
|
233
|
+
patch_size_t: Optional[int] = None,
|
229
234
|
temporal_compression_ratio: int = 4,
|
230
235
|
max_text_seq_length: int = 226,
|
231
236
|
activation_fn: str = "gelu-approximate",
|
@@ -236,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
236
241
|
temporal_interpolation_scale: float = 1.0,
|
237
242
|
use_rotary_positional_embeddings: bool = False,
|
238
243
|
use_learned_positional_embeddings: bool = False,
|
244
|
+
patch_bias: bool = True,
|
239
245
|
):
|
240
246
|
super().__init__()
|
241
247
|
inner_dim = num_attention_heads * attention_head_dim
|
@@ -250,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
250
256
|
# 1. Patch embedding
|
251
257
|
self.patch_embed = CogVideoXPatchEmbed(
|
252
258
|
patch_size=patch_size,
|
259
|
+
patch_size_t=patch_size_t,
|
253
260
|
in_channels=in_channels,
|
254
261
|
embed_dim=inner_dim,
|
255
262
|
text_embed_dim=text_embed_dim,
|
256
|
-
bias=
|
263
|
+
bias=patch_bias,
|
257
264
|
sample_width=sample_width,
|
258
265
|
sample_height=sample_height,
|
259
266
|
sample_frames=sample_frames,
|
@@ -266,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
266
273
|
)
|
267
274
|
self.embedding_dropout = nn.Dropout(dropout)
|
268
275
|
|
269
|
-
# 2. Time embeddings
|
276
|
+
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
|
277
|
+
|
270
278
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
271
279
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
272
280
|
|
281
|
+
self.ofs_proj = None
|
282
|
+
self.ofs_embedding = None
|
283
|
+
if ofs_embed_dim:
|
284
|
+
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
285
|
+
self.ofs_embedding = TimestepEmbedding(
|
286
|
+
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
|
287
|
+
) # same as time embeddings, for ofs
|
288
|
+
|
273
289
|
# 3. Define spatio-temporal transformers blocks
|
274
290
|
self.transformer_blocks = nn.ModuleList(
|
275
291
|
[
|
@@ -297,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
297
313
|
norm_eps=norm_eps,
|
298
314
|
chunk_dim=1,
|
299
315
|
)
|
300
|
-
|
316
|
+
|
317
|
+
if patch_size_t is None:
|
318
|
+
# For CogVideox 1.0
|
319
|
+
output_dim = patch_size * patch_size * out_channels
|
320
|
+
else:
|
321
|
+
# For CogVideoX 1.5
|
322
|
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
323
|
+
|
324
|
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
301
325
|
|
302
326
|
self.gradient_checkpointing = False
|
303
327
|
|
@@ -410,9 +434,26 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
410
434
|
encoder_hidden_states: torch.Tensor,
|
411
435
|
timestep: Union[int, float, torch.LongTensor],
|
412
436
|
timestep_cond: Optional[torch.Tensor] = None,
|
437
|
+
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
413
438
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
439
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
414
440
|
return_dict: bool = True,
|
415
441
|
):
|
442
|
+
if attention_kwargs is not None:
|
443
|
+
attention_kwargs = attention_kwargs.copy()
|
444
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
445
|
+
else:
|
446
|
+
lora_scale = 1.0
|
447
|
+
|
448
|
+
if USE_PEFT_BACKEND:
|
449
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
450
|
+
scale_lora_layers(self, lora_scale)
|
451
|
+
else:
|
452
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
453
|
+
logger.warning(
|
454
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
455
|
+
)
|
456
|
+
|
416
457
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
417
458
|
|
418
459
|
# 1. Time embedding
|
@@ -425,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
425
466
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
426
467
|
emb = self.time_embedding(t_emb, timestep_cond)
|
427
468
|
|
469
|
+
if self.ofs_embedding is not None:
|
470
|
+
ofs_emb = self.ofs_proj(ofs)
|
471
|
+
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
472
|
+
ofs_emb = self.ofs_embedding(ofs_emb)
|
473
|
+
emb = emb + ofs_emb
|
474
|
+
|
428
475
|
# 2. Patch embedding
|
429
476
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
430
477
|
hidden_states = self.embedding_dropout(hidden_states)
|
@@ -435,7 +482,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
435
482
|
|
436
483
|
# 3. Transformer blocks
|
437
484
|
for i, block in enumerate(self.transformer_blocks):
|
438
|
-
if
|
485
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
439
486
|
|
440
487
|
def create_custom_forward(module):
|
441
488
|
def custom_forward(*inputs):
|
@@ -474,12 +521,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
474
521
|
hidden_states = self.proj_out(hidden_states)
|
475
522
|
|
476
523
|
# 5. Unpatchify
|
477
|
-
# Note: we use `-1` instead of `channels`:
|
478
|
-
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
479
|
-
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
480
524
|
p = self.config.patch_size
|
481
|
-
|
482
|
-
|
525
|
+
p_t = self.config.patch_size_t
|
526
|
+
|
527
|
+
if p_t is None:
|
528
|
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
529
|
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
530
|
+
else:
|
531
|
+
output = hidden_states.reshape(
|
532
|
+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
533
|
+
)
|
534
|
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
535
|
+
|
536
|
+
if USE_PEFT_BACKEND:
|
537
|
+
# remove `lora_scale` from each PEFT layer
|
538
|
+
unscale_lora_layers(self, lora_scale)
|
483
539
|
|
484
540
|
if not return_dict:
|
485
541
|
return (output,)
|
@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
184
184
|
|
185
185
|
# 2. Blocks
|
186
186
|
for block in self.transformer_blocks:
|
187
|
-
if
|
187
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
188
188
|
|
189
189
|
def create_custom_forward(module, return_dict=None):
|
190
190
|
def custom_forward(*inputs):
|
@@ -156,9 +156,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
156
156
|
|
157
157
|
# define temporal positional embedding
|
158
158
|
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
|
159
|
-
inner_dim, torch.arange(0, video_length).unsqueeze(1)
|
159
|
+
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
|
160
160
|
) # 1152 hidden size
|
161
|
-
self.register_buffer("temp_pos_embed",
|
161
|
+
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
|
162
162
|
|
163
163
|
self.gradient_checkpointing = False
|
164
164
|
|
@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
238
238
|
for i, (spatial_block, temp_block) in enumerate(
|
239
239
|
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
240
240
|
):
|
241
|
-
if
|
241
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
242
242
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
243
243
|
spatial_block,
|
244
244
|
hidden_states,
|
@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
271
271
|
if i == 0 and num_frame > 1:
|
272
272
|
hidden_states = hidden_states + self.temp_pos_embed
|
273
273
|
|
274
|
-
if
|
274
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
275
275
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
276
276
|
temp_block,
|
277
277
|
hidden_states,
|
@@ -19,7 +19,7 @@ from torch import nn
|
|
19
19
|
from ...configuration_utils import ConfigMixin, register_to_config
|
20
20
|
from ...utils import is_torch_version, logging
|
21
21
|
from ..attention import BasicTransformerBlock
|
22
|
-
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
|
22
|
+
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
23
23
|
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
24
24
|
from ..modeling_outputs import Transformer2DModelOutput
|
25
25
|
from ..modeling_utils import ModelMixin
|
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
247
247
|
for name, module in self.named_children():
|
248
248
|
fn_recursive_attn_processor(name, module, processor)
|
249
249
|
|
250
|
+
def set_default_attn_processor(self):
|
251
|
+
"""
|
252
|
+
Disables custom attention processors and sets the default attention implementation.
|
253
|
+
|
254
|
+
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
|
255
|
+
"""
|
256
|
+
self.set_attn_processor(AttnProcessor())
|
257
|
+
|
250
258
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
251
259
|
def fuse_qkv_projections(self):
|
252
260
|
"""
|
@@ -378,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
378
386
|
|
379
387
|
# 2. Blocks
|
380
388
|
for block in self.transformer_blocks:
|
381
|
-
if
|
389
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
382
390
|
|
383
391
|
def create_custom_forward(module, return_dict=None):
|
384
392
|
def custom_forward(*inputs):
|