diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +9 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +1 -1
- diffusers/loaders/single_file_model.py +5 -0
- diffusers/loaders/single_file_utils.py +242 -2
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +5 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +128 -17
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +344 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
226
226
|
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
227
227
|
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
228
228
|
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
229
|
+
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
230
|
+
|
229
231
|
if "self_attn" in diffusers_name:
|
230
232
|
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
231
233
|
te_state_dict[diffusers_name] = state_dict.pop(key)
|
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
243
245
|
else:
|
244
246
|
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
245
247
|
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
248
|
+
# OneTrainer specificity
|
249
|
+
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
|
250
|
+
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
251
|
+
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
246
252
|
|
247
253
|
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
248
254
|
dora_scale_key_to_replace_te = (
|
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
270
276
|
network_alphas.update({new_name: alpha})
|
271
277
|
|
272
278
|
if len(state_dict) > 0:
|
273
|
-
raise ValueError(f"The following keys have not been correctly
|
279
|
+
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
274
280
|
|
275
281
|
logger.info("Kohya-style checkpoint detected.")
|
276
282
|
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
diffusers/loaders/single_file.py
CHANGED
@@ -234,7 +234,7 @@ def _download_diffusers_model_config_from_hub(
|
|
234
234
|
local_files_only=None,
|
235
235
|
token=None,
|
236
236
|
):
|
237
|
-
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
|
237
|
+
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
238
238
|
cached_model_path = snapshot_download(
|
239
239
|
pretrained_model_name_or_path,
|
240
240
|
cache_dir=cache_dir,
|
@@ -24,6 +24,7 @@ from .single_file_utils import (
|
|
24
24
|
convert_controlnet_checkpoint,
|
25
25
|
convert_ldm_unet_checkpoint,
|
26
26
|
convert_ldm_vae_checkpoint,
|
27
|
+
convert_sd3_transformer_checkpoint_to_diffusers,
|
27
28
|
convert_stable_cascade_unet_single_file_to_diffusers,
|
28
29
|
create_controlnet_diffusers_config_from_ldm,
|
29
30
|
create_unet_diffusers_config_from_ldm,
|
@@ -64,6 +65,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
|
64
65
|
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
65
66
|
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
66
67
|
},
|
68
|
+
"SD3Transformer2DModel": {
|
69
|
+
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
70
|
+
"default_subfolder": "transformer",
|
71
|
+
},
|
67
72
|
}
|
68
73
|
|
69
74
|
|
@@ -21,6 +21,7 @@ from io import BytesIO
|
|
21
21
|
from urllib.parse import urlparse
|
22
22
|
|
23
23
|
import requests
|
24
|
+
import torch
|
24
25
|
import yaml
|
25
26
|
|
26
27
|
from ..models.modeling_utils import load_state_dict
|
@@ -65,11 +66,14 @@ CHECKPOINT_KEY_NAMES = {
|
|
65
66
|
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
|
66
67
|
"clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
67
68
|
"clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
|
69
|
+
"clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
|
68
70
|
"open_clip": "cond_stage_model.model.token_embedding.weight",
|
69
71
|
"open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
|
70
72
|
"open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
|
73
|
+
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
|
71
74
|
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
72
75
|
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
76
|
+
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
73
77
|
}
|
74
78
|
|
75
79
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
@@ -96,6 +100,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
96
100
|
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
|
97
101
|
"subfolder": "prior_lite",
|
98
102
|
},
|
103
|
+
"sd3": {
|
104
|
+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
|
105
|
+
},
|
99
106
|
}
|
100
107
|
|
101
108
|
# Use to configure model sample size when original config is provided
|
@@ -242,7 +249,11 @@ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
|
242
249
|
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
243
250
|
LDM_UNET_KEY = "model.diffusion_model."
|
244
251
|
LDM_CONTROLNET_KEY = "control_model."
|
245
|
-
LDM_CLIP_PREFIX_TO_REMOVE = [
|
252
|
+
LDM_CLIP_PREFIX_TO_REMOVE = [
|
253
|
+
"cond_stage_model.transformer.",
|
254
|
+
"conditioner.embedders.0.transformer.",
|
255
|
+
"text_encoders.clip_l.transformer.",
|
256
|
+
]
|
246
257
|
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
247
258
|
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
248
259
|
|
@@ -366,6 +377,13 @@ def is_clip_sdxl_model(checkpoint):
|
|
366
377
|
return False
|
367
378
|
|
368
379
|
|
380
|
+
def is_clip_sd3_model(checkpoint):
|
381
|
+
if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
|
382
|
+
return True
|
383
|
+
|
384
|
+
return False
|
385
|
+
|
386
|
+
|
369
387
|
def is_open_clip_model(checkpoint):
|
370
388
|
if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
|
371
389
|
return True
|
@@ -380,8 +398,12 @@ def is_open_clip_sdxl_model(checkpoint):
|
|
380
398
|
return False
|
381
399
|
|
382
400
|
|
401
|
+
def is_open_clip_sd3_model(checkpoint):
|
402
|
+
is_open_clip_sdxl_refiner_model(checkpoint)
|
403
|
+
|
404
|
+
|
383
405
|
def is_open_clip_sdxl_refiner_model(checkpoint):
|
384
|
-
if CHECKPOINT_KEY_NAMES["
|
406
|
+
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
385
407
|
return True
|
386
408
|
|
387
409
|
return False
|
@@ -391,9 +413,11 @@ def is_clip_model_in_single_file(class_obj, checkpoint):
|
|
391
413
|
is_clip_in_checkpoint = any(
|
392
414
|
[
|
393
415
|
is_clip_model(checkpoint),
|
416
|
+
is_clip_sd3_model(checkpoint),
|
394
417
|
is_open_clip_model(checkpoint),
|
395
418
|
is_open_clip_sdxl_model(checkpoint),
|
396
419
|
is_open_clip_sdxl_refiner_model(checkpoint),
|
420
|
+
is_open_clip_sd3_model(checkpoint),
|
397
421
|
]
|
398
422
|
)
|
399
423
|
if (
|
@@ -456,6 +480,9 @@ def infer_diffusers_model_type(checkpoint):
|
|
456
480
|
):
|
457
481
|
model_type = "stable_cascade_stage_b"
|
458
482
|
|
483
|
+
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
|
484
|
+
model_type = "sd3"
|
485
|
+
|
459
486
|
else:
|
460
487
|
model_type = "v1"
|
461
488
|
|
@@ -1364,6 +1391,10 @@ def create_diffusers_clip_model_from_ldm(
|
|
1364
1391
|
prefix = "conditioner.embedders.0.model."
|
1365
1392
|
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
1366
1393
|
|
1394
|
+
elif is_open_clip_sd3_model(checkpoint):
|
1395
|
+
prefix = "text_encoders.clip_g.transformer."
|
1396
|
+
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
1397
|
+
|
1367
1398
|
else:
|
1368
1399
|
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
1369
1400
|
|
@@ -1559,3 +1590,212 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
|
|
1559
1590
|
)
|
1560
1591
|
|
1561
1592
|
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
|
1593
|
+
|
1594
|
+
|
1595
|
+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
1596
|
+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
1597
|
+
def swap_scale_shift(weight, dim):
|
1598
|
+
shift, scale = weight.chunk(2, dim=0)
|
1599
|
+
new_weight = torch.cat([scale, shift], dim=0)
|
1600
|
+
return new_weight
|
1601
|
+
|
1602
|
+
|
1603
|
+
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1604
|
+
converted_state_dict = {}
|
1605
|
+
keys = list(checkpoint.keys())
|
1606
|
+
for k in keys:
|
1607
|
+
if "model.diffusion_model." in k:
|
1608
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
1609
|
+
|
1610
|
+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
|
1611
|
+
caption_projection_dim = 1536
|
1612
|
+
|
1613
|
+
# Positional and patch embeddings.
|
1614
|
+
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
|
1615
|
+
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
1616
|
+
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
1617
|
+
|
1618
|
+
# Timestep embeddings.
|
1619
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
1620
|
+
"t_embedder.mlp.0.weight"
|
1621
|
+
)
|
1622
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
1623
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
1624
|
+
"t_embedder.mlp.2.weight"
|
1625
|
+
)
|
1626
|
+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
1627
|
+
|
1628
|
+
# Context projections.
|
1629
|
+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
|
1630
|
+
converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
|
1631
|
+
|
1632
|
+
# Pooled context projection.
|
1633
|
+
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
|
1634
|
+
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
|
1635
|
+
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
|
1636
|
+
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
|
1637
|
+
|
1638
|
+
# Transformer blocks 🎸.
|
1639
|
+
for i in range(num_layers):
|
1640
|
+
# Q, K, V
|
1641
|
+
sample_q, sample_k, sample_v = torch.chunk(
|
1642
|
+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
|
1643
|
+
)
|
1644
|
+
context_q, context_k, context_v = torch.chunk(
|
1645
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
|
1646
|
+
)
|
1647
|
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
1648
|
+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
|
1649
|
+
)
|
1650
|
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
1651
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
|
1652
|
+
)
|
1653
|
+
|
1654
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
|
1655
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
|
1656
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
|
1657
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
|
1658
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
|
1659
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
|
1660
|
+
|
1661
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
|
1662
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
1663
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
|
1664
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
1665
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
|
1666
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
1667
|
+
|
1668
|
+
# output projections.
|
1669
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
|
1670
|
+
f"joint_blocks.{i}.x_block.attn.proj.weight"
|
1671
|
+
)
|
1672
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
|
1673
|
+
f"joint_blocks.{i}.x_block.attn.proj.bias"
|
1674
|
+
)
|
1675
|
+
if not (i == num_layers - 1):
|
1676
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
|
1677
|
+
f"joint_blocks.{i}.context_block.attn.proj.weight"
|
1678
|
+
)
|
1679
|
+
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
|
1680
|
+
f"joint_blocks.{i}.context_block.attn.proj.bias"
|
1681
|
+
)
|
1682
|
+
|
1683
|
+
# norms.
|
1684
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
|
1685
|
+
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
|
1686
|
+
)
|
1687
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
|
1688
|
+
f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
|
1689
|
+
)
|
1690
|
+
if not (i == num_layers - 1):
|
1691
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
|
1692
|
+
f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
|
1693
|
+
)
|
1694
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
|
1695
|
+
f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
|
1696
|
+
)
|
1697
|
+
else:
|
1698
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
|
1699
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
|
1700
|
+
dim=caption_projection_dim,
|
1701
|
+
)
|
1702
|
+
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
|
1703
|
+
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
|
1704
|
+
dim=caption_projection_dim,
|
1705
|
+
)
|
1706
|
+
|
1707
|
+
# ffs.
|
1708
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
|
1709
|
+
f"joint_blocks.{i}.x_block.mlp.fc1.weight"
|
1710
|
+
)
|
1711
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
|
1712
|
+
f"joint_blocks.{i}.x_block.mlp.fc1.bias"
|
1713
|
+
)
|
1714
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
|
1715
|
+
f"joint_blocks.{i}.x_block.mlp.fc2.weight"
|
1716
|
+
)
|
1717
|
+
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
|
1718
|
+
f"joint_blocks.{i}.x_block.mlp.fc2.bias"
|
1719
|
+
)
|
1720
|
+
if not (i == num_layers - 1):
|
1721
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
|
1722
|
+
f"joint_blocks.{i}.context_block.mlp.fc1.weight"
|
1723
|
+
)
|
1724
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
|
1725
|
+
f"joint_blocks.{i}.context_block.mlp.fc1.bias"
|
1726
|
+
)
|
1727
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
|
1728
|
+
f"joint_blocks.{i}.context_block.mlp.fc2.weight"
|
1729
|
+
)
|
1730
|
+
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
|
1731
|
+
f"joint_blocks.{i}.context_block.mlp.fc2.bias"
|
1732
|
+
)
|
1733
|
+
|
1734
|
+
# Final blocks.
|
1735
|
+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
1736
|
+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
1737
|
+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
1738
|
+
checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
|
1739
|
+
)
|
1740
|
+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
1741
|
+
checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
|
1742
|
+
)
|
1743
|
+
|
1744
|
+
return converted_state_dict
|
1745
|
+
|
1746
|
+
|
1747
|
+
def is_t5_in_single_file(checkpoint):
|
1748
|
+
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
|
1749
|
+
return True
|
1750
|
+
|
1751
|
+
return False
|
1752
|
+
|
1753
|
+
|
1754
|
+
def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
|
1755
|
+
keys = list(checkpoint.keys())
|
1756
|
+
text_model_dict = {}
|
1757
|
+
|
1758
|
+
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
|
1759
|
+
|
1760
|
+
for key in keys:
|
1761
|
+
for prefix in remove_prefixes:
|
1762
|
+
if key.startswith(prefix):
|
1763
|
+
diffusers_key = key.replace(prefix, "")
|
1764
|
+
text_model_dict[diffusers_key] = checkpoint.get(key)
|
1765
|
+
|
1766
|
+
return text_model_dict
|
1767
|
+
|
1768
|
+
|
1769
|
+
def create_diffusers_t5_model_from_checkpoint(
|
1770
|
+
cls,
|
1771
|
+
checkpoint,
|
1772
|
+
subfolder="",
|
1773
|
+
config=None,
|
1774
|
+
torch_dtype=None,
|
1775
|
+
local_files_only=None,
|
1776
|
+
):
|
1777
|
+
if config:
|
1778
|
+
config = {"pretrained_model_name_or_path": config}
|
1779
|
+
else:
|
1780
|
+
config = fetch_diffusers_config(checkpoint)
|
1781
|
+
|
1782
|
+
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
|
1783
|
+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1784
|
+
with ctx():
|
1785
|
+
model = cls(model_config)
|
1786
|
+
|
1787
|
+
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
|
1788
|
+
|
1789
|
+
if is_accelerate_available():
|
1790
|
+
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1791
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
1792
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
1793
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1794
|
+
|
1795
|
+
if len(unexpected_keys) > 0:
|
1796
|
+
logger.warning(
|
1797
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1798
|
+
)
|
1799
|
+
|
1800
|
+
else:
|
1801
|
+
model.load_state_dict(diffusers_format_checkpoint)
|