diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -84,15 +84,106 @@ def get_3d_sincos_pos_embed(
|
|
84
84
|
temporal_size: int,
|
85
85
|
spatial_interpolation_scale: float = 1.0,
|
86
86
|
temporal_interpolation_scale: float = 1.0,
|
87
|
+
device: Optional[torch.device] = None,
|
88
|
+
output_type: str = "np",
|
89
|
+
) -> torch.Tensor:
|
90
|
+
r"""
|
91
|
+
Creates 3D sinusoidal positional embeddings.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
embed_dim (`int`):
|
95
|
+
The embedding dimension of inputs. It must be divisible by 16.
|
96
|
+
spatial_size (`int` or `Tuple[int, int]`):
|
97
|
+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
98
|
+
spatial dimensions (height and width).
|
99
|
+
temporal_size (`int`):
|
100
|
+
The temporal dimension of postional embeddings (number of frames).
|
101
|
+
spatial_interpolation_scale (`float`, defaults to 1.0):
|
102
|
+
Scale factor for spatial grid interpolation.
|
103
|
+
temporal_interpolation_scale (`float`, defaults to 1.0):
|
104
|
+
Scale factor for temporal grid interpolation.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
`torch.Tensor`:
|
108
|
+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
109
|
+
embed_dim]`.
|
110
|
+
"""
|
111
|
+
if output_type == "np":
|
112
|
+
return _get_3d_sincos_pos_embed_np(
|
113
|
+
embed_dim=embed_dim,
|
114
|
+
spatial_size=spatial_size,
|
115
|
+
temporal_size=temporal_size,
|
116
|
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
117
|
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
118
|
+
)
|
119
|
+
if embed_dim % 4 != 0:
|
120
|
+
raise ValueError("`embed_dim` must be divisible by 4")
|
121
|
+
if isinstance(spatial_size, int):
|
122
|
+
spatial_size = (spatial_size, spatial_size)
|
123
|
+
|
124
|
+
embed_dim_spatial = 3 * embed_dim // 4
|
125
|
+
embed_dim_temporal = embed_dim // 4
|
126
|
+
|
127
|
+
# 1. Spatial
|
128
|
+
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
129
|
+
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
130
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
131
|
+
grid = torch.stack(grid, dim=0)
|
132
|
+
|
133
|
+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
134
|
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
|
135
|
+
|
136
|
+
# 2. Temporal
|
137
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
|
138
|
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
|
139
|
+
|
140
|
+
# 3. Concat
|
141
|
+
pos_embed_spatial = pos_embed_spatial[None, :, :]
|
142
|
+
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
|
143
|
+
|
144
|
+
pos_embed_temporal = pos_embed_temporal[:, None, :]
|
145
|
+
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
|
146
|
+
spatial_size[0] * spatial_size[1], dim=1
|
147
|
+
) # [T, H*W, D // 4]
|
148
|
+
|
149
|
+
pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
|
150
|
+
return pos_embed
|
151
|
+
|
152
|
+
|
153
|
+
def _get_3d_sincos_pos_embed_np(
|
154
|
+
embed_dim: int,
|
155
|
+
spatial_size: Union[int, Tuple[int, int]],
|
156
|
+
temporal_size: int,
|
157
|
+
spatial_interpolation_scale: float = 1.0,
|
158
|
+
temporal_interpolation_scale: float = 1.0,
|
87
159
|
) -> np.ndarray:
|
88
160
|
r"""
|
161
|
+
Creates 3D sinusoidal positional embeddings.
|
162
|
+
|
89
163
|
Args:
|
90
164
|
embed_dim (`int`):
|
165
|
+
The embedding dimension of inputs. It must be divisible by 16.
|
91
166
|
spatial_size (`int` or `Tuple[int, int]`):
|
167
|
+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
168
|
+
spatial dimensions (height and width).
|
92
169
|
temporal_size (`int`):
|
170
|
+
The temporal dimension of postional embeddings (number of frames).
|
93
171
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
172
|
+
Scale factor for spatial grid interpolation.
|
94
173
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
174
|
+
Scale factor for temporal grid interpolation.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
`np.ndarray`:
|
178
|
+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
179
|
+
embed_dim]`.
|
95
180
|
"""
|
181
|
+
deprecation_message = (
|
182
|
+
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
|
183
|
+
" `from_numpy` is no longer required."
|
184
|
+
" Pass `output_type='pt' to use the new version now."
|
185
|
+
)
|
186
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
96
187
|
if embed_dim % 4 != 0:
|
97
188
|
raise ValueError("`embed_dim` must be divisible by 4")
|
98
189
|
if isinstance(spatial_size, int):
|
@@ -126,11 +217,164 @@ def get_3d_sincos_pos_embed(
|
|
126
217
|
|
127
218
|
|
128
219
|
def get_2d_sincos_pos_embed(
|
220
|
+
embed_dim,
|
221
|
+
grid_size,
|
222
|
+
cls_token=False,
|
223
|
+
extra_tokens=0,
|
224
|
+
interpolation_scale=1.0,
|
225
|
+
base_size=16,
|
226
|
+
device: Optional[torch.device] = None,
|
227
|
+
output_type: str = "np",
|
228
|
+
):
|
229
|
+
"""
|
230
|
+
Creates 2D sinusoidal positional embeddings.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
embed_dim (`int`):
|
234
|
+
The embedding dimension.
|
235
|
+
grid_size (`int`):
|
236
|
+
The size of the grid height and width.
|
237
|
+
cls_token (`bool`, defaults to `False`):
|
238
|
+
Whether or not to add a classification token.
|
239
|
+
extra_tokens (`int`, defaults to `0`):
|
240
|
+
The number of extra tokens to add.
|
241
|
+
interpolation_scale (`float`, defaults to `1.0`):
|
242
|
+
The scale of the interpolation.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
pos_embed (`torch.Tensor`):
|
246
|
+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
247
|
+
embed_dim]` if using cls_token
|
248
|
+
"""
|
249
|
+
if output_type == "np":
|
250
|
+
deprecation_message = (
|
251
|
+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
252
|
+
" `from_numpy` is no longer required."
|
253
|
+
" Pass `output_type='pt' to use the new version now."
|
254
|
+
)
|
255
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
256
|
+
return get_2d_sincos_pos_embed_np(
|
257
|
+
embed_dim=embed_dim,
|
258
|
+
grid_size=grid_size,
|
259
|
+
cls_token=cls_token,
|
260
|
+
extra_tokens=extra_tokens,
|
261
|
+
interpolation_scale=interpolation_scale,
|
262
|
+
base_size=base_size,
|
263
|
+
)
|
264
|
+
if isinstance(grid_size, int):
|
265
|
+
grid_size = (grid_size, grid_size)
|
266
|
+
|
267
|
+
grid_h = (
|
268
|
+
torch.arange(grid_size[0], device=device, dtype=torch.float32)
|
269
|
+
/ (grid_size[0] / base_size)
|
270
|
+
/ interpolation_scale
|
271
|
+
)
|
272
|
+
grid_w = (
|
273
|
+
torch.arange(grid_size[1], device=device, dtype=torch.float32)
|
274
|
+
/ (grid_size[1] / base_size)
|
275
|
+
/ interpolation_scale
|
276
|
+
)
|
277
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
278
|
+
grid = torch.stack(grid, dim=0)
|
279
|
+
|
280
|
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
281
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
|
282
|
+
if cls_token and extra_tokens > 0:
|
283
|
+
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
|
284
|
+
return pos_embed
|
285
|
+
|
286
|
+
|
287
|
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
288
|
+
r"""
|
289
|
+
This function generates 2D sinusoidal positional embeddings from a grid.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
embed_dim (`int`): The embedding dimension.
|
293
|
+
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
297
|
+
"""
|
298
|
+
if output_type == "np":
|
299
|
+
deprecation_message = (
|
300
|
+
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
301
|
+
" `from_numpy` is no longer required."
|
302
|
+
" Pass `output_type='pt' to use the new version now."
|
303
|
+
)
|
304
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
305
|
+
return get_2d_sincos_pos_embed_from_grid_np(
|
306
|
+
embed_dim=embed_dim,
|
307
|
+
grid=grid,
|
308
|
+
)
|
309
|
+
if embed_dim % 2 != 0:
|
310
|
+
raise ValueError("embed_dim must be divisible by 2")
|
311
|
+
|
312
|
+
# use half of dimensions to encode grid_h
|
313
|
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
|
314
|
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
|
315
|
+
|
316
|
+
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
|
317
|
+
return emb
|
318
|
+
|
319
|
+
|
320
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
321
|
+
"""
|
322
|
+
This function generates 1D positional embeddings from a grid.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
embed_dim (`int`): The embedding dimension `D`
|
326
|
+
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
|
330
|
+
"""
|
331
|
+
if output_type == "np":
|
332
|
+
deprecation_message = (
|
333
|
+
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
334
|
+
" `from_numpy` is no longer required."
|
335
|
+
" Pass `output_type='pt' to use the new version now."
|
336
|
+
)
|
337
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
338
|
+
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
339
|
+
if embed_dim % 2 != 0:
|
340
|
+
raise ValueError("embed_dim must be divisible by 2")
|
341
|
+
|
342
|
+
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
|
343
|
+
omega /= embed_dim / 2.0
|
344
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
345
|
+
|
346
|
+
pos = pos.reshape(-1) # (M,)
|
347
|
+
out = torch.outer(pos, omega) # (M, D/2), outer product
|
348
|
+
|
349
|
+
emb_sin = torch.sin(out) # (M, D/2)
|
350
|
+
emb_cos = torch.cos(out) # (M, D/2)
|
351
|
+
|
352
|
+
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
353
|
+
return emb
|
354
|
+
|
355
|
+
|
356
|
+
def get_2d_sincos_pos_embed_np(
|
129
357
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
130
358
|
):
|
131
359
|
"""
|
132
|
-
|
133
|
-
|
360
|
+
Creates 2D sinusoidal positional embeddings.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
embed_dim (`int`):
|
364
|
+
The embedding dimension.
|
365
|
+
grid_size (`int`):
|
366
|
+
The size of the grid height and width.
|
367
|
+
cls_token (`bool`, defaults to `False`):
|
368
|
+
Whether or not to add a classification token.
|
369
|
+
extra_tokens (`int`, defaults to `0`):
|
370
|
+
The number of extra tokens to add.
|
371
|
+
interpolation_scale (`float`, defaults to `1.0`):
|
372
|
+
The scale of the interpolation.
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
pos_embed (`np.ndarray`):
|
376
|
+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
377
|
+
embed_dim]` if using cls_token
|
134
378
|
"""
|
135
379
|
if isinstance(grid_size, int):
|
136
380
|
grid_size = (grid_size, grid_size)
|
@@ -141,27 +385,44 @@ def get_2d_sincos_pos_embed(
|
|
141
385
|
grid = np.stack(grid, axis=0)
|
142
386
|
|
143
387
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
144
|
-
pos_embed =
|
388
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
|
145
389
|
if cls_token and extra_tokens > 0:
|
146
390
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
147
391
|
return pos_embed
|
148
392
|
|
149
393
|
|
150
|
-
def
|
394
|
+
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
|
395
|
+
r"""
|
396
|
+
This function generates 2D sinusoidal positional embeddings from a grid.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
embed_dim (`int`): The embedding dimension.
|
400
|
+
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
|
401
|
+
|
402
|
+
Returns:
|
403
|
+
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
404
|
+
"""
|
151
405
|
if embed_dim % 2 != 0:
|
152
406
|
raise ValueError("embed_dim must be divisible by 2")
|
153
407
|
|
154
408
|
# use half of dimensions to encode grid_h
|
155
|
-
emb_h =
|
156
|
-
emb_w =
|
409
|
+
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
|
410
|
+
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
|
157
411
|
|
158
412
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
159
413
|
return emb
|
160
414
|
|
161
415
|
|
162
|
-
def
|
416
|
+
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
|
163
417
|
"""
|
164
|
-
|
418
|
+
This function generates 1D positional embeddings from a grid.
|
419
|
+
|
420
|
+
Args:
|
421
|
+
embed_dim (`int`): The embedding dimension `D`
|
422
|
+
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
|
423
|
+
|
424
|
+
Returns:
|
425
|
+
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
|
165
426
|
"""
|
166
427
|
if embed_dim % 2 != 0:
|
167
428
|
raise ValueError("embed_dim must be divisible by 2")
|
@@ -181,7 +442,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
181
442
|
|
182
443
|
|
183
444
|
class PatchEmbed(nn.Module):
|
184
|
-
"""
|
445
|
+
"""
|
446
|
+
2D Image to Patch Embedding with support for SD3 cropping.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
height (`int`, defaults to `224`): The height of the image.
|
450
|
+
width (`int`, defaults to `224`): The width of the image.
|
451
|
+
patch_size (`int`, defaults to `16`): The size of the patches.
|
452
|
+
in_channels (`int`, defaults to `3`): The number of input channels.
|
453
|
+
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
454
|
+
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
|
455
|
+
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
|
456
|
+
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
457
|
+
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
|
458
|
+
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
|
459
|
+
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
|
460
|
+
"""
|
185
461
|
|
186
462
|
def __init__(
|
187
463
|
self,
|
@@ -227,10 +503,14 @@ class PatchEmbed(nn.Module):
|
|
227
503
|
self.pos_embed = None
|
228
504
|
elif pos_embed_type == "sincos":
|
229
505
|
pos_embed = get_2d_sincos_pos_embed(
|
230
|
-
embed_dim,
|
506
|
+
embed_dim,
|
507
|
+
grid_size,
|
508
|
+
base_size=self.base_size,
|
509
|
+
interpolation_scale=self.interpolation_scale,
|
510
|
+
output_type="pt",
|
231
511
|
)
|
232
512
|
persistent = True if pos_embed_max_size else False
|
233
|
-
self.register_buffer("pos_embed",
|
513
|
+
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
|
234
514
|
else:
|
235
515
|
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
236
516
|
|
@@ -262,7 +542,6 @@ class PatchEmbed(nn.Module):
|
|
262
542
|
height, width = latent.shape[-2:]
|
263
543
|
else:
|
264
544
|
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
265
|
-
|
266
545
|
latent = self.proj(latent)
|
267
546
|
if self.flatten:
|
268
547
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
@@ -280,8 +559,10 @@ class PatchEmbed(nn.Module):
|
|
280
559
|
grid_size=(height, width),
|
281
560
|
base_size=self.base_size,
|
282
561
|
interpolation_scale=self.interpolation_scale,
|
562
|
+
device=latent.device,
|
563
|
+
output_type="pt",
|
283
564
|
)
|
284
|
-
pos_embed =
|
565
|
+
pos_embed = pos_embed.float().unsqueeze(0)
|
285
566
|
else:
|
286
567
|
pos_embed = self.pos_embed
|
287
568
|
|
@@ -289,7 +570,15 @@ class PatchEmbed(nn.Module):
|
|
289
570
|
|
290
571
|
|
291
572
|
class LuminaPatchEmbed(nn.Module):
|
292
|
-
"""
|
573
|
+
"""
|
574
|
+
2D Image to Patch Embedding with support for Lumina-T2X
|
575
|
+
|
576
|
+
Args:
|
577
|
+
patch_size (`int`, defaults to `2`): The size of the patches.
|
578
|
+
in_channels (`int`, defaults to `4`): The number of input channels.
|
579
|
+
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
580
|
+
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
581
|
+
"""
|
293
582
|
|
294
583
|
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
295
584
|
super().__init__()
|
@@ -338,6 +627,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
338
627
|
def __init__(
|
339
628
|
self,
|
340
629
|
patch_size: int = 2,
|
630
|
+
patch_size_t: Optional[int] = None,
|
341
631
|
in_channels: int = 16,
|
342
632
|
embed_dim: int = 1920,
|
343
633
|
text_embed_dim: int = 4096,
|
@@ -355,6 +645,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
355
645
|
super().__init__()
|
356
646
|
|
357
647
|
self.patch_size = patch_size
|
648
|
+
self.patch_size_t = patch_size_t
|
358
649
|
self.embed_dim = embed_dim
|
359
650
|
self.sample_height = sample_height
|
360
651
|
self.sample_width = sample_width
|
@@ -366,9 +657,15 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
366
657
|
self.use_positional_embeddings = use_positional_embeddings
|
367
658
|
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
368
659
|
|
369
|
-
|
370
|
-
|
371
|
-
|
660
|
+
if patch_size_t is None:
|
661
|
+
# CogVideoX 1.0 checkpoints
|
662
|
+
self.proj = nn.Conv2d(
|
663
|
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
664
|
+
)
|
665
|
+
else:
|
666
|
+
# CogVideoX 1.5 checkpoints
|
667
|
+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
668
|
+
|
372
669
|
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
373
670
|
|
374
671
|
if use_positional_embeddings or use_learned_positional_embeddings:
|
@@ -376,7 +673,9 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
376
673
|
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
377
674
|
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
378
675
|
|
379
|
-
def _get_positional_embeddings(
|
676
|
+
def _get_positional_embeddings(
|
677
|
+
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
|
678
|
+
) -> torch.Tensor:
|
380
679
|
post_patch_height = sample_height // self.patch_size
|
381
680
|
post_patch_width = sample_width // self.patch_size
|
382
681
|
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
@@ -388,9 +687,11 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
388
687
|
post_time_compression_frames,
|
389
688
|
self.spatial_interpolation_scale,
|
390
689
|
self.temporal_interpolation_scale,
|
690
|
+
device=device,
|
691
|
+
output_type="pt",
|
391
692
|
)
|
392
|
-
pos_embedding =
|
393
|
-
joint_pos_embedding =
|
693
|
+
pos_embedding = pos_embedding.flatten(0, 1)
|
694
|
+
joint_pos_embedding = pos_embedding.new_zeros(
|
394
695
|
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
395
696
|
)
|
396
697
|
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
@@ -407,12 +708,24 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
407
708
|
"""
|
408
709
|
text_embeds = self.text_proj(text_embeds)
|
409
710
|
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
711
|
+
batch_size, num_frames, channels, height, width = image_embeds.shape
|
712
|
+
|
713
|
+
if self.patch_size_t is None:
|
714
|
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
715
|
+
image_embeds = self.proj(image_embeds)
|
716
|
+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
717
|
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
718
|
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
719
|
+
else:
|
720
|
+
p = self.patch_size
|
721
|
+
p_t = self.patch_size_t
|
722
|
+
|
723
|
+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
724
|
+
image_embeds = image_embeds.reshape(
|
725
|
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
726
|
+
)
|
727
|
+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
728
|
+
image_embeds = self.proj(image_embeds)
|
416
729
|
|
417
730
|
embeds = torch.cat(
|
418
731
|
[text_embeds, image_embeds], dim=1
|
@@ -432,18 +745,84 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
432
745
|
or self.sample_width != width
|
433
746
|
or self.sample_frames != pre_time_compression_frames
|
434
747
|
):
|
435
|
-
pos_embedding = self._get_positional_embeddings(
|
436
|
-
|
748
|
+
pos_embedding = self._get_positional_embeddings(
|
749
|
+
height, width, pre_time_compression_frames, device=embeds.device
|
750
|
+
)
|
437
751
|
else:
|
438
752
|
pos_embedding = self.pos_embedding
|
439
753
|
|
754
|
+
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
440
755
|
embeds = embeds + pos_embedding
|
441
756
|
|
442
757
|
return embeds
|
443
758
|
|
444
759
|
|
760
|
+
class CogView3PlusPatchEmbed(nn.Module):
|
761
|
+
def __init__(
|
762
|
+
self,
|
763
|
+
in_channels: int = 16,
|
764
|
+
hidden_size: int = 2560,
|
765
|
+
patch_size: int = 2,
|
766
|
+
text_hidden_size: int = 4096,
|
767
|
+
pos_embed_max_size: int = 128,
|
768
|
+
):
|
769
|
+
super().__init__()
|
770
|
+
self.in_channels = in_channels
|
771
|
+
self.hidden_size = hidden_size
|
772
|
+
self.patch_size = patch_size
|
773
|
+
self.text_hidden_size = text_hidden_size
|
774
|
+
self.pos_embed_max_size = pos_embed_max_size
|
775
|
+
# Linear projection for image patches
|
776
|
+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
777
|
+
|
778
|
+
# Linear projection for text embeddings
|
779
|
+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
780
|
+
|
781
|
+
pos_embed = get_2d_sincos_pos_embed(
|
782
|
+
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
|
783
|
+
)
|
784
|
+
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
|
785
|
+
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
|
786
|
+
|
787
|
+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
788
|
+
batch_size, channel, height, width = hidden_states.shape
|
789
|
+
|
790
|
+
if height % self.patch_size != 0 or width % self.patch_size != 0:
|
791
|
+
raise ValueError("Height and width must be divisible by patch size")
|
792
|
+
|
793
|
+
height = height // self.patch_size
|
794
|
+
width = width // self.patch_size
|
795
|
+
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
|
796
|
+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
|
797
|
+
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
|
798
|
+
|
799
|
+
# Project the patches
|
800
|
+
hidden_states = self.proj(hidden_states)
|
801
|
+
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
802
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
803
|
+
|
804
|
+
# Calculate text_length
|
805
|
+
text_length = encoder_hidden_states.shape[1]
|
806
|
+
|
807
|
+
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
|
808
|
+
text_pos_embed = torch.zeros(
|
809
|
+
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
|
810
|
+
)
|
811
|
+
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
|
812
|
+
|
813
|
+
return (hidden_states + pos_embed).to(hidden_states.dtype)
|
814
|
+
|
815
|
+
|
445
816
|
def get_3d_rotary_pos_embed(
|
446
|
-
embed_dim,
|
817
|
+
embed_dim,
|
818
|
+
crops_coords,
|
819
|
+
grid_size,
|
820
|
+
temporal_size,
|
821
|
+
theta: int = 10000,
|
822
|
+
use_real: bool = True,
|
823
|
+
grid_type: str = "linspace",
|
824
|
+
max_size: Optional[Tuple[int, int]] = None,
|
825
|
+
device: Optional[torch.device] = None,
|
447
826
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
448
827
|
"""
|
449
828
|
RoPE for video tokens with 3D structure.
|
@@ -459,16 +838,36 @@ def get_3d_rotary_pos_embed(
|
|
459
838
|
The size of the temporal dimension.
|
460
839
|
theta (`float`):
|
461
840
|
Scaling factor for frequency computation.
|
462
|
-
|
463
|
-
|
841
|
+
grid_type (`str`):
|
842
|
+
Whether to use "linspace" or "slice" to compute grids.
|
464
843
|
|
465
844
|
Returns:
|
466
845
|
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
467
846
|
"""
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
847
|
+
if use_real is not True:
|
848
|
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
849
|
+
|
850
|
+
if grid_type == "linspace":
|
851
|
+
start, stop = crops_coords
|
852
|
+
grid_size_h, grid_size_w = grid_size
|
853
|
+
grid_h = torch.linspace(
|
854
|
+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
855
|
+
)
|
856
|
+
grid_w = torch.linspace(
|
857
|
+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
858
|
+
)
|
859
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
860
|
+
grid_t = torch.linspace(
|
861
|
+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
862
|
+
)
|
863
|
+
elif grid_type == "slice":
|
864
|
+
max_h, max_w = max_size
|
865
|
+
grid_size_h, grid_size_w = grid_size
|
866
|
+
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
867
|
+
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
868
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
869
|
+
else:
|
870
|
+
raise ValueError("Invalid value passed for `grid_type`.")
|
472
871
|
|
473
872
|
# Compute dimensions for each axis
|
474
873
|
dim_t = embed_dim // 4
|
@@ -476,57 +875,139 @@ def get_3d_rotary_pos_embed(
|
|
476
875
|
dim_w = embed_dim // 8 * 3
|
477
876
|
|
478
877
|
# Temporal frequencies
|
479
|
-
freqs_t =
|
480
|
-
|
481
|
-
|
482
|
-
|
878
|
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
|
879
|
+
# Spatial frequencies for height and width
|
880
|
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
|
881
|
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
|
882
|
+
|
883
|
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
884
|
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
885
|
+
freqs_t = freqs_t[:, None, None, :].expand(
|
886
|
+
-1, grid_size_h, grid_size_w, -1
|
887
|
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
888
|
+
freqs_h = freqs_h[None, :, None, :].expand(
|
889
|
+
temporal_size, -1, grid_size_w, -1
|
890
|
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
891
|
+
freqs_w = freqs_w[None, None, :, :].expand(
|
892
|
+
temporal_size, grid_size_h, -1, -1
|
893
|
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
894
|
+
|
895
|
+
freqs = torch.cat(
|
896
|
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
897
|
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
898
|
+
freqs = freqs.view(
|
899
|
+
temporal_size * grid_size_h * grid_size_w, -1
|
900
|
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
901
|
+
return freqs
|
902
|
+
|
903
|
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
904
|
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
905
|
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
906
|
+
|
907
|
+
if grid_type == "slice":
|
908
|
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
909
|
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
910
|
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
911
|
+
|
912
|
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
913
|
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
914
|
+
return cos, sin
|
915
|
+
|
916
|
+
|
917
|
+
def get_3d_rotary_pos_embed_allegro(
|
918
|
+
embed_dim,
|
919
|
+
crops_coords,
|
920
|
+
grid_size,
|
921
|
+
temporal_size,
|
922
|
+
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
923
|
+
theta: int = 10000,
|
924
|
+
device: Optional[torch.device] = None,
|
925
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
926
|
+
# TODO(aryan): docs
|
927
|
+
start, stop = crops_coords
|
928
|
+
grid_size_h, grid_size_w = grid_size
|
929
|
+
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
|
930
|
+
grid_t = torch.linspace(
|
931
|
+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
932
|
+
)
|
933
|
+
grid_h = torch.linspace(
|
934
|
+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
935
|
+
)
|
936
|
+
grid_w = torch.linspace(
|
937
|
+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
938
|
+
)
|
939
|
+
|
940
|
+
# Compute dimensions for each axis
|
941
|
+
dim_t = embed_dim // 3
|
942
|
+
dim_h = embed_dim // 3
|
943
|
+
dim_w = embed_dim // 3
|
483
944
|
|
945
|
+
# Temporal frequencies
|
946
|
+
freqs_t = get_1d_rotary_pos_embed(
|
947
|
+
dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
|
948
|
+
)
|
484
949
|
# Spatial frequencies for height and width
|
485
|
-
freqs_h =
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
492
|
-
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
493
|
-
|
494
|
-
# Broadcast and concatenate tensors along specified dimension
|
495
|
-
def broadcast(tensors, dim=-1):
|
496
|
-
num_tensors = len(tensors)
|
497
|
-
shape_lens = {len(t.shape) for t in tensors}
|
498
|
-
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
499
|
-
shape_len = list(shape_lens)[0]
|
500
|
-
dim = (dim + shape_len) if dim < 0 else dim
|
501
|
-
dims = list(zip(*(list(t.shape) for t in tensors)))
|
502
|
-
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
503
|
-
assert all(
|
504
|
-
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
505
|
-
), "invalid dimensions for broadcastable concatenation"
|
506
|
-
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
507
|
-
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
508
|
-
expanded_dims.insert(dim, (dim, dims[dim]))
|
509
|
-
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
510
|
-
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
511
|
-
return torch.cat(tensors, dim=dim)
|
512
|
-
|
513
|
-
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
514
|
-
|
515
|
-
t, h, w, d = freqs.shape
|
516
|
-
freqs = freqs.view(t * h * w, d)
|
517
|
-
|
518
|
-
# Generate sine and cosine components
|
519
|
-
sin = freqs.sin()
|
520
|
-
cos = freqs.cos()
|
950
|
+
freqs_h = get_1d_rotary_pos_embed(
|
951
|
+
dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
|
952
|
+
)
|
953
|
+
freqs_w = get_1d_rotary_pos_embed(
|
954
|
+
dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
|
955
|
+
)
|
521
956
|
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
957
|
+
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
|
958
|
+
|
959
|
+
|
960
|
+
def get_2d_rotary_pos_embed(
|
961
|
+
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
|
962
|
+
):
|
963
|
+
"""
|
964
|
+
RoPE for image tokens with 2d structure.
|
965
|
+
|
966
|
+
Args:
|
967
|
+
embed_dim: (`int`):
|
968
|
+
The embedding dimension size
|
969
|
+
crops_coords (`Tuple[int]`)
|
970
|
+
The top-left and bottom-right coordinates of the crop.
|
971
|
+
grid_size (`Tuple[int]`):
|
972
|
+
The grid size of the positional embedding.
|
973
|
+
use_real (`bool`):
|
974
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
975
|
+
device: (`torch.device`, **optional**):
|
976
|
+
The device used to create tensors.
|
977
|
+
|
978
|
+
Returns:
|
979
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
980
|
+
"""
|
981
|
+
if output_type == "np":
|
982
|
+
deprecation_message = (
|
983
|
+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
984
|
+
" `from_numpy` is no longer required."
|
985
|
+
" Pass `output_type='pt' to use the new version now."
|
986
|
+
)
|
987
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
988
|
+
return _get_2d_rotary_pos_embed_np(
|
989
|
+
embed_dim=embed_dim,
|
990
|
+
crops_coords=crops_coords,
|
991
|
+
grid_size=grid_size,
|
992
|
+
use_real=use_real,
|
993
|
+
)
|
994
|
+
start, stop = crops_coords
|
995
|
+
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
|
996
|
+
grid_h = torch.linspace(
|
997
|
+
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
|
998
|
+
)
|
999
|
+
grid_w = torch.linspace(
|
1000
|
+
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
|
1001
|
+
)
|
1002
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
1003
|
+
grid = torch.stack(grid, dim=0) # [2, W, H]
|
1004
|
+
|
1005
|
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
1006
|
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
1007
|
+
return pos_embed
|
527
1008
|
|
528
1009
|
|
529
|
-
def
|
1010
|
+
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
|
530
1011
|
"""
|
531
1012
|
RoPE for image tokens with 2d structure.
|
532
1013
|
|
@@ -555,6 +1036,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
|
555
1036
|
|
556
1037
|
|
557
1038
|
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
1039
|
+
"""
|
1040
|
+
Get 2D RoPE from grid.
|
1041
|
+
|
1042
|
+
Args:
|
1043
|
+
embed_dim: (`int`):
|
1044
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
1045
|
+
grid (`np.ndarray`):
|
1046
|
+
The grid of the positional embedding.
|
1047
|
+
use_real (`bool`):
|
1048
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
1049
|
+
|
1050
|
+
Returns:
|
1051
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
1052
|
+
"""
|
558
1053
|
assert embed_dim % 4 == 0
|
559
1054
|
|
560
1055
|
# use half of dimensions to encode grid_h
|
@@ -575,6 +1070,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
|
575
1070
|
|
576
1071
|
|
577
1072
|
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
1073
|
+
"""
|
1074
|
+
Get 2D RoPE from grid.
|
1075
|
+
|
1076
|
+
Args:
|
1077
|
+
embed_dim: (`int`):
|
1078
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
1079
|
+
grid (`np.ndarray`):
|
1080
|
+
The grid of the positional embedding.
|
1081
|
+
linear_factor (`float`):
|
1082
|
+
The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
|
1083
|
+
layer.
|
1084
|
+
ntk_factor (`float`):
|
1085
|
+
The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
|
1086
|
+
|
1087
|
+
Returns:
|
1088
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
1089
|
+
"""
|
578
1090
|
assert embed_dim % 4 == 0
|
579
1091
|
|
580
1092
|
emb_h = get_1d_rotary_pos_embed(
|
@@ -598,6 +1110,7 @@ def get_1d_rotary_pos_embed(
|
|
598
1110
|
linear_factor=1.0,
|
599
1111
|
ntk_factor=1.0,
|
600
1112
|
repeat_interleave_real=True,
|
1113
|
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
601
1114
|
):
|
602
1115
|
"""
|
603
1116
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
@@ -620,26 +1133,37 @@ def get_1d_rotary_pos_embed(
|
|
620
1133
|
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
621
1134
|
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
622
1135
|
Otherwise, they are concateanted with themselves.
|
1136
|
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
1137
|
+
the dtype of the frequency tensor.
|
623
1138
|
Returns:
|
624
1139
|
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
625
1140
|
"""
|
626
1141
|
assert dim % 2 == 0
|
627
1142
|
|
628
1143
|
if isinstance(pos, int):
|
629
|
-
pos =
|
1144
|
+
pos = torch.arange(pos)
|
1145
|
+
if isinstance(pos, np.ndarray):
|
1146
|
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
1147
|
+
|
630
1148
|
theta = theta * ntk_factor
|
631
|
-
freqs =
|
632
|
-
|
633
|
-
|
1149
|
+
freqs = (
|
1150
|
+
1.0
|
1151
|
+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
1152
|
+
/ linear_factor
|
1153
|
+
) # [D/2]
|
1154
|
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
634
1155
|
if use_real and repeat_interleave_real:
|
635
|
-
|
636
|
-
|
1156
|
+
# flux, hunyuan-dit, cogvideox
|
1157
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
1158
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
637
1159
|
return freqs_cos, freqs_sin
|
638
1160
|
elif use_real:
|
639
|
-
|
640
|
-
|
1161
|
+
# stable audio, allegro
|
1162
|
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
1163
|
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
641
1164
|
return freqs_cos, freqs_sin
|
642
1165
|
else:
|
1166
|
+
# lumina
|
643
1167
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
644
1168
|
return freqs_cis
|
645
1169
|
|
@@ -671,11 +1195,11 @@ def apply_rotary_emb(
|
|
671
1195
|
cos, sin = cos.to(x.device), sin.to(x.device)
|
672
1196
|
|
673
1197
|
if use_real_unbind_dim == -1:
|
674
|
-
#
|
1198
|
+
# Used for flux, cogvideox, hunyuan-dit
|
675
1199
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
676
1200
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
677
1201
|
elif use_real_unbind_dim == -2:
|
678
|
-
#
|
1202
|
+
# Used for Stable Audio
|
679
1203
|
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
680
1204
|
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
681
1205
|
else:
|
@@ -685,6 +1209,7 @@ def apply_rotary_emb(
|
|
685
1209
|
|
686
1210
|
return out
|
687
1211
|
else:
|
1212
|
+
# used for lumina
|
688
1213
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
689
1214
|
freqs_cis = freqs_cis.unsqueeze(2)
|
690
1215
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
@@ -692,6 +1217,54 @@ def apply_rotary_emb(
|
|
692
1217
|
return x_out.type_as(x)
|
693
1218
|
|
694
1219
|
|
1220
|
+
def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
1221
|
+
# TODO(aryan): rewrite
|
1222
|
+
def apply_1d_rope(tokens, pos, cos, sin):
|
1223
|
+
cos = F.embedding(pos, cos)[:, None, :, :]
|
1224
|
+
sin = F.embedding(pos, sin)[:, None, :, :]
|
1225
|
+
x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
|
1226
|
+
tokens_rotated = torch.cat((-x2, x1), dim=-1)
|
1227
|
+
return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
|
1228
|
+
|
1229
|
+
(t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
|
1230
|
+
t, h, w = x.chunk(3, dim=-1)
|
1231
|
+
t = apply_1d_rope(t, positions[0], t_cos, t_sin)
|
1232
|
+
h = apply_1d_rope(h, positions[1], h_cos, h_sin)
|
1233
|
+
w = apply_1d_rope(w, positions[2], w_cos, w_sin)
|
1234
|
+
x = torch.cat([t, h, w], dim=-1)
|
1235
|
+
return x
|
1236
|
+
|
1237
|
+
|
1238
|
+
class FluxPosEmbed(nn.Module):
|
1239
|
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
1240
|
+
def __init__(self, theta: int, axes_dim: List[int]):
|
1241
|
+
super().__init__()
|
1242
|
+
self.theta = theta
|
1243
|
+
self.axes_dim = axes_dim
|
1244
|
+
|
1245
|
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
1246
|
+
n_axes = ids.shape[-1]
|
1247
|
+
cos_out = []
|
1248
|
+
sin_out = []
|
1249
|
+
pos = ids.float()
|
1250
|
+
is_mps = ids.device.type == "mps"
|
1251
|
+
freqs_dtype = torch.float32 if is_mps else torch.float64
|
1252
|
+
for i in range(n_axes):
|
1253
|
+
cos, sin = get_1d_rotary_pos_embed(
|
1254
|
+
self.axes_dim[i],
|
1255
|
+
pos[:, i],
|
1256
|
+
theta=self.theta,
|
1257
|
+
repeat_interleave_real=True,
|
1258
|
+
use_real=True,
|
1259
|
+
freqs_dtype=freqs_dtype,
|
1260
|
+
)
|
1261
|
+
cos_out.append(cos)
|
1262
|
+
sin_out.append(sin)
|
1263
|
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
1264
|
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
1265
|
+
return freqs_cos, freqs_sin
|
1266
|
+
|
1267
|
+
|
695
1268
|
class TimestepEmbedding(nn.Module):
|
696
1269
|
def __init__(
|
697
1270
|
self,
|
@@ -962,7 +1535,7 @@ class ImageProjection(nn.Module):
|
|
962
1535
|
batch_size = image_embeds.shape[0]
|
963
1536
|
|
964
1537
|
# image
|
965
|
-
image_embeds = self.image_embeds(image_embeds)
|
1538
|
+
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
|
966
1539
|
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
967
1540
|
image_embeds = self.norm(image_embeds)
|
968
1541
|
return image_embeds
|
@@ -1058,6 +1631,39 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
|
1058
1631
|
return conditioning
|
1059
1632
|
|
1060
1633
|
|
1634
|
+
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
1635
|
+
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
1636
|
+
super().__init__()
|
1637
|
+
|
1638
|
+
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1639
|
+
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1640
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
|
1641
|
+
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
1642
|
+
|
1643
|
+
def forward(
|
1644
|
+
self,
|
1645
|
+
timestep: torch.Tensor,
|
1646
|
+
original_size: torch.Tensor,
|
1647
|
+
target_size: torch.Tensor,
|
1648
|
+
crop_coords: torch.Tensor,
|
1649
|
+
hidden_dtype: torch.dtype,
|
1650
|
+
) -> torch.Tensor:
|
1651
|
+
timesteps_proj = self.time_proj(timestep)
|
1652
|
+
|
1653
|
+
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
|
1654
|
+
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
|
1655
|
+
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
|
1656
|
+
|
1657
|
+
# (B, 3 * condition_dim)
|
1658
|
+
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
|
1659
|
+
|
1660
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
1661
|
+
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
1662
|
+
|
1663
|
+
conditioning = timesteps_emb + condition_emb
|
1664
|
+
return conditioning
|
1665
|
+
|
1666
|
+
|
1061
1667
|
class HunyuanDiTAttentionPool(nn.Module):
|
1062
1668
|
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
1063
1669
|
|
@@ -1193,6 +1799,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
|
1193
1799
|
return conditioning
|
1194
1800
|
|
1195
1801
|
|
1802
|
+
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
|
1803
|
+
def __init__(
|
1804
|
+
self,
|
1805
|
+
embedding_dim: int,
|
1806
|
+
pooled_projection_dim: int,
|
1807
|
+
text_embed_dim: int,
|
1808
|
+
time_embed_dim: int = 256,
|
1809
|
+
num_attention_heads: int = 8,
|
1810
|
+
) -> None:
|
1811
|
+
super().__init__()
|
1812
|
+
|
1813
|
+
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
|
1814
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
|
1815
|
+
self.pooler = MochiAttentionPool(
|
1816
|
+
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
|
1817
|
+
)
|
1818
|
+
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
|
1819
|
+
|
1820
|
+
def forward(
|
1821
|
+
self,
|
1822
|
+
timestep: torch.LongTensor,
|
1823
|
+
encoder_hidden_states: torch.Tensor,
|
1824
|
+
encoder_attention_mask: torch.Tensor,
|
1825
|
+
hidden_dtype: Optional[torch.dtype] = None,
|
1826
|
+
):
|
1827
|
+
time_proj = self.time_proj(timestep)
|
1828
|
+
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
|
1829
|
+
|
1830
|
+
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
|
1831
|
+
caption_proj = self.caption_proj(encoder_hidden_states)
|
1832
|
+
|
1833
|
+
conditioning = time_emb + pooled_projections
|
1834
|
+
return conditioning, caption_proj
|
1835
|
+
|
1836
|
+
|
1196
1837
|
class TextTimeEmbedding(nn.Module):
|
1197
1838
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
1198
1839
|
super().__init__()
|
@@ -1321,6 +1962,88 @@ class AttentionPooling(nn.Module):
|
|
1321
1962
|
return a[:, 0, :] # cls_token
|
1322
1963
|
|
1323
1964
|
|
1965
|
+
class MochiAttentionPool(nn.Module):
|
1966
|
+
def __init__(
|
1967
|
+
self,
|
1968
|
+
num_attention_heads: int,
|
1969
|
+
embed_dim: int,
|
1970
|
+
output_dim: Optional[int] = None,
|
1971
|
+
) -> None:
|
1972
|
+
super().__init__()
|
1973
|
+
|
1974
|
+
self.output_dim = output_dim or embed_dim
|
1975
|
+
self.num_attention_heads = num_attention_heads
|
1976
|
+
|
1977
|
+
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
|
1978
|
+
self.to_q = nn.Linear(embed_dim, embed_dim)
|
1979
|
+
self.to_out = nn.Linear(embed_dim, self.output_dim)
|
1980
|
+
|
1981
|
+
@staticmethod
|
1982
|
+
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
1983
|
+
"""
|
1984
|
+
Pool tokens in x using mask.
|
1985
|
+
|
1986
|
+
NOTE: We assume x does not require gradients.
|
1987
|
+
|
1988
|
+
Args:
|
1989
|
+
x: (B, L, D) tensor of tokens.
|
1990
|
+
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
1991
|
+
|
1992
|
+
Returns:
|
1993
|
+
pooled: (B, D) tensor of pooled tokens.
|
1994
|
+
"""
|
1995
|
+
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
1996
|
+
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
1997
|
+
mask = mask[:, :, None].to(dtype=x.dtype)
|
1998
|
+
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
1999
|
+
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
2000
|
+
return pooled
|
2001
|
+
|
2002
|
+
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
2003
|
+
r"""
|
2004
|
+
Args:
|
2005
|
+
x (`torch.Tensor`):
|
2006
|
+
Tensor of shape `(B, S, D)` of input tokens.
|
2007
|
+
mask (`torch.Tensor`):
|
2008
|
+
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
|
2009
|
+
|
2010
|
+
Returns:
|
2011
|
+
`torch.Tensor`:
|
2012
|
+
`(B, D)` tensor of pooled tokens.
|
2013
|
+
"""
|
2014
|
+
D = x.size(2)
|
2015
|
+
|
2016
|
+
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
2017
|
+
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
2018
|
+
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
2019
|
+
|
2020
|
+
# Average non-padding token features. These will be used as the query.
|
2021
|
+
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
2022
|
+
|
2023
|
+
# Concat pooled features to input sequence.
|
2024
|
+
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
2025
|
+
|
2026
|
+
# Compute queries, keys, values. Only the mean token is used to create a query.
|
2027
|
+
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
2028
|
+
q = self.to_q(x[:, 0]) # (B, D)
|
2029
|
+
|
2030
|
+
# Extract heads.
|
2031
|
+
head_dim = D // self.num_attention_heads
|
2032
|
+
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
2033
|
+
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
2034
|
+
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
2035
|
+
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
|
2036
|
+
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
2037
|
+
|
2038
|
+
# Compute attention.
|
2039
|
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
|
2040
|
+
|
2041
|
+
# Concatenate heads and run output.
|
2042
|
+
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
2043
|
+
x = self.to_out(x)
|
2044
|
+
return x
|
2045
|
+
|
2046
|
+
|
1324
2047
|
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
1325
2048
|
"""
|
1326
2049
|
Args:
|
@@ -1673,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
|
1673
2396
|
return out
|
1674
2397
|
|
1675
2398
|
|
2399
|
+
class IPAdapterTimeImageProjectionBlock(nn.Module):
|
2400
|
+
"""Block for IPAdapterTimeImageProjection.
|
2401
|
+
|
2402
|
+
Args:
|
2403
|
+
hidden_dim (`int`, defaults to 1280):
|
2404
|
+
The number of hidden channels.
|
2405
|
+
dim_head (`int`, defaults to 64):
|
2406
|
+
The number of head channels.
|
2407
|
+
heads (`int`, defaults to 20):
|
2408
|
+
Parallel attention heads.
|
2409
|
+
ffn_ratio (`int`, defaults to 4):
|
2410
|
+
The expansion ratio of feedforward network hidden layer channels.
|
2411
|
+
"""
|
2412
|
+
|
2413
|
+
def __init__(
|
2414
|
+
self,
|
2415
|
+
hidden_dim: int = 1280,
|
2416
|
+
dim_head: int = 64,
|
2417
|
+
heads: int = 20,
|
2418
|
+
ffn_ratio: int = 4,
|
2419
|
+
) -> None:
|
2420
|
+
super().__init__()
|
2421
|
+
from .attention import FeedForward
|
2422
|
+
|
2423
|
+
self.ln0 = nn.LayerNorm(hidden_dim)
|
2424
|
+
self.ln1 = nn.LayerNorm(hidden_dim)
|
2425
|
+
self.attn = Attention(
|
2426
|
+
query_dim=hidden_dim,
|
2427
|
+
cross_attention_dim=hidden_dim,
|
2428
|
+
dim_head=dim_head,
|
2429
|
+
heads=heads,
|
2430
|
+
bias=False,
|
2431
|
+
out_bias=False,
|
2432
|
+
)
|
2433
|
+
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
|
2434
|
+
|
2435
|
+
# AdaLayerNorm
|
2436
|
+
self.adaln_silu = nn.SiLU()
|
2437
|
+
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
|
2438
|
+
self.adaln_norm = nn.LayerNorm(hidden_dim)
|
2439
|
+
|
2440
|
+
# Set attention scale and fuse KV
|
2441
|
+
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
|
2442
|
+
self.attn.fuse_projections()
|
2443
|
+
self.attn.to_k = None
|
2444
|
+
self.attn.to_v = None
|
2445
|
+
|
2446
|
+
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
|
2447
|
+
"""Forward pass.
|
2448
|
+
|
2449
|
+
Args:
|
2450
|
+
x (`torch.Tensor`):
|
2451
|
+
Image features.
|
2452
|
+
latents (`torch.Tensor`):
|
2453
|
+
Latent features.
|
2454
|
+
timestep_emb (`torch.Tensor`):
|
2455
|
+
Timestep embedding.
|
2456
|
+
|
2457
|
+
Returns:
|
2458
|
+
`torch.Tensor`: Output latent features.
|
2459
|
+
"""
|
2460
|
+
|
2461
|
+
# Shift and scale for AdaLayerNorm
|
2462
|
+
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
|
2463
|
+
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
|
2464
|
+
|
2465
|
+
# Fused Attention
|
2466
|
+
residual = latents
|
2467
|
+
x = self.ln0(x)
|
2468
|
+
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
2469
|
+
|
2470
|
+
batch_size = latents.shape[0]
|
2471
|
+
|
2472
|
+
query = self.attn.to_q(latents)
|
2473
|
+
kv_input = torch.cat((x, latents), dim=-2)
|
2474
|
+
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
|
2475
|
+
|
2476
|
+
inner_dim = key.shape[-1]
|
2477
|
+
head_dim = inner_dim // self.attn.heads
|
2478
|
+
|
2479
|
+
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2480
|
+
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2481
|
+
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2482
|
+
|
2483
|
+
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
|
2484
|
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
2485
|
+
latents = weight @ value
|
2486
|
+
|
2487
|
+
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
|
2488
|
+
latents = self.attn.to_out[0](latents)
|
2489
|
+
latents = self.attn.to_out[1](latents)
|
2490
|
+
latents = latents + residual
|
2491
|
+
|
2492
|
+
## FeedForward
|
2493
|
+
residual = latents
|
2494
|
+
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
2495
|
+
return self.ff(latents) + residual
|
2496
|
+
|
2497
|
+
|
2498
|
+
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2499
|
+
class IPAdapterTimeImageProjection(nn.Module):
|
2500
|
+
"""Resampler of SD3 IP-Adapter with timestep embedding.
|
2501
|
+
|
2502
|
+
Args:
|
2503
|
+
embed_dim (`int`, defaults to 1152):
|
2504
|
+
The feature dimension.
|
2505
|
+
output_dim (`int`, defaults to 2432):
|
2506
|
+
The number of output channels.
|
2507
|
+
hidden_dim (`int`, defaults to 1280):
|
2508
|
+
The number of hidden channels.
|
2509
|
+
depth (`int`, defaults to 4):
|
2510
|
+
The number of blocks.
|
2511
|
+
dim_head (`int`, defaults to 64):
|
2512
|
+
The number of head channels.
|
2513
|
+
heads (`int`, defaults to 20):
|
2514
|
+
Parallel attention heads.
|
2515
|
+
num_queries (`int`, defaults to 64):
|
2516
|
+
The number of queries.
|
2517
|
+
ffn_ratio (`int`, defaults to 4):
|
2518
|
+
The expansion ratio of feedforward network hidden layer channels.
|
2519
|
+
timestep_in_dim (`int`, defaults to 320):
|
2520
|
+
The number of input channels for timestep embedding.
|
2521
|
+
timestep_flip_sin_to_cos (`bool`, defaults to True):
|
2522
|
+
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
|
2523
|
+
timestep_freq_shift (`int`, defaults to 0):
|
2524
|
+
Controls the timestep delta between frequencies between dimensions.
|
2525
|
+
"""
|
2526
|
+
|
2527
|
+
def __init__(
|
2528
|
+
self,
|
2529
|
+
embed_dim: int = 1152,
|
2530
|
+
output_dim: int = 2432,
|
2531
|
+
hidden_dim: int = 1280,
|
2532
|
+
depth: int = 4,
|
2533
|
+
dim_head: int = 64,
|
2534
|
+
heads: int = 20,
|
2535
|
+
num_queries: int = 64,
|
2536
|
+
ffn_ratio: int = 4,
|
2537
|
+
timestep_in_dim: int = 320,
|
2538
|
+
timestep_flip_sin_to_cos: bool = True,
|
2539
|
+
timestep_freq_shift: int = 0,
|
2540
|
+
) -> None:
|
2541
|
+
super().__init__()
|
2542
|
+
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
|
2543
|
+
self.proj_in = nn.Linear(embed_dim, hidden_dim)
|
2544
|
+
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
2545
|
+
self.norm_out = nn.LayerNorm(output_dim)
|
2546
|
+
self.layers = nn.ModuleList(
|
2547
|
+
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
2548
|
+
)
|
2549
|
+
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
2550
|
+
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
|
2551
|
+
|
2552
|
+
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
2553
|
+
"""Forward pass.
|
2554
|
+
|
2555
|
+
Args:
|
2556
|
+
x (`torch.Tensor`):
|
2557
|
+
Image features.
|
2558
|
+
timestep (`torch.Tensor`):
|
2559
|
+
Timestep in denoising process.
|
2560
|
+
Returns:
|
2561
|
+
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
2562
|
+
"""
|
2563
|
+
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
|
2564
|
+
timestep_emb = self.time_embedding(timestep_emb)
|
2565
|
+
|
2566
|
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
2567
|
+
|
2568
|
+
x = self.proj_in(x)
|
2569
|
+
x = x + timestep_emb[:, None]
|
2570
|
+
|
2571
|
+
for block in self.layers:
|
2572
|
+
latents = block(x, latents, timestep_emb)
|
2573
|
+
|
2574
|
+
latents = self.proj_out(latents)
|
2575
|
+
latents = self.norm_out(latents)
|
2576
|
+
|
2577
|
+
return latents, timestep_emb
|
2578
|
+
|
2579
|
+
|
1676
2580
|
class MultiIPAdapterImageProjection(nn.Module):
|
1677
2581
|
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
1678
2582
|
super().__init__()
|