diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. diffusers/__init__.py +9 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +1 -1
  9. diffusers/loaders/single_file_model.py +5 -0
  10. diffusers/loaders/single_file_utils.py +242 -2
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +5 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_xs.py +6 -6
  22. diffusers/models/embeddings.py +112 -84
  23. diffusers/models/model_loading_utils.py +55 -0
  24. diffusers/models/modeling_utils.py +128 -17
  25. diffusers/models/normalization.py +11 -6
  26. diffusers/models/transformers/__init__.py +1 -0
  27. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  28. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  29. diffusers/models/transformers/prior_transformer.py +5 -5
  30. diffusers/models/transformers/transformer_2d.py +2 -2
  31. diffusers/models/transformers/transformer_sd3.py +344 -0
  32. diffusers/models/transformers/transformer_temporal.py +12 -10
  33. diffusers/models/unets/unet_1d.py +3 -3
  34. diffusers/models/unets/unet_2d.py +3 -3
  35. diffusers/models/unets/unet_2d_condition.py +4 -15
  36. diffusers/models/unets/unet_3d_condition.py +5 -17
  37. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  38. diffusers/models/unets/unet_motion_model.py +4 -4
  39. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  40. diffusers/models/vq_model.py +8 -165
  41. diffusers/pipelines/__init__.py +2 -0
  42. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  43. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  44. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  45. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  46. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  47. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  48. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  49. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  50. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  51. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  52. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  54. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  55. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  56. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  57. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  58. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  59. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  60. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  61. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  69. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  70. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  71. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
  72. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
  73. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  74. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  75. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  76. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  77. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  78. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  79. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  80. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  81. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  82. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  83. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  84. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  85. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  86. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  87. diffusers/schedulers/__init__.py +2 -0
  88. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  89. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  90. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  91. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  92. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  93. diffusers/training_utils.py +4 -4
  94. diffusers/utils/__init__.py +3 -0
  95. diffusers/utils/constants.py +2 -0
  96. diffusers/utils/dummy_pt_objects.py +30 -0
  97. diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
  98. diffusers/utils/dynamic_modules_utils.py +15 -13
  99. diffusers/utils/hub_utils.py +106 -0
  100. diffusers/utils/import_utils.py +0 -1
  101. diffusers/utils/logging.py +3 -1
  102. diffusers/utils/state_dict_utils.py +2 -0
  103. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
  104. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
  105. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
  106. diffusers/models/dual_transformer_2d.py +0 -20
  107. diffusers/models/prior_transformer.py +0 -12
  108. diffusers/models/t5_film_transformer.py +0 -70
  109. diffusers/models/transformer_2d.py +0 -25
  110. diffusers/models/transformer_temporal.py +0 -34
  111. diffusers/models/unet_1d.py +0 -26
  112. diffusers/models/unet_1d_blocks.py +0 -203
  113. diffusers/models/unet_2d.py +0 -27
  114. diffusers/models/unet_2d_blocks.py +0 -375
  115. diffusers/models/unet_2d_condition.py +0 -25
  116. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
  117. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
  118. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
226
226
  diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
227
227
  diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
228
228
  diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
229
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
230
+
229
231
  if "self_attn" in diffusers_name:
230
232
  if lora_name.startswith(("lora_te_", "lora_te1_")):
231
233
  te_state_dict[diffusers_name] = state_dict.pop(key)
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
243
245
  else:
244
246
  te2_state_dict[diffusers_name] = state_dict.pop(key)
245
247
  te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
248
+ # OneTrainer specificity
249
+ elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
250
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
251
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
246
252
 
247
253
  if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
248
254
  dora_scale_key_to_replace_te = (
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
270
276
  network_alphas.update({new_name: alpha})
271
277
 
272
278
  if len(state_dict) > 0:
273
- raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
279
+ raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
274
280
 
275
281
  logger.info("Kohya-style checkpoint detected.")
276
282
  unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
@@ -234,7 +234,7 @@ def _download_diffusers_model_config_from_hub(
234
234
  local_files_only=None,
235
235
  token=None,
236
236
  ):
237
- allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
237
+ allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
238
238
  cached_model_path = snapshot_download(
239
239
  pretrained_model_name_or_path,
240
240
  cache_dir=cache_dir,
@@ -24,6 +24,7 @@ from .single_file_utils import (
24
24
  convert_controlnet_checkpoint,
25
25
  convert_ldm_unet_checkpoint,
26
26
  convert_ldm_vae_checkpoint,
27
+ convert_sd3_transformer_checkpoint_to_diffusers,
27
28
  convert_stable_cascade_unet_single_file_to_diffusers,
28
29
  create_controlnet_diffusers_config_from_ldm,
29
30
  create_unet_diffusers_config_from_ldm,
@@ -64,6 +65,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
64
65
  "checkpoint_mapping_fn": convert_controlnet_checkpoint,
65
66
  "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
66
67
  },
68
+ "SD3Transformer2DModel": {
69
+ "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
70
+ "default_subfolder": "transformer",
71
+ },
67
72
  }
68
73
 
69
74
 
@@ -21,6 +21,7 @@ from io import BytesIO
21
21
  from urllib.parse import urlparse
22
22
 
23
23
  import requests
24
+ import torch
24
25
  import yaml
25
26
 
26
27
  from ..models.modeling_utils import load_state_dict
@@ -65,11 +66,14 @@ CHECKPOINT_KEY_NAMES = {
65
66
  "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
66
67
  "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
67
68
  "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
69
+ "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
68
70
  "open_clip": "cond_stage_model.model.token_embedding.weight",
69
71
  "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
70
72
  "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
73
+ "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
71
74
  "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
72
75
  "stable_cascade_stage_c": "clip_txt_mapper.weight",
76
+ "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
73
77
  }
74
78
 
75
79
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -96,6 +100,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
96
100
  "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
97
101
  "subfolder": "prior_lite",
98
102
  },
