diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (174) hide show
  1. diffusers/__init__.py +11 -1
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +12 -8
  4. diffusers/dependency_versions_table.py +2 -1
  5. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  6. diffusers/image_processor.py +286 -46
  7. diffusers/loaders/ip_adapter.py +11 -9
  8. diffusers/loaders/lora.py +198 -60
  9. diffusers/loaders/single_file.py +24 -18
  10. diffusers/loaders/textual_inversion.py +10 -14
  11. diffusers/loaders/unet.py +130 -37
  12. diffusers/models/__init__.py +18 -12
  13. diffusers/models/activations.py +9 -6
  14. diffusers/models/attention.py +137 -16
  15. diffusers/models/attention_processor.py +133 -46
  16. diffusers/models/autoencoders/__init__.py +5 -0
  17. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
  18. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
  19. diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
  20. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
  21. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
  22. diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
  23. diffusers/models/downsampling.py +338 -0
  24. diffusers/models/embeddings.py +112 -29
  25. diffusers/models/modeling_flax_utils.py +12 -7
  26. diffusers/models/modeling_utils.py +10 -10
  27. diffusers/models/normalization.py +108 -2
  28. diffusers/models/resnet.py +15 -699
  29. diffusers/models/transformer_2d.py +2 -2
  30. diffusers/models/unet_2d_condition.py +37 -0
  31. diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
  32. diffusers/models/upsampling.py +454 -0
  33. diffusers/models/uvit_2d.py +471 -0
  34. diffusers/models/vq_model.py +9 -2
  35. diffusers/pipelines/__init__.py +81 -73
  36. diffusers/pipelines/amused/__init__.py +62 -0
  37. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  38. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  39. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
  41. diffusers/pipelines/auto_pipeline.py +17 -13
  42. diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
  43. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
  44. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
  45. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
  46. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
  47. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
  48. diffusers/pipelines/deprecated/__init__.py +153 -0
  49. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  50. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
  51. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
  52. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  53. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  54. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  55. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  56. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  57. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  58. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  59. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  60. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  61. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  62. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  63. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
  64. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  65. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  66. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  67. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  68. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
  69. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  70. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
  71. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
  72. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
  73. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
  74. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
  75. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
  76. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  77. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  78. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  79. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
  80. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  81. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
  82. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
  83. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
  84. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  85. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  86. diffusers/pipelines/kandinsky3/__init__.py +4 -4
  87. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  88. diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
  89. diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
  90. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
  91. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
  92. diffusers/pipelines/onnx_utils.py +8 -5
  93. diffusers/pipelines/pipeline_flax_utils.py +7 -6
  94. diffusers/pipelines/pipeline_utils.py +30 -29
  95. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
  96. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  97. diffusers/pipelines/stable_diffusion/__init__.py +1 -72
  98. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  107. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  108. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
  109. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  110. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
  111. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  112. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
  113. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
  114. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  115. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
  116. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  117. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
  118. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  119. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
  120. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  121. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  122. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
  131. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
  132. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
  133. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  134. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  135. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  136. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
  137. diffusers/schedulers/__init__.py +2 -0
  138. diffusers/schedulers/scheduling_amused.py +162 -0
  139. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  140. diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
  141. diffusers/schedulers/scheduling_ddpm.py +46 -0
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
  143. diffusers/schedulers/scheduling_deis_multistep.py +13 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
  149. diffusers/schedulers/scheduling_euler_discrete.py +62 -3
  150. diffusers/schedulers/scheduling_heun_discrete.py +2 -0
  151. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
  152. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
  153. diffusers/schedulers/scheduling_lms_discrete.py +2 -0
  154. diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
  155. diffusers/schedulers/scheduling_utils.py +3 -1
  156. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  157. diffusers/training_utils.py +1 -1
  158. diffusers/utils/__init__.py +0 -2
  159. diffusers/utils/constants.py +2 -5
  160. diffusers/utils/dummy_pt_objects.py +30 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  162. diffusers/utils/dynamic_modules_utils.py +14 -18
  163. diffusers/utils/hub_utils.py +24 -36
  164. diffusers/utils/logging.py +1 -1
  165. diffusers/utils/state_dict_utils.py +8 -0
  166. diffusers/utils/testing_utils.py +199 -1
  167. diffusers/utils/torch_utils.py +3 -3
  168. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
  169. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
  170. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  172. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  173. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -447,7 +447,8 @@ def convert_ldm_unet_checkpoint(
447
447
 
448
448
  # Relevant to StableDiffusionUpscalePipeline
449
449
  if "num_class_embeds" in config:
450
- new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
450
+ if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
451
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
451
452
 
452
453
  new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
453
454
  new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
@@ -1152,7 +1153,9 @@ def download_from_original_stable_diffusion_ckpt(
1152
1153
  vae_path=None,
1153
1154
  vae=None,
1154
1155
  text_encoder=None,
1156
+ text_encoder_2=None,
1155
1157
  tokenizer=None,
1158
+ tokenizer_2=None,
1156
1159
  config_files=None,
1157
1160
  ) -> DiffusionPipeline:
1158
1161
  """
@@ -1231,7 +1234,9 @@ def download_from_original_stable_diffusion_ckpt(
1231
1234
  StableDiffusionInpaintPipeline,
1232
1235
  StableDiffusionPipeline,
1233
1236
  StableDiffusionUpscalePipeline,
1237
+ StableDiffusionXLControlNetInpaintPipeline,
1234
1238
  StableDiffusionXLImg2ImgPipeline,
1239
+ StableDiffusionXLInpaintPipeline,
1235
1240
  StableDiffusionXLPipeline,
1236
1241
  StableUnCLIPImg2ImgPipeline,
1237
1242
  StableUnCLIPPipeline,
@@ -1338,7 +1343,11 @@ def download_from_original_stable_diffusion_ckpt(
1338
1343
  else:
1339
1344
  pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
1340
1345
 
1341
- if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
1346
+ if num_in_channels is None and pipeline_class in [
1347
+ StableDiffusionInpaintPipeline,
1348
+ StableDiffusionXLInpaintPipeline,
1349
+ StableDiffusionXLControlNetInpaintPipeline,
1350
+ ]:
1342
1351
  num_in_channels = 9
1343
1352
  if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
1344
1353
  num_in_channels = 7
@@ -1480,9 +1489,12 @@ def download_from_original_stable_diffusion_ckpt(
1480
1489
  config_name = "stabilityai/stable-diffusion-2"
1481
1490
  config_kwargs = {"subfolder": "text_encoder"}
1482
1491
 
1483
- text_model = convert_open_clip_checkpoint(
1484
- checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
1485
- )
1492
+ if text_encoder is None:
1493
+ text_model = convert_open_clip_checkpoint(
1494
+ checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
1495
+ )
1496
+ else:
1497
+ text_model = text_encoder
1486
1498
 
1487
1499
  try:
1488
1500
  tokenizer = CLIPTokenizer.from_pretrained(
@@ -1682,7 +1694,9 @@ def download_from_original_stable_diffusion_ckpt(
1682
1694
  feature_extractor=feature_extractor,
1683
1695
  )
1684
1696
  elif model_type in ["SDXL", "SDXL-Refiner"]:
1685
- if model_type == "SDXL":
1697
+ is_refiner = model_type == "SDXL-Refiner"
1698
+
1699
+ if (is_refiner is False) and (tokenizer is None):
1686
1700
  try:
1687
1701
  tokenizer = CLIPTokenizer.from_pretrained(
1688
1702
  "openai/clip-vit-large-patch14", local_files_only=local_files_only
@@ -1691,7 +1705,11 @@ def download_from_original_stable_diffusion_ckpt(
1691
1705
  raise ValueError(
1692
1706
  f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
1693
1707
  )
1708
+
1709
+ if (is_refiner is False) and (text_encoder is None):
1694
1710
  text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
1711
+
1712
+ if tokenizer_2 is None:
1695
1713
  try:
1696
1714
  tokenizer_2 = CLIPTokenizer.from_pretrained(
1697
1715
  "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
@@ -1701,95 +1719,69 @@ def download_from_original_stable_diffusion_ckpt(
1701
1719
  f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
1702
1720
  )
1703
1721
 
1722
+ if text_encoder_2 is None:
1704
1723
  config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1705
1724
  config_kwargs = {"projection_dim": 1280}
1706
- text_encoder_2 = convert_open_clip_checkpoint(
1707
- checkpoint,
1708
- config_name,
1709
- prefix="conditioner.embedders.1.model.",
1710
- has_projection=True,
1711
- local_files_only=local_files_only,
1712
- **config_kwargs,
1713
- )
1714
-
1715
- if is_accelerate_available(): # SBM Now move model to cpu.
1716
- if model_type in ["SDXL", "SDXL-Refiner"]:
1717
- for param_name, param in converted_unet_checkpoint.items():
1718
- set_module_tensor_to_device(unet, param_name, "cpu", value=param)
1725
+ prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model."
1719
1726
 
1720
- if controlnet:
1721
- pipe = pipeline_class(
1722
- vae=vae,
1723
- text_encoder=text_encoder,
1724
- tokenizer=tokenizer,
1725
- text_encoder_2=text_encoder_2,
1726
- tokenizer_2=tokenizer_2,
1727
- unet=unet,
1728
- controlnet=controlnet,
1729
- scheduler=scheduler,
1730
- force_zeros_for_empty_prompt=True,
1731
- )
1732
- elif adapter:
1733
- pipe = pipeline_class(
1734
- vae=vae,
1735
- text_encoder=text_encoder,
1736
- tokenizer=tokenizer,
1737
- text_encoder_2=text_encoder_2,
1738
- tokenizer_2=tokenizer_2,
1739
- unet=unet,
1740
- adapter=adapter,
1741
- scheduler=scheduler,
1742
- force_zeros_for_empty_prompt=True,
1743
- )
1744
- else:
1745
- pipe = pipeline_class(
1746
- vae=vae,
1747
- text_encoder=text_encoder,
1748
- tokenizer=tokenizer,
1749
- text_encoder_2=text_encoder_2,
1750
- tokenizer_2=tokenizer_2,
1751
- unet=unet,
1752
- scheduler=scheduler,
1753
- force_zeros_for_empty_prompt=True,
1754
- )
1755
- else:
1756
- tokenizer = None
1757
- text_encoder = None
1758
- try:
1759
- tokenizer_2 = CLIPTokenizer.from_pretrained(
1760
- "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
1761
- )
1762
- except Exception:
1763
- raise ValueError(
1764
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
1765
- )
1766
- config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1767
- config_kwargs = {"projection_dim": 1280}
1768
1727
  text_encoder_2 = convert_open_clip_checkpoint(
1769
1728
  checkpoint,
1770
1729
  config_name,
1771
- prefix="conditioner.embedders.0.model.",
1730
+ prefix=prefix,
1772
1731
  has_projection=True,
1773
1732
  local_files_only=local_files_only,
1774
1733
  **config_kwargs,
1775
1734
  )
1776
1735
 
1777
- if is_accelerate_available(): # SBM Now move model to cpu.
1778
- if model_type in ["SDXL", "SDXL-Refiner"]:
1779
- for param_name, param in converted_unet_checkpoint.items():
1780
- set_module_tensor_to_device(unet, param_name, "cpu", value=param)
1736
+ if is_accelerate_available(): # SBM Now move model to cpu.
1737
+ for param_name, param in converted_unet_checkpoint.items():
1738
+ set_module_tensor_to_device(unet, param_name, "cpu", value=param)
1781
1739
 
1782
- pipe = StableDiffusionXLImg2ImgPipeline(
1740
+ if controlnet:
1741
+ pipe = pipeline_class(
1783
1742
  vae=vae,
1784
1743
  text_encoder=text_encoder,
1785
1744
  tokenizer=tokenizer,
1786
1745
  text_encoder_2=text_encoder_2,
1787
1746
  tokenizer_2=tokenizer_2,
1788
1747
  unet=unet,
1748
+ controlnet=controlnet,
1749
+ scheduler=scheduler,
1750
+ force_zeros_for_empty_prompt=True,
1751
+ )
1752
+ elif adapter:
1753
+ pipe = pipeline_class(
1754
+ vae=vae,
1755
+ text_encoder=text_encoder,
1756
+ tokenizer=tokenizer,
1757
+ text_encoder_2=text_encoder_2,
1758
+ tokenizer_2=tokenizer_2,
1759
+ unet=unet,
1760
+ adapter=adapter,
1789
1761
  scheduler=scheduler,
1790
- requires_aesthetics_score=True,
1791
- force_zeros_for_empty_prompt=False,
1762
+ force_zeros_for_empty_prompt=True,
1792
1763
  )
1764
+
1765
+ else:
1766
+ pipeline_kwargs = {
1767
+ "vae": vae,
1768
+ "text_encoder": text_encoder,
1769
+ "tokenizer": tokenizer,
1770
+ "text_encoder_2": text_encoder_2,
1771
+ "tokenizer_2": tokenizer_2,
1772
+ "unet": unet,
1773
+ "scheduler": scheduler,
1774
+ }
1775
+
1776
+ if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
1777
+ pipeline_class == StableDiffusionXLInpaintPipeline
1778
+ ):
1779
+ pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})
1780
+
1781
+ if is_refiner:
1782
+ pipeline_kwargs.update({"force_zeros_for_empty_prompt": False})
1783
+
1784
+ pipe = pipeline_class(**pipeline_kwargs)
1793
1785
  else:
1794
1786
  text_config = create_ldm_bert_config(original_config)
1795
1787
  text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
@@ -22,7 +22,8 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
22
22
  from ...configuration_utils import FrozenDict
23
23
  from ...image_processor import PipelineImageInput, VaeImageProcessor
24
24
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
25
- from ...models import AutoencoderKL, UNet2DConditionModel
25
+ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
26
+ from ...models.attention_processor import FusedAttnProcessor2_0
26
27
  from ...models.lora import adjust_lora_scale_text_encoder
27
28
  from ...schedulers import KarrasDiffusionSchedulers
28
29
  from ...utils import (
@@ -150,7 +151,7 @@ class StableDiffusionPipeline(
150
151
  A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
151
152
  """
152
153
 
153
- model_cpu_offload_seq = "text_encoder->unet->vae"
154
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
154
155
  _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
155
156
  _exclude_from_cpu_offload = ["safety_checker"]
156
157
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -489,18 +490,29 @@ class StableDiffusionPipeline(
489
490
 
490
491
  return prompt_embeds, negative_prompt_embeds
491
492
 
492
- def encode_image(self, image, device, num_images_per_prompt):
493
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
493
494
  dtype = next(self.image_encoder.parameters()).dtype
494
495
 
495
496
  if not isinstance(image, torch.Tensor):
496
497
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
497
498
 
498
499
  image = image.to(device=device, dtype=dtype)
499
- image_embeds = self.image_encoder(image).image_embeds
500
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
500
+ if output_hidden_states:
501
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
502
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
503
+ uncond_image_enc_hidden_states = self.image_encoder(
504
+ torch.zeros_like(image), output_hidden_states=True
505
+ ).hidden_states[-2]
506
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
507
+ num_images_per_prompt, dim=0
508
+ )
509
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
510
+ else:
511
+ image_embeds = self.image_encoder(image).image_embeds
512
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
513
+ uncond_image_embeds = torch.zeros_like(image_embeds)
501
514
 
502
- uncond_image_embeds = torch.zeros_like(image_embeds)
503
- return image_embeds, uncond_image_embeds
515
+ return image_embeds, uncond_image_embeds
504
516
 
505
517
  def run_safety_checker(self, image, device, dtype):
506
518
  if self.safety_checker is None:
@@ -639,6 +651,67 @@ class StableDiffusionPipeline(
639
651
  """Disables the FreeU mechanism if enabled."""
640
652
  self.unet.disable_freeu()
641
653
 
654
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
655
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
656
+ """
657
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
658
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
659
+
660
+ <Tip warning={true}>
661
+
662
+ This API is 🧪 experimental.
663
+
664
+ </Tip>
665
+
666
+ Args:
667
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
668
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
669
+ """
670
+ self.fusing_unet = False
671
+ self.fusing_vae = False
672
+
673
+ if unet:
674
+ self.fusing_unet = True
675
+ self.unet.fuse_qkv_projections()
676
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
677
+
678
+ if vae:
679
+ if not isinstance(self.vae, AutoencoderKL):
680
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
681
+
682
+ self.fusing_vae = True
683
+ self.vae.fuse_qkv_projections()
684
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
685
+
686
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
687
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
688
+ """Disable QKV projection fusion if enabled.
689
+
690
+ <Tip warning={true}>
691
+
692
+ This API is 🧪 experimental.
693
+
694
+ </Tip>
695
+
696
+ Args:
697
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
698
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
699
+
700
+ """
701
+ if unet:
702
+ if not self.fusing_unet:
703
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
704
+ else:
705
+ self.unet.unfuse_qkv_projections()
706
+ self.fusing_unet = False
707
+
708
+ if vae:
709
+ if not self.fusing_vae:
710
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
711
+ else:
712
+ self.vae.unfuse_qkv_projections()
713
+ self.fusing_vae = False
714
+
642
715
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
643
716
  def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
644
717
  """
@@ -695,6 +768,10 @@ class StableDiffusionPipeline(
695
768
  def num_timesteps(self):
696
769
  return self._num_timesteps
697
770
 
771
+ @property
772
+ def interrupt(self):
773
+ return self._interrupt
774
+
698
775
  @torch.no_grad()
699
776
  @replace_example_docstring(EXAMPLE_DOC_STRING)
700
777
  def __call__(
@@ -836,6 +913,7 @@ class StableDiffusionPipeline(
836
913
  self._guidance_rescale = guidance_rescale
837
914
  self._clip_skip = clip_skip
838
915
  self._cross_attention_kwargs = cross_attention_kwargs
916
+ self._interrupt = False
839
917
 
840
918
  # 2. Define call parameters
841
919
  if prompt is not None and isinstance(prompt, str):
@@ -871,7 +949,10 @@ class StableDiffusionPipeline(
871
949
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
872
950
 
873
951
  if ip_adapter_image is not None:
874
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
952
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
953
+ image_embeds, negative_image_embeds = self.encode_image(
954
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
955
+ )
875
956
  if self.do_classifier_free_guidance:
876
957
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
877
958
 
@@ -910,6 +991,9 @@ class StableDiffusionPipeline(
910
991
  self._num_timesteps = len(timesteps)
911
992
  with self.progress_bar(total=num_inference_steps) as progress_bar:
912
993
  for i, t in enumerate(timesteps):
994
+ if self.interrupt:
995
+ continue
996
+
913
997
  # expand the latents if we are doing classifier free guidance
914
998
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
915
999
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -24,7 +24,8 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
24
24
  from ...configuration_utils import FrozenDict
25
25
  from ...image_processor import PipelineImageInput, VaeImageProcessor
26
26
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
- from ...models import AutoencoderKL, UNet2DConditionModel
27
+ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
28
+ from ...models.attention_processor import FusedAttnProcessor2_0
28
29
  from ...models.lora import adjust_lora_scale_text_encoder
29
30
  from ...schedulers import KarrasDiffusionSchedulers
30
31
  from ...utils import (
@@ -190,7 +191,7 @@ class StableDiffusionImg2ImgPipeline(
190
191
  A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
191
192
  """
192
193
 
193
- model_cpu_offload_seq = "text_encoder->unet->vae"
194
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
194
195
  _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
195
196
  _exclude_from_cpu_offload = ["safety_checker"]
196
197
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -503,18 +504,29 @@ class StableDiffusionImg2ImgPipeline(
503
504
  return prompt_embeds, negative_prompt_embeds
504
505
 
505
506
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
506
- def encode_image(self, image, device, num_images_per_prompt):
507
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
507
508
  dtype = next(self.image_encoder.parameters()).dtype
508
509
 
509
510
  if not isinstance(image, torch.Tensor):
510
511
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
511
512
 
512
513
  image = image.to(device=device, dtype=dtype)
513
- image_embeds = self.image_encoder(image).image_embeds
514
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
514
+ if output_hidden_states:
515
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
516
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
517
+ uncond_image_enc_hidden_states = self.image_encoder(
518
+ torch.zeros_like(image), output_hidden_states=True
519
+ ).hidden_states[-2]
520
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
521
+ num_images_per_prompt, dim=0
522
+ )
523
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
524
+ else:
525
+ image_embeds = self.image_encoder(image).image_embeds
526
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
527
+ uncond_image_embeds = torch.zeros_like(image_embeds)
515
528
 
516
- uncond_image_embeds = torch.zeros_like(image_embeds)
517
- return image_embeds, uncond_image_embeds
529
+ return image_embeds, uncond_image_embeds
518
530
 
519
531
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
520
532
  def run_safety_checker(self, image, device, dtype):
@@ -707,6 +719,67 @@ class StableDiffusionImg2ImgPipeline(
707
719
  """Disables the FreeU mechanism if enabled."""
708
720
  self.unet.disable_freeu()
709
721
 
722
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
723
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
724
+ """
725
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
726
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
727
+
728
+ <Tip warning={true}>
729
+
730
+ This API is 🧪 experimental.
731
+
732
+ </Tip>
733
+
734
+ Args:
735
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
736
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
737
+ """
738
+ self.fusing_unet = False
739
+ self.fusing_vae = False
740
+
741
+ if unet:
742
+ self.fusing_unet = True
743
+ self.unet.fuse_qkv_projections()
744
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
745
+
746
+ if vae:
747
+ if not isinstance(self.vae, AutoencoderKL):
748
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
749
+
750
+ self.fusing_vae = True
751
+ self.vae.fuse_qkv_projections()
752
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
753
+
754
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
755
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
756
+ """Disable QKV projection fusion if enabled.
757
+
758
+ <Tip warning={true}>
759
+
760
+ This API is 🧪 experimental.
761
+
762
+ </Tip>
763
+
764
+ Args:
765
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
766
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
767
+
768
+ """
769
+ if unet:
770
+ if not self.fusing_unet:
771
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
772
+ else:
773
+ self.unet.unfuse_qkv_projections()
774
+ self.fusing_unet = False
775
+
776
+ if vae:
777
+ if not self.fusing_vae:
778
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
779
+ else:
780
+ self.vae.unfuse_qkv_projections()
781
+ self.fusing_vae = False
782
+
710
783
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
711
784
  def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
712
785
  """
@@ -759,6 +832,10 @@ class StableDiffusionImg2ImgPipeline(
759
832
  def num_timesteps(self):
760
833
  return self._num_timesteps
761
834
 
835
+ @property
836
+ def interrupt(self):
837
+ return self._interrupt
838
+
762
839
  @torch.no_grad()
763
840
  @replace_example_docstring(EXAMPLE_DOC_STRING)
764
841
  def __call__(
@@ -890,6 +967,7 @@ class StableDiffusionImg2ImgPipeline(
890
967
  self._guidance_scale = guidance_scale
891
968
  self._clip_skip = clip_skip
892
969
  self._cross_attention_kwargs = cross_attention_kwargs
970
+ self._interrupt = False
893
971
 
894
972
  # 2. Define call parameters
895
973
  if prompt is not None and isinstance(prompt, str):
@@ -923,7 +1001,10 @@ class StableDiffusionImg2ImgPipeline(
923
1001
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
924
1002
 
925
1003
  if ip_adapter_image is not None:
926
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
1004
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
1005
+ image_embeds, negative_image_embeds = self.encode_image(
1006
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
1007
+ )
927
1008
  if self.do_classifier_free_guidance:
928
1009
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
929
1010
 
@@ -965,6 +1046,9 @@ class StableDiffusionImg2ImgPipeline(
965
1046
  self._num_timesteps = len(timesteps)
966
1047
  with self.progress_bar(total=num_inference_steps) as progress_bar:
967
1048
  for i, t in enumerate(timesteps):
1049
+ if self.interrupt:
1050
+ continue
1051
+
968
1052
  # expand the latents if we are doing classifier free guidance
969
1053
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
970
1054
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)