diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -35,10 +35,21 @@ def get_timestep_embedding(
35
35
  """
36
36
  This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
37
37
 
38
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
39
- These may be fractional.
40
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
41
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
38
+ Args
39
+ timesteps (torch.Tensor):
40
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
41
+ embedding_dim (int):
42
+ the dimension of the output.
43
+ flip_sin_to_cos (bool):
44
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
45
+ downscale_freq_shift (float):
46
+ Controls the delta between frequencies between dimensions
47
+ scale (float):
48
+ Scaling factor applied to the embeddings.
49
+ max_period (int):
50
+ Controls the maximum frequency of the embeddings
51
+ Returns
52
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
42
53
  """
43
54
  assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
44
55
 
@@ -67,6 +78,53 @@ def get_timestep_embedding(
67
78
  return emb
68
79
 
69
80
 
81
+ def get_3d_sincos_pos_embed(
82
+ embed_dim: int,
83
+ spatial_size: Union[int, Tuple[int, int]],
84
+ temporal_size: int,
85
+ spatial_interpolation_scale: float = 1.0,
86
+ temporal_interpolation_scale: float = 1.0,
87
+ ) -> np.ndarray:
88
+ r"""
89
+ Args:
90
+ embed_dim (`int`):
91
+ spatial_size (`int` or `Tuple[int, int]`):
92
+ temporal_size (`int`):
93
+ spatial_interpolation_scale (`float`, defaults to 1.0):
94
+ temporal_interpolation_scale (`float`, defaults to 1.0):
95
+ """
96
+ if embed_dim % 4 != 0:
97
+ raise ValueError("`embed_dim` must be divisible by 4")
98
+ if isinstance(spatial_size, int):
99
+ spatial_size = (spatial_size, spatial_size)
100
+
101
+ embed_dim_spatial = 3 * embed_dim // 4
102
+ embed_dim_temporal = embed_dim // 4
103
+
104
+ # 1. Spatial
105
+ grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
106
+ grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
107
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
108
+ grid = np.stack(grid, axis=0)
109
+
110
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
111
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
112
+
113
+ # 2. Temporal
114
+ grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
115
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
116
+
117
+ # 3. Concat
118
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
119
+ pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
120
+
121
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
122
+ pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
123
+
124
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
125
+ return pos_embed
126
+
127
+
70
128
  def get_2d_sincos_pos_embed(
71
129
  embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
72
130
  ):
@@ -230,6 +288,92 @@ class PatchEmbed(nn.Module):
230
288
  return (latent + pos_embed).to(latent.dtype)
231
289
 
232
290
 
291
+ class LuminaPatchEmbed(nn.Module):
292
+ """2D Image to Patch Embedding with support for Lumina-T2X"""
293
+
294
+ def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
295
+ super().__init__()
296
+ self.patch_size = patch_size
297
+ self.proj = nn.Linear(
298
+ in_features=patch_size * patch_size * in_channels,
299
+ out_features=embed_dim,
300
+ bias=bias,
301
+ )
302
+
303
+ def forward(self, x, freqs_cis):
304
+ """
305
+ Patchifies and embeds the input tensor(s).
306
+
307
+ Args:
308
+ x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
309
+
310
+ Returns:
311
+ Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
312
+ and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
313
+ frequency tensor(s).
314
+ """
315
+ freqs_cis = freqs_cis.to(x[0].device)
316
+ patch_height = patch_width = self.patch_size
317
+ batch_size, channel, height, width = x.size()
318
+ height_tokens, width_tokens = height // patch_height, width // patch_width
319
+
320
+ x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
321
+ 0, 2, 4, 1, 3, 5
322
+ )
323
+ x = x.flatten(3)
324
+ x = self.proj(x)
325
+ x = x.flatten(1, 2)
326
+
327
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
328
+
329
+ return (
330
+ x,
331
+ mask,
332
+ [(height, width)] * batch_size,
333
+ freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
334
+ )
335
+
336
+
337
+ class CogVideoXPatchEmbed(nn.Module):
338
+ def __init__(
339
+ self,
340
+ patch_size: int = 2,
341
+ in_channels: int = 16,
342
+ embed_dim: int = 1920,
343
+ text_embed_dim: int = 4096,
344
+ bias: bool = True,
345
+ ) -> None:
346
+ super().__init__()
347
+ self.patch_size = patch_size
348
+
349
+ self.proj = nn.Conv2d(
350
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
351
+ )
352
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
353
+
354
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
355
+ r"""
356
+ Args:
357
+ text_embeds (`torch.Tensor`):
358
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
359
+ image_embeds (`torch.Tensor`):
360
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
361
+ """
362
+ text_embeds = self.text_proj(text_embeds)
363
+
364
+ batch, num_frames, channels, height, width = image_embeds.shape
365
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
366
+ image_embeds = self.proj(image_embeds)
367
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
368
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
369
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
370
+
371
+ embeds = torch.cat(
372
+ [text_embeds, image_embeds], dim=1
373
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
374
+ return embeds
375
+
376
+
233
377
  def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
234
378
  """
235
379
  RoPE for image tokens with 2d structure.
@@ -245,7 +389,7 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
245
389
  If True, return real part and imaginary part separately. Otherwise, return complex numbers.
246
390
 
247
391
  Returns:
248
- `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
392
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
249
393
  """
250
394
  start, stop = crops_coords
251
395
  grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
@@ -262,19 +406,47 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
262
406
  assert embed_dim % 4 == 0
263
407
 
264
408
  # use half of dimensions to encode grid_h
265
- emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
266
- emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
409
+ emb_h = get_1d_rotary_pos_embed(
410
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
411
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
412
+ emb_w = get_1d_rotary_pos_embed(
413
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
414
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
267
415
 
268
416
  if use_real:
269
- cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
270
- sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
417
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
418
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
271
419
  return cos, sin
272
420
  else:
273
421
  emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
274
422
  return emb
275
423
 
276
424
 
277
- def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
425
+ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
426
+ assert embed_dim % 4 == 0
427
+
428
+ emb_h = get_1d_rotary_pos_embed(
429
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
430
+ ) # (H, D/4)
431
+ emb_w = get_1d_rotary_pos_embed(
432
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
433
+ ) # (W, D/4)
434
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
435
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
436
+
437
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
438
+ return emb
439
+
440
+
441
+ def get_1d_rotary_pos_embed(
442
+ dim: int,
443
+ pos: Union[np.ndarray, int],
444
+ theta: float = 10000.0,
445
+ use_real=False,
446
+ linear_factor=1.0,
447
+ ntk_factor=1.0,
448
+ repeat_interleave_real=True,
449
+ ):
278
450
  """
279
451
  Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
280
452
 
@@ -289,19 +461,32 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
289
461
  Scaling factor for frequency computation. Defaults to 10000.0.
290
462
  use_real (`bool`, *optional*):
291
463
  If True, return real part and imaginary part separately. Otherwise, return complex numbers.
292
-
464
+ linear_factor (`float`, *optional*, defaults to 1.0):
465
+ Scaling factor for the context extrapolation. Defaults to 1.0.
466
+ ntk_factor (`float`, *optional*, defaults to 1.0):
467
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
468
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
469
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
470
+ Otherwise, they are concateanted with themselves.
293
471
  Returns:
294
472
  `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
295
473
  """
474
+ assert dim % 2 == 0
475
+
296
476
  if isinstance(pos, int):
297
477
  pos = np.arange(pos)
298
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
478
+ theta = theta * ntk_factor
479
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
299
480
  t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
300
481
  freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
301
- if use_real:
482
+ if use_real and repeat_interleave_real:
302
483
  freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
303
484
  freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
304
485
  return freqs_cos, freqs_sin
486
+ elif use_real:
487
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
488
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
489
+ return freqs_cos, freqs_sin
305
490
  else:
306
491
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
307
492
  return freqs_cis
@@ -310,6 +495,8 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
310
495
  def apply_rotary_emb(
311
496
  x: torch.Tensor,
312
497
  freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
498
+ use_real: bool = True,
499
+ use_real_unbind_dim: int = -1,
313
500
  ) -> Tuple[torch.Tensor, torch.Tensor]:
314
501
  """
315
502
  Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -325,16 +512,32 @@ def apply_rotary_emb(
325
512
  Returns:
326
513
  Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
327
514
  """
328
- cos, sin = freqs_cis # [S, D]
329
- cos = cos[None, None]
330
- sin = sin[None, None]
331
- cos, sin = cos.to(x.device), sin.to(x.device)
515
+ if use_real:
516
+ cos, sin = freqs_cis # [S, D]
517
+ cos = cos[None, None]
518
+ sin = sin[None, None]
519
+ cos, sin = cos.to(x.device), sin.to(x.device)
520
+
521
+ if use_real_unbind_dim == -1:
522
+ # Use for example in Lumina
523
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
524
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
525
+ elif use_real_unbind_dim == -2:
526
+ # Use for example in Stable Audio
527
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
528
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
529
+ else:
530
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
332
531
 
333
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
334
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
335
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
532
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
533
+
534
+ return out
535
+ else:
536
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
537
+ freqs_cis = freqs_cis.unsqueeze(2)
538
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
336
539
 
337
- return out
540
+ return x_out.type_as(x)
338
541
 
339
542
 
340
543
  class TimestepEmbedding(nn.Module):
@@ -386,11 +589,12 @@ class TimestepEmbedding(nn.Module):
386
589
 
387
590
 
388
591
  class Timesteps(nn.Module):
389
- def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
592
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
390
593
  super().__init__()
391
594
  self.num_channels = num_channels
392
595
  self.flip_sin_to_cos = flip_sin_to_cos
393
596
  self.downscale_freq_shift = downscale_freq_shift
597
+ self.scale = scale
394
598
 
395
599
  def forward(self, timesteps):
396
600
  t_emb = get_timestep_embedding(
@@ -398,6 +602,7 @@ class Timesteps(nn.Module):
398
602
  self.num_channels,
399
603
  flip_sin_to_cos=self.flip_sin_to_cos,
400
604
  downscale_freq_shift=self.downscale_freq_shift,
605
+ scale=self.scale,
401
606
  )
402
607
  return t_emb
403
608
 
@@ -415,9 +620,10 @@ class GaussianFourierProjection(nn.Module):
415
620
 
416
621
  if set_W_to_weight:
417
622
  # to delete later
623
+ del self.weight
418
624
  self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
419
-
420
625
  self.weight = self.W
626
+ del self.W
421
627
 
422
628
  def forward(self, x):
423
629
  if self.log:
@@ -676,6 +882,30 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
676
882
  return conditioning
677
883
 
678
884
 
885
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
886
+ def __init__(self, embedding_dim, pooled_projection_dim):
887
+ super().__init__()
888
+
889
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
890
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
891
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
892
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
893
+
894
+ def forward(self, timestep, guidance, pooled_projection):
895
+ timesteps_proj = self.time_proj(timestep)
896
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
897
+
898
+ guidance_proj = self.time_proj(guidance)
899
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
900
+
901
+ time_guidance_emb = timesteps_emb + guidance_emb
902
+
903
+ pooled_projections = self.text_embedder(pooled_projection)
904
+ conditioning = time_guidance_emb + pooled_projections
905
+
906
+ return conditioning
907
+
908
+
679
909
  class HunyuanDiTAttentionPool(nn.Module):
680
910
  # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
681
911
 
@@ -717,18 +947,33 @@ class HunyuanDiTAttentionPool(nn.Module):
717
947
 
718
948
 
719
949
  class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
720
- def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
950
+ def __init__(
951
+ self,
952
+ embedding_dim,
953
+ pooled_projection_dim=1024,
954
+ seq_len=256,
955
+ cross_attention_dim=2048,
956
+ use_style_cond_and_image_meta_size=True,
957
+ ):
721
958
  super().__init__()
722
959
 
723
960
  self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
724
961
  self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
725
962
 
963
+ self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
964
+
726
965
  self.pooler = HunyuanDiTAttentionPool(
727
966
  seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
728
967
  )
968
+
729
969
  # Here we use a default learned embedder layer for future extension.
730
- self.style_embedder = nn.Embedding(1, embedding_dim)
731
- extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
970
+ self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
971
+ if use_style_cond_and_image_meta_size:
972
+ self.style_embedder = nn.Embedding(1, embedding_dim)
973
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
974
+ else:
975
+ extra_in_dim = pooled_projection_dim
976
+
732
977
  self.extra_embedder = PixArtAlphaTextProjection(
733
978
  in_features=extra_in_dim,
734
979
  hidden_size=embedding_dim * 4,
@@ -743,21 +988,59 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
743
988
  # extra condition1: text
744
989
  pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
745
990
 
746
- # extra condition2: image meta size embdding
747
- image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
748
- image_meta_size = image_meta_size.to(dtype=hidden_dtype)
749
- image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
991
+ if self.use_style_cond_and_image_meta_size:
992
+ # extra condition2: image meta size embedding
993
+ image_meta_size = self.size_proj(image_meta_size.view(-1))
994
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
995
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
996
+
997
+ # extra condition3: style embedding
998
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
750
999
 
751
- # extra condition3: style embedding
752
- style_embedding = self.style_embedder(style) # (N, embedding_dim)
1000
+ # Concatenate all extra vectors
1001
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
1002
+ else:
1003
+ extra_cond = torch.cat([pooled_projections], dim=1)
753
1004
 
754
- # Concatenate all extra vectors
755
- extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
756
1005
  conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
757
1006
 
758
1007
  return conditioning
759
1008
 
760
1009
 
1010
+ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
1011
+ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
1012
+ super().__init__()
1013
+ self.time_proj = Timesteps(
1014
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
1015
+ )
1016
+
1017
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
1018
+
1019
+ self.caption_embedder = nn.Sequential(
1020
+ nn.LayerNorm(cross_attention_dim),
1021
+ nn.Linear(
1022
+ cross_attention_dim,
1023
+ hidden_size,
1024
+ bias=True,
1025
+ ),
1026
+ )
1027
+
1028
+ def forward(self, timestep, caption_feat, caption_mask):
1029
+ # timestep embedding:
1030
+ time_freq = self.time_proj(timestep)
1031
+ time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
1032
+
1033
+ # caption condition embedding:
1034
+ caption_mask_float = caption_mask.float().unsqueeze(-1)
1035
+ caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
1036
+ caption_feats_pool = caption_feats_pool.to(caption_feat)
1037
+ caption_embed = self.caption_embedder(caption_feats_pool)
1038
+
1039
+ conditioning = time_embed + caption_embed
1040
+
1041
+ return conditioning
1042
+
1043
+
761
1044
  class TextTimeEmbedding(nn.Module):
762
1045
  def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
763
1046
  super().__init__()
@@ -980,7 +1263,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
980
1263
 
981
1264
  objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
982
1265
 
983
- # positionet with text and image infomation
1266
+ # positionet with text and image information
984
1267
  else:
985
1268
  phrases_masks = phrases_masks.unsqueeze(-1)
986
1269
  image_masks = image_masks.unsqueeze(-1)
@@ -1252,7 +1535,7 @@ class MultiIPAdapterImageProjection(nn.Module):
1252
1535
  if not isinstance(image_embeds, list):
1253
1536
  deprecation_message = (
1254
1537
  "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
1255
- " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
1538
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
1256
1539
  )
1257
1540
  deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
1258
1541
  image_embeds = [image_embeds.unsqueeze(1)]
@@ -191,7 +191,6 @@ def _fetch_index_file(
191
191
  cache_dir,
192
192
  variant,
193
193
  force_download,
194
- resume_download,
195
194
  proxies,
196
195
  local_files_only,
197
196
  token,
@@ -216,12 +215,11 @@ def _fetch_index_file(
216
215
  weights_name=index_file_in_repo,
217
216
  cache_dir=cache_dir,
218
217
  force_download=force_download,
219
- resume_download=resume_download,
220
218
  proxies=proxies,
221
219
  local_files_only=local_files_only,
222
220
  token=token,
223
221
  revision=revision,
224
- subfolder=subfolder,
222
+ subfolder=None,
225
223
  user_agent=user_agent,
226
224
  commit_hash=commit_hash,
227
225
  )
@@ -245,9 +245,7 @@ class FlaxModelMixin(PushToHubMixin):
245
245
  force_download (`bool`, *optional*, defaults to `False`):
246
246
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
247
247
  cached versions if they exist.
248
- resume_download:
249
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
250
- of Diffusers.
248
+
251
249
  proxies (`Dict[str, str]`, *optional*):
252
250
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
253
251
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -296,7 +294,6 @@ class FlaxModelMixin(PushToHubMixin):
296
294
  cache_dir = kwargs.pop("cache_dir", None)
297
295
  force_download = kwargs.pop("force_download", False)
298
296
  from_pt = kwargs.pop("from_pt", False)
299
- resume_download = kwargs.pop("resume_download", None)
300
297
  proxies = kwargs.pop("proxies", None)
301
298
  local_files_only = kwargs.pop("local_files_only", False)
302
299
  token = kwargs.pop("token", None)
@@ -316,7 +313,6 @@ class FlaxModelMixin(PushToHubMixin):
316
313
  cache_dir=cache_dir,
317
314
  return_unused_kwargs=True,
318
315
  force_download=force_download,
319
- resume_download=resume_download,
320
316
  proxies=proxies,
321
317
  local_files_only=local_files_only,
322
318
  token=token,
@@ -362,7 +358,6 @@ class FlaxModelMixin(PushToHubMixin):
362
358
  cache_dir=cache_dir,
363
359
  force_download=force_download,
364
360
  proxies=proxies,
365
- resume_download=resume_download,
366
361
  local_files_only=local_files_only,
367
362
  token=token,
368
363
  user_agent=user_agent,
@@ -434,9 +434,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
434
434
  force_download (`bool`, *optional*, defaults to `False`):
435
435
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
436
436
  cached versions if they exist.
437
- resume_download:
438
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
439
- of Diffusers.
440
437
  proxies (`Dict[str, str]`, *optional*):
441
438
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
442
439
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -518,7 +515,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
518
515
  ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
519
516
  force_download = kwargs.pop("force_download", False)
520
517
  from_flax = kwargs.pop("from_flax", False)
521
- resume_download = kwargs.pop("resume_download", None)
522
518
  proxies = kwargs.pop("proxies", None)
523
519
  output_loading_info = kwargs.pop("output_loading_info", False)
524
520
  local_files_only = kwargs.pop("local_files_only", None)
@@ -619,7 +615,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
619
615
  return_unused_kwargs=True,
620
616
  return_commit_hash=True,
621
617
  force_download=force_download,
622
- resume_download=resume_download,
623
618
  proxies=proxies,
624
619
  local_files_only=local_files_only,
625
620
  token=token,
@@ -641,7 +636,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
641
636
  cache_dir=cache_dir,
642
637
  variant=variant,
643
638
  force_download=force_download,
644
- resume_download=resume_download,
645
639
  proxies=proxies,
646
640
  local_files_only=local_files_only,
647
641
  token=token,
@@ -663,7 +657,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
663
657
  weights_name=FLAX_WEIGHTS_NAME,
664
658
  cache_dir=cache_dir,
665
659
  force_download=force_download,
666
- resume_download=resume_download,
667
660
  proxies=proxies,
668
661
  local_files_only=local_files_only,
669
662
  token=token,
@@ -685,7 +678,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
685
678
  index_file,
686
679
  cache_dir=cache_dir,
687
680
  proxies=proxies,
688
- resume_download=resume_download,
689
681
  local_files_only=local_files_only,
690
682
  token=token,
691
683
  user_agent=user_agent,
@@ -700,7 +692,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
700
692
  weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
701
693
  cache_dir=cache_dir,
702
694
  force_download=force_download,
703
- resume_download=resume_download,
704
695
  proxies=proxies,
705
696
  local_files_only=local_files_only,
706
697
  token=token,
@@ -724,7 +715,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
724
715
  weights_name=_add_variant(WEIGHTS_NAME, variant),
725
716
  cache_dir=cache_dir,
726
717
  force_download=force_download,
727
- resume_download=resume_download,
728
718
  proxies=proxies,
729
719
  local_files_only=local_files_only,
730
720
  token=token,
@@ -783,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
783
773
  try:
784
774
  accelerate.load_checkpoint_and_dispatch(
785
775
  model,
786
- model_file if not is_sharded else sharded_ckpt_cached_folder,
776
+ model_file if not is_sharded else index_file,
787
777
  device_map,
788
778
  max_memory=max_memory,
789
779
  offload_folder=offload_folder,
@@ -813,13 +803,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
813
803
  model._temp_convert_self_to_deprecated_attention_blocks()
814
804
  accelerate.load_checkpoint_and_dispatch(
815
805
  model,
816
- model_file if not is_sharded else sharded_ckpt_cached_folder,
806
+ model_file if not is_sharded else index_file,
817
807
  device_map,
818
808
  max_memory=max_memory,
819
809
  offload_folder=offload_folder,
820
810
  offload_state_dict=offload_state_dict,
821
811
  dtype=torch_dtype,
822
- force_hook=force_hook,
812
+ force_hooks=force_hook,
823
813
  strict=True,
824
814
  )
825
815
  model._undo_temp_convert_self_to_deprecated_attention_blocks()
@@ -1169,7 +1159,7 @@ class LegacyModelMixin(ModelMixin):
1169
1159
  @classmethod
1170
1160
  @validate_hf_hub_args
1171
1161
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1172
- # To prevent depedency import problem.
1162
+ # To prevent dependency import problem.
1173
1163
  from .model_loading_utils import _fetch_remapped_cls_from_config
1174
1164
 
1175
1165
  # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
@@ -1177,7 +1167,6 @@ class LegacyModelMixin(ModelMixin):
1177
1167
 
1178
1168
  cache_dir = kwargs.pop("cache_dir", None)
1179
1169
  force_download = kwargs.pop("force_download", False)
1180
- resume_download = kwargs.pop("resume_download", None)
1181
1170
  proxies = kwargs.pop("proxies", None)
1182
1171
  local_files_only = kwargs.pop("local_files_only", None)
1183
1172
  token = kwargs.pop("token", None)
@@ -1200,7 +1189,6 @@ class LegacyModelMixin(ModelMixin):
1200
1189
  return_unused_kwargs=True,
1201
1190
  return_commit_hash=True,
1202
1191
  force_download=force_download,
1203
- resume_download=resume_download,
1204
1192
  proxies=proxies,
1205
1193
  local_files_only=local_files_only,
1206
1194
  token=token,