diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/lora_base.py
CHANGED
@@ -51,6 +51,9 @@ if is_accelerate_available():
|
|
51
51
|
|
52
52
|
logger = logging.get_logger(__name__)
|
53
53
|
|
54
|
+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
55
|
+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
56
|
+
|
54
57
|
|
55
58
|
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
56
59
|
"""
|
@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder):
|
|
181
184
|
text_encoder._hf_peft_config_loaded = None
|
182
185
|
|
183
186
|
|
187
|
+
def _fetch_state_dict(
|
188
|
+
pretrained_model_name_or_path_or_dict,
|
189
|
+
weight_name,
|
190
|
+
use_safetensors,
|
191
|
+
local_files_only,
|
192
|
+
cache_dir,
|
193
|
+
force_download,
|
194
|
+
proxies,
|
195
|
+
token,
|
196
|
+
revision,
|
197
|
+
subfolder,
|
198
|
+
user_agent,
|
199
|
+
allow_pickle,
|
200
|
+
):
|
201
|
+
model_file = None
|
202
|
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
203
|
+
# Let's first try to load .safetensors weights
|
204
|
+
if (use_safetensors and weight_name is None) or (
|
205
|
+
weight_name is not None and weight_name.endswith(".safetensors")
|
206
|
+
):
|
207
|
+
try:
|
208
|
+
# Here we're relaxing the loading check to enable more Inference API
|
209
|
+
# friendliness where sometimes, it's not at all possible to automatically
|
210
|
+
# determine `weight_name`.
|
211
|
+
if weight_name is None:
|
212
|
+
weight_name = _best_guess_weight_name(
|
213
|
+
pretrained_model_name_or_path_or_dict,
|
214
|
+
file_extension=".safetensors",
|
215
|
+
local_files_only=local_files_only,
|
216
|
+
)
|
217
|
+
model_file = _get_model_file(
|
218
|
+
pretrained_model_name_or_path_or_dict,
|
219
|
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
220
|
+
cache_dir=cache_dir,
|
221
|
+
force_download=force_download,
|
222
|
+
proxies=proxies,
|
223
|
+
local_files_only=local_files_only,
|
224
|
+
token=token,
|
225
|
+
revision=revision,
|
226
|
+
subfolder=subfolder,
|
227
|
+
user_agent=user_agent,
|
228
|
+
)
|
229
|
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
230
|
+
except (IOError, safetensors.SafetensorError) as e:
|
231
|
+
if not allow_pickle:
|
232
|
+
raise e
|
233
|
+
# try loading non-safetensors weights
|
234
|
+
model_file = None
|
235
|
+
pass
|
236
|
+
|
237
|
+
if model_file is None:
|
238
|
+
if weight_name is None:
|
239
|
+
weight_name = _best_guess_weight_name(
|
240
|
+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
241
|
+
)
|
242
|
+
model_file = _get_model_file(
|
243
|
+
pretrained_model_name_or_path_or_dict,
|
244
|
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
245
|
+
cache_dir=cache_dir,
|
246
|
+
force_download=force_download,
|
247
|
+
proxies=proxies,
|
248
|
+
local_files_only=local_files_only,
|
249
|
+
token=token,
|
250
|
+
revision=revision,
|
251
|
+
subfolder=subfolder,
|
252
|
+
user_agent=user_agent,
|
253
|
+
)
|
254
|
+
state_dict = load_state_dict(model_file)
|
255
|
+
else:
|
256
|
+
state_dict = pretrained_model_name_or_path_or_dict
|
257
|
+
|
258
|
+
return state_dict
|
259
|
+
|
260
|
+
|
261
|
+
def _best_guess_weight_name(
|
262
|
+
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
263
|
+
):
|
264
|
+
if local_files_only or HF_HUB_OFFLINE:
|
265
|
+
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
266
|
+
|
267
|
+
targeted_files = []
|
268
|
+
|
269
|
+
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
270
|
+
return
|
271
|
+
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
272
|
+
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
273
|
+
else:
|
274
|
+
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
275
|
+
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
276
|
+
if len(targeted_files) == 0:
|
277
|
+
return
|
278
|
+
|
279
|
+
# "scheduler" does not correspond to a LoRA checkpoint.
|
280
|
+
# "optimizer" does not correspond to a LoRA checkpoint
|
281
|
+
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
282
|
+
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
283
|
+
targeted_files = list(
|
284
|
+
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
285
|
+
)
|
286
|
+
|
287
|
+
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
288
|
+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
289
|
+
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
290
|
+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
291
|
+
|
292
|
+
if len(targeted_files) > 1:
|
293
|
+
raise ValueError(
|
294
|
+
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
295
|
+
)
|
296
|
+
weight_name = targeted_files[0]
|
297
|
+
return weight_name
|
298
|
+
|
299
|
+
|
184
300
|
class LoraBaseMixin:
|
185
301
|
"""Utility class for handling LoRAs."""
|
186
302
|
|
@@ -234,124 +350,16 @@ class LoraBaseMixin:
|
|
234
350
|
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
235
351
|
|
236
352
|
@classmethod
|
237
|
-
def _fetch_state_dict(
|
238
|
-
cls
|
239
|
-
|
240
|
-
|
241
|
-
use_safetensors,
|
242
|
-
local_files_only,
|
243
|
-
cache_dir,
|
244
|
-
force_download,
|
245
|
-
proxies,
|
246
|
-
token,
|
247
|
-
revision,
|
248
|
-
subfolder,
|
249
|
-
user_agent,
|
250
|
-
allow_pickle,
|
251
|
-
):
|
252
|
-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
253
|
-
|
254
|
-
model_file = None
|
255
|
-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
256
|
-
# Let's first try to load .safetensors weights
|
257
|
-
if (use_safetensors and weight_name is None) or (
|
258
|
-
weight_name is not None and weight_name.endswith(".safetensors")
|
259
|
-
):
|
260
|
-
try:
|
261
|
-
# Here we're relaxing the loading check to enable more Inference API
|
262
|
-
# friendliness where sometimes, it's not at all possible to automatically
|
263
|
-
# determine `weight_name`.
|
264
|
-
if weight_name is None:
|
265
|
-
weight_name = cls._best_guess_weight_name(
|
266
|
-
pretrained_model_name_or_path_or_dict,
|
267
|
-
file_extension=".safetensors",
|
268
|
-
local_files_only=local_files_only,
|
269
|
-
)
|
270
|
-
model_file = _get_model_file(
|
271
|
-
pretrained_model_name_or_path_or_dict,
|
272
|
-
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
273
|
-
cache_dir=cache_dir,
|
274
|
-
force_download=force_download,
|
275
|
-
proxies=proxies,
|
276
|
-
local_files_only=local_files_only,
|
277
|
-
token=token,
|
278
|
-
revision=revision,
|
279
|
-
subfolder=subfolder,
|
280
|
-
user_agent=user_agent,
|
281
|
-
)
|
282
|
-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
283
|
-
except (IOError, safetensors.SafetensorError) as e:
|
284
|
-
if not allow_pickle:
|
285
|
-
raise e
|
286
|
-
# try loading non-safetensors weights
|
287
|
-
model_file = None
|
288
|
-
pass
|
289
|
-
|
290
|
-
if model_file is None:
|
291
|
-
if weight_name is None:
|
292
|
-
weight_name = cls._best_guess_weight_name(
|
293
|
-
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
294
|
-
)
|
295
|
-
model_file = _get_model_file(
|
296
|
-
pretrained_model_name_or_path_or_dict,
|
297
|
-
weights_name=weight_name or LORA_WEIGHT_NAME,
|
298
|
-
cache_dir=cache_dir,
|
299
|
-
force_download=force_download,
|
300
|
-
proxies=proxies,
|
301
|
-
local_files_only=local_files_only,
|
302
|
-
token=token,
|
303
|
-
revision=revision,
|
304
|
-
subfolder=subfolder,
|
305
|
-
user_agent=user_agent,
|
306
|
-
)
|
307
|
-
state_dict = load_state_dict(model_file)
|
308
|
-
else:
|
309
|
-
state_dict = pretrained_model_name_or_path_or_dict
|
310
|
-
|
311
|
-
return state_dict
|
353
|
+
def _fetch_state_dict(cls, *args, **kwargs):
|
354
|
+
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
355
|
+
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
356
|
+
return _fetch_state_dict(*args, **kwargs)
|
312
357
|
|
313
358
|
@classmethod
|
314
|
-
def _best_guess_weight_name(
|
315
|
-
cls
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
if local_files_only or HF_HUB_OFFLINE:
|
320
|
-
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
321
|
-
|
322
|
-
targeted_files = []
|
323
|
-
|
324
|
-
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
325
|
-
return
|
326
|
-
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
327
|
-
targeted_files = [
|
328
|
-
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
|
329
|
-
]
|
330
|
-
else:
|
331
|
-
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
332
|
-
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
333
|
-
if len(targeted_files) == 0:
|
334
|
-
return
|
335
|
-
|
336
|
-
# "scheduler" does not correspond to a LoRA checkpoint.
|
337
|
-
# "optimizer" does not correspond to a LoRA checkpoint
|
338
|
-
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
339
|
-
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
340
|
-
targeted_files = list(
|
341
|
-
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
342
|
-
)
|
343
|
-
|
344
|
-
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
345
|
-
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
346
|
-
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
347
|
-
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
348
|
-
|
349
|
-
if len(targeted_files) > 1:
|
350
|
-
raise ValueError(
|
351
|
-
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
352
|
-
)
|
353
|
-
weight_name = targeted_files[0]
|
354
|
-
return weight_name
|
359
|
+
def _best_guess_weight_name(cls, *args, **kwargs):
|
360
|
+
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
361
|
+
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
362
|
+
return _best_guess_weight_name(*args, **kwargs)
|
355
363
|
|
356
364
|
def unload_lora_weights(self):
|
357
365
|
"""
|
@@ -725,8 +733,6 @@ class LoraBaseMixin:
|
|
725
733
|
save_function: Callable,
|
726
734
|
safe_serialization: bool,
|
727
735
|
):
|
728
|
-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
729
|
-
|
730
736
|
if os.path.isfile(save_directory):
|
731
737
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
732
738
|
return
|
@@ -636,10 +636,19 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
|
636
636
|
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
|
637
637
|
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
638
638
|
|
639
|
-
if "
|
639
|
+
if "proj_lora" in old_key:
|
640
640
|
new_key += ".proj_out"
|
641
|
-
elif "
|
642
|
-
|
641
|
+
elif "qkv_lora" in old_key and "up" not in old_key:
|
642
|
+
handle_qkv(
|
643
|
+
old_state_dict,
|
644
|
+
new_state_dict,
|
645
|
+
old_key,
|
646
|
+
[
|
647
|
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
|
648
|
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
|
649
|
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
|
650
|
+
],
|
651
|
+
)
|
643
652
|
|
644
653
|
if "down" in old_key:
|
645
654
|
new_key += ".lora_A.weight"
|
@@ -658,3 +667,309 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
|
658
667
|
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
|
659
668
|
|
660
669
|
return new_state_dict
|
670
|
+
|
671
|
+
|
672
|
+
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
673
|
+
converted_state_dict = {}
|
674
|
+
original_state_dict_keys = list(original_state_dict.keys())
|
675
|
+
num_layers = 19
|
676
|
+
num_single_layers = 38
|
677
|
+
inner_dim = 3072
|
678
|
+
mlp_ratio = 4.0
|
679
|
+
|
680
|
+
def swap_scale_shift(weight):
|
681
|
+
shift, scale = weight.chunk(2, dim=0)
|
682
|
+
new_weight = torch.cat([scale, shift], dim=0)
|
683
|
+
return new_weight
|
684
|
+
|
685
|
+
for lora_key in ["lora_A", "lora_B"]:
|
686
|
+
## time_text_embed.timestep_embedder <- time_in
|
687
|
+
converted_state_dict[
|
688
|
+
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
|
689
|
+
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
690
|
+
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
691
|
+
converted_state_dict[
|
692
|
+
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
|
693
|
+
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
694
|
+
|
695
|
+
converted_state_dict[
|
696
|
+
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
|
697
|
+
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
698
|
+
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
699
|
+
converted_state_dict[
|
700
|
+
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
|
701
|
+
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
702
|
+
|
703
|
+
## time_text_embed.text_embedder <- vector_in
|
704
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
|
705
|
+
f"vector_in.in_layer.{lora_key}.weight"
|
706
|
+
)
|
707
|
+
if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
708
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
|
709
|
+
f"vector_in.in_layer.{lora_key}.bias"
|
710
|
+
)
|
711
|
+
|
712
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
|
713
|
+
f"vector_in.out_layer.{lora_key}.weight"
|
714
|
+
)
|
715
|
+
if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
716
|
+
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
|
717
|
+
f"vector_in.out_layer.{lora_key}.bias"
|
718
|
+
)
|
719
|
+
|
720
|
+
# guidance
|
721
|
+
has_guidance = any("guidance" in k for k in original_state_dict)
|
722
|
+
if has_guidance:
|
723
|
+
converted_state_dict[
|
724
|
+
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
|
725
|
+
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
726
|
+
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
727
|
+
converted_state_dict[
|
728
|
+
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
|
729
|
+
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
730
|
+
|
731
|
+
converted_state_dict[
|
732
|
+
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
|
733
|
+
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
734
|
+
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
735
|
+
converted_state_dict[
|
736
|
+
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
|
737
|
+
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
738
|
+
|
739
|
+
# context_embedder
|
740
|
+
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
|
741
|
+
f"txt_in.{lora_key}.weight"
|
742
|
+
)
|
743
|
+
if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
|
744
|
+
converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
|
745
|
+
f"txt_in.{lora_key}.bias"
|
746
|
+
)
|
747
|
+
|
748
|
+
# x_embedder
|
749
|
+
converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
|
750
|
+
if f"img_in.{lora_key}.bias" in original_state_dict_keys:
|
751
|
+
converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
|
752
|
+
|
753
|
+
# double transformer blocks
|
754
|
+
for i in range(num_layers):
|
755
|
+
block_prefix = f"transformer_blocks.{i}."
|
756
|
+
|
757
|
+
for lora_key in ["lora_A", "lora_B"]:
|
758
|
+
# norms
|
759
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
|
760
|
+
f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
|
761
|
+
)
|
762
|
+
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
763
|
+
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
|
764
|
+
f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
|
765
|
+
)
|
766
|
+
|
767
|
+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
|
768
|
+
f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
|
769
|
+
)
|
770
|
+
if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
771
|
+
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
|
772
|
+
f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
|
773
|
+
)
|
774
|
+
|
775
|
+
# Q, K, V
|
776
|
+
if lora_key == "lora_A":
|
777
|
+
sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
|
778
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
779
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
780
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
781
|
+
|
782
|
+
context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
|
783
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
|
784
|
+
[context_lora_weight]
|
785
|
+
)
|
786
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
|
787
|
+
[context_lora_weight]
|
788
|
+
)
|
789
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
|
790
|
+
[context_lora_weight]
|
791
|
+
)
|
792
|
+
else:
|
793
|
+
sample_q, sample_k, sample_v = torch.chunk(
|
794
|
+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
|
795
|
+
)
|
796
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
797
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
798
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
799
|
+
|
800
|
+
context_q, context_k, context_v = torch.chunk(
|
801
|
+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
|
802
|
+
)
|
803
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
|
804
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
|
805
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
|
806
|
+
|
807
|
+
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
808
|
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
809
|
+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
|
810
|
+
)
|
811
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
|
812
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
|
813
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
|
814
|
+
|
815
|
+
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
816
|
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
817
|
+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
|
818
|
+
)
|
819
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
|
820
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
|
821
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
|
822
|
+
|
823
|
+
# ff img_mlp
|
824
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
825
|
+
f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
|
826
|
+
)
|
827
|
+
if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
828
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
829
|
+
f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
|
830
|
+
)
|
831
|
+
|
832
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
833
|
+
f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
|
834
|
+
)
|
835
|
+
if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
836
|
+
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
837
|
+
f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
|
838
|
+
)
|
839
|
+
|
840
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
841
|
+
f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
|
842
|
+
)
|
843
|
+
if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
844
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
845
|
+
f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
|
846
|
+
)
|
847
|
+
|
848
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
849
|
+
f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
|
850
|
+
)
|
851
|
+
if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
852
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
853
|
+
f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
|
854
|
+
)
|
855
|
+
|
856
|
+
# output projections.
|
857
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
|
858
|
+
f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
|
859
|
+
)
|
860
|
+
if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
861
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
|
862
|
+
f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
|
863
|
+
)
|
864
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
|
865
|
+
f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
|
866
|
+
)
|
867
|
+
if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
868
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
|
869
|
+
f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
|
870
|
+
)
|
871
|
+
|
872
|
+
# qk_norm
|
873
|
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
874
|
+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
875
|
+
)
|
876
|
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
877
|
+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
878
|
+
)
|
879
|
+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
|
880
|
+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
881
|
+
)
|
882
|
+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
|
883
|
+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
884
|
+
)
|
885
|
+
|
886
|
+
# single transfomer blocks
|
887
|
+
for i in range(num_single_layers):
|
888
|
+
block_prefix = f"single_transformer_blocks.{i}."
|
889
|
+
|
890
|
+
for lora_key in ["lora_A", "lora_B"]:
|
891
|
+
# norm.linear <- single_blocks.0.modulation.lin
|
892
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
|
893
|
+
f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
|
894
|
+
)
|
895
|
+
if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
|
896
|
+
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
|
897
|
+
f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
|
898
|
+
)
|
899
|
+
|
900
|
+
# Q, K, V, mlp
|
901
|
+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
902
|
+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
903
|
+
|
904
|
+
if lora_key == "lora_A":
|
905
|
+
lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
|
906
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
|
907
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
|
908
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
|
909
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
|
910
|
+
|
911
|
+
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
912
|
+
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
|
913
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
|
914
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
|
915
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
|
916
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
|
917
|
+
else:
|
918
|
+
q, k, v, mlp = torch.split(
|
919
|
+
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
|
920
|
+
)
|
921
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
|
922
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
|
923
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
|
924
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
|
925
|
+
|
926
|
+
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
927
|
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
928
|
+
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
|
929
|
+
)
|
930
|
+
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
|
931
|
+
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
|
932
|
+
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
|
933
|
+
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
|
934
|
+
|
935
|
+
# output projections.
|
936
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
937
|
+
f"single_blocks.{i}.linear2.{lora_key}.weight"
|
938
|
+
)
|
939
|
+
if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
|
940
|
+
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
941
|
+
f"single_blocks.{i}.linear2.{lora_key}.bias"
|
942
|
+
)
|
943
|
+
|
944
|
+
# qk norm
|
945
|
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
946
|
+
f"single_blocks.{i}.norm.query_norm.scale"
|
947
|
+
)
|
948
|
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
949
|
+
f"single_blocks.{i}.norm.key_norm.scale"
|
950
|
+
)
|
951
|
+
|
952
|
+
for lora_key in ["lora_A", "lora_B"]:
|
953
|
+
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
954
|
+
f"final_layer.linear.{lora_key}.weight"
|
955
|
+
)
|
956
|
+
if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
|
957
|
+
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
958
|
+
f"final_layer.linear.{lora_key}.bias"
|
959
|
+
)
|
960
|
+
|
961
|
+
converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
|
962
|
+
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
|
963
|
+
)
|
964
|
+
if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
|
965
|
+
converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
|
966
|
+
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
|
967
|
+
)
|
968
|
+
|
969
|
+
if len(original_state_dict) > 0:
|
970
|
+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
971
|
+
|
972
|
+
for key in list(converted_state_dict.keys()):
|
973
|
+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
974
|
+
|
975
|
+
return converted_state_dict
|