diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -35,10 +35,21 @@ def get_timestep_embedding(
|
|
35
35
|
"""
|
36
36
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
37
37
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
38
|
+
Args
|
39
|
+
timesteps (torch.Tensor):
|
40
|
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
41
|
+
embedding_dim (int):
|
42
|
+
the dimension of the output.
|
43
|
+
flip_sin_to_cos (bool):
|
44
|
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
45
|
+
downscale_freq_shift (float):
|
46
|
+
Controls the delta between frequencies between dimensions
|
47
|
+
scale (float):
|
48
|
+
Scaling factor applied to the embeddings.
|
49
|
+
max_period (int):
|
50
|
+
Controls the maximum frequency of the embeddings
|
51
|
+
Returns
|
52
|
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
42
53
|
"""
|
43
54
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
44
55
|
|
@@ -67,6 +78,53 @@ def get_timestep_embedding(
|
|
67
78
|
return emb
|
68
79
|
|
69
80
|
|
81
|
+
def get_3d_sincos_pos_embed(
|
82
|
+
embed_dim: int,
|
83
|
+
spatial_size: Union[int, Tuple[int, int]],
|
84
|
+
temporal_size: int,
|
85
|
+
spatial_interpolation_scale: float = 1.0,
|
86
|
+
temporal_interpolation_scale: float = 1.0,
|
87
|
+
) -> np.ndarray:
|
88
|
+
r"""
|
89
|
+
Args:
|
90
|
+
embed_dim (`int`):
|
91
|
+
spatial_size (`int` or `Tuple[int, int]`):
|
92
|
+
temporal_size (`int`):
|
93
|
+
spatial_interpolation_scale (`float`, defaults to 1.0):
|
94
|
+
temporal_interpolation_scale (`float`, defaults to 1.0):
|
95
|
+
"""
|
96
|
+
if embed_dim % 4 != 0:
|
97
|
+
raise ValueError("`embed_dim` must be divisible by 4")
|
98
|
+
if isinstance(spatial_size, int):
|
99
|
+
spatial_size = (spatial_size, spatial_size)
|
100
|
+
|
101
|
+
embed_dim_spatial = 3 * embed_dim // 4
|
102
|
+
embed_dim_temporal = embed_dim // 4
|
103
|
+
|
104
|
+
# 1. Spatial
|
105
|
+
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
|
106
|
+
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
|
107
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
108
|
+
grid = np.stack(grid, axis=0)
|
109
|
+
|
110
|
+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
111
|
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
112
|
+
|
113
|
+
# 2. Temporal
|
114
|
+
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
|
115
|
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
116
|
+
|
117
|
+
# 3. Concat
|
118
|
+
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
119
|
+
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
|
120
|
+
|
121
|
+
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
122
|
+
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
|
123
|
+
|
124
|
+
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
|
125
|
+
return pos_embed
|
126
|
+
|
127
|
+
|
70
128
|
def get_2d_sincos_pos_embed(
|
71
129
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
72
130
|
):
|
@@ -230,6 +288,176 @@ class PatchEmbed(nn.Module):
|
|
230
288
|
return (latent + pos_embed).to(latent.dtype)
|
231
289
|
|
232
290
|
|
291
|
+
class LuminaPatchEmbed(nn.Module):
|
292
|
+
"""2D Image to Patch Embedding with support for Lumina-T2X"""
|
293
|
+
|
294
|
+
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
295
|
+
super().__init__()
|
296
|
+
self.patch_size = patch_size
|
297
|
+
self.proj = nn.Linear(
|
298
|
+
in_features=patch_size * patch_size * in_channels,
|
299
|
+
out_features=embed_dim,
|
300
|
+
bias=bias,
|
301
|
+
)
|
302
|
+
|
303
|
+
def forward(self, x, freqs_cis):
|
304
|
+
"""
|
305
|
+
Patchifies and embeds the input tensor(s).
|
306
|
+
|
307
|
+
Args:
|
308
|
+
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
309
|
+
|
310
|
+
Returns:
|
311
|
+
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
312
|
+
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
|
313
|
+
frequency tensor(s).
|
314
|
+
"""
|
315
|
+
freqs_cis = freqs_cis.to(x[0].device)
|
316
|
+
patch_height = patch_width = self.patch_size
|
317
|
+
batch_size, channel, height, width = x.size()
|
318
|
+
height_tokens, width_tokens = height // patch_height, width // patch_width
|
319
|
+
|
320
|
+
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
|
321
|
+
0, 2, 4, 1, 3, 5
|
322
|
+
)
|
323
|
+
x = x.flatten(3)
|
324
|
+
x = self.proj(x)
|
325
|
+
x = x.flatten(1, 2)
|
326
|
+
|
327
|
+
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
|
328
|
+
|
329
|
+
return (
|
330
|
+
x,
|
331
|
+
mask,
|
332
|
+
[(height, width)] * batch_size,
|
333
|
+
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
|
334
|
+
)
|
335
|
+
|
336
|
+
|
337
|
+
class CogVideoXPatchEmbed(nn.Module):
|
338
|
+
def __init__(
|
339
|
+
self,
|
340
|
+
patch_size: int = 2,
|
341
|
+
in_channels: int = 16,
|
342
|
+
embed_dim: int = 1920,
|
343
|
+
text_embed_dim: int = 4096,
|
344
|
+
bias: bool = True,
|
345
|
+
) -> None:
|
346
|
+
super().__init__()
|
347
|
+
self.patch_size = patch_size
|
348
|
+
|
349
|
+
self.proj = nn.Conv2d(
|
350
|
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
351
|
+
)
|
352
|
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
353
|
+
|
354
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
355
|
+
r"""
|
356
|
+
Args:
|
357
|
+
text_embeds (`torch.Tensor`):
|
358
|
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
359
|
+
image_embeds (`torch.Tensor`):
|
360
|
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
361
|
+
"""
|
362
|
+
text_embeds = self.text_proj(text_embeds)
|
363
|
+
|
364
|
+
batch, num_frames, channels, height, width = image_embeds.shape
|
365
|
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
366
|
+
image_embeds = self.proj(image_embeds)
|
367
|
+
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
368
|
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
369
|
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
370
|
+
|
371
|
+
embeds = torch.cat(
|
372
|
+
[text_embeds, image_embeds], dim=1
|
373
|
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
374
|
+
return embeds
|
375
|
+
|
376
|
+
|
377
|
+
def get_3d_rotary_pos_embed(
|
378
|
+
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
379
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
380
|
+
"""
|
381
|
+
RoPE for video tokens with 3D structure.
|
382
|
+
|
383
|
+
Args:
|
384
|
+
embed_dim: (`int`):
|
385
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
386
|
+
crops_coords (`Tuple[int]`):
|
387
|
+
The top-left and bottom-right coordinates of the crop.
|
388
|
+
grid_size (`Tuple[int]`):
|
389
|
+
The grid size of the spatial positional embedding (height, width).
|
390
|
+
temporal_size (`int`):
|
391
|
+
The size of the temporal dimension.
|
392
|
+
theta (`float`):
|
393
|
+
Scaling factor for frequency computation.
|
394
|
+
use_real (`bool`):
|
395
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
396
|
+
|
397
|
+
Returns:
|
398
|
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
399
|
+
"""
|
400
|
+
start, stop = crops_coords
|
401
|
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
402
|
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
403
|
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
404
|
+
|
405
|
+
# Compute dimensions for each axis
|
406
|
+
dim_t = embed_dim // 4
|
407
|
+
dim_h = embed_dim // 8 * 3
|
408
|
+
dim_w = embed_dim // 8 * 3
|
409
|
+
|
410
|
+
# Temporal frequencies
|
411
|
+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
412
|
+
grid_t = torch.from_numpy(grid_t).float()
|
413
|
+
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
414
|
+
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
415
|
+
|
416
|
+
# Spatial frequencies for height and width
|
417
|
+
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
418
|
+
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
419
|
+
grid_h = torch.from_numpy(grid_h).float()
|
420
|
+
grid_w = torch.from_numpy(grid_w).float()
|
421
|
+
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
422
|
+
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
423
|
+
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
424
|
+
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
425
|
+
|
426
|
+
# Broadcast and concatenate tensors along specified dimension
|
427
|
+
def broadcast(tensors, dim=-1):
|
428
|
+
num_tensors = len(tensors)
|
429
|
+
shape_lens = {len(t.shape) for t in tensors}
|
430
|
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
431
|
+
shape_len = list(shape_lens)[0]
|
432
|
+
dim = (dim + shape_len) if dim < 0 else dim
|
433
|
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
434
|
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
435
|
+
assert all(
|
436
|
+
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
437
|
+
), "invalid dimensions for broadcastable concatenation"
|
438
|
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
439
|
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
440
|
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
441
|
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
442
|
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
443
|
+
return torch.cat(tensors, dim=dim)
|
444
|
+
|
445
|
+
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
446
|
+
|
447
|
+
t, h, w, d = freqs.shape
|
448
|
+
freqs = freqs.view(t * h * w, d)
|
449
|
+
|
450
|
+
# Generate sine and cosine components
|
451
|
+
sin = freqs.sin()
|
452
|
+
cos = freqs.cos()
|
453
|
+
|
454
|
+
if use_real:
|
455
|
+
return cos, sin
|
456
|
+
else:
|
457
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
458
|
+
return freqs_cis
|
459
|
+
|
460
|
+
|
233
461
|
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
234
462
|
"""
|
235
463
|
RoPE for image tokens with 2d structure.
|
@@ -245,7 +473,7 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
|
245
473
|
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
246
474
|
|
247
475
|
Returns:
|
248
|
-
`torch.Tensor`: positional
|
476
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
249
477
|
"""
|
250
478
|
start, stop = crops_coords
|
251
479
|
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
@@ -262,19 +490,47 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
|
262
490
|
assert embed_dim % 4 == 0
|
263
491
|
|
264
492
|
# use half of dimensions to encode grid_h
|
265
|
-
emb_h = get_1d_rotary_pos_embed(
|
266
|
-
|
493
|
+
emb_h = get_1d_rotary_pos_embed(
|
494
|
+
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
495
|
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
496
|
+
emb_w = get_1d_rotary_pos_embed(
|
497
|
+
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
498
|
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
267
499
|
|
268
500
|
if use_real:
|
269
|
-
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D
|
270
|
-
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D
|
501
|
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
502
|
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
271
503
|
return cos, sin
|
272
504
|
else:
|
273
505
|
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
274
506
|
return emb
|
275
507
|
|
276
508
|
|
277
|
-
def
|
509
|
+
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
510
|
+
assert embed_dim % 4 == 0
|
511
|
+
|
512
|
+
emb_h = get_1d_rotary_pos_embed(
|
513
|
+
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
514
|
+
) # (H, D/4)
|
515
|
+
emb_w = get_1d_rotary_pos_embed(
|
516
|
+
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
517
|
+
) # (W, D/4)
|
518
|
+
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
519
|
+
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
520
|
+
|
521
|
+
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
522
|
+
return emb
|
523
|
+
|
524
|
+
|
525
|
+
def get_1d_rotary_pos_embed(
|
526
|
+
dim: int,
|
527
|
+
pos: Union[np.ndarray, int],
|
528
|
+
theta: float = 10000.0,
|
529
|
+
use_real=False,
|
530
|
+
linear_factor=1.0,
|
531
|
+
ntk_factor=1.0,
|
532
|
+
repeat_interleave_real=True,
|
533
|
+
):
|
278
534
|
"""
|
279
535
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
280
536
|
|
@@ -289,19 +545,32 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
|
|
289
545
|
Scaling factor for frequency computation. Defaults to 10000.0.
|
290
546
|
use_real (`bool`, *optional*):
|
291
547
|
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
292
|
-
|
548
|
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
549
|
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
550
|
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
551
|
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
552
|
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
553
|
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
554
|
+
Otherwise, they are concateanted with themselves.
|
293
555
|
Returns:
|
294
556
|
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
295
557
|
"""
|
558
|
+
assert dim % 2 == 0
|
559
|
+
|
296
560
|
if isinstance(pos, int):
|
297
561
|
pos = np.arange(pos)
|
298
|
-
|
562
|
+
theta = theta * ntk_factor
|
563
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
|
299
564
|
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
300
565
|
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
301
|
-
if use_real:
|
566
|
+
if use_real and repeat_interleave_real:
|
302
567
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
303
568
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
304
569
|
return freqs_cos, freqs_sin
|
570
|
+
elif use_real:
|
571
|
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
|
572
|
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
|
573
|
+
return freqs_cos, freqs_sin
|
305
574
|
else:
|
306
575
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
307
576
|
return freqs_cis
|
@@ -310,6 +579,8 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
|
|
310
579
|
def apply_rotary_emb(
|
311
580
|
x: torch.Tensor,
|
312
581
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
582
|
+
use_real: bool = True,
|
583
|
+
use_real_unbind_dim: int = -1,
|
313
584
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
314
585
|
"""
|
315
586
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
@@ -325,16 +596,32 @@ def apply_rotary_emb(
|
|
325
596
|
Returns:
|
326
597
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
327
598
|
"""
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
599
|
+
if use_real:
|
600
|
+
cos, sin = freqs_cis # [S, D]
|
601
|
+
cos = cos[None, None]
|
602
|
+
sin = sin[None, None]
|
603
|
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
604
|
+
|
605
|
+
if use_real_unbind_dim == -1:
|
606
|
+
# Use for example in Lumina
|
607
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
608
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
609
|
+
elif use_real_unbind_dim == -2:
|
610
|
+
# Use for example in Stable Audio
|
611
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
612
|
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
613
|
+
else:
|
614
|
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
615
|
+
|
616
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
332
617
|
|
333
|
-
|
334
|
-
|
335
|
-
|
618
|
+
return out
|
619
|
+
else:
|
620
|
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
621
|
+
freqs_cis = freqs_cis.unsqueeze(2)
|
622
|
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
336
623
|
|
337
|
-
|
624
|
+
return x_out.type_as(x)
|
338
625
|
|
339
626
|
|
340
627
|
class TimestepEmbedding(nn.Module):
|
@@ -386,11 +673,12 @@ class TimestepEmbedding(nn.Module):
|
|
386
673
|
|
387
674
|
|
388
675
|
class Timesteps(nn.Module):
|
389
|
-
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
676
|
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
390
677
|
super().__init__()
|
391
678
|
self.num_channels = num_channels
|
392
679
|
self.flip_sin_to_cos = flip_sin_to_cos
|
393
680
|
self.downscale_freq_shift = downscale_freq_shift
|
681
|
+
self.scale = scale
|
394
682
|
|
395
683
|
def forward(self, timesteps):
|
396
684
|
t_emb = get_timestep_embedding(
|
@@ -398,6 +686,7 @@ class Timesteps(nn.Module):
|
|
398
686
|
self.num_channels,
|
399
687
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
400
688
|
downscale_freq_shift=self.downscale_freq_shift,
|
689
|
+
scale=self.scale,
|
401
690
|
)
|
402
691
|
return t_emb
|
403
692
|
|
@@ -415,9 +704,10 @@ class GaussianFourierProjection(nn.Module):
|
|
415
704
|
|
416
705
|
if set_W_to_weight:
|
417
706
|
# to delete later
|
707
|
+
del self.weight
|
418
708
|
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
419
|
-
|
420
709
|
self.weight = self.W
|
710
|
+
del self.W
|
421
711
|
|
422
712
|
def forward(self, x):
|
423
713
|
if self.log:
|
@@ -676,6 +966,30 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
|
|
676
966
|
return conditioning
|
677
967
|
|
678
968
|
|
969
|
+
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
970
|
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
971
|
+
super().__init__()
|
972
|
+
|
973
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
974
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
975
|
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
976
|
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
977
|
+
|
978
|
+
def forward(self, timestep, guidance, pooled_projection):
|
979
|
+
timesteps_proj = self.time_proj(timestep)
|
980
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
981
|
+
|
982
|
+
guidance_proj = self.time_proj(guidance)
|
983
|
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
984
|
+
|
985
|
+
time_guidance_emb = timesteps_emb + guidance_emb
|
986
|
+
|
987
|
+
pooled_projections = self.text_embedder(pooled_projection)
|
988
|
+
conditioning = time_guidance_emb + pooled_projections
|
989
|
+
|
990
|
+
return conditioning
|
991
|
+
|
992
|
+
|
679
993
|
class HunyuanDiTAttentionPool(nn.Module):
|
680
994
|
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
681
995
|
|
@@ -717,18 +1031,33 @@ class HunyuanDiTAttentionPool(nn.Module):
|
|
717
1031
|
|
718
1032
|
|
719
1033
|
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
720
|
-
def __init__(
|
1034
|
+
def __init__(
|
1035
|
+
self,
|
1036
|
+
embedding_dim,
|
1037
|
+
pooled_projection_dim=1024,
|
1038
|
+
seq_len=256,
|
1039
|
+
cross_attention_dim=2048,
|
1040
|
+
use_style_cond_and_image_meta_size=True,
|
1041
|
+
):
|
721
1042
|
super().__init__()
|
722
1043
|
|
723
1044
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
724
1045
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
725
1046
|
|
1047
|
+
self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1048
|
+
|
726
1049
|
self.pooler = HunyuanDiTAttentionPool(
|
727
1050
|
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
728
1051
|
)
|
1052
|
+
|
729
1053
|
# Here we use a default learned embedder layer for future extension.
|
730
|
-
self.
|
731
|
-
|
1054
|
+
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
|
1055
|
+
if use_style_cond_and_image_meta_size:
|
1056
|
+
self.style_embedder = nn.Embedding(1, embedding_dim)
|
1057
|
+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
1058
|
+
else:
|
1059
|
+
extra_in_dim = pooled_projection_dim
|
1060
|
+
|
732
1061
|
self.extra_embedder = PixArtAlphaTextProjection(
|
733
1062
|
in_features=extra_in_dim,
|
734
1063
|
hidden_size=embedding_dim * 4,
|
@@ -743,21 +1072,59 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
|
743
1072
|
# extra condition1: text
|
744
1073
|
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
745
1074
|
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
1075
|
+
if self.use_style_cond_and_image_meta_size:
|
1076
|
+
# extra condition2: image meta size embedding
|
1077
|
+
image_meta_size = self.size_proj(image_meta_size.view(-1))
|
1078
|
+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
1079
|
+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
1080
|
+
|
1081
|
+
# extra condition3: style embedding
|
1082
|
+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
750
1083
|
|
751
|
-
|
752
|
-
|
1084
|
+
# Concatenate all extra vectors
|
1085
|
+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
1086
|
+
else:
|
1087
|
+
extra_cond = torch.cat([pooled_projections], dim=1)
|
753
1088
|
|
754
|
-
# Concatenate all extra vectors
|
755
|
-
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
756
1089
|
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
757
1090
|
|
758
1091
|
return conditioning
|
759
1092
|
|
760
1093
|
|
1094
|
+
class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
1095
|
+
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
|
1096
|
+
super().__init__()
|
1097
|
+
self.time_proj = Timesteps(
|
1098
|
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
1099
|
+
)
|
1100
|
+
|
1101
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
1102
|
+
|
1103
|
+
self.caption_embedder = nn.Sequential(
|
1104
|
+
nn.LayerNorm(cross_attention_dim),
|
1105
|
+
nn.Linear(
|
1106
|
+
cross_attention_dim,
|
1107
|
+
hidden_size,
|
1108
|
+
bias=True,
|
1109
|
+
),
|
1110
|
+
)
|
1111
|
+
|
1112
|
+
def forward(self, timestep, caption_feat, caption_mask):
|
1113
|
+
# timestep embedding:
|
1114
|
+
time_freq = self.time_proj(timestep)
|
1115
|
+
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
|
1116
|
+
|
1117
|
+
# caption condition embedding:
|
1118
|
+
caption_mask_float = caption_mask.float().unsqueeze(-1)
|
1119
|
+
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
|
1120
|
+
caption_feats_pool = caption_feats_pool.to(caption_feat)
|
1121
|
+
caption_embed = self.caption_embedder(caption_feats_pool)
|
1122
|
+
|
1123
|
+
conditioning = time_embed + caption_embed
|
1124
|
+
|
1125
|
+
return conditioning
|
1126
|
+
|
1127
|
+
|
761
1128
|
class TextTimeEmbedding(nn.Module):
|
762
1129
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
763
1130
|
super().__init__()
|
@@ -980,7 +1347,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
|
980
1347
|
|
981
1348
|
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
982
1349
|
|
983
|
-
# positionet with text and image
|
1350
|
+
# positionet with text and image information
|
984
1351
|
else:
|
985
1352
|
phrases_masks = phrases_masks.unsqueeze(-1)
|
986
1353
|
image_masks = image_masks.unsqueeze(-1)
|
@@ -1252,7 +1619,7 @@ class MultiIPAdapterImageProjection(nn.Module):
|
|
1252
1619
|
if not isinstance(image_embeds, list):
|
1253
1620
|
deprecation_message = (
|
1254
1621
|
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
1255
|
-
" Please make sure to update your script to pass `image_embeds` as a list of tensors to
|
1622
|
+
" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
|
1256
1623
|
)
|
1257
1624
|
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
1258
1625
|
image_embeds = [image_embeds.unsqueeze(1)]
|
@@ -191,7 +191,6 @@ def _fetch_index_file(
|
|
191
191
|
cache_dir,
|
192
192
|
variant,
|
193
193
|
force_download,
|
194
|
-
resume_download,
|
195
194
|
proxies,
|
196
195
|
local_files_only,
|
197
196
|
token,
|
@@ -216,12 +215,11 @@ def _fetch_index_file(
|
|
216
215
|
weights_name=index_file_in_repo,
|
217
216
|
cache_dir=cache_dir,
|
218
217
|
force_download=force_download,
|
219
|
-
resume_download=resume_download,
|
220
218
|
proxies=proxies,
|
221
219
|
local_files_only=local_files_only,
|
222
220
|
token=token,
|
223
221
|
revision=revision,
|
224
|
-
subfolder=
|
222
|
+
subfolder=None,
|
225
223
|
user_agent=user_agent,
|
226
224
|
commit_hash=commit_hash,
|
227
225
|
)
|
@@ -245,9 +245,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
245
245
|
force_download (`bool`, *optional*, defaults to `False`):
|
246
246
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
247
247
|
cached versions if they exist.
|
248
|
-
|
249
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
250
|
-
of Diffusers.
|
248
|
+
|
251
249
|
proxies (`Dict[str, str]`, *optional*):
|
252
250
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
253
251
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -296,7 +294,6 @@ class FlaxModelMixin(PushToHubMixin):
|
|
296
294
|
cache_dir = kwargs.pop("cache_dir", None)
|
297
295
|
force_download = kwargs.pop("force_download", False)
|
298
296
|
from_pt = kwargs.pop("from_pt", False)
|
299
|
-
resume_download = kwargs.pop("resume_download", None)
|
300
297
|
proxies = kwargs.pop("proxies", None)
|
301
298
|
local_files_only = kwargs.pop("local_files_only", False)
|
302
299
|
token = kwargs.pop("token", None)
|
@@ -316,7 +313,6 @@ class FlaxModelMixin(PushToHubMixin):
|
|
316
313
|
cache_dir=cache_dir,
|
317
314
|
return_unused_kwargs=True,
|
318
315
|
force_download=force_download,
|
319
|
-
resume_download=resume_download,
|
320
316
|
proxies=proxies,
|
321
317
|
local_files_only=local_files_only,
|
322
318
|
token=token,
|
@@ -362,7 +358,6 @@ class FlaxModelMixin(PushToHubMixin):
|
|
362
358
|
cache_dir=cache_dir,
|
363
359
|
force_download=force_download,
|
364
360
|
proxies=proxies,
|
365
|
-
resume_download=resume_download,
|
366
361
|
local_files_only=local_files_only,
|
367
362
|
token=token,
|
368
363
|
user_agent=user_agent,
|