diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/lora_base.py
CHANGED
@@ -51,6 +51,9 @@ if is_accelerate_available():
|
|
51
51
|
|
52
52
|
logger = logging.get_logger(__name__)
|
53
53
|
|
54
|
+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
55
|
+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
56
|
+
|
54
57
|
|
55
58
|
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
56
59
|
"""
|
@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder):
|
|
181
184
|
text_encoder._hf_peft_config_loaded = None
|
182
185
|
|
183
186
|
|
187
|
+
def _fetch_state_dict(
|
188
|
+
pretrained_model_name_or_path_or_dict,
|
189
|
+
weight_name,
|
190
|
+
use_safetensors,
|
191
|
+
local_files_only,
|
192
|
+
cache_dir,
|
193
|
+
force_download,
|
194
|
+
proxies,
|
195
|
+
token,
|
196
|
+
revision,
|
197
|
+
subfolder,
|
198
|
+
user_agent,
|
199
|
+
allow_pickle,
|
200
|
+
):
|
201
|
+
model_file = None
|
202
|
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
203
|
+
# Let's first try to load .safetensors weights
|
204
|
+
if (use_safetensors and weight_name is None) or (
|
205
|
+
weight_name is not None and weight_name.endswith(".safetensors")
|
206
|
+
):
|
207
|
+
try:
|
208
|
+
# Here we're relaxing the loading check to enable more Inference API
|
209
|
+
# friendliness where sometimes, it's not at all possible to automatically
|
210
|
+
# determine `weight_name`.
|
211
|
+
if weight_name is None:
|
212
|
+
weight_name = _best_guess_weight_name(
|
213
|
+
pretrained_model_name_or_path_or_dict,
|
214
|
+
file_extension=".safetensors",
|
215
|
+
local_files_only=local_files_only,
|
216
|
+
)
|
217
|
+
model_file = _get_model_file(
|
218
|
+
pretrained_model_name_or_path_or_dict,
|
219
|
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
220
|
+
cache_dir=cache_dir,
|
221
|
+
force_download=force_download,
|
222
|
+
proxies=proxies,
|
223
|
+
local_files_only=local_files_only,
|
224
|
+
token=token,
|
225
|
+
revision=revision,
|
226
|
+
subfolder=subfolder,
|
227
|
+
user_agent=user_agent,
|
228
|
+
)
|
229
|
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
230
|
+
except (IOError, safetensors.SafetensorError) as e:
|
231
|
+
if not allow_pickle:
|
232
|
+
raise e
|
233
|
+
# try loading non-safetensors weights
|
234
|
+
model_file = None
|
235
|
+
pass
|
236
|
+
|
237
|
+
if model_file is None:
|
238
|
+
if weight_name is None:
|
239
|
+
weight_name = _best_guess_weight_name(
|
240
|
+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
241
|
+
)
|
242
|
+
model_file = _get_model_file(
|
243
|
+
pretrained_model_name_or_path_or_dict,
|
244
|
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
245
|
+
cache_dir=cache_dir,
|
246
|
+
force_download=force_download,
|
247
|
+
proxies=proxies,
|
248
|
+
local_files_only=local_files_only,
|
249
|
+
token=token,
|
250
|
+
revision=revision,
|
251
|
+
subfolder=subfolder,
|
252
|
+
user_agent=user_agent,
|
253
|
+
)
|
254
|
+
state_dict = load_state_dict(model_file)
|
255
|
+
else:
|
256
|
+
state_dict = pretrained_model_name_or_path_or_dict
|
257
|
+
|
258
|
+
return state_dict
|
259
|
+
|
260
|
+
|
261
|
+
def _best_guess_weight_name(
|
262
|
+
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
263
|
+
):
|
264
|
+
if local_files_only or HF_HUB_OFFLINE:
|
265
|
+
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
266
|
+
|
267
|
+
targeted_files = []
|
268
|
+
|
269
|
+
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
270
|
+
return
|
271
|
+
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
272
|
+
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
273
|
+
else:
|
274
|
+
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
275
|
+
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
276
|
+
if len(targeted_files) == 0:
|
277
|
+
return
|
278
|
+
|
279
|
+
# "scheduler" does not correspond to a LoRA checkpoint.
|
280
|
+
# "optimizer" does not correspond to a LoRA checkpoint
|
281
|
+
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
282
|
+
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
283
|
+
targeted_files = list(
|
284
|
+
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
285
|
+
)
|
286
|
+
|
287
|
+
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
288
|
+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
289
|
+
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
290
|
+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
291
|
+
|
292
|
+
if len(targeted_files) > 1:
|
293
|
+
raise ValueError(
|
294
|
+
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
295
|
+
)
|
296
|
+
weight_name = targeted_files[0]
|
297
|
+
return weight_name
|
298
|
+
|
299
|
+
|
184
300
|
class LoraBaseMixin:
|
185
301
|
"""Utility class for handling LoRAs."""
|
186
302
|
|
@@ -234,124 +350,16 @@ class LoraBaseMixin:
|
|
234
350
|
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
235
351
|
|
236
352
|
@classmethod
|
237
|
-
def _fetch_state_dict(
|
238
|
-
cls
|
239
|
-
|
240
|
-
|
241
|
-
use_safetensors,
|
242
|
-
local_files_only,
|
243
|
-
cache_dir,
|
244
|
-
force_download,
|
245
|
-
proxies,
|
246
|
-
token,
|
247
|
-
revision,
|
248
|
-
subfolder,
|
249
|
-
user_agent,
|
250
|
-
allow_pickle,
|
251
|
-
):
|
252
|
-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
253
|
-
|
254
|
-
model_file = None
|
255
|
-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
256
|
-
# Let's first try to load .safetensors weights
|
257
|
-
if (use_safetensors and weight_name is None) or (
|
258
|
-
weight_name is not None and weight_name.endswith(".safetensors")
|
259
|
-
):
|
260
|
-
try:
|
261
|
-
# Here we're relaxing the loading check to enable more Inference API
|
262
|
-
# friendliness where sometimes, it's not at all possible to automatically
|
263
|
-
# determine `weight_name`.
|
264
|
-
if weight_name is None:
|
265
|
-
weight_name = cls._best_guess_weight_name(
|
266
|
-
pretrained_model_name_or_path_or_dict,
|
267
|
-
file_extension=".safetensors",
|
268
|
-
local_files_only=local_files_only,
|
269
|
-
)
|
270
|
-
model_file = _get_model_file(
|
271
|
-
pretrained_model_name_or_path_or_dict,
|
272
|
-
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
273
|
-
cache_dir=cache_dir,
|
274
|
-
force_download=force_download,
|
275
|
-
proxies=proxies,
|
276
|
-
local_files_only=local_files_only,
|
277
|
-
token=token,
|
278
|
-
revision=revision,
|
279
|
-
subfolder=subfolder,
|
280
|
-
user_agent=user_agent,
|
281
|
-
)
|
282
|
-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
283
|
-
except (IOError, safetensors.SafetensorError) as e:
|
284
|
-
if not allow_pickle:
|
285
|
-
raise e
|
286
|
-
# try loading non-safetensors weights
|
287
|
-
model_file = None
|
288
|
-
pass
|
289
|
-
|
290
|
-
if model_file is None:
|
291
|
-
if weight_name is None:
|
292
|
-
weight_name = cls._best_guess_weight_name(
|
293
|
-
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
294
|
-
)
|
295
|
-
model_file = _get_model_file(
|
296
|
-
pretrained_model_name_or_path_or_dict,
|
297
|
-
weights_name=weight_name or LORA_WEIGHT_NAME,
|
298
|
-
cache_dir=cache_dir,
|
299
|
-
force_download=force_download,
|
300
|
-
proxies=proxies,
|
301
|
-
local_files_only=local_files_only,
|
302
|
-
token=token,
|
303
|
-
revision=revision,
|
304
|
-
subfolder=subfolder,
|
305
|
-
user_agent=user_agent,
|
306
|
-
)
|
307
|
-
state_dict = load_state_dict(model_file)
|
308
|
-
else:
|
309
|
-
state_dict = pretrained_model_name_or_path_or_dict
|
310
|
-
|
311
|
-
return state_dict
|
353
|
+
def _fetch_state_dict(cls, *args, **kwargs):
|
354
|
+
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
355
|
+
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
356
|
+
return _fetch_state_dict(*args, **kwargs)
|
312
357
|
|
313
358
|
@classmethod
|
314
|
-
def _best_guess_weight_name(
|
315
|
-
cls
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
if local_files_only or HF_HUB_OFFLINE:
|
320
|
-
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
321
|
-
|
322
|
-
targeted_files = []
|
323
|
-
|
324
|
-
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
325
|
-
return
|
326
|
-
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
327
|
-
targeted_files = [
|
328
|
-
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
|
329
|
-
]
|
330
|
-
else:
|
331
|
-
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
332
|
-
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
333
|
-
if len(targeted_files) == 0:
|
334
|
-
return
|
335
|
-
|
336
|
-
# "scheduler" does not correspond to a LoRA checkpoint.
|
337
|
-
# "optimizer" does not correspond to a LoRA checkpoint
|
338
|
-
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
339
|
-
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
340
|
-
targeted_files = list(
|
341
|
-
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
342
|
-
)
|
343
|
-
|
344
|
-
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
345
|
-
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
346
|
-
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
347
|
-
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
348
|
-
|
349
|
-
if len(targeted_files) > 1:
|
350
|
-
raise ValueError(
|
351
|
-
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
352
|
-
)
|
353
|
-
weight_name = targeted_files[0]
|
354
|
-
return weight_name
|
359
|
+
def _best_guess_weight_name(cls, *args, **kwargs):
|
360
|
+
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
361
|
+
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
362
|
+
return _best_guess_weight_name(*args, **kwargs)
|
355
363
|
|
356
364
|
def unload_lora_weights(self):
|
357
365
|
"""
|
@@ -532,13 +540,19 @@ class LoraBaseMixin:
|
|
532
540
|
)
|
533
541
|
|
534
542
|
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
535
|
-
|
536
|
-
|
537
|
-
|
543
|
+
# eg ["adapter1", "adapter2"]
|
544
|
+
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
545
|
+
missing_adapters = set(adapter_names) - all_adapters
|
546
|
+
if len(missing_adapters) > 0:
|
547
|
+
raise ValueError(
|
548
|
+
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
549
|
+
)
|
550
|
+
|
551
|
+
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
538
552
|
invert_list_adapters = {
|
539
553
|
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
540
554
|
for adapter in all_adapters
|
541
|
-
}
|
555
|
+
}
|
542
556
|
|
543
557
|
# Decompose weights into weights for denoiser and text encoders.
|
544
558
|
_component_adapter_weights = {}
|
@@ -699,9 +713,10 @@ class LoraBaseMixin:
|
|
699
713
|
module.lora_B[adapter_name].to(device)
|
700
714
|
# this is a param, not a module, so device placement is not in-place -> re-assign
|
701
715
|
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
702
|
-
|
703
|
-
adapter_name
|
704
|
-
|
716
|
+
if adapter_name in module.lora_magnitude_vector:
|
717
|
+
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
718
|
+
adapter_name
|
719
|
+
].to(device)
|
705
720
|
|
706
721
|
@staticmethod
|
707
722
|
def pack_weights(layers, prefix):
|
@@ -718,8 +733,6 @@ class LoraBaseMixin:
|
|
718
733
|
save_function: Callable,
|
719
734
|
safe_serialization: bool,
|
720
735
|
):
|
721
|
-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
722
|
-
|
723
736
|
if os.path.isfile(save_directory):
|
724
737
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
725
738
|
return
|