diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -342,15 +342,61 @@ class CogVideoXPatchEmbed(nn.Module):
342
342
  embed_dim: int = 1920,
343
343
  text_embed_dim: int = 4096,
344
344
  bias: bool = True,
345
+ sample_width: int = 90,
346
+ sample_height: int = 60,
347
+ sample_frames: int = 49,
348
+ temporal_compression_ratio: int = 4,
349
+ max_text_seq_length: int = 226,
350
+ spatial_interpolation_scale: float = 1.875,
351
+ temporal_interpolation_scale: float = 1.0,
352
+ use_positional_embeddings: bool = True,
353
+ use_learned_positional_embeddings: bool = True,
345
354
  ) -> None:
346
355
  super().__init__()
356
+
347
357
  self.patch_size = patch_size
358
+ self.embed_dim = embed_dim
359
+ self.sample_height = sample_height
360
+ self.sample_width = sample_width
361
+ self.sample_frames = sample_frames
362
+ self.temporal_compression_ratio = temporal_compression_ratio
363
+ self.max_text_seq_length = max_text_seq_length
364
+ self.spatial_interpolation_scale = spatial_interpolation_scale
365
+ self.temporal_interpolation_scale = temporal_interpolation_scale
366
+ self.use_positional_embeddings = use_positional_embeddings
367
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
348
368
 
349
369
  self.proj = nn.Conv2d(
350
370
  in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
351
371
  )
352
372
  self.text_proj = nn.Linear(text_embed_dim, embed_dim)
353
373
 
374
+ if use_positional_embeddings or use_learned_positional_embeddings:
375
+ persistent = use_learned_positional_embeddings
376
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
377
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
378
+
379
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
380
+ post_patch_height = sample_height // self.patch_size
381
+ post_patch_width = sample_width // self.patch_size
382
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
383
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
384
+
385
+ pos_embedding = get_3d_sincos_pos_embed(
386
+ self.embed_dim,
387
+ (post_patch_width, post_patch_height),
388
+ post_time_compression_frames,
389
+ self.spatial_interpolation_scale,
390
+ self.temporal_interpolation_scale,
391
+ )
392
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
393
+ joint_pos_embedding = torch.zeros(
394
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
395
+ )
396
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
397
+
398
+ return joint_pos_embedding
399
+
354
400
  def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
355
401
  r"""
356
402
  Args:
@@ -371,9 +417,85 @@ class CogVideoXPatchEmbed(nn.Module):
371
417
  embeds = torch.cat(
372
418
  [text_embeds, image_embeds], dim=1
373
419
  ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
420
+
421
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
422
+ if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
423
+ raise ValueError(
424
+ "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
425
+ "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
426
+ )
427
+
428
+ pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
429
+
430
+ if (
431
+ self.sample_height != height
432
+ or self.sample_width != width
433
+ or self.sample_frames != pre_time_compression_frames
434
+ ):
435
+ pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
436
+ pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
437
+ else:
438
+ pos_embedding = self.pos_embedding
439
+
440
+ embeds = embeds + pos_embedding
441
+
374
442
  return embeds
375
443
 
376
444
 
445
+ class CogView3PlusPatchEmbed(nn.Module):
446
+ def __init__(
447
+ self,
448
+ in_channels: int = 16,
449
+ hidden_size: int = 2560,
450
+ patch_size: int = 2,
451
+ text_hidden_size: int = 4096,
452
+ pos_embed_max_size: int = 128,
453
+ ):
454
+ super().__init__()
455
+ self.in_channels = in_channels
456
+ self.hidden_size = hidden_size
457
+ self.patch_size = patch_size
458
+ self.text_hidden_size = text_hidden_size
459
+ self.pos_embed_max_size = pos_embed_max_size
460
+ # Linear projection for image patches
461
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
462
+
463
+ # Linear projection for text embeddings
464
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
465
+
466
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
467
+ pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
468
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
469
+
470
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
471
+ batch_size, channel, height, width = hidden_states.shape
472
+
473
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
474
+ raise ValueError("Height and width must be divisible by patch size")
475
+
476
+ height = height // self.patch_size
477
+ width = width // self.patch_size
478
+ hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
479
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
480
+ hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
481
+
482
+ # Project the patches
483
+ hidden_states = self.proj(hidden_states)
484
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
485
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
486
+
487
+ # Calculate text_length
488
+ text_length = encoder_hidden_states.shape[1]
489
+
490
+ image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
491
+ text_pos_embed = torch.zeros(
492
+ (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
493
+ )
494
+ pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
495
+
496
+ return (hidden_states + pos_embed).to(hidden_states.dtype)
497
+
498
+
377
499
  def get_3d_rotary_pos_embed(
378
500
  embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
379
501
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@@ -391,15 +513,16 @@ def get_3d_rotary_pos_embed(
391
513
  The size of the temporal dimension.
392
514
  theta (`float`):
393
515
  Scaling factor for frequency computation.
394
- use_real (`bool`):
395
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
396
516
 
397
517
  Returns:
398
518
  `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
399
519
  """
