diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +233 -6
- diffusers/callbacks.py +209 -0
- diffusers/commands/env.py +102 -6
- diffusers/configuration_utils.py +45 -16
- diffusers/dependency_versions_table.py +4 -3
- diffusers/image_processor.py +434 -110
- diffusers/loaders/__init__.py +42 -9
- diffusers/loaders/ip_adapter.py +626 -36
- diffusers/loaders/lora_base.py +900 -0
- diffusers/loaders/lora_conversion_utils.py +991 -125
- diffusers/loaders/lora_pipeline.py +3812 -0
- diffusers/loaders/peft.py +571 -7
- diffusers/loaders/single_file.py +405 -173
- diffusers/loaders/single_file_model.py +385 -0
- diffusers/loaders/single_file_utils.py +1783 -713
- diffusers/loaders/textual_inversion.py +41 -23
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +464 -540
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +76 -7
- diffusers/models/activations.py +65 -10
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +605 -18
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +4304 -687
- diffusers/models/autoencoders/__init__.py +8 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +110 -28
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
- diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
- diffusers/models/autoencoders/vae.py +41 -29
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet.py +47 -800
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +68 -0
- diffusers/models/controlnet_sparsectrl.py +116 -0
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/controlnets/controlnet_xs.py +1946 -0
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/downsampling.py +85 -18
- diffusers/models/embeddings.py +1856 -158
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +480 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +2 -7
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +611 -146
- diffusers/models/normalization.py +361 -20
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformers/__init__.py +16 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +9 -8
- diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +445 -0
- diffusers/models/transformers/prior_transformer.py +13 -13
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +297 -187
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +593 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +461 -0
- diffusers/models/transformers/transformer_temporal.py +21 -19
- diffusers/models/unets/unet_1d.py +8 -8
- diffusers/models/unets/unet_1d_blocks.py +31 -31
- diffusers/models/unets/unet_2d.py +17 -10
- diffusers/models/unets/unet_2d_blocks.py +225 -149
- diffusers/models/unets/unet_2d_condition.py +41 -40
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +192 -1057
- diffusers/models/unets/unet_3d_condition.py +22 -27
- diffusers/models/unets/unet_i2vgen_xl.py +22 -18
- diffusers/models/unets/unet_kandinsky3.py +2 -2
- diffusers/models/unets/unet_motion_model.py +1413 -89
- diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
- diffusers/models/unets/unet_stable_cascade.py +19 -18
- diffusers/models/unets/uvit_2d.py +2 -2
- diffusers/models/upsampling.py +95 -26
- diffusers/models/vq_model.py +12 -164
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +202 -3
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +8 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
- diffusers/pipelines/auto_pipeline.py +196 -28
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/cogvideo/__init__.py +54 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
- diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/flux/__init__.py +69 -0
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +957 -0
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +37 -0
- diffusers/pipelines/free_init_utils.py +41 -38
- diffusers/pipelines/free_noise_utils.py +596 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +338 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/pag/__init__.py +80 -0
- diffusers/pipelines/pag/pag_utils.py +243 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +74 -164
- diffusers/pipelines/pipeline_flax_utils.py +5 -10
- diffusers/pipelines/pipeline_loading_utils.py +515 -53
- diffusers/pipelines/pipeline_utils.py +411 -222
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
- diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/__init__.py +12 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +27 -26
- diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
- diffusers/schedulers/scheduling_ddpm.py +27 -30
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +150 -50
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
- diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
- diffusers/schedulers/scheduling_edm_euler.py +62 -39
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
- diffusers/schedulers/scheduling_euler_discrete.py +255 -74
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
- diffusers/schedulers/scheduling_heun_discrete.py +174 -46
- diffusers/schedulers/scheduling_ipndm.py +9 -9
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +23 -29
- diffusers/schedulers/scheduling_lms_discrete.py +105 -28
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +21 -21
- diffusers/schedulers/scheduling_sasolver.py +157 -60
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +41 -36
- diffusers/schedulers/scheduling_unclip.py +19 -16
- diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
- diffusers/schedulers/scheduling_utils.py +12 -5
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +214 -30
- diffusers/utils/__init__.py +17 -1
- diffusers/utils/constants.py +3 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +592 -7
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
- diffusers/utils/dynamic_modules_utils.py +34 -29
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +131 -17
- diffusers/utils/import_utils.py +210 -8
- diffusers/utils/loading_utils.py +118 -5
- diffusers/utils/logging.py +4 -2
- diffusers/utils/peft_utils.py +37 -7
- diffusers/utils/state_dict_utils.py +13 -2
- diffusers/utils/testing_utils.py +193 -11
- diffusers/utils/torch_utils.py +4 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
- diffusers-0.32.2.dist-info/RECORD +550 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1349
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -16,10 +16,11 @@ from typing import List, Optional, Tuple, Union
|
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
|
+
import torch.nn.functional as F
|
19
20
|
from torch import nn
|
20
21
|
|
21
22
|
from ..utils import deprecate
|
22
|
-
from .activations import get_activation
|
23
|
+
from .activations import FP32SiLU, get_activation
|
23
24
|
from .attention_processor import Attention
|
24
25
|
|
25
26
|
|
@@ -34,10 +35,21 @@ def get_timestep_embedding(
|
|
34
35
|
"""
|
35
36
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
36
37
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
38
|
+
Args
|
39
|
+
timesteps (torch.Tensor):
|
40
|
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
41
|
+
embedding_dim (int):
|
42
|
+
the dimension of the output.
|
43
|
+
flip_sin_to_cos (bool):
|
44
|
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
45
|
+
downscale_freq_shift (float):
|
46
|
+
Controls the delta between frequencies between dimensions
|
47
|
+
scale (float):
|
48
|
+
Scaling factor applied to the embeddings.
|
49
|
+
max_period (int):
|
50
|
+
Controls the maximum frequency of the embeddings
|
51
|
+
Returns
|
52
|
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
41
53
|
"""
|
42
54
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
43
55
|
|
@@ -66,12 +78,303 @@ def get_timestep_embedding(
|
|
66
78
|
return emb
|
67
79
|
|
68
80
|
|
81
|
+
def get_3d_sincos_pos_embed(
|
82
|
+
embed_dim: int,
|
83
|
+
spatial_size: Union[int, Tuple[int, int]],
|
84
|
+
temporal_size: int,
|
85
|
+
spatial_interpolation_scale: float = 1.0,
|
86
|
+
temporal_interpolation_scale: float = 1.0,
|
87
|
+
device: Optional[torch.device] = None,
|
88
|
+
output_type: str = "np",
|
89
|
+
) -> torch.Tensor:
|
90
|
+
r"""
|
91
|
+
Creates 3D sinusoidal positional embeddings.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
embed_dim (`int`):
|
95
|
+
The embedding dimension of inputs. It must be divisible by 16.
|
96
|
+
spatial_size (`int` or `Tuple[int, int]`):
|
97
|
+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
98
|
+
spatial dimensions (height and width).
|
99
|
+
temporal_size (`int`):
|
100
|
+
The temporal dimension of postional embeddings (number of frames).
|
101
|
+
spatial_interpolation_scale (`float`, defaults to 1.0):
|
102
|
+
Scale factor for spatial grid interpolation.
|
103
|
+
temporal_interpolation_scale (`float`, defaults to 1.0):
|
104
|
+
Scale factor for temporal grid interpolation.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
`torch.Tensor`:
|
108
|
+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
109
|
+
embed_dim]`.
|
110
|
+
"""
|
111
|
+
if output_type == "np":
|
112
|
+
return _get_3d_sincos_pos_embed_np(
|
113
|
+
embed_dim=embed_dim,
|
114
|
+
spatial_size=spatial_size,
|
115
|
+
temporal_size=temporal_size,
|
116
|
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
117
|
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
118
|
+
)
|
119
|
+
if embed_dim % 4 != 0:
|
120
|
+
raise ValueError("`embed_dim` must be divisible by 4")
|
121
|
+
if isinstance(spatial_size, int):
|
122
|
+
spatial_size = (spatial_size, spatial_size)
|
123
|
+
|
124
|
+
embed_dim_spatial = 3 * embed_dim // 4
|
125
|
+
embed_dim_temporal = embed_dim // 4
|
126
|
+
|
127
|
+
# 1. Spatial
|
128
|
+
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
129
|
+
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
130
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
131
|
+
grid = torch.stack(grid, dim=0)
|
132
|
+
|
133
|
+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
134
|
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
|
135
|
+
|
136
|
+
# 2. Temporal
|
137
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
|
138
|
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
|
139
|
+
|
140
|
+
# 3. Concat
|
141
|
+
pos_embed_spatial = pos_embed_spatial[None, :, :]
|
142
|
+
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
|
143
|
+
|
144
|
+
pos_embed_temporal = pos_embed_temporal[:, None, :]
|
145
|
+
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
|
146
|
+
spatial_size[0] * spatial_size[1], dim=1
|
147
|
+
) # [T, H*W, D // 4]
|
148
|
+
|
149
|
+
pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
|
150
|
+
return pos_embed
|
151
|
+
|
152
|
+
|
153
|
+
def _get_3d_sincos_pos_embed_np(
|
154
|
+
embed_dim: int,
|
155
|
+
spatial_size: Union[int, Tuple[int, int]],
|
156
|
+
temporal_size: int,
|
157
|
+
spatial_interpolation_scale: float = 1.0,
|
158
|
+
temporal_interpolation_scale: float = 1.0,
|
159
|
+
) -> np.ndarray:
|
160
|
+
r"""
|
161
|
+
Creates 3D sinusoidal positional embeddings.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
embed_dim (`int`):
|
165
|
+
The embedding dimension of inputs. It must be divisible by 16.
|
166
|
+
spatial_size (`int` or `Tuple[int, int]`):
|
167
|
+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
168
|
+
spatial dimensions (height and width).
|
169
|
+
temporal_size (`int`):
|
170
|
+
The temporal dimension of postional embeddings (number of frames).
|
171
|
+
spatial_interpolation_scale (`float`, defaults to 1.0):
|
172
|
+
Scale factor for spatial grid interpolation.
|
173
|
+
temporal_interpolation_scale (`float`, defaults to 1.0):
|
174
|
+
Scale factor for temporal grid interpolation.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
`np.ndarray`:
|
178
|
+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
179
|
+
embed_dim]`.
|
180
|
+
"""
|
181
|
+
deprecation_message = (
|
182
|
+
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
|
183
|
+
" `from_numpy` is no longer required."
|
184
|
+
" Pass `output_type='pt' to use the new version now."
|
185
|
+
)
|
186
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
187
|
+
if embed_dim % 4 != 0:
|
188
|
+
raise ValueError("`embed_dim` must be divisible by 4")
|
189
|
+
if isinstance(spatial_size, int):
|
190
|
+
spatial_size = (spatial_size, spatial_size)
|
191
|
+
|
192
|
+
embed_dim_spatial = 3 * embed_dim // 4
|
193
|
+
embed_dim_temporal = embed_dim // 4
|
194
|
+
|
195
|
+
# 1. Spatial
|
196
|
+
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
|
197
|
+
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
|
198
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
199
|
+
grid = np.stack(grid, axis=0)
|
200
|
+
|
201
|
+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
202
|
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
203
|
+
|
204
|
+
# 2. Temporal
|
205
|
+
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
|
206
|
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
207
|
+
|
208
|
+
# 3. Concat
|
209
|
+
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
210
|
+
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
|
211
|
+
|
212
|
+
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
213
|
+
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
|
214
|
+
|
215
|
+
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
|
216
|
+
return pos_embed
|
217
|
+
|
218
|
+
|
69
219
|
def get_2d_sincos_pos_embed(
|
220
|
+
embed_dim,
|
221
|
+
grid_size,
|
222
|
+
cls_token=False,
|
223
|
+
extra_tokens=0,
|
224
|
+
interpolation_scale=1.0,
|
225
|
+
base_size=16,
|
226
|
+
device: Optional[torch.device] = None,
|
227
|
+
output_type: str = "np",
|
228
|
+
):
|
229
|
+
"""
|
230
|
+
Creates 2D sinusoidal positional embeddings.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
embed_dim (`int`):
|
234
|
+
The embedding dimension.
|
235
|
+
grid_size (`int`):
|
236
|
+
The size of the grid height and width.
|
237
|
+
cls_token (`bool`, defaults to `False`):
|
238
|
+
Whether or not to add a classification token.
|
239
|
+
extra_tokens (`int`, defaults to `0`):
|
240
|
+
The number of extra tokens to add.
|
241
|
+
interpolation_scale (`float`, defaults to `1.0`):
|
242
|
+
The scale of the interpolation.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
pos_embed (`torch.Tensor`):
|
246
|
+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
247
|
+
embed_dim]` if using cls_token
|
248
|
+
"""
|
249
|
+
if output_type == "np":
|
250
|
+
deprecation_message = (
|
251
|
+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
252
|
+
" `from_numpy` is no longer required."
|
253
|
+
" Pass `output_type='pt' to use the new version now."
|
254
|
+
)
|
255
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
256
|
+
return get_2d_sincos_pos_embed_np(
|
257
|
+
embed_dim=embed_dim,
|
258
|
+
grid_size=grid_size,
|
259
|
+
cls_token=cls_token,
|
260
|
+
extra_tokens=extra_tokens,
|
261
|
+
interpolation_scale=interpolation_scale,
|
262
|
+
base_size=base_size,
|
263
|
+
)
|
264
|
+
if isinstance(grid_size, int):
|
265
|
+
grid_size = (grid_size, grid_size)
|
266
|
+
|
267
|
+
grid_h = (
|
268
|
+
torch.arange(grid_size[0], device=device, dtype=torch.float32)
|
269
|
+
/ (grid_size[0] / base_size)
|
270
|
+
/ interpolation_scale
|
271
|
+
)
|
272
|
+
grid_w = (
|
273
|
+
torch.arange(grid_size[1], device=device, dtype=torch.float32)
|
274
|
+
/ (grid_size[1] / base_size)
|
275
|
+
/ interpolation_scale
|
276
|
+
)
|
277
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
278
|
+
grid = torch.stack(grid, dim=0)
|
279
|
+
|
280
|
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
281
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
|
282
|
+
if cls_token and extra_tokens > 0:
|
283
|
+
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
|
284
|
+
return pos_embed
|
285
|
+
|
286
|
+
|
287
|
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
288
|
+
r"""
|
289
|
+
This function generates 2D sinusoidal positional embeddings from a grid.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
embed_dim (`int`): The embedding dimension.
|
293
|
+
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
297
|
+
"""
|
298
|
+
if output_type == "np":
|
299
|
+
deprecation_message = (
|
300
|
+
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
301
|
+
" `from_numpy` is no longer required."
|
302
|
+
" Pass `output_type='pt' to use the new version now."
|
303
|
+
)
|
304
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
305
|
+
return get_2d_sincos_pos_embed_from_grid_np(
|
306
|
+
embed_dim=embed_dim,
|
307
|
+
grid=grid,
|
308
|
+
)
|
309
|
+
if embed_dim % 2 != 0:
|
310
|
+
raise ValueError("embed_dim must be divisible by 2")
|
311
|
+
|
312
|
+
# use half of dimensions to encode grid_h
|
313
|
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
|
314
|
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
|
315
|
+
|
316
|
+
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
|
317
|
+
return emb
|
318
|
+
|
319
|
+
|
320
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
321
|
+
"""
|
322
|
+
This function generates 1D positional embeddings from a grid.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
embed_dim (`int`): The embedding dimension `D`
|
326
|
+
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
|
330
|
+
"""
|
331
|
+
if output_type == "np":
|
332
|
+
deprecation_message = (
|
333
|
+
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
334
|
+
" `from_numpy` is no longer required."
|
335
|
+
" Pass `output_type='pt' to use the new version now."
|
336
|
+
)
|
337
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
338
|
+
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
339
|
+
if embed_dim % 2 != 0:
|
340
|
+
raise ValueError("embed_dim must be divisible by 2")
|
341
|
+
|
342
|
+
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
|
343
|
+
omega /= embed_dim / 2.0
|
344
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
345
|
+
|
346
|
+
pos = pos.reshape(-1) # (M,)
|
347
|
+
out = torch.outer(pos, omega) # (M, D/2), outer product
|
348
|
+
|
349
|
+
emb_sin = torch.sin(out) # (M, D/2)
|
350
|
+
emb_cos = torch.cos(out) # (M, D/2)
|
351
|
+
|
352
|
+
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
353
|
+
return emb
|
354
|
+
|
355
|
+
|
356
|
+
def get_2d_sincos_pos_embed_np(
|
70
357
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
71
358
|
):
|
72
359
|
"""
|
73
|
-
|
74
|
-
|
360
|
+
Creates 2D sinusoidal positional embeddings.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
embed_dim (`int`):
|
364
|
+
The embedding dimension.
|
365
|
+
grid_size (`int`):
|
366
|
+
The size of the grid height and width.
|
367
|
+
cls_token (`bool`, defaults to `False`):
|
368
|
+
Whether or not to add a classification token.
|
369
|
+
extra_tokens (`int`, defaults to `0`):
|
370
|
+
The number of extra tokens to add.
|
371
|
+
interpolation_scale (`float`, defaults to `1.0`):
|
372
|
+
The scale of the interpolation.
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
pos_embed (`np.ndarray`):
|
376
|
+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
377
|
+
embed_dim]` if using cls_token
|
75
378
|
"""
|
76
379
|
if isinstance(grid_size, int):
|
77
380
|
grid_size = (grid_size, grid_size)
|
@@ -82,27 +385,44 @@ def get_2d_sincos_pos_embed(
|
|
82
385
|
grid = np.stack(grid, axis=0)
|
83
386
|
|
84
387
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
85
|
-
pos_embed =
|
388
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
|
86
389
|
if cls_token and extra_tokens > 0:
|
87
390
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
88
391
|
return pos_embed
|
89
392
|
|
90
393
|
|
91
|
-
def
|
394
|
+
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
|
395
|
+
r"""
|
396
|
+
This function generates 2D sinusoidal positional embeddings from a grid.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
embed_dim (`int`): The embedding dimension.
|
400
|
+
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
|
401
|
+
|
402
|
+
Returns:
|
403
|
+
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
404
|
+
"""
|
92
405
|
if embed_dim % 2 != 0:
|
93
406
|
raise ValueError("embed_dim must be divisible by 2")
|
94
407
|
|
95
408
|
# use half of dimensions to encode grid_h
|
96
|
-
emb_h =
|
97
|
-
emb_w =
|
409
|
+
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
|
410
|
+
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
|
98
411
|
|
99
412
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
100
413
|
return emb
|
101
414
|
|
102
415
|
|
103
|
-
def
|
416
|
+
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
|
104
417
|
"""
|
105
|
-
|
418
|
+
This function generates 1D positional embeddings from a grid.
|
419
|
+
|
420
|
+
Args:
|
421
|
+
embed_dim (`int`): The embedding dimension `D`
|
422
|
+
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
|
423
|
+
|
424
|
+
Returns:
|
425
|
+
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
|
106
426
|
"""
|
107
427
|
if embed_dim % 2 != 0:
|
108
428
|
raise ValueError("embed_dim must be divisible by 2")
|
@@ -122,7 +442,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
122
442
|
|
123
443
|
|
124
444
|
class PatchEmbed(nn.Module):
|
125
|
-
"""
|
445
|
+
"""
|
446
|
+
2D Image to Patch Embedding with support for SD3 cropping.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
height (`int`, defaults to `224`): The height of the image.
|
450
|
+
width (`int`, defaults to `224`): The width of the image.
|
451
|
+
patch_size (`int`, defaults to `16`): The size of the patches.
|
452
|
+
in_channels (`int`, defaults to `3`): The number of input channels.
|
453
|
+
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
454
|
+
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
|
455
|
+
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
|
456
|
+
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
457
|
+
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
|
458
|
+
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
|
459
|
+
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
|
460
|
+
"""
|
126
461
|
|
127
462
|
def __init__(
|
128
463
|
self,
|
@@ -135,12 +470,15 @@ class PatchEmbed(nn.Module):
|
|
135
470
|
flatten=True,
|
136
471
|
bias=True,
|
137
472
|
interpolation_scale=1,
|
473
|
+
pos_embed_type="sincos",
|
474
|
+
pos_embed_max_size=None, # For SD3 cropping
|
138
475
|
):
|
139
476
|
super().__init__()
|
140
477
|
|
141
478
|
num_patches = (height // patch_size) * (width // patch_size)
|
142
479
|
self.flatten = flatten
|
143
480
|
self.layer_norm = layer_norm
|
481
|
+
self.pos_embed_max_size = pos_embed_max_size
|
144
482
|
|
145
483
|
self.proj = nn.Conv2d(
|
146
484
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
@@ -151,40 +489,780 @@ class PatchEmbed(nn.Module):
|
|
151
489
|
self.norm = None
|
152
490
|
|
153
491
|
self.patch_size = patch_size
|
154
|
-
# See:
|
155
|
-
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
156
492
|
self.height, self.width = height // patch_size, width // patch_size
|
157
493
|
self.base_size = height // patch_size
|
158
494
|
self.interpolation_scale = interpolation_scale
|
159
|
-
pos_embed = get_2d_sincos_pos_embed(
|
160
|
-
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
161
|
-
)
|
162
|
-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
163
495
|
|
164
|
-
|
165
|
-
|
496
|
+
# Calculate positional embeddings based on max size or default
|
497
|
+
if pos_embed_max_size:
|
498
|
+
grid_size = pos_embed_max_size
|
499
|
+
else:
|
500
|
+
grid_size = int(num_patches**0.5)
|
501
|
+
|
502
|
+
if pos_embed_type is None:
|
503
|
+
self.pos_embed = None
|
504
|
+
elif pos_embed_type == "sincos":
|
505
|
+
pos_embed = get_2d_sincos_pos_embed(
|
506
|
+
embed_dim,
|
507
|
+
grid_size,
|
508
|
+
base_size=self.base_size,
|
509
|
+
interpolation_scale=self.interpolation_scale,
|
510
|
+
output_type="pt",
|
511
|
+
)
|
512
|
+
persistent = True if pos_embed_max_size else False
|
513
|
+
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
|
514
|
+
else:
|
515
|
+
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
166
516
|
|
517
|
+
def cropped_pos_embed(self, height, width):
|
518
|
+
"""Crops positional embeddings for SD3 compatibility."""
|
519
|
+
if self.pos_embed_max_size is None:
|
520
|
+
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
521
|
+
|
522
|
+
height = height // self.patch_size
|
523
|
+
width = width // self.patch_size
|
524
|
+
if height > self.pos_embed_max_size:
|
525
|
+
raise ValueError(
|
526
|
+
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
527
|
+
)
|
528
|
+
if width > self.pos_embed_max_size:
|
529
|
+
raise ValueError(
|
530
|
+
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
531
|
+
)
|
532
|
+
|
533
|
+
top = (self.pos_embed_max_size - height) // 2
|
534
|
+
left = (self.pos_embed_max_size - width) // 2
|
535
|
+
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
536
|
+
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
537
|
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
538
|
+
return spatial_pos_embed
|
539
|
+
|
540
|
+
def forward(self, latent):
|
541
|
+
if self.pos_embed_max_size is not None:
|
542
|
+
height, width = latent.shape[-2:]
|
543
|
+
else:
|
544
|
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
167
545
|
latent = self.proj(latent)
|
168
546
|
if self.flatten:
|
169
547
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
170
548
|
if self.layer_norm:
|
171
549
|
latent = self.norm(latent)
|
550
|
+
if self.pos_embed is None:
|
551
|
+
return latent.to(latent.dtype)
|
552
|
+
# Interpolate or crop positional embeddings as needed
|
553
|
+
if self.pos_embed_max_size:
|
554
|
+
pos_embed = self.cropped_pos_embed(height, width)
|
555
|
+
else:
|
556
|
+
if self.height != height or self.width != width:
|
557
|
+
pos_embed = get_2d_sincos_pos_embed(
|
558
|
+
embed_dim=self.pos_embed.shape[-1],
|
559
|
+
grid_size=(height, width),
|
560
|
+
base_size=self.base_size,
|
561
|
+
interpolation_scale=self.interpolation_scale,
|
562
|
+
device=latent.device,
|
563
|
+
output_type="pt",
|
564
|
+
)
|
565
|
+
pos_embed = pos_embed.float().unsqueeze(0)
|
566
|
+
else:
|
567
|
+
pos_embed = self.pos_embed
|
172
568
|
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
569
|
+
return (latent + pos_embed).to(latent.dtype)
|
570
|
+
|
571
|
+
|
572
|
+
class LuminaPatchEmbed(nn.Module):
|
573
|
+
"""
|
574
|
+
2D Image to Patch Embedding with support for Lumina-T2X
|
575
|
+
|
576
|
+
Args:
|
577
|
+
patch_size (`int`, defaults to `2`): The size of the patches.
|
578
|
+
in_channels (`int`, defaults to `4`): The number of input channels.
|
579
|
+
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
580
|
+
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
581
|
+
"""
|
582
|
+
|
583
|
+
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
584
|
+
super().__init__()
|
585
|
+
self.patch_size = patch_size
|
586
|
+
self.proj = nn.Linear(
|
587
|
+
in_features=patch_size * patch_size * in_channels,
|
588
|
+
out_features=embed_dim,
|
589
|
+
bias=bias,
|
590
|
+
)
|
591
|
+
|
592
|
+
def forward(self, x, freqs_cis):
|
593
|
+
"""
|
594
|
+
Patchifies and embeds the input tensor(s).
|
595
|
+
|
596
|
+
Args:
|
597
|
+
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
598
|
+
|
599
|
+
Returns:
|
600
|
+
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
601
|
+
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
|
602
|
+
frequency tensor(s).
|
603
|
+
"""
|
604
|
+
freqs_cis = freqs_cis.to(x[0].device)
|
605
|
+
patch_height = patch_width = self.patch_size
|
606
|
+
batch_size, channel, height, width = x.size()
|
607
|
+
height_tokens, width_tokens = height // patch_height, width // patch_width
|
608
|
+
|
609
|
+
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
|
610
|
+
0, 2, 4, 1, 3, 5
|
611
|
+
)
|
612
|
+
x = x.flatten(3)
|
613
|
+
x = self.proj(x)
|
614
|
+
x = x.flatten(1, 2)
|
615
|
+
|
616
|
+
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
|
617
|
+
|
618
|
+
return (
|
619
|
+
x,
|
620
|
+
mask,
|
621
|
+
[(height, width)] * batch_size,
|
622
|
+
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
|
623
|
+
)
|
624
|
+
|
625
|
+
|
626
|
+
class CogVideoXPatchEmbed(nn.Module):
|
627
|
+
def __init__(
|
628
|
+
self,
|
629
|
+
patch_size: int = 2,
|
630
|
+
patch_size_t: Optional[int] = None,
|
631
|
+
in_channels: int = 16,
|
632
|
+
embed_dim: int = 1920,
|
633
|
+
text_embed_dim: int = 4096,
|
634
|
+
bias: bool = True,
|
635
|
+
sample_width: int = 90,
|
636
|
+
sample_height: int = 60,
|
637
|
+
sample_frames: int = 49,
|
638
|
+
temporal_compression_ratio: int = 4,
|
639
|
+
max_text_seq_length: int = 226,
|
640
|
+
spatial_interpolation_scale: float = 1.875,
|
641
|
+
temporal_interpolation_scale: float = 1.0,
|
642
|
+
use_positional_embeddings: bool = True,
|
643
|
+
use_learned_positional_embeddings: bool = True,
|
644
|
+
) -> None:
|
645
|
+
super().__init__()
|
646
|
+
|
647
|
+
self.patch_size = patch_size
|
648
|
+
self.patch_size_t = patch_size_t
|
649
|
+
self.embed_dim = embed_dim
|
650
|
+
self.sample_height = sample_height
|
651
|
+
self.sample_width = sample_width
|
652
|
+
self.sample_frames = sample_frames
|
653
|
+
self.temporal_compression_ratio = temporal_compression_ratio
|
654
|
+
self.max_text_seq_length = max_text_seq_length
|
655
|
+
self.spatial_interpolation_scale = spatial_interpolation_scale
|
656
|
+
self.temporal_interpolation_scale = temporal_interpolation_scale
|
657
|
+
self.use_positional_embeddings = use_positional_embeddings
|
658
|
+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
659
|
+
|
660
|
+
if patch_size_t is None:
|
661
|
+
# CogVideoX 1.0 checkpoints
|
662
|
+
self.proj = nn.Conv2d(
|
663
|
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
181
664
|
)
|
182
|
-
pos_embed = torch.from_numpy(pos_embed)
|
183
|
-
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
184
665
|
else:
|
185
|
-
|
666
|
+
# CogVideoX 1.5 checkpoints
|
667
|
+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
668
|
+
|
669
|
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
670
|
+
|
671
|
+
if use_positional_embeddings or use_learned_positional_embeddings:
|
672
|
+
persistent = use_learned_positional_embeddings
|
673
|
+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
674
|
+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
675
|
+
|
676
|
+
def _get_positional_embeddings(
|
677
|
+
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
|
678
|
+
) -> torch.Tensor:
|
679
|
+
post_patch_height = sample_height // self.patch_size
|
680
|
+
post_patch_width = sample_width // self.patch_size
|
681
|
+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
682
|
+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
683
|
+
|
684
|
+
pos_embedding = get_3d_sincos_pos_embed(
|
685
|
+
self.embed_dim,
|
686
|
+
(post_patch_width, post_patch_height),
|
687
|
+
post_time_compression_frames,
|
688
|
+
self.spatial_interpolation_scale,
|
689
|
+
self.temporal_interpolation_scale,
|
690
|
+
device=device,
|
691
|
+
output_type="pt",
|
692
|
+
)
|
693
|
+
pos_embedding = pos_embedding.flatten(0, 1)
|
694
|
+
joint_pos_embedding = pos_embedding.new_zeros(
|
695
|
+
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
696
|
+
)
|
697
|
+
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
186
698
|
|
187
|
-
return
|
699
|
+
return joint_pos_embedding
|
700
|
+
|
701
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
702
|
+
r"""
|
703
|
+
Args:
|
704
|
+
text_embeds (`torch.Tensor`):
|
705
|
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
706
|
+
image_embeds (`torch.Tensor`):
|
707
|
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
708
|
+
"""
|
709
|
+
text_embeds = self.text_proj(text_embeds)
|
710
|
+
|
711
|
+
batch_size, num_frames, channels, height, width = image_embeds.shape
|
712
|
+
|
713
|
+
if self.patch_size_t is None:
|
714
|
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
715
|
+
image_embeds = self.proj(image_embeds)
|
716
|
+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
717
|
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
718
|
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
719
|
+
else:
|
720
|
+
p = self.patch_size
|
721
|
+
p_t = self.patch_size_t
|
722
|
+
|
723
|
+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
724
|
+
image_embeds = image_embeds.reshape(
|
725
|
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
726
|
+
)
|
727
|
+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
728
|
+
image_embeds = self.proj(image_embeds)
|
729
|
+
|
730
|
+
embeds = torch.cat(
|
731
|
+
[text_embeds, image_embeds], dim=1
|
732
|
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
733
|
+
|
734
|
+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
735
|
+
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
|
736
|
+
raise ValueError(
|
737
|
+
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
|
738
|
+
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
|
739
|
+
)
|
740
|
+
|
741
|
+
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
742
|
+
|
743
|
+
if (
|
744
|
+
self.sample_height != height
|
745
|
+
or self.sample_width != width
|
746
|
+
or self.sample_frames != pre_time_compression_frames
|
747
|
+
):
|
748
|
+
pos_embedding = self._get_positional_embeddings(
|
749
|
+
height, width, pre_time_compression_frames, device=embeds.device
|
750
|
+
)
|
751
|
+
else:
|
752
|
+
pos_embedding = self.pos_embedding
|
753
|
+
|
754
|
+
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
755
|
+
embeds = embeds + pos_embedding
|
756
|
+
|
757
|
+
return embeds
|
758
|
+
|
759
|
+
|
760
|
+
class CogView3PlusPatchEmbed(nn.Module):
|
761
|
+
def __init__(
|
762
|
+
self,
|
763
|
+
in_channels: int = 16,
|
764
|
+
hidden_size: int = 2560,
|
765
|
+
patch_size: int = 2,
|
766
|
+
text_hidden_size: int = 4096,
|
767
|
+
pos_embed_max_size: int = 128,
|
768
|
+
):
|
769
|
+
super().__init__()
|
770
|
+
self.in_channels = in_channels
|
771
|
+
self.hidden_size = hidden_size
|
772
|
+
self.patch_size = patch_size
|
773
|
+
self.text_hidden_size = text_hidden_size
|
774
|
+
self.pos_embed_max_size = pos_embed_max_size
|
775
|
+
# Linear projection for image patches
|
776
|
+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
777
|
+
|
778
|
+
# Linear projection for text embeddings
|
779
|
+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
780
|
+
|
781
|
+
pos_embed = get_2d_sincos_pos_embed(
|
782
|
+
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
|
783
|
+
)
|
784
|
+
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
|
785
|
+
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
|
786
|
+
|
787
|
+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
788
|
+
batch_size, channel, height, width = hidden_states.shape
|
789
|
+
|
790
|
+
if height % self.patch_size != 0 or width % self.patch_size != 0:
|
791
|
+
raise ValueError("Height and width must be divisible by patch size")
|
792
|
+
|
793
|
+
height = height // self.patch_size
|
794
|
+
width = width // self.patch_size
|
795
|
+
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
|
796
|
+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
|
797
|
+
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
|
798
|
+
|
799
|
+
# Project the patches
|
800
|
+
hidden_states = self.proj(hidden_states)
|
801
|
+
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
802
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
803
|
+
|
804
|
+
# Calculate text_length
|
805
|
+
text_length = encoder_hidden_states.shape[1]
|
806
|
+
|
807
|
+
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
|
808
|
+
text_pos_embed = torch.zeros(
|
809
|
+
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
|
810
|
+
)
|
811
|
+
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
|
812
|
+
|
813
|
+
return (hidden_states + pos_embed).to(hidden_states.dtype)
|
814
|
+
|
815
|
+
|
816
|
+
def get_3d_rotary_pos_embed(
|
817
|
+
embed_dim,
|
818
|
+
crops_coords,
|
819
|
+
grid_size,
|
820
|
+
temporal_size,
|
821
|
+
theta: int = 10000,
|
822
|
+
use_real: bool = True,
|
823
|
+
grid_type: str = "linspace",
|
824
|
+
max_size: Optional[Tuple[int, int]] = None,
|
825
|
+
device: Optional[torch.device] = None,
|
826
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
827
|
+
"""
|
828
|
+
RoPE for video tokens with 3D structure.
|
829
|
+
|
830
|
+
Args:
|
831
|
+
embed_dim: (`int`):
|
832
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
833
|
+
crops_coords (`Tuple[int]`):
|
834
|
+
The top-left and bottom-right coordinates of the crop.
|
835
|
+
grid_size (`Tuple[int]`):
|
836
|
+
The grid size of the spatial positional embedding (height, width).
|
837
|
+
temporal_size (`int`):
|
838
|
+
The size of the temporal dimension.
|
839
|
+
theta (`float`):
|
840
|
+
Scaling factor for frequency computation.
|
841
|
+
grid_type (`str`):
|
842
|
+
Whether to use "linspace" or "slice" to compute grids.
|
843
|
+
|
844
|
+
Returns:
|
845
|
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
846
|
+
"""
|
847
|
+
if use_real is not True:
|
848
|
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
849
|
+
|
850
|
+
if grid_type == "linspace":
|
851
|
+
start, stop = crops_coords
|
852
|
+
grid_size_h, grid_size_w = grid_size
|
853
|
+
grid_h = torch.linspace(
|
854
|
+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
855
|
+
)
|
856
|
+
grid_w = torch.linspace(
|
857
|
+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
858
|
+
)
|
859
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
860
|
+
grid_t = torch.linspace(
|
861
|
+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
862
|
+
)
|
863
|
+
elif grid_type == "slice":
|
864
|
+
max_h, max_w = max_size
|
865
|
+
grid_size_h, grid_size_w = grid_size
|
866
|
+
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
867
|
+
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
868
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
869
|
+
else:
|
870
|
+
raise ValueError("Invalid value passed for `grid_type`.")
|
871
|
+
|
872
|
+
# Compute dimensions for each axis
|
873
|
+
dim_t = embed_dim // 4
|
874
|
+
dim_h = embed_dim // 8 * 3
|
875
|
+
dim_w = embed_dim // 8 * 3
|
876
|
+
|
877
|
+
# Temporal frequencies
|
878
|
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
|
879
|
+
# Spatial frequencies for height and width
|
880
|
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
|
881
|
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
|
882
|
+
|
883
|
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
884
|
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
885
|
+
freqs_t = freqs_t[:, None, None, :].expand(
|
886
|
+
-1, grid_size_h, grid_size_w, -1
|
887
|
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
888
|
+
freqs_h = freqs_h[None, :, None, :].expand(
|
889
|
+
temporal_size, -1, grid_size_w, -1
|
890
|
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
891
|
+
freqs_w = freqs_w[None, None, :, :].expand(
|
892
|
+
temporal_size, grid_size_h, -1, -1
|
893
|
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
894
|
+
|
895
|
+
freqs = torch.cat(
|
896
|
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
897
|
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
898
|
+
freqs = freqs.view(
|
899
|
+
temporal_size * grid_size_h * grid_size_w, -1
|
900
|
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
901
|
+
return freqs
|
902
|
+
|
903
|
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
904
|
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
905
|
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
906
|
+
|
907
|
+
if grid_type == "slice":
|
908
|
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
909
|
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
910
|
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
911
|
+
|
912
|
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
913
|
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
914
|
+
return cos, sin
|
915
|
+
|
916
|
+
|
917
|
+
def get_3d_rotary_pos_embed_allegro(
|
918
|
+
embed_dim,
|
919
|
+
crops_coords,
|
920
|
+
grid_size,
|
921
|
+
temporal_size,
|
922
|
+
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
923
|
+
theta: int = 10000,
|
924
|
+
device: Optional[torch.device] = None,
|
925
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
926
|
+
# TODO(aryan): docs
|
927
|
+
start, stop = crops_coords
|
928
|
+
grid_size_h, grid_size_w = grid_size
|
929
|
+
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
|
930
|
+
grid_t = torch.linspace(
|
931
|
+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
932
|
+
)
|
933
|
+
grid_h = torch.linspace(
|
934
|
+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
935
|
+
)
|
936
|
+
grid_w = torch.linspace(
|
937
|
+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
938
|
+
)
|
939
|
+
|
940
|
+
# Compute dimensions for each axis
|
941
|
+
dim_t = embed_dim // 3
|
942
|
+
dim_h = embed_dim // 3
|
943
|
+
dim_w = embed_dim // 3
|
944
|
+
|
945
|
+
# Temporal frequencies
|
946
|
+
freqs_t = get_1d_rotary_pos_embed(
|
947
|
+
dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
|
948
|
+
)
|
949
|
+
# Spatial frequencies for height and width
|
950
|
+
freqs_h = get_1d_rotary_pos_embed(
|
951
|
+
dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
|
952
|
+
)
|
953
|
+
freqs_w = get_1d_rotary_pos_embed(
|
954
|
+
dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
|
955
|
+
)
|
956
|
+
|
957
|
+
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
|
958
|
+
|
959
|
+
|
960
|
+
def get_2d_rotary_pos_embed(
|
961
|
+
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
|
962
|
+
):
|
963
|
+
"""
|
964
|
+
RoPE for image tokens with 2d structure.
|
965
|
+
|
966
|
+
Args:
|
967
|
+
embed_dim: (`int`):
|
968
|
+
The embedding dimension size
|
969
|
+
crops_coords (`Tuple[int]`)
|
970
|
+
The top-left and bottom-right coordinates of the crop.
|
971
|
+
grid_size (`Tuple[int]`):
|
972
|
+
The grid size of the positional embedding.
|
973
|
+
use_real (`bool`):
|
974
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
975
|
+
device: (`torch.device`, **optional**):
|
976
|
+
The device used to create tensors.
|
977
|
+
|
978
|
+
Returns:
|
979
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
980
|
+
"""
|
981
|
+
if output_type == "np":
|
982
|
+
deprecation_message = (
|
983
|
+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
984
|
+
" `from_numpy` is no longer required."
|
985
|
+
" Pass `output_type='pt' to use the new version now."
|
986
|
+
)
|
987
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
988
|
+
return _get_2d_rotary_pos_embed_np(
|
989
|
+
embed_dim=embed_dim,
|
990
|
+
crops_coords=crops_coords,
|
991
|
+
grid_size=grid_size,
|
992
|
+
use_real=use_real,
|
993
|
+
)
|
994
|
+
start, stop = crops_coords
|
995
|
+
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
|
996
|
+
grid_h = torch.linspace(
|
997
|
+
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
|
998
|
+
)
|
999
|
+
grid_w = torch.linspace(
|
1000
|
+
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
|
1001
|
+
)
|
1002
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
1003
|
+
grid = torch.stack(grid, dim=0) # [2, W, H]
|
1004
|
+
|
1005
|
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
1006
|
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
1007
|
+
return pos_embed
|
1008
|
+
|
1009
|
+
|
1010
|
+
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
|
1011
|
+
"""
|
1012
|
+
RoPE for image tokens with 2d structure.
|
1013
|
+
|
1014
|
+
Args:
|
1015
|
+
embed_dim: (`int`):
|
1016
|
+
The embedding dimension size
|
1017
|
+
crops_coords (`Tuple[int]`)
|
1018
|
+
The top-left and bottom-right coordinates of the crop.
|
1019
|
+
grid_size (`Tuple[int]`):
|
1020
|
+
The grid size of the positional embedding.
|
1021
|
+
use_real (`bool`):
|
1022
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
1023
|
+
|
1024
|
+
Returns:
|
1025
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
1026
|
+
"""
|
1027
|
+
start, stop = crops_coords
|
1028
|
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
1029
|
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
1030
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
1031
|
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
1032
|
+
|
1033
|
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
1034
|
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
1035
|
+
return pos_embed
|
1036
|
+
|
1037
|
+
|
1038
|
+
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
1039
|
+
"""
|
1040
|
+
Get 2D RoPE from grid.
|
1041
|
+
|
1042
|
+
Args:
|
1043
|
+
embed_dim: (`int`):
|
1044
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
1045
|
+
grid (`np.ndarray`):
|
1046
|
+
The grid of the positional embedding.
|
1047
|
+
use_real (`bool`):
|
1048
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
1049
|
+
|
1050
|
+
Returns:
|
1051
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
1052
|
+
"""
|
1053
|
+
assert embed_dim % 4 == 0
|
1054
|
+
|
1055
|
+
# use half of dimensions to encode grid_h
|
1056
|
+
emb_h = get_1d_rotary_pos_embed(
|
1057
|
+
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
1058
|
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
1059
|
+
emb_w = get_1d_rotary_pos_embed(
|
1060
|
+
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
1061
|
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
1062
|
+
|
1063
|
+
if use_real:
|
1064
|
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
1065
|
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
1066
|
+
return cos, sin
|
1067
|
+
else:
|
1068
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
1069
|
+
return emb
|
1070
|
+
|
1071
|
+
|
1072
|
+
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
1073
|
+
"""
|
1074
|
+
Get 2D RoPE from grid.
|
1075
|
+
|
1076
|
+
Args:
|
1077
|
+
embed_dim: (`int`):
|
1078
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
1079
|
+
grid (`np.ndarray`):
|
1080
|
+
The grid of the positional embedding.
|
1081
|
+
linear_factor (`float`):
|
1082
|
+
The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
|
1083
|
+
layer.
|
1084
|
+
ntk_factor (`float`):
|
1085
|
+
The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
|
1086
|
+
|
1087
|
+
Returns:
|
1088
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
1089
|
+
"""
|
1090
|
+
assert embed_dim % 4 == 0
|
1091
|
+
|
1092
|
+
emb_h = get_1d_rotary_pos_embed(
|
1093
|
+
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
1094
|
+
) # (H, D/4)
|
1095
|
+
emb_w = get_1d_rotary_pos_embed(
|
1096
|
+
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
1097
|
+
) # (W, D/4)
|
1098
|
+
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
1099
|
+
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
1100
|
+
|
1101
|
+
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
1102
|
+
return emb
|
1103
|
+
|
1104
|
+
|
1105
|
+
def get_1d_rotary_pos_embed(
|
1106
|
+
dim: int,
|
1107
|
+
pos: Union[np.ndarray, int],
|
1108
|
+
theta: float = 10000.0,
|
1109
|
+
use_real=False,
|
1110
|
+
linear_factor=1.0,
|
1111
|
+
ntk_factor=1.0,
|
1112
|
+
repeat_interleave_real=True,
|
1113
|
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
1114
|
+
):
|
1115
|
+
"""
|
1116
|
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
1117
|
+
|
1118
|
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
1119
|
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
1120
|
+
data type.
|
1121
|
+
|
1122
|
+
Args:
|
1123
|
+
dim (`int`): Dimension of the frequency tensor.
|
1124
|
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
1125
|
+
theta (`float`, *optional*, defaults to 10000.0):
|
1126
|
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
1127
|
+
use_real (`bool`, *optional*):
|
1128
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
1129
|
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
1130
|
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
1131
|
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
1132
|
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
1133
|
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
1134
|
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
1135
|
+
Otherwise, they are concateanted with themselves.
|
1136
|
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
1137
|
+
the dtype of the frequency tensor.
|
1138
|
+
Returns:
|
1139
|
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
1140
|
+
"""
|
1141
|
+
assert dim % 2 == 0
|
1142
|
+
|
1143
|
+
if isinstance(pos, int):
|
1144
|
+
pos = torch.arange(pos)
|
1145
|
+
if isinstance(pos, np.ndarray):
|
1146
|
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
1147
|
+
|
1148
|
+
theta = theta * ntk_factor
|
1149
|
+
freqs = (
|
1150
|
+
1.0
|
1151
|
+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
1152
|
+
/ linear_factor
|
1153
|
+
) # [D/2]
|
1154
|
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
1155
|
+
if use_real and repeat_interleave_real:
|
1156
|
+
# flux, hunyuan-dit, cogvideox
|
1157
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
1158
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
1159
|
+
return freqs_cos, freqs_sin
|
1160
|
+
elif use_real:
|
1161
|
+
# stable audio, allegro
|
1162
|
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
1163
|
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
1164
|
+
return freqs_cos, freqs_sin
|
1165
|
+
else:
|
1166
|
+
# lumina
|
1167
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
1168
|
+
return freqs_cis
|
1169
|
+
|
1170
|
+
|
1171
|
+
def apply_rotary_emb(
|
1172
|
+
x: torch.Tensor,
|
1173
|
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
1174
|
+
use_real: bool = True,
|
1175
|
+
use_real_unbind_dim: int = -1,
|
1176
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1177
|
+
"""
|
1178
|
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
1179
|
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
1180
|
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
1181
|
+
tensors contain rotary embeddings and are returned as real tensors.
|
1182
|
+
|
1183
|
+
Args:
|
1184
|
+
x (`torch.Tensor`):
|
1185
|
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
1186
|
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
1187
|
+
|
1188
|
+
Returns:
|
1189
|
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
1190
|
+
"""
|
1191
|
+
if use_real:
|
1192
|
+
cos, sin = freqs_cis # [S, D]
|
1193
|
+
cos = cos[None, None]
|
1194
|
+
sin = sin[None, None]
|
1195
|
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
1196
|
+
|
1197
|
+
if use_real_unbind_dim == -1:
|
1198
|
+
# Used for flux, cogvideox, hunyuan-dit
|
1199
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
1200
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
1201
|
+
elif use_real_unbind_dim == -2:
|
1202
|
+
# Used for Stable Audio
|
1203
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
1204
|
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
1205
|
+
else:
|
1206
|
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
1207
|
+
|
1208
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
1209
|
+
|
1210
|
+
return out
|
1211
|
+
else:
|
1212
|
+
# used for lumina
|
1213
|
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
1214
|
+
freqs_cis = freqs_cis.unsqueeze(2)
|
1215
|
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
1216
|
+
|
1217
|
+
return x_out.type_as(x)
|
1218
|
+
|
1219
|
+
|
1220
|
+
def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
1221
|
+
# TODO(aryan): rewrite
|
1222
|
+
def apply_1d_rope(tokens, pos, cos, sin):
|
1223
|
+
cos = F.embedding(pos, cos)[:, None, :, :]
|
1224
|
+
sin = F.embedding(pos, sin)[:, None, :, :]
|
1225
|
+
x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
|
1226
|
+
tokens_rotated = torch.cat((-x2, x1), dim=-1)
|
1227
|
+
return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
|
1228
|
+
|
1229
|
+
(t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
|
1230
|
+
t, h, w = x.chunk(3, dim=-1)
|
1231
|
+
t = apply_1d_rope(t, positions[0], t_cos, t_sin)
|
1232
|
+
h = apply_1d_rope(h, positions[1], h_cos, h_sin)
|
1233
|
+
w = apply_1d_rope(w, positions[2], w_cos, w_sin)
|
1234
|
+
x = torch.cat([t, h, w], dim=-1)
|
1235
|
+
return x
|
1236
|
+
|
1237
|
+
|
1238
|
+
class FluxPosEmbed(nn.Module):
|
1239
|
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
1240
|
+
def __init__(self, theta: int, axes_dim: List[int]):
|
1241
|
+
super().__init__()
|
1242
|
+
self.theta = theta
|
1243
|
+
self.axes_dim = axes_dim
|
1244
|
+
|
1245
|
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
1246
|
+
n_axes = ids.shape[-1]
|
1247
|
+
cos_out = []
|
1248
|
+
sin_out = []
|
1249
|
+
pos = ids.float()
|
1250
|
+
is_mps = ids.device.type == "mps"
|
1251
|
+
freqs_dtype = torch.float32 if is_mps else torch.float64
|
1252
|
+
for i in range(n_axes):
|
1253
|
+
cos, sin = get_1d_rotary_pos_embed(
|
1254
|
+
self.axes_dim[i],
|
1255
|
+
pos[:, i],
|
1256
|
+
theta=self.theta,
|
1257
|
+
repeat_interleave_real=True,
|
1258
|
+
use_real=True,
|
1259
|
+
freqs_dtype=freqs_dtype,
|
1260
|
+
)
|
1261
|
+
cos_out.append(cos)
|
1262
|
+
sin_out.append(sin)
|
1263
|
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
1264
|
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
1265
|
+
return freqs_cos, freqs_sin
|
188
1266
|
|
189
1267
|
|
190
1268
|
class TimestepEmbedding(nn.Module):
|
@@ -199,9 +1277,8 @@ class TimestepEmbedding(nn.Module):
|
|
199
1277
|
sample_proj_bias=True,
|
200
1278
|
):
|
201
1279
|
super().__init__()
|
202
|
-
linear_cls = nn.Linear
|
203
1280
|
|
204
|
-
self.linear_1 =
|
1281
|
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
205
1282
|
|
206
1283
|
if cond_proj_dim is not None:
|
207
1284
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
@@ -214,7 +1291,7 @@ class TimestepEmbedding(nn.Module):
|
|
214
1291
|
time_embed_dim_out = out_dim
|
215
1292
|
else:
|
216
1293
|
time_embed_dim_out = time_embed_dim
|
217
|
-
self.linear_2 =
|
1294
|
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
218
1295
|
|
219
1296
|
if post_act_fn is None:
|
220
1297
|
self.post_act = None
|
@@ -237,11 +1314,12 @@ class TimestepEmbedding(nn.Module):
|
|
237
1314
|
|
238
1315
|
|
239
1316
|
class Timesteps(nn.Module):
|
240
|
-
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
1317
|
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
241
1318
|
super().__init__()
|
242
1319
|
self.num_channels = num_channels
|
243
1320
|
self.flip_sin_to_cos = flip_sin_to_cos
|
244
1321
|
self.downscale_freq_shift = downscale_freq_shift
|
1322
|
+
self.scale = scale
|
245
1323
|
|
246
1324
|
def forward(self, timesteps):
|
247
1325
|
t_emb = get_timestep_embedding(
|
@@ -249,6 +1327,7 @@ class Timesteps(nn.Module):
|
|
249
1327
|
self.num_channels,
|
250
1328
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
251
1329
|
downscale_freq_shift=self.downscale_freq_shift,
|
1330
|
+
scale=self.scale,
|
252
1331
|
)
|
253
1332
|
return t_emb
|
254
1333
|
|
@@ -266,9 +1345,10 @@ class GaussianFourierProjection(nn.Module):
|
|
266
1345
|
|
267
1346
|
if set_W_to_weight:
|
268
1347
|
# to delete later
|
1348
|
+
del self.weight
|
269
1349
|
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
270
|
-
|
271
1350
|
self.weight = self.W
|
1351
|
+
del self.W
|
272
1352
|
|
273
1353
|
def forward(self, x):
|
274
1354
|
if self.log:
|
@@ -392,106 +1472,368 @@ class LabelEmbedding(nn.Module):
|
|
392
1472
|
self.num_classes = num_classes
|
393
1473
|
self.dropout_prob = dropout_prob
|
394
1474
|
|
395
|
-
def token_drop(self, labels, force_drop_ids=None):
|
396
|
-
"""
|
397
|
-
Drops labels to enable classifier-free guidance.
|
398
|
-
"""
|
399
|
-
if force_drop_ids is None:
|
400
|
-
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
401
|
-
else:
|
402
|
-
drop_ids = torch.tensor(force_drop_ids == 1)
|
403
|
-
labels = torch.where(drop_ids, self.num_classes, labels)
|
404
|
-
return labels
|
1475
|
+
def token_drop(self, labels, force_drop_ids=None):
|
1476
|
+
"""
|
1477
|
+
Drops labels to enable classifier-free guidance.
|
1478
|
+
"""
|
1479
|
+
if force_drop_ids is None:
|
1480
|
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
1481
|
+
else:
|
1482
|
+
drop_ids = torch.tensor(force_drop_ids == 1)
|
1483
|
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
1484
|
+
return labels
|
1485
|
+
|
1486
|
+
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
1487
|
+
use_dropout = self.dropout_prob > 0
|
1488
|
+
if (self.training and use_dropout) or (force_drop_ids is not None):
|
1489
|
+
labels = self.token_drop(labels, force_drop_ids)
|
1490
|
+
embeddings = self.embedding_table(labels)
|
1491
|
+
return embeddings
|
1492
|
+
|
1493
|
+
|
1494
|
+
class TextImageProjection(nn.Module):
|
1495
|
+
def __init__(
|
1496
|
+
self,
|
1497
|
+
text_embed_dim: int = 1024,
|
1498
|
+
image_embed_dim: int = 768,
|
1499
|
+
cross_attention_dim: int = 768,
|
1500
|
+
num_image_text_embeds: int = 10,
|
1501
|
+
):
|
1502
|
+
super().__init__()
|
1503
|
+
|
1504
|
+
self.num_image_text_embeds = num_image_text_embeds
|
1505
|
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
1506
|
+
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
1507
|
+
|
1508
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
1509
|
+
batch_size = text_embeds.shape[0]
|
1510
|
+
|
1511
|
+
# image
|
1512
|
+
image_text_embeds = self.image_embeds(image_embeds)
|
1513
|
+
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
1514
|
+
|
1515
|
+
# text
|
1516
|
+
text_embeds = self.text_proj(text_embeds)
|
1517
|
+
|
1518
|
+
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
1519
|
+
|
1520
|
+
|
1521
|
+
class ImageProjection(nn.Module):
|
1522
|
+
def __init__(
|
1523
|
+
self,
|
1524
|
+
image_embed_dim: int = 768,
|
1525
|
+
cross_attention_dim: int = 768,
|
1526
|
+
num_image_text_embeds: int = 32,
|
1527
|
+
):
|
1528
|
+
super().__init__()
|
1529
|
+
|
1530
|
+
self.num_image_text_embeds = num_image_text_embeds
|
1531
|
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
1532
|
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
1533
|
+
|
1534
|
+
def forward(self, image_embeds: torch.Tensor):
|
1535
|
+
batch_size = image_embeds.shape[0]
|
1536
|
+
|
1537
|
+
# image
|
1538
|
+
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
|
1539
|
+
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
1540
|
+
image_embeds = self.norm(image_embeds)
|
1541
|
+
return image_embeds
|
1542
|
+
|
1543
|
+
|
1544
|
+
class IPAdapterFullImageProjection(nn.Module):
|
1545
|
+
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
1546
|
+
super().__init__()
|
1547
|
+
from .attention import FeedForward
|
1548
|
+
|
1549
|
+
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
|
1550
|
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
1551
|
+
|
1552
|
+
def forward(self, image_embeds: torch.Tensor):
|
1553
|
+
return self.norm(self.ff(image_embeds))
|
1554
|
+
|
1555
|
+
|
1556
|
+
class IPAdapterFaceIDImageProjection(nn.Module):
|
1557
|
+
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
1558
|
+
super().__init__()
|
1559
|
+
from .attention import FeedForward
|
1560
|
+
|
1561
|
+
self.num_tokens = num_tokens
|
1562
|
+
self.cross_attention_dim = cross_attention_dim
|
1563
|
+
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
|
1564
|
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
1565
|
+
|
1566
|
+
def forward(self, image_embeds: torch.Tensor):
|
1567
|
+
x = self.ff(image_embeds)
|
1568
|
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
1569
|
+
return self.norm(x)
|
1570
|
+
|
1571
|
+
|
1572
|
+
class CombinedTimestepLabelEmbeddings(nn.Module):
|
1573
|
+
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
1574
|
+
super().__init__()
|
1575
|
+
|
1576
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
1577
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
1578
|
+
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
1579
|
+
|
1580
|
+
def forward(self, timestep, class_labels, hidden_dtype=None):
|
1581
|
+
timesteps_proj = self.time_proj(timestep)
|
1582
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
1583
|
+
|
1584
|
+
class_labels = self.class_embedder(class_labels) # (N, D)
|
1585
|
+
|
1586
|
+
conditioning = timesteps_emb + class_labels # (N, D)
|
1587
|
+
|
1588
|
+
return conditioning
|
1589
|
+
|
1590
|
+
|
1591
|
+
class CombinedTimestepTextProjEmbeddings(nn.Module):
|
1592
|
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
1593
|
+
super().__init__()
|
1594
|
+
|
1595
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1596
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
1597
|
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
1598
|
+
|
1599
|
+
def forward(self, timestep, pooled_projection):
|
1600
|
+
timesteps_proj = self.time_proj(timestep)
|
1601
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
1602
|
+
|
1603
|
+
pooled_projections = self.text_embedder(pooled_projection)
|
1604
|
+
|
1605
|
+
conditioning = timesteps_emb + pooled_projections
|
1606
|
+
|
1607
|
+
return conditioning
|
1608
|
+
|
1609
|
+
|
1610
|
+
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
1611
|
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
1612
|
+
super().__init__()
|
1613
|
+
|
1614
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1615
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
1616
|
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
1617
|
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
1618
|
+
|
1619
|
+
def forward(self, timestep, guidance, pooled_projection):
|
1620
|
+
timesteps_proj = self.time_proj(timestep)
|
1621
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
1622
|
+
|
1623
|
+
guidance_proj = self.time_proj(guidance)
|
1624
|
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
1625
|
+
|
1626
|
+
time_guidance_emb = timesteps_emb + guidance_emb
|
1627
|
+
|
1628
|
+
pooled_projections = self.text_embedder(pooled_projection)
|
1629
|
+
conditioning = time_guidance_emb + pooled_projections
|
1630
|
+
|
1631
|
+
return conditioning
|
1632
|
+
|
405
1633
|
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
labels = self.token_drop(labels, force_drop_ids)
|
410
|
-
embeddings = self.embedding_table(labels)
|
411
|
-
return embeddings
|
1634
|
+
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
1635
|
+
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
1636
|
+
super().__init__()
|
412
1637
|
|
1638
|
+
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1639
|
+
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1640
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
|
1641
|
+
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
413
1642
|
|
414
|
-
|
415
|
-
def __init__(
|
1643
|
+
def forward(
|
416
1644
|
self,
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
1645
|
+
timestep: torch.Tensor,
|
1646
|
+
original_size: torch.Tensor,
|
1647
|
+
target_size: torch.Tensor,
|
1648
|
+
crop_coords: torch.Tensor,
|
1649
|
+
hidden_dtype: torch.dtype,
|
1650
|
+
) -> torch.Tensor:
|
1651
|
+
timesteps_proj = self.time_proj(timestep)
|
423
1652
|
|
424
|
-
self.
|
425
|
-
|
426
|
-
|
1653
|
+
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
|
1654
|
+
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
|
1655
|
+
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
|
427
1656
|
|
428
|
-
|
429
|
-
|
1657
|
+
# (B, 3 * condition_dim)
|
1658
|
+
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
|
430
1659
|
|
431
|
-
#
|
432
|
-
|
433
|
-
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
1660
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
1661
|
+
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
434
1662
|
|
435
|
-
|
436
|
-
|
1663
|
+
conditioning = timesteps_emb + condition_emb
|
1664
|
+
return conditioning
|
437
1665
|
|
438
|
-
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
439
1666
|
|
1667
|
+
class HunyuanDiTAttentionPool(nn.Module):
|
1668
|
+
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
440
1669
|
|
441
|
-
|
1670
|
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
1671
|
+
super().__init__()
|
1672
|
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
|
1673
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
1674
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
1675
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
1676
|
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
1677
|
+
self.num_heads = num_heads
|
1678
|
+
|
1679
|
+
def forward(self, x):
|
1680
|
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
1681
|
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
1682
|
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
1683
|
+
x, _ = F.multi_head_attention_forward(
|
1684
|
+
query=x[:1],
|
1685
|
+
key=x,
|
1686
|
+
value=x,
|
1687
|
+
embed_dim_to_check=x.shape[-1],
|
1688
|
+
num_heads=self.num_heads,
|
1689
|
+
q_proj_weight=self.q_proj.weight,
|
1690
|
+
k_proj_weight=self.k_proj.weight,
|
1691
|
+
v_proj_weight=self.v_proj.weight,
|
1692
|
+
in_proj_weight=None,
|
1693
|
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
1694
|
+
bias_k=None,
|
1695
|
+
bias_v=None,
|
1696
|
+
add_zero_attn=False,
|
1697
|
+
dropout_p=0,
|
1698
|
+
out_proj_weight=self.c_proj.weight,
|
1699
|
+
out_proj_bias=self.c_proj.bias,
|
1700
|
+
use_separate_proj_weight=True,
|
1701
|
+
training=self.training,
|
1702
|
+
need_weights=False,
|
1703
|
+
)
|
1704
|
+
return x.squeeze(0)
|
1705
|
+
|
1706
|
+
|
1707
|
+
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
442
1708
|
def __init__(
|
443
1709
|
self,
|
444
|
-
|
445
|
-
|
446
|
-
|
1710
|
+
embedding_dim,
|
1711
|
+
pooled_projection_dim=1024,
|
1712
|
+
seq_len=256,
|
1713
|
+
cross_attention_dim=2048,
|
1714
|
+
use_style_cond_and_image_meta_size=True,
|
447
1715
|
):
|
448
1716
|
super().__init__()
|
449
1717
|
|
450
|
-
self.
|
451
|
-
self.
|
452
|
-
self.norm = nn.LayerNorm(cross_attention_dim)
|
1718
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1719
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
453
1720
|
|
454
|
-
|
455
|
-
batch_size = image_embeds.shape[0]
|
1721
|
+
self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
456
1722
|
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
image_embeds = self.norm(image_embeds)
|
461
|
-
return image_embeds
|
1723
|
+
self.pooler = HunyuanDiTAttentionPool(
|
1724
|
+
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
1725
|
+
)
|
462
1726
|
|
1727
|
+
# Here we use a default learned embedder layer for future extension.
|
1728
|
+
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
|
1729
|
+
if use_style_cond_and_image_meta_size:
|
1730
|
+
self.style_embedder = nn.Embedding(1, embedding_dim)
|
1731
|
+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
1732
|
+
else:
|
1733
|
+
extra_in_dim = pooled_projection_dim
|
463
1734
|
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
1735
|
+
self.extra_embedder = PixArtAlphaTextProjection(
|
1736
|
+
in_features=extra_in_dim,
|
1737
|
+
hidden_size=embedding_dim * 4,
|
1738
|
+
out_features=embedding_dim,
|
1739
|
+
act_fn="silu_fp32",
|
1740
|
+
)
|
468
1741
|
|
469
|
-
|
470
|
-
|
1742
|
+
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
|
1743
|
+
timesteps_proj = self.time_proj(timestep)
|
1744
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
|
471
1745
|
|
472
|
-
|
473
|
-
|
1746
|
+
# extra condition1: text
|
1747
|
+
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
474
1748
|
|
1749
|
+
if self.use_style_cond_and_image_meta_size:
|
1750
|
+
# extra condition2: image meta size embedding
|
1751
|
+
image_meta_size = self.size_proj(image_meta_size.view(-1))
|
1752
|
+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
1753
|
+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
475
1754
|
|
476
|
-
|
477
|
-
|
1755
|
+
# extra condition3: style embedding
|
1756
|
+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
1757
|
+
|
1758
|
+
# Concatenate all extra vectors
|
1759
|
+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
1760
|
+
else:
|
1761
|
+
extra_cond = torch.cat([pooled_projections], dim=1)
|
1762
|
+
|
1763
|
+
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
1764
|
+
|
1765
|
+
return conditioning
|
1766
|
+
|
1767
|
+
|
1768
|
+
class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
1769
|
+
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
|
478
1770
|
super().__init__()
|
1771
|
+
self.time_proj = Timesteps(
|
1772
|
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
1773
|
+
)
|
479
1774
|
|
480
|
-
self.
|
481
|
-
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
482
|
-
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
1775
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
483
1776
|
|
484
|
-
|
485
|
-
|
486
|
-
|
1777
|
+
self.caption_embedder = nn.Sequential(
|
1778
|
+
nn.LayerNorm(cross_attention_dim),
|
1779
|
+
nn.Linear(
|
1780
|
+
cross_attention_dim,
|
1781
|
+
hidden_size,
|
1782
|
+
bias=True,
|
1783
|
+
),
|
1784
|
+
)
|
487
1785
|
|
488
|
-
|
1786
|
+
def forward(self, timestep, caption_feat, caption_mask):
|
1787
|
+
# timestep embedding:
|
1788
|
+
time_freq = self.time_proj(timestep)
|
1789
|
+
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
|
489
1790
|
|
490
|
-
|
1791
|
+
# caption condition embedding:
|
1792
|
+
caption_mask_float = caption_mask.float().unsqueeze(-1)
|
1793
|
+
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
|
1794
|
+
caption_feats_pool = caption_feats_pool.to(caption_feat)
|
1795
|
+
caption_embed = self.caption_embedder(caption_feats_pool)
|
1796
|
+
|
1797
|
+
conditioning = time_embed + caption_embed
|
491
1798
|
|
492
1799
|
return conditioning
|
493
1800
|
|
494
1801
|
|
1802
|
+
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
|
1803
|
+
def __init__(
|
1804
|
+
self,
|
1805
|
+
embedding_dim: int,
|
1806
|
+
pooled_projection_dim: int,
|
1807
|
+
text_embed_dim: int,
|
1808
|
+
time_embed_dim: int = 256,
|
1809
|
+
num_attention_heads: int = 8,
|
1810
|
+
) -> None:
|
1811
|
+
super().__init__()
|
1812
|
+
|
1813
|
+
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
|
1814
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
|
1815
|
+
self.pooler = MochiAttentionPool(
|
1816
|
+
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
|
1817
|
+
)
|
1818
|
+
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
|
1819
|
+
|
1820
|
+
def forward(
|
1821
|
+
self,
|
1822
|
+
timestep: torch.LongTensor,
|
1823
|
+
encoder_hidden_states: torch.Tensor,
|
1824
|
+
encoder_attention_mask: torch.Tensor,
|
1825
|
+
hidden_dtype: Optional[torch.dtype] = None,
|
1826
|
+
):
|
1827
|
+
time_proj = self.time_proj(timestep)
|
1828
|
+
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
|
1829
|
+
|
1830
|
+
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
|
1831
|
+
caption_proj = self.caption_proj(encoder_hidden_states)
|
1832
|
+
|
1833
|
+
conditioning = time_emb + pooled_projections
|
1834
|
+
return conditioning, caption_proj
|
1835
|
+
|
1836
|
+
|
495
1837
|
class TextTimeEmbedding(nn.Module):
|
496
1838
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
497
1839
|
super().__init__()
|
@@ -515,7 +1857,7 @@ class TextImageTimeEmbedding(nn.Module):
|
|
515
1857
|
self.text_norm = nn.LayerNorm(time_embed_dim)
|
516
1858
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
517
1859
|
|
518
|
-
def forward(self, text_embeds: torch.
|
1860
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
519
1861
|
# text
|
520
1862
|
time_text_embeds = self.text_proj(text_embeds)
|
521
1863
|
time_text_embeds = self.text_norm(time_text_embeds)
|
@@ -532,7 +1874,7 @@ class ImageTimeEmbedding(nn.Module):
|
|
532
1874
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
533
1875
|
self.image_norm = nn.LayerNorm(time_embed_dim)
|
534
1876
|
|
535
|
-
def forward(self, image_embeds: torch.
|
1877
|
+
def forward(self, image_embeds: torch.Tensor):
|
536
1878
|
# image
|
537
1879
|
time_image_embeds = self.image_proj(image_embeds)
|
538
1880
|
time_image_embeds = self.image_norm(time_image_embeds)
|
@@ -562,7 +1904,7 @@ class ImageHintTimeEmbedding(nn.Module):
|
|
562
1904
|
nn.Conv2d(256, 4, 3, padding=1),
|
563
1905
|
)
|
564
1906
|
|
565
|
-
def forward(self, image_embeds: torch.
|
1907
|
+
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
|
566
1908
|
# image
|
567
1909
|
time_image_embeds = self.image_proj(image_embeds)
|
568
1910
|
time_image_embeds = self.image_norm(time_image_embeds)
|
@@ -620,6 +1962,88 @@ class AttentionPooling(nn.Module):
|
|
620
1962
|
return a[:, 0, :] # cls_token
|
621
1963
|
|
622
1964
|
|
1965
|
+
class MochiAttentionPool(nn.Module):
|
1966
|
+
def __init__(
|
1967
|
+
self,
|
1968
|
+
num_attention_heads: int,
|
1969
|
+
embed_dim: int,
|
1970
|
+
output_dim: Optional[int] = None,
|
1971
|
+
) -> None:
|
1972
|
+
super().__init__()
|
1973
|
+
|
1974
|
+
self.output_dim = output_dim or embed_dim
|
1975
|
+
self.num_attention_heads = num_attention_heads
|
1976
|
+
|
1977
|
+
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
|
1978
|
+
self.to_q = nn.Linear(embed_dim, embed_dim)
|
1979
|
+
self.to_out = nn.Linear(embed_dim, self.output_dim)
|
1980
|
+
|
1981
|
+
@staticmethod
|
1982
|
+
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
1983
|
+
"""
|
1984
|
+
Pool tokens in x using mask.
|
1985
|
+
|
1986
|
+
NOTE: We assume x does not require gradients.
|
1987
|
+
|
1988
|
+
Args:
|
1989
|
+
x: (B, L, D) tensor of tokens.
|
1990
|
+
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
1991
|
+
|
1992
|
+
Returns:
|
1993
|
+
pooled: (B, D) tensor of pooled tokens.
|
1994
|
+
"""
|
1995
|
+
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
1996
|
+
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
1997
|
+
mask = mask[:, :, None].to(dtype=x.dtype)
|
1998
|
+
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
1999
|
+
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
2000
|
+
return pooled
|
2001
|
+
|
2002
|
+
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
2003
|
+
r"""
|
2004
|
+
Args:
|
2005
|
+
x (`torch.Tensor`):
|
2006
|
+
Tensor of shape `(B, S, D)` of input tokens.
|
2007
|
+
mask (`torch.Tensor`):
|
2008
|
+
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
|
2009
|
+
|
2010
|
+
Returns:
|
2011
|
+
`torch.Tensor`:
|
2012
|
+
`(B, D)` tensor of pooled tokens.
|
2013
|
+
"""
|
2014
|
+
D = x.size(2)
|
2015
|
+
|
2016
|
+
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
2017
|
+
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
2018
|
+
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
2019
|
+
|
2020
|
+
# Average non-padding token features. These will be used as the query.
|
2021
|
+
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
2022
|
+
|
2023
|
+
# Concat pooled features to input sequence.
|
2024
|
+
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
2025
|
+
|
2026
|
+
# Compute queries, keys, values. Only the mean token is used to create a query.
|
2027
|
+
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
2028
|
+
q = self.to_q(x[:, 0]) # (B, D)
|
2029
|
+
|
2030
|
+
# Extract heads.
|
2031
|
+
head_dim = D // self.num_attention_heads
|
2032
|
+
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
2033
|
+
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
2034
|
+
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
2035
|
+
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
|
2036
|
+
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
2037
|
+
|
2038
|
+
# Compute attention.
|
2039
|
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
|
2040
|
+
|
2041
|
+
# Concatenate heads and run output.
|
2042
|
+
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
2043
|
+
x = self.to_out(x)
|
2044
|
+
return x
|
2045
|
+
|
2046
|
+
|
623
2047
|
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
624
2048
|
"""
|
625
2049
|
Args:
|
@@ -714,7 +2138,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
|
714
2138
|
|
715
2139
|
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
716
2140
|
|
717
|
-
# positionet with text and image
|
2141
|
+
# positionet with text and image information
|
718
2142
|
else:
|
719
2143
|
phrases_masks = phrases_masks.unsqueeze(-1)
|
720
2144
|
image_masks = image_masks.unsqueeze(-1)
|
@@ -778,11 +2202,20 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
778
2202
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
779
2203
|
"""
|
780
2204
|
|
781
|
-
def __init__(self, in_features, hidden_size,
|
2205
|
+
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
782
2206
|
super().__init__()
|
2207
|
+
if out_features is None:
|
2208
|
+
out_features = hidden_size
|
783
2209
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
784
|
-
|
785
|
-
|
2210
|
+
if act_fn == "gelu_tanh":
|
2211
|
+
self.act_1 = nn.GELU(approximate="tanh")
|
2212
|
+
elif act_fn == "silu":
|
2213
|
+
self.act_1 = nn.SiLU()
|
2214
|
+
elif act_fn == "silu_fp32":
|
2215
|
+
self.act_1 = FP32SiLU()
|
2216
|
+
else:
|
2217
|
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
2218
|
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
786
2219
|
|
787
2220
|
def forward(self, caption):
|
788
2221
|
hidden_states = self.linear_1(caption)
|
@@ -791,21 +2224,52 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
791
2224
|
return hidden_states
|
792
2225
|
|
793
2226
|
|
2227
|
+
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
2228
|
+
def __init__(
|
2229
|
+
self,
|
2230
|
+
embed_dims: int = 768,
|
2231
|
+
dim_head: int = 64,
|
2232
|
+
heads: int = 16,
|
2233
|
+
ffn_ratio: float = 4,
|
2234
|
+
) -> None:
|
2235
|
+
super().__init__()
|
2236
|
+
from .attention import FeedForward
|
2237
|
+
|
2238
|
+
self.ln0 = nn.LayerNorm(embed_dims)
|
2239
|
+
self.ln1 = nn.LayerNorm(embed_dims)
|
2240
|
+
self.attn = Attention(
|
2241
|
+
query_dim=embed_dims,
|
2242
|
+
dim_head=dim_head,
|
2243
|
+
heads=heads,
|
2244
|
+
out_bias=False,
|
2245
|
+
)
|
2246
|
+
self.ff = nn.Sequential(
|
2247
|
+
nn.LayerNorm(embed_dims),
|
2248
|
+
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
2249
|
+
)
|
2250
|
+
|
2251
|
+
def forward(self, x, latents, residual):
|
2252
|
+
encoder_hidden_states = self.ln0(x)
|
2253
|
+
latents = self.ln1(latents)
|
2254
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
2255
|
+
latents = self.attn(latents, encoder_hidden_states) + residual
|
2256
|
+
latents = self.ff(latents) + latents
|
2257
|
+
return latents
|
2258
|
+
|
2259
|
+
|
794
2260
|
class IPAdapterPlusImageProjection(nn.Module):
|
795
2261
|
"""Resampler of IP-Adapter Plus.
|
796
2262
|
|
797
2263
|
Args:
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
num_queries (int): The number of queries. Defaults to 8.
|
808
|
-
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
2264
|
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
2265
|
+
that is the same
|
2266
|
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
2267
|
+
hidden_dims (int):
|
2268
|
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
2269
|
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
2270
|
+
Defaults to 16. num_queries (int):
|
2271
|
+
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
2272
|
+
of feedforward network hidden
|
809
2273
|
layer channels. Defaults to 4.
|
810
2274
|
"""
|
811
2275
|
|
@@ -821,8 +2285,6 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
821
2285
|
ffn_ratio: float = 4,
|
822
2286
|
) -> None:
|
823
2287
|
super().__init__()
|
824
|
-
from .attention import FeedForward # Lazy import to avoid circular import
|
825
|
-
|
826
2288
|
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
827
2289
|
|
828
2290
|
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
@@ -830,61 +2292,297 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
830
2292
|
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
831
2293
|
self.norm_out = nn.LayerNorm(output_dims)
|
832
2294
|
|
833
|
-
self.layers = nn.ModuleList(
|
834
|
-
|
835
|
-
|
836
|
-
nn.ModuleList(
|
837
|
-
[
|
838
|
-
nn.LayerNorm(hidden_dims),
|
839
|
-
nn.LayerNorm(hidden_dims),
|
840
|
-
Attention(
|
841
|
-
query_dim=hidden_dims,
|
842
|
-
dim_head=dim_head,
|
843
|
-
heads=heads,
|
844
|
-
out_bias=False,
|
845
|
-
),
|
846
|
-
nn.Sequential(
|
847
|
-
nn.LayerNorm(hidden_dims),
|
848
|
-
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
849
|
-
),
|
850
|
-
]
|
851
|
-
)
|
852
|
-
)
|
2295
|
+
self.layers = nn.ModuleList(
|
2296
|
+
[IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
2297
|
+
)
|
853
2298
|
|
854
2299
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
855
2300
|
"""Forward pass.
|
856
2301
|
|
857
2302
|
Args:
|
858
|
-
----
|
859
2303
|
x (torch.Tensor): Input Tensor.
|
860
|
-
|
861
2304
|
Returns:
|
862
|
-
-------
|
863
2305
|
torch.Tensor: Output Tensor.
|
864
2306
|
"""
|
865
2307
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
866
2308
|
|
867
2309
|
x = self.proj_in(x)
|
868
2310
|
|
869
|
-
for
|
2311
|
+
for block in self.layers:
|
870
2312
|
residual = latents
|
871
|
-
|
872
|
-
encoder_hidden_states = ln0(x)
|
873
|
-
latents = ln1(latents)
|
874
|
-
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
875
|
-
latents = attn(latents, encoder_hidden_states) + residual
|
876
|
-
latents = ff(latents) + latents
|
2313
|
+
latents = block(x, latents, residual)
|
877
2314
|
|
878
2315
|
latents = self.proj_out(latents)
|
879
2316
|
return self.norm_out(latents)
|
880
2317
|
|
881
2318
|
|
2319
|
+
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
2320
|
+
"""FacePerceiverResampler of IP-Adapter Plus.
|
2321
|
+
|
2322
|
+
Args:
|
2323
|
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
2324
|
+
that is the same
|
2325
|
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
2326
|
+
hidden_dims (int):
|
2327
|
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
2328
|
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
2329
|
+
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
|
2330
|
+
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
2331
|
+
layer channels. Defaults to 4.
|
2332
|
+
ffproj_ratio (float): The expansion ratio of feedforward network hidden
|
2333
|
+
layer channels (for ID embeddings). Defaults to 4.
|
2334
|
+
"""
|
2335
|
+
|
2336
|
+
def __init__(
|
2337
|
+
self,
|
2338
|
+
embed_dims: int = 768,
|
2339
|
+
output_dims: int = 768,
|
2340
|
+
hidden_dims: int = 1280,
|
2341
|
+
id_embeddings_dim: int = 512,
|
2342
|
+
depth: int = 4,
|
2343
|
+
dim_head: int = 64,
|
2344
|
+
heads: int = 16,
|
2345
|
+
num_tokens: int = 4,
|
2346
|
+
num_queries: int = 8,
|
2347
|
+
ffn_ratio: float = 4,
|
2348
|
+
ffproj_ratio: int = 2,
|
2349
|
+
) -> None:
|
2350
|
+
super().__init__()
|
2351
|
+
from .attention import FeedForward
|
2352
|
+
|
2353
|
+
self.num_tokens = num_tokens
|
2354
|
+
self.embed_dim = embed_dims
|
2355
|
+
self.clip_embeds = None
|
2356
|
+
self.shortcut = False
|
2357
|
+
self.shortcut_scale = 1.0
|
2358
|
+
|
2359
|
+
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
|
2360
|
+
self.norm = nn.LayerNorm(embed_dims)
|
2361
|
+
|
2362
|
+
self.proj_in = nn.Linear(hidden_dims, embed_dims)
|
2363
|
+
|
2364
|
+
self.proj_out = nn.Linear(embed_dims, output_dims)
|
2365
|
+
self.norm_out = nn.LayerNorm(output_dims)
|
2366
|
+
|
2367
|
+
self.layers = nn.ModuleList(
|
2368
|
+
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
2369
|
+
)
|
2370
|
+
|
2371
|
+
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
|
2372
|
+
"""Forward pass.
|
2373
|
+
|
2374
|
+
Args:
|
2375
|
+
id_embeds (torch.Tensor): Input Tensor (ID embeds).
|
2376
|
+
Returns:
|
2377
|
+
torch.Tensor: Output Tensor.
|
2378
|
+
"""
|
2379
|
+
id_embeds = id_embeds.to(self.clip_embeds.dtype)
|
2380
|
+
id_embeds = self.proj(id_embeds)
|
2381
|
+
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
|
2382
|
+
id_embeds = self.norm(id_embeds)
|
2383
|
+
latents = id_embeds
|
2384
|
+
|
2385
|
+
clip_embeds = self.proj_in(self.clip_embeds)
|
2386
|
+
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
|
2387
|
+
|
2388
|
+
for block in self.layers:
|
2389
|
+
residual = latents
|
2390
|
+
latents = block(x, latents, residual)
|
2391
|
+
|
2392
|
+
latents = self.proj_out(latents)
|
2393
|
+
out = self.norm_out(latents)
|
2394
|
+
if self.shortcut:
|
2395
|
+
out = id_embeds + self.shortcut_scale * out
|
2396
|
+
return out
|
2397
|
+
|
2398
|
+
|
2399
|
+
class IPAdapterTimeImageProjectionBlock(nn.Module):
|
2400
|
+
"""Block for IPAdapterTimeImageProjection.
|
2401
|
+
|
2402
|
+
Args:
|
2403
|
+
hidden_dim (`int`, defaults to 1280):
|
2404
|
+
The number of hidden channels.
|
2405
|
+
dim_head (`int`, defaults to 64):
|
2406
|
+
The number of head channels.
|
2407
|
+
heads (`int`, defaults to 20):
|
2408
|
+
Parallel attention heads.
|
2409
|
+
ffn_ratio (`int`, defaults to 4):
|
2410
|
+
The expansion ratio of feedforward network hidden layer channels.
|
2411
|
+
"""
|
2412
|
+
|
2413
|
+
def __init__(
|
2414
|
+
self,
|
2415
|
+
hidden_dim: int = 1280,
|
2416
|
+
dim_head: int = 64,
|
2417
|
+
heads: int = 20,
|
2418
|
+
ffn_ratio: int = 4,
|
2419
|
+
) -> None:
|
2420
|
+
super().__init__()
|
2421
|
+
from .attention import FeedForward
|
2422
|
+
|
2423
|
+
self.ln0 = nn.LayerNorm(hidden_dim)
|
2424
|
+
self.ln1 = nn.LayerNorm(hidden_dim)
|
2425
|
+
self.attn = Attention(
|
2426
|
+
query_dim=hidden_dim,
|
2427
|
+
cross_attention_dim=hidden_dim,
|
2428
|
+
dim_head=dim_head,
|
2429
|
+
heads=heads,
|
2430
|
+
bias=False,
|
2431
|
+
out_bias=False,
|
2432
|
+
)
|
2433
|
+
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
|
2434
|
+
|
2435
|
+
# AdaLayerNorm
|
2436
|
+
self.adaln_silu = nn.SiLU()
|
2437
|
+
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
|
2438
|
+
self.adaln_norm = nn.LayerNorm(hidden_dim)
|
2439
|
+
|
2440
|
+
# Set attention scale and fuse KV
|
2441
|
+
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
|
2442
|
+
self.attn.fuse_projections()
|
2443
|
+
self.attn.to_k = None
|
2444
|
+
self.attn.to_v = None
|
2445
|
+
|
2446
|
+
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
|
2447
|
+
"""Forward pass.
|
2448
|
+
|
2449
|
+
Args:
|
2450
|
+
x (`torch.Tensor`):
|
2451
|
+
Image features.
|
2452
|
+
latents (`torch.Tensor`):
|
2453
|
+
Latent features.
|
2454
|
+
timestep_emb (`torch.Tensor`):
|
2455
|
+
Timestep embedding.
|
2456
|
+
|
2457
|
+
Returns:
|
2458
|
+
`torch.Tensor`: Output latent features.
|
2459
|
+
"""
|
2460
|
+
|
2461
|
+
# Shift and scale for AdaLayerNorm
|
2462
|
+
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
|
2463
|
+
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
|
2464
|
+
|
2465
|
+
# Fused Attention
|
2466
|
+
residual = latents
|
2467
|
+
x = self.ln0(x)
|
2468
|
+
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
2469
|
+
|
2470
|
+
batch_size = latents.shape[0]
|
2471
|
+
|
2472
|
+
query = self.attn.to_q(latents)
|
2473
|
+
kv_input = torch.cat((x, latents), dim=-2)
|
2474
|
+
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
|
2475
|
+
|
2476
|
+
inner_dim = key.shape[-1]
|
2477
|
+
head_dim = inner_dim // self.attn.heads
|
2478
|
+
|
2479
|
+
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2480
|
+
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2481
|
+
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2482
|
+
|
2483
|
+
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
|
2484
|
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
2485
|
+
latents = weight @ value
|
2486
|
+
|
2487
|
+
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
|
2488
|
+
latents = self.attn.to_out[0](latents)
|
2489
|
+
latents = self.attn.to_out[1](latents)
|
2490
|
+
latents = latents + residual
|
2491
|
+
|
2492
|
+
## FeedForward
|
2493
|
+
residual = latents
|
2494
|
+
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
2495
|
+
return self.ff(latents) + residual
|
2496
|
+
|
2497
|
+
|
2498
|
+
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2499
|
+
class IPAdapterTimeImageProjection(nn.Module):
|
2500
|
+
"""Resampler of SD3 IP-Adapter with timestep embedding.
|
2501
|
+
|
2502
|
+
Args:
|
2503
|
+
embed_dim (`int`, defaults to 1152):
|
2504
|
+
The feature dimension.
|
2505
|
+
output_dim (`int`, defaults to 2432):
|
2506
|
+
The number of output channels.
|
2507
|
+
hidden_dim (`int`, defaults to 1280):
|
2508
|
+
The number of hidden channels.
|
2509
|
+
depth (`int`, defaults to 4):
|
2510
|
+
The number of blocks.
|
2511
|
+
dim_head (`int`, defaults to 64):
|
2512
|
+
The number of head channels.
|
2513
|
+
heads (`int`, defaults to 20):
|
2514
|
+
Parallel attention heads.
|
2515
|
+
num_queries (`int`, defaults to 64):
|
2516
|
+
The number of queries.
|
2517
|
+
ffn_ratio (`int`, defaults to 4):
|
2518
|
+
The expansion ratio of feedforward network hidden layer channels.
|
2519
|
+
timestep_in_dim (`int`, defaults to 320):
|
2520
|
+
The number of input channels for timestep embedding.
|
2521
|
+
timestep_flip_sin_to_cos (`bool`, defaults to True):
|
2522
|
+
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
|
2523
|
+
timestep_freq_shift (`int`, defaults to 0):
|
2524
|
+
Controls the timestep delta between frequencies between dimensions.
|
2525
|
+
"""
|
2526
|
+
|
2527
|
+
def __init__(
|
2528
|
+
self,
|
2529
|
+
embed_dim: int = 1152,
|
2530
|
+
output_dim: int = 2432,
|
2531
|
+
hidden_dim: int = 1280,
|
2532
|
+
depth: int = 4,
|
2533
|
+
dim_head: int = 64,
|
2534
|
+
heads: int = 20,
|
2535
|
+
num_queries: int = 64,
|
2536
|
+
ffn_ratio: int = 4,
|
2537
|
+
timestep_in_dim: int = 320,
|
2538
|
+
timestep_flip_sin_to_cos: bool = True,
|
2539
|
+
timestep_freq_shift: int = 0,
|
2540
|
+
) -> None:
|
2541
|
+
super().__init__()
|
2542
|
+
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
|
2543
|
+
self.proj_in = nn.Linear(embed_dim, hidden_dim)
|
2544
|
+
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
2545
|
+
self.norm_out = nn.LayerNorm(output_dim)
|
2546
|
+
self.layers = nn.ModuleList(
|
2547
|
+
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
2548
|
+
)
|
2549
|
+
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
2550
|
+
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
|
2551
|
+
|
2552
|
+
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
2553
|
+
"""Forward pass.
|
2554
|
+
|
2555
|
+
Args:
|
2556
|
+
x (`torch.Tensor`):
|
2557
|
+
Image features.
|
2558
|
+
timestep (`torch.Tensor`):
|
2559
|
+
Timestep in denoising process.
|
2560
|
+
Returns:
|
2561
|
+
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
2562
|
+
"""
|
2563
|
+
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
|
2564
|
+
timestep_emb = self.time_embedding(timestep_emb)
|
2565
|
+
|
2566
|
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
2567
|
+
|
2568
|
+
x = self.proj_in(x)
|
2569
|
+
x = x + timestep_emb[:, None]
|
2570
|
+
|
2571
|
+
for block in self.layers:
|
2572
|
+
latents = block(x, latents, timestep_emb)
|
2573
|
+
|
2574
|
+
latents = self.proj_out(latents)
|
2575
|
+
latents = self.norm_out(latents)
|
2576
|
+
|
2577
|
+
return latents, timestep_emb
|
2578
|
+
|
2579
|
+
|
882
2580
|
class MultiIPAdapterImageProjection(nn.Module):
|
883
2581
|
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
884
2582
|
super().__init__()
|
885
2583
|
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
886
2584
|
|
887
|
-
def forward(self, image_embeds: List[torch.
|
2585
|
+
def forward(self, image_embeds: List[torch.Tensor]):
|
888
2586
|
projected_image_embeds = []
|
889
2587
|
|
890
2588
|
# currently, we accept `image_embeds` as
|
@@ -893,7 +2591,7 @@ class MultiIPAdapterImageProjection(nn.Module):
|
|
893
2591
|
if not isinstance(image_embeds, list):
|
894
2592
|
deprecation_message = (
|
895
2593
|
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
896
|
-
" Please make sure to update your script to pass `image_embeds` as a list of tensors to
|
2594
|
+
" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
|
897
2595
|
)
|
898
2596
|
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
899
2597
|
image_embeds = [image_embeds.unsqueeze(1)]
|