diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +15 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +13 -1
- diffusers/loaders/single_file_model.py +15 -8
- diffusers/loaders/single_file_utils.py +267 -17
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +7 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_sd3.py +418 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +138 -20
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +353 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +11 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/auto_pipeline.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
226
226
|
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
227
227
|
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
228
228
|
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
229
|
+
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
230
|
+
|
229
231
|
if "self_attn" in diffusers_name:
|
230
232
|
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
231
233
|
te_state_dict[diffusers_name] = state_dict.pop(key)
|
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
243
245
|
else:
|
244
246
|
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
245
247
|
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
248
|
+
# OneTrainer specificity
|
249
|
+
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
|
250
|
+
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
251
|
+
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
246
252
|
|
247
253
|
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
248
254
|
dora_scale_key_to_replace_te = (
|
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
270
276
|
network_alphas.update({new_name: alpha})
|
271
277
|
|
272
278
|
if len(state_dict) > 0:
|
273
|
-
raise ValueError(f"The following keys have not been correctly
|
279
|
+
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
274
280
|
|
275
281
|
logger.info("Kohya-style checkpoint detected.")
|
276
282
|
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
diffusers/loaders/single_file.py
CHANGED
@@ -28,9 +28,11 @@ from .single_file_utils import (
|
|
28
28
|
_legacy_load_safety_checker,
|
29
29
|
_legacy_load_scheduler,
|
30
30
|
create_diffusers_clip_model_from_ldm,
|
31
|
+
create_diffusers_t5_model_from_checkpoint,
|
31
32
|
fetch_diffusers_config,
|
32
33
|
fetch_original_config,
|
33
34
|
is_clip_model_in_single_file,
|
35
|
+
is_t5_in_single_file,
|
34
36
|
load_single_file_checkpoint,
|
35
37
|
)
|
36
38
|
|
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
|
|
118
120
|
is_legacy_loading=is_legacy_loading,
|
119
121
|
)
|
120
122
|
|
123
|
+
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
124
|
+
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
125
|
+
class_obj,
|
126
|
+
checkpoint=checkpoint,
|
127
|
+
config=cached_model_config_path,
|
128
|
+
subfolder=name,
|
129
|
+
torch_dtype=torch_dtype,
|
130
|
+
local_files_only=local_files_only,
|
131
|
+
)
|
132
|
+
|
121
133
|
elif is_tokenizer and is_legacy_loading:
|
122
134
|
loaded_sub_model = _legacy_load_clip_tokenizer(
|
123
135
|
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
@@ -234,7 +246,7 @@ def _download_diffusers_model_config_from_hub(
|
|
234
246
|
local_files_only=None,
|
235
247
|
token=None,
|
236
248
|
):
|
237
|
-
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
|
249
|
+
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
238
250
|
cached_model_path = snapshot_download(
|
239
251
|
pretrained_model_name_or_path,
|
240
252
|
cache_dir=cache_dir,
|
@@ -24,6 +24,7 @@ from .single_file_utils import (
|
|
24
24
|
convert_controlnet_checkpoint,
|
25
25
|
convert_ldm_unet_checkpoint,
|
26
26
|
convert_ldm_vae_checkpoint,
|
27
|
+
convert_sd3_transformer_checkpoint_to_diffusers,
|
27
28
|
convert_stable_cascade_unet_single_file_to_diffusers,
|
28
29
|
create_controlnet_diffusers_config_from_ldm,
|
29
30
|
create_unet_diffusers_config_from_ldm,
|
@@ -64,6 +65,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|
64
65
|
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
65
66
|
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
66
67
|
},
|
68
|
+
"SD3Transformer2DModel": {
|
69
|
+
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
70
|
+
"default_subfolder": "transformer",
|
71
|
+
},
|
67
72
|
}
|
68
73
|
|
69
74
|
|
@@ -271,16 +276,18 @@ class FromOriginalModelMixin:
|
|
271
276
|
|
272
277
|
if is_accelerate_available():
|
273
278
|
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
274
|
-
if model._keys_to_ignore_on_load_unexpected is not None:
|
275
|
-
for pat in model._keys_to_ignore_on_load_unexpected:
|
276
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
277
279
|
|
278
|
-
if len(unexpected_keys) > 0:
|
279
|
-
logger.warning(
|
280
|
-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
281
|
-
)
|
282
280
|
else:
|
283
|
-
model.load_state_dict(diffusers_format_checkpoint)
|
281
|
+
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
282
|
+
|
283
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
284
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
285
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
286
|
+
|
287
|
+
if len(unexpected_keys) > 0:
|
288
|
+
logger.warning(
|
289
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
290
|
+
)
|
284
291
|
|
285
292
|
if torch_dtype is not None:
|
286
293
|
model.to(torch_dtype)
|
@@ -21,6 +21,7 @@ from io import BytesIO
|
|
21
21
|
from urllib.parse import urlparse
|
22
22
|
|
23
23
|
import requests
|
24
|
+
import torch
|
24
25
|
import yaml
|
25
26
|
|
26
27
|
from ..models.modeling_utils import load_state_dict
|
@@ -65,11 +66,14 @@ CHECKPOINT_KEY_NAMES = {
|
|
65
66
|
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
|
66
67
|
"clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
67
68
|
"clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
|
69
|
+
"clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
|
68
70
|
"open_clip": "cond_stage_model.model.token_embedding.weight",
|
69
71
|
"open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
|
70
72
|
"open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
|
73
|
+
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
|
71
74
|
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
72
75
|
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
76
|
+
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
73
77
|
}
|
74
78
|
|
75
79
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
@@ -96,6 +100,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
96
100
|
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
|
97
101
|
"subfolder": "prior_lite",
|
98
102
|
},
|
103
|
+
"sd3": {
|
104
|
+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
|
105
|
+
},
|
99
106
|
}
|
100
107
|
|
101
108
|
# Use to configure model sample size when original config is provided
|
@@ -242,7 +249,10 @@ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
|
242
249
|
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
243
250
|
LDM_UNET_KEY = "model.diffusion_model."
|
244
251
|
LDM_CONTROLNET_KEY = "control_model."
|
245
|
-
LDM_CLIP_PREFIX_TO_REMOVE = [
|
252
|
+
LDM_CLIP_PREFIX_TO_REMOVE = [
|
253
|
+
"cond_stage_model.transformer.",
|
254
|
+
"conditioner.embedders.0.transformer.",
|
255
|
+
]
|
246
256
|
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
247
257
|
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
248
258
|
|
@@ -366,6 +376,13 @@ def is_clip_sdxl_model(checkpoint):
|
|
366
376
|
return False
|
367
377
|
|
368
378
|
|
379
|
+
def is_clip_sd3_model(checkpoint):
|
380
|
+
if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
|
381
|
+
return True
|
382
|
+
|
383
|
+
return False
|
384
|
+
|
385
|
+
|
369
386
|
def is_open_clip_model(checkpoint):
|
370
387
|
if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
|
371
388
|
return True
|
@@ -380,6 +397,13 @@ def is_open_clip_sdxl_model(checkpoint):
|
|
380
397
|
return False
|
381
398
|
|
382
399
|
|
400
|
+
def is_open_clip_sd3_model(checkpoint):
|
401
|
+
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
402
|
+
return True
|
403
|
+
|
404
|
+
return False
|
405
|
+
|
406
|
+
|
383
407
|
def is_open_clip_sdxl_refiner_model(checkpoint):
|
384
408
|
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
|
385
409
|
return True
|
@@ -391,9 +415,11 @@ def is_clip_model_in_single_file(class_obj, checkpoint):
|
|
391
415
|
is_clip_in_checkpoint = any(
|
392
416
|
[
|
393
417
|
is_clip_model(checkpoint),
|
418
|
+
is_clip_sd3_model(checkpoint),
|
394
419
|
is_open_clip_model(checkpoint),
|
395
420
|
is_open_clip_sdxl_model(checkpoint),
|
396
421
|
is_open_clip_sdxl_refiner_model(checkpoint),
|
422
|
+
is_open_clip_sd3_model(checkpoint),
|
397
423
|
]
|
398
424
|
)
|
399
425
|
if (
|
@@ -456,6 +482,9 @@ def infer_diffusers_model_type(checkpoint):
|
|
456
482
|
):
|
457
483
|
model_type = "stable_cascade_stage_b"
|
458
484
|
|
485
|
+
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
|
486
|
+
model_type = "sd3"
|
487
|
+
|
459
488
|
else:
|
460
489
|
model_type = "v1"
|
461
490
|
|
@@ -1206,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
1206
1235
|
return new_checkpoint
|
1207
1236
|
|
1208
1237
|
|
1209
|
-
def convert_ldm_clip_checkpoint(checkpoint):
|
1238
|
+
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
|
1210
1239
|
keys = list(checkpoint.keys())
|
1211
1240
|
text_model_dict = {}
|
1212
1241
|
|
1213
|
-
remove_prefixes =
|
1242
|
+
remove_prefixes = []
|
1243
|
+
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
|
1244
|
+
if remove_prefix:
|
1245
|
+
remove_prefixes.append(remove_prefix)
|
1214
1246
|
|
1215
1247
|
for key in keys:
|
1216
1248
|
for prefix in remove_prefixes:
|
@@ -1236,8 +1268,6 @@ def convert_open_clip_checkpoint(
|
|
1236
1268
|
else:
|
1237
1269
|
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
1238
1270
|
|
1239
|
-
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
1240
|
-
|
1241
1271
|
keys = list(checkpoint.keys())
|
1242
1272
|
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
|
1243
1273
|
|
@@ -1286,9 +1316,6 @@ def convert_open_clip_checkpoint(
|
|
1286
1316
|
else:
|
1287
1317
|
text_model_dict[diffusers_key] = checkpoint.get(key)
|
1288
1318
|
|
1289
|
-
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
1290
|
-
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
1291
|
-
|
1292
1319
|
return text_model_dict
|
1293
1320
|
|
1294
1321
|
|
@@ -1349,6 +1376,13 @@ def create_diffusers_clip_model_from_ldm(
|
|
1349
1376
|
):
|
1350
1377
|
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
1351
1378
|
|
1379
|
+
elif (
|
1380
|
+
is_clip_sd3_model(checkpoint)
|
1381
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
|
1382
|
+
):
|
1383
|
+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
|
1384
|
+
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
|
1385
|
+
|
1352
1386
|
elif is_open_clip_model(checkpoint):
|
1353
1387
|
prefix = "cond_stage_model.model."
|
1354
1388
|
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
@@ -1364,22 +1398,28 @@ def create_diffusers_clip_model_from_ldm(
|
|
1364
1398
|
prefix = "conditioner.embedders.0.model."
|
1365
1399
|
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
1366
1400
|
|
1401
|
+
elif (
|
1402
|
+
is_open_clip_sd3_model(checkpoint)
|
1403
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
|
1404
|
+
):
|
1405
|
+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
|
1406
|
+
|
1367
1407
|
else:
|
1368
1408
|
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
1369
1409
|
|
1370
1410
|
if is_accelerate_available():
|
1371
1411
|
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1372
|
-
|
1373
|
-
|
1374
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1412
|
+
else:
|
1413
|
+
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
1375
1414
|
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
)
|
1415
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
1416
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
1417
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1380
1418
|
|
1381
|
-
|
1382
|
-
|
1419
|
+
if len(unexpected_keys) > 0:
|
1420
|
+
logger.warning(
|
1421
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1422
|
+
)
|
1383
1423
|
|
1384
1424
|
if torch_dtype is not None:
|
1385
1425
|
model.to(torch_dtype)
|
@@ -1559,3 +1599,213 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
|
|
1559
1599
|
)
|
1560
1600
|
|
1561
1601
|
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
|
1602
|
+
|
1603
|
+
|
1604
|
+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
1605
|
+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
1606
|
+
def swap_scale_shift(weight, dim):
|
1607
|
+
shift, scale = weight.chunk(2, dim=0)
|
1608
|
+
new_weight = torch.cat([scale, shift], dim=0)
|
1609
|
+
return new_weight
|
1610
|
+
|
1611
|
+
|
1612
|
+
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1613
|
+
converted_state_dict = {}
|
1614
|
+
keys = list(checkpoint.keys())
|
1615
|
+
for k in keys:
|
1616
|
+
if "model.diffusion_model." in k:
|
1617
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
1618
|
+
|
1619
|
+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
|
1620
|
+
caption_projection_dim = 1536
|
1621
|
+
|
1622
|
+
# Positional and patch embeddings.
|
1623
|
+
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
|
1624
|
+
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
1625
|
+
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
1626
|
+
|
1627
|
+
# Timestep embeddings.
|
1628
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
1629
|
+
"t_embedder.mlp.0.weight"
|
1630
|
+
)
|
1631
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
1632
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
1633
|
+
"t_embedder.mlp.2.weight"
|
1634
|
+
)
|
1635
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
1636
|
+
|
1637
|
+
# Context projections.
|
1638
|
+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
|
1639
|
+
converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
|
1640
|
+
|
1641
|
+
# Pooled context projection.
|
1642
|
+
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
|
1643
|
+
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
|
1644
|
+
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
|
1645
|
+
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
|
1646
|
+
|
1647
|
+
# Transformer blocks 🎸.
|
1648
|
+
for i in range(num_layers):
|
1649
|
+
# Q, K, V
|
1650
|
+
sample_q, sample_k, sample_v = torch.chunk(
|
1651
|
+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
|
1652
|
+
)
|
1653
|
+
context_q, context_k, context_v = torch.chunk(
|
1654
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
|
1655
|
+
)
|
1656
|
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
1657
|
+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
|
1658
|
+
)
|
1659
|
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
1660
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
|
1661
|
+
)
|
1662
|
+
|
1663
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
|
1664
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
|
1665
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
|
1666
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
|
1667
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
|
1668
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
|
1669
|
+
|
1670
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
|
1671
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
1672
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
|
1673
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
1674
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
|
1675
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
1676
|
+
|
1677
|
+
# output projections.
|
1678
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
|
1679
|
+
f"joint_blocks.{i}.x_block.attn.proj.weight"
|
1680
|
+
)
|
1681
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
|
1682
|
+
f"joint_blocks.{i}.x_block.attn.proj.bias"
|
1683
|
+
)
|
1684
|
+
if not (i == num_layers - 1):
|
1685
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
|
1686
|
+
f"joint_blocks.{i}.context_block.attn.proj.weight"
|
1687
|
+
)
|
1688
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
|
1689
|
+
f"joint_blocks.{i}.context_block.attn.proj.bias"
|
1690
|
+
)
|
1691
|
+
|
1692
|
+
# norms.
|
1693
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
|
1694
|
+
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
|
1695
|
+
)
|
1696
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
|
1697
|
+
f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
|
1698
|
+
)
|
1699
|
+
if not (i == num_layers - 1):
|
1700
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
|
1701
|
+
f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
|
1702
|
+
)
|
1703
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
|
1704
|
+
f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
|
1705
|
+
)
|
1706
|
+
else:
|
1707
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
|
1708
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
|
1709
|
+
dim=caption_projection_dim,
|
1710
|
+
)
|
1711
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
|
1712
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
|
1713
|
+
dim=caption_projection_dim,
|
1714
|
+
)
|
1715
|
+
|
1716
|
+
# ffs.
|
1717
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
|
1718
|
+
f"joint_blocks.{i}.x_block.mlp.fc1.weight"
|
1719
|
+
)
|
1720
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
|
1721
|
+
f"joint_blocks.{i}.x_block.mlp.fc1.bias"
|
1722
|
+
)
|
1723
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
|
1724
|
+
f"joint_blocks.{i}.x_block.mlp.fc2.weight"
|
1725
|
+
)
|
1726
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
|
1727
|
+
f"joint_blocks.{i}.x_block.mlp.fc2.bias"
|
1728
|
+
)
|
1729
|
+
if not (i == num_layers - 1):
|
1730
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
|
1731
|
+
f"joint_blocks.{i}.context_block.mlp.fc1.weight"
|
1732
|
+
)
|
1733
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
|
1734
|
+
f"joint_blocks.{i}.context_block.mlp.fc1.bias"
|
1735
|
+
)
|
1736
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
|
1737
|
+
f"joint_blocks.{i}.context_block.mlp.fc2.weight"
|
1738
|
+
)
|
1739
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
|
1740
|
+
f"joint_blocks.{i}.context_block.mlp.fc2.bias"
|
1741
|
+
)
|
1742
|
+
|
1743
|
+
# Final blocks.
|
1744
|
+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
1745
|
+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
1746
|
+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
1747
|
+
checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
|
1748
|
+
)
|
1749
|
+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
1750
|
+
checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
|
1751
|
+
)
|
1752
|
+
|
1753
|
+
return converted_state_dict
|
1754
|
+
|
1755
|
+
|
1756
|
+
def is_t5_in_single_file(checkpoint):
|
1757
|
+
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
|
1758
|
+
return True
|
1759
|
+
|
1760
|
+
return False
|
1761
|
+
|
1762
|
+
|
1763
|
+
def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
|
1764
|
+
keys = list(checkpoint.keys())
|
1765
|
+
text_model_dict = {}
|
1766
|
+
|
1767
|
+
remove_prefixes = ["text_encoders.t5xxl.transformer."]
|
1768
|
+
|
1769
|
+
for key in keys:
|
1770
|
+
for prefix in remove_prefixes:
|
1771
|
+
if key.startswith(prefix):
|
1772
|
+
diffusers_key = key.replace(prefix, "")
|
1773
|
+
text_model_dict[diffusers_key] = checkpoint.get(key)
|
1774
|
+
|
1775
|
+
return text_model_dict
|
1776
|
+
|
1777
|
+
|
1778
|
+
def create_diffusers_t5_model_from_checkpoint(
|
1779
|
+
cls,
|
1780
|
+
checkpoint,
|
1781
|
+
subfolder="",
|
1782
|
+
config=None,
|
1783
|
+
torch_dtype=None,
|
1784
|
+
local_files_only=None,
|
1785
|
+
):
|
1786
|
+
if config:
|
1787
|
+
config = {"pretrained_model_name_or_path": config}
|
1788
|
+
else:
|
1789
|
+
config = fetch_diffusers_config(checkpoint)
|
1790
|
+
|
1791
|
+
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
|
1792
|
+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1793
|
+
with ctx():
|
1794
|
+
model = cls(model_config)
|
1795
|
+
|
1796
|
+
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
|
1797
|
+
|
1798
|
+
if is_accelerate_available():
|
1799
|
+
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1800
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
1801
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
1802
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1803
|
+
|
1804
|
+
if len(unexpected_keys) > 0:
|
1805
|
+
logger.warning(
|
1806
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1807
|
+
)
|
1808
|
+
|
1809
|
+
else:
|
1810
|
+
model.load_state_dict(diffusers_format_checkpoint)
|
1811
|
+
return model
|