520
+ if use_real is not True:
521
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
400
522
  start, stop = crops_coords
401
- grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
402
- grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
523
+ grid_size_h, grid_size_w = grid_size
524
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
525
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
403
526
  grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
404
527
 
405
528
  # Compute dimensions for each axis
@@ -408,54 +531,37 @@ def get_3d_rotary_pos_embed(
408
531
  dim_w = embed_dim // 8 * 3
409
532
 
410
533
  # Temporal frequencies
411
- freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
412
- grid_t = torch.from_numpy(grid_t).float()
413
- freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
414
- freqs_t = freqs_t.repeat_interleave(2, dim=-1)
415
-
534
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
416
535
  # Spatial frequencies for height and width
417
- freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
418
- freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
419
- grid_h = torch.from_numpy(grid_h).float()
420
- grid_w = torch.from_numpy(grid_w).float()
421
- freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
422
- freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
423
- freqs_h = freqs_h.repeat_interleave(2, dim=-1)
424
- freqs_w = freqs_w.repeat_interleave(2, dim=-1)
425
-
426
- # Broadcast and concatenate tensors along specified dimension
427
- def broadcast(tensors, dim=-1):
428
- num_tensors = len(tensors)
429
- shape_lens = {len(t.shape) for t in tensors}
430
- assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
431
- shape_len = list(shape_lens)[0]
432
- dim = (dim + shape_len) if dim < 0 else dim
433
- dims = list(zip(*(list(t.shape) for t in tensors)))
434
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
435
- assert all(
436
- [*(len(set(t[1])) <= 2 for t in expandable_dims)]
437
- ), "invalid dimensions for broadcastable concatenation"
438
- max_dims = [(t[0], max(t[1])) for t in expandable_dims]
439
- expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
440
- expanded_dims.insert(dim, (dim, dims[dim]))
441
- expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
442
- tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
443
- return torch.cat(tensors, dim=dim)
444
-
445
- freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
446
-
447
- t, h, w, d = freqs.shape
448
- freqs = freqs.view(t * h * w, d)
449
-
450
- # Generate sine and cosine components
451
- sin = freqs.sin()
452
- cos = freqs.cos()
453
-
454
- if use_real:
455
- return cos, sin
456
- else:
457
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
458
- return freqs_cis
536
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
537
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
538
+
539
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
540
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
541
+ freqs_t = freqs_t[:, None, None, :].expand(
542
+ -1, grid_size_h, grid_size_w, -1
543
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
544
+ freqs_h = freqs_h[None, :, None, :].expand(
545
+ temporal_size, -1, grid_size_w, -1
546
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
547
+ freqs_w = freqs_w[None, None, :, :].expand(
548
+ temporal_size, grid_size_h, -1, -1
549
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
550
+
551
+ freqs = torch.cat(
552
+ [freqs_t, freqs_h, freqs_w], dim=-1
553
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
554
+ freqs = freqs.view(
555
+ temporal_size * grid_size_h * grid_size_w, -1
556
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
557
+ return freqs
558
+
559
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
560
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
561
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
562
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
563
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
564
+ return cos, sin
459
565
 
460
566
 
461
567
  def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
@@ -530,6 +636,7 @@ def get_1d_rotary_pos_embed(
530
636
  linear_factor=1.0,
531
637
  ntk_factor=1.0,
532
638
  repeat_interleave_real=True,
639
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
533
640
  ):
534
641
  """
535
642
  Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -552,26 +659,37 @@ def get_1d_rotary_pos_embed(
552
659
  repeat_interleave_real (`bool`, *optional*, defaults to `True`):
553
660
  If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
554
661
  Otherwise, they are concateanted with themselves.
662
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
663
+ the dtype of the frequency tensor.
555
664
  Returns:
556
665
  `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
557
666
  """
558
667
  assert dim % 2 == 0
559
668
 
560
669
  if isinstance(pos, int):
561
- pos = np.arange(pos)
670
+ pos = torch.arange(pos)
671
+ if isinstance(pos, np.ndarray):
672
+ pos = torch.from_numpy(pos) # type: ignore # [S]
673
+
562
674
  theta = theta * ntk_factor
563
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
564
- t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
565
- freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
675
+ freqs = (
676
+ 1.0
677
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
678
+ / linear_factor
679
+ ) # [D/2]
680
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
566
681
  if use_real and repeat_interleave_real:
567
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
568
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
682
+ # flux, hunyuan-dit, cogvideox
683
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
684
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
569
685
  return freqs_cos, freqs_sin
570
686
  elif use_real:
571
- freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
572
- freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
687
+ # stable audio
688
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
689
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
573
690
  return freqs_cos, freqs_sin
574
691
  else:
692
+ # lumina
575
693
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
576
694
  return freqs_cis
577
695
 
@@ -603,11 +721,11 @@ def apply_rotary_emb(
603
721
  cos, sin = cos.to(x.device), sin.to(x.device)
604
722
 
605
723
  if use_real_unbind_dim == -1:
606
- # Use for example in Lumina
724
+ # Used for flux, cogvideox, hunyuan-dit
607
725
  x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
608
726
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
609
727
  elif use_real_unbind_dim == -2:
610
- # Use for example in Stable Audio
728
+ # Used for Stable Audio
611
729
  x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
612
730
  x_rotated = torch.cat([-x_imag, x_real], dim=-1)
613
731
  else:
@@ -617,6 +735,7 @@ def apply_rotary_emb(
617
735
 
618
736
  return out
619
737
  else:
738
+ # used for lumina
620
739
  x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
621
740
  freqs_cis = freqs_cis.unsqueeze(2)
622
741
  x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
@@ -624,6 +743,31 @@ def apply_rotary_emb(
624
743
  return x_out.type_as(x)
625
744
 
626
745
 
746
+ class FluxPosEmbed(nn.Module):
747
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
748
+ def __init__(self, theta: int, axes_dim: List[int]):
749
+ super().__init__()
750
+ self.theta = theta
751
+ self.axes_dim = axes_dim
752
+
753
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
754
+ n_axes = ids.shape[-1]
755
+ cos_out = []
756
+ sin_out = []
757
+ pos = ids.float()
758
+ is_mps = ids.device.type == "mps"
759
+ freqs_dtype = torch.float32 if is_mps else torch.float64
760
+ for i in range(n_axes):
761
+ cos, sin = get_1d_rotary_pos_embed(
762
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
763
+ )
764
+ cos_out.append(cos)
765
+ sin_out.append(sin)
766
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
767
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
768
+ return freqs_cos, freqs_sin
769
+
770
+
627
771
  class TimestepEmbedding(nn.Module):
628
772
  def __init__(
629
773
  self,
@@ -990,6 +1134,39 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
990
1134
  return conditioning
991
1135
 
992
1136
 
1137
+ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
1138
+ def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
1139
+ super().__init__()
1140
+
1141
+ self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
1142
+ self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
1143
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
1144
+ self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
1145
+
1146
+ def forward(
1147
+ self,
1148
+ timestep: torch.Tensor,
1149
+ original_size: torch.Tensor,
1150
+ target_size: torch.Tensor,
1151
+ crop_coords: torch.Tensor,
1152
+ hidden_dtype: torch.dtype,
1153
+ ) -> torch.Tensor:
1154
+ timesteps_proj = self.time_proj(timestep)
1155
+
1156
+ original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
1157
+ crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
1158
+ target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
1159
+
1160
+ # (B, 3 * condition_dim)
1161
+ condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
1162
+
1163
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
1164
+ condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
1165
+
1166
+ conditioning = timesteps_emb + condition_emb
1167
+ return conditioning
1168
+
1169
+
993
1170
  class HunyuanDiTAttentionPool(nn.Module):
994
1171
  # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
995
1172
 
@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
29
29
  """Returns the positional encoding (same as Tensor2Tensor).
30
30
 
31
31
  Args:
32
- timesteps: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- embedding_dim: The number of output channels.
35
- min_timescale: The smallest time unit (should probably be 0.0).
36
- max_timescale: The largest time unit.
32
+ timesteps (`jnp.ndarray` of shape `(N,)`):
33
+ A 1-D array of N indices, one per batch element. These may be fractional.
34
+ embedding_dim (`int`):
35
+ The number of output channels.
36
+ freq_shift (`float`, *optional*, defaults to `1`):
37
+ Shift applied to the frequency scaling of the embeddings.
38
+ min_timescale (`float`, *optional*, defaults to `1`):
39
+ The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
40
+ max_timescale (`float`, *optional*, defaults to `1.0e4`):
41
+ The largest time unit used in the sinusoidal calculation.
42
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
43
+ Whether to flip the order of sinusoidal components to cosine first.
44
+ scale (`float`, *optional*, defaults to `1.0`):
45
+ A scaling factor applied to the positional embeddings.
46
+
37
47
  Returns:
38
48
  a Tensor of timing signals [N, num_channels]
39
49
  """
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
61
71
 
62
72
  Args:
63
73
  time_embed_dim (`int`, *optional*, defaults to `32`):
64
- Time step embedding dimension
65
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
- Parameters `dtype`
74
+ Time step embedding dimension.
75
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
76
+ The data type for the embedding parameters.
67
77
  """
68
78
 
69
79
  time_embed_dim: int = 32
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
83
93
 
84
94
  Args:
85
95
  dim (`int`, *optional*, defaults to `32`):
86
- Time step embedding dimension
96
+ Time step embedding dimension.
97
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
98
+ Whether to flip the sinusoidal function from sine to cosine.
99
+ freq_shift (`float`, *optional*, defaults to `1`):
100
+ Frequency shift applied to the sinusoidal embeddings.
87
101
  """
88
102
 
89
103
  dim: int = 32