diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. diffusers/__init__.py +15 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +13 -1
  9. diffusers/loaders/single_file_model.py +15 -8
  10. diffusers/loaders/single_file_utils.py +267 -17
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +7 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_sd3.py +418 -0
  22. diffusers/models/controlnet_xs.py +6 -6
  23. diffusers/models/embeddings.py +112 -84
  24. diffusers/models/model_loading_utils.py +55 -0
  25. diffusers/models/modeling_utils.py +138 -20
  26. diffusers/models/normalization.py +11 -6
  27. diffusers/models/transformers/__init__.py +1 -0
  28. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  29. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  30. diffusers/models/transformers/prior_transformer.py +5 -5
  31. diffusers/models/transformers/transformer_2d.py +2 -2
  32. diffusers/models/transformers/transformer_sd3.py +353 -0
  33. diffusers/models/transformers/transformer_temporal.py +12 -10
  34. diffusers/models/unets/unet_1d.py +3 -3
  35. diffusers/models/unets/unet_2d.py +3 -3
  36. diffusers/models/unets/unet_2d_condition.py +4 -15
  37. diffusers/models/unets/unet_3d_condition.py +5 -17
  38. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  39. diffusers/models/unets/unet_motion_model.py +4 -4
  40. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  41. diffusers/models/vq_model.py +8 -165
  42. diffusers/pipelines/__init__.py +11 -0
  43. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  45. diffusers/pipelines/auto_pipeline.py +8 -0
  46. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  47. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  48. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  49. diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
  50. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
  51. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  52. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  54. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  55. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  56. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  57. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  58. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  59. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  60. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  61. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  62. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  63. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  64. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  65. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  72. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  73. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  74. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  75. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
  76. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
  77. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  78. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  79. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  80. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  81. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  82. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  83. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  84. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  85. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  86. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  87. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  88. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  89. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  90. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  91. diffusers/schedulers/__init__.py +2 -0
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  93. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  94. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  95. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  96. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  97. diffusers/training_utils.py +4 -4
  98. diffusers/utils/__init__.py +3 -0
  99. diffusers/utils/constants.py +2 -0
  100. diffusers/utils/dummy_pt_objects.py +60 -0
  101. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  102. diffusers/utils/dynamic_modules_utils.py +15 -13
  103. diffusers/utils/hub_utils.py +106 -0
  104. diffusers/utils/import_utils.py +0 -1
  105. diffusers/utils/logging.py +3 -1
  106. diffusers/utils/state_dict_utils.py +2 -0
  107. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
  108. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
  109. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
  110. diffusers/models/dual_transformer_2d.py +0 -20
  111. diffusers/models/prior_transformer.py +0 -12
  112. diffusers/models/t5_film_transformer.py +0 -70
  113. diffusers/models/transformer_2d.py +0 -25
  114. diffusers/models/transformer_temporal.py +0 -34
  115. diffusers/models/unet_1d.py +0 -26
  116. diffusers/models/unet_1d_blocks.py +0 -203
  117. diffusers/models/unet_2d.py +0 -27
  118. diffusers/models/unet_2d_blocks.py +0 -375
  119. diffusers/models/unet_2d_condition.py +0 -25
  120. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
  121. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
  122. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
226
226
  diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
227
227
  diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
228
228
  diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
229
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
230
+
229
231
  if "self_attn" in diffusers_name:
230
232
  if lora_name.startswith(("lora_te_", "lora_te1_")):
231
233
  te_state_dict[diffusers_name] = state_dict.pop(key)
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
243
245
  else:
244
246
  te2_state_dict[diffusers_name] = state_dict.pop(key)
245
247
  te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
248
+ # OneTrainer specificity
249
+ elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
250
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
251
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
246
252
 
247
253
  if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
248
254
  dora_scale_key_to_replace_te = (
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
270
276
  network_alphas.update({new_name: alpha})
271
277
 
272
278
  if len(state_dict) > 0:
273
- raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
279
+ raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
274
280
 
275
281
  logger.info("Kohya-style checkpoint detected.")
276
282
  unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
@@ -28,9 +28,11 @@ from .single_file_utils import (
28
28
  _legacy_load_safety_checker,
29
29
  _legacy_load_scheduler,
30
30
  create_diffusers_clip_model_from_ldm,
31
+ create_diffusers_t5_model_from_checkpoint,
31
32
  fetch_diffusers_config,
32
33
  fetch_original_config,
33
34
  is_clip_model_in_single_file,
35
+ is_t5_in_single_file,
34
36
  load_single_file_checkpoint,
35
37
  )
36
38
 
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
118
120
  is_legacy_loading=is_legacy_loading,
119
121
  )
120
122
 
