diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,14 @@ import safetensors
25
25
  import torch
26
26
  from huggingface_hub.utils import EntryNotFoundError
27
27
 
28
+ from ..quantizers.quantization_config import QuantizationMethod
28
29
  from ..utils import (
29
30
  SAFE_WEIGHTS_INDEX_NAME,
30
31
  SAFETENSORS_FILE_EXTENSION,
31
32
  WEIGHTS_INDEX_NAME,
32
33
  _add_variant,
33
34
  _get_model_file,
35
+ deprecate,
34
36
  is_accelerate_available,
35
37
  is_torch_version,
36
38
  logging,
@@ -53,11 +55,36 @@ if is_accelerate_available():
53
55
 
54
56
 
55
57
  # Adapted from `transformers` (see modeling_utils.py)
56
- def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
58
+ def _determine_device_map(
59
+ model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
60
+ ):
57
61
  if isinstance(device_map, str):
62
+ special_dtypes = {}
63
+ if hf_quantizer is not None:
64
+ special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
65
+ special_dtypes.update(
66
+ {
67
+ name: torch.float32
68
+ for name, _ in model.named_parameters()
69
+ if any(m in name for m in keep_in_fp32_modules)
70
+ }
71
+ )
72
+
73
+ target_dtype = torch_dtype
74
+ if hf_quantizer is not None:
75
+ target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
76
+
58
77
  no_split_modules = model._get_no_split_modules(device_map)
59
78
  device_map_kwargs = {"no_split_module_classes": no_split_modules}
60
79
 
80
+ if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
81
+ device_map_kwargs["special_dtypes"] = special_dtypes
82
+ elif len(special_dtypes) > 0:
83
+ logger.warning(
84
+ "This model has some weights that should be kept in higher precision, you need to upgrade "
85
+ "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
86
+ )
87
+
61
88
  if device_map != "sequential":
62
89
  max_memory = get_balanced_memory(
63
90
  model,
@@ -69,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
69
96
  else:
70
97
  max_memory = get_max_memory(max_memory)
71
98
 
99
+ if hf_quantizer is not None:
100
+ max_memory = hf_quantizer.adjust_max_memory(max_memory)
101
+
72
102
  device_map_kwargs["max_memory"] = max_memory
73
- device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
103
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
104
+
105
+ if hf_quantizer is not None:
106
+ hf_quantizer.validate_environment(device_map=device_map)
74
107
 
75
108
  return device_map
76
109
 
@@ -99,6 +132,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
99
132
  """
100
133
  Reads a checkpoint file, returning properly formatted errors if they arise.
101
134
  """
135
+ # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
136
+ # when refactoring the _merge_sharded_checkpoints() method later.
137
+ if isinstance(checkpoint_file, dict):
138
+ return checkpoint_file
102
139
  try:
103
140
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
104
141
  if file_extension == SAFETENSORS_FILE_EXTENSION:
@@ -136,29 +173,67 @@ def load_model_dict_into_meta(
136
173
  device: Optional[Union[str, torch.device]] = None,
137
174
  dtype: Optional[Union[str, torch.dtype]] = None,
138
175
  model_name_or_path: Optional[str] = None,
176
+ hf_quantizer=None,
177
+ keep_in_fp32_modules=None,
139
178
  ) -> List[str]:
140
- device = device or torch.device("cpu")
179
+ if hf_quantizer is None:
180
+ device = device or torch.device("cpu")
141
181
  dtype = dtype or torch.float32
182
+ is_quantized = hf_quantizer is not None
183
+ is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
142
184
 
143
185
  accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
144
-
145
- unexpected_keys = []
146
186
  empty_state_dict = model.state_dict()
187
+ unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
188
+
147
189
  for param_name, param in state_dict.items():
148
190
  if param_name not in empty_state_dict:
149
- unexpected_keys.append(param_name)
150
191
  continue
151
192
 
193
+ set_module_kwargs = {}
194
+ # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
195
+ # in int/uint/bool and not cast them.
196
+ # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
197
+ if torch.is_floating_point(param):
198
+ if (
199
+ keep_in_fp32_modules is not None
200
+ and any(
201
+ module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
202
+ )
203
+ and dtype == torch.float16
204
+ ):
205
+ param = param.to(torch.float32)
206
+ if accepts_dtype:
207
+ set_module_kwargs["dtype"] = torch.float32
208
+ else:
209
+ param = param.to(dtype)
210
+ if accepts_dtype:
211
+ set_module_kwargs["dtype"] = dtype
212
+
213
+ # bnb params are flattened.
152
214
  if empty_state_dict[param_name].shape != param.shape:
153
- model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
154
- raise ValueError(
155
- f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
156
- )
157
-
158
- if accepts_dtype:
159
- set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
215
+ if (
216
+ is_quant_method_bnb
217
+ and hf_quantizer.pre_quantized
218
+ and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
219
+ ):
220
+ hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
221
+ elif not is_quant_method_bnb:
222
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
223
+ raise ValueError(
224
+ f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
225
+ )
226
+
227
+ if is_quantized and (
228
+ hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
229
+ ):
230
+ hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
160
231
  else:
161
- set_module_tensor_to_device(model, param_name, device, value=param)
232
+ if accepts_dtype:
233
+ set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
234
+ else:
235
+ set_module_tensor_to_device(model, param_name, device, value=param)
236
+
162
237
  return unexpected_keys
163
238
 
164
239
 
@@ -228,3 +303,96 @@ def _fetch_index_file(
228
303
  index_file = None
229
304
 
230
305
  return index_file
306
+
307
+
308
+ # Adapted from
309
+ # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
310
+ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
311
+ weight_map = sharded_metadata.get("weight_map", None)
312
+ if weight_map is None:
313
+ raise KeyError("'weight_map' key not found in the shard index file.")
314
+
315
+ # Collect all unique safetensors files from weight_map
316
+ files_to_load = set(weight_map.values())
317
+ is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
318
+ merged_state_dict = {}
319
+
320
+ # Load tensors from each unique file
321
+ for file_name in files_to_load:
322
+ part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
323
+ if not os.path.exists(part_file_path):
324
+ raise FileNotFoundError(f"Part file {file_name} not found.")
325
+
326
+ if is_safetensors:
327
+ with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
328
+ for tensor_key in f.keys():
329
+ if tensor_key in weight_map:
330
+ merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
331
+ else:
332
+ merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
333
+
334
+ return merged_state_dict
335
+
336
+
337
+ def _fetch_index_file_legacy(
338
+ is_local,
339
+ pretrained_model_name_or_path,
340
+ subfolder,
341
+ use_safetensors,
342
+ cache_dir,
343
+ variant,
344
+ force_download,
345
+ proxies,
346
+ local_files_only,
347
+ token,
348
+ revision,
349
+ user_agent,
350
+ commit_hash,
351
+ ):
352
+ if is_local:
353
+ index_file = Path(
354
+ pretrained_model_name_or_path,
355
+ subfolder or "",
356
+ SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
357
+ ).as_posix()
358
+ splits = index_file.split(".")
359
+ split_index = -3 if ".cache" in index_file else -2
360
+ splits = splits[:-split_index] + [variant] + splits[-split_index:]
361
+ index_file = ".".join(splits)
362
+ if os.path.exists(index_file):
363
+ deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
364
+ deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
365
+ index_file = Path(index_file)
366
+ else:
367
+ index_file = None
368
+ else:
369
+ if variant is not None:
370
+ index_file_in_repo = Path(
371
+ subfolder or "",
372
+ SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
373
+ ).as_posix()
374
+ splits = index_file_in_repo.split(".")
375
+ split_index = -2
376
+ splits = splits[:-split_index] + [variant] + splits[-split_index:]
377
+ index_file_in_repo = ".".join(splits)
378
+ try:
379
+ index_file = _get_model_file(
380
+ pretrained_model_name_or_path,
381
+ weights_name=index_file_in_repo,
382
+ cache_dir=cache_dir,
383
+ force_download=force_download,
384
+ proxies=proxies,
385
+ local_files_only=local_files_only,
386
+ token=token,
387
+ revision=revision,
388
+ subfolder=None,
389
+ user_agent=user_agent,
390
+ commit_hash=commit_hash,
391
+ )
392
+ index_file = Path(index_file)
393
+ deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
394
+ deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
395
+ except (EntryNotFoundError, EnvironmentError):
396
+ index_file = None
397
+
398
+ return index_file