diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -50,8 +50,10 @@ else:
50
50
  from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
51
51
  from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
52
52
  from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
53
+ from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
53
54
  from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
54
55
  from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
56
+ from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
55
57
  from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
56
58
  from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
57
59
  from .pipeline_stable_unclip import StableUnCLIPPipeline
@@ -24,6 +24,7 @@ from transformers import (
24
24
  AutoFeatureExtractor,
25
25
  BertTokenizerFast,
26
26
  CLIPImageProcessor,
27
+ CLIPTextConfig,
27
28
  CLIPTextModel,
28
29
  CLIPTextModelWithProjection,
29
30
  CLIPTokenizer,
@@ -48,7 +49,7 @@ from ...schedulers import (
48
49
  PNDMScheduler,
49
50
  UnCLIPScheduler,
50
51
  )
51
- from ...utils import is_omegaconf_available, is_safetensors_available, logging
52
+ from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
52
53
  from ...utils.import_utils import BACKENDS_MAPPING
53
54
  from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
54
55
  from ..paint_by_example import PaintByExampleImageEncoder
@@ -57,6 +58,10 @@ from .safety_checker import StableDiffusionSafetyChecker
57
58
  from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
58
59
 
59
60
 
61
+ if is_accelerate_available():
62
+ from accelerate import init_empty_weights
63
+ from accelerate.utils import set_module_tensor_to_device
64
+
60
65
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
61
66
 
62
67
 
@@ -233,7 +238,10 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
233
238
  if controlnet:
234
239
  unet_params = original_config.model.params.control_stage_config.params
235
240
  else:
236
- unet_params = original_config.model.params.unet_config.params
241
+ if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
242
+ unet_params = original_config.model.params.unet_config.params
243
+ else:
244
+ unet_params = original_config.model.params.network_config.params
237
245
 
238
246
  vae_params = original_config.model.params.first_stage_config.params.ddconfig
239
247
 
@@ -253,6 +261,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
253
261
  up_block_types.append(block_type)
254
262
  resolution //= 2
255
263
 
264
+ if unet_params.transformer_depth is not None:
265
+ transformer_layers_per_block = (
266
+ unet_params.transformer_depth
267
+ if isinstance(unet_params.transformer_depth, int)
268
+ else list(unet_params.transformer_depth)
269
+ )
270
+ else:
271
+ transformer_layers_per_block = 1
272
+
256
273
  vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
257
274
 
258
275
  head_dim = unet_params.num_heads if "num_heads" in unet_params else None
@@ -262,14 +279,28 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
262
279
  if use_linear_projection:
263
280
  # stable diffusion 2-base-512 and 2-768
264
281
  if head_dim is None:
265
- head_dim = [5, 10, 20, 20]
282
+ head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
283
+ head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
266
284
 
267
285
  class_embed_type = None
286
+ addition_embed_type = None
287
+ addition_time_embed_dim = None
268
288
  projection_class_embeddings_input_dim = None
289
+ context_dim = None
290
+
291
+ if unet_params.context_dim is not None:
292
+ context_dim = (
293
+ unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
294
+ )
269
295
 
270
296
  if "num_classes" in unet_params:
271
297
  if unet_params.num_classes == "sequential":
272
- class_embed_type = "projection"
298
+ if context_dim in [2048, 1280]:
299
+ # SDXL
300
+ addition_embed_type = "text_time"
301
+ addition_time_embed_dim = 256
302
+ else:
303
+ class_embed_type = "projection"
273
304
  assert "adm_in_channels" in unet_params
274
305
  projection_class_embeddings_input_dim = unet_params.adm_in_channels
275
306
  else:
@@ -281,14 +312,19 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
281
312
  "down_block_types": tuple(down_block_types),
282
313
  "block_out_channels": tuple(block_out_channels),
283
314
  "layers_per_block": unet_params.num_res_blocks,
284
- "cross_attention_dim": unet_params.context_dim,
315
+ "cross_attention_dim": context_dim,
285
316
  "attention_head_dim": head_dim,
286
317
  "use_linear_projection": use_linear_projection,
287
318
  "class_embed_type": class_embed_type,
319
+ "addition_embed_type": addition_embed_type,
320
+ "addition_time_embed_dim": addition_time_embed_dim,
288
321
  "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
322
+ "transformer_layers_per_block": transformer_layers_per_block,
289
323
  }