123
+ elif is_transformers_model and is_t5_in_single_file(checkpoint):
124
+ loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
125
+ class_obj,
126
+ checkpoint=checkpoint,
127
+ config=cached_model_config_path,
128
+ subfolder=name,
129
+ torch_dtype=torch_dtype,
130
+ local_files_only=local_files_only,
131
+ )
132
+
121
133
  elif is_tokenizer and is_legacy_loading:
122
134
  loaded_sub_model = _legacy_load_clip_tokenizer(
123
135
  class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
@@ -234,7 +246,7 @@ def _download_diffusers_model_config_from_hub(
234
246
  local_files_only=None,
235
247
  token=None,
236
248
  ):
237
- allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
249
+ allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
238
250
  cached_model_path = snapshot_download(
239
251
  pretrained_model_name_or_path,
240
252
  cache_dir=cache_dir,
@@ -24,6 +24,7 @@ from .single_file_utils import (
24
24
  convert_controlnet_checkpoint,
25
25
  convert_ldm_unet_checkpoint,
26
26
  convert_ldm_vae_checkpoint,
27
+ convert_sd3_transformer_checkpoint_to_diffusers,
27
28
  convert_stable_cascade_unet_single_file_to_diffusers,
28
29
  create_controlnet_diffusers_config_from_ldm,
29
30
  create_unet_diffusers_config_from_ldm,
@@ -64,6 +65,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
64
65
  "checkpoint_mapping_fn": convert_controlnet_checkpoint,
65
66
  "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
66
67
  },
68
+ "SD3Transformer2DModel": {
69
+ "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
70
+ "default_subfolder": "transformer",
71
+ },
67
72
  }
68
73
 
69
74
 
@@ -271,16 +276,18 @@ class FromOriginalModelMixin:
271
276
 
272
277
  if is_accelerate_available():
273
278
  unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
274
- if model._keys_to_ignore_on_load_unexpected is not None:
275
- for pat in model._keys_to_ignore_on_load_unexpected:
276
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
277
279
 
278
- if len(unexpected_keys) > 0:
279
- logger.warning(
280
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
281
- )
282
280
  else:
283
- model.load_state_dict(diffusers_format_checkpoint)
281
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
282
+
283
+ if model._keys_to_ignore_on_load_unexpected is not None:
284
+ for pat in model._keys_to_ignore_on_load_unexpected:
285
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
286
+
287
+ if len(unexpected_keys) > 0:
288
+ logger.warning(
289
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
290
+ )
284
291
 
285
292
  if torch_dtype is not None:
286
293
  model.to(torch_dtype)
@@ -21,6 +21,7 @@ from io import BytesIO
21
21
  from urllib.parse import urlparse
22
22
 
23
23
  import requests
24
+ import torch
24
25
  import yaml
25
26
 
26
27
  from ..models.modeling_utils import load_state_dict
@@ -65,11 +66,14 @@ CHECKPOINT_KEY_NAMES = {
65
66
  "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
66
67
  "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
67
68
  "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
69
+ "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
68
70
  "open_clip": "cond_stage_model.model.token_embedding.weight",
69
71
  "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
70
72
  "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
73
+ "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
71
74
  "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
72
75
  "stable_cascade_stage_c": "clip_txt_mapper.weight",
76
+ "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
73
77
  }
74
78
 
75
79
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -96,6 +100,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
96
100
  "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
97
101
  "subfolder": "prior_lite",
98
102
  },
103
+ "sd3": {
104
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
105
+ },
99
106
  }
100
107
 
101
108
  # Use to configure model sample size when original config is provided
@@ -242,7 +249,10 @@ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
242
249
  PLAYGROUND_VAE_SCALING_FACTOR = 0.5
243
250
  LDM_UNET_KEY = "model.diffusion_model."
244
251
  LDM_CONTROLNET_KEY = "control_model."
245
- LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
252
+ LDM_CLIP_PREFIX_TO_REMOVE = [
253
+ "cond_stage_model.transformer.",
254
+ "conditioner.embedders.0.transformer.",
255
+ ]
246
256
  OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
247
257
  LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
248
258
 
@@ -366,6 +376,13 @@ def is_clip_sdxl_model(checkpoint):
366
376
  return False
367
377
 
368
378
 
379
+ def is_clip_sd3_model(checkpoint):
380
+ if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
381
+ return True
382
+
383
+ return False
384
+
385
+
369
386
  def is_open_clip_model(checkpoint):
370
387
  if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
371
388
  return True
@@ -380,6 +397,13 @@ def is_open_clip_sdxl_model(checkpoint):
380
397
  return False
381
398
 
382
399
 
400
+ def is_open_clip_sd3_model(checkpoint):
401
+ if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
402
+ return True
403
+
404
+ return False
405
+
406
+
383
407
  def is_open_clip_sdxl_refiner_model(checkpoint):
384
408
  if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
385
409
  return True
