diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
|
|
41
41
|
)
|
42
42
|
from .lora_conversion_utils import (
|
43
43
|
_convert_bfl_flux_control_lora_to_diffusers,
|
44
|
+
_convert_fal_kontext_lora_to_diffusers,
|
44
45
|
_convert_hunyuan_video_lora_to_diffusers,
|
45
46
|
_convert_kohya_flux_lora_to_diffusers,
|
46
47
|
_convert_musubi_wan_lora_to_diffusers,
|
@@ -48,6 +49,7 @@ from .lora_conversion_utils import (
|
|
48
49
|
_convert_non_diffusers_lora_to_diffusers,
|
49
50
|
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
50
51
|
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
52
|
+
_convert_non_diffusers_qwen_lora_to_diffusers,
|
51
53
|
_convert_non_diffusers_wan_lora_to_diffusers,
|
52
54
|
_convert_xlabs_flux_lora_to_diffusers,
|
53
55
|
_maybe_map_sgm_blocks_to_diffusers,
|
@@ -2062,6 +2064,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
2062
2064
|
return_metadata=return_lora_metadata,
|
2063
2065
|
)
|
2064
2066
|
|
2067
|
+
is_fal_kontext = any("base_model" in k for k in state_dict)
|
2068
|
+
if is_fal_kontext:
|
2069
|
+
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
|
2070
|
+
return cls._prepare_outputs(
|
2071
|
+
state_dict,
|
2072
|
+
metadata=metadata,
|
2073
|
+
alphas=None,
|
2074
|
+
return_alphas=return_alphas,
|
2075
|
+
return_metadata=return_lora_metadata,
|
2076
|
+
)
|
2077
|
+
|
2065
2078
|
# For state dicts like
|
2066
2079
|
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
2067
2080
|
keys = list(state_dict.keys())
|
@@ -5052,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
5052
5065
|
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
5053
5066
|
"""
|
5054
5067
|
|
5055
|
-
_lora_loadable_modules = ["transformer"]
|
5068
|
+
_lora_loadable_modules = ["transformer", "transformer_2"]
|
5056
5069
|
transformer_name = TRANSFORMER_NAME
|
5057
5070
|
|
5058
5071
|
@classmethod
|
@@ -5257,15 +5270,35 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
5257
5270
|
if not is_correct_format:
|
5258
5271
|
raise ValueError("Invalid LoRA checkpoint.")
|
5259
5272
|
|
5260
|
-
|
5261
|
-
|
5262
|
-
|
5263
|
-
|
5264
|
-
|
5265
|
-
|
5266
|
-
|
5267
|
-
|
5268
|
-
|
5273
|
+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
5274
|
+
if load_into_transformer_2:
|
5275
|
+
if not hasattr(self, "transformer_2"):
|
5276
|
+
raise AttributeError(
|
5277
|
+
f"'{type(self).__name__}' object has no attribute transformer_2"
|
5278
|
+
"Note that Wan2.1 models do not have a transformer_2 component."
|
5279
|
+
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
|
5280
|
+
)
|
5281
|
+
self.load_lora_into_transformer(
|
5282
|
+
state_dict,
|
5283
|
+
transformer=self.transformer_2,
|
5284
|
+
adapter_name=adapter_name,
|
5285
|
+
metadata=metadata,
|
5286
|
+
_pipeline=self,
|
5287
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
5288
|
+
hotswap=hotswap,
|
5289
|
+
)
|
5290
|
+
else:
|
5291
|
+
self.load_lora_into_transformer(
|
5292
|
+
state_dict,
|
5293
|
+
transformer=getattr(self, self.transformer_name)
|
5294
|
+
if not hasattr(self, "transformer")
|
5295
|
+
else self.transformer,
|
5296
|
+
adapter_name=adapter_name,
|
5297
|
+
metadata=metadata,
|
5298
|
+
_pipeline=self,
|
5299
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
5300
|
+
hotswap=hotswap,
|
5301
|
+
)
|
5269
5302
|
|
5270
5303
|
@classmethod
|
5271
5304
|
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
@@ -5442,9 +5475,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
|
5442
5475
|
super().unfuse_lora(components=components, **kwargs)
|
5443
5476
|
|
5444
5477
|
|
5445
|
-
class
|
5478
|
+
class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
5446
5479
|
r"""
|
5447
|
-
Load LoRA layers into [`
|
5480
|
+
Load LoRA layers into [`SkyReelsV2Transformer3DModel`].
|
5448
5481
|
"""
|
5449
5482
|
|
5450
5483
|
_lora_loadable_modules = ["transformer"]
|
@@ -5452,7 +5485,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5452
5485
|
|
5453
5486
|
@classmethod
|
5454
5487
|
@validate_hf_hub_args
|
5455
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
5488
|
+
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
|
5456
5489
|
def lora_state_dict(
|
5457
5490
|
cls,
|
5458
5491
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -5503,7 +5536,6 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5503
5536
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
5504
5537
|
return_lora_metadata (`bool`, *optional*, defaults to False):
|
5505
5538
|
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
5506
|
-
|
5507
5539
|
"""
|
5508
5540
|
# Load the main state dict first which has the LoRA layers for either of
|
5509
5541
|
# transformer and text encoder or both.
|
@@ -5539,6 +5571,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5539
5571
|
user_agent=user_agent,
|
5540
5572
|
allow_pickle=allow_pickle,
|
5541
5573
|
)
|
5574
|
+
if any(k.startswith("diffusion_model.") for k in state_dict):
|
5575
|
+
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
5576
|
+
elif any(k.startswith("lora_unet_") for k in state_dict):
|
5577
|
+
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
|
5542
5578
|
|
5543
5579
|
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
5544
5580
|
if is_dora_scale_present:
|
@@ -5549,7 +5585,56 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5549
5585
|
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
5550
5586
|
return out
|
5551
5587
|
|
5552
|
-
|
5588
|
+
@classmethod
|
5589
|
+
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v
|
5590
|
+
def _maybe_expand_t2v_lora_for_i2v(
|
5591
|
+
cls,
|
5592
|
+
transformer: torch.nn.Module,
|
5593
|
+
state_dict,
|
5594
|
+
):
|
5595
|
+
if transformer.config.image_dim is None:
|
5596
|
+
return state_dict
|
5597
|
+
|
5598
|
+
target_device = transformer.device
|
5599
|
+
|
5600
|
+
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
5601
|
+
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
|
5602
|
+
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
5603
|
+
has_bias = any(".lora_B.bias" in k for k in state_dict)
|
5604
|
+
|
5605
|
+
if is_i2v_lora:
|
5606
|
+
return state_dict
|
5607
|
+
|
5608
|
+
for i in range(num_blocks):
|
5609
|
+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
5610
|
+
# These keys should exist if the block `i` was part of the T2V LoRA.
|
5611
|
+
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
|
5612
|
+
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
|
5613
|
+
|
5614
|
+
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
|
5615
|
+
continue
|
5616
|
+
|
5617
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
5618
|
+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
|
5619
|
+
)
|
5620
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
5621
|
+
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
|
5622
|
+
)
|
5623
|
+
|
5624
|
+
# If the original LoRA had biases (indicated by has_bias)
|
5625
|
+
# AND the specific reference bias key exists for this block.
|
5626
|
+
|
5627
|
+
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
|
5628
|
+
if has_bias and ref_key_lora_B_bias in state_dict:
|
5629
|
+
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
|
5630
|
+
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
|
5631
|
+
ref_lora_B_bias_tensor,
|
5632
|
+
device=target_device,
|
5633
|
+
)
|
5634
|
+
|
5635
|
+
return state_dict
|
5636
|
+
|
5637
|
+
# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
|
5553
5638
|
def load_lora_weights(
|
5554
5639
|
self,
|
5555
5640
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -5594,23 +5679,47 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5594
5679
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
5595
5680
|
kwargs["return_lora_metadata"] = True
|
5596
5681
|
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
5597
|
-
|
5682
|
+
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
5683
|
+
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
5684
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
5685
|
+
state_dict=state_dict,
|
5686
|
+
)
|
5598
5687
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
5599
5688
|
if not is_correct_format:
|
5600
5689
|
raise ValueError("Invalid LoRA checkpoint.")
|
5601
5690
|
|
5602
|
-
|
5603
|
-
|
5604
|
-
|
5605
|
-
|
5606
|
-
|
5607
|
-
|
5608
|
-
|
5609
|
-
|
5610
|
-
|
5691
|
+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
5692
|
+
if load_into_transformer_2:
|
5693
|
+
if not hasattr(self, "transformer_2"):
|
5694
|
+
raise AttributeError(
|
5695
|
+
f"'{type(self).__name__}' object has no attribute transformer_2"
|
5696
|
+
"Note that Wan2.1 models do not have a transformer_2 component."
|
5697
|
+
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
|
5698
|
+
)
|
5699
|
+
self.load_lora_into_transformer(
|
5700
|
+
state_dict,
|
5701
|
+
transformer=self.transformer_2,
|
5702
|
+
adapter_name=adapter_name,
|
5703
|
+
metadata=metadata,
|
5704
|
+
_pipeline=self,
|
5705
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
5706
|
+
hotswap=hotswap,
|
5707
|
+
)
|
5708
|
+
else:
|
5709
|
+
self.load_lora_into_transformer(
|
5710
|
+
state_dict,
|
5711
|
+
transformer=getattr(self, self.transformer_name)
|
5712
|
+
if not hasattr(self, "transformer")
|
5713
|
+
else self.transformer,
|
5714
|
+
adapter_name=adapter_name,
|
5715
|
+
metadata=metadata,
|
5716
|
+
_pipeline=self,
|
5717
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
5718
|
+
hotswap=hotswap,
|
5719
|
+
)
|
5611
5720
|
|
5612
5721
|
@classmethod
|
5613
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
5722
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
|
5614
5723
|
def load_lora_into_transformer(
|
5615
5724
|
cls,
|
5616
5725
|
state_dict,
|
@@ -5629,7 +5738,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5629
5738
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
5630
5739
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
5631
5740
|
encoder lora layers.
|
5632
|
-
transformer (`
|
5741
|
+
transformer (`SkyReelsV2Transformer3DModel`):
|
5633
5742
|
The Transformer model to load the LoRA layers into.
|
5634
5743
|
adapter_name (`str`, *optional*):
|
5635
5744
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -5784,9 +5893,9 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
|
5784
5893
|
super().unfuse_lora(components=components, **kwargs)
|
5785
5894
|
|
5786
5895
|
|
5787
|
-
class
|
5896
|
+
class CogView4LoraLoaderMixin(LoraBaseMixin):
|
5788
5897
|
r"""
|
5789
|
-
Load LoRA layers into [`
|
5898
|
+
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
|
5790
5899
|
"""
|
5791
5900
|
|
5792
5901
|
_lora_loadable_modules = ["transformer"]
|
@@ -5794,6 +5903,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|
5794
5903
|
|
5795
5904
|
@classmethod
|
5796
5905
|
@validate_hf_hub_args
|
5906
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
5797
5907
|
def lora_state_dict(
|
5798
5908
|
cls,
|
5799
5909
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
@@ -5844,6 +5954,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|
5844
5954
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
5845
5955
|
return_lora_metadata (`bool`, *optional*, defaults to False):
|
5846
5956
|
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
5957
|
+
|
5847
5958
|
"""
|
5848
5959
|
# Load the main state dict first which has the LoRA layers for either of
|
5849
5960
|
# transformer and text encoder or both.
|
@@ -5886,10 +5997,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|
5886
5997
|
logger.warning(warn_msg)
|
5887
5998
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
5888
5999
|
|
5889
|
-
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
|
5890
|
-
if is_non_diffusers_format:
|
5891
|
-
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
|
5892
|
-
|
5893
6000
|
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
5894
6001
|
return out
|
5895
6002
|
|
@@ -5954,7 +6061,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|
5954
6061
|
)
|
5955
6062
|
|
5956
6063
|
@classmethod
|
5957
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->
|
6064
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
5958
6065
|
def load_lora_into_transformer(
|
5959
6066
|
cls,
|
5960
6067
|
state_dict,
|
@@ -5973,7 +6080,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|
5973
6080
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
5974
6081
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
5975
6082
|
encoder lora layers.
|
5976
|
-
transformer (`
|
6083
|
+
transformer (`CogView4Transformer2DModel`):
|
5977
6084
|
The Transformer model to load the LoRA layers into.
|
5978
6085
|
adapter_name (`str`, *optional*):
|
5979
6086
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
@@ -6061,7 +6168,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|
6061
6168
|
lora_adapter_metadata=lora_adapter_metadata,
|
6062
6169
|
)
|
6063
6170
|
|
6064
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
6171
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
6065
6172
|
def fuse_lora(
|
6066
6173
|
self,
|
6067
6174
|
components: List[str] = ["transformer"],
|
@@ -6109,7 +6216,697 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|
6109
6216
|
**kwargs,
|
6110
6217
|
)
|
6111
6218
|
|
6112
|
-
# Copied from diffusers.loaders.lora_pipeline.
|
6219
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
6220
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
6221
|
+
r"""
|
6222
|
+
Reverses the effect of
|
6223
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
6224
|
+
|
6225
|
+
<Tip warning={true}>
|
6226
|
+
|
6227
|
+
This is an experimental API.
|
6228
|
+
|
6229
|
+
</Tip>
|
6230
|
+
|
6231
|
+
Args:
|
6232
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
6233
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
6234
|
+
"""
|
6235
|
+
super().unfuse_lora(components=components, **kwargs)
|
6236
|
+
|
6237
|
+
|
6238
|
+
class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
6239
|
+
r"""
|
6240
|
+
Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
|
6241
|
+
"""
|
6242
|
+
|
6243
|
+
_lora_loadable_modules = ["transformer"]
|
6244
|
+
transformer_name = TRANSFORMER_NAME
|
6245
|
+
|
6246
|
+
@classmethod
|
6247
|
+
@validate_hf_hub_args
|
6248
|
+
def lora_state_dict(
|
6249
|
+
cls,
|
6250
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
6251
|
+
**kwargs,
|
6252
|
+
):
|
6253
|
+
r"""
|
6254
|
+
Return state dict for lora weights and the network alphas.
|
6255
|
+
|
6256
|
+
<Tip warning={true}>
|
6257
|
+
|
6258
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
6259
|
+
|
6260
|
+
This function is experimental and might change in the future.
|
6261
|
+
|
6262
|
+
</Tip>
|
6263
|
+
|
6264
|
+
Parameters:
|
6265
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
6266
|
+
Can be either:
|
6267
|
+
|
6268
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
6269
|
+
the Hub.
|
6270
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
6271
|
+
with [`ModelMixin.save_pretrained`].
|
6272
|
+
- A [torch state
|
6273
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
6274
|
+
|
6275
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
6276
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
6277
|
+
is not used.
|
6278
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
6279
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
6280
|
+
cached versions if they exist.
|
6281
|
+
|
6282
|
+
proxies (`Dict[str, str]`, *optional*):
|
6283
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
6284
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
6285
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
6286
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
6287
|
+
won't be downloaded from the Hub.
|
6288
|
+
token (`str` or *bool*, *optional*):
|
6289
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
6290
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
6291
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
6292
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
6293
|
+
allowed by Git.
|
6294
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
6295
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
6296
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
6297
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
6298
|
+
"""
|
6299
|
+
# Load the main state dict first which has the LoRA layers for either of
|
6300
|
+
# transformer and text encoder or both.
|
6301
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
6302
|
+
force_download = kwargs.pop("force_download", False)
|
6303
|
+
proxies = kwargs.pop("proxies", None)
|
6304
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
6305
|
+
token = kwargs.pop("token", None)
|
6306
|
+
revision = kwargs.pop("revision", None)
|
6307
|
+
subfolder = kwargs.pop("subfolder", None)
|
6308
|
+
weight_name = kwargs.pop("weight_name", None)
|
6309
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
6310
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
6311
|
+
|
6312
|
+
allow_pickle = False
|
6313
|
+
if use_safetensors is None:
|
6314
|
+
use_safetensors = True
|
6315
|
+
allow_pickle = True
|
6316
|
+
|
6317
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
6318
|
+
|
6319
|
+
state_dict, metadata = _fetch_state_dict(
|
6320
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
6321
|
+
weight_name=weight_name,
|
6322
|
+
use_safetensors=use_safetensors,
|
6323
|
+
local_files_only=local_files_only,
|
6324
|
+
cache_dir=cache_dir,
|
6325
|
+
force_download=force_download,
|
6326
|
+
proxies=proxies,
|
6327
|
+
token=token,
|
6328
|
+
revision=revision,
|
6329
|
+
subfolder=subfolder,
|
6330
|
+
user_agent=user_agent,
|
6331
|
+
allow_pickle=allow_pickle,
|
6332
|
+
)
|
6333
|
+
|
6334
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
6335
|
+
if is_dora_scale_present:
|
6336
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
6337
|
+
logger.warning(warn_msg)
|
6338
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
6339
|
+
|
6340
|
+
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
|
6341
|
+
if is_non_diffusers_format:
|
6342
|
+
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
|
6343
|
+
|
6344
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
6345
|
+
return out
|
6346
|
+
|
6347
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
6348
|
+
def load_lora_weights(
|
6349
|
+
self,
|
6350
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
6351
|
+
adapter_name: Optional[str] = None,
|
6352
|
+
hotswap: bool = False,
|
6353
|
+
**kwargs,
|
6354
|
+
):
|
6355
|
+
"""
|
6356
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
6357
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
6358
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
6359
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
6360
|
+
dict is loaded into `self.transformer`.
|
6361
|
+
|
6362
|
+
Parameters:
|
6363
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
6364
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
6365
|
+
adapter_name (`str`, *optional*):
|
6366
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
6367
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
6368
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
6369
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
6370
|
+
weights.
|
6371
|
+
hotswap (`bool`, *optional*):
|
6372
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
6373
|
+
kwargs (`dict`, *optional*):
|
6374
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
6375
|
+
"""
|
6376
|
+
if not USE_PEFT_BACKEND:
|
6377
|
+
raise ValueError("PEFT backend is required for this method.")
|
6378
|
+
|
6379
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
6380
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
6381
|
+
raise ValueError(
|
6382
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
6383
|
+
)
|
6384
|
+
|
6385
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
6386
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
6387
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
6388
|
+
|
6389
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
6390
|
+
kwargs["return_lora_metadata"] = True
|
6391
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
6392
|
+
|
6393
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
6394
|
+
if not is_correct_format:
|
6395
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
6396
|
+
|
6397
|
+
self.load_lora_into_transformer(
|
6398
|
+
state_dict,
|
6399
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
6400
|
+
adapter_name=adapter_name,
|
6401
|
+
metadata=metadata,
|
6402
|
+
_pipeline=self,
|
6403
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
6404
|
+
hotswap=hotswap,
|
6405
|
+
)
|
6406
|
+
|
6407
|
+
@classmethod
|
6408
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
|
6409
|
+
def load_lora_into_transformer(
|
6410
|
+
cls,
|
6411
|
+
state_dict,
|
6412
|
+
transformer,
|
6413
|
+
adapter_name=None,
|
6414
|
+
_pipeline=None,
|
6415
|
+
low_cpu_mem_usage=False,
|
6416
|
+
hotswap: bool = False,
|
6417
|
+
metadata=None,
|
6418
|
+
):
|
6419
|
+
"""
|
6420
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
6421
|
+
|
6422
|
+
Parameters:
|
6423
|
+
state_dict (`dict`):
|
6424
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
6425
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
6426
|
+
encoder lora layers.
|
6427
|
+
transformer (`HiDreamImageTransformer2DModel`):
|
6428
|
+
The Transformer model to load the LoRA layers into.
|
6429
|
+
adapter_name (`str`, *optional*):
|
6430
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
6431
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
6432
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
6433
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
6434
|
+
weights.
|
6435
|
+
hotswap (`bool`, *optional*):
|
6436
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
6437
|
+
metadata (`dict`):
|
6438
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
6439
|
+
from the state dict.
|
6440
|
+
"""
|
6441
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
6442
|
+
raise ValueError(
|
6443
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
6444
|
+
)
|
6445
|
+
|
6446
|
+
# Load the layers corresponding to transformer.
|
6447
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
6448
|
+
transformer.load_lora_adapter(
|
6449
|
+
state_dict,
|
6450
|
+
network_alphas=None,
|
6451
|
+
adapter_name=adapter_name,
|
6452
|
+
metadata=metadata,
|
6453
|
+
_pipeline=_pipeline,
|
6454
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
6455
|
+
hotswap=hotswap,
|
6456
|
+
)
|
6457
|
+
|
6458
|
+
@classmethod
|
6459
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
6460
|
+
def save_lora_weights(
|
6461
|
+
cls,
|
6462
|
+
save_directory: Union[str, os.PathLike],
|
6463
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
6464
|
+
is_main_process: bool = True,
|
6465
|
+
weight_name: str = None,
|
6466
|
+
save_function: Callable = None,
|
6467
|
+
safe_serialization: bool = True,
|
6468
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
6469
|
+
):
|
6470
|
+
r"""
|
6471
|
+
Save the LoRA parameters corresponding to the transformer.
|
6472
|
+
|
6473
|
+
Arguments:
|
6474
|
+
save_directory (`str` or `os.PathLike`):
|
6475
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
6476
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
6477
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
6478
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
6479
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
6480
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
6481
|
+
process to avoid race conditions.
|
6482
|
+
save_function (`Callable`):
|
6483
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
6484
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
6485
|
+
`DIFFUSERS_SAVE_MODE`.
|
6486
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
6487
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
6488
|
+
transformer_lora_adapter_metadata:
|
6489
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
6490
|
+
"""
|
6491
|
+
state_dict = {}
|
6492
|
+
lora_adapter_metadata = {}
|
6493
|
+
|
6494
|
+
if not transformer_lora_layers:
|
6495
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
6496
|
+
|
6497
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
6498
|
+
|
6499
|
+
if transformer_lora_adapter_metadata is not None:
|
6500
|
+
lora_adapter_metadata.update(
|
6501
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
6502
|
+
)
|
6503
|
+
|
6504
|
+
# Save the model
|
6505
|
+
cls.write_lora_layers(
|
6506
|
+
state_dict=state_dict,
|
6507
|
+
save_directory=save_directory,
|
6508
|
+
is_main_process=is_main_process,
|
6509
|
+
weight_name=weight_name,
|
6510
|
+
save_function=save_function,
|
6511
|
+
safe_serialization=safe_serialization,
|
6512
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
6513
|
+
)
|
6514
|
+
|
6515
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
6516
|
+
def fuse_lora(
|
6517
|
+
self,
|
6518
|
+
components: List[str] = ["transformer"],
|
6519
|
+
lora_scale: float = 1.0,
|
6520
|
+
safe_fusing: bool = False,
|
6521
|
+
adapter_names: Optional[List[str]] = None,
|
6522
|
+
**kwargs,
|
6523
|
+
):
|
6524
|
+
r"""
|
6525
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
6526
|
+
|
6527
|
+
<Tip warning={true}>
|
6528
|
+
|
6529
|
+
This is an experimental API.
|
6530
|
+
|
6531
|
+
</Tip>
|
6532
|
+
|
6533
|
+
Args:
|
6534
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
6535
|
+
lora_scale (`float`, defaults to 1.0):
|
6536
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
6537
|
+
safe_fusing (`bool`, defaults to `False`):
|
6538
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
6539
|
+
adapter_names (`List[str]`, *optional*):
|
6540
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
6541
|
+
|
6542
|
+
Example:
|
6543
|
+
|
6544
|
+
```py
|
6545
|
+
from diffusers import DiffusionPipeline
|
6546
|
+
import torch
|
6547
|
+
|
6548
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
6549
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
6550
|
+
).to("cuda")
|
6551
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
6552
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
6553
|
+
```
|
6554
|
+
"""
|
6555
|
+
super().fuse_lora(
|
6556
|
+
components=components,
|
6557
|
+
lora_scale=lora_scale,
|
6558
|
+
safe_fusing=safe_fusing,
|
6559
|
+
adapter_names=adapter_names,
|
6560
|
+
**kwargs,
|
6561
|
+
)
|
6562
|
+
|
6563
|
+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
6564
|
+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
6565
|
+
r"""
|
6566
|
+
Reverses the effect of
|
6567
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
6568
|
+
|
6569
|
+
<Tip warning={true}>
|
6570
|
+
|
6571
|
+
This is an experimental API.
|
6572
|
+
|
6573
|
+
</Tip>
|
6574
|
+
|
6575
|
+
Args:
|
6576
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
6577
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
6578
|
+
"""
|
6579
|
+
super().unfuse_lora(components=components, **kwargs)
|
6580
|
+
|
6581
|
+
|
6582
|
+
class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
6583
|
+
r"""
|
6584
|
+
Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
|
6585
|
+
"""
|
6586
|
+
|
6587
|
+
_lora_loadable_modules = ["transformer"]
|
6588
|
+
transformer_name = TRANSFORMER_NAME
|
6589
|
+
|
6590
|
+
@classmethod
|
6591
|
+
@validate_hf_hub_args
|
6592
|
+
def lora_state_dict(
|
6593
|
+
cls,
|
6594
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
6595
|
+
**kwargs,
|
6596
|
+
):
|
6597
|
+
r"""
|
6598
|
+
Return state dict for lora weights and the network alphas.
|
6599
|
+
|
6600
|
+
<Tip warning={true}>
|
6601
|
+
|
6602
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
6603
|
+
|
6604
|
+
This function is experimental and might change in the future.
|
6605
|
+
|
6606
|
+
</Tip>
|
6607
|
+
|
6608
|
+
Parameters:
|
6609
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
6610
|
+
Can be either:
|
6611
|
+
|
6612
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
6613
|
+
the Hub.
|
6614
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
6615
|
+
with [`ModelMixin.save_pretrained`].
|
6616
|
+
- A [torch state
|
6617
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
6618
|
+
|
6619
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
6620
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
6621
|
+
is not used.
|
6622
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
6623
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
6624
|
+
cached versions if they exist.
|
6625
|
+
|
6626
|
+
proxies (`Dict[str, str]`, *optional*):
|
6627
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
6628
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
6629
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
6630
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
6631
|
+
won't be downloaded from the Hub.
|
6632
|
+
token (`str` or *bool*, *optional*):
|
6633
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
6634
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
6635
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
6636
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
6637
|
+
allowed by Git.
|
6638
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
6639
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
6640
|
+
return_lora_metadata (`bool`, *optional*, defaults to False):
|
6641
|
+
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
6642
|
+
|
6643
|
+
"""
|
6644
|
+
# Load the main state dict first which has the LoRA layers for either of
|
6645
|
+
# transformer and text encoder or both.
|
6646
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
6647
|
+
force_download = kwargs.pop("force_download", False)
|
6648
|
+
proxies = kwargs.pop("proxies", None)
|
6649
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
6650
|
+
token = kwargs.pop("token", None)
|
6651
|
+
revision = kwargs.pop("revision", None)
|
6652
|
+
subfolder = kwargs.pop("subfolder", None)
|
6653
|
+
weight_name = kwargs.pop("weight_name", None)
|
6654
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
6655
|
+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
6656
|
+
|
6657
|
+
allow_pickle = False
|
6658
|
+
if use_safetensors is None:
|
6659
|
+
use_safetensors = True
|
6660
|
+
allow_pickle = True
|
6661
|
+
|
6662
|
+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
6663
|
+
|
6664
|
+
state_dict, metadata = _fetch_state_dict(
|
6665
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
6666
|
+
weight_name=weight_name,
|
6667
|
+
use_safetensors=use_safetensors,
|
6668
|
+
local_files_only=local_files_only,
|
6669
|
+
cache_dir=cache_dir,
|
6670
|
+
force_download=force_download,
|
6671
|
+
proxies=proxies,
|
6672
|
+
token=token,
|
6673
|
+
revision=revision,
|
6674
|
+
subfolder=subfolder,
|
6675
|
+
user_agent=user_agent,
|
6676
|
+
allow_pickle=allow_pickle,
|
6677
|
+
)
|
6678
|
+
|
6679
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
6680
|
+
if is_dora_scale_present:
|
6681
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
6682
|
+
logger.warning(warn_msg)
|
6683
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
6684
|
+
|
6685
|
+
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
6686
|
+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
6687
|
+
if has_alphas_in_sd or has_lora_unet:
|
6688
|
+
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
6689
|
+
|
6690
|
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
6691
|
+
return out
|
6692
|
+
|
6693
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
6694
|
+
def load_lora_weights(
|
6695
|
+
self,
|
6696
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
6697
|
+
adapter_name: Optional[str] = None,
|
6698
|
+
hotswap: bool = False,
|
6699
|
+
**kwargs,
|
6700
|
+
):
|
6701
|
+
"""
|
6702
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
6703
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
6704
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
6705
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
6706
|
+
dict is loaded into `self.transformer`.
|
6707
|
+
|
6708
|
+
Parameters:
|
6709
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
6710
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
6711
|
+
adapter_name (`str`, *optional*):
|
6712
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
6713
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
6714
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
6715
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
6716
|
+
weights.
|
6717
|
+
hotswap (`bool`, *optional*):
|
6718
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
6719
|
+
kwargs (`dict`, *optional*):
|
6720
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
6721
|
+
"""
|
6722
|
+
if not USE_PEFT_BACKEND:
|
6723
|
+
raise ValueError("PEFT backend is required for this method.")
|
6724
|
+
|
6725
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
6726
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
6727
|
+
raise ValueError(
|
6728
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
6729
|
+
)
|
6730
|
+
|
6731
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
6732
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
6733
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
6734
|
+
|
6735
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
6736
|
+
kwargs["return_lora_metadata"] = True
|
6737
|
+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
6738
|
+
|
6739
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
6740
|
+
if not is_correct_format:
|
6741
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
6742
|
+
|
6743
|
+
self.load_lora_into_transformer(
|
6744
|
+
state_dict,
|
6745
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
6746
|
+
adapter_name=adapter_name,
|
6747
|
+
metadata=metadata,
|
6748
|
+
_pipeline=self,
|
6749
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
6750
|
+
hotswap=hotswap,
|
6751
|
+
)
|
6752
|
+
|
6753
|
+
@classmethod
|
6754
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
|
6755
|
+
def load_lora_into_transformer(
|
6756
|
+
cls,
|
6757
|
+
state_dict,
|
6758
|
+
transformer,
|
6759
|
+
adapter_name=None,
|
6760
|
+
_pipeline=None,
|
6761
|
+
low_cpu_mem_usage=False,
|
6762
|
+
hotswap: bool = False,
|
6763
|
+
metadata=None,
|
6764
|
+
):
|
6765
|
+
"""
|
6766
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
6767
|
+
|
6768
|
+
Parameters:
|
6769
|
+
state_dict (`dict`):
|
6770
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
6771
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
6772
|
+
encoder lora layers.
|
6773
|
+
transformer (`QwenImageTransformer2DModel`):
|
6774
|
+
The Transformer model to load the LoRA layers into.
|
6775
|
+
adapter_name (`str`, *optional*):
|
6776
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
6777
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
6778
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
6779
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
6780
|
+
weights.
|
6781
|
+
hotswap (`bool`, *optional*):
|
6782
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
6783
|
+
metadata (`dict`):
|
6784
|
+
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
6785
|
+
from the state dict.
|
6786
|
+
"""
|
6787
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
6788
|
+
raise ValueError(
|
6789
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
6790
|
+
)
|
6791
|
+
|
6792
|
+
# Load the layers corresponding to transformer.
|
6793
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
6794
|
+
transformer.load_lora_adapter(
|
6795
|
+
state_dict,
|
6796
|
+
network_alphas=None,
|
6797
|
+
adapter_name=adapter_name,
|
6798
|
+
metadata=metadata,
|
6799
|
+
_pipeline=_pipeline,
|
6800
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
6801
|
+
hotswap=hotswap,
|
6802
|
+
)
|
6803
|
+
|
6804
|
+
@classmethod
|
6805
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
6806
|
+
def save_lora_weights(
|
6807
|
+
cls,
|
6808
|
+
save_directory: Union[str, os.PathLike],
|
6809
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
6810
|
+
is_main_process: bool = True,
|
6811
|
+
weight_name: str = None,
|
6812
|
+
save_function: Callable = None,
|
6813
|
+
safe_serialization: bool = True,
|
6814
|
+
transformer_lora_adapter_metadata: Optional[dict] = None,
|
6815
|
+
):
|
6816
|
+
r"""
|
6817
|
+
Save the LoRA parameters corresponding to the transformer.
|
6818
|
+
|
6819
|
+
Arguments:
|
6820
|
+
save_directory (`str` or `os.PathLike`):
|
6821
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
6822
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
6823
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
6824
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
6825
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
6826
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
6827
|
+
process to avoid race conditions.
|
6828
|
+
save_function (`Callable`):
|
6829
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
6830
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
6831
|
+
`DIFFUSERS_SAVE_MODE`.
|
6832
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
6833
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
6834
|
+
transformer_lora_adapter_metadata:
|
6835
|
+
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
6836
|
+
"""
|
6837
|
+
state_dict = {}
|
6838
|
+
lora_adapter_metadata = {}
|
6839
|
+
|
6840
|
+
if not transformer_lora_layers:
|
6841
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
6842
|
+
|
6843
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
6844
|
+
|
6845
|
+
if transformer_lora_adapter_metadata is not None:
|
6846
|
+
lora_adapter_metadata.update(
|
6847
|
+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
6848
|
+
)
|
6849
|
+
|
6850
|
+
# Save the model
|
6851
|
+
cls.write_lora_layers(
|
6852
|
+
state_dict=state_dict,
|
6853
|
+
save_directory=save_directory,
|
6854
|
+
is_main_process=is_main_process,
|
6855
|
+
weight_name=weight_name,
|
6856
|
+
save_function=save_function,
|
6857
|
+
safe_serialization=safe_serialization,
|
6858
|
+
lora_adapter_metadata=lora_adapter_metadata,
|
6859
|
+
)
|
6860
|
+
|
6861
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
6862
|
+
def fuse_lora(
|
6863
|
+
self,
|
6864
|
+
components: List[str] = ["transformer"],
|
6865
|
+
lora_scale: float = 1.0,
|
6866
|
+
safe_fusing: bool = False,
|
6867
|
+
adapter_names: Optional[List[str]] = None,
|
6868
|
+
**kwargs,
|
6869
|
+
):
|
6870
|
+
r"""
|
6871
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
6872
|
+
|
6873
|
+
<Tip warning={true}>
|
6874
|
+
|
6875
|
+
This is an experimental API.
|
6876
|
+
|
6877
|
+
</Tip>
|
6878
|
+
|
6879
|
+
Args:
|
6880
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
6881
|
+
lora_scale (`float`, defaults to 1.0):
|
6882
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
6883
|
+
safe_fusing (`bool`, defaults to `False`):
|
6884
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
6885
|
+
adapter_names (`List[str]`, *optional*):
|
6886
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
6887
|
+
|
6888
|
+
Example:
|
6889
|
+
|
6890
|
+
```py
|
6891
|
+
from diffusers import DiffusionPipeline
|
6892
|
+
import torch
|
6893
|
+
|
6894
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
6895
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
6896
|
+
).to("cuda")
|
6897
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
6898
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
6899
|
+
```
|
6900
|
+
"""
|
6901
|
+
super().fuse_lora(
|
6902
|
+
components=components,
|
6903
|
+
lora_scale=lora_scale,
|
6904
|
+
safe_fusing=safe_fusing,
|
6905
|
+
adapter_names=adapter_names,
|
6906
|
+
**kwargs,
|
6907
|
+
)
|
6908
|
+
|
6909
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
6113
6910
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
6114
6911
|
r"""
|
6115
6912
|
Reverses the effect of
|