290
324
 
291
- if not controlnet:
325
+ if controlnet:
326
+ config["conditioning_channels"] = unet_params.hint_channels
327
+ else:
292
328
  config["out_channels"] = unet_params.out_channels
293
329
  config["up_block_types"] = tuple(up_block_types)
294
330
 
@@ -360,8 +396,8 @@ def convert_ldm_unet_checkpoint(
360
396
 
361
397
  # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
362
398
  if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
363
- print(f"Checkpoint {path} has both EMA and non-EMA weights.")
364
- print(
399
+ logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
400
+ logger.warning(
365
401
  "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
366
402
  " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
367
403
  )
@@ -371,7 +407,7 @@ def convert_ldm_unet_checkpoint(
371
407
  unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
372
408
  else:
373
409
  if sum(k.startswith("model_ema") for k in keys) > 100:
374
- print(
410
+ logger.warning(
375
411
  "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
376
412
  " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
377
413
  )
@@ -398,6 +434,12 @@ def convert_ldm_unet_checkpoint(
398
434
  else:
399
435
  raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
400
436
 
437
+ if config["addition_embed_type"] == "text_time":
438
+ new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
439
+ new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
440
+ new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
441
+ new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
442
+
401
443
  new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
402
444
  new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
403
445
 
@@ -732,27 +774,37 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
732
774
  return hf_model
733
775
 
734
776
 
735
- def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
736
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
777
+ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
778
+ if text_encoder is None:
779
+ config_name = "openai/clip-vit-large-patch14"
780
+ config = CLIPTextConfig.from_pretrained(config_name)
781
+
782
+ with init_empty_weights():
783
+ text_model = CLIPTextModel(config)
737
784
 
738
785
  keys = list(checkpoint.keys())
739
786
 
740
787
  text_model_dict = {}
741
788
 
789
+ remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]
790
+
742
791
  for key in keys:
743
- if key.startswith("cond_stage_model.transformer"):
744
- text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
792
+ for prefix in remove_prefixes:
793
+ if key.startswith(prefix):
794
+ text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
745
795
 
746
- text_model.load_state_dict(text_model_dict)
796
+ for param_name, param in text_model_dict.items():
797
+ set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
747
798
 
748
799
  return text_model
749
800
 
750
801
 
751
802
  textenc_conversion_lst = [
752
- ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
753
- ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
754
- ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
755
- ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
803
+ ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
804
+ ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
805
+ ("ln_final.weight", "text_model.final_layer_norm.weight"),
806
+ ("ln_final.bias", "text_model.final_layer_norm.bias"),
807
+ ("text_projection", "text_projection.weight"),
756
808
  ]
757
809
  textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
758
810
 
@@ -839,27 +891,48 @@ def convert_paint_by_example_checkpoint(checkpoint):
839
891
  return model
840
892
 
841
893
 
842
- def convert_open_clip_checkpoint(checkpoint):
843
- text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
894
+ def convert_open_clip_checkpoint(
895
+ checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
896
+ ):
897
+ # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
898
+ # text_model = CLIPTextModelWithProjection.from_pretrained(
899
+ # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
900
+ # )
901
+ config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
902
+
903
+ with init_empty_weights():
904
+ text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
844
905
 
845
906
  keys = list(checkpoint.keys())
846
907
 
908
+ keys_to_ignore = []
909
+ if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
910
+ # make sure to remove all keys > 22
911
+ keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
912
+ keys_to_ignore += ["cond_stage_model.model.text_projection"]
913
+
847
914
  text_model_dict = {}
848
915
 
849
- if "cond_stage_model.model.text_projection" in checkpoint:
850
- d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
916
+ if prefix + "text_projection" in checkpoint:
917
+ d_model = int(checkpoint[prefix + "text_projection"].shape[0])
851
918
  else:
852
919
  d_model = 1024
853
920
 
854
921
  text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
855
922
 
856
923
  for key in keys:
857
- if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
924
+ if key in keys_to_ignore:
858
925
  continue
859
- if key in textenc_conversion_map:
860
- text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
861
- if key.startswith("cond_stage_model.model.transformer."):
862
- new_key = key[len("cond_stage_model.model.transformer.") :]
926
+ if key[len(prefix) :] in textenc_conversion_map:
927
+ if key.endswith("text_projection"):
928
+ value = checkpoint[key].T
929
+ else:
930
+ value = checkpoint[key]
931
+
932
+ text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
933
+
934
+ if key.startswith(prefix + "transformer."):
935
+ new_key = key[len(prefix + "transformer.") :]
863
936
  if new_key.endswith(".in_proj_weight"):
864
937
  new_key = new_key[: -len(".in_proj_weight")]
865
938
  new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
@@ -877,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint):
877
950
 
878
951
  text_model_dict[new_key] = checkpoint[key]
879
952
 
880
- text_model.load_state_dict(text_model_dict)
953
+ for param_name, param in text_model_dict.items():
954
+ set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
881
955
 
882
956
  return text_model
883
957
 
@@ -1007,7 +1081,7 @@ def convert_controlnet_checkpoint(
1007
1081
  def download_from_original_stable_diffusion_ckpt(
1008
1082
  checkpoint_path: str,
1009
1083
  original_config_file: str = None,
1010
- image_size: int = 512,
1084
+ image_size: Optional[int] = None,
1011
1085
  prediction_type: str = None,
1012
1086
  model_type: str = None,
1013
1087
  extract_ema: bool = False,
@@ -1023,6 +1097,9 @@ def download_from_original_stable_diffusion_ckpt(
1023
1097
  load_safety_checker: bool = True,
1024
1098
  pipeline_class: DiffusionPipeline = None,
1025
1099
  local_files_only=False,
1100
+ vae_path=None,
1101
+ text_encoder=None,
1102
+ tokenizer=None,
1026
1103
  ) -> DiffusionPipeline:
1027
1104
  """
1028
1105
  Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1070,15 +1147,27 @@ def download_from_original_stable_diffusion_ckpt(
1070
1147
  The pipeline class to use. Pass `None` to determine automatically.
1071
1148
  local_files_only (`bool`, *optional*, defaults to `False`):
1072
1149
  Whether or not to only look at local files (i.e., do not try to download the model).
1150
+ text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1151
+ An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
1152
+ to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1153
+ variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
1154
+ tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1155
+ An instance of
1156
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1157
+ to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
1158
+ needed.
1073
1159
  return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
1074
1160
  """
1075
1161
 
1076
- # import pipelines here to avoid circular import error when using from_ckpt method
1162
+ # import pipelines here to avoid circular import error when using from_single_file method
1077
1163
  from diffusers import (
1078
1164
  LDMTextToImagePipeline,
1079
1165
  PaintByExamplePipeline,
1080
1166
  StableDiffusionControlNetPipeline,
1167
+ StableDiffusionInpaintPipeline,
1081
1168
  StableDiffusionPipeline,
1169
+ StableDiffusionXLImg2ImgPipeline,
1170
+ StableDiffusionXLPipeline,
1082
1171
  StableUnCLIPImg2ImgPipeline,
1083
1172
  StableUnCLIPPipeline,
1084
1173
  )
@@ -1098,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt(
1098
1187
  if not is_safetensors_available():
1099
1188
  raise ValueError(BACKENDS_MAPPING["safetensors"][1])
1100
1189
 
1101
- from safetensors import safe_open
1190
+ from safetensors.torch import load_file as safe_load
1102
1191
 
1103
- checkpoint = {}
1104
- with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
1105
- for key in f.keys():
1106
- checkpoint[key] = f.get_tensor(key)
1192
+ checkpoint = safe_load(checkpoint_path, device="cpu")
1107
1193
  else:
1108
1194
  if device is None:
1109
1195
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -1115,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt(
1115
1201
  if "global_step" in checkpoint:
1116
1202
  global_step = checkpoint["global_step"]
1117
1203
  else:
1118
- print("global_step key not found in model")
1204
+ logger.debug("global_step key not found in model")
1119
1205
  global_step = None
1120
1206
 
1121
1207
  # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
@@ -1124,24 +1210,53 @@ def download_from_original_stable_diffusion_ckpt(
1124
1210
  checkpoint = checkpoint["state_dict"]
1125
1211
 
1126
1212
  if original_config_file is None:
1127
- key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
1213
+ key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
1214
+ key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
1215
+ key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
1128
1216
 
1129
1217
  # model_type = "v1"
1130
1218
  config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
1131
1219
 
1132
- if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
1220
+ if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
1133
1221
  # model_type = "v2"
1134
1222
  config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
1135
1223
 
1136
1224
  if global_step == 110000:
1137
1225
  # v2.1 needs to upcast attention
1138
1226
  upcast_attention = True
1227
+ elif key_name_sd_xl_base in checkpoint:
1228
+ # only base xl has two text embedders
1229
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
1230
+ elif key_name_sd_xl_refiner in checkpoint:
1231
+ # only refiner xl has embedder and one text embedders
1232
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
1139
1233
 
1140
1234
  original_config_file = BytesIO(requests.get(config_url).content)
1141
1235
 
1142
1236
  original_config = OmegaConf.load(original_config_file)
1143
1237
 
1144
- if num_in_channels is not None:
1238
+ # Convert the text model.
1239
+ if (
1240
+ model_type is None
1241
+ and "cond_stage_config" in original_config.model.params
1242
+ and original_config.model.params.cond_stage_config is not None
1243
+ ):
1244
+ model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
1245
+ logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
1246
+ elif model_type is None and original_config.model.params.network_config is not None:
1247
+ if original_config.model.params.network_config.params.context_dim == 2048:
1248
+ model_type = "SDXL"
1249
+ else:
1250
+ model_type = "SDXL-Refiner"
1251
+ if image_size is None:
1252
+ image_size = 1024
1253
+
1254
+ if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
1255
+ num_in_channels = 9
1256
+ elif num_in_channels is None:
1257
+ num_in_channels = 4
1258
+
1259
+ if "unet_config" in original_config.model.params:
1145
1260
  original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
1146
1261
 
1147
1262
  if (
@@ -1170,20 +1285,37 @@ def download_from_original_stable_diffusion_ckpt(
1170
1285
  checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
1171
1286
  )
1172
1287
 
1173
- num_train_timesteps = original_config.model.params.timesteps
1174
- beta_start = original_config.model.params.linear_start
1175
- beta_end = original_config.model.params.linear_end
1176
-
1177
- scheduler = DDIMScheduler(
1178
- beta_end=beta_end,
1179
- beta_schedule="scaled_linear",
1180
- beta_start=beta_start,
1181
- num_train_timesteps=num_train_timesteps,
1182
- steps_offset=1,
1183
- clip_sample=False,
1184
- set_alpha_to_one=False,
1185
- prediction_type=prediction_type,
1186
- )
1288
+ num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
1289
+
1290
+ if model_type in ["SDXL", "SDXL-Refiner"]:
1291
+ scheduler_dict = {
1292
+ "beta_schedule": "scaled_linear",
1293
+ "beta_start": 0.00085,
1294
+ "beta_end": 0.012,
1295
+ "interpolation_type": "linear",
1296
+ "num_train_timesteps": num_train_timesteps,
1297
+ "prediction_type": "epsilon",
1298
+ "sample_max_value": 1.0,
1299
+ "set_alpha_to_one": False,
1300
+ "skip_prk_steps": True,
1301
+ "steps_offset": 1,
1302
+ "timestep_spacing": "leading",
1303
+ }
1304
+ scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
1305
+ scheduler_type = "euler"
1306
+ else:
1307
+ beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
1308
+ beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
1309
+ scheduler = DDIMScheduler(
1310
+ beta_end=beta_end,
1311
+ beta_schedule="scaled_linear",
1312
+ beta_start=beta_start,
1313
+ num_train_timesteps=num_train_timesteps,
1314
+ steps_offset=1,
1315
+ clip_sample=False,
1316
+ set_alpha_to_one=False,
1317
+ prediction_type=prediction_type,
1318
+ )
1187
1319
  # make sure scheduler works correctly with DDIM
1188
1320
  scheduler.register_to_config(clip_sample=False)
1189
1321
 
@@ -1209,28 +1341,45 @@ def download_from_original_stable_diffusion_ckpt(
1209
1341
  # Convert the UNet2DConditionModel model.
1210
1342
  unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
1211
1343
  unet_config["upcast_attention"] = upcast_attention
1212
- unet = UNet2DConditionModel(**unet_config)
1344
+ with init_empty_weights():
1345
+ unet = UNet2DConditionModel(**unet_config)
1213
1346
 
1214
1347
  converted_unet_checkpoint = convert_ldm_unet_checkpoint(
1215
1348
  checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
1216
1349
  )
1217
1350
 
1218
- unet.load_state_dict(converted_unet_checkpoint)
1351
+ for param_name, param in converted_unet_checkpoint.items():
1352
+ set_module_tensor_to_device(unet, param_name, "cpu", value=param)
1219
1353
 
1220
1354
  # Convert the VAE model.
1221
- vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
1222
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
1355
+ if vae_path is None:
1356
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
1357
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
1358
+
1359
+ if (
1360
+ "model" in original_config
1361
+ and "params" in original_config.model
1362
+ and "scale_factor" in original_config.model.params
1363
+ ):
1364
+ vae_scaling_factor = original_config.model.params.scale_factor
1365
+ else:
1366
+ vae_scaling_factor = 0.18215 # default SD scaling factor
1223
1367
 
1224
- vae = AutoencoderKL(**vae_config)
1225
- vae.load_state_dict(converted_vae_checkpoint)
1368
+ vae_config["scaling_factor"] = vae_scaling_factor
1226
1369
 
1227
- # Convert the text model.
1228
- if model_type is None:
1229
- model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
1230
- logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
1370
+ with init_empty_weights():
1371
+ vae = AutoencoderKL(**vae_config)
1372
+
1373
+ for param_name, param in converted_vae_checkpoint.items():
1374
+ set_module_tensor_to_device(vae, param_name, "cpu", value=param)
1375
+ else:
1376
+ vae = AutoencoderKL.from_pretrained(vae_path)
1231
1377
 
1232
1378
  if model_type == "FrozenOpenCLIPEmbedder":
1233
- text_model = convert_open_clip_checkpoint(checkpoint)
1379
+ config_name = "stabilityai/stable-diffusion-2"
1380
+ config_kwargs = {"subfolder": "text_encoder"}
1381
+
1382
+ text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
1234
1383
  tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
1235
1384
 
1236
1385
  if stable_unclip is None:
@@ -1325,8 +1474,10 @@ def download_from_original_stable_diffusion_ckpt(
1325
1474
  feature_extractor=feature_extractor,
1326
1475
  )
1327
1476
  elif model_type == "FrozenCLIPEmbedder":
1328
- text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
1329
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1477
+ text_model = convert_ldm_clip_checkpoint(
1478
+ checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
1479
+ )
1480
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer
1330
1481
 
1331
1482
  if load_safety_checker:
1332
1483
  safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
@@ -1356,6 +1507,50 @@ def download_from_original_stable_diffusion_ckpt(
1356
1507
  safety_checker=safety_checker,
1357
1508
  feature_extractor=feature_extractor,
1358
1509
  )
1510
+ elif model_type in ["SDXL", "SDXL-Refiner"]:
1511
+ if model_type == "SDXL":
1512
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1513
+ text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
1514
+ tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
1515
+
1516
+ config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1517
+ config_kwargs = {"projection_dim": 1280}
1518
+ text_encoder_2 = convert_open_clip_checkpoint(
1519
+ checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
1520
+ )
1521
+
1522
+ pipe = StableDiffusionXLPipeline(
1523
+ vae=vae,
1524
+ text_encoder=text_encoder,
1525
+ tokenizer=tokenizer,
1526
+ text_encoder_2=text_encoder_2,
1527
+ tokenizer_2=tokenizer_2,
1528
+ unet=unet,
1529
+ scheduler=scheduler,
1530
+ force_zeros_for_empty_prompt=True,
1531
+ )
1532
+ else:
1533
+ tokenizer = None
1534
+ text_encoder = None
1535
+ tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
1536
+
1537
+ config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1538
+ config_kwargs = {"projection_dim": 1280}
1539
+ text_encoder_2 = convert_open_clip_checkpoint(
1540
+ checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
1541
+ )
1542
+
1543
+ pipe = StableDiffusionXLImg2ImgPipeline(
1544
+ vae=vae,
1545
+ text_encoder=text_encoder,
1546
+ tokenizer=tokenizer,
1547
+ text_encoder_2=text_encoder_2,
1548
+ tokenizer_2=tokenizer_2,
1549
+ unet=unet,
1550
+ scheduler=scheduler,
1551
+ requires_aesthetics_score=True,
1552
+ force_zeros_for_empty_prompt=False,
1553
+ )
1359
1554
  else:
1360
1555
  text_config = create_ldm_bert_config(original_config)
1361
1556
  text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
22
22
 
23
23
  from ...configuration_utils import FrozenDict
24
24
  from ...image_processor import VaeImageProcessor
25
- from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
25
+ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, UNet2DConditionModel
27
27
  from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import (
@@ -69,7 +69,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
69
69
  return noise_cfg
70
70
 
71
71
 
72
- class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
72
+ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
73
73
  r"""
74
74
  Pipeline for text-to-image generation using Stable Diffusion.
75
75
 
@@ -79,7 +79,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
79
79
  In addition the pipeline inherits the following loading methods:
80
80
  - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
81
81
  - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
82
- - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
82
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
83
83
 
84
84
  as well as the following saving methods:
85
85
  - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
24
24
 
25
25
  from ...configuration_utils import FrozenDict
26
26
  from ...image_processor import VaeImageProcessor
27
- from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
28
28
  from ...models import AutoencoderKL, UNet2DConditionModel
29
29
  from ...schedulers import KarrasDiffusionSchedulers
30
30
  from ...utils import (
@@ -98,7 +98,9 @@ def preprocess(image):
98
98
  return image
99
99
 
100
100
 
101
- class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
101
+ class StableDiffusionImg2ImgPipeline(
102
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
103
+ ):
102
104
  r"""
103
105
  Pipeline for text-guided image to image generation using Stable Diffusion.
104
106
 
@@ -108,7 +110,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
108
110
  In addition the pipeline inherits the following loading methods:
109
111
  - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
110
112
  - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
111
- - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
113
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
112
114
 
113
115
  as well as the following saving methods:
114
116
  - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
24
24
 
25
25
  from ...configuration_utils import FrozenDict
26
26
  from ...image_processor import VaeImageProcessor
27
- from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
28
28
  from ...models import AutoencoderKL, UNet2DConditionModel
29
29
  from ...schedulers import KarrasDiffusionSchedulers
30
30
  from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
153
153
  return mask, masked_image
154
154
 
155
155
 
156
- class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
156
+ class StableDiffusionInpaintPipeline(
157
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
158
+ ):
157
159
  r"""
158
160
  Pipeline for text-guided image inpainting using Stable Diffusion.
159
161