diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -442,6 +442,60 @@ class CogVideoXPatchEmbed(nn.Module):
442
442
  return embeds
443
443
 
444
444
 
445
+ class CogView3PlusPatchEmbed(nn.Module):
446
+ def __init__(
447
+ self,
448
+ in_channels: int = 16,
449
+ hidden_size: int = 2560,
450
+ patch_size: int = 2,
451
+ text_hidden_size: int = 4096,
452
+ pos_embed_max_size: int = 128,
453
+ ):
454
+ super().__init__()
455
+ self.in_channels = in_channels
456
+ self.hidden_size = hidden_size
457
+ self.patch_size = patch_size
458
+ self.text_hidden_size = text_hidden_size
459
+ self.pos_embed_max_size = pos_embed_max_size
460
+ # Linear projection for image patches
461
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
462
+
463
+ # Linear projection for text embeddings
464
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
465
+
466
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
467
+ pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
468
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
469
+
470
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
471
+ batch_size, channel, height, width = hidden_states.shape
472
+
473
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
474
+ raise ValueError("Height and width must be divisible by patch size")
475
+
476
+ height = height // self.patch_size
477
+ width = width // self.patch_size
478
+ hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
479
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
480
+ hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
481
+
482
+ # Project the patches
483
+ hidden_states = self.proj(hidden_states)
484
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
485
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
486
+
487
+ # Calculate text_length
488
+ text_length = encoder_hidden_states.shape[1]
489
+
490
+ image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
491
+ text_pos_embed = torch.zeros(
492
+ (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
493
+ )
494
+ pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
495
+
496
+ return (hidden_states + pos_embed).to(hidden_states.dtype)
497
+
498
+
445
499
  def get_3d_rotary_pos_embed(
446
500
  embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
447
501
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@@ -459,15 +513,16 @@ def get_3d_rotary_pos_embed(
459
513
  The size of the temporal dimension.
460
514
  theta (`float`):
461
515
  Scaling factor for frequency computation.
462
- use_real (`bool`):
463
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
464
516
 
465
517
  Returns:
466
518
  `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
467
519
  """
520
+ if use_real is not True:
521
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
468
522
  start, stop = crops_coords
469
- grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
470
- grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
523
+ grid_size_h, grid_size_w = grid_size
524
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
525
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
471
526
  grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
472
527
 
473
528
  # Compute dimensions for each axis
@@ -476,54 +531,37 @@ def get_3d_rotary_pos_embed(
476
531
  dim_w = embed_dim // 8 * 3
477
532
 
478
533
  # Temporal frequencies
479
- freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
480
- grid_t = torch.from_numpy(grid_t).float()
481
- freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
482
- freqs_t = freqs_t.repeat_interleave(2, dim=-1)
483
-
534
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
484
535
  # Spatial frequencies for height and width
485
- freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
486
- freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
487
- grid_h = torch.from_numpy(grid_h).float()
488
- grid_w = torch.from_numpy(grid_w).float()
489
- freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
490
- freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
491
- freqs_h = freqs_h.repeat_interleave(2, dim=-1)
492
- freqs_w = freqs_w.repeat_interleave(2, dim=-1)
493
-
494
- # Broadcast and concatenate tensors along specified dimension
495
- def broadcast(tensors, dim=-1):
496
- num_tensors = len(tensors)
497
- shape_lens = {len(t.shape) for t in tensors}
498
- assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
499
- shape_len = list(shape_lens)[0]
500
- dim = (dim + shape_len) if dim < 0 else dim
501
- dims = list(zip(*(list(t.shape) for t in tensors)))
502
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
503
- assert all(
504
- [*(len(set(t[1])) <= 2 for t in expandable_dims)]
505
- ), "invalid dimensions for broadcastable concatenation"
506
- max_dims = [(t[0], max(t[1])) for t in expandable_dims]
507
- expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
508
- expanded_dims.insert(dim, (dim, dims[dim]))
509
- expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
510
- tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
511
- return torch.cat(tensors, dim=dim)
512
-
513
- freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
514
-
515
- t, h, w, d = freqs.shape
516
- freqs = freqs.view(t * h * w, d)
517
-
518
- # Generate sine and cosine components
519
- sin = freqs.sin()
520
- cos = freqs.cos()
521
-
522
- if use_real:
523
- return cos, sin
524
- else:
525
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
526
- return freqs_cis
536
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
537
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
538
+
539
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
540
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
541
+ freqs_t = freqs_t[:, None, None, :].expand(
542
+ -1, grid_size_h, grid_size_w, -1
543
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
544
+ freqs_h = freqs_h[None, :, None, :].expand(
545
+ temporal_size, -1, grid_size_w, -1
546
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
547
+ freqs_w = freqs_w[None, None, :, :].expand(
548
+ temporal_size, grid_size_h, -1, -1
549
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
550
+
551
+ freqs = torch.cat(
552
+ [freqs_t, freqs_h, freqs_w], dim=-1
553
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
554
+ freqs = freqs.view(
555
+ temporal_size * grid_size_h * grid_size_w, -1
556
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
557
+ return freqs
558
+
559
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
560
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
561
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
562
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
563
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
564
+ return cos, sin
527
565
 
528
566
 
529
567
  def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
@@ -598,6 +636,7 @@ def get_1d_rotary_pos_embed(
598
636
  linear_factor=1.0,
599
637
  ntk_factor=1.0,
600
638
  repeat_interleave_real=True,
639
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
601
640
  ):
602
641
  """
603
642
  Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -620,26 +659,37 @@ def get_1d_rotary_pos_embed(
620
659
  repeat_interleave_real (`bool`, *optional*, defaults to `True`):
621
660
  If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
622
661
  Otherwise, they are concateanted with themselves.
662
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
663
+ the dtype of the frequency tensor.
623
664
  Returns:
624
665
  `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
625
666
  """
626
667
  assert dim % 2 == 0
627
668
 
628
669
  if isinstance(pos, int):
629
- pos = np.arange(pos)
670
+ pos = torch.arange(pos)
671
+ if isinstance(pos, np.ndarray):
672
+ pos = torch.from_numpy(pos) # type: ignore # [S]
673
+
630
674
  theta = theta * ntk_factor
631
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
632
- t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
633
- freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
675
+ freqs = (
676
+ 1.0
677
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
678
+ / linear_factor
679
+ ) # [D/2]
680
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
634
681
  if use_real and repeat_interleave_real:
635
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
636
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
682
+ # flux, hunyuan-dit, cogvideox
683
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
684
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
637
685
  return freqs_cos, freqs_sin
638
686
  elif use_real:
639
- freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
640
- freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
687
+ # stable audio
688
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
689
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
641
690
  return freqs_cos, freqs_sin
642
691
  else:
692
+ # lumina
643
693
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
644
694
  return freqs_cis
645
695
 
@@ -671,11 +721,11 @@ def apply_rotary_emb(
671
721
  cos, sin = cos.to(x.device), sin.to(x.device)
672
722
 
673
723
  if use_real_unbind_dim == -1:
674
- # Use for example in Lumina
724
+ # Used for flux, cogvideox, hunyuan-dit
675
725
  x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
676
726
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
677
727
  elif use_real_unbind_dim == -2:
678
- # Use for example in Stable Audio
728
+ # Used for Stable Audio
679
729
  x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
680
730
  x_rotated = torch.cat([-x_imag, x_real], dim=-1)
681
731
  else:
@@ -685,6 +735,7 @@ def apply_rotary_emb(
685
735
 
686
736
  return out
687
737
  else:
738
+ # used for lumina
688
739
  x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
689
740
  freqs_cis = freqs_cis.unsqueeze(2)
690
741
  x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
@@ -692,6 +743,31 @@ def apply_rotary_emb(
692
743
  return x_out.type_as(x)
693
744
 
694
745
 
746
+ class FluxPosEmbed(nn.Module):
747
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
748
+ def __init__(self, theta: int, axes_dim: List[int]):
749
+ super().__init__()
750
+ self.theta = theta
751
+ self.axes_dim = axes_dim
752
+
753
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
754
+ n_axes = ids.shape[-1]
755
+ cos_out = []
756
+ sin_out = []
757
+ pos = ids.float()
758
+ is_mps = ids.device.type == "mps"
759
+ freqs_dtype = torch.float32 if is_mps else torch.float64
760
+ for i in range(n_axes):
761
+ cos, sin = get_1d_rotary_pos_embed(
762
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
763
+ )
764
+ cos_out.append(cos)
765
+ sin_out.append(sin)
766
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
767
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
768
+ return freqs_cos, freqs_sin
769
+
770
+
695
771
  class TimestepEmbedding(nn.Module):
696
772
  def __init__(
697
773
  self,
@@ -1058,6 +1134,39 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
1058
1134
  return conditioning
1059
1135
 
1060
1136
 
1137
+ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
1138
+ def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
1139
+ super().__init__()
1140
+
1141
+ self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
1142
+ self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
1143
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
1144
+ self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
1145
+
1146
+ def forward(
1147
+ self,
1148
+ timestep: torch.Tensor,
1149
+ original_size: torch.Tensor,
1150
+ target_size: torch.Tensor,
1151
+ crop_coords: torch.Tensor,
1152
+ hidden_dtype: torch.dtype,
1153
+ ) -> torch.Tensor:
1154
+ timesteps_proj = self.time_proj(timestep)
1155
+
1156
+ original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
1157
+ crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
1158
+ target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
1159
+
1160
+ # (B, 3 * condition_dim)
1161
+ condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
1162
+
1163
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
1164
+ condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
1165
+
1166
+ conditioning = timesteps_emb + condition_emb
1167
+ return conditioning
1168
+
1169
+
1061
1170
  class HunyuanDiTAttentionPool(nn.Module):
1062
1171
  # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
1063
1172
 
@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
29
29
  """Returns the positional encoding (same as Tensor2Tensor).
30
30
 
31
31
  Args:
32
- timesteps: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- embedding_dim: The number of output channels.
35
- min_timescale: The smallest time unit (should probably be 0.0).
36
- max_timescale: The largest time unit.
32
+ timesteps (`jnp.ndarray` of shape `(N,)`):
33
+ A 1-D array of N indices, one per batch element. These may be fractional.
34
+ embedding_dim (`int`):
35
+ The number of output channels.
36
+ freq_shift (`float`, *optional*, defaults to `1`):
37
+ Shift applied to the frequency scaling of the embeddings.
38
+ min_timescale (`float`, *optional*, defaults to `1`):
39
+ The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
40
+ max_timescale (`float`, *optional*, defaults to `1.0e4`):
41
+ The largest time unit used in the sinusoidal calculation.
42
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
43
+ Whether to flip the order of sinusoidal components to cosine first.
44
+ scale (`float`, *optional*, defaults to `1.0`):
45
+ A scaling factor applied to the positional embeddings.
46
+
37
47
  Returns:
38
48
  a Tensor of timing signals [N, num_channels]
39
49
  """
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
61
71
 
62
72
  Args:
63
73
  time_embed_dim (`int`, *optional*, defaults to `32`):
64
- Time step embedding dimension
65
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
- Parameters `dtype`
74
+ Time step embedding dimension.
75
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
76
+ The data type for the embedding parameters.
67
77
  """
68
78
 
69
79
  time_embed_dim: int = 32
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
83
93
 
84
94
  Args:
85
95
  dim (`int`, *optional*, defaults to `32`):
86
- Time step embedding dimension
96
+ Time step embedding dimension.
97
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
98
+ Whether to flip the sinusoidal function from sine to cosine.
99
+ freq_shift (`float`, *optional*, defaults to `1`):
100
+ Frequency shift applied to the sinusoidal embeddings.
87
101
  """
88
102
 
89
103
  dim: int = 32
@@ -25,12 +25,14 @@ import safetensors
25
25
  import torch
26
26
  from huggingface_hub.utils import EntryNotFoundError
27
27
 
28
+ from ..quantizers.quantization_config import QuantizationMethod
28
29
  from ..utils import (
29
30
  SAFE_WEIGHTS_INDEX_NAME,
30
31
  SAFETENSORS_FILE_EXTENSION,
31
32
  WEIGHTS_INDEX_NAME,
32
33
  _add_variant,
33
34
  _get_model_file,
35
+ deprecate,
34
36
  is_accelerate_available,
35
37
  is_torch_version,
36
38
  logging,
@@ -53,11 +55,36 @@ if is_accelerate_available():
53
55
 
54
56
 
55
57
  # Adapted from `transformers` (see modeling_utils.py)
56
- def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
58
+ def _determine_device_map(
59
+ model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
60
+ ):
57
61
  if isinstance(device_map, str):
62
+ special_dtypes = {}
63
+ if hf_quantizer is not None:
64
+ special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
65
+ special_dtypes.update(
66
+ {
67
+ name: torch.float32
68
+ for name, _ in model.named_parameters()
69
+ if any(m in name for m in keep_in_fp32_modules)
70
+ }
71
+ )
72
+
73
+ target_dtype = torch_dtype
74
+ if hf_quantizer is not None:
75
+ target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
76
+
58
77
  no_split_modules = model._get_no_split_modules(device_map)
59
78
  device_map_kwargs = {"no_split_module_classes": no_split_modules}
60
79
 
80
+ if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
81
+ device_map_kwargs["special_dtypes"] = special_dtypes
82
+ elif len(special_dtypes) > 0:
83
+ logger.warning(
84
+ "This model has some weights that should be kept in higher precision, you need to upgrade "
85
+ "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
86
+ )
87
+
61
88
  if device_map != "sequential":
62
89
  max_memory = get_balanced_memory(
63
90
  model,
@@ -69,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
69
96
  else:
70
97
  max_memory = get_max_memory(max_memory)
71
98
 
99
+ if hf_quantizer is not None:
100
+ max_memory = hf_quantizer.adjust_max_memory(max_memory)
101
+
72
102
  device_map_kwargs["max_memory"] = max_memory
73
- device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
103
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
104
+
105
+ if hf_quantizer is not None:
106
+ hf_quantizer.validate_environment(device_map=device_map)
74
107
 
75
108
  return device_map
76
109
 
@@ -99,6 +132,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
99
132
  """
100
133
  Reads a checkpoint file, returning properly formatted errors if they arise.
101
134
  """
135
+ # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
136
+ # when refactoring the _merge_sharded_checkpoints() method later.
137
+ if isinstance(checkpoint_file, dict):
138
+ return checkpoint_file
102
139
  try:
103
140
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
104
141
  if file_extension == SAFETENSORS_FILE_EXTENSION:
@@ -136,29 +173,67 @@ def load_model_dict_into_meta(
136
173
  device: Optional[Union[str, torch.device]] = None,
137
174
  dtype: Optional[Union[str, torch.dtype]] = None,
138
175
  model_name_or_path: Optional[str] = None,
176
+ hf_quantizer=None,
177
+ keep_in_fp32_modules=None,
139
178
  ) -> List[str]:
140
- device = device or torch.device("cpu")
179
+ if hf_quantizer is None:
180
+ device = device or torch.device("cpu")
141
181
  dtype = dtype or torch.float32
182
+ is_quantized = hf_quantizer is not None
183
+ is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
142
184
 
143
185
  accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
144
-
145
- unexpected_keys = []
146
186
  empty_state_dict = model.state_dict()
187
+ unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
188
+
147
189
  for param_name, param in state_dict.items():
148
190
  if param_name not in empty_state_dict:
149
- unexpected_keys.append(param_name)
150
191
  continue
151
192
 
193
+ set_module_kwargs = {}
194
+ # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
195
+ # in int/uint/bool and not cast them.
196
+ # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
197
+ if torch.is_floating_point(param):
198
+ if (
199
+ keep_in_fp32_modules is not None
200
+ and any(
201
+ module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
202
+ )
203
+ and dtype == torch.float16
204
+ ):
205
+ param = param.to(torch.float32)
206
+ if accepts_dtype:
207
+ set_module_kwargs["dtype"] = torch.float32
208
+ else:
209
+ param = param.to(dtype)
210
+ if accepts_dtype:
211
+ set_module_kwargs["dtype"] = dtype
212
+
213
+ # bnb params are flattened.
152
214
  if empty_state_dict[param_name].shape != param.shape:
153
- model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
154
- raise ValueError(
155
- f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
156
- )
157
-
158
- if accepts_dtype:
159
- set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
215
+ if (
216
+ is_quant_method_bnb
217
+ and hf_quantizer.pre_quantized
218
+ and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
219
+ ):
220
+ hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
221
+ elif not is_quant_method_bnb:
222
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
223
+ raise ValueError(
224
+ f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
225
+ )
226
+
227
+ if is_quantized and (
228
+ hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
229
+ ):
230
+ hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
160
231
  else:
161
- set_module_tensor_to_device(model, param_name, device, value=param)
232
+ if accepts_dtype:
233
+ set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
234
+ else:
235
+ set_module_tensor_to_device(model, param_name, device, value=param)
236
+
162
237
  return unexpected_keys
163
238
 
164
239
 
@@ -228,3 +303,96 @@ def _fetch_index_file(
228
303
  index_file = None
229
304
 
230
305
  return index_file
306
+
307
+
308
+ # Adapted from
309
+ # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
310
+ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
311
+ weight_map = sharded_metadata.get("weight_map", None)
312
+ if weight_map is None:
313
+ raise KeyError("'weight_map' key not found in the shard index file.")
314
+
315
+ # Collect all unique safetensors files from weight_map
316
+ files_to_load = set(weight_map.values())
317
+ is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
318
+ merged_state_dict = {}
319
+
320
+ # Load tensors from each unique file
321
+ for file_name in files_to_load:
322
+ part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
323
+ if not os.path.exists(part_file_path):
324
+ raise FileNotFoundError(f"Part file {file_name} not found.")
325
+
326
+ if is_safetensors:
327
+ with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
328
+ for tensor_key in f.keys():
329
+ if tensor_key in weight_map:
330
+ merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
331
+ else:
332
+ merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
333
+
334
+ return merged_state_dict
335
+
336
+
337
+ def _fetch_index_file_legacy(
338
+ is_local,
339
+ pretrained_model_name_or_path,
340
+ subfolder,
341
+ use_safetensors,
342
+ cache_dir,
343
+ variant,
344
+ force_download,
345
+ proxies,
346
+ local_files_only,
347
+ token,
348
+ revision,
349
+ user_agent,
350
+ commit_hash,
351
+ ):
352
+ if is_local:
353
+ index_file = Path(
354
+ pretrained_model_name_or_path,
355
+ subfolder or "",
356
+ SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
357
+ ).as_posix()
358
+ splits = index_file.split(".")
359
+ split_index = -3 if ".cache" in index_file else -2
360
+ splits = splits[:-split_index] + [variant] + splits[-split_index:]
361
+ index_file = ".".join(splits)
362
+ if os.path.exists(index_file):
363
+ deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
364
+ deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
365
+ index_file = Path(index_file)
366
+ else:
367
+ index_file = None
368
+ else:
369
+ if variant is not None:
370
+ index_file_in_repo = Path(
371
+ subfolder or "",
372
+ SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
373
+ ).as_posix()
374
+ splits = index_file_in_repo.split(".")
375
+ split_index = -2
376
+ splits = splits[:-split_index] + [variant] + splits[-split_index:]
377
+ index_file_in_repo = ".".join(splits)
378
+ try:
379
+ index_file = _get_model_file(
380
+ pretrained_model_name_or_path,
381
+ weights_name=index_file_in_repo,
382
+ cache_dir=cache_dir,
383
+ force_download=force_download,
384
+ proxies=proxies,
385
+ local_files_only=local_files_only,
386
+ token=token,
387
+ revision=revision,
388
+ subfolder=None,
389
+ user_agent=user_agent,
390
+ commit_hash=commit_hash,
391
+ )
392
+ index_file = Path(index_file)
393
+ deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
394
+ deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
395
+ except (EntryNotFoundError, EnvironmentError):
396
+ index_file = None
397
+
398
+ return index_file