@@ -391,9 +415,11 @@ def is_clip_model_in_single_file(class_obj, checkpoint):
391
415
  is_clip_in_checkpoint = any(
392
416
  [
393
417
  is_clip_model(checkpoint),
418
+ is_clip_sd3_model(checkpoint),
394
419
  is_open_clip_model(checkpoint),
395
420
  is_open_clip_sdxl_model(checkpoint),
396
421
  is_open_clip_sdxl_refiner_model(checkpoint),
422
+ is_open_clip_sd3_model(checkpoint),
397
423
  ]
398
424
  )
399
425
  if (
@@ -456,6 +482,9 @@ def infer_diffusers_model_type(checkpoint):
456
482
  ):
457
483
  model_type = "stable_cascade_stage_b"
458
484
 
485
+ elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
486
+ model_type = "sd3"
487
+
459
488
  else:
460
489
  model_type = "v1"
461
490
 
@@ -1206,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1206
1235
  return new_checkpoint
1207
1236
 
1208
1237
 
1209
- def convert_ldm_clip_checkpoint(checkpoint):
1238
+ def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
1210
1239
  keys = list(checkpoint.keys())
1211
1240
  text_model_dict = {}
1212
1241
 
1213
- remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
1242
+ remove_prefixes = []
1243
+ remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
1244
+ if remove_prefix:
1245
+ remove_prefixes.append(remove_prefix)
1214
1246
 
1215
1247
  for key in keys:
1216
1248
  for prefix in remove_prefixes:
