diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +38 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +238 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +40 -7
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +6 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,14 @@ import safetensors
|
|
25
25
|
import torch
|
26
26
|
from huggingface_hub.utils import EntryNotFoundError
|
27
27
|
|
28
|
+
from ..quantizers.quantization_config import QuantizationMethod
|
28
29
|
from ..utils import (
|
29
30
|
SAFE_WEIGHTS_INDEX_NAME,
|
30
31
|
SAFETENSORS_FILE_EXTENSION,
|
31
32
|
WEIGHTS_INDEX_NAME,
|
32
33
|
_add_variant,
|
33
34
|
_get_model_file,
|
35
|
+
deprecate,
|
34
36
|
is_accelerate_available,
|
35
37
|
is_torch_version,
|
36
38
|
logging,
|
@@ -53,11 +55,36 @@ if is_accelerate_available():
|
|
53
55
|
|
54
56
|
|
55
57
|
# Adapted from `transformers` (see modeling_utils.py)
|
56
|
-
def _determine_device_map(
|
58
|
+
def _determine_device_map(
|
59
|
+
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
|
60
|
+
):
|
57
61
|
if isinstance(device_map, str):
|
62
|
+
special_dtypes = {}
|
63
|
+
if hf_quantizer is not None:
|
64
|
+
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
|
65
|
+
special_dtypes.update(
|
66
|
+
{
|
67
|
+
name: torch.float32
|
68
|
+
for name, _ in model.named_parameters()
|
69
|
+
if any(m in name for m in keep_in_fp32_modules)
|
70
|
+
}
|
71
|
+
)
|
72
|
+
|
73
|
+
target_dtype = torch_dtype
|
74
|
+
if hf_quantizer is not None:
|
75
|
+
target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
|
76
|
+
|
58
77
|
no_split_modules = model._get_no_split_modules(device_map)
|
59
78
|
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
60
79
|
|
80
|
+
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
81
|
+
device_map_kwargs["special_dtypes"] = special_dtypes
|
82
|
+
elif len(special_dtypes) > 0:
|
83
|
+
logger.warning(
|
84
|
+
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
85
|
+
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
|
86
|
+
)
|
87
|
+
|
61
88
|
if device_map != "sequential":
|
62
89
|
max_memory = get_balanced_memory(
|
63
90
|
model,
|
@@ -69,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
|
|
69
96
|
else:
|
70
97
|
max_memory = get_max_memory(max_memory)
|
71
98
|
|
99
|
+
if hf_quantizer is not None:
|
100
|
+
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
101
|
+
|
72
102
|
device_map_kwargs["max_memory"] = max_memory
|
73
|
-
device_map = infer_auto_device_map(model, dtype=
|
103
|
+
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
104
|
+
|
105
|
+
if hf_quantizer is not None:
|
106
|
+
hf_quantizer.validate_environment(device_map=device_map)
|
74
107
|
|
75
108
|
return device_map
|
76
109
|
|
@@ -99,6 +132,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|
99
132
|
"""
|
100
133
|
Reads a checkpoint file, returning properly formatted errors if they arise.
|
101
134
|
"""
|
135
|
+
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
|
136
|
+
# when refactoring the _merge_sharded_checkpoints() method later.
|
137
|
+
if isinstance(checkpoint_file, dict):
|
138
|
+
return checkpoint_file
|
102
139
|
try:
|
103
140
|
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
104
141
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
@@ -136,29 +173,67 @@ def load_model_dict_into_meta(
|
|
136
173
|
device: Optional[Union[str, torch.device]] = None,
|
137
174
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
138
175
|
model_name_or_path: Optional[str] = None,
|
176
|
+
hf_quantizer=None,
|
177
|
+
keep_in_fp32_modules=None,
|
139
178
|
) -> List[str]:
|
140
|
-
|
179
|
+
if hf_quantizer is None:
|
180
|
+
device = device or torch.device("cpu")
|
141
181
|
dtype = dtype or torch.float32
|
182
|
+
is_quantized = hf_quantizer is not None
|
183
|
+
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
142
184
|
|
143
185
|
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
144
|
-
|
145
|
-
unexpected_keys = []
|
146
186
|
empty_state_dict = model.state_dict()
|
187
|
+
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
|
188
|
+
|
147
189
|
for param_name, param in state_dict.items():
|
148
190
|
if param_name not in empty_state_dict:
|
149
|
-
unexpected_keys.append(param_name)
|
150
191
|
continue
|
151
192
|
|
193
|
+
set_module_kwargs = {}
|
194
|
+
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
|
195
|
+
# in int/uint/bool and not cast them.
|
196
|
+
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
|
197
|
+
if torch.is_floating_point(param):
|
198
|
+
if (
|
199
|
+
keep_in_fp32_modules is not None
|
200
|
+
and any(
|
201
|
+
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
202
|
+
)
|
203
|
+
and dtype == torch.float16
|
204
|
+
):
|
205
|
+
param = param.to(torch.float32)
|
206
|
+
if accepts_dtype:
|
207
|
+
set_module_kwargs["dtype"] = torch.float32
|
208
|
+
else:
|
209
|
+
param = param.to(dtype)
|
210
|
+
if accepts_dtype:
|
211
|
+
set_module_kwargs["dtype"] = dtype
|
212
|
+
|
213
|
+
# bnb params are flattened.
|
152
214
|
if empty_state_dict[param_name].shape != param.shape:
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
215
|
+
if (
|
216
|
+
is_quant_method_bnb
|
217
|
+
and hf_quantizer.pre_quantized
|
218
|
+
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
219
|
+
):
|
220
|
+
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
|
221
|
+
elif not is_quant_method_bnb:
|
222
|
+
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
223
|
+
raise ValueError(
|
224
|
+
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
225
|
+
)
|
226
|
+
|
227
|
+
if is_quantized and (
|
228
|
+
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
229
|
+
):
|
230
|
+
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
|
160
231
|
else:
|
161
|
-
|
232
|
+
if accepts_dtype:
|
233
|
+
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
|
234
|
+
else:
|
235
|
+
set_module_tensor_to_device(model, param_name, device, value=param)
|
236
|
+
|
162
237
|
return unexpected_keys
|
163
238
|
|
164
239
|
|
@@ -228,3 +303,96 @@ def _fetch_index_file(
|
|
228
303
|
index_file = None
|
229
304
|
|
230
305
|
return index_file
|
306
|
+
|
307
|
+
|
308
|
+
# Adapted from
|
309
|
+
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
|
310
|
+
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
|
311
|
+
weight_map = sharded_metadata.get("weight_map", None)
|
312
|
+
if weight_map is None:
|
313
|
+
raise KeyError("'weight_map' key not found in the shard index file.")
|
314
|
+
|
315
|
+
# Collect all unique safetensors files from weight_map
|
316
|
+
files_to_load = set(weight_map.values())
|
317
|
+
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
|
318
|
+
merged_state_dict = {}
|
319
|
+
|
320
|
+
# Load tensors from each unique file
|
321
|
+
for file_name in files_to_load:
|
322
|
+
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
|
323
|
+
if not os.path.exists(part_file_path):
|
324
|
+
raise FileNotFoundError(f"Part file {file_name} not found.")
|
325
|
+
|
326
|
+
if is_safetensors:
|
327
|
+
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
|
328
|
+
for tensor_key in f.keys():
|
329
|
+
if tensor_key in weight_map:
|
330
|
+
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
|
331
|
+
else:
|
332
|
+
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
|
333
|
+
|
334
|
+
return merged_state_dict
|
335
|
+
|
336
|
+
|
337
|
+
def _fetch_index_file_legacy(
|
338
|
+
is_local,
|
339
|
+
pretrained_model_name_or_path,
|
340
|
+
subfolder,
|
341
|
+
use_safetensors,
|
342
|
+
cache_dir,
|
343
|
+
variant,
|
344
|
+
force_download,
|
345
|
+
proxies,
|
346
|
+
local_files_only,
|
347
|
+
token,
|
348
|
+
revision,
|
349
|
+
user_agent,
|
350
|
+
commit_hash,
|
351
|
+
):
|
352
|
+
if is_local:
|
353
|
+
index_file = Path(
|
354
|
+
pretrained_model_name_or_path,
|
355
|
+
subfolder or "",
|
356
|
+
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
357
|
+
).as_posix()
|
358
|
+
splits = index_file.split(".")
|
359
|
+
split_index = -3 if ".cache" in index_file else -2
|
360
|
+
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
361
|
+
index_file = ".".join(splits)
|
362
|
+
if os.path.exists(index_file):
|
363
|
+
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
364
|
+
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
365
|
+
index_file = Path(index_file)
|
366
|
+
else:
|
367
|
+
index_file = None
|
368
|
+
else:
|
369
|
+
if variant is not None:
|
370
|
+
index_file_in_repo = Path(
|
371
|
+
subfolder or "",
|
372
|
+
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
373
|
+
).as_posix()
|
374
|
+
splits = index_file_in_repo.split(".")
|
375
|
+
split_index = -2
|
376
|
+
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
377
|
+
index_file_in_repo = ".".join(splits)
|
378
|
+
try:
|
379
|
+
index_file = _get_model_file(
|
380
|
+
pretrained_model_name_or_path,
|
381
|
+
weights_name=index_file_in_repo,
|
382
|
+
cache_dir=cache_dir,
|
383
|
+
force_download=force_download,
|
384
|
+
proxies=proxies,
|
385
|
+
local_files_only=local_files_only,
|
386
|
+
token=token,
|
387
|
+
revision=revision,
|
388
|
+
subfolder=None,
|
389
|
+
user_agent=user_agent,
|
390
|
+
commit_hash=commit_hash,
|
391
|
+
)
|
392
|
+
index_file = Path(index_file)
|
393
|
+
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
394
|
+
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
395
|
+
except (EntryNotFoundError, EnvironmentError):
|
396
|
+
index_file = None
|
397
|
+
|
398
|
+
return index_file
|