103
+ "sd3": {
104
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
105
+ },
99
106
  }
100
107
 
101
108
  # Use to configure model sample size when original config is provided
@@ -242,7 +249,11 @@ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
242
249
  PLAYGROUND_VAE_SCALING_FACTOR = 0.5
243
250
  LDM_UNET_KEY = "model.diffusion_model."
244
251
  LDM_CONTROLNET_KEY = "control_model."
245
- LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
252
+ LDM_CLIP_PREFIX_TO_REMOVE = [
253
+ "cond_stage_model.transformer.",
254
+ "conditioner.embedders.0.transformer.",
255
+ "text_encoders.clip_l.transformer.",
256
+ ]
246
257
  OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
247
258
  LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
248
259
 
@@ -366,6 +377,13 @@ def is_clip_sdxl_model(checkpoint):
366
377
  return False
367
378
 
368
379
 
380
+ def is_clip_sd3_model(checkpoint):
381
+ if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
382
+ return True
383
+
384
+ return False
385
+
386
+
369
387
  def is_open_clip_model(checkpoint):
370
388
  if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
371
389
  return True
@@ -380,8 +398,12 @@ def is_open_clip_sdxl_model(checkpoint):
380
398
  return False
381
399
 
382
400
 
401
+ def is_open_clip_sd3_model(checkpoint):
402
+ is_open_clip_sdxl_refiner_model(checkpoint)
403
+
404
+
383
405
  def is_open_clip_sdxl_refiner_model(checkpoint):
384
- if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
406
+ if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
385
407
  return True
386
408
 
387
409
  return False
@@ -391,9 +413,11 @@ def is_clip_model_in_single_file(class_obj, checkpoint):
391
413
  is_clip_in_checkpoint = any(
392
414
  [
393
415
  is_clip_model(checkpoint),
416
+ is_clip_sd3_model(checkpoint),
394
417
  is_open_clip_model(checkpoint),
395
418
  is_open_clip_sdxl_model(checkpoint),
396
419
  is_open_clip_sdxl_refiner_model(checkpoint),
420
+ is_open_clip_sd3_model(checkpoint),
397
421
  ]
398
422
  )
399
423
  if (
@@ -456,6 +480,9 @@ def infer_diffusers_model_type(checkpoint):
456
480
  ):
457
481
  model_type = "stable_cascade_stage_b"
458
482
 
483
+ elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
484
+ model_type = "sd3"
485
+
459
486
  else:
460
487
  model_type = "v1"
461
488
 
@@ -1364,6 +1391,10 @@ def create_diffusers_clip_model_from_ldm(
1364
1391
  prefix = "conditioner.embedders.0.model."
1365
1392
  diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1366
1393
 
1394
+ elif is_open_clip_sd3_model(checkpoint):
1395
+ prefix = "text_encoders.clip_g.transformer."
1396
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1397
+
1367
1398
  else:
1368
1399
  raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1369
1400
 
@@ -1559,3 +1590,212 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
1559
1590
  )
1560
1591
 
1561
1592
  return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