@@ -1236,8 +1268,6 @@ def convert_open_clip_checkpoint(
1236
1268
  else:
1237
1269
  text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1238
1270
 
1239
- text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
1240
-
1241
1271
  keys = list(checkpoint.keys())
1242
1272
  keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
1243
1273
 
@@ -1286,9 +1316,6 @@ def convert_open_clip_checkpoint(
1286
1316
  else:
1287
1317
  text_model_dict[diffusers_key] = checkpoint.get(key)
1288
1318
 
1289
- if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1290
- text_model_dict.pop("text_model.embeddings.position_ids", None)
1291
-
1292
1319
  return text_model_dict
1293
1320
 
1294
1321
 
@@ -1349,6 +1376,13 @@ def create_diffusers_clip_model_from_ldm(
1349
1376
  ):
1350
1377
  diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1351
1378
 
1379
+ elif (
1380
+ is_clip_sd3_model(checkpoint)
1381
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
1382
+ ):
1383
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
1384
+ diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
1385
+
1352
1386
  elif is_open_clip_model(checkpoint):
1353
1387
  prefix = "cond_stage_model.model."
1354
1388
  diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
@@ -1364,22 +1398,28 @@ def create_diffusers_clip_model_from_ldm(
1364
1398
  prefix = "conditioner.embedders.0.model."
1365
1399
  diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1366
1400
 
1401
+ elif (
1402
+ is_open_clip_sd3_model(checkpoint)
1403
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
1404
+ ):
1405
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
1406
+
1367
1407
  else:
1368
1408
  raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1369
1409
 
1370
1410
  if is_accelerate_available():
1371
1411
  unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1372
- if model._keys_to_ignore_on_load_unexpected is not None:
1373
- for pat in model._keys_to_ignore_on_load_unexpected:
1374
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1412
+ else:
1413
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1375
1414
 
1376
- if len(unexpected_keys) > 0:
1377
- logger.warning(
1378
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1379
- )
1415
+ if model._keys_to_ignore_on_load_unexpected is not None:
1416
+ for pat in model._keys_to_ignore_on_load_unexpected:
1417
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1380
1418
 
1381
- else:
1382
- model.load_state_dict(diffusers_format_checkpoint)
1419
+ if len(unexpected_keys) > 0:
1420
+ logger.warning(
1421
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1422
+ )
1383
1423
 
1384
1424
  if torch_dtype is not None:
1385
1425
  model.to(torch_dtype)
@@ -1559,3 +1599,213 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
1559
1599
  )
1560
1600
 
1561
1601
  return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
1602
+
1603
+
1604
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1605
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1606
+ def swap_scale_shift(weight, dim):
1607
+ shift, scale = weight.chunk(2, dim=0)
1608
+ new_weight = torch.cat([scale, shift], dim=0)
1609
+ return new_weight
1610
+
1611
+
1612
+ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1613
+ converted_state_dict = {}
1614
+ keys = list(checkpoint.keys())
1615
+ for k in keys:
1616
+ if "model.diffusion_model." in k:
1617
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1618
+
1619
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1620
+ caption_projection_dim = 1536
1621
+
1622
+ # Positional and patch embeddings.
1623
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
1624
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
1625
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
1626
+
1627
+ # Timestep embeddings.
1628
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1629
+ "t_embedder.mlp.0.weight"
1630
+ )
1631
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
1632
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1633
+ "t_embedder.mlp.2.weight"
1634
+ )
1635
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
1636
+
1637
+ # Context projections.
1638
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
1639
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
1640
+
1641
+ # Pooled context projection.
1642
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
1643
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
1644
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
1645
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
1646
+
1647
+ # Transformer blocks 🎸.
1648
+ for i in range(num_layers):
1649
+ # Q, K, V
1650
+ sample_q, sample_k, sample_v = torch.chunk(
1651
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
1652
+ )
1653
+ context_q, context_k, context_v = torch.chunk(
1654
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
1655
+ )
1656
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1657
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
1658
+ )
1659
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1660
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
1661
+ )
1662
+
1663
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
1664
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
1665
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
1666
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
1667
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
1668
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
1669
+
1670
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
1671
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
1672
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
1673
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
1674
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
1675
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1676
+
1677
+ # output projections.
1678
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
1679
+ f"joint_blocks.{i}.x_block.attn.proj.weight"
1680
+ )
1681
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
1682
+ f"joint_blocks.{i}.x_block.attn.proj.bias"
1683
+ )
1684
+ if not (i == num_layers - 1):
1685
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
1686
+ f"joint_blocks.{i}.context_block.attn.proj.weight"
1687
+ )
1688
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
1689
+ f"joint_blocks.{i}.context_block.attn.proj.bias"
1690
+ )
1691
+
1692
+ # norms.
1693
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
1694
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
1695
+ )
1696
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
1697
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
1698
+ )
1699
+ if not (i == num_layers - 1):
1700
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
1701
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
1702
+ )
1703
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
1704
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
1705
+ )
1706
+ else:
1707
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
1708
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
1709
+ dim=caption_projection_dim,
1710
+ )
1711
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
1712
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
1713
+ dim=caption_projection_dim,
1714
+ )
1715
+
1716
+ # ffs.
1717
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
1718
+ f"joint_blocks.{i}.x_block.mlp.fc1.weight"
1719
+ )
1720
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
1721
+ f"joint_blocks.{i}.x_block.mlp.fc1.bias"
1722
+ )
1723
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
1724
+ f"joint_blocks.{i}.x_block.mlp.fc2.weight"
1725
+ )
1726
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
1727
+ f"joint_blocks.{i}.x_block.mlp.fc2.bias"
1728
+ )
1729
+ if not (i == num_layers - 1):
1730
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
1731
+ f"joint_blocks.{i}.context_block.mlp.fc1.weight"
1732
+ )
1733
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
1734
+ f"joint_blocks.{i}.context_block.mlp.fc1.bias"
1735
+ )
1736
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
1737
+ f"joint_blocks.{i}.context_block.mlp.fc2.weight"
1738
+ )
1739
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
1740
+ f"joint_blocks.{i}.context_block.mlp.fc2.bias"
1741
+ )
1742
+
1743
+ # Final blocks.
1744
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
1745
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
1746
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
1747
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
1748
+ )
1749
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
1750
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
1751
+ )
1752
+
1753
+ return converted_state_dict
1754
+
1755
+
1756
+ def is_t5_in_single_file(checkpoint):
1757
+ if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
1758
+ return True
1759
+
1760
+ return False
1761
+
1762
+
1763
+ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
1764
+ keys = list(checkpoint.keys())
1765
+ text_model_dict = {}
1766
+
1767
+ remove_prefixes = ["text_encoders.t5xxl.transformer."]
1768
+
1769
+ for key in keys:
1770
+ for prefix in remove_prefixes:
1771
+ if key.startswith(prefix):
1772
+ diffusers_key = key.replace(prefix, "")
1773
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1774
+
1775
+ return text_model_dict
1776
+
1777
+
1778
+ def create_diffusers_t5_model_from_checkpoint(
1779
+ cls,
1780
+ checkpoint,
1781
+ subfolder="",
1782
+ config=None,
1783
+ torch_dtype=None,
1784
+ local_files_only=None,
1785
+ ):
1786
+ if config:
1787
+ config = {"pretrained_model_name_or_path": config}
1788
+ else:
1789
+ config = fetch_diffusers_config(checkpoint)
1790
+
1791
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1792
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
1793
+ with ctx():
1794
+ model = cls(model_config)
1795
+
1796
+ diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
1797
+
1798
+ if is_accelerate_available():
1799
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1800
+ if model._keys_to_ignore_on_load_unexpected is not None:
1801
+ for pat in model._keys_to_ignore_on_load_unexpected:
1802
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1803
+
1804
+ if len(unexpected_keys) > 0:
1805
+ logger.warning(
1806
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1807
+ )
1808
+
1809
+ else:
1810
+ model.load_state_dict(diffusers_format_checkpoint)
1811
+ return model