diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +28 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +278 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
- diffusers-0.27.0.dist-info/RECORD +399 -0
- diffusers-0.26.2.dist-info/RECORD +0 -384
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
75
75
|
The tuple of downsample blocks to use.
|
76
76
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
|
77
77
|
The tuple of upsample blocks to use.
|
78
|
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
79
|
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
|
78
80
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
79
81
|
The tuple of output channels for each block.
|
80
82
|
layers_per_block (`int`, *optional*, defaults to 2):
|
@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
107
109
|
"DownBlock2D",
|
108
110
|
)
|
109
111
|
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
112
|
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
|
110
113
|
only_cross_attention: Union[bool, Tuple[bool]] = False
|
111
114
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
112
115
|
layers_per_block: int = 2
|
@@ -252,16 +255,21 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
252
255
|
self.down_blocks = down_blocks
|
253
256
|
|
254
257
|
# mid
|
255
|
-
self.
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
258
|
+
if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
|
259
|
+
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
260
|
+
in_channels=block_out_channels[-1],
|
261
|
+
dropout=self.dropout,
|
262
|
+
num_attention_heads=num_attention_heads[-1],
|
263
|
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
264
|
+
use_linear_projection=self.use_linear_projection,
|
265
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
266
|
+
split_head_dim=self.split_head_dim,
|
267
|
+
dtype=self.dtype,
|
268
|
+
)
|
269
|
+
elif self.config.mid_block_type is None:
|
270
|
+
self.mid_block = None
|
271
|
+
else:
|
272
|
+
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")
|
265
273
|
|
266
274
|
# up
|
267
275
|
up_blocks = []
|
@@ -412,7 +420,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
412
420
|
down_block_res_samples = new_down_block_res_samples
|
413
421
|
|
414
422
|
# 4. mid
|
415
|
-
|
423
|
+
if self.mid_block is not None:
|
424
|
+
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
416
425
|
|
417
426
|
if mid_block_additional_residual is not None:
|
418
427
|
sample += mid_block_additional_residual
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
|
20
|
-
from ...utils import is_torch_version
|
20
|
+
from ...utils import deprecate, is_torch_version, logging
|
21
21
|
from ...utils.torch_utils import apply_freeu
|
22
22
|
from ..attention import Attention
|
23
23
|
from ..resnet import (
|
@@ -35,6 +35,9 @@ from ..transformers.transformer_temporal import (
|
|
35
35
|
)
|
36
36
|
|
37
37
|
|
38
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39
|
+
|
40
|
+
|
38
41
|
def get_down_block(
|
39
42
|
down_block_type: str,
|
40
43
|
num_layers: int,
|
@@ -1005,9 +1008,14 @@ class DownBlockMotion(nn.Module):
|
|
1005
1008
|
self,
|
1006
1009
|
hidden_states: torch.FloatTensor,
|
1007
1010
|
temb: Optional[torch.FloatTensor] = None,
|
1008
|
-
scale: float = 1.0,
|
1009
1011
|
num_frames: int = 1,
|
1012
|
+
*args,
|
1013
|
+
**kwargs,
|
1010
1014
|
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1015
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1016
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1017
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1018
|
+
|
1011
1019
|
output_states = ()
|
1012
1020
|
|
1013
1021
|
blocks = zip(self.resnets, self.motion_modules)
|
@@ -1029,24 +1037,18 @@ class DownBlockMotion(nn.Module):
|
|
1029
1037
|
)
|
1030
1038
|
else:
|
1031
1039
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
1032
|
-
create_custom_forward(resnet), hidden_states, temb
|
1040
|
+
create_custom_forward(resnet), hidden_states, temb
|
1033
1041
|
)
|
1034
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1035
|
-
create_custom_forward(motion_module),
|
1036
|
-
hidden_states.requires_grad_(),
|
1037
|
-
temb,
|
1038
|
-
num_frames,
|
1039
|
-
)
|
1040
1042
|
|
1041
1043
|
else:
|
1042
|
-
hidden_states = resnet(hidden_states, temb
|
1043
|
-
|
1044
|
+
hidden_states = resnet(hidden_states, temb)
|
1045
|
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
1044
1046
|
|
1045
1047
|
output_states = output_states + (hidden_states,)
|
1046
1048
|
|
1047
1049
|
if self.downsamplers is not None:
|
1048
1050
|
for downsampler in self.downsamplers:
|
1049
|
-
hidden_states = downsampler(hidden_states
|
1051
|
+
hidden_states = downsampler(hidden_states)
|
1050
1052
|
|
1051
1053
|
output_states = output_states + (hidden_states,)
|
1052
1054
|
|
@@ -1179,9 +1181,11 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
1179
1181
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1180
1182
|
additional_residuals: Optional[torch.FloatTensor] = None,
|
1181
1183
|
):
|
1182
|
-
|
1184
|
+
if cross_attention_kwargs is not None:
|
1185
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1186
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
1183
1187
|
|
1184
|
-
|
1188
|
+
output_states = ()
|
1185
1189
|
|
1186
1190
|
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
|
1187
1191
|
for i, (resnet, attn, motion_module) in enumerate(blocks):
|
@@ -1212,7 +1216,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
1212
1216
|
return_dict=False,
|
1213
1217
|
)[0]
|
1214
1218
|
else:
|
1215
|
-
hidden_states = resnet(hidden_states, temb
|
1219
|
+
hidden_states = resnet(hidden_states, temb)
|
1216
1220
|
hidden_states = attn(
|
1217
1221
|
hidden_states,
|
1218
1222
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -1221,10 +1225,10 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
1221
1225
|
encoder_attention_mask=encoder_attention_mask,
|
1222
1226
|
return_dict=False,
|
1223
1227
|
)[0]
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
+
hidden_states = motion_module(
|
1229
|
+
hidden_states,
|
1230
|
+
num_frames=num_frames,
|
1231
|
+
)[0]
|
1228
1232
|
|
1229
1233
|
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
1230
1234
|
if i == len(blocks) - 1 and additional_residuals is not None:
|
@@ -1234,7 +1238,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
1234
1238
|
|
1235
1239
|
if self.downsamplers is not None:
|
1236
1240
|
for downsampler in self.downsamplers:
|
1237
|
-
hidden_states = downsampler(hidden_states
|
1241
|
+
hidden_states = downsampler(hidden_states)
|
1238
1242
|
|
1239
1243
|
output_states = output_states + (hidden_states,)
|
1240
1244
|
|
@@ -1361,7 +1365,10 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
1361
1365
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1362
1366
|
num_frames: int = 1,
|
1363
1367
|
) -> torch.FloatTensor:
|
1364
|
-
|
1368
|
+
if cross_attention_kwargs is not None:
|
1369
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1370
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
1371
|
+
|
1365
1372
|
is_freeu_enabled = (
|
1366
1373
|
getattr(self, "s1", None)
|
1367
1374
|
and getattr(self, "s2", None)
|
@@ -1416,7 +1423,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
1416
1423
|
return_dict=False,
|
1417
1424
|
)[0]
|
1418
1425
|
else:
|
1419
|
-
hidden_states = resnet(hidden_states, temb
|
1426
|
+
hidden_states = resnet(hidden_states, temb)
|
1420
1427
|
hidden_states = attn(
|
1421
1428
|
hidden_states,
|
1422
1429
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -1425,14 +1432,14 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
1425
1432
|
encoder_attention_mask=encoder_attention_mask,
|
1426
1433
|
return_dict=False,
|
1427
1434
|
)[0]
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1435
|
+
hidden_states = motion_module(
|
1436
|
+
hidden_states,
|
1437
|
+
num_frames=num_frames,
|
1438
|
+
)[0]
|
1432
1439
|
|
1433
1440
|
if self.upsamplers is not None:
|
1434
1441
|
for upsampler in self.upsamplers:
|
1435
|
-
hidden_states = upsampler(hidden_states, upsample_size
|
1442
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1436
1443
|
|
1437
1444
|
return hidden_states
|
1438
1445
|
|
@@ -1513,9 +1520,14 @@ class UpBlockMotion(nn.Module):
|
|
1513
1520
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1514
1521
|
temb: Optional[torch.FloatTensor] = None,
|
1515
1522
|
upsample_size=None,
|
1516
|
-
scale: float = 1.0,
|
1517
1523
|
num_frames: int = 1,
|
1524
|
+
*args,
|
1525
|
+
**kwargs,
|
1518
1526
|
) -> torch.FloatTensor:
|
1527
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1528
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1529
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1530
|
+
|
1519
1531
|
is_freeu_enabled = (
|
1520
1532
|
getattr(self, "s1", None)
|
1521
1533
|
and getattr(self, "s2", None)
|
@@ -1563,19 +1575,14 @@ class UpBlockMotion(nn.Module):
|
|
1563
1575
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
1564
1576
|
create_custom_forward(resnet), hidden_states, temb
|
1565
1577
|
)
|
1566
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1567
|
-
create_custom_forward(resnet),
|
1568
|
-
hidden_states,
|
1569
|
-
temb,
|
1570
|
-
)
|
1571
1578
|
|
1572
1579
|
else:
|
1573
|
-
hidden_states = resnet(hidden_states, temb
|
1574
|
-
|
1580
|
+
hidden_states = resnet(hidden_states, temb)
|
1581
|
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
1575
1582
|
|
1576
1583
|
if self.upsamplers is not None:
|
1577
1584
|
for upsampler in self.upsamplers:
|
1578
|
-
hidden_states = upsampler(hidden_states, upsample_size
|
1585
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1579
1586
|
|
1580
1587
|
return hidden_states
|
1581
1588
|
|
@@ -1698,8 +1705,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1698
1705
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1699
1706
|
num_frames: int = 1,
|
1700
1707
|
) -> torch.FloatTensor:
|
1701
|
-
|
1702
|
-
|
1708
|
+
if cross_attention_kwargs is not None:
|
1709
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1710
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
1711
|
+
|
1712
|
+
hidden_states = self.resnets[0](hidden_states, temb)
|
1703
1713
|
|
1704
1714
|
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
|
1705
1715
|
for attn, resnet, motion_module in blocks:
|
@@ -1748,7 +1758,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1748
1758
|
hidden_states,
|
1749
1759
|
num_frames=num_frames,
|
1750
1760
|
)[0]
|
1751
|
-
hidden_states = resnet(hidden_states, temb
|
1761
|
+
hidden_states = resnet(hidden_states, temb)
|
1752
1762
|
|
1753
1763
|
return hidden_states
|
1754
1764
|
|
@@ -1,5 +1,5 @@
|
|
1
|
-
# Copyright
|
2
|
-
# Copyright
|
1
|
+
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2
|
+
# Copyright 2024 The ModelScope Team.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -27,6 +27,7 @@ from ..activations import get_activation
|
|
27
27
|
from ..attention_processor import (
|
28
28
|
ADDED_KV_ATTENTION_PROCESSORS,
|
29
29
|
CROSS_ATTENTION_PROCESSORS,
|
30
|
+
Attention,
|
30
31
|
AttentionProcessor,
|
31
32
|
AttnAddedKVProcessor,
|
32
33
|
AttnProcessor,
|
@@ -54,7 +55,7 @@ class UNet3DConditionOutput(BaseOutput):
|
|
54
55
|
The output of [`UNet3DConditionModel`].
|
55
56
|
|
56
57
|
Args:
|
57
|
-
sample (`torch.FloatTensor` of shape `(batch_size,
|
58
|
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
|
58
59
|
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
59
60
|
"""
|
60
61
|
|
@@ -74,9 +75,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
74
75
|
Height and width of input/output sample.
|
75
76
|
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
76
77
|
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
77
|
-
down_block_types (`Tuple[str]`, *optional*, defaults to `("
|
78
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
|
78
79
|
The tuple of downsample blocks to use.
|
79
|
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("
|
80
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`):
|
80
81
|
The tuple of upsample blocks to use.
|
81
82
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
82
83
|
The tuple of output channels for each block.
|
@@ -87,8 +88,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
87
88
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
88
89
|
If `None`, normalization and activation layers is skipped in post-processing.
|
89
90
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
90
|
-
cross_attention_dim (`int`, *optional*, defaults to
|
91
|
-
attention_head_dim (`int`, *optional*, defaults to
|
91
|
+
cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
|
92
|
+
attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
|
92
93
|
num_attention_heads (`int`, *optional*): The number of attention heads.
|
93
94
|
"""
|
94
95
|
|
@@ -503,6 +504,44 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
503
504
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
504
505
|
setattr(upsample_block, k, None)
|
505
506
|
|
507
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
508
|
+
def fuse_qkv_projections(self):
|
509
|
+
"""
|
510
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
511
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
512
|
+
|
513
|
+
<Tip warning={true}>
|
514
|
+
|
515
|
+
This API is 🧪 experimental.
|
516
|
+
|
517
|
+
</Tip>
|
518
|
+
"""
|
519
|
+
self.original_attn_processors = None
|
520
|
+
|
521
|
+
for _, attn_processor in self.attn_processors.items():
|
522
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
523
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
524
|
+
|
525
|
+
self.original_attn_processors = self.attn_processors
|
526
|
+
|
527
|
+
for module in self.modules():
|
528
|
+
if isinstance(module, Attention):
|
529
|
+
module.fuse_projections(fuse=True)
|
530
|
+
|
531
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
532
|
+
def unfuse_qkv_projections(self):
|
533
|
+
"""Disables the fused QKV projection if enabled.
|
534
|
+
|
535
|
+
<Tip warning={true}>
|
536
|
+
|
537
|
+
This API is 🧪 experimental.
|
538
|
+
|
539
|
+
</Tip>
|
540
|
+
|
541
|
+
"""
|
542
|
+
if self.original_attn_processors is not None:
|
543
|
+
self.set_attn_processor(self.original_attn_processors)
|
544
|
+
|
506
545
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
|
507
546
|
def unload_lora(self):
|
508
547
|
"""Unloads LoRA weights."""
|
@@ -533,7 +572,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
533
572
|
|
534
573
|
Args:
|
535
574
|
sample (`torch.FloatTensor`):
|
536
|
-
The noisy input tensor with the following shape `(batch,
|
575
|
+
The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`.
|
537
576
|
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
538
577
|
encoder_hidden_states (`torch.FloatTensor`):
|
539
578
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -48,29 +48,6 @@ from .unet_3d_condition import UNet3DConditionOutput
|
|
48
48
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49
49
|
|
50
50
|
|
51
|
-
def _to_tensor(inputs, device):
|
52
|
-
if not torch.is_tensor(inputs):
|
53
|
-
# TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can
|
54
|
-
# This would be a good case for the `match` statement (Python 3.10+)
|
55
|
-
is_mps = device.type == "mps"
|
56
|
-
if isinstance(inputs, float):
|
57
|
-
dtype = torch.float32 if is_mps else torch.float64
|
58
|
-
else:
|
59
|
-
dtype = torch.int32 if is_mps else torch.int64
|
60
|
-
inputs = torch.tensor([inputs], dtype=dtype, device=device)
|
61
|
-
elif len(inputs.shape) == 0:
|
62
|
-
inputs = inputs[None].to(device)
|
63
|
-
|
64
|
-
return inputs
|
65
|
-
|
66
|
-
|
67
|
-
def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor:
|
68
|
-
batch_size, channels, num_frames, height, width = sample.shape
|
69
|
-
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
70
|
-
|
71
|
-
return sample
|
72
|
-
|
73
|
-
|
74
51
|
class I2VGenXLTransformerTemporalEncoder(nn.Module):
|
75
52
|
def __init__(
|
76
53
|
self,
|
@@ -112,7 +89,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
|
|
112
89
|
if hidden_states.ndim == 4:
|
113
90
|
hidden_states = hidden_states.squeeze(1)
|
114
91
|
|
115
|
-
ff_output = self.ff(hidden_states
|
92
|
+
ff_output = self.ff(hidden_states)
|
116
93
|
hidden_states = ff_output + hidden_states
|
117
94
|
if hidden_states.ndim == 4:
|
118
95
|
hidden_states = hidden_states.squeeze(1)
|
@@ -143,6 +120,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
143
120
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
144
121
|
If `None`, normalization and activation layers is skipped in post-processing.
|
145
122
|
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
123
|
+
attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim.
|
146
124
|
num_attention_heads (`int`, *optional*): The number of attention heads.
|
147
125
|
"""
|
148
126
|
|
@@ -170,11 +148,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
170
148
|
layers_per_block: int = 2,
|
171
149
|
norm_num_groups: Optional[int] = 32,
|
172
150
|
cross_attention_dim: int = 1024,
|
173
|
-
|
151
|
+
attention_head_dim: Union[int, Tuple[int]] = 64,
|
152
|
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
174
153
|
):
|
175
154
|
super().__init__()
|
176
155
|
|
177
|
-
|
156
|
+
# When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence
|
157
|
+
# of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This
|
158
|
+
# is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below.
|
159
|
+
# This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it
|
160
|
+
# without running proper depcrecation cycles for the {down,mid,up} blocks which are a
|
161
|
+
# part of the public API.
|
162
|
+
num_attention_heads = attention_head_dim
|
178
163
|
|
179
164
|
# Check inputs
|
180
165
|
if len(down_block_types) != len(up_block_types):
|
@@ -489,6 +474,44 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
489
474
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
490
475
|
setattr(upsample_block, k, None)
|
491
476
|
|
477
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
478
|
+
def fuse_qkv_projections(self):
|
479
|
+
"""
|
480
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
481
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
482
|
+
|
483
|
+
<Tip warning={true}>
|
484
|
+
|
485
|
+
This API is 🧪 experimental.
|
486
|
+
|
487
|
+
</Tip>
|
488
|
+
"""
|
489
|
+
self.original_attn_processors = None
|
490
|
+
|
491
|
+
for _, attn_processor in self.attn_processors.items():
|
492
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
493
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
494
|
+
|
495
|
+
self.original_attn_processors = self.attn_processors
|
496
|
+
|
497
|
+
for module in self.modules():
|
498
|
+
if isinstance(module, Attention):
|
499
|
+
module.fuse_projections(fuse=True)
|
500
|
+
|
501
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
502
|
+
def unfuse_qkv_projections(self):
|
503
|
+
"""Disables the fused QKV projection if enabled.
|
504
|
+
|
505
|
+
<Tip warning={true}>
|
506
|
+
|
507
|
+
This API is 🧪 experimental.
|
508
|
+
|
509
|
+
</Tip>
|
510
|
+
|
511
|
+
"""
|
512
|
+
if self.original_attn_processors is not None:
|
513
|
+
self.set_attn_processor(self.original_attn_processors)
|
514
|
+
|
492
515
|
def forward(
|
493
516
|
self,
|
494
517
|
sample: torch.FloatTensor,
|
@@ -543,7 +566,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
543
566
|
forward_upsample_size = True
|
544
567
|
|
545
568
|
# 1. time
|
546
|
-
timesteps =
|
569
|
+
timesteps = timestep
|
570
|
+
if not torch.is_tensor(timesteps):
|
571
|
+
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
|
572
|
+
# This would be a good case for the `match` statement (Python 3.10+)
|
573
|
+
is_mps = sample.device.type == "mps"
|
574
|
+
if isinstance(timesteps, float):
|
575
|
+
dtype = torch.float32 if is_mps else torch.float64
|
576
|
+
else:
|
577
|
+
dtype = torch.int32 if is_mps else torch.int64
|
578
|
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
579
|
+
elif len(timesteps.shape) == 0:
|
580
|
+
timesteps = timesteps[None].to(sample.device)
|
547
581
|
|
548
582
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
549
583
|
timesteps = timesteps.expand(sample.shape[0])
|
@@ -572,7 +606,13 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
572
606
|
context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim)
|
573
607
|
context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1)
|
574
608
|
|
575
|
-
|
609
|
+
image_latents_for_context_embds = image_latents[:, :, :1, :]
|
610
|
+
image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(
|
611
|
+
image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2],
|
612
|
+
image_latents_for_context_embds.shape[1],
|
613
|
+
image_latents_for_context_embds.shape[3],
|
614
|
+
image_latents_for_context_embds.shape[4],
|
615
|
+
)
|
576
616
|
image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs)
|
577
617
|
|
578
618
|
_batch_size, _channels, _height, _width = image_latents_context_embs.shape
|
@@ -586,7 +626,12 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
586
626
|
context_emb = torch.cat([context_emb, image_emb], dim=1)
|
587
627
|
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
|
588
628
|
|
589
|
-
image_latents =
|
629
|
+
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
|
630
|
+
image_latents.shape[0] * image_latents.shape[2],
|
631
|
+
image_latents.shape[1],
|
632
|
+
image_latents.shape[3],
|
633
|
+
image_latents.shape[4],
|
634
|
+
)
|
590
635
|
image_latents = self.image_latents_proj_in(image_latents)
|
591
636
|
image_latents = (
|
592
637
|
image_latents[None, :]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -23,6 +23,7 @@ from ...utils import logging
|
|
23
23
|
from ..attention_processor import (
|
24
24
|
ADDED_KV_ATTENTION_PROCESSORS,
|
25
25
|
CROSS_ATTENTION_PROCESSORS,
|
26
|
+
Attention,
|
26
27
|
AttentionProcessor,
|
27
28
|
AttnAddedKVProcessor,
|
28
29
|
AttnProcessor,
|
@@ -217,6 +218,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
217
218
|
use_motion_mid_block: int = True,
|
218
219
|
encoder_hid_dim: Optional[int] = None,
|
219
220
|
encoder_hid_dim_type: Optional[str] = None,
|
221
|
+
time_cond_proj_dim: Optional[int] = None,
|
220
222
|
):
|
221
223
|
super().__init__()
|
222
224
|
|
@@ -252,9 +254,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
252
254
|
timestep_input_dim = block_out_channels[0]
|
253
255
|
|
254
256
|
self.time_embedding = TimestepEmbedding(
|
255
|
-
timestep_input_dim,
|
256
|
-
time_embed_dim,
|
257
|
-
act_fn=act_fn,
|
257
|
+
timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim
|
258
258
|
)
|
259
259
|
|
260
260
|
if encoder_hid_dim_type is None:
|
@@ -306,6 +306,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
306
306
|
num_attention_heads=num_attention_heads[-1],
|
307
307
|
resnet_groups=norm_num_groups,
|
308
308
|
dual_cross_attention=False,
|
309
|
+
use_linear_projection=use_linear_projection,
|
309
310
|
temporal_num_attention_heads=motion_num_attention_heads,
|
310
311
|
temporal_max_seq_length=motion_max_seq_length,
|
311
312
|
)
|
@@ -321,6 +322,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
321
322
|
num_attention_heads=num_attention_heads[-1],
|
322
323
|
resnet_groups=norm_num_groups,
|
323
324
|
dual_cross_attention=False,
|
325
|
+
use_linear_projection=use_linear_projection,
|
324
326
|
)
|
325
327
|
|
326
328
|
# count how many layers upsample the images
|
@@ -700,6 +702,44 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
700
702
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
701
703
|
setattr(upsample_block, k, None)
|
702
704
|
|
705
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
706
|
+
def fuse_qkv_projections(self):
|
707
|
+
"""
|
708
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
709
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
710
|
+
|
711
|
+
<Tip warning={true}>
|
712
|
+
|
713
|
+
This API is 🧪 experimental.
|
714
|
+
|
715
|
+
</Tip>
|
716
|
+
"""
|
717
|
+
self.original_attn_processors = None
|
718
|
+
|
719
|
+
for _, attn_processor in self.attn_processors.items():
|
720
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
721
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
722
|
+
|
723
|
+
self.original_attn_processors = self.attn_processors
|
724
|
+
|
725
|
+
for module in self.modules():
|
726
|
+
if isinstance(module, Attention):
|
727
|
+
module.fuse_projections(fuse=True)
|
728
|
+
|
729
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
730
|
+
def unfuse_qkv_projections(self):
|
731
|
+
"""Disables the fused QKV projection if enabled.
|
732
|
+
|
733
|
+
<Tip warning={true}>
|
734
|
+
|
735
|
+
This API is 🧪 experimental.
|
736
|
+
|
737
|
+
</Tip>
|
738
|
+
|
739
|
+
"""
|
740
|
+
if self.original_attn_processors is not None:
|
741
|
+
self.set_attn_processor(self.original_attn_processors)
|
742
|
+
|
703
743
|
def forward(
|
704
744
|
self,
|
705
745
|
sample: torch.FloatTensor,
|
@@ -792,6 +832,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
792
832
|
|
793
833
|
emb = self.time_embedding(t_emb, timestep_cond)
|
794
834
|
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
835
|
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
795
836
|
|
796
837
|
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
797
838
|
if "image_embeds" not in added_cond_kwargs:
|
@@ -799,10 +840,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
799
840
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
800
841
|
)
|
801
842
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
802
|
-
image_embeds = self.encoder_hid_proj(image_embeds)
|
803
|
-
|
804
|
-
|
805
|
-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
843
|
+
image_embeds = self.encoder_hid_proj(image_embeds)
|
844
|
+
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
|
845
|
+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
806
846
|
|
807
847
|
# 2. pre-process
|
808
848
|
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|