diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
752
752
  condition = self.controlnet_cond_embedding(cond)
753
753
  feat_seq = torch.mean(condition, dim=(2, 3))
754
754
  feat_seq = feat_seq + self.task_embedding[control_idx]
755
- if from_multi:
755
+ if from_multi or len(control_type_idx) == 1:
756
756
  inputs.append(feat_seq.unsqueeze(1))
757
757
  condition_list.append(condition)
758
758
  else:
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
772
772
  for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
773
773
  alpha = self.spatial_ch_projs(x[:, idx])
774
774
  alpha = alpha.unsqueeze(-1).unsqueeze(-1)
775
- if from_multi:
775
+ if from_multi or len(control_type_idx) == 1:
776
776
  controlnet_cond_fuser += condition + alpha
777
777
  else:
778
778
  controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
819
819
  # 6. scaling
820
820
  if guess_mode and not self.config.global_pool_conditions:
821
821
  scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
822
- if from_multi:
822
+ if from_multi or len(control_type_idx) == 1:
823
823
  scales = scales * conditioning_scale[0]
824
824
  down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
825
825
  mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
826
- elif from_multi:
826
+ elif from_multi or len(control_type_idx) == 1:
827
827
  down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828
828
  mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
829
829
 
@@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
319
319
  return emb
320
320
 
321
321
 
322
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
322
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
323
323
  """
324
324
  This function generates 1D positional embeddings from a grid.
325
325
 
@@ -352,6 +352,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
352
352
  emb_cos = torch.cos(out) # (M, D/2)
353
353
 
354
354
  emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
355
+
356
+ # flip sine and cosine embeddings
357
+ if flip_sin_to_cos:
358
+ emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
359
+
355
360
  return emb
356
361
 
357
362
 
@@ -1176,6 +1181,7 @@ def apply_rotary_emb(
1176
1181
  freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
1177
1182
  use_real: bool = True,
1178
1183
  use_real_unbind_dim: int = -1,
1184
+ sequence_dim: int = 2,
1179
1185
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1180
1186
  """
