diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +72 -26
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
|
14
15
|
import os
|
15
16
|
from typing import Callable, Dict, List, Optional, Union
|
16
17
|
|
@@ -21,7 +22,6 @@ from ..utils import (
|
|
21
22
|
USE_PEFT_BACKEND,
|
22
23
|
convert_state_dict_to_diffusers,
|
23
24
|
convert_state_dict_to_peft,
|
24
|
-
convert_unet_state_dict_to_peft,
|
25
25
|
deprecate,
|
26
26
|
get_adapter_name,
|
27
27
|
get_peft_kwargs,
|
@@ -33,8 +33,9 @@ from ..utils import (
|
|
33
33
|
logging,
|
34
34
|
scale_lora_layers,
|
35
35
|
)
|
36
|
-
from .lora_base import LoraBaseMixin
|
36
|
+
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
|
37
37
|
from .lora_conversion_utils import (
|
38
|
+
_convert_bfl_flux_control_lora_to_diffusers,
|
38
39
|
_convert_kohya_flux_lora_to_diffusers,
|
39
40
|
_convert_non_diffusers_lora_to_diffusers,
|
40
41
|
_convert_xlabs_flux_lora_to_diffusers,
|
@@ -62,8 +63,7 @@ TEXT_ENCODER_NAME = "text_encoder"
|
|
62
63
|
UNET_NAME = "unet"
|
63
64
|
TRANSFORMER_NAME = "transformer"
|
64
65
|
|
65
|
-
|
66
|
-
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
66
|
+
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
67
67
|
|
68
68
|
|
69
69
|
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
@@ -222,7 +222,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
222
222
|
"framework": "pytorch",
|
223
223
|
}
|
224
224
|
|
225
|
-
state_dict =
|
225
|
+
state_dict = _fetch_state_dict(
|
226
226
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
227
227
|
weight_name=weight_name,
|
228
228
|
use_safetensors=use_safetensors,
|
@@ -282,7 +282,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
282
282
|
adapter_name (`str`, *optional*):
|
283
283
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
284
284
|
`default_{i}` where i is the total number of adapters being loaded.
|
285
|
-
|
285
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
286
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
287
|
+
weights.
|
286
288
|
"""
|
287
289
|
if not USE_PEFT_BACKEND:
|
288
290
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -300,8 +302,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
300
302
|
if not only_text_encoder:
|
301
303
|
# Load the layers corresponding to UNet.
|
302
304
|
logger.info(f"Loading {cls.unet_name}.")
|
303
|
-
unet.
|
305
|
+
unet.load_lora_adapter(
|
304
306
|
state_dict,
|
307
|
+
prefix=cls.unet_name,
|
305
308
|
network_alphas=network_alphas,
|
306
309
|
adapter_name=adapter_name,
|
307
310
|
_pipeline=_pipeline,
|
@@ -341,7 +344,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
341
344
|
adapter_name (`str`, *optional*):
|
342
345
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
343
346
|
`default_{i}` where i is the total number of adapters being loaded.
|
344
|
-
|
347
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
348
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
349
|
+
weights.
|
345
350
|
"""
|
346
351
|
if not USE_PEFT_BACKEND:
|
347
352
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -407,6 +412,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
407
412
|
}
|
408
413
|
|
409
414
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
415
|
+
|
410
416
|
if "use_dora" in lora_config_kwargs:
|
411
417
|
if lora_config_kwargs["use_dora"]:
|
412
418
|
if is_peft_version("<", "0.9.0"):
|
@@ -416,6 +422,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
416
422
|
else:
|
417
423
|
if is_peft_version("<", "0.9.0"):
|
418
424
|
lora_config_kwargs.pop("use_dora")
|
425
|
+
|
426
|
+
if "lora_bias" in lora_config_kwargs:
|
427
|
+
if lora_config_kwargs["lora_bias"]:
|
428
|
+
if is_peft_version("<=", "0.13.2"):
|
429
|
+
raise ValueError(
|
430
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
431
|
+
)
|
432
|
+
else:
|
433
|
+
if is_peft_version("<=", "0.13.2"):
|
434
|
+
lora_config_kwargs.pop("lora_bias")
|
435
|
+
|
419
436
|
lora_config = LoraConfig(**lora_config_kwargs)
|
420
437
|
|
421
438
|
# adapter_name
|
@@ -601,7 +618,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
601
618
|
adapter_name (`str`, *optional*):
|
602
619
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
603
620
|
`default_{i}` where i is the total number of adapters being loaded.
|
604
|
-
|
621
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
622
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
623
|
+
weights.
|
605
624
|
kwargs (`dict`, *optional*):
|
606
625
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
607
626
|
"""
|
@@ -744,7 +763,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
744
763
|
"framework": "pytorch",
|
745
764
|
}
|
746
765
|
|
747
|
-
state_dict =
|
766
|
+
state_dict = _fetch_state_dict(
|
748
767
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
749
768
|
weight_name=weight_name,
|
750
769
|
use_safetensors=use_safetensors,
|
@@ -805,7 +824,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
805
824
|
adapter_name (`str`, *optional*):
|
806
825
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
807
826
|
`default_{i}` where i is the total number of adapters being loaded.
|
808
|
-
|
827
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
828
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
829
|
+
weights.
|
809
830
|
"""
|
810
831
|
if not USE_PEFT_BACKEND:
|
811
832
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -823,8 +844,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
823
844
|
if not only_text_encoder:
|
824
845
|
# Load the layers corresponding to UNet.
|
825
846
|
logger.info(f"Loading {cls.unet_name}.")
|
826
|
-
unet.
|
847
|
+
unet.load_lora_adapter(
|
827
848
|
state_dict,
|
849
|
+
prefix=cls.unet_name,
|
828
850
|
network_alphas=network_alphas,
|
829
851
|
adapter_name=adapter_name,
|
830
852
|
_pipeline=_pipeline,
|
@@ -865,7 +887,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
865
887
|
adapter_name (`str`, *optional*):
|
866
888
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
867
889
|
`default_{i}` where i is the total number of adapters being loaded.
|
868
|
-
|
890
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
891
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
892
|
+
weights.
|
869
893
|
"""
|
870
894
|
if not USE_PEFT_BACKEND:
|
871
895
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -931,6 +955,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
931
955
|
}
|
932
956
|
|
933
957
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
958
|
+
|
934
959
|
if "use_dora" in lora_config_kwargs:
|
935
960
|
if lora_config_kwargs["use_dora"]:
|
936
961
|
if is_peft_version("<", "0.9.0"):
|
@@ -940,6 +965,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
940
965
|
else:
|
941
966
|
if is_peft_version("<", "0.9.0"):
|
942
967
|
lora_config_kwargs.pop("use_dora")
|
968
|
+
|
969
|
+
if "lora_bias" in lora_config_kwargs:
|
970
|
+
if lora_config_kwargs["lora_bias"]:
|
971
|
+
if is_peft_version("<=", "0.13.2"):
|
972
|
+
raise ValueError(
|
973
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
974
|
+
)
|
975
|
+
else:
|
976
|
+
if is_peft_version("<=", "0.13.2"):
|
977
|
+
lora_config_kwargs.pop("lora_bias")
|
978
|
+
|
943
979
|
lora_config = LoraConfig(**lora_config_kwargs)
|
944
980
|
|
945
981
|
# adapter_name
|
@@ -1182,7 +1218,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1182
1218
|
"framework": "pytorch",
|
1183
1219
|
}
|
1184
1220
|
|
1185
|
-
state_dict =
|
1221
|
+
state_dict = _fetch_state_dict(
|
1186
1222
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1187
1223
|
weight_name=weight_name,
|
1188
1224
|
use_safetensors=use_safetensors,
|
@@ -1226,7 +1262,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1226
1262
|
adapter_name (`str`, *optional*):
|
1227
1263
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1228
1264
|
`default_{i}` where i is the total number of adapters being loaded.
|
1229
|
-
|
1265
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1266
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1267
|
+
weights.
|
1230
1268
|
kwargs (`dict`, *optional*):
|
1231
1269
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1232
1270
|
"""
|
@@ -1250,13 +1288,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1250
1288
|
if not is_correct_format:
|
1251
1289
|
raise ValueError("Invalid LoRA checkpoint.")
|
1252
1290
|
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1291
|
+
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
1292
|
+
if len(transformer_state_dict) > 0:
|
1293
|
+
self.load_lora_into_transformer(
|
1294
|
+
state_dict,
|
1295
|
+
transformer=getattr(self, self.transformer_name)
|
1296
|
+
if not hasattr(self, "transformer")
|
1297
|
+
else self.transformer,
|
1298
|
+
adapter_name=adapter_name,
|
1299
|
+
_pipeline=self,
|
1300
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1301
|
+
)
|
1260
1302
|
|
1261
1303
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1262
1304
|
if len(text_encoder_state_dict) > 0:
|
@@ -1301,94 +1343,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1301
1343
|
adapter_name (`str`, *optional*):
|
1302
1344
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1303
1345
|
`default_{i}` where i is the total number of adapters being loaded.
|
1304
|
-
|
1346
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1347
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1348
|
+
weights.
|
1305
1349
|
"""
|
1306
1350
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1307
1351
|
raise ValueError(
|
1308
1352
|
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1309
1353
|
)
|
1310
1354
|
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
if len(state_dict.keys()) > 0:
|
1321
|
-
# check with first key if is not in peft format
|
1322
|
-
first_key = next(iter(state_dict.keys()))
|
1323
|
-
if "lora_A" not in first_key:
|
1324
|
-
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
1325
|
-
|
1326
|
-
if adapter_name in getattr(transformer, "peft_config", {}):
|
1327
|
-
raise ValueError(
|
1328
|
-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
1329
|
-
)
|
1330
|
-
|
1331
|
-
rank = {}
|
1332
|
-
for key, val in state_dict.items():
|
1333
|
-
if "lora_B" in key:
|
1334
|
-
rank[key] = val.shape[1]
|
1335
|
-
|
1336
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
1337
|
-
if "use_dora" in lora_config_kwargs:
|
1338
|
-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
1339
|
-
raise ValueError(
|
1340
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1341
|
-
)
|
1342
|
-
else:
|
1343
|
-
lora_config_kwargs.pop("use_dora")
|
1344
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
1345
|
-
|
1346
|
-
# adapter_name
|
1347
|
-
if adapter_name is None:
|
1348
|
-
adapter_name = get_adapter_name(transformer)
|
1349
|
-
|
1350
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
1351
|
-
# otherwise loading LoRA weights will lead to an error
|
1352
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1353
|
-
|
1354
|
-
peft_kwargs = {}
|
1355
|
-
if is_peft_version(">=", "0.13.1"):
|
1356
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
1357
|
-
|
1358
|
-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
1359
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
1360
|
-
|
1361
|
-
warn_msg = ""
|
1362
|
-
if incompatible_keys is not None:
|
1363
|
-
# Check only for unexpected keys.
|
1364
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1365
|
-
if unexpected_keys:
|
1366
|
-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
1367
|
-
if lora_unexpected_keys:
|
1368
|
-
warn_msg = (
|
1369
|
-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
1370
|
-
f" {', '.join(lora_unexpected_keys)}. "
|
1371
|
-
)
|
1372
|
-
|
1373
|
-
# Filter missing keys specific to the current adapter.
|
1374
|
-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
1375
|
-
if missing_keys:
|
1376
|
-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
1377
|
-
if lora_missing_keys:
|
1378
|
-
warn_msg += (
|
1379
|
-
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
1380
|
-
f" {', '.join(lora_missing_keys)}."
|
1381
|
-
)
|
1382
|
-
|
1383
|
-
if warn_msg:
|
1384
|
-
logger.warning(warn_msg)
|
1385
|
-
|
1386
|
-
# Offload back.
|
1387
|
-
if is_model_cpu_offload:
|
1388
|
-
_pipeline.enable_model_cpu_offload()
|
1389
|
-
elif is_sequential_cpu_offload:
|
1390
|
-
_pipeline.enable_sequential_cpu_offload()
|
1391
|
-
# Unsafe code />
|
1355
|
+
# Load the layers corresponding to transformer.
|
1356
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1357
|
+
transformer.load_lora_adapter(
|
1358
|
+
state_dict,
|
1359
|
+
network_alphas=None,
|
1360
|
+
adapter_name=adapter_name,
|
1361
|
+
_pipeline=_pipeline,
|
1362
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1363
|
+
)
|
1392
1364
|
|
1393
1365
|
@classmethod
|
1394
1366
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -1424,7 +1396,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1424
1396
|
adapter_name (`str`, *optional*):
|
1425
1397
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1426
1398
|
`default_{i}` where i is the total number of adapters being loaded.
|
1427
|
-
|
1399
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1400
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1401
|
+
weights.
|
1428
1402
|
"""
|
1429
1403
|
if not USE_PEFT_BACKEND:
|
1430
1404
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -1490,6 +1464,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1490
1464
|
}
|
1491
1465
|
|
1492
1466
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
1467
|
+
|
1493
1468
|
if "use_dora" in lora_config_kwargs:
|
1494
1469
|
if lora_config_kwargs["use_dora"]:
|
1495
1470
|
if is_peft_version("<", "0.9.0"):
|
@@ -1499,6 +1474,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1499
1474
|
else:
|
1500
1475
|
if is_peft_version("<", "0.9.0"):
|
1501
1476
|
lora_config_kwargs.pop("use_dora")
|
1477
|
+
|
1478
|
+
if "lora_bias" in lora_config_kwargs:
|
1479
|
+
if lora_config_kwargs["lora_bias"]:
|
1480
|
+
if is_peft_version("<=", "0.13.2"):
|
1481
|
+
raise ValueError(
|
1482
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
1483
|
+
)
|
1484
|
+
else:
|
1485
|
+
if is_peft_version("<=", "0.13.2"):
|
1486
|
+
lora_config_kwargs.pop("lora_bias")
|
1487
|
+
|
1502
1488
|
lora_config = LoraConfig(**lora_config_kwargs)
|
1503
1489
|
|
1504
1490
|
# adapter_name
|
@@ -1666,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1666
1652
|
_lora_loadable_modules = ["transformer", "text_encoder"]
|
1667
1653
|
transformer_name = TRANSFORMER_NAME
|
1668
1654
|
text_encoder_name = TEXT_ENCODER_NAME
|
1655
|
+
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
|
1669
1656
|
|
1670
1657
|
@classmethod
|
1671
1658
|
@validate_hf_hub_args
|
@@ -1742,7 +1729,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1742
1729
|
"framework": "pytorch",
|
1743
1730
|
}
|
1744
1731
|
|
1745
|
-
state_dict =
|
1732
|
+
state_dict = _fetch_state_dict(
|
1746
1733
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1747
1734
|
weight_name=weight_name,
|
1748
1735
|
use_safetensors=use_safetensors,
|
@@ -1775,6 +1762,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1775
1762
|
# xlabs doesn't use `alpha`.
|
1776
1763
|
return (state_dict, None) if return_alphas else state_dict
|
1777
1764
|
|
1765
|
+
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
|
1766
|
+
if is_bfl_control:
|
1767
|
+
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
|
1768
|
+
return (state_dict, None) if return_alphas else state_dict
|
1769
|
+
|
1778
1770
|
# For state dicts like
|
1779
1771
|
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
1780
1772
|
keys = list(state_dict.keys())
|
@@ -1819,7 +1811,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1819
1811
|
adapter_name (`str`, *optional*):
|
1820
1812
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1821
1813
|
`default_{i}` where i is the total number of adapters being loaded.
|
1822
|
-
|
1814
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1815
|
+
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1816
|
+
weights.
|
1823
1817
|
"""
|
1824
1818
|
if not USE_PEFT_BACKEND:
|
1825
1819
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -1839,19 +1833,57 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1839
1833
|
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
1840
1834
|
)
|
1841
1835
|
|
1842
|
-
|
1843
|
-
|
1836
|
+
has_lora_keys = any("lora" in key for key in state_dict.keys())
|
1837
|
+
|
1838
|
+
# Flux Control LoRAs also have norm keys
|
1839
|
+
has_norm_keys = any(
|
1840
|
+
norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
|
1841
|
+
)
|
1842
|
+
|
1843
|
+
if not (has_lora_keys or has_norm_keys):
|
1844
1844
|
raise ValueError("Invalid LoRA checkpoint.")
|
1845
1845
|
|
1846
|
-
|
1847
|
-
state_dict
|
1848
|
-
|
1849
|
-
|
1850
|
-
|
1851
|
-
|
1852
|
-
|
1846
|
+
transformer_lora_state_dict = {
|
1847
|
+
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
|
1848
|
+
}
|
1849
|
+
transformer_norm_state_dict = {
|
1850
|
+
k: state_dict.pop(k)
|
1851
|
+
for k in list(state_dict.keys())
|
1852
|
+
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
1853
|
+
}
|
1854
|
+
|
1855
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1856
|
+
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
1857
|
+
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
1858
|
+
)
|
1859
|
+
|
1860
|
+
if has_param_with_expanded_shape:
|
1861
|
+
logger.info(
|
1862
|
+
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
|
1863
|
+
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
1864
|
+
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
1865
|
+
)
|
1866
|
+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
1867
|
+
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
1853
1868
|
)
|
1854
1869
|
|
1870
|
+
if len(transformer_lora_state_dict) > 0:
|
1871
|
+
self.load_lora_into_transformer(
|
1872
|
+
transformer_lora_state_dict,
|
1873
|
+
network_alphas=network_alphas,
|
1874
|
+
transformer=transformer,
|
1875
|
+
adapter_name=adapter_name,
|
1876
|
+
_pipeline=self,
|
1877
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1878
|
+
)
|
1879
|
+
|
1880
|
+
if len(transformer_norm_state_dict) > 0:
|
1881
|
+
transformer._transformer_norm_layers = self._load_norm_into_transformer(
|
1882
|
+
transformer_norm_state_dict,
|
1883
|
+
transformer=transformer,
|
1884
|
+
discard_original_layers=False,
|
1885
|
+
)
|
1886
|
+
|
1855
1887
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1856
1888
|
if len(text_encoder_state_dict) > 0:
|
1857
1889
|
self.load_lora_into_text_encoder(
|
@@ -1881,104 +1913,86 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1881
1913
|
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1882
1914
|
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1883
1915
|
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1884
|
-
transformer (`
|
1916
|
+
transformer (`FluxTransformer2DModel`):
|
1885
1917
|
The Transformer model to load the LoRA layers into.
|
1886
1918
|
adapter_name (`str`, *optional*):
|
1887
1919
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1888
1920
|
`default_{i}` where i is the total number of adapters being loaded.
|
1889
|
-
|
1921
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1922
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1923
|
+
weights.
|
1890
1924
|
"""
|
1891
1925
|
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1892
1926
|
raise ValueError(
|
1893
1927
|
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1894
1928
|
)
|
1895
1929
|
|
1896
|
-
|
1897
|
-
|
1930
|
+
# Load the layers corresponding to transformer.
|
1898
1931
|
keys = list(state_dict.keys())
|
1932
|
+
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
1933
|
+
if transformer_present:
|
1934
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1935
|
+
transformer.load_lora_adapter(
|
1936
|
+
state_dict,
|
1937
|
+
network_alphas=network_alphas,
|
1938
|
+
adapter_name=adapter_name,
|
1939
|
+
_pipeline=_pipeline,
|
1940
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1941
|
+
)
|
1899
1942
|
|
1900
|
-
|
1901
|
-
|
1902
|
-
|
1903
|
-
|
1904
|
-
|
1905
|
-
|
1906
|
-
|
1907
|
-
|
1908
|
-
|
1909
|
-
|
1910
|
-
|
1911
|
-
if
|
1912
|
-
|
1913
|
-
|
1914
|
-
|
1943
|
+
@classmethod
|
1944
|
+
def _load_norm_into_transformer(
|
1945
|
+
cls,
|
1946
|
+
state_dict,
|
1947
|
+
transformer,
|
1948
|
+
prefix=None,
|
1949
|
+
discard_original_layers=False,
|
1950
|
+
) -> Dict[str, torch.Tensor]:
|
1951
|
+
# Remove prefix if present
|
1952
|
+
prefix = prefix or cls.transformer_name
|
1953
|
+
for key in list(state_dict.keys()):
|
1954
|
+
if key.split(".")[0] == prefix:
|
1955
|
+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
1956
|
+
|
1957
|
+
# Find invalid keys
|
1958
|
+
transformer_state_dict = transformer.state_dict()
|
1959
|
+
transformer_keys = set(transformer_state_dict.keys())
|
1960
|
+
state_dict_keys = set(state_dict.keys())
|
1961
|
+
extra_keys = list(state_dict_keys - transformer_keys)
|
1962
|
+
|
1963
|
+
if extra_keys:
|
1964
|
+
logger.warning(
|
1965
|
+
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
|
1966
|
+
)
|
1915
1967
|
|
1916
|
-
|
1917
|
-
|
1918
|
-
if "lora_B" in key:
|
1919
|
-
rank[key] = val.shape[1]
|
1968
|
+
for key in extra_keys:
|
1969
|
+
state_dict.pop(key)
|
1920
1970
|
|
1921
|
-
|
1922
|
-
|
1923
|
-
|
1924
|
-
|
1971
|
+
# Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
|
1972
|
+
overwritten_layers_state_dict = {}
|
1973
|
+
if not discard_original_layers:
|
1974
|
+
for key in state_dict.keys():
|
1975
|
+
overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
|
1925
1976
|
|
1926
|
-
|
1927
|
-
|
1928
|
-
|
1929
|
-
|
1930
|
-
|
1931
|
-
|
1932
|
-
else:
|
1933
|
-
lora_config_kwargs.pop("use_dora")
|
1934
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
1935
|
-
|
1936
|
-
# adapter_name
|
1937
|
-
if adapter_name is None:
|
1938
|
-
adapter_name = get_adapter_name(transformer)
|
1939
|
-
|
1940
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
1941
|
-
# otherwise loading LoRA weights will lead to an error
|
1942
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1943
|
-
|
1944
|
-
peft_kwargs = {}
|
1945
|
-
if is_peft_version(">=", "0.13.1"):
|
1946
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
1947
|
-
|
1948
|
-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
1949
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
1950
|
-
|
1951
|
-
warn_msg = ""
|
1952
|
-
if incompatible_keys is not None:
|
1953
|
-
# Check only for unexpected keys.
|
1954
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1955
|
-
if unexpected_keys:
|
1956
|
-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
1957
|
-
if lora_unexpected_keys:
|
1958
|
-
warn_msg = (
|
1959
|
-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
1960
|
-
f" {', '.join(lora_unexpected_keys)}. "
|
1961
|
-
)
|
1977
|
+
logger.info(
|
1978
|
+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
|
1979
|
+
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
|
1980
|
+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
|
1981
|
+
"If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
|
1982
|
+
)
|
1962
1983
|
|
1963
|
-
|
1964
|
-
|
1965
|
-
|
1966
|
-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
1967
|
-
if lora_missing_keys:
|
1968
|
-
warn_msg += (
|
1969
|
-
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
1970
|
-
f" {', '.join(lora_missing_keys)}."
|
1971
|
-
)
|
1984
|
+
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
|
1985
|
+
incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
|
1986
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1972
1987
|
|
1973
|
-
|
1974
|
-
|
1988
|
+
# We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
|
1989
|
+
if unexpected_keys:
|
1990
|
+
if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
|
1991
|
+
raise ValueError(
|
1992
|
+
f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
|
1993
|
+
)
|
1975
1994
|
|
1976
|
-
|
1977
|
-
if is_model_cpu_offload:
|
1978
|
-
_pipeline.enable_model_cpu_offload()
|
1979
|
-
elif is_sequential_cpu_offload:
|
1980
|
-
_pipeline.enable_sequential_cpu_offload()
|
1981
|
-
# Unsafe code />
|
1995
|
+
return overwritten_layers_state_dict
|
1982
1996
|
|
1983
1997
|
@classmethod
|
1984
1998
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -2014,7 +2028,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2014
2028
|
adapter_name (`str`, *optional*):
|
2015
2029
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2016
2030
|
`default_{i}` where i is the total number of adapters being loaded.
|
2017
|
-
|
2031
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2032
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2033
|
+
weights.
|
2018
2034
|
"""
|
2019
2035
|
if not USE_PEFT_BACKEND:
|
2020
2036
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -2080,6 +2096,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2080
2096
|
}
|
2081
2097
|
|
2082
2098
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
2099
|
+
|
2083
2100
|
if "use_dora" in lora_config_kwargs:
|
2084
2101
|
if lora_config_kwargs["use_dora"]:
|
2085
2102
|
if is_peft_version("<", "0.9.0"):
|
@@ -2089,6 +2106,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2089
2106
|
else:
|
2090
2107
|
if is_peft_version("<", "0.9.0"):
|
2091
2108
|
lora_config_kwargs.pop("use_dora")
|
2109
|
+
|
2110
|
+
if "lora_bias" in lora_config_kwargs:
|
2111
|
+
if lora_config_kwargs["lora_bias"]:
|
2112
|
+
if is_peft_version("<=", "0.13.2"):
|
2113
|
+
raise ValueError(
|
2114
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
2115
|
+
)
|
2116
|
+
else:
|
2117
|
+
if is_peft_version("<=", "0.13.2"):
|
2118
|
+
lora_config_kwargs.pop("lora_bias")
|
2119
|
+
|
2092
2120
|
lora_config = LoraConfig(**lora_config_kwargs)
|
2093
2121
|
|
2094
2122
|
# adapter_name
|
@@ -2173,7 +2201,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2173
2201
|
safe_serialization=safe_serialization,
|
2174
2202
|
)
|
2175
2203
|
|
2176
|
-
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
2177
2204
|
def fuse_lora(
|
2178
2205
|
self,
|
2179
2206
|
components: List[str] = ["transformer", "text_encoder"],
|
@@ -2213,6 +2240,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2213
2240
|
pipeline.fuse_lora(lora_scale=0.7)
|
2214
2241
|
```
|
2215
2242
|
"""
|
2243
|
+
|
2244
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
2245
|
+
if (
|
2246
|
+
hasattr(transformer, "_transformer_norm_layers")
|
2247
|
+
and isinstance(transformer._transformer_norm_layers, dict)
|
2248
|
+
and len(transformer._transformer_norm_layers.keys()) > 0
|
2249
|
+
):
|
2250
|
+
logger.info(
|
2251
|
+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
|
2252
|
+
"as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
|
2253
|
+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
|
2254
|
+
)
|
2255
|
+
|
2216
2256
|
super().fuse_lora(
|
2217
2257
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
2218
2258
|
)
|
@@ -2231,8 +2271,168 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2231
2271
|
Args:
|
2232
2272
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
2233
2273
|
"""
|
2274
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
2275
|
+
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
2276
|
+
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
2277
|
+
|
2234
2278
|
super().unfuse_lora(components=components)
|
2235
2279
|
|
2280
|
+
# We override this here account for `_transformer_norm_layers`.
|
2281
|
+
def unload_lora_weights(self):
|
2282
|
+
super().unload_lora_weights()
|
2283
|
+
|
2284
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
2285
|
+
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
2286
|
+
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
2287
|
+
transformer._transformer_norm_layers = None
|
2288
|
+
|
2289
|
+
@classmethod
|
2290
|
+
def _maybe_expand_transformer_param_shape_or_error_(
|
2291
|
+
cls,
|
2292
|
+
transformer: torch.nn.Module,
|
2293
|
+
lora_state_dict=None,
|
2294
|
+
norm_state_dict=None,
|
2295
|
+
prefix=None,
|
2296
|
+
) -> bool:
|
2297
|
+
"""
|
2298
|
+
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
|
2299
|
+
generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
|
2300
|
+
"""
|
2301
|
+
state_dict = {}
|
2302
|
+
if lora_state_dict is not None:
|
2303
|
+
state_dict.update(lora_state_dict)
|
2304
|
+
if norm_state_dict is not None:
|
2305
|
+
state_dict.update(norm_state_dict)
|
2306
|
+
|
2307
|
+
# Remove prefix if present
|
2308
|
+
prefix = prefix or cls.transformer_name
|
2309
|
+
for key in list(state_dict.keys()):
|
2310
|
+
if key.split(".")[0] == prefix:
|
2311
|
+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
2312
|
+
|
2313
|
+
# Expand transformer parameter shapes if they don't match lora
|
2314
|
+
has_param_with_shape_update = False
|
2315
|
+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
2316
|
+
for name, module in transformer.named_modules():
|
2317
|
+
if isinstance(module, torch.nn.Linear):
|
2318
|
+
module_weight = module.weight.data
|
2319
|
+
module_bias = module.bias.data if module.bias is not None else None
|
2320
|
+
bias = module_bias is not None
|
2321
|
+
|
2322
|
+
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
|
2323
|
+
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
|
2324
|
+
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
|
2325
|
+
if lora_A_weight_name not in state_dict:
|
2326
|
+
continue
|
2327
|
+
|
2328
|
+
in_features = state_dict[lora_A_weight_name].shape[1]
|
2329
|
+
out_features = state_dict[lora_B_weight_name].shape[0]
|
2330
|
+
|
2331
|
+
# This means there's no need for an expansion in the params, so we simply skip.
|
2332
|
+
if tuple(module_weight.shape) == (out_features, in_features):
|
2333
|
+
continue
|
2334
|
+
|
2335
|
+
module_out_features, module_in_features = module_weight.shape
|
2336
|
+
debug_message = ""
|
2337
|
+
if in_features > module_in_features:
|
2338
|
+
debug_message += (
|
2339
|
+
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
|
2340
|
+
f"checkpoint contains higher number of features than expected. The number of input_features will be "
|
2341
|
+
f"expanded from {module_in_features} to {in_features}"
|
2342
|
+
)
|
2343
|
+
if out_features > module_out_features:
|
2344
|
+
debug_message += (
|
2345
|
+
", and the number of output features will be "
|
2346
|
+
f"expanded from {module_out_features} to {out_features}."
|
2347
|
+
)
|
2348
|
+
else:
|
2349
|
+
debug_message += "."
|
2350
|
+
if debug_message:
|
2351
|
+
logger.debug(debug_message)
|
2352
|
+
|
2353
|
+
if out_features > module_out_features or in_features > module_in_features:
|
2354
|
+
has_param_with_shape_update = True
|
2355
|
+
parent_module_name, _, current_module_name = name.rpartition(".")
|
2356
|
+
parent_module = transformer.get_submodule(parent_module_name)
|
2357
|
+
|
2358
|
+
with torch.device("meta"):
|
2359
|
+
expanded_module = torch.nn.Linear(
|
2360
|
+
in_features, out_features, bias=bias, dtype=module_weight.dtype
|
2361
|
+
)
|
2362
|
+
# Only weights are expanded and biases are not. This is because only the input dimensions
|
2363
|
+
# are changed while the output dimensions remain the same. The shape of the weight tensor
|
2364
|
+
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
|
2365
|
+
# explains the reason why only weights are expanded.
|
2366
|
+
new_weight = torch.zeros_like(
|
2367
|
+
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
|
2368
|
+
)
|
2369
|
+
slices = tuple(slice(0, dim) for dim in module_weight.shape)
|
2370
|
+
new_weight[slices] = module_weight
|
2371
|
+
tmp_state_dict = {"weight": new_weight}
|
2372
|
+
if module_bias is not None:
|
2373
|
+
tmp_state_dict["bias"] = module_bias
|
2374
|
+
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
|
2375
|
+
|
2376
|
+
setattr(parent_module, current_module_name, expanded_module)
|
2377
|
+
|
2378
|
+
del tmp_state_dict
|
2379
|
+
|
2380
|
+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
2381
|
+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
2382
|
+
new_value = int(expanded_module.weight.data.shape[1])
|
2383
|
+
old_value = getattr(transformer.config, attribute_name)
|
2384
|
+
setattr(transformer.config, attribute_name, new_value)
|
2385
|
+
logger.info(
|
2386
|
+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
2387
|
+
)
|
2388
|
+
|
2389
|
+
return has_param_with_shape_update
|
2390
|
+
|
2391
|
+
@classmethod
|
2392
|
+
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
|
2393
|
+
expanded_module_names = set()
|
2394
|
+
transformer_state_dict = transformer.state_dict()
|
2395
|
+
prefix = f"{cls.transformer_name}."
|
2396
|
+
|
2397
|
+
lora_module_names = [
|
2398
|
+
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
|
2399
|
+
]
|
2400
|
+
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
|
2401
|
+
lora_module_names = sorted(set(lora_module_names))
|
2402
|
+
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
|
2403
|
+
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
|
2404
|
+
if unexpected_modules:
|
2405
|
+
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
|
2406
|
+
|
2407
|
+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
2408
|
+
for k in lora_module_names:
|
2409
|
+
if k in unexpected_modules:
|
2410
|
+
continue
|
2411
|
+
|
2412
|
+
base_param_name = (
|
2413
|
+
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
|
2414
|
+
)
|
2415
|
+
base_weight_param = transformer_state_dict[base_param_name]
|
2416
|
+
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
|
2417
|
+
|
2418
|
+
if base_weight_param.shape[1] > lora_A_param.shape[1]:
|
2419
|
+
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
|
2420
|
+
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
|
2421
|
+
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
|
2422
|
+
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
|
2423
|
+
expanded_module_names.add(k)
|
2424
|
+
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
|
2425
|
+
raise NotImplementedError(
|
2426
|
+
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
|
2427
|
+
)
|
2428
|
+
|
2429
|
+
if expanded_module_names:
|
2430
|
+
logger.info(
|
2431
|
+
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
|
2432
|
+
)
|
2433
|
+
|
2434
|
+
return lora_state_dict
|
2435
|
+
|
2236
2436
|
|
2237
2437
|
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
2238
2438
|
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
@@ -2242,7 +2442,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2242
2442
|
text_encoder_name = TEXT_ENCODER_NAME
|
2243
2443
|
|
2244
2444
|
@classmethod
|
2245
|
-
|
2445
|
+
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
|
2446
|
+
def load_lora_into_transformer(
|
2447
|
+
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2448
|
+
):
|
2246
2449
|
"""
|
2247
2450
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2248
2451
|
|
@@ -2255,93 +2458,32 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2255
2458
|
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2256
2459
|
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2257
2460
|
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2258
|
-
|
2259
|
-
The
|
2461
|
+
transformer (`UVit2DModel`):
|
2462
|
+
The Transformer model to load the LoRA layers into.
|
2260
2463
|
adapter_name (`str`, *optional*):
|
2261
2464
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2262
2465
|
`default_{i}` where i is the total number of adapters being loaded.
|
2466
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2467
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2468
|
+
weights.
|
2263
2469
|
"""
|
2264
|
-
if not
|
2265
|
-
raise ValueError(
|
2266
|
-
|
2267
|
-
|
2470
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
2471
|
+
raise ValueError(
|
2472
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2473
|
+
)
|
2268
2474
|
|
2475
|
+
# Load the layers corresponding to transformer.
|
2269
2476
|
keys = list(state_dict.keys())
|
2270
|
-
|
2271
|
-
|
2272
|
-
|
2273
|
-
|
2274
|
-
|
2275
|
-
|
2276
|
-
|
2277
|
-
|
2278
|
-
|
2279
|
-
|
2280
|
-
}
|
2281
|
-
|
2282
|
-
if len(state_dict.keys()) > 0:
|
2283
|
-
if adapter_name in getattr(transformer, "peft_config", {}):
|
2284
|
-
raise ValueError(
|
2285
|
-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
2286
|
-
)
|
2287
|
-
|
2288
|
-
rank = {}
|
2289
|
-
for key, val in state_dict.items():
|
2290
|
-
if "lora_B" in key:
|
2291
|
-
rank[key] = val.shape[1]
|
2292
|
-
|
2293
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
2294
|
-
if "use_dora" in lora_config_kwargs:
|
2295
|
-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
2296
|
-
raise ValueError(
|
2297
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2298
|
-
)
|
2299
|
-
else:
|
2300
|
-
lora_config_kwargs.pop("use_dora")
|
2301
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
2302
|
-
|
2303
|
-
# adapter_name
|
2304
|
-
if adapter_name is None:
|
2305
|
-
adapter_name = get_adapter_name(transformer)
|
2306
|
-
|
2307
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
2308
|
-
# otherwise loading LoRA weights will lead to an error
|
2309
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2310
|
-
|
2311
|
-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
2312
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
2313
|
-
|
2314
|
-
warn_msg = ""
|
2315
|
-
if incompatible_keys is not None:
|
2316
|
-
# Check only for unexpected keys.
|
2317
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
2318
|
-
if unexpected_keys:
|
2319
|
-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
2320
|
-
if lora_unexpected_keys:
|
2321
|
-
warn_msg = (
|
2322
|
-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
2323
|
-
f" {', '.join(lora_unexpected_keys)}. "
|
2324
|
-
)
|
2325
|
-
|
2326
|
-
# Filter missing keys specific to the current adapter.
|
2327
|
-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
2328
|
-
if missing_keys:
|
2329
|
-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
2330
|
-
if lora_missing_keys:
|
2331
|
-
warn_msg += (
|
2332
|
-
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
2333
|
-
f" {', '.join(lora_missing_keys)}."
|
2334
|
-
)
|
2335
|
-
|
2336
|
-
if warn_msg:
|
2337
|
-
logger.warning(warn_msg)
|
2338
|
-
|
2339
|
-
# Offload back.
|
2340
|
-
if is_model_cpu_offload:
|
2341
|
-
_pipeline.enable_model_cpu_offload()
|
2342
|
-
elif is_sequential_cpu_offload:
|
2343
|
-
_pipeline.enable_sequential_cpu_offload()
|
2344
|
-
# Unsafe code />
|
2477
|
+
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
2478
|
+
if transformer_present:
|
2479
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2480
|
+
transformer.load_lora_adapter(
|
2481
|
+
state_dict,
|
2482
|
+
network_alphas=network_alphas,
|
2483
|
+
adapter_name=adapter_name,
|
2484
|
+
_pipeline=_pipeline,
|
2485
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2486
|
+
)
|
2345
2487
|
|
2346
2488
|
@classmethod
|
2347
2489
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -2377,7 +2519,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2377
2519
|
adapter_name (`str`, *optional*):
|
2378
2520
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2379
2521
|
`default_{i}` where i is the total number of adapters being loaded.
|
2380
|
-
|
2522
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2523
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2524
|
+
weights.
|
2381
2525
|
"""
|
2382
2526
|
if not USE_PEFT_BACKEND:
|
2383
2527
|
raise ValueError("PEFT backend is required for this method.")
|
@@ -2443,6 +2587,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2443
2587
|
}
|
2444
2588
|
|
2445
2589
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
2590
|
+
|
2446
2591
|
if "use_dora" in lora_config_kwargs:
|
2447
2592
|
if lora_config_kwargs["use_dora"]:
|
2448
2593
|
if is_peft_version("<", "0.9.0"):
|
@@ -2452,6 +2597,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2452
2597
|
else:
|
2453
2598
|
if is_peft_version("<", "0.9.0"):
|
2454
2599
|
lora_config_kwargs.pop("use_dora")
|
2600
|
+
|
2601
|
+
if "lora_bias" in lora_config_kwargs:
|
2602
|
+
if lora_config_kwargs["lora_bias"]:
|
2603
|
+
if is_peft_version("<=", "0.13.2"):
|
2604
|
+
raise ValueError(
|
2605
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
2606
|
+
)
|
2607
|
+
else:
|
2608
|
+
if is_peft_version("<=", "0.13.2"):
|
2609
|
+
lora_config_kwargs.pop("lora_bias")
|
2610
|
+
|
2455
2611
|
lora_config = LoraConfig(**lora_config_kwargs)
|
2456
2612
|
|
2457
2613
|
# adapter_name
|
@@ -2538,7 +2694,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2538
2694
|
|
2539
2695
|
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
2540
2696
|
r"""
|
2541
|
-
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`
|
2697
|
+
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
|
2542
2698
|
"""
|
2543
2699
|
|
2544
2700
|
_lora_loadable_modules = ["transformer"]
|
@@ -2619,7 +2775,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2619
2775
|
"framework": "pytorch",
|
2620
2776
|
}
|
2621
2777
|
|
2622
|
-
state_dict =
|
2778
|
+
state_dict = _fetch_state_dict(
|
2623
2779
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2624
2780
|
weight_name=weight_name,
|
2625
2781
|
use_safetensors=use_safetensors,
|
@@ -2658,7 +2814,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2658
2814
|
adapter_name (`str`, *optional*):
|
2659
2815
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2660
2816
|
`default_{i}` where i is the total number of adapters being loaded.
|
2661
|
-
|
2817
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2818
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2819
|
+
weights.
|
2662
2820
|
kwargs (`dict`, *optional*):
|
2663
2821
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2664
2822
|
"""
|
@@ -2691,7 +2849,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2691
2849
|
)
|
2692
2850
|
|
2693
2851
|
@classmethod
|
2694
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
2852
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
2695
2853
|
def load_lora_into_transformer(
|
2696
2854
|
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2697
2855
|
):
|
@@ -2703,99 +2861,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2703
2861
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
2704
2862
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
2705
2863
|
encoder lora layers.
|
2706
|
-
transformer (`
|
2864
|
+
transformer (`CogVideoXTransformer3DModel`):
|
2707
2865
|
The Transformer model to load the LoRA layers into.
|
2708
2866
|
adapter_name (`str`, *optional*):
|
2709
2867
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2710
2868
|
`default_{i}` where i is the total number of adapters being loaded.
|
2711
|
-
|
2869
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2870
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2871
|
+
weights.
|
2712
2872
|
"""
|
2713
2873
|
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2714
2874
|
raise ValueError(
|
2715
2875
|
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2716
2876
|
)
|
2717
2877
|
|
2718
|
-
|
2719
|
-
|
2720
|
-
|
2721
|
-
|
2722
|
-
|
2723
|
-
|
2724
|
-
|
2725
|
-
|
2726
|
-
|
2727
|
-
if len(state_dict.keys()) > 0:
|
2728
|
-
# check with first key if is not in peft format
|
2729
|
-
first_key = next(iter(state_dict.keys()))
|
2730
|
-
if "lora_A" not in first_key:
|
2731
|
-
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
2732
|
-
|
2733
|
-
if adapter_name in getattr(transformer, "peft_config", {}):
|
2734
|
-
raise ValueError(
|
2735
|
-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
2736
|
-
)
|
2737
|
-
|
2738
|
-
rank = {}
|
2739
|
-
for key, val in state_dict.items():
|
2740
|
-
if "lora_B" in key:
|
2741
|
-
rank[key] = val.shape[1]
|
2742
|
-
|
2743
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
2744
|
-
if "use_dora" in lora_config_kwargs:
|
2745
|
-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
2746
|
-
raise ValueError(
|
2747
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2748
|
-
)
|
2749
|
-
else:
|
2750
|
-
lora_config_kwargs.pop("use_dora")
|
2751
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
2752
|
-
|
2753
|
-
# adapter_name
|
2754
|
-
if adapter_name is None:
|
2755
|
-
adapter_name = get_adapter_name(transformer)
|
2756
|
-
|
2757
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
2758
|
-
# otherwise loading LoRA weights will lead to an error
|
2759
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2760
|
-
|
2761
|
-
peft_kwargs = {}
|
2762
|
-
if is_peft_version(">=", "0.13.1"):
|
2763
|
-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2764
|
-
|
2765
|
-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
2766
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
2767
|
-
|
2768
|
-
warn_msg = ""
|
2769
|
-
if incompatible_keys is not None:
|
2770
|
-
# Check only for unexpected keys.
|
2771
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
2772
|
-
if unexpected_keys:
|
2773
|
-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
2774
|
-
if lora_unexpected_keys:
|
2775
|
-
warn_msg = (
|
2776
|
-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
2777
|
-
f" {', '.join(lora_unexpected_keys)}. "
|
2778
|
-
)
|
2779
|
-
|
2780
|
-
# Filter missing keys specific to the current adapter.
|
2781
|
-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
2782
|
-
if missing_keys:
|
2783
|
-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
2784
|
-
if lora_missing_keys:
|
2785
|
-
warn_msg += (
|
2786
|
-
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
2787
|
-
f" {', '.join(lora_missing_keys)}."
|
2788
|
-
)
|
2789
|
-
|
2790
|
-
if warn_msg:
|
2791
|
-
logger.warning(warn_msg)
|
2792
|
-
|
2793
|
-
# Offload back.
|
2794
|
-
if is_model_cpu_offload:
|
2795
|
-
_pipeline.enable_model_cpu_offload()
|
2796
|
-
elif is_sequential_cpu_offload:
|
2797
|
-
_pipeline.enable_sequential_cpu_offload()
|
2798
|
-
# Unsafe code />
|
2878
|
+
# Load the layers corresponding to transformer.
|
2879
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2880
|
+
transformer.load_lora_adapter(
|
2881
|
+
state_dict,
|
2882
|
+
network_alphas=None,
|
2883
|
+
adapter_name=adapter_name,
|
2884
|
+
_pipeline=_pipeline,
|
2885
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2886
|
+
)
|
2799
2887
|
|
2800
2888
|
@classmethod
|
2801
2889
|
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
|
@@ -2911,6 +2999,1238 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
|
2911
2999
|
super().unfuse_lora(components=components)
|
2912
3000
|
|
2913
3001
|
|
3002
|
+
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
3003
|
+
r"""
|
3004
|
+
Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
|
3005
|
+
"""
|
3006
|
+
|
3007
|
+
_lora_loadable_modules = ["transformer"]
|
3008
|
+
transformer_name = TRANSFORMER_NAME
|
3009
|
+
|
3010
|
+
@classmethod
|
3011
|
+
@validate_hf_hub_args
|
3012
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3013
|
+
def lora_state_dict(
|
3014
|
+
cls,
|
3015
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3016
|
+
**kwargs,
|
3017
|
+
):
|
3018
|
+
r"""
|
3019
|
+
Return state dict for lora weights and the network alphas.
|
3020
|
+
|
3021
|
+
<Tip warning={true}>
|
3022
|
+
|
3023
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3024
|
+
|
3025
|
+
This function is experimental and might change in the future.
|
3026
|
+
|
3027
|
+
</Tip>
|
3028
|
+
|
3029
|
+
Parameters:
|
3030
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3031
|
+
Can be either:
|
3032
|
+
|
3033
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3034
|
+
the Hub.
|
3035
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3036
|
+
with [`ModelMixin.save_pretrained`].
|
3037
|
+
- A [torch state
|
3038
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3039
|
+
|
3040
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3041
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3042
|
+
is not used.
|
3043
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3044
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3045
|
+
cached versions if they exist.
|
3046
|
+
|
3047
|
+
proxies (`Dict[str, str]`, *optional*):
|
3048
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3049
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3050
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3051
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3052
|
+
won't be downloaded from the Hub.
|
3053
|
+
token (`str` or *bool*, *optional*):
|
3054
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3055
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3056
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3057
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3058
|
+
allowed by Git.
|
3059
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3060
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3061
|
+
|
3062
|
+
"""
|
3063
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3064
|
+
# transformer and text encoder or both.
|
3065
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3066
|
+
force_download = kwargs.pop("force_download", False)
|
3067
|
+
proxies = kwargs.pop("proxies", None)
|
3068
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3069
|
+
token = kwargs.pop("token", None)
|
3070
|
+
revision = kwargs.pop("revision", None)
|
3071
|
+
subfolder = kwargs.pop("subfolder", None)
|
3072
|
+
weight_name = kwargs.pop("weight_name", None)
|
3073
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3074
|
+
|
3075
|
+
allow_pickle = False
|
3076
|
+
if use_safetensors is None:
|
3077
|
+
use_safetensors = True
|
3078
|
+
allow_pickle = True
|
3079
|
+
|
3080
|
+
user_agent = {
|
3081
|
+
"file_type": "attn_procs_weights",
|
3082
|
+
"framework": "pytorch",
|
3083
|
+
}
|
3084
|
+
|
3085
|
+
state_dict = _fetch_state_dict(
|
3086
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3087
|
+
weight_name=weight_name,
|
3088
|
+
use_safetensors=use_safetensors,
|
3089
|
+
local_files_only=local_files_only,
|
3090
|
+
cache_dir=cache_dir,
|
3091
|
+
force_download=force_download,
|
3092
|
+
proxies=proxies,
|
3093
|
+
token=token,
|
3094
|
+
revision=revision,
|
3095
|
+
subfolder=subfolder,
|
3096
|
+
user_agent=user_agent,
|
3097
|
+
allow_pickle=allow_pickle,
|
3098
|
+
)
|
3099
|
+
|
3100
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3101
|
+
if is_dora_scale_present:
|
3102
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3103
|
+
logger.warning(warn_msg)
|
3104
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3105
|
+
|
3106
|
+
return state_dict
|
3107
|
+
|
3108
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3109
|
+
def load_lora_weights(
|
3110
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3111
|
+
):
|
3112
|
+
"""
|
3113
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3114
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3115
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3116
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3117
|
+
dict is loaded into `self.transformer`.
|
3118
|
+
|
3119
|
+
Parameters:
|
3120
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3121
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3122
|
+
adapter_name (`str`, *optional*):
|
3123
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3124
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3125
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3126
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3127
|
+
weights.
|
3128
|
+
kwargs (`dict`, *optional*):
|
3129
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3130
|
+
"""
|
3131
|
+
if not USE_PEFT_BACKEND:
|
3132
|
+
raise ValueError("PEFT backend is required for this method.")
|
3133
|
+
|
3134
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3135
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3136
|
+
raise ValueError(
|
3137
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3138
|
+
)
|
3139
|
+
|
3140
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3141
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3142
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3143
|
+
|
3144
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3145
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3146
|
+
|
3147
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3148
|
+
if not is_correct_format:
|
3149
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3150
|
+
|
3151
|
+
self.load_lora_into_transformer(
|
3152
|
+
state_dict,
|
3153
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3154
|
+
adapter_name=adapter_name,
|
3155
|
+
_pipeline=self,
|
3156
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3157
|
+
)
|
3158
|
+
|
3159
|
+
@classmethod
|
3160
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
|
3161
|
+
def load_lora_into_transformer(
|
3162
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3163
|
+
):
|
3164
|
+
"""
|
3165
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3166
|
+
|
3167
|
+
Parameters:
|
3168
|
+
state_dict (`dict`):
|
3169
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3170
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3171
|
+
encoder lora layers.
|
3172
|
+
transformer (`MochiTransformer3DModel`):
|
3173
|
+
The Transformer model to load the LoRA layers into.
|
3174
|
+
adapter_name (`str`, *optional*):
|
3175
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3176
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3177
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3178
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3179
|
+
weights.
|
3180
|
+
"""
|
3181
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3182
|
+
raise ValueError(
|
3183
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3184
|
+
)
|
3185
|
+
|
3186
|
+
# Load the layers corresponding to transformer.
|
3187
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3188
|
+
transformer.load_lora_adapter(
|
3189
|
+
state_dict,
|
3190
|
+
network_alphas=None,
|
3191
|
+
adapter_name=adapter_name,
|
3192
|
+
_pipeline=_pipeline,
|
3193
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3194
|
+
)
|
3195
|
+
|
3196
|
+
@classmethod
|
3197
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3198
|
+
def save_lora_weights(
|
3199
|
+
cls,
|
3200
|
+
save_directory: Union[str, os.PathLike],
|
3201
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3202
|
+
is_main_process: bool = True,
|
3203
|
+
weight_name: str = None,
|
3204
|
+
save_function: Callable = None,
|
3205
|
+
safe_serialization: bool = True,
|
3206
|
+
):
|
3207
|
+
r"""
|
3208
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3209
|
+
|
3210
|
+
Arguments:
|
3211
|
+
save_directory (`str` or `os.PathLike`):
|
3212
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3213
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3214
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3215
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3216
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3217
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3218
|
+
process to avoid race conditions.
|
3219
|
+
save_function (`Callable`):
|
3220
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3221
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3222
|
+
`DIFFUSERS_SAVE_MODE`.
|
3223
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3224
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3225
|
+
"""
|
3226
|
+
state_dict = {}
|
3227
|
+
|
3228
|
+
if not transformer_lora_layers:
|
3229
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3230
|
+
|
3231
|
+
if transformer_lora_layers:
|
3232
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3233
|
+
|
3234
|
+
# Save the model
|
3235
|
+
cls.write_lora_layers(
|
3236
|
+
state_dict=state_dict,
|
3237
|
+
save_directory=save_directory,
|
3238
|
+
is_main_process=is_main_process,
|
3239
|
+
weight_name=weight_name,
|
3240
|
+
save_function=save_function,
|
3241
|
+
safe_serialization=safe_serialization,
|
3242
|
+
)
|
3243
|
+
|
3244
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
3245
|
+
def fuse_lora(
|
3246
|
+
self,
|
3247
|
+
components: List[str] = ["transformer", "text_encoder"],
|
3248
|
+
lora_scale: float = 1.0,
|
3249
|
+
safe_fusing: bool = False,
|
3250
|
+
adapter_names: Optional[List[str]] = None,
|
3251
|
+
**kwargs,
|
3252
|
+
):
|
3253
|
+
r"""
|
3254
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3255
|
+
|
3256
|
+
<Tip warning={true}>
|
3257
|
+
|
3258
|
+
This is an experimental API.
|
3259
|
+
|
3260
|
+
</Tip>
|
3261
|
+
|
3262
|
+
Args:
|
3263
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3264
|
+
lora_scale (`float`, defaults to 1.0):
|
3265
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3266
|
+
safe_fusing (`bool`, defaults to `False`):
|
3267
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3268
|
+
adapter_names (`List[str]`, *optional*):
|
3269
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3270
|
+
|
3271
|
+
Example:
|
3272
|
+
|
3273
|
+
```py
|
3274
|
+
from diffusers import DiffusionPipeline
|
3275
|
+
import torch
|
3276
|
+
|
3277
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3278
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3279
|
+
).to("cuda")
|
3280
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3281
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3282
|
+
```
|
3283
|
+
"""
|
3284
|
+
super().fuse_lora(
|
3285
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3286
|
+
)
|
3287
|
+
|
3288
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
3289
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
3290
|
+
r"""
|
3291
|
+
Reverses the effect of
|
3292
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3293
|
+
|
3294
|
+
<Tip warning={true}>
|
3295
|
+
|
3296
|
+
This is an experimental API.
|
3297
|
+
|
3298
|
+
</Tip>
|
3299
|
+
|
3300
|
+
Args:
|
3301
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3302
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3303
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
3304
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3305
|
+
LoRA parameters then it won't have any effect.
|
3306
|
+
"""
|
3307
|
+
super().unfuse_lora(components=components)
|
3308
|
+
|
3309
|
+
|
3310
|
+
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
3311
|
+
r"""
|
3312
|
+
Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
|
3313
|
+
"""
|
3314
|
+
|
3315
|
+
_lora_loadable_modules = ["transformer"]
|
3316
|
+
transformer_name = TRANSFORMER_NAME
|
3317
|
+
|
3318
|
+
@classmethod
|
3319
|
+
@validate_hf_hub_args
|
3320
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
3321
|
+
def lora_state_dict(
|
3322
|
+
cls,
|
3323
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3324
|
+
**kwargs,
|
3325
|
+
):
|
3326
|
+
r"""
|
3327
|
+
Return state dict for lora weights and the network alphas.
|
3328
|
+
|
3329
|
+
<Tip warning={true}>
|
3330
|
+
|
3331
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3332
|
+
|
3333
|
+
This function is experimental and might change in the future.
|
3334
|
+
|
3335
|
+
</Tip>
|
3336
|
+
|
3337
|
+
Parameters:
|
3338
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3339
|
+
Can be either:
|
3340
|
+
|
3341
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3342
|
+
the Hub.
|
3343
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3344
|
+
with [`ModelMixin.save_pretrained`].
|
3345
|
+
- A [torch state
|
3346
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3347
|
+
|
3348
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3349
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3350
|
+
is not used.
|
3351
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3352
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3353
|
+
cached versions if they exist.
|
3354
|
+
|
3355
|
+
proxies (`Dict[str, str]`, *optional*):
|
3356
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3357
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3358
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3359
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3360
|
+
won't be downloaded from the Hub.
|
3361
|
+
token (`str` or *bool*, *optional*):
|
3362
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3363
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3364
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3365
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3366
|
+
allowed by Git.
|
3367
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3368
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3369
|
+
|
3370
|
+
"""
|
3371
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3372
|
+
# transformer and text encoder or both.
|
3373
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3374
|
+
force_download = kwargs.pop("force_download", False)
|
3375
|
+
proxies = kwargs.pop("proxies", None)
|
3376
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3377
|
+
token = kwargs.pop("token", None)
|
3378
|
+
revision = kwargs.pop("revision", None)
|
3379
|
+
subfolder = kwargs.pop("subfolder", None)
|
3380
|
+
weight_name = kwargs.pop("weight_name", None)
|
3381
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3382
|
+
|
3383
|
+
allow_pickle = False
|
3384
|
+
if use_safetensors is None:
|
3385
|
+
use_safetensors = True
|
3386
|
+
allow_pickle = True
|
3387
|
+
|
3388
|
+
user_agent = {
|
3389
|
+
"file_type": "attn_procs_weights",
|
3390
|
+
"framework": "pytorch",
|
3391
|
+
}
|
3392
|
+
|
3393
|
+
state_dict = _fetch_state_dict(
|
3394
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3395
|
+
weight_name=weight_name,
|
3396
|
+
use_safetensors=use_safetensors,
|
3397
|
+
local_files_only=local_files_only,
|
3398
|
+
cache_dir=cache_dir,
|
3399
|
+
force_download=force_download,
|
3400
|
+
proxies=proxies,
|
3401
|
+
token=token,
|
3402
|
+
revision=revision,
|
3403
|
+
subfolder=subfolder,
|
3404
|
+
user_agent=user_agent,
|
3405
|
+
allow_pickle=allow_pickle,
|
3406
|
+
)
|
3407
|
+
|
3408
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3409
|
+
if is_dora_scale_present:
|
3410
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3411
|
+
logger.warning(warn_msg)
|
3412
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3413
|
+
|
3414
|
+
return state_dict
|
3415
|
+
|
3416
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3417
|
+
def load_lora_weights(
|
3418
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3419
|
+
):
|
3420
|
+
"""
|
3421
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3422
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3423
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3424
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3425
|
+
dict is loaded into `self.transformer`.
|
3426
|
+
|
3427
|
+
Parameters:
|
3428
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3429
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3430
|
+
adapter_name (`str`, *optional*):
|
3431
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3432
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3433
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3434
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3435
|
+
weights.
|
3436
|
+
kwargs (`dict`, *optional*):
|
3437
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3438
|
+
"""
|
3439
|
+
if not USE_PEFT_BACKEND:
|
3440
|
+
raise ValueError("PEFT backend is required for this method.")
|
3441
|
+
|
3442
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3443
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3444
|
+
raise ValueError(
|
3445
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3446
|
+
)
|
3447
|
+
|
3448
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3449
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3450
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3451
|
+
|
3452
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3453
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3454
|
+
|
3455
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3456
|
+
if not is_correct_format:
|
3457
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3458
|
+
|
3459
|
+
self.load_lora_into_transformer(
|
3460
|
+
state_dict,
|
3461
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3462
|
+
adapter_name=adapter_name,
|
3463
|
+
_pipeline=self,
|
3464
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3465
|
+
)
|
3466
|
+
|
3467
|
+
@classmethod
|
3468
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
|
3469
|
+
def load_lora_into_transformer(
|
3470
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3471
|
+
):
|
3472
|
+
"""
|
3473
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3474
|
+
|
3475
|
+
Parameters:
|
3476
|
+
state_dict (`dict`):
|
3477
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3478
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3479
|
+
encoder lora layers.
|
3480
|
+
transformer (`LTXVideoTransformer3DModel`):
|
3481
|
+
The Transformer model to load the LoRA layers into.
|
3482
|
+
adapter_name (`str`, *optional*):
|
3483
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3484
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3485
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3486
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3487
|
+
weights.
|
3488
|
+
"""
|
3489
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3490
|
+
raise ValueError(
|
3491
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3492
|
+
)
|
3493
|
+
|
3494
|
+
# Load the layers corresponding to transformer.
|
3495
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3496
|
+
transformer.load_lora_adapter(
|
3497
|
+
state_dict,
|
3498
|
+
network_alphas=None,
|
3499
|
+
adapter_name=adapter_name,
|
3500
|
+
_pipeline=_pipeline,
|
3501
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3502
|
+
)
|
3503
|
+
|
3504
|
+
@classmethod
|
3505
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3506
|
+
def save_lora_weights(
|
3507
|
+
cls,
|
3508
|
+
save_directory: Union[str, os.PathLike],
|
3509
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3510
|
+
is_main_process: bool = True,
|
3511
|
+
weight_name: str = None,
|
3512
|
+
save_function: Callable = None,
|
3513
|
+
safe_serialization: bool = True,
|
3514
|
+
):
|
3515
|
+
r"""
|
3516
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3517
|
+
|
3518
|
+
Arguments:
|
3519
|
+
save_directory (`str` or `os.PathLike`):
|
3520
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3521
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3522
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3523
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3524
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3525
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3526
|
+
process to avoid race conditions.
|
3527
|
+
save_function (`Callable`):
|
3528
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3529
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3530
|
+
`DIFFUSERS_SAVE_MODE`.
|
3531
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3532
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3533
|
+
"""
|
3534
|
+
state_dict = {}
|
3535
|
+
|
3536
|
+
if not transformer_lora_layers:
|
3537
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3538
|
+
|
3539
|
+
if transformer_lora_layers:
|
3540
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3541
|
+
|
3542
|
+
# Save the model
|
3543
|
+
cls.write_lora_layers(
|
3544
|
+
state_dict=state_dict,
|
3545
|
+
save_directory=save_directory,
|
3546
|
+
is_main_process=is_main_process,
|
3547
|
+
weight_name=weight_name,
|
3548
|
+
save_function=save_function,
|
3549
|
+
safe_serialization=safe_serialization,
|
3550
|
+
)
|
3551
|
+
|
3552
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
3553
|
+
def fuse_lora(
|
3554
|
+
self,
|
3555
|
+
components: List[str] = ["transformer", "text_encoder"],
|
3556
|
+
lora_scale: float = 1.0,
|
3557
|
+
safe_fusing: bool = False,
|
3558
|
+
adapter_names: Optional[List[str]] = None,
|
3559
|
+
**kwargs,
|
3560
|
+
):
|
3561
|
+
r"""
|
3562
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3563
|
+
|
3564
|
+
<Tip warning={true}>
|
3565
|
+
|
3566
|
+
This is an experimental API.
|
3567
|
+
|
3568
|
+
</Tip>
|
3569
|
+
|
3570
|
+
Args:
|
3571
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3572
|
+
lora_scale (`float`, defaults to 1.0):
|
3573
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3574
|
+
safe_fusing (`bool`, defaults to `False`):
|
3575
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3576
|
+
adapter_names (`List[str]`, *optional*):
|
3577
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3578
|
+
|
3579
|
+
Example:
|
3580
|
+
|
3581
|
+
```py
|
3582
|
+
from diffusers import DiffusionPipeline
|
3583
|
+
import torch
|
3584
|
+
|
3585
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3586
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3587
|
+
).to("cuda")
|
3588
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3589
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3590
|
+
```
|
3591
|
+
"""
|
3592
|
+
super().fuse_lora(
|
3593
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3594
|
+
)
|
3595
|
+
|
3596
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
3597
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
3598
|
+
r"""
|
3599
|
+
Reverses the effect of
|
3600
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3601
|
+
|
3602
|
+
<Tip warning={true}>
|
3603
|
+
|
3604
|
+
This is an experimental API.
|
3605
|
+
|
3606
|
+
</Tip>
|
3607
|
+
|
3608
|
+
Args:
|
3609
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3610
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3611
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
3612
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3613
|
+
LoRA parameters then it won't have any effect.
|
3614
|
+
"""
|
3615
|
+
super().unfuse_lora(components=components)
|
3616
|
+
|
3617
|
+
|
3618
|
+
class SanaLoraLoaderMixin(LoraBaseMixin):
|
3619
|
+
r"""
|
3620
|
+
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
|
3621
|
+
"""
|
3622
|
+
|
3623
|
+
_lora_loadable_modules = ["transformer"]
|
3624
|
+
transformer_name = TRANSFORMER_NAME
|
3625
|
+
|
3626
|
+
@classmethod
|
3627
|
+
@validate_hf_hub_args
|
3628
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3629
|
+
def lora_state_dict(
|
3630
|
+
cls,
|
3631
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3632
|
+
**kwargs,
|
3633
|
+
):
|
3634
|
+
r"""
|
3635
|
+
Return state dict for lora weights and the network alphas.
|
3636
|
+
|
3637
|
+
<Tip warning={true}>
|
3638
|
+
|
3639
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3640
|
+
|
3641
|
+
This function is experimental and might change in the future.
|
3642
|
+
|
3643
|
+
</Tip>
|
3644
|
+
|
3645
|
+
Parameters:
|
3646
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3647
|
+
Can be either:
|
3648
|
+
|
3649
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3650
|
+
the Hub.
|
3651
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3652
|
+
with [`ModelMixin.save_pretrained`].
|
3653
|
+
- A [torch state
|
3654
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3655
|
+
|
3656
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3657
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3658
|
+
is not used.
|
3659
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3660
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3661
|
+
cached versions if they exist.
|
3662
|
+
|
3663
|
+
proxies (`Dict[str, str]`, *optional*):
|
3664
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3665
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3666
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3667
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3668
|
+
won't be downloaded from the Hub.
|
3669
|
+
token (`str` or *bool*, *optional*):
|
3670
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3671
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3672
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3673
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3674
|
+
allowed by Git.
|
3675
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3676
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3677
|
+
|
3678
|
+
"""
|
3679
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3680
|
+
# transformer and text encoder or both.
|
3681
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3682
|
+
force_download = kwargs.pop("force_download", False)
|
3683
|
+
proxies = kwargs.pop("proxies", None)
|
3684
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3685
|
+
token = kwargs.pop("token", None)
|
3686
|
+
revision = kwargs.pop("revision", None)
|
3687
|
+
subfolder = kwargs.pop("subfolder", None)
|
3688
|
+
weight_name = kwargs.pop("weight_name", None)
|
3689
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3690
|
+
|
3691
|
+
allow_pickle = False
|
3692
|
+
if use_safetensors is None:
|
3693
|
+
use_safetensors = True
|
3694
|
+
allow_pickle = True
|
3695
|
+
|
3696
|
+
user_agent = {
|
3697
|
+
"file_type": "attn_procs_weights",
|
3698
|
+
"framework": "pytorch",
|
3699
|
+
}
|
3700
|
+
|
3701
|
+
state_dict = _fetch_state_dict(
|
3702
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3703
|
+
weight_name=weight_name,
|
3704
|
+
use_safetensors=use_safetensors,
|
3705
|
+
local_files_only=local_files_only,
|
3706
|
+
cache_dir=cache_dir,
|
3707
|
+
force_download=force_download,
|
3708
|
+
proxies=proxies,
|
3709
|
+
token=token,
|
3710
|
+
revision=revision,
|
3711
|
+
subfolder=subfolder,
|
3712
|
+
user_agent=user_agent,
|
3713
|
+
allow_pickle=allow_pickle,
|
3714
|
+
)
|
3715
|
+
|
3716
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3717
|
+
if is_dora_scale_present:
|
3718
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3719
|
+
logger.warning(warn_msg)
|
3720
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3721
|
+
|
3722
|
+
return state_dict
|
3723
|
+
|
3724
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3725
|
+
def load_lora_weights(
|
3726
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3727
|
+
):
|
3728
|
+
"""
|
3729
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3730
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3731
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3732
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3733
|
+
dict is loaded into `self.transformer`.
|
3734
|
+
|
3735
|
+
Parameters:
|
3736
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3737
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3738
|
+
adapter_name (`str`, *optional*):
|
3739
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3740
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3741
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3742
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3743
|
+
weights.
|
3744
|
+
kwargs (`dict`, *optional*):
|
3745
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3746
|
+
"""
|
3747
|
+
if not USE_PEFT_BACKEND:
|
3748
|
+
raise ValueError("PEFT backend is required for this method.")
|
3749
|
+
|
3750
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3751
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3752
|
+
raise ValueError(
|
3753
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3754
|
+
)
|
3755
|
+
|
3756
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3757
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3758
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3759
|
+
|
3760
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3761
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3762
|
+
|
3763
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3764
|
+
if not is_correct_format:
|
3765
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3766
|
+
|
3767
|
+
self.load_lora_into_transformer(
|
3768
|
+
state_dict,
|
3769
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3770
|
+
adapter_name=adapter_name,
|
3771
|
+
_pipeline=self,
|
3772
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3773
|
+
)
|
3774
|
+
|
3775
|
+
@classmethod
|
3776
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
|
3777
|
+
def load_lora_into_transformer(
|
3778
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3779
|
+
):
|
3780
|
+
"""
|
3781
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3782
|
+
|
3783
|
+
Parameters:
|
3784
|
+
state_dict (`dict`):
|
3785
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3786
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3787
|
+
encoder lora layers.
|
3788
|
+
transformer (`SanaTransformer2DModel`):
|
3789
|
+
The Transformer model to load the LoRA layers into.
|
3790
|
+
adapter_name (`str`, *optional*):
|
3791
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3792
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3793
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3794
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3795
|
+
weights.
|
3796
|
+
"""
|
3797
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3798
|
+
raise ValueError(
|
3799
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3800
|
+
)
|
3801
|
+
|
3802
|
+
# Load the layers corresponding to transformer.
|
3803
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3804
|
+
transformer.load_lora_adapter(
|
3805
|
+
state_dict,
|
3806
|
+
network_alphas=None,
|
3807
|
+
adapter_name=adapter_name,
|
3808
|
+
_pipeline=_pipeline,
|
3809
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3810
|
+
)
|
3811
|
+
|
3812
|
+
@classmethod
|
3813
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3814
|
+
def save_lora_weights(
|
3815
|
+
cls,
|
3816
|
+
save_directory: Union[str, os.PathLike],
|
3817
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3818
|
+
is_main_process: bool = True,
|
3819
|
+
weight_name: str = None,
|
3820
|
+
save_function: Callable = None,
|
3821
|
+
safe_serialization: bool = True,
|
3822
|
+
):
|
3823
|
+
r"""
|
3824
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3825
|
+
|
3826
|
+
Arguments:
|
3827
|
+
save_directory (`str` or `os.PathLike`):
|
3828
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3829
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3830
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3831
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3832
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3833
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3834
|
+
process to avoid race conditions.
|
3835
|
+
save_function (`Callable`):
|
3836
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3837
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3838
|
+
`DIFFUSERS_SAVE_MODE`.
|
3839
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3840
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3841
|
+
"""
|
3842
|
+
state_dict = {}
|
3843
|
+
|
3844
|
+
if not transformer_lora_layers:
|
3845
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3846
|
+
|
3847
|
+
if transformer_lora_layers:
|
3848
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3849
|
+
|
3850
|
+
# Save the model
|
3851
|
+
cls.write_lora_layers(
|
3852
|
+
state_dict=state_dict,
|
3853
|
+
save_directory=save_directory,
|
3854
|
+
is_main_process=is_main_process,
|
3855
|
+
weight_name=weight_name,
|
3856
|
+
save_function=save_function,
|
3857
|
+
safe_serialization=safe_serialization,
|
3858
|
+
)
|
3859
|
+
|
3860
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
3861
|
+
def fuse_lora(
|
3862
|
+
self,
|
3863
|
+
components: List[str] = ["transformer", "text_encoder"],
|
3864
|
+
lora_scale: float = 1.0,
|
3865
|
+
safe_fusing: bool = False,
|
3866
|
+
adapter_names: Optional[List[str]] = None,
|
3867
|
+
**kwargs,
|
3868
|
+
):
|
3869
|
+
r"""
|
3870
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3871
|
+
|
3872
|
+
<Tip warning={true}>
|
3873
|
+
|
3874
|
+
This is an experimental API.
|
3875
|
+
|
3876
|
+
</Tip>
|
3877
|
+
|
3878
|
+
Args:
|
3879
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3880
|
+
lora_scale (`float`, defaults to 1.0):
|
3881
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3882
|
+
safe_fusing (`bool`, defaults to `False`):
|
3883
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3884
|
+
adapter_names (`List[str]`, *optional*):
|
3885
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3886
|
+
|
3887
|
+
Example:
|
3888
|
+
|
3889
|
+
```py
|
3890
|
+
from diffusers import DiffusionPipeline
|
3891
|
+
import torch
|
3892
|
+
|
3893
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3894
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3895
|
+
).to("cuda")
|
3896
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3897
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3898
|
+
```
|
3899
|
+
"""
|
3900
|
+
super().fuse_lora(
|
3901
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3902
|
+
)
|
3903
|
+
|
3904
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
3905
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
3906
|
+
r"""
|
3907
|
+
Reverses the effect of
|
3908
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3909
|
+
|
3910
|
+
<Tip warning={true}>
|
3911
|
+
|
3912
|
+
This is an experimental API.
|
3913
|
+
|
3914
|
+
</Tip>
|
3915
|
+
|
3916
|
+
Args:
|
3917
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3918
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3919
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
3920
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3921
|
+
LoRA parameters then it won't have any effect.
|
3922
|
+
"""
|
3923
|
+
super().unfuse_lora(components=components)
|
3924
|
+
|
3925
|
+
|
3926
|
+
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
3927
|
+
r"""
|
3928
|
+
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
|
3929
|
+
"""
|
3930
|
+
|
3931
|
+
_lora_loadable_modules = ["transformer"]
|
3932
|
+
transformer_name = TRANSFORMER_NAME
|
3933
|
+
|
3934
|
+
@classmethod
|
3935
|
+
@validate_hf_hub_args
|
3936
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3937
|
+
def lora_state_dict(
|
3938
|
+
cls,
|
3939
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3940
|
+
**kwargs,
|
3941
|
+
):
|
3942
|
+
r"""
|
3943
|
+
Return state dict for lora weights and the network alphas.
|
3944
|
+
|
3945
|
+
<Tip warning={true}>
|
3946
|
+
|
3947
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3948
|
+
|
3949
|
+
This function is experimental and might change in the future.
|
3950
|
+
|
3951
|
+
</Tip>
|
3952
|
+
|
3953
|
+
Parameters:
|
3954
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3955
|
+
Can be either:
|
3956
|
+
|
3957
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3958
|
+
the Hub.
|
3959
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3960
|
+
with [`ModelMixin.save_pretrained`].
|
3961
|
+
- A [torch state
|
3962
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3963
|
+
|
3964
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3965
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3966
|
+
is not used.
|
3967
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3968
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3969
|
+
cached versions if they exist.
|
3970
|
+
|
3971
|
+
proxies (`Dict[str, str]`, *optional*):
|
3972
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3973
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3974
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3975
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3976
|
+
won't be downloaded from the Hub.
|
3977
|
+
token (`str` or *bool*, *optional*):
|
3978
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3979
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3980
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3981
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3982
|
+
allowed by Git.
|
3983
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3984
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3985
|
+
|
3986
|
+
"""
|
3987
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3988
|
+
# transformer and text encoder or both.
|
3989
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3990
|
+
force_download = kwargs.pop("force_download", False)
|
3991
|
+
proxies = kwargs.pop("proxies", None)
|
3992
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3993
|
+
token = kwargs.pop("token", None)
|
3994
|
+
revision = kwargs.pop("revision", None)
|
3995
|
+
subfolder = kwargs.pop("subfolder", None)
|
3996
|
+
weight_name = kwargs.pop("weight_name", None)
|
3997
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3998
|
+
|
3999
|
+
allow_pickle = False
|
4000
|
+
if use_safetensors is None:
|
4001
|
+
use_safetensors = True
|
4002
|
+
allow_pickle = True
|
4003
|
+
|
4004
|
+
user_agent = {
|
4005
|
+
"file_type": "attn_procs_weights",
|
4006
|
+
"framework": "pytorch",
|
4007
|
+
}
|
4008
|
+
|
4009
|
+
state_dict = _fetch_state_dict(
|
4010
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
4011
|
+
weight_name=weight_name,
|
4012
|
+
use_safetensors=use_safetensors,
|
4013
|
+
local_files_only=local_files_only,
|
4014
|
+
cache_dir=cache_dir,
|
4015
|
+
force_download=force_download,
|
4016
|
+
proxies=proxies,
|
4017
|
+
token=token,
|
4018
|
+
revision=revision,
|
4019
|
+
subfolder=subfolder,
|
4020
|
+
user_agent=user_agent,
|
4021
|
+
allow_pickle=allow_pickle,
|
4022
|
+
)
|
4023
|
+
|
4024
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
4025
|
+
if is_dora_scale_present:
|
4026
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
4027
|
+
logger.warning(warn_msg)
|
4028
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4029
|
+
|
4030
|
+
return state_dict
|
4031
|
+
|
4032
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
4033
|
+
def load_lora_weights(
|
4034
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
4035
|
+
):
|
4036
|
+
"""
|
4037
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
4038
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
4039
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
4040
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
4041
|
+
dict is loaded into `self.transformer`.
|
4042
|
+
|
4043
|
+
Parameters:
|
4044
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
4045
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4046
|
+
adapter_name (`str`, *optional*):
|
4047
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4048
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4049
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4050
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4051
|
+
weights.
|
4052
|
+
kwargs (`dict`, *optional*):
|
4053
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4054
|
+
"""
|
4055
|
+
if not USE_PEFT_BACKEND:
|
4056
|
+
raise ValueError("PEFT backend is required for this method.")
|
4057
|
+
|
4058
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
4059
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4060
|
+
raise ValueError(
|
4061
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4062
|
+
)
|
4063
|
+
|
4064
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
4065
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
4066
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
4067
|
+
|
4068
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
4069
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
4070
|
+
|
4071
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
4072
|
+
if not is_correct_format:
|
4073
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
4074
|
+
|
4075
|
+
self.load_lora_into_transformer(
|
4076
|
+
state_dict,
|
4077
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4078
|
+
adapter_name=adapter_name,
|
4079
|
+
_pipeline=self,
|
4080
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4081
|
+
)
|
4082
|
+
|
4083
|
+
@classmethod
|
4084
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
|
4085
|
+
def load_lora_into_transformer(
|
4086
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
4087
|
+
):
|
4088
|
+
"""
|
4089
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
4090
|
+
|
4091
|
+
Parameters:
|
4092
|
+
state_dict (`dict`):
|
4093
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4094
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4095
|
+
encoder lora layers.
|
4096
|
+
transformer (`HunyuanVideoTransformer3DModel`):
|
4097
|
+
The Transformer model to load the LoRA layers into.
|
4098
|
+
adapter_name (`str`, *optional*):
|
4099
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4100
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4101
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4102
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4103
|
+
weights.
|
4104
|
+
"""
|
4105
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4106
|
+
raise ValueError(
|
4107
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4108
|
+
)
|
4109
|
+
|
4110
|
+
# Load the layers corresponding to transformer.
|
4111
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
4112
|
+
transformer.load_lora_adapter(
|
4113
|
+
state_dict,
|
4114
|
+
network_alphas=None,
|
4115
|
+
adapter_name=adapter_name,
|
4116
|
+
_pipeline=_pipeline,
|
4117
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4118
|
+
)
|
4119
|
+
|
4120
|
+
@classmethod
|
4121
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
4122
|
+
def save_lora_weights(
|
4123
|
+
cls,
|
4124
|
+
save_directory: Union[str, os.PathLike],
|
4125
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
4126
|
+
is_main_process: bool = True,
|
4127
|
+
weight_name: str = None,
|
4128
|
+
save_function: Callable = None,
|
4129
|
+
safe_serialization: bool = True,
|
4130
|
+
):
|
4131
|
+
r"""
|
4132
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
4133
|
+
|
4134
|
+
Arguments:
|
4135
|
+
save_directory (`str` or `os.PathLike`):
|
4136
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
4137
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
4138
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
4139
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
4140
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
4141
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
4142
|
+
process to avoid race conditions.
|
4143
|
+
save_function (`Callable`):
|
4144
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
4145
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
4146
|
+
`DIFFUSERS_SAVE_MODE`.
|
4147
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
4148
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
4149
|
+
"""
|
4150
|
+
state_dict = {}
|
4151
|
+
|
4152
|
+
if not transformer_lora_layers:
|
4153
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
4154
|
+
|
4155
|
+
if transformer_lora_layers:
|
4156
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
4157
|
+
|
4158
|
+
# Save the model
|
4159
|
+
cls.write_lora_layers(
|
4160
|
+
state_dict=state_dict,
|
4161
|
+
save_directory=save_directory,
|
4162
|
+
is_main_process=is_main_process,
|
4163
|
+
weight_name=weight_name,
|
4164
|
+
save_function=save_function,
|
4165
|
+
safe_serialization=safe_serialization,
|
4166
|
+
)
|
4167
|
+
|
4168
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
4169
|
+
def fuse_lora(
|
4170
|
+
self,
|
4171
|
+
components: List[str] = ["transformer", "text_encoder"],
|
4172
|
+
lora_scale: float = 1.0,
|
4173
|
+
safe_fusing: bool = False,
|
4174
|
+
adapter_names: Optional[List[str]] = None,
|
4175
|
+
**kwargs,
|
4176
|
+
):
|
4177
|
+
r"""
|
4178
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
4179
|
+
|
4180
|
+
<Tip warning={true}>
|
4181
|
+
|
4182
|
+
This is an experimental API.
|
4183
|
+
|
4184
|
+
</Tip>
|
4185
|
+
|
4186
|
+
Args:
|
4187
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
4188
|
+
lora_scale (`float`, defaults to 1.0):
|
4189
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
4190
|
+
safe_fusing (`bool`, defaults to `False`):
|
4191
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
4192
|
+
adapter_names (`List[str]`, *optional*):
|
4193
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
4194
|
+
|
4195
|
+
Example:
|
4196
|
+
|
4197
|
+
```py
|
4198
|
+
from diffusers import DiffusionPipeline
|
4199
|
+
import torch
|
4200
|
+
|
4201
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
4202
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
4203
|
+
).to("cuda")
|
4204
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
4205
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
4206
|
+
```
|
4207
|
+
"""
|
4208
|
+
super().fuse_lora(
|
4209
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
4210
|
+
)
|
4211
|
+
|
4212
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
4213
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
4214
|
+
r"""
|
4215
|
+
Reverses the effect of
|
4216
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
4217
|
+
|
4218
|
+
<Tip warning={true}>
|
4219
|
+
|
4220
|
+
This is an experimental API.
|
4221
|
+
|
4222
|
+
</Tip>
|
4223
|
+
|
4224
|
+
Args:
|
4225
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
4226
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
4227
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
4228
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
4229
|
+
LoRA parameters then it won't have any effect.
|
4230
|
+
"""
|
4231
|
+
super().unfuse_lora(components=components)
|
4232
|
+
|
4233
|
+
|
2914
4234
|
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
2915
4235
|
def __init__(self, *args, **kwargs):
|
2916
4236
|
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|