diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -1370,6 +1370,8 @@ def download_from_original_stable_diffusion_ckpt(
1370
1370
 
1371
1371
  if "unet_config" in original_config["model"]["params"]:
1372
1372
  original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
1373
+ elif "network_config" in original_config["model"]["params"]:
1374
+ original_config["model"]["params"]["network_config"]["params"]["in_channels"] = num_in_channels
1373
1375
 
1374
1376
  if (
1375
1377
  "parameterization" in original_config["model"]["params"]
@@ -55,7 +55,7 @@ EXAMPLE_DOC_STRING = """
55
55
  >>> from diffusers import FlaxStableDiffusionPipeline
56
56
 
57
57
  >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
58
- ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16
58
+ ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
59
59
  ... )
60
60
 
61
61
  >>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -21,7 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
21
21
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
22
22
  from ...configuration_utils import FrozenDict
23
23
  from ...image_processor import PipelineImageInput, VaeImageProcessor
24
- from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
24
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
25
25
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
26
26
  from ...models.lora import adjust_lora_scale_text_encoder
27
27
  from ...schedulers import KarrasDiffusionSchedulers
@@ -133,7 +133,7 @@ class StableDiffusionPipeline(
133
133
  DiffusionPipeline,
134
134
  StableDiffusionMixin,
135
135
  TextualInversionLoaderMixin,
136
- LoraLoaderMixin,
136
+ StableDiffusionLoraLoaderMixin,
137
137
  IPAdapterMixin,
138
138
  FromSingleFileMixin,
139
139
  ):
@@ -145,8 +145,8 @@ class StableDiffusionPipeline(
145
145
 
146
146
  The pipeline also inherits the following loading methods:
147
147
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
148
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
149
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
148
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
149
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
150
150
  - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
151
151
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
152
152
 
@@ -342,7 +342,7 @@ class StableDiffusionPipeline(
342
342
  """
343
343
  # set lora scale so that monkey patched LoRA
344
344
  # function of text encoder can correctly access it
345
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
345
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
346
346
  self._lora_scale = lora_scale
347
347
 
348
348
  # dynamically adjust the LoRA scale
@@ -475,7 +475,7 @@ class StableDiffusionPipeline(
475
475
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
476
476
 
477
477
  if self.text_encoder is not None:
478
- if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
478
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
479
479
  # Retrieve the original scale by scaling back the LoRA layers
480
480
  unscale_lora_layers(self.text_encoder, lora_scale)
481
481
 
@@ -508,6 +508,9 @@ class StableDiffusionPipeline(
508
508
  def prepare_ip_adapter_image_embeds(
509
509
  self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
510
510
  ):
511
+ image_embeds = []
512
+ if do_classifier_free_guidance:
513
+ negative_image_embeds = []
511
514
  if ip_adapter_image_embeds is None:
512
515
  if not isinstance(ip_adapter_image, list):
513
516
  ip_adapter_image = [ip_adapter_image]
@@ -517,7 +520,6 @@ class StableDiffusionPipeline(
517
520
  f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
518
521
  )
519
522
 
520
- image_embeds = []
521
523
  for single_ip_adapter_image, image_proj_layer in zip(
522
524
  ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
523
525
  ):
@@ -525,36 +527,28 @@ class StableDiffusionPipeline(
525
527
  single_image_embeds, single_negative_image_embeds = self.encode_image(
526
528
  single_ip_adapter_image, device, 1, output_hidden_state
527
529
  )
528
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
529
- single_negative_image_embeds = torch.stack(
530
- [single_negative_image_embeds] * num_images_per_prompt, dim=0
531
- )
532
530
 
531
+ image_embeds.append(single_image_embeds[None, :])
533
532
  if do_classifier_free_guidance:
534
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
535
- single_image_embeds = single_image_embeds.to(device)
536
-
537
- image_embeds.append(single_image_embeds)
533
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
538
534
  else:
539
- repeat_dims = [1]
540
- image_embeds = []
541
535
  for single_image_embeds in ip_adapter_image_embeds:
542
536
  if do_classifier_free_guidance:
543
537
  single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
544
- single_image_embeds = single_image_embeds.repeat(
545
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
546
- )
547
- single_negative_image_embeds = single_negative_image_embeds.repeat(
548
- num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
549
- )
550
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
551
- else:
552
- single_image_embeds = single_image_embeds.repeat(
553
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
554
- )
538
+ negative_image_embeds.append(single_negative_image_embeds)
555
539
  image_embeds.append(single_image_embeds)
556
540
 
557
- return image_embeds
541
+ ip_adapter_image_embeds = []
542
+ for i, single_image_embeds in enumerate(image_embeds):
543
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
544
+ if do_classifier_free_guidance:
545
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
546
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
547
+
548
+ single_image_embeds = single_image_embeds.to(device=device)
549
+ ip_adapter_image_embeds.append(single_image_embeds)
550
+
551
+ return ip_adapter_image_embeds
558
552
 
559
553
  def run_safety_checker(self, image, device, dtype):
560
554
  if self.safety_checker is None:
@@ -20,11 +20,11 @@ import numpy as np
20
20
  import PIL.Image
21
21
  import torch
22
22
  from packaging import version
23
- from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
23
+ from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor
24
24
 
25
25
  from ...configuration_utils import FrozenDict
26
26
  from ...image_processor import PipelineImageInput, VaeImageProcessor
27
- from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
28
28
  from ...models import AutoencoderKL, UNet2DConditionModel
29
29
  from ...models.lora import adjust_lora_scale_text_encoder
30
30
  from ...schedulers import KarrasDiffusionSchedulers
@@ -74,7 +74,7 @@ def preprocess(image):
74
74
  return image
75
75
 
76
76
 
77
- class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
77
+ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin):
78
78
  r"""
79
79
  Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion.
80
80
 
@@ -83,8 +83,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
83
83
 
84
84
  The pipeline also inherits the following loading methods:
85
85
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
86
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
87
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
86
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
87
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
88
88
 
89
89
  Args:
90
90
  vae ([`AutoencoderKL`]):
@@ -111,7 +111,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
111
111
  unet: UNet2DConditionModel,
112
112
  scheduler: KarrasDiffusionSchedulers,
113
113
  depth_estimator: DPTForDepthEstimation,
114
- feature_extractor: DPTFeatureExtractor,
114
+ feature_extractor: DPTImageProcessor,
115
115
  ):
116
116
  super().__init__()
117
117
 
@@ -225,7 +225,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
225
225
  """
226
226
  # set lora scale so that monkey patched LoRA
227
227
  # function of text encoder can correctly access it
228
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
228
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
229
229
  self._lora_scale = lora_scale
230
230
 
231
231
  # dynamically adjust the LoRA scale
@@ -358,7 +358,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
358
358
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
359
359
 
360
360
  if self.text_encoder is not None:
361
- if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
361
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
362
362
  # Retrieve the original scale by scaling back the LoRA layers
363
363
  unscale_lora_layers(self.text_encoder, lora_scale)
364
364
 
@@ -494,6 +494,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
494
494
  )
495
495
 
496
496
  elif isinstance(generator, list):
497
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
498
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
499
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
500
+ raise ValueError(
501
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
502
+ )
503
+
497
504
  init_latents = [
498
505
  retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
499
506
  for i in range(batch_size)
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
24
24
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
25
  from ...configuration_utils import FrozenDict
26
26
  from ...image_processor import PipelineImageInput, VaeImageProcessor
27
- from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
28
28
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
29
29
  from ...models.lora import adjust_lora_scale_text_encoder
30
30
  from ...schedulers import KarrasDiffusionSchedulers
@@ -175,7 +175,7 @@ class StableDiffusionImg2ImgPipeline(
175
175
  StableDiffusionMixin,
176
176
  TextualInversionLoaderMixin,
177
177
  IPAdapterMixin,
178
- LoraLoaderMixin,
178
+ StableDiffusionLoraLoaderMixin,
179
179
  FromSingleFileMixin,
180
180
  ):
181
181
  r"""
@@ -186,8 +186,8 @@ class StableDiffusionImg2ImgPipeline(
186
186
 
187
187
  The pipeline also inherits the following loading methods:
188
188
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
189
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
190
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
189
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
190
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
191
191
  - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
192
192
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
193
193
 
@@ -385,7 +385,7 @@ class StableDiffusionImg2ImgPipeline(
385
385
  """
386
386
  # set lora scale so that monkey patched LoRA
387
387
  # function of text encoder can correctly access it
388
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
388
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
389
389
  self._lora_scale = lora_scale
390
390
 
391
391
  # dynamically adjust the LoRA scale
@@ -518,7 +518,7 @@ class StableDiffusionImg2ImgPipeline(
518
518
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
519
519
 
520
520
  if self.text_encoder is not None:
521
- if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
521
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
522
522
  # Retrieve the original scale by scaling back the LoRA layers
523
523
  unscale_lora_layers(self.text_encoder, lora_scale)
524
524
 
@@ -553,6 +553,9 @@ class StableDiffusionImg2ImgPipeline(
553
553
  def prepare_ip_adapter_image_embeds(
554
554
  self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
555
555
  ):
556
+ image_embeds = []
557
+ if do_classifier_free_guidance:
558
+ negative_image_embeds = []
556
559
  if ip_adapter_image_embeds is None:
557
560
  if not isinstance(ip_adapter_image, list):
558
561
  ip_adapter_image = [ip_adapter_image]
@@ -562,7 +565,6 @@ class StableDiffusionImg2ImgPipeline(
562
565
  f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
563
566
  )
564
567
 
565
- image_embeds = []
566
568
  for single_ip_adapter_image, image_proj_layer in zip(
567
569
  ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
568
570
  ):
@@ -570,36 +572,28 @@ class StableDiffusionImg2ImgPipeline(
570
572
  single_image_embeds, single_negative_image_embeds = self.encode_image(
571
573
  single_ip_adapter_image, device, 1, output_hidden_state
572
574
  )
573
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
574
- single_negative_image_embeds = torch.stack(
575
- [single_negative_image_embeds] * num_images_per_prompt, dim=0
576
- )
577
575
 
576
+ image_embeds.append(single_image_embeds[None, :])
578
577
  if do_classifier_free_guidance:
579
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
580
- single_image_embeds = single_image_embeds.to(device)
581
-
582
- image_embeds.append(single_image_embeds)
578
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
583
579
  else:
584
- repeat_dims = [1]
585
- image_embeds = []
586
580
  for single_image_embeds in ip_adapter_image_embeds:
587
581
  if do_classifier_free_guidance:
588
582
  single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
589
- single_image_embeds = single_image_embeds.repeat(
590
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
591
- )
592
- single_negative_image_embeds = single_negative_image_embeds.repeat(
593
- num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
594
- )
595
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
596
- else:
597
- single_image_embeds = single_image_embeds.repeat(
598
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
599
- )
583
+ negative_image_embeds.append(single_negative_image_embeds)
600
584
  image_embeds.append(single_image_embeds)
601
585
 
602
- return image_embeds
586
+ ip_adapter_image_embeds = []
587
+ for i, single_image_embeds in enumerate(image_embeds):
588
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
589
+ if do_classifier_free_guidance:
590
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
591
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
592
+
593
+ single_image_embeds = single_image_embeds.to(device=device)
594
+ ip_adapter_image_embeds.append(single_image_embeds)
595
+
596
+ return ip_adapter_image_embeds
603
597
 
604
598
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
605
599
  def run_safety_checker(self, image, device, dtype):
@@ -746,6 +740,13 @@ class StableDiffusionImg2ImgPipeline(
746
740
  )
747
741
 
748
742
  elif isinstance(generator, list):
743
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
744
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
745
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
746
+ raise ValueError(
747
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
748
+ )
749
+
749
750
  init_latents = [
750
751
  retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
751
752
  for i in range(batch_size)
@@ -15,7 +15,6 @@
15
15
  import inspect
16
16
  from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
- import numpy as np
19
18
  import PIL.Image
20
19
  import torch
21
20
  from packaging import version
@@ -24,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
24
23
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
24
  from ...configuration_utils import FrozenDict
26
25
  from ...image_processor import PipelineImageInput, VaeImageProcessor
27
- from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
28
27
  from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
29
28
  from ...models.lora import adjust_lora_scale_text_encoder
30
29
  from ...schedulers import KarrasDiffusionSchedulers
@@ -38,128 +37,6 @@ from .safety_checker import StableDiffusionSafetyChecker
38
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
38
 
40
39
 
41
- def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
42
- """
43
- Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
44
- converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
45
- ``image`` and ``1`` for the ``mask``.
46
-
47
- The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
48
- binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
49
-
50
- Args:
51
- image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
52
- It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
53
- ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
54
- mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
55
- It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
56
- ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
57
-
58
-
59
- Raises:
60
- ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
61
- should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
62
- TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
63
- (ot the other way around).
64
-
65
- Returns:
66
- tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
67
- dimensions: ``batch x channels x height x width``.
68
- """
69
- deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
70
- deprecate(
71
- "prepare_mask_and_masked_image",
72
- "0.30.0",
73
- deprecation_message,
74
- )
75
- if image is None:
76
- raise ValueError("`image` input cannot be undefined.")
77
-
78
- if mask is None:
79
- raise ValueError("`mask_image` input cannot be undefined.")
80
-
81
- if isinstance(image, torch.Tensor):
82
- if not isinstance(mask, torch.Tensor):
83
- raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
84
-
85
- # Batch single image
86
- if image.ndim == 3:
87
- assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
88
- image = image.unsqueeze(0)
89
-
90
- # Batch and add channel dim for single mask
91
- if mask.ndim == 2:
92
- mask = mask.unsqueeze(0).unsqueeze(0)
93
-
94
- # Batch single mask or add channel dim
95
- if mask.ndim == 3:
96
- # Single batched mask, no channel dim or single mask not batched but channel dim
97
- if mask.shape[0] == 1:
98
- mask = mask.unsqueeze(0)
99
-
100
- # Batched masks no channel dim
101
- else:
102
- mask = mask.unsqueeze(1)
103
-
104
- assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
105
- assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
106
- assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
107
-
108
- # Check image is in [-1, 1]
109
- if image.min() < -1 or image.max() > 1:
110
- raise ValueError("Image should be in [-1, 1] range")
111
-
112
- # Check mask is in [0, 1]
113
- if mask.min() < 0 or mask.max() > 1:
114
- raise ValueError("Mask should be in [0, 1] range")
115
-
116
- # Binarize mask
117
- mask[mask < 0.5] = 0
118
- mask[mask >= 0.5] = 1
119
-
120
- # Image as float32
121
- image = image.to(dtype=torch.float32)
122
- elif isinstance(mask, torch.Tensor):
123
- raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
124
- else:
125
- # preprocess image
126
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
127
- image = [image]
128
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
129
- # resize all images w.r.t passed height an width
130
- image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
131
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
132
- image = np.concatenate(image, axis=0)
133
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
134
- image = np.concatenate([i[None, :] for i in image], axis=0)
135
-
136
- image = image.transpose(0, 3, 1, 2)
137
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
138
-
139
- # preprocess mask
140
- if isinstance(mask, (PIL.Image.Image, np.ndarray)):
141
- mask = [mask]
142
-
143
- if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
144
- mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
145
- mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
146
- mask = mask.astype(np.float32) / 255.0
147
- elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
148
- mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
149
-
150
- mask[mask < 0.5] = 0
151
- mask[mask >= 0.5] = 1
152
- mask = torch.from_numpy(mask)
153
-
154
- masked_image = image * (mask < 0.5)
155
-
156
- # n.b. ensure backwards compatibility as old function does not return image
157
- if return_image:
158
- return mask, masked_image, image
159
-
160
- return mask, masked_image
161
-
162
-
163
40
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
164
41
  def retrieve_latents(
165
42
  encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -239,7 +116,7 @@ class StableDiffusionInpaintPipeline(
239
116
  StableDiffusionMixin,
240
117
  TextualInversionLoaderMixin,
241
118
  IPAdapterMixin,
242
- LoraLoaderMixin,
119
+ StableDiffusionLoraLoaderMixin,
243
120
  FromSingleFileMixin,
244
121
  ):
245
122
  r"""
@@ -250,8 +127,8 @@ class StableDiffusionInpaintPipeline(
250
127
 
251
128
  The pipeline also inherits the following loading methods:
252
129
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
253
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
254
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
130
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
131
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
255
132
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
256
133
  - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
257
134
 
@@ -457,7 +334,7 @@ class StableDiffusionInpaintPipeline(
457
334
  """
458
335
  # set lora scale so that monkey patched LoRA
459
336
  # function of text encoder can correctly access it
460
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
337
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
461
338
  self._lora_scale = lora_scale
462
339
 
463
340
  # dynamically adjust the LoRA scale
@@ -590,7 +467,7 @@ class StableDiffusionInpaintPipeline(
590
467
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
591
468
 
592
469
  if self.text_encoder is not None:
593
- if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
470
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
594
471
  # Retrieve the original scale by scaling back the LoRA layers
595
472
  unscale_lora_layers(self.text_encoder, lora_scale)
596
473
 
@@ -625,6 +502,9 @@ class StableDiffusionInpaintPipeline(
625
502
  def prepare_ip_adapter_image_embeds(
626
503
  self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
627
504
  ):
505
+ image_embeds = []
506
+ if do_classifier_free_guidance:
507
+ negative_image_embeds = []
628
508
  if ip_adapter_image_embeds is None:
629
509
  if not isinstance(ip_adapter_image, list):
630
510
  ip_adapter_image = [ip_adapter_image]
@@ -634,7 +514,6 @@ class StableDiffusionInpaintPipeline(
634
514
  f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
635
515
  )
636
516
 
637
- image_embeds = []
638
517
  for single_ip_adapter_image, image_proj_layer in zip(
639
518
  ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
640
519
  ):
@@ -642,36 +521,28 @@ class StableDiffusionInpaintPipeline(
642
521
  single_image_embeds, single_negative_image_embeds = self.encode_image(
643
522
  single_ip_adapter_image, device, 1, output_hidden_state
644
523
  )
645
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
646
- single_negative_image_embeds = torch.stack(
647
- [single_negative_image_embeds] * num_images_per_prompt, dim=0
648
- )
649
524
 
525
+ image_embeds.append(single_image_embeds[None, :])
650
526
  if do_classifier_free_guidance:
651
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
652
- single_image_embeds = single_image_embeds.to(device)
653
-
654
- image_embeds.append(single_image_embeds)
527
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
655
528
  else:
656
- repeat_dims = [1]
657
- image_embeds = []
658
529
  for single_image_embeds in ip_adapter_image_embeds:
659
530
  if do_classifier_free_guidance:
660
531
  single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
661
- single_image_embeds = single_image_embeds.repeat(
662
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
663
- )
664
- single_negative_image_embeds = single_negative_image_embeds.repeat(
665
- num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
666
- )
667
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
668
- else:
669
- single_image_embeds = single_image_embeds.repeat(
670
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
671
- )
532
+ negative_image_embeds.append(single_negative_image_embeds)
672
533
  image_embeds.append(single_image_embeds)
673
534
 
674
- return image_embeds
535
+ ip_adapter_image_embeds = []
536
+ for i, single_image_embeds in enumerate(image_embeds):
537
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
538
+ if do_classifier_free_guidance:
539
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
540
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
541
+
542
+ single_image_embeds = single_image_embeds.to(device=device)
543
+ ip_adapter_image_embeds.append(single_image_embeds)
544
+
545
+ return ip_adapter_image_embeds
675
546
 
676
547
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
677
548
  def run_safety_checker(self, image, device, dtype):
@@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
22
22
 
23
23
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
24
24
  from ...image_processor import PipelineImageInput, VaeImageProcessor
25
- from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
25
+ from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
27
27
  from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import PIL_INTERPOLATION, deprecate, logging
@@ -74,7 +74,11 @@ def retrieve_latents(
74
74
 
75
75
 
76
76
  class StableDiffusionInstructPix2PixPipeline(
77
- DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin
77
+ DiffusionPipeline,
78
+ StableDiffusionMixin,
79
+ TextualInversionLoaderMixin,
80
+ StableDiffusionLoraLoaderMixin,
81
+ IPAdapterMixin,
78
82
  ):
79
83
  r"""
80
84
  Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -84,8 +88,8 @@ class StableDiffusionInstructPix2PixPipeline(
84
88
 
85
89
  The pipeline also inherits the following loading methods:
86
90
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
87
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
88
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
91
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
92
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
89
93
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
90
94
 
91
95
  Args: