diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -35,10 +35,21 @@ def get_timestep_embedding(
35
35
  """
36
36
  This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
37
37
 
38
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
39
- These may be fractional.
40
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
41
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
38
+ Args
39
+ timesteps (torch.Tensor):
40
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
41
+ embedding_dim (int):
42
+ the dimension of the output.
43
+ flip_sin_to_cos (bool):
44
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
45
+ downscale_freq_shift (float):
46
+ Controls the delta between frequencies between dimensions
47
+ scale (float):
48
+ Scaling factor applied to the embeddings.
49
+ max_period (int):
50
+ Controls the maximum frequency of the embeddings
51
+ Returns
52
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
42
53
  """
43
54
  assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
44
55
 
@@ -67,6 +78,53 @@ def get_timestep_embedding(
67
78
  return emb
68
79
 
69
80
 
81
+ def get_3d_sincos_pos_embed(
82
+ embed_dim: int,
83
+ spatial_size: Union[int, Tuple[int, int]],
84
+ temporal_size: int,
85
+ spatial_interpolation_scale: float = 1.0,
86
+ temporal_interpolation_scale: float = 1.0,
87
+ ) -> np.ndarray:
88
+ r"""
89
+ Args:
90
+ embed_dim (`int`):
91
+ spatial_size (`int` or `Tuple[int, int]`):
92
+ temporal_size (`int`):
93
+ spatial_interpolation_scale (`float`, defaults to 1.0):
94
+ temporal_interpolation_scale (`float`, defaults to 1.0):
95
+ """
96
+ if embed_dim % 4 != 0:
97
+ raise ValueError("`embed_dim` must be divisible by 4")
98
+ if isinstance(spatial_size, int):
99
+ spatial_size = (spatial_size, spatial_size)
100
+
101
+ embed_dim_spatial = 3 * embed_dim // 4
102
+ embed_dim_temporal = embed_dim // 4
103
+
104
+ # 1. Spatial
105
+ grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
106
+ grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
107
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
108
+ grid = np.stack(grid, axis=0)
109
+
110
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
111
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
112
+
113
+ # 2. Temporal
114
+ grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
115
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
116
+
117
+ # 3. Concat
118
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
119
+ pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
120
+
121
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
122
+ pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
123
+
124
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
125
+ return pos_embed
126
+
127
+
70
128
  def get_2d_sincos_pos_embed(
71
129
  embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
72
130
  ):
@@ -230,6 +288,176 @@ class PatchEmbed(nn.Module):
230
288
  return (latent + pos_embed).to(latent.dtype)
231
289
 
232
290
 
291
+ class LuminaPatchEmbed(nn.Module):
292
+ """2D Image to Patch Embedding with support for Lumina-T2X"""
293
+
294
+ def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
295
+ super().__init__()
296
+ self.patch_size = patch_size
297
+ self.proj = nn.Linear(
298
+ in_features=patch_size * patch_size * in_channels,
299
+ out_features=embed_dim,
300
+ bias=bias,
301
+ )
302
+
303
+ def forward(self, x, freqs_cis):
304
+ """
305
+ Patchifies and embeds the input tensor(s).
306
+
307
+ Args:
308
+ x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
309
+
310
+ Returns:
311
+ Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
312
+ and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
313
+ frequency tensor(s).
314
+ """
315
+ freqs_cis = freqs_cis.to(x[0].device)
316
+ patch_height = patch_width = self.patch_size
317
+ batch_size, channel, height, width = x.size()
318
+ height_tokens, width_tokens = height // patch_height, width // patch_width
319
+
320
+ x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
321
+ 0, 2, 4, 1, 3, 5
322
+ )
323
+ x = x.flatten(3)
324
+ x = self.proj(x)
325
+ x = x.flatten(1, 2)
326
+
327
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
328
+
329
+ return (
330
+ x,
331
+ mask,
332
+ [(height, width)] * batch_size,
333
+ freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
334
+ )
335
+
336
+
337
+ class CogVideoXPatchEmbed(nn.Module):
338
+ def __init__(
339
+ self,
340
+ patch_size: int = 2,
341
+ in_channels: int = 16,
342
+ embed_dim: int = 1920,
343
+ text_embed_dim: int = 4096,
344
+ bias: bool = True,
345
+ ) -> None:
346
+ super().__init__()
347
+ self.patch_size = patch_size
348
+
349
+ self.proj = nn.Conv2d(
350
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
351
+ )
352
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
353
+
354
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
355
+ r"""
356
+ Args:
357
+ text_embeds (`torch.Tensor`):
358
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
359
+ image_embeds (`torch.Tensor`):
360
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
361
+ """
362
+ text_embeds = self.text_proj(text_embeds)
363
+
364
+ batch, num_frames, channels, height, width = image_embeds.shape
365
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
366
+ image_embeds = self.proj(image_embeds)
367
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
368
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
369
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
370
+
371
+ embeds = torch.cat(
372
+ [text_embeds, image_embeds], dim=1
373
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
374
+ return embeds
375
+
376
+
377
+ def get_3d_rotary_pos_embed(
378
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
379
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
380
+ """
381
+ RoPE for video tokens with 3D structure.
382
+
383
+ Args:
384
+ embed_dim: (`int`):
385
+ The embedding dimension size, corresponding to hidden_size_head.
386
+ crops_coords (`Tuple[int]`):
387
+ The top-left and bottom-right coordinates of the crop.
388
+ grid_size (`Tuple[int]`):
389
+ The grid size of the spatial positional embedding (height, width).
390
+ temporal_size (`int`):
391
+ The size of the temporal dimension.
392
+ theta (`float`):
393
+ Scaling factor for frequency computation.
394
+ use_real (`bool`):
395
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
396
+
397
+ Returns:
398
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
399
+ """
400
+ start, stop = crops_coords
401
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
402
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
403
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
404
+
405
+ # Compute dimensions for each axis
406
+ dim_t = embed_dim // 4
407
+ dim_h = embed_dim // 8 * 3
408
+ dim_w = embed_dim // 8 * 3
409
+
410
+ # Temporal frequencies
411
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
412
+ grid_t = torch.from_numpy(grid_t).float()
413
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
414
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
415
+
416
+ # Spatial frequencies for height and width
417
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
418
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
419
+ grid_h = torch.from_numpy(grid_h).float()
420
+ grid_w = torch.from_numpy(grid_w).float()
421
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
422
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
423
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
424
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
425
+
426
+ # Broadcast and concatenate tensors along specified dimension
427
+ def broadcast(tensors, dim=-1):
428
+ num_tensors = len(tensors)
429
+ shape_lens = {len(t.shape) for t in tensors}
430
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
431
+ shape_len = list(shape_lens)[0]
432
+ dim = (dim + shape_len) if dim < 0 else dim
433
+ dims = list(zip(*(list(t.shape) for t in tensors)))
434
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
435
+ assert all(
436
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
437
+ ), "invalid dimensions for broadcastable concatenation"
438
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
439
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
440
+ expanded_dims.insert(dim, (dim, dims[dim]))
441
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
442
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
443
+ return torch.cat(tensors, dim=dim)
444
+
445
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
446
+
447
+ t, h, w, d = freqs.shape
448
+ freqs = freqs.view(t * h * w, d)
449
+
450
+ # Generate sine and cosine components
451
+ sin = freqs.sin()
452
+ cos = freqs.cos()
453
+
454
+ if use_real:
455
+ return cos, sin
456
+ else:
457
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
458
+ return freqs_cis
459
+
460
+
233
461
  def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
234
462
  """
235
463
  RoPE for image tokens with 2d structure.
@@ -245,7 +473,7 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
245
473
  If True, return real part and imaginary part separately. Otherwise, return complex numbers.
246
474
 
247
475
  Returns:
248
- `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
476
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
249
477
  """
250
478
  start, stop = crops_coords
251
479
  grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
@@ -262,19 +490,47 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
262
490
  assert embed_dim % 4 == 0
263
491
 
264
492
  # use half of dimensions to encode grid_h
265
- emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
266
- emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
493
+ emb_h = get_1d_rotary_pos_embed(
494
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
495
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
496
+ emb_w = get_1d_rotary_pos_embed(
497
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
498
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
267
499
 
268
500
  if use_real:
269
- cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
270
- sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
501
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
502
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
271
503
  return cos, sin
272
504
  else:
273
505
  emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
274
506
  return emb
275
507
 
276
508
 
277
- def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
509
+ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
510
+ assert embed_dim % 4 == 0
511
+
512
+ emb_h = get_1d_rotary_pos_embed(
513
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
514
+ ) # (H, D/4)
515
+ emb_w = get_1d_rotary_pos_embed(
516
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
517
+ ) # (W, D/4)
518
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
519
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
520
+
521
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
522
+ return emb
523
+
524
+
525
+ def get_1d_rotary_pos_embed(
526
+ dim: int,
527
+ pos: Union[np.ndarray, int],
528
+ theta: float = 10000.0,
529
+ use_real=False,
530
+ linear_factor=1.0,
531
+ ntk_factor=1.0,
532
+ repeat_interleave_real=True,
533
+ ):
278
534
  """
279
535
  Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
280
536
 
@@ -289,19 +545,32 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
289
545
  Scaling factor for frequency computation. Defaults to 10000.0.
290
546
  use_real (`bool`, *optional*):
291
547
  If True, return real part and imaginary part separately. Otherwise, return complex numbers.
292
-
548
+ linear_factor (`float`, *optional*, defaults to 1.0):
549
+ Scaling factor for the context extrapolation. Defaults to 1.0.
550
+ ntk_factor (`float`, *optional*, defaults to 1.0):
551
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
552
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
553
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
554
+ Otherwise, they are concateanted with themselves.
293
555
  Returns:
294
556
  `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
295
557
  """
558
+ assert dim % 2 == 0
559
+
296
560
  if isinstance(pos, int):
297
561
  pos = np.arange(pos)
298
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
562
+ theta = theta * ntk_factor
563
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
299
564
  t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
300
565
  freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
301
- if use_real:
566
+ if use_real and repeat_interleave_real:
302
567
  freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
303
568
  freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
304
569
  return freqs_cos, freqs_sin
570
+ elif use_real:
571
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
572
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
573
+ return freqs_cos, freqs_sin
305
574
  else:
306
575
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
307
576
  return freqs_cis
@@ -310,6 +579,8 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
310
579
  def apply_rotary_emb(
311
580
  x: torch.Tensor,
312
581
  freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
582
+ use_real: bool = True,
583
+ use_real_unbind_dim: int = -1,
313
584
  ) -> Tuple[torch.Tensor, torch.Tensor]:
314
585
  """
315
586
  Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -325,16 +596,32 @@ def apply_rotary_emb(
325
596
  Returns:
326
597
  Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
327
598
  """
328
- cos, sin = freqs_cis # [S, D]
329
- cos = cos[None, None]
330
- sin = sin[None, None]
331
- cos, sin = cos.to(x.device), sin.to(x.device)
599
+ if use_real:
600
+ cos, sin = freqs_cis # [S, D]
601
+ cos = cos[None, None]
602
+ sin = sin[None, None]
603
+ cos, sin = cos.to(x.device), sin.to(x.device)
604
+
605
+ if use_real_unbind_dim == -1:
606
+ # Use for example in Lumina
607
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
608
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
609
+ elif use_real_unbind_dim == -2:
610
+ # Use for example in Stable Audio
611
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
612
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
613
+ else:
614
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
615
+
616
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
332
617
 
333
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
334
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
335
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
618
+ return out
619
+ else:
620
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
621
+ freqs_cis = freqs_cis.unsqueeze(2)
622
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
336
623
 
337
- return out
624
+ return x_out.type_as(x)
338
625
 
339
626
 
340
627
  class TimestepEmbedding(nn.Module):
@@ -386,11 +673,12 @@ class TimestepEmbedding(nn.Module):
386
673
 
387
674
 
388
675
  class Timesteps(nn.Module):
389
- def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
676
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
390
677
  super().__init__()
391
678
  self.num_channels = num_channels
392
679
  self.flip_sin_to_cos = flip_sin_to_cos
393
680
  self.downscale_freq_shift = downscale_freq_shift
681
+ self.scale = scale
394
682
 
395
683
  def forward(self, timesteps):
396
684
  t_emb = get_timestep_embedding(
@@ -398,6 +686,7 @@ class Timesteps(nn.Module):
398
686
  self.num_channels,
399
687
  flip_sin_to_cos=self.flip_sin_to_cos,
400
688
  downscale_freq_shift=self.downscale_freq_shift,
689
+ scale=self.scale,
401
690
  )
402
691
  return t_emb
403
692
 
@@ -415,9 +704,10 @@ class GaussianFourierProjection(nn.Module):
415
704
 
416
705
  if set_W_to_weight:
417
706
  # to delete later
707
+ del self.weight
418
708
  self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
419
-
420
709
  self.weight = self.W
710
+ del self.W
421
711
 
422
712
  def forward(self, x):
423
713
  if self.log:
@@ -676,6 +966,30 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
676
966
  return conditioning
677
967
 
678
968
 
969
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
970
+ def __init__(self, embedding_dim, pooled_projection_dim):
971
+ super().__init__()
972
+
973
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
974
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
975
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
976
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
977
+
978
+ def forward(self, timestep, guidance, pooled_projection):
979
+ timesteps_proj = self.time_proj(timestep)
980
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
981
+
982
+ guidance_proj = self.time_proj(guidance)
983
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
984
+
985
+ time_guidance_emb = timesteps_emb + guidance_emb
986
+
987
+ pooled_projections = self.text_embedder(pooled_projection)
988
+ conditioning = time_guidance_emb + pooled_projections
989
+
990
+ return conditioning
991
+
992
+
679
993
  class HunyuanDiTAttentionPool(nn.Module):
680
994
  # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
681
995
 
@@ -717,18 +1031,33 @@ class HunyuanDiTAttentionPool(nn.Module):
717
1031
 
718
1032
 
719
1033
  class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
720
- def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
1034
+ def __init__(
1035
+ self,
1036
+ embedding_dim,
1037
+ pooled_projection_dim=1024,
1038
+ seq_len=256,
1039
+ cross_attention_dim=2048,
1040
+ use_style_cond_and_image_meta_size=True,
1041
+ ):
721
1042
  super().__init__()
722
1043
 
723
1044
  self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
724
1045
  self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
725
1046
 
1047
+ self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1048
+
726
1049
  self.pooler = HunyuanDiTAttentionPool(
727
1050
  seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
728
1051
  )
1052
+
729
1053
  # Here we use a default learned embedder layer for future extension.
730
- self.style_embedder = nn.Embedding(1, embedding_dim)
731
- extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
1054
+ self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
1055
+ if use_style_cond_and_image_meta_size:
1056
+ self.style_embedder = nn.Embedding(1, embedding_dim)
1057
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
1058
+ else:
1059
+ extra_in_dim = pooled_projection_dim
1060
+
732
1061
  self.extra_embedder = PixArtAlphaTextProjection(
733
1062
  in_features=extra_in_dim,
734
1063
  hidden_size=embedding_dim * 4,
@@ -743,21 +1072,59 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
743
1072
  # extra condition1: text
744
1073
  pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
745
1074
 
746
- # extra condition2: image meta size embdding
747
- image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
748
- image_meta_size = image_meta_size.to(dtype=hidden_dtype)
749
- image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
1075
+ if self.use_style_cond_and_image_meta_size:
1076
+ # extra condition2: image meta size embedding
1077
+ image_meta_size = self.size_proj(image_meta_size.view(-1))
1078
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
1079
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
1080
+
1081
+ # extra condition3: style embedding
1082
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
750
1083
 
751
- # extra condition3: style embedding
752
- style_embedding = self.style_embedder(style) # (N, embedding_dim)
1084
+ # Concatenate all extra vectors
1085
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
1086
+ else:
1087
+ extra_cond = torch.cat([pooled_projections], dim=1)
753
1088
 
754
- # Concatenate all extra vectors
755
- extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
756
1089
  conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
757
1090
 
758
1091
  return conditioning
759
1092
 
760
1093
 
1094
+ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
1095
+ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
1096
+ super().__init__()
1097
+ self.time_proj = Timesteps(
1098
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
1099
+ )
1100
+
1101
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
1102
+
1103
+ self.caption_embedder = nn.Sequential(
1104
+ nn.LayerNorm(cross_attention_dim),
1105
+ nn.Linear(
1106
+ cross_attention_dim,
1107
+ hidden_size,
1108
+ bias=True,
1109
+ ),
1110
+ )
1111
+
1112
+ def forward(self, timestep, caption_feat, caption_mask):
1113
+ # timestep embedding:
1114
+ time_freq = self.time_proj(timestep)
1115
+ time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
1116
+
1117
+ # caption condition embedding:
1118
+ caption_mask_float = caption_mask.float().unsqueeze(-1)
1119
+ caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
1120
+ caption_feats_pool = caption_feats_pool.to(caption_feat)
1121
+ caption_embed = self.caption_embedder(caption_feats_pool)
1122
+
1123
+ conditioning = time_embed + caption_embed
1124
+
1125
+ return conditioning
1126
+
1127
+
761
1128
  class TextTimeEmbedding(nn.Module):
762
1129
  def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
763
1130
  super().__init__()
@@ -980,7 +1347,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
980
1347
 
981
1348
  objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
982
1349
 
983
- # positionet with text and image infomation
1350
+ # positionet with text and image information
984
1351
  else:
985
1352
  phrases_masks = phrases_masks.unsqueeze(-1)
986
1353
  image_masks = image_masks.unsqueeze(-1)
@@ -1252,7 +1619,7 @@ class MultiIPAdapterImageProjection(nn.Module):
1252
1619
  if not isinstance(image_embeds, list):
1253
1620
  deprecation_message = (
1254
1621
  "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
1255
- " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
1622
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
1256
1623
  )
1257
1624
  deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
1258
1625
  image_embeds = [image_embeds.unsqueeze(1)]
@@ -191,7 +191,6 @@ def _fetch_index_file(
191
191
  cache_dir,
192
192
  variant,
193
193
  force_download,
194
- resume_download,
195
194
  proxies,
196
195
  local_files_only,
197
196
  token,
@@ -216,12 +215,11 @@ def _fetch_index_file(
216
215
  weights_name=index_file_in_repo,
217
216
  cache_dir=cache_dir,
218
217
  force_download=force_download,
219
- resume_download=resume_download,
220
218
  proxies=proxies,
221
219
  local_files_only=local_files_only,
222
220
  token=token,
223
221
  revision=revision,
224
- subfolder=subfolder,
222
+ subfolder=None,
225
223
  user_agent=user_agent,
226
224
  commit_hash=commit_hash,
227
225
  )
@@ -245,9 +245,7 @@ class FlaxModelMixin(PushToHubMixin):
245
245
  force_download (`bool`, *optional*, defaults to `False`):
246
246
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
247
247
  cached versions if they exist.
248
- resume_download:
249
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
250
- of Diffusers.
248
+
251
249
  proxies (`Dict[str, str]`, *optional*):
252
250
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
253
251
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -296,7 +294,6 @@ class FlaxModelMixin(PushToHubMixin):
296
294
  cache_dir = kwargs.pop("cache_dir", None)
297
295
  force_download = kwargs.pop("force_download", False)
298
296
  from_pt = kwargs.pop("from_pt", False)
299
- resume_download = kwargs.pop("resume_download", None)
300
297
  proxies = kwargs.pop("proxies", None)
301
298
  local_files_only = kwargs.pop("local_files_only", False)
302
299
  token = kwargs.pop("token", None)
@@ -316,7 +313,6 @@ class FlaxModelMixin(PushToHubMixin):
316
313
  cache_dir=cache_dir,
317
314
  return_unused_kwargs=True,
318
315
  force_download=force_download,
319
- resume_download=resume_download,
320
316
  proxies=proxies,
321
317
  local_files_only=local_files_only,
322
318
  token=token,
@@ -362,7 +358,6 @@ class FlaxModelMixin(PushToHubMixin):
362
358
  cache_dir=cache_dir,
363
359
  force_download=force_download,
364
360
  proxies=proxies,
365
- resume_download=resume_download,
366
361
  local_files_only=local_files_only,
367
362
  token=token,
368
363
  user_agent=user_agent,