diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -434,9 +434,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
434
434
|
force_download (`bool`, *optional*, defaults to `False`):
|
435
435
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
436
436
|
cached versions if they exist.
|
437
|
-
resume_download:
|
438
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
439
|
-
of Diffusers.
|
440
437
|
proxies (`Dict[str, str]`, *optional*):
|
441
438
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
442
439
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -518,7 +515,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
518
515
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
519
516
|
force_download = kwargs.pop("force_download", False)
|
520
517
|
from_flax = kwargs.pop("from_flax", False)
|
521
|
-
resume_download = kwargs.pop("resume_download", None)
|
522
518
|
proxies = kwargs.pop("proxies", None)
|
523
519
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
524
520
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -619,7 +615,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
619
615
|
return_unused_kwargs=True,
|
620
616
|
return_commit_hash=True,
|
621
617
|
force_download=force_download,
|
622
|
-
resume_download=resume_download,
|
623
618
|
proxies=proxies,
|
624
619
|
local_files_only=local_files_only,
|
625
620
|
token=token,
|
@@ -641,7 +636,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
641
636
|
cache_dir=cache_dir,
|
642
637
|
variant=variant,
|
643
638
|
force_download=force_download,
|
644
|
-
resume_download=resume_download,
|
645
639
|
proxies=proxies,
|
646
640
|
local_files_only=local_files_only,
|
647
641
|
token=token,
|
@@ -663,7 +657,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
663
657
|
weights_name=FLAX_WEIGHTS_NAME,
|
664
658
|
cache_dir=cache_dir,
|
665
659
|
force_download=force_download,
|
666
|
-
resume_download=resume_download,
|
667
660
|
proxies=proxies,
|
668
661
|
local_files_only=local_files_only,
|
669
662
|
token=token,
|
@@ -685,7 +678,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
685
678
|
index_file,
|
686
679
|
cache_dir=cache_dir,
|
687
680
|
proxies=proxies,
|
688
|
-
resume_download=resume_download,
|
689
681
|
local_files_only=local_files_only,
|
690
682
|
token=token,
|
691
683
|
user_agent=user_agent,
|
@@ -700,7 +692,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
700
692
|
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
701
693
|
cache_dir=cache_dir,
|
702
694
|
force_download=force_download,
|
703
|
-
resume_download=resume_download,
|
704
695
|
proxies=proxies,
|
705
696
|
local_files_only=local_files_only,
|
706
697
|
token=token,
|
@@ -724,7 +715,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
724
715
|
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
725
716
|
cache_dir=cache_dir,
|
726
717
|
force_download=force_download,
|
727
|
-
resume_download=resume_download,
|
728
718
|
proxies=proxies,
|
729
719
|
local_files_only=local_files_only,
|
730
720
|
token=token,
|
@@ -783,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
783
773
|
try:
|
784
774
|
accelerate.load_checkpoint_and_dispatch(
|
785
775
|
model,
|
786
|
-
model_file if not is_sharded else
|
776
|
+
model_file if not is_sharded else index_file,
|
787
777
|
device_map,
|
788
778
|
max_memory=max_memory,
|
789
779
|
offload_folder=offload_folder,
|
@@ -813,13 +803,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
813
803
|
model._temp_convert_self_to_deprecated_attention_blocks()
|
814
804
|
accelerate.load_checkpoint_and_dispatch(
|
815
805
|
model,
|
816
|
-
model_file if not is_sharded else
|
806
|
+
model_file if not is_sharded else index_file,
|
817
807
|
device_map,
|
818
808
|
max_memory=max_memory,
|
819
809
|
offload_folder=offload_folder,
|
820
810
|
offload_state_dict=offload_state_dict,
|
821
811
|
dtype=torch_dtype,
|
822
|
-
|
812
|
+
force_hooks=force_hook,
|
823
813
|
strict=True,
|
824
814
|
)
|
825
815
|
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
@@ -1169,7 +1159,7 @@ class LegacyModelMixin(ModelMixin):
|
|
1169
1159
|
@classmethod
|
1170
1160
|
@validate_hf_hub_args
|
1171
1161
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
1172
|
-
# To prevent
|
1162
|
+
# To prevent dependency import problem.
|
1173
1163
|
from .model_loading_utils import _fetch_remapped_cls_from_config
|
1174
1164
|
|
1175
1165
|
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
@@ -1177,7 +1167,6 @@ class LegacyModelMixin(ModelMixin):
|
|
1177
1167
|
|
1178
1168
|
cache_dir = kwargs.pop("cache_dir", None)
|
1179
1169
|
force_download = kwargs.pop("force_download", False)
|
1180
|
-
resume_download = kwargs.pop("resume_download", None)
|
1181
1170
|
proxies = kwargs.pop("proxies", None)
|
1182
1171
|
local_files_only = kwargs.pop("local_files_only", None)
|
1183
1172
|
token = kwargs.pop("token", None)
|
@@ -1200,7 +1189,6 @@ class LegacyModelMixin(ModelMixin):
|
|
1200
1189
|
return_unused_kwargs=True,
|
1201
1190
|
return_commit_hash=True,
|
1202
1191
|
force_download=force_download,
|
1203
|
-
resume_download=resume_download,
|
1204
1192
|
proxies=proxies,
|
1205
1193
|
local_files_only=local_files_only,
|
1206
1194
|
token=token,
|
@@ -22,7 +22,10 @@ import torch.nn.functional as F
|
|
22
22
|
|
23
23
|
from ..utils import is_torch_version
|
24
24
|
from .activations import get_activation
|
25
|
-
from .embeddings import
|
25
|
+
from .embeddings import (
|
26
|
+
CombinedTimestepLabelEmbeddings,
|
27
|
+
PixArtAlphaCombinedTimestepSizeEmbeddings,
|
28
|
+
)
|
26
29
|
|
27
30
|
|
28
31
|
class AdaLayerNorm(nn.Module):
|
@@ -31,23 +34,69 @@ class AdaLayerNorm(nn.Module):
|
|
31
34
|
|
32
35
|
Parameters:
|
33
36
|
embedding_dim (`int`): The size of each embedding vector.
|
34
|
-
num_embeddings (`int
|
37
|
+
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
38
|
+
output_dim (`int`, *optional*):
|
39
|
+
norm_elementwise_affine (`bool`, defaults to `False):
|
40
|
+
norm_eps (`bool`, defaults to `False`):
|
41
|
+
chunk_dim (`int`, defaults to `0`):
|
35
42
|
"""
|
36
43
|
|
37
|
-
def __init__(
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
embedding_dim: int,
|
47
|
+
num_embeddings: Optional[int] = None,
|
48
|
+
output_dim: Optional[int] = None,
|
49
|
+
norm_elementwise_affine: bool = False,
|
50
|
+
norm_eps: float = 1e-5,
|
51
|
+
chunk_dim: int = 0,
|
52
|
+
):
|
38
53
|
super().__init__()
|
39
|
-
|
54
|
+
|
55
|
+
self.chunk_dim = chunk_dim
|
56
|
+
output_dim = output_dim or embedding_dim * 2
|
57
|
+
|
58
|
+
if num_embeddings is not None:
|
59
|
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
60
|
+
else:
|
61
|
+
self.emb = None
|
62
|
+
|
40
63
|
self.silu = nn.SiLU()
|
41
|
-
self.linear = nn.Linear(embedding_dim,
|
42
|
-
self.norm = nn.LayerNorm(
|
64
|
+
self.linear = nn.Linear(embedding_dim, output_dim)
|
65
|
+
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
66
|
+
|
67
|
+
def forward(
|
68
|
+
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
|
69
|
+
) -> torch.Tensor:
|
70
|
+
if self.emb is not None:
|
71
|
+
temb = self.emb(timestep)
|
72
|
+
|
73
|
+
temb = self.linear(self.silu(temb))
|
74
|
+
|
75
|
+
if self.chunk_dim == 1:
|
76
|
+
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
77
|
+
# other if-branch. This branch is specific to CogVideoX for now.
|
78
|
+
shift, scale = temb.chunk(2, dim=1)
|
79
|
+
shift = shift[:, None, :]
|
80
|
+
scale = scale[:, None, :]
|
81
|
+
else:
|
82
|
+
scale, shift = temb.chunk(2, dim=0)
|
43
83
|
|
44
|
-
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
45
|
-
emb = self.linear(self.silu(self.emb(timestep)))
|
46
|
-
scale, shift = torch.chunk(emb, 2)
|
47
84
|
x = self.norm(x) * (1 + scale) + shift
|
48
85
|
return x
|
49
86
|
|
50
87
|
|
88
|
+
class FP32LayerNorm(nn.LayerNorm):
|
89
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
90
|
+
origin_dtype = inputs.dtype
|
91
|
+
return F.layer_norm(
|
92
|
+
inputs.float(),
|
93
|
+
self.normalized_shape,
|
94
|
+
self.weight.float() if self.weight is not None else None,
|
95
|
+
self.bias.float() if self.bias is not None else None,
|
96
|
+
self.eps,
|
97
|
+
).to(origin_dtype)
|
98
|
+
|
99
|
+
|
51
100
|
class AdaLayerNormZero(nn.Module):
|
52
101
|
r"""
|
53
102
|
Norm layer adaptive layer norm zero (adaLN-Zero).
|
@@ -57,7 +106,7 @@ class AdaLayerNormZero(nn.Module):
|
|
57
106
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
58
107
|
"""
|
59
108
|
|
60
|
-
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
|
109
|
+
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
61
110
|
super().__init__()
|
62
111
|
if num_embeddings is not None:
|
63
112
|
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
@@ -65,8 +114,15 @@ class AdaLayerNormZero(nn.Module):
|
|
65
114
|
self.emb = None
|
66
115
|
|
67
116
|
self.silu = nn.SiLU()
|
68
|
-
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=
|
69
|
-
|
117
|
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
118
|
+
if norm_type == "layer_norm":
|
119
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
120
|
+
elif norm_type == "fp32_layer_norm":
|
121
|
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
122
|
+
else:
|
123
|
+
raise ValueError(
|
124
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
125
|
+
)
|
70
126
|
|
71
127
|
def forward(
|
72
128
|
self,
|
@@ -84,6 +140,69 @@ class AdaLayerNormZero(nn.Module):
|
|
84
140
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
85
141
|
|
86
142
|
|
143
|
+
class AdaLayerNormZeroSingle(nn.Module):
|
144
|
+
r"""
|
145
|
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
146
|
+
|
147
|
+
Parameters:
|
148
|
+
embedding_dim (`int`): The size of each embedding vector.
|
149
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
150
|
+
"""
|
151
|
+
|
152
|
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
153
|
+
super().__init__()
|
154
|
+
|
155
|
+
self.silu = nn.SiLU()
|
156
|
+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
157
|
+
if norm_type == "layer_norm":
|
158
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
159
|
+
else:
|
160
|
+
raise ValueError(
|
161
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
162
|
+
)
|
163
|
+
|
164
|
+
def forward(
|
165
|
+
self,
|
166
|
+
x: torch.Tensor,
|
167
|
+
emb: Optional[torch.Tensor] = None,
|
168
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
169
|
+
emb = self.linear(self.silu(emb))
|
170
|
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
171
|
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
172
|
+
return x, gate_msa
|
173
|
+
|
174
|
+
|
175
|
+
class LuminaRMSNormZero(nn.Module):
|
176
|
+
"""
|
177
|
+
Norm layer adaptive RMS normalization zero.
|
178
|
+
|
179
|
+
Parameters:
|
180
|
+
embedding_dim (`int`): The size of each embedding vector.
|
181
|
+
"""
|
182
|
+
|
183
|
+
def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
|
184
|
+
super().__init__()
|
185
|
+
self.silu = nn.SiLU()
|
186
|
+
self.linear = nn.Linear(
|
187
|
+
min(embedding_dim, 1024),
|
188
|
+
4 * embedding_dim,
|
189
|
+
bias=True,
|
190
|
+
)
|
191
|
+
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
192
|
+
|
193
|
+
def forward(
|
194
|
+
self,
|
195
|
+
x: torch.Tensor,
|
196
|
+
emb: Optional[torch.Tensor] = None,
|
197
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
198
|
+
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
|
199
|
+
emb = self.linear(self.silu(emb))
|
200
|
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
201
|
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
202
|
+
|
203
|
+
return x, gate_msa, scale_mlp, gate_mlp
|
204
|
+
|
205
|
+
|
87
206
|
class AdaLayerNormSingle(nn.Module):
|
88
207
|
r"""
|
89
208
|
Norm layer adaptive layer norm single (adaLN-single).
|
@@ -188,6 +307,78 @@ class AdaLayerNormContinuous(nn.Module):
|
|
188
307
|
return x
|
189
308
|
|
190
309
|
|
310
|
+
class LuminaLayerNormContinuous(nn.Module):
|
311
|
+
def __init__(
|
312
|
+
self,
|
313
|
+
embedding_dim: int,
|
314
|
+
conditioning_embedding_dim: int,
|
315
|
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
316
|
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
317
|
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
318
|
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
319
|
+
# set `elementwise_affine` to False.
|
320
|
+
elementwise_affine=True,
|
321
|
+
eps=1e-5,
|
322
|
+
bias=True,
|
323
|
+
norm_type="layer_norm",
|
324
|
+
out_dim: Optional[int] = None,
|
325
|
+
):
|
326
|
+
super().__init__()
|
327
|
+
# AdaLN
|
328
|
+
self.silu = nn.SiLU()
|
329
|
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
330
|
+
if norm_type == "layer_norm":
|
331
|
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
332
|
+
else:
|
333
|
+
raise ValueError(f"unknown norm_type {norm_type}")
|
334
|
+
# linear_2
|
335
|
+
if out_dim is not None:
|
336
|
+
self.linear_2 = nn.Linear(
|
337
|
+
embedding_dim,
|
338
|
+
out_dim,
|
339
|
+
bias=bias,
|
340
|
+
)
|
341
|
+
|
342
|
+
def forward(
|
343
|
+
self,
|
344
|
+
x: torch.Tensor,
|
345
|
+
conditioning_embedding: torch.Tensor,
|
346
|
+
) -> torch.Tensor:
|
347
|
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
348
|
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
349
|
+
scale = emb
|
350
|
+
x = self.norm(x) * (1 + scale)[:, None, :]
|
351
|
+
|
352
|
+
if self.linear_2 is not None:
|
353
|
+
x = self.linear_2(x)
|
354
|
+
|
355
|
+
return x
|
356
|
+
|
357
|
+
|
358
|
+
class CogVideoXLayerNormZero(nn.Module):
|
359
|
+
def __init__(
|
360
|
+
self,
|
361
|
+
conditioning_dim: int,
|
362
|
+
embedding_dim: int,
|
363
|
+
elementwise_affine: bool = True,
|
364
|
+
eps: float = 1e-5,
|
365
|
+
bias: bool = True,
|
366
|
+
) -> None:
|
367
|
+
super().__init__()
|
368
|
+
|
369
|
+
self.silu = nn.SiLU()
|
370
|
+
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
371
|
+
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
372
|
+
|
373
|
+
def forward(
|
374
|
+
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
375
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
376
|
+
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
377
|
+
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
378
|
+
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
379
|
+
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
380
|
+
|
381
|
+
|
191
382
|
if is_torch_version(">=", "2.1.0"):
|
192
383
|
LayerNorm = nn.LayerNorm
|
193
384
|
else:
|
@@ -2,12 +2,18 @@ from ...utils import is_torch_available
|
|
2
2
|
|
3
3
|
|
4
4
|
if is_torch_available():
|
5
|
+
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
|
6
|
+
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
5
7
|
from .dit_transformer_2d import DiTTransformer2DModel
|
6
8
|
from .dual_transformer_2d import DualTransformer2DModel
|
7
9
|
from .hunyuan_transformer_2d import HunyuanDiT2DModel
|
10
|
+
from .latte_transformer_3d import LatteTransformer3DModel
|
11
|
+
from .lumina_nextdit2d import LuminaNextDiT2DModel
|
8
12
|
from .pixart_transformer_2d import PixArtTransformer2DModel
|
9
13
|
from .prior_transformer import PriorTransformer
|
14
|
+
from .stable_audio_transformer import StableAudioDiTModel
|
10
15
|
from .t5_film_transformer import T5FilmDecoder
|
11
16
|
from .transformer_2d import Transformer2DModel
|
17
|
+
from .transformer_flux import FluxTransformer2DModel
|
12
18
|
from .transformer_sd3 import SD3Transformer2DModel
|
13
19
|
from .transformer_temporal import TransformerTemporalModel
|