diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +38 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +238 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +40 -7
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +6 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -342,15 +342,61 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
342
342
|
embed_dim: int = 1920,
|
343
343
|
text_embed_dim: int = 4096,
|
344
344
|
bias: bool = True,
|
345
|
+
sample_width: int = 90,
|
346
|
+
sample_height: int = 60,
|
347
|
+
sample_frames: int = 49,
|
348
|
+
temporal_compression_ratio: int = 4,
|
349
|
+
max_text_seq_length: int = 226,
|
350
|
+
spatial_interpolation_scale: float = 1.875,
|
351
|
+
temporal_interpolation_scale: float = 1.0,
|
352
|
+
use_positional_embeddings: bool = True,
|
353
|
+
use_learned_positional_embeddings: bool = True,
|
345
354
|
) -> None:
|
346
355
|
super().__init__()
|
356
|
+
|
347
357
|
self.patch_size = patch_size
|
358
|
+
self.embed_dim = embed_dim
|
359
|
+
self.sample_height = sample_height
|
360
|
+
self.sample_width = sample_width
|
361
|
+
self.sample_frames = sample_frames
|
362
|
+
self.temporal_compression_ratio = temporal_compression_ratio
|
363
|
+
self.max_text_seq_length = max_text_seq_length
|
364
|
+
self.spatial_interpolation_scale = spatial_interpolation_scale
|
365
|
+
self.temporal_interpolation_scale = temporal_interpolation_scale
|
366
|
+
self.use_positional_embeddings = use_positional_embeddings
|
367
|
+
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
348
368
|
|
349
369
|
self.proj = nn.Conv2d(
|
350
370
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
351
371
|
)
|
352
372
|
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
353
373
|
|
374
|
+
if use_positional_embeddings or use_learned_positional_embeddings:
|
375
|
+
persistent = use_learned_positional_embeddings
|
376
|
+
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
377
|
+
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
378
|
+
|
379
|
+
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
380
|
+
post_patch_height = sample_height // self.patch_size
|
381
|
+
post_patch_width = sample_width // self.patch_size
|
382
|
+
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
383
|
+
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
384
|
+
|
385
|
+
pos_embedding = get_3d_sincos_pos_embed(
|
386
|
+
self.embed_dim,
|
387
|
+
(post_patch_width, post_patch_height),
|
388
|
+
post_time_compression_frames,
|
389
|
+
self.spatial_interpolation_scale,
|
390
|
+
self.temporal_interpolation_scale,
|
391
|
+
)
|
392
|
+
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
393
|
+
joint_pos_embedding = torch.zeros(
|
394
|
+
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
395
|
+
)
|
396
|
+
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
397
|
+
|
398
|
+
return joint_pos_embedding
|
399
|
+
|
354
400
|
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
355
401
|
r"""
|
356
402
|
Args:
|
@@ -371,9 +417,85 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
371
417
|
embeds = torch.cat(
|
372
418
|
[text_embeds, image_embeds], dim=1
|
373
419
|
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
420
|
+
|
421
|
+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
422
|
+
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
|
423
|
+
raise ValueError(
|
424
|
+
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
|
425
|
+
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
|
426
|
+
)
|
427
|
+
|
428
|
+
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
429
|
+
|
430
|
+
if (
|
431
|
+
self.sample_height != height
|
432
|
+
or self.sample_width != width
|
433
|
+
or self.sample_frames != pre_time_compression_frames
|
434
|
+
):
|
435
|
+
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
436
|
+
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
437
|
+
else:
|
438
|
+
pos_embedding = self.pos_embedding
|
439
|
+
|
440
|
+
embeds = embeds + pos_embedding
|
441
|
+
|
374
442
|
return embeds
|
375
443
|
|
376
444
|
|
445
|
+
class CogView3PlusPatchEmbed(nn.Module):
|
446
|
+
def __init__(
|
447
|
+
self,
|
448
|
+
in_channels: int = 16,
|
449
|
+
hidden_size: int = 2560,
|
450
|
+
patch_size: int = 2,
|
451
|
+
text_hidden_size: int = 4096,
|
452
|
+
pos_embed_max_size: int = 128,
|
453
|
+
):
|
454
|
+
super().__init__()
|
455
|
+
self.in_channels = in_channels
|
456
|
+
self.hidden_size = hidden_size
|
457
|
+
self.patch_size = patch_size
|
458
|
+
self.text_hidden_size = text_hidden_size
|
459
|
+
self.pos_embed_max_size = pos_embed_max_size
|
460
|
+
# Linear projection for image patches
|
461
|
+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
462
|
+
|
463
|
+
# Linear projection for text embeddings
|
464
|
+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
465
|
+
|
466
|
+
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
|
467
|
+
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
|
468
|
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
|
469
|
+
|
470
|
+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
471
|
+
batch_size, channel, height, width = hidden_states.shape
|
472
|
+
|
473
|
+
if height % self.patch_size != 0 or width % self.patch_size != 0:
|
474
|
+
raise ValueError("Height and width must be divisible by patch size")
|
475
|
+
|
476
|
+
height = height // self.patch_size
|
477
|
+
width = width // self.patch_size
|
478
|
+
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
|
479
|
+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
|
480
|
+
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
|
481
|
+
|
482
|
+
# Project the patches
|
483
|
+
hidden_states = self.proj(hidden_states)
|
484
|
+
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
485
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
486
|
+
|
487
|
+
# Calculate text_length
|
488
|
+
text_length = encoder_hidden_states.shape[1]
|
489
|
+
|
490
|
+
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
|
491
|
+
text_pos_embed = torch.zeros(
|
492
|
+
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
|
493
|
+
)
|
494
|
+
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
|
495
|
+
|
496
|
+
return (hidden_states + pos_embed).to(hidden_states.dtype)
|
497
|
+
|
498
|
+
|
377
499
|
def get_3d_rotary_pos_embed(
|
378
500
|
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
379
501
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
@@ -391,15 +513,16 @@ def get_3d_rotary_pos_embed(
|
|
391
513
|
The size of the temporal dimension.
|
392
514
|
theta (`float`):
|
393
515
|
Scaling factor for frequency computation.
|
394
|
-
use_real (`bool`):
|
395
|
-
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
396
516
|
|
397
517
|
Returns:
|
398
518
|
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
399
519
|
"""
|
520
|
+
if use_real is not True:
|
521
|
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
400
522
|
start, stop = crops_coords
|
401
|
-
|
402
|
-
|
523
|
+
grid_size_h, grid_size_w = grid_size
|
524
|
+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
525
|
+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
403
526
|
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
404
527
|
|
405
528
|
# Compute dimensions for each axis
|
@@ -408,54 +531,37 @@ def get_3d_rotary_pos_embed(
|
|
408
531
|
dim_w = embed_dim // 8 * 3
|
409
532
|
|
410
533
|
# Temporal frequencies
|
411
|
-
freqs_t =
|
412
|
-
grid_t = torch.from_numpy(grid_t).float()
|
413
|
-
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
414
|
-
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
415
|
-
|
534
|
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
416
535
|
# Spatial frequencies for height and width
|
417
|
-
freqs_h =
|
418
|
-
freqs_w =
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
),
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
t, h, w, d = freqs.shape
|
448
|
-
freqs = freqs.view(t * h * w, d)
|
449
|
-
|
450
|
-
# Generate sine and cosine components
|
451
|
-
sin = freqs.sin()
|
452
|
-
cos = freqs.cos()
|
453
|
-
|
454
|
-
if use_real:
|
455
|
-
return cos, sin
|
456
|
-
else:
|
457
|
-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
458
|
-
return freqs_cis
|
536
|
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
537
|
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
538
|
+
|
539
|
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
540
|
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
541
|
+
freqs_t = freqs_t[:, None, None, :].expand(
|
542
|
+
-1, grid_size_h, grid_size_w, -1
|
543
|
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
544
|
+
freqs_h = freqs_h[None, :, None, :].expand(
|
545
|
+
temporal_size, -1, grid_size_w, -1
|
546
|
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
547
|
+
freqs_w = freqs_w[None, None, :, :].expand(
|
548
|
+
temporal_size, grid_size_h, -1, -1
|
549
|
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
550
|
+
|
551
|
+
freqs = torch.cat(
|
552
|
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
553
|
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
554
|
+
freqs = freqs.view(
|
555
|
+
temporal_size * grid_size_h * grid_size_w, -1
|
556
|
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
557
|
+
return freqs
|
558
|
+
|
559
|
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
560
|
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
561
|
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
562
|
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
563
|
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
564
|
+
return cos, sin
|
459
565
|
|
460
566
|
|
461
567
|
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
@@ -530,6 +636,7 @@ def get_1d_rotary_pos_embed(
|
|
530
636
|
linear_factor=1.0,
|
531
637
|
ntk_factor=1.0,
|
532
638
|
repeat_interleave_real=True,
|
639
|
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
533
640
|
):
|
534
641
|
"""
|
535
642
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
@@ -552,26 +659,37 @@ def get_1d_rotary_pos_embed(
|
|
552
659
|
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
553
660
|
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
554
661
|
Otherwise, they are concateanted with themselves.
|
662
|
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
663
|
+
the dtype of the frequency tensor.
|
555
664
|
Returns:
|
556
665
|
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
557
666
|
"""
|
558
667
|
assert dim % 2 == 0
|
559
668
|
|
560
669
|
if isinstance(pos, int):
|
561
|
-
pos =
|
670
|
+
pos = torch.arange(pos)
|
671
|
+
if isinstance(pos, np.ndarray):
|
672
|
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
673
|
+
|
562
674
|
theta = theta * ntk_factor
|
563
|
-
freqs =
|
564
|
-
|
565
|
-
|
675
|
+
freqs = (
|
676
|
+
1.0
|
677
|
+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
678
|
+
/ linear_factor
|
679
|
+
) # [D/2]
|
680
|
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
566
681
|
if use_real and repeat_interleave_real:
|
567
|
-
|
568
|
-
|
682
|
+
# flux, hunyuan-dit, cogvideox
|
683
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
684
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
569
685
|
return freqs_cos, freqs_sin
|
570
686
|
elif use_real:
|
571
|
-
|
572
|
-
|
687
|
+
# stable audio
|
688
|
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
689
|
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
573
690
|
return freqs_cos, freqs_sin
|
574
691
|
else:
|
692
|
+
# lumina
|
575
693
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
576
694
|
return freqs_cis
|
577
695
|
|
@@ -603,11 +721,11 @@ def apply_rotary_emb(
|
|
603
721
|
cos, sin = cos.to(x.device), sin.to(x.device)
|
604
722
|
|
605
723
|
if use_real_unbind_dim == -1:
|
606
|
-
#
|
724
|
+
# Used for flux, cogvideox, hunyuan-dit
|
607
725
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
608
726
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
609
727
|
elif use_real_unbind_dim == -2:
|
610
|
-
#
|
728
|
+
# Used for Stable Audio
|
611
729
|
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
612
730
|
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
613
731
|
else:
|
@@ -617,6 +735,7 @@ def apply_rotary_emb(
|
|
617
735
|
|
618
736
|
return out
|
619
737
|
else:
|
738
|
+
# used for lumina
|
620
739
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
621
740
|
freqs_cis = freqs_cis.unsqueeze(2)
|
622
741
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
@@ -624,6 +743,31 @@ def apply_rotary_emb(
|
|
624
743
|
return x_out.type_as(x)
|
625
744
|
|
626
745
|
|
746
|
+
class FluxPosEmbed(nn.Module):
|
747
|
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
748
|
+
def __init__(self, theta: int, axes_dim: List[int]):
|
749
|
+
super().__init__()
|
750
|
+
self.theta = theta
|
751
|
+
self.axes_dim = axes_dim
|
752
|
+
|
753
|
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
754
|
+
n_axes = ids.shape[-1]
|
755
|
+
cos_out = []
|
756
|
+
sin_out = []
|
757
|
+
pos = ids.float()
|
758
|
+
is_mps = ids.device.type == "mps"
|
759
|
+
freqs_dtype = torch.float32 if is_mps else torch.float64
|
760
|
+
for i in range(n_axes):
|
761
|
+
cos, sin = get_1d_rotary_pos_embed(
|
762
|
+
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
|
763
|
+
)
|
764
|
+
cos_out.append(cos)
|
765
|
+
sin_out.append(sin)
|
766
|
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
767
|
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
768
|
+
return freqs_cos, freqs_sin
|
769
|
+
|
770
|
+
|
627
771
|
class TimestepEmbedding(nn.Module):
|
628
772
|
def __init__(
|
629
773
|
self,
|
@@ -990,6 +1134,39 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
|
990
1134
|
return conditioning
|
991
1135
|
|
992
1136
|
|
1137
|
+
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
1138
|
+
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
1139
|
+
super().__init__()
|
1140
|
+
|
1141
|
+
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1142
|
+
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
1143
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
|
1144
|
+
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
1145
|
+
|
1146
|
+
def forward(
|
1147
|
+
self,
|
1148
|
+
timestep: torch.Tensor,
|
1149
|
+
original_size: torch.Tensor,
|
1150
|
+
target_size: torch.Tensor,
|
1151
|
+
crop_coords: torch.Tensor,
|
1152
|
+
hidden_dtype: torch.dtype,
|
1153
|
+
) -> torch.Tensor:
|
1154
|
+
timesteps_proj = self.time_proj(timestep)
|
1155
|
+
|
1156
|
+
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
|
1157
|
+
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
|
1158
|
+
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
|
1159
|
+
|
1160
|
+
# (B, 3 * condition_dim)
|
1161
|
+
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
|
1162
|
+
|
1163
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
1164
|
+
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
1165
|
+
|
1166
|
+
conditioning = timesteps_emb + condition_emb
|
1167
|
+
return conditioning
|
1168
|
+
|
1169
|
+
|
993
1170
|
class HunyuanDiTAttentionPool(nn.Module):
|
994
1171
|
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
995
1172
|
|
@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
|
|
29
29
|
"""Returns the positional encoding (same as Tensor2Tensor).
|
30
30
|
|
31
31
|
Args:
|
32
|
-
timesteps
|
33
|
-
|
34
|
-
embedding_dim:
|
35
|
-
|
36
|
-
|
32
|
+
timesteps (`jnp.ndarray` of shape `(N,)`):
|
33
|
+
A 1-D array of N indices, one per batch element. These may be fractional.
|
34
|
+
embedding_dim (`int`):
|
35
|
+
The number of output channels.
|
36
|
+
freq_shift (`float`, *optional*, defaults to `1`):
|
37
|
+
Shift applied to the frequency scaling of the embeddings.
|
38
|
+
min_timescale (`float`, *optional*, defaults to `1`):
|
39
|
+
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
|
40
|
+
max_timescale (`float`, *optional*, defaults to `1.0e4`):
|
41
|
+
The largest time unit used in the sinusoidal calculation.
|
42
|
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
43
|
+
Whether to flip the order of sinusoidal components to cosine first.
|
44
|
+
scale (`float`, *optional*, defaults to `1.0`):
|
45
|
+
A scaling factor applied to the positional embeddings.
|
46
|
+
|
37
47
|
Returns:
|
38
48
|
a Tensor of timing signals [N, num_channels]
|
39
49
|
"""
|
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
|
|
61
71
|
|
62
72
|
Args:
|
63
73
|
time_embed_dim (`int`, *optional*, defaults to `32`):
|
64
|
-
|
65
|
-
dtype (
|
66
|
-
|
74
|
+
Time step embedding dimension.
|
75
|
+
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
76
|
+
The data type for the embedding parameters.
|
67
77
|
"""
|
68
78
|
|
69
79
|
time_embed_dim: int = 32
|
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
|
|
83
93
|
|
84
94
|
Args:
|
85
95
|
dim (`int`, *optional*, defaults to `32`):
|
86
|
-
|
96
|
+
Time step embedding dimension.
|
97
|
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
98
|
+
Whether to flip the sinusoidal function from sine to cosine.
|
99
|
+
freq_shift (`float`, *optional*, defaults to `1`):
|
100
|
+
Frequency shift applied to the sinusoidal embeddings.
|
87
101
|
"""
|
88
102
|
|
89
103
|
dim: int = 32
|