1181
1187
  Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1193,8 +1199,15 @@ def apply_rotary_emb(
1193
1199
  """
1194
1200
  if use_real:
1195
1201
  cos, sin = freqs_cis # [S, D]
1196
- cos = cos[None, None]
1197
- sin = sin[None, None]
1202
+ if sequence_dim == 2:
1203
+ cos = cos[None, None, :, :]
1204
+ sin = sin[None, None, :, :]
1205
+ elif sequence_dim == 1:
1206
+ cos = cos[None, :, None, :]
1207
+ sin = sin[None, :, None, :]
1208
+ else:
1209
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
1210
+
1198
1211
  cos, sin = cos.to(x.device), sin.to(x.device)
1199
1212
 
1200
1213
  if use_real_unbind_dim == -1:
@@ -1238,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
1238
1251
  return x
1239
1252
 
1240
1253
 
1241
- class FluxPosEmbed(nn.Module):
1242
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
1243
- def __init__(self, theta: int, axes_dim: List[int]):
1244
- super().__init__()
1245
- self.theta = theta
1246
- self.axes_dim = axes_dim
1247
-
1248
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
1249
- n_axes = ids.shape[-1]
1250
- cos_out = []
1251
- sin_out = []
1252
- pos = ids.float()
1253
- is_mps = ids.device.type == "mps"
1254
- is_npu = ids.device.type == "npu"
1255
- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1256
- for i in range(n_axes):
1257
- cos, sin = get_1d_rotary_pos_embed(
1258
- self.axes_dim[i],
1259
- pos[:, i],
1260
- theta=self.theta,
1261
- repeat_interleave_real=True,
1262
- use_real=True,
1263
- freqs_dtype=freqs_dtype,
1264
- )
1265
- cos_out.append(cos)
1266
- sin_out.append(sin)
1267
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
1268
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
1269
- return freqs_cos, freqs_sin
1270
-
1271
-
1272
1254
  class TimestepEmbedding(nn.Module):
1273
1255
  def __init__(
1274
1256
  self,
@@ -2619,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
2619
2601
  projected_image_embeds.append(image_embed)
2620
2602
 
2621
2603
  return projected_image_embeds
2604
+
2605
+
2606
+ class FluxPosEmbed(nn.Module):
2607
+ def __new__(cls, *args, **kwargs):
2608
+ deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
2609
+ deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
2610
+
2611
+ from .transformers.transformer_flux import FluxPosEmbed
2612
+
2613
+ return FluxPosEmbed(*args, **kwargs)
@@ -14,11 +14,13 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
 
17
+ import functools
17
18
  import importlib
18
19
  import inspect
19
20
  import os
20
21
  from array import array
21
- from collections import OrderedDict
22
+ from collections import OrderedDict, defaultdict
23
+ from concurrent.futures import ThreadPoolExecutor, as_completed
22
24
  from pathlib import Path
23
25
  from typing import Dict, List, Optional, Union
24
26
  from zipfile import is_zipfile
@@ -30,6 +32,7 @@ from huggingface_hub.utils import EntryNotFoundError
30
32
 
31
33
  from ..quantizers import DiffusersQuantizer
32
34
  from ..utils import (
35
+ DEFAULT_HF_PARALLEL_LOADING_WORKERS,
33
36
  GGUF_FILE_EXTENSION,
34
37
  SAFE_WEIGHTS_INDEX_NAME,
35
38
  SAFETENSORS_FILE_EXTENSION,
@@ -38,6 +41,7 @@ from ..utils import (
38
41
  _get_model_file,
39
42
  deprecate,
40
43
  is_accelerate_available,
44
+ is_accelerate_version,
41
45
  is_gguf_available,
42
46
  is_torch_available,
43
47
  is_torch_version,
@@ -252,6 +256,10 @@ def load_model_dict_into_meta(
252
256
  param = param.to(dtype)
253
257
  set_module_kwargs["dtype"] = dtype
254
258
 
259
+ if is_accelerate_version(">", "1.8.1"):
260
+ set_module_kwargs["non_blocking"] = True
261
+ set_module_kwargs["clear_cache"] = False
262
+
255
263
  # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
256
264
  # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
257
265
  # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -304,6 +312,161 @@ def load_model_dict_into_meta(
304
312
  return offload_index, state_dict_index
305
313
 
306
314
 
315
+ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
316
+ """
317
+ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
318
+ checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
319
+ parameters.
320
+
321
+ """
322
+ if model_to_load.device.type == "meta":
323
+ return False
324
+
325
+ if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
326
+ return False
327
+
328
+ # Some models explicitly do not support param buffer assignment
329
+ if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
330
+ logger.debug(
331
+ f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
332
+ )
333
+ return False
334
+
335
+ # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
336
+ first_key = next(iter(model_to_load.state_dict().keys()))
337
+ if start_prefix + first_key in state_dict:
338
+ return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
339
+
340
+ return False
341
+
342
+
343
+ def _load_shard_file(
344
+ shard_file,
345
+ model,
346
+ model_state_dict,
347
+ device_map=None,
348
+ dtype=None,
349
+ hf_quantizer=None,
350
+ keep_in_fp32_modules=None,
351
+ dduf_entries=None,
352
+ loaded_keys=None,
353
+ unexpected_keys=None,
354
+ offload_index=None,
355
+ offload_folder=None,
356
+ state_dict_index=None,
357
+ state_dict_folder=None,
358
+ ignore_mismatched_sizes=False,
359
+ low_cpu_mem_usage=False,
360
+ ):
361
+ state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
362
+ mismatched_keys = _find_mismatched_keys(
363
+ state_dict,
364
+ model_state_dict,
365
+ loaded_keys,
366
+ ignore_mismatched_sizes,
367
+ )
368
+ error_msgs = []
369
+ if low_cpu_mem_usage:
370
+ offload_index, state_dict_index = load_model_dict_into_meta(
371
+ model,
372
+ state_dict,
373
+ device_map=device_map,
374
+ dtype=dtype,
375
+ hf_quantizer=hf_quantizer,
376
+ keep_in_fp32_modules=keep_in_fp32_modules,
377
+ unexpected_keys=unexpected_keys,
378
+ offload_folder=offload_folder,
379
+ offload_index=offload_index,
380
+ state_dict_index=state_dict_index,
381
+ state_dict_folder=state_dict_folder,
382
+ )
383
+ else:
384
+ assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
385
+
386
+ error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
387
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
388
+
389
+
390
+ def _load_shard_files_with_threadpool(
391
+ shard_files,
392
+ model,
393
+ model_state_dict,
394
+ device_map=None,
395
+ dtype=None,
396
+ hf_quantizer=None,
397
+ keep_in_fp32_modules=None,
398
+ dduf_entries=None,
399
+ loaded_keys=None,
400
+ unexpected_keys=None,
401
+ offload_index=None,
402
+ offload_folder=None,
403
+ state_dict_index=None,
404
+ state_dict_folder=None,
405
+ ignore_mismatched_sizes=False,
406
+ low_cpu_mem_usage=False,
407
+ ):
408
+ # Do not spawn anymore workers than you need
409
+ num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
410
+
411
+ logger.info(f"Loading model weights in parallel with {num_workers} workers...")
412
+
413
+ error_msgs = []
414
+ mismatched_keys = []
415
+
416
+ load_one = functools.partial(
417
+ _load_shard_file,
418
+ model=model,
419
+ model_state_dict=model_state_dict,
420
+ device_map=device_map,
421
+ dtype=dtype,
422
+ hf_quantizer=hf_quantizer,
423
+ keep_in_fp32_modules=keep_in_fp32_modules,
424
+ dduf_entries=dduf_entries,
425
+ loaded_keys=loaded_keys,
426
+ unexpected_keys=unexpected_keys,
427
+ offload_index=offload_index,
428
+ offload_folder=offload_folder,
429
+ state_dict_index=state_dict_index,
430
+ state_dict_folder=state_dict_folder,
431
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
432
+ low_cpu_mem_usage=low_cpu_mem_usage,
433
+ )
434
+
435
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
436
+ with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
437
+ futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
438
+ for future in as_completed(futures):
439
+ result = future.result()
440
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
441
+ error_msgs += _error_msgs
442
+ mismatched_keys += _mismatched_keys
443
+ pbar.update(1)
444
+
445
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
446
+
447
+
448
+ def _find_mismatched_keys(
449
+ state_dict,
450
+ model_state_dict,
451
+ loaded_keys,
452
+ ignore_mismatched_sizes,
453
+ ):
454
+ mismatched_keys = []
455
+ if ignore_mismatched_sizes:
456
+ for checkpoint_key in loaded_keys:
457
+ model_key = checkpoint_key
458
+ # If the checkpoint is sharded, we may not have the key here.
459
+ if checkpoint_key not in state_dict:
460
+ continue
461
+
462
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
463
+ mismatched_keys.append(
464
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
465
+ )
466
+ del state_dict[checkpoint_key]
467
+ return mismatched_keys
468
+
469
+
307
470
  def _load_state_dict_into_model(
308
471
  model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
309
472
  ) -> List[str]:
@@ -520,3 +683,72 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
520
683
  parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
521
684
 
522
685
  return parsed_parameters
686
+
687
+
688
+ def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
689
+ mismatched_keys = []
690
+ if not ignore_mismatched_sizes:
691
+ return mismatched_keys
692
+ for checkpoint_key in loaded_keys:
693
+ model_key = checkpoint_key
694
+ # If the checkpoint is sharded, we may not have the key here.
695
+ if checkpoint_key not in state_dict:
696
+ continue
697
+
698
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
699
+ mismatched_keys.append(
700
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
701
+ )
702
+ del state_dict[checkpoint_key]
703
+ return mismatched_keys
704
+
705
+
706
+ def _expand_device_map(device_map, param_names):
707
+ """
708
+ Expand a device map to return the correspondence parameter name to device.
709
+ """
710
+ new_device_map = {}
711
+ for module, device in device_map.items():
712
+ new_device_map.update(
713
+ {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
714
+ )
715
+ return new_device_map
716
+
717
+
718
+ # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
719
+ def _caching_allocator_warmup(
720
+ model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
721
+ ) -> None:
722
+ """
723
+ This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
724
+ device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
725
+ which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
726
+ very large margin.
727
+ """
728
+ factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
729
+
730
+ # Keep only accelerator devices
731
+ accelerator_device_map = {
732
+ param: torch.device(device)
733
+ for param, device in expanded_device_map.items()
734
+ if str(device) not in ["cpu", "disk"]
735
+ }
736
+ if not accelerator_device_map:
737
+ return
738
+
739
+ elements_per_device = defaultdict(int)
740
+ for param_name, device in accelerator_device_map.items():
741
+ try:
742
+ p = model.get_parameter(param_name)
743
+ except AttributeError:
744
+ try:
745
+ p = model.get_buffer(param_name)
746
+ except AttributeError:
747
+ raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
748
+ # TODO: account for TP when needed.
749
+ elements_per_device[device] += p.numel()
750
+
751
+ # This will kick off the caching allocator to avoid having to Malloc afterwards
752
+ for device, elem_count in elements_per_device.items():
753
+ warmup_elems = max(1, elem_count // factor)
754
+ _ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
@@ -369,8 +369,7 @@ class FlaxModelMixin(PushToHubMixin):
369
369
  raise EnvironmentError(
370
370
  f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
371
371
  "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
372
- "token having permission to this repo with `token` or log in with `huggingface-cli "
373
- "login`."
372
+ "token having permission to this repo with `token` or log in with `hf auth login`."
374
373
  )
375
374
  except RevisionNotFoundError:
376
375
  raise EnvironmentError(