1593
+
1594
+
1595
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1596
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1597
+ def swap_scale_shift(weight, dim):
1598
+ shift, scale = weight.chunk(2, dim=0)
1599
+ new_weight = torch.cat([scale, shift], dim=0)
1600
+ return new_weight
1601
+
1602
+
1603
+ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1604
+ converted_state_dict = {}
1605
+ keys = list(checkpoint.keys())
1606
+ for k in keys:
1607
+ if "model.diffusion_model." in k:
1608
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1609
+
1610
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1611
+ caption_projection_dim = 1536
1612
+
1613
+ # Positional and patch embeddings.
1614
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
1615
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
1616
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
1617
+
1618
+ # Timestep embeddings.
1619
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1620
+ "t_embedder.mlp.0.weight"
1621
+ )
1622
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
1623
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1624
+ "t_embedder.mlp.2.weight"
1625
+ )
1626
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
1627
+
1628
+ # Context projections.
1629
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
1630
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
1631
+
1632
+ # Pooled context projection.
1633
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
1634
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
1635
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
1636
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
1637
+
1638
+ # Transformer blocks 🎸.
1639
+ for i in range(num_layers):
1640
+ # Q, K, V
1641
+ sample_q, sample_k, sample_v = torch.chunk(
1642
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
1643
+ )
1644
+ context_q, context_k, context_v = torch.chunk(
1645
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
1646
+ )
1647
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1648
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
1649
+ )
1650
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1651
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
1652
+ )
1653
+
1654
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
1655
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
1656
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
1657
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
1658
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
1659
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
1660
+
1661
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
1662
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
1663
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
1664
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
1665
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
1666
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1667
+
1668
+ # output projections.
1669
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
1670
+ f"joint_blocks.{i}.x_block.attn.proj.weight"
1671
+ )
1672
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
1673
+ f"joint_blocks.{i}.x_block.attn.proj.bias"
1674
+ )
1675
+ if not (i == num_layers - 1):
1676
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
1677
+ f"joint_blocks.{i}.context_block.attn.proj.weight"
1678
+ )
1679
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
1680
+ f"joint_blocks.{i}.context_block.attn.proj.bias"
1681
+ )
1682
+
1683
+ # norms.
1684
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
1685
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
1686
+ )
1687
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
1688
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
1689
+ )
1690
+ if not (i == num_layers - 1):
1691
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
1692
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
1693
+ )
1694
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
1695
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
1696
+ )
1697
+ else:
1698
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
1699
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
1700
+ dim=caption_projection_dim,
1701
+ )
1702
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
1703
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
1704
+ dim=caption_projection_dim,
1705
+ )
1706
+
1707
+ # ffs.
1708
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
1709
+ f"joint_blocks.{i}.x_block.mlp.fc1.weight"
1710
+ )
1711
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
1712
+ f"joint_blocks.{i}.x_block.mlp.fc1.bias"
1713
+ )
1714
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
1715
+ f"joint_blocks.{i}.x_block.mlp.fc2.weight"
1716
+ )
1717
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
1718
+ f"joint_blocks.{i}.x_block.mlp.fc2.bias"
1719
+ )
1720
+ if not (i == num_layers - 1):
1721
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
1722
+ f"joint_blocks.{i}.context_block.mlp.fc1.weight"
1723
+ )
1724
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
1725
+ f"joint_blocks.{i}.context_block.mlp.fc1.bias"
1726
+ )
1727
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
1728
+ f"joint_blocks.{i}.context_block.mlp.fc2.weight"
1729
+ )
1730
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
1731
+ f"joint_blocks.{i}.context_block.mlp.fc2.bias"
1732
+ )
1733
+
1734
+ # Final blocks.
1735
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
1736
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
1737
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
1738
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
1739
+ )
1740
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
1741
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
1742
+ )
1743
+
1744
+ return converted_state_dict
1745
+
1746
+
1747
+ def is_t5_in_single_file(checkpoint):
1748
+ if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
1749
+ return True
1750
+
1751
+ return False
1752
+
1753
+
1754
+ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
1755
+ keys = list(checkpoint.keys())
1756
+ text_model_dict = {}
1757
+
1758
+ remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
1759
+
1760
+ for key in keys:
1761
+ for prefix in remove_prefixes:
1762
+ if key.startswith(prefix):
1763
+ diffusers_key = key.replace(prefix, "")
1764
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1765
+
1766
+ return text_model_dict
1767
+
1768
+
1769
+ def create_diffusers_t5_model_from_checkpoint(
1770
+ cls,
1771
+ checkpoint,
1772
+ subfolder="",
1773
+ config=None,
1774
+ torch_dtype=None,
1775
+ local_files_only=None,
1776
+ ):
1777
+ if config:
1778
+ config = {"pretrained_model_name_or_path": config}
1779
+ else:
1780
+ config = fetch_diffusers_config(checkpoint)
1781
+
1782
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1783
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
1784
+ with ctx():
1785
+ model = cls(model_config)
1786
+
1787
+ diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
1788
+
1789
+ if is_accelerate_available():
1790
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1791
+ if model._keys_to_ignore_on_load_unexpected is not None:
1792
+ for pat in model._keys_to_ignore_on_load_unexpected:
1793
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1794
+
1795
+ if len(unexpected_keys) > 0:
1796
+ logger.warning(
1797
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1798
+ )
1799
+
1800
+ else:
1801
+ model.load_state_dict(diffusers_format_